Add RAG embedding and AI chat feature
Implement end-to-end RAG pipeline: AST-aware code chunking, LiteLLM embedding generation, MongoDB vector storage with brute-force cosine similarity fallback for self-hosted instances, and a chat API with RAG-augmented responses. Add dedicated /chat/:repo_id dashboard page with embedding build controls, message history, and source reference cards. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -20,6 +20,7 @@ impl ComplianceAgent {
|
||||
config.litellm_url.clone(),
|
||||
config.litellm_api_key.clone(),
|
||||
config.litellm_model.clone(),
|
||||
config.litellm_embed_model.clone(),
|
||||
));
|
||||
Self {
|
||||
config,
|
||||
|
||||
238
compliance-agent/src/api/handlers/chat.rs
Normal file
238
compliance-agent/src/api/handlers/chat.rs
Normal file
@@ -0,0 +1,238 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use axum::extract::{Extension, Path};
|
||||
use axum::http::StatusCode;
|
||||
use axum::Json;
|
||||
use mongodb::bson::doc;
|
||||
|
||||
use compliance_core::models::chat::{ChatRequest, ChatResponse, SourceReference};
|
||||
use compliance_core::models::embedding::EmbeddingBuildRun;
|
||||
use compliance_graph::graph::embedding_store::EmbeddingStore;
|
||||
|
||||
use crate::agent::ComplianceAgent;
|
||||
use crate::rag::pipeline::RagPipeline;
|
||||
|
||||
use super::ApiResponse;
|
||||
|
||||
type AgentExt = Extension<Arc<ComplianceAgent>>;
|
||||
|
||||
/// POST /api/v1/chat/:repo_id — Send a chat message with RAG context
|
||||
pub async fn chat(
|
||||
Extension(agent): AgentExt,
|
||||
Path(repo_id): Path<String>,
|
||||
Json(req): Json<ChatRequest>,
|
||||
) -> Result<Json<ApiResponse<ChatResponse>>, StatusCode> {
|
||||
let pipeline = RagPipeline::new(agent.llm.clone(), agent.db.inner());
|
||||
|
||||
// Step 1: Embed the user's message
|
||||
let query_vectors = agent
|
||||
.llm
|
||||
.embed(vec![req.message.clone()])
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!("Failed to embed query: {e}");
|
||||
StatusCode::INTERNAL_SERVER_ERROR
|
||||
})?;
|
||||
|
||||
let query_embedding = query_vectors.into_iter().next().ok_or_else(|| {
|
||||
tracing::error!("Empty embedding response");
|
||||
StatusCode::INTERNAL_SERVER_ERROR
|
||||
})?;
|
||||
|
||||
// Step 2: Vector search — retrieve top 8 chunks
|
||||
let search_results = pipeline
|
||||
.store()
|
||||
.vector_search(&repo_id, query_embedding, 8, 0.5)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!("Vector search failed: {e}");
|
||||
StatusCode::INTERNAL_SERVER_ERROR
|
||||
})?;
|
||||
|
||||
// Step 3: Build system prompt with code context
|
||||
let mut context_parts = Vec::new();
|
||||
let mut sources = Vec::new();
|
||||
|
||||
for (embedding, score) in &search_results {
|
||||
context_parts.push(format!(
|
||||
"--- {} ({}, {}:L{}-L{}) ---\n{}",
|
||||
embedding.qualified_name,
|
||||
embedding.kind,
|
||||
embedding.file_path,
|
||||
embedding.start_line,
|
||||
embedding.end_line,
|
||||
embedding.content,
|
||||
));
|
||||
|
||||
// Truncate snippet for the response
|
||||
let snippet: String = embedding
|
||||
.content
|
||||
.lines()
|
||||
.take(10)
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n");
|
||||
sources.push(SourceReference {
|
||||
file_path: embedding.file_path.clone(),
|
||||
qualified_name: embedding.qualified_name.clone(),
|
||||
start_line: embedding.start_line,
|
||||
end_line: embedding.end_line,
|
||||
language: embedding.language.clone(),
|
||||
snippet,
|
||||
score: *score,
|
||||
});
|
||||
}
|
||||
|
||||
let code_context = if context_parts.is_empty() {
|
||||
"No relevant code context found.".to_string()
|
||||
} else {
|
||||
context_parts.join("\n\n")
|
||||
};
|
||||
|
||||
let system_prompt = format!(
|
||||
"You are an expert code assistant for a software repository. \
|
||||
Answer the user's question based on the code context below. \
|
||||
Reference specific files and functions when relevant. \
|
||||
If the context doesn't contain enough information, say so.\n\n\
|
||||
## Code Context\n\n{code_context}"
|
||||
);
|
||||
|
||||
// Step 4: Build messages array with history
|
||||
let mut messages: Vec<(String, String)> = Vec::new();
|
||||
messages.push(("system".to_string(), system_prompt));
|
||||
|
||||
for msg in &req.history {
|
||||
messages.push((msg.role.clone(), msg.content.clone()));
|
||||
}
|
||||
messages.push(("user".to_string(), req.message));
|
||||
|
||||
// Step 5: Call LLM
|
||||
let response_text = agent
|
||||
.llm
|
||||
.chat_with_messages(messages, Some(0.3))
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!("LLM chat failed: {e}");
|
||||
StatusCode::INTERNAL_SERVER_ERROR
|
||||
})?;
|
||||
|
||||
Ok(Json(ApiResponse {
|
||||
data: ChatResponse {
|
||||
message: response_text,
|
||||
sources,
|
||||
},
|
||||
total: None,
|
||||
page: None,
|
||||
}))
|
||||
}
|
||||
|
||||
/// POST /api/v1/chat/:repo_id/build-embeddings — Trigger embedding build
|
||||
pub async fn build_embeddings(
|
||||
Extension(agent): AgentExt,
|
||||
Path(repo_id): Path<String>,
|
||||
) -> Result<Json<serde_json::Value>, StatusCode> {
|
||||
let agent_clone = (*agent).clone();
|
||||
tokio::spawn(async move {
|
||||
let repo = match agent_clone
|
||||
.db
|
||||
.repositories()
|
||||
.find_one(doc! { "_id": mongodb::bson::oid::ObjectId::parse_str(&repo_id).ok() })
|
||||
.await
|
||||
{
|
||||
Ok(Some(r)) => r,
|
||||
_ => {
|
||||
tracing::error!("Repository {repo_id} not found for embedding build");
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
// Get latest graph build
|
||||
let build = match agent_clone
|
||||
.db
|
||||
.graph_builds()
|
||||
.find_one(doc! { "repo_id": &repo_id })
|
||||
.sort(doc! { "started_at": -1 })
|
||||
.await
|
||||
{
|
||||
Ok(Some(b)) => b,
|
||||
_ => {
|
||||
tracing::error!("[{repo_id}] No graph build found — build graph first");
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
let graph_build_id = build
|
||||
.id
|
||||
.map(|id| id.to_hex())
|
||||
.unwrap_or_else(|| "unknown".to_string());
|
||||
|
||||
// Get nodes
|
||||
let nodes: Vec<compliance_core::models::graph::CodeNode> = match agent_clone
|
||||
.db
|
||||
.graph_nodes()
|
||||
.find(doc! { "repo_id": &repo_id })
|
||||
.await
|
||||
{
|
||||
Ok(cursor) => {
|
||||
use futures_util::StreamExt;
|
||||
let mut items = Vec::new();
|
||||
let mut cursor = cursor;
|
||||
while let Some(Ok(item)) = cursor.next().await {
|
||||
items.push(item);
|
||||
}
|
||||
items
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!("[{repo_id}] Failed to fetch nodes: {e}");
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
let git_ops = crate::pipeline::git::GitOps::new(&agent_clone.config.git_clone_base_path);
|
||||
let repo_path = match git_ops.clone_or_fetch(&repo.git_url, &repo.name) {
|
||||
Ok(p) => p,
|
||||
Err(e) => {
|
||||
tracing::error!("Failed to clone repo for embedding build: {e}");
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
let pipeline = RagPipeline::new(agent_clone.llm.clone(), agent_clone.db.inner());
|
||||
match pipeline
|
||||
.build_embeddings(&repo_id, &repo_path, &graph_build_id, &nodes)
|
||||
.await
|
||||
{
|
||||
Ok(run) => {
|
||||
tracing::info!(
|
||||
"[{repo_id}] Embedding build complete: {}/{} chunks",
|
||||
run.embedded_chunks,
|
||||
run.total_chunks
|
||||
);
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!("[{repo_id}] Embedding build failed: {e}");
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
Ok(Json(
|
||||
serde_json::json!({ "status": "embedding_build_triggered" }),
|
||||
))
|
||||
}
|
||||
|
||||
/// GET /api/v1/chat/:repo_id/status — Get latest embedding build status
|
||||
pub async fn embedding_status(
|
||||
Extension(agent): AgentExt,
|
||||
Path(repo_id): Path<String>,
|
||||
) -> Result<Json<ApiResponse<Option<EmbeddingBuildRun>>>, StatusCode> {
|
||||
let store = EmbeddingStore::new(agent.db.inner());
|
||||
let build = store.get_latest_build(&repo_id).await.map_err(|e| {
|
||||
tracing::error!("Failed to get embedding status: {e}");
|
||||
StatusCode::INTERNAL_SERVER_ERROR
|
||||
})?;
|
||||
|
||||
Ok(Json(ApiResponse {
|
||||
data: build,
|
||||
total: None,
|
||||
page: None,
|
||||
}))
|
||||
}
|
||||
@@ -1,3 +1,4 @@
|
||||
pub mod chat;
|
||||
pub mod dast;
|
||||
pub mod graph;
|
||||
|
||||
|
||||
@@ -23,10 +23,7 @@ pub fn build_router() -> Router {
|
||||
.route("/api/v1/issues", get(handlers::list_issues))
|
||||
.route("/api/v1/scan-runs", get(handlers::list_scan_runs))
|
||||
// Graph API endpoints
|
||||
.route(
|
||||
"/api/v1/graph/{repo_id}",
|
||||
get(handlers::graph::get_graph),
|
||||
)
|
||||
.route("/api/v1/graph/{repo_id}", get(handlers::graph::get_graph))
|
||||
.route(
|
||||
"/api/v1/graph/{repo_id}/nodes",
|
||||
get(handlers::graph::get_nodes),
|
||||
@@ -52,14 +49,8 @@ pub fn build_router() -> Router {
|
||||
post(handlers::graph::trigger_build),
|
||||
)
|
||||
// DAST API endpoints
|
||||
.route(
|
||||
"/api/v1/dast/targets",
|
||||
get(handlers::dast::list_targets),
|
||||
)
|
||||
.route(
|
||||
"/api/v1/dast/targets",
|
||||
post(handlers::dast::add_target),
|
||||
)
|
||||
.route("/api/v1/dast/targets", get(handlers::dast::list_targets))
|
||||
.route("/api/v1/dast/targets", post(handlers::dast::add_target))
|
||||
.route(
|
||||
"/api/v1/dast/targets/{id}/scan",
|
||||
post(handlers::dast::trigger_scan),
|
||||
@@ -68,12 +59,19 @@ pub fn build_router() -> Router {
|
||||
"/api/v1/dast/scan-runs",
|
||||
get(handlers::dast::list_scan_runs),
|
||||
)
|
||||
.route(
|
||||
"/api/v1/dast/findings",
|
||||
get(handlers::dast::list_findings),
|
||||
)
|
||||
.route("/api/v1/dast/findings", get(handlers::dast::list_findings))
|
||||
.route(
|
||||
"/api/v1/dast/findings/{id}",
|
||||
get(handlers::dast::get_finding),
|
||||
)
|
||||
// Chat / RAG API endpoints
|
||||
.route("/api/v1/chat/{repo_id}", post(handlers::chat::chat))
|
||||
.route(
|
||||
"/api/v1/chat/{repo_id}/build-embeddings",
|
||||
post(handlers::chat::build_embeddings),
|
||||
)
|
||||
.route(
|
||||
"/api/v1/chat/{repo_id}/status",
|
||||
get(handlers::chat::embedding_status),
|
||||
)
|
||||
}
|
||||
|
||||
@@ -24,6 +24,8 @@ pub fn load_config() -> Result<AgentConfig, AgentError> {
|
||||
.unwrap_or_else(|| "http://localhost:4000".to_string()),
|
||||
litellm_api_key: SecretString::from(env_var_opt("LITELLM_API_KEY").unwrap_or_default()),
|
||||
litellm_model: env_var_opt("LITELLM_MODEL").unwrap_or_else(|| "gpt-4o".to_string()),
|
||||
litellm_embed_model: env_var_opt("LITELLM_EMBED_MODEL")
|
||||
.unwrap_or_else(|| "text-embedding-3-small".to_string()),
|
||||
github_token: env_secret_opt("GITHUB_TOKEN"),
|
||||
github_webhook_secret: env_secret_opt("GITHUB_WEBHOOK_SECRET"),
|
||||
gitlab_url: env_var_opt("GITLAB_URL"),
|
||||
|
||||
@@ -127,11 +127,7 @@ impl Database {
|
||||
|
||||
// dast_targets: index on repo_id
|
||||
self.dast_targets()
|
||||
.create_index(
|
||||
IndexModel::builder()
|
||||
.keys(doc! { "repo_id": 1 })
|
||||
.build(),
|
||||
)
|
||||
.create_index(IndexModel::builder().keys(doc! { "repo_id": 1 }).build())
|
||||
.await?;
|
||||
|
||||
// dast_scan_runs: compound (target_id, started_at DESC)
|
||||
@@ -152,6 +148,24 @@ impl Database {
|
||||
)
|
||||
.await?;
|
||||
|
||||
// code_embeddings: compound (repo_id, graph_build_id)
|
||||
self.code_embeddings()
|
||||
.create_index(
|
||||
IndexModel::builder()
|
||||
.keys(doc! { "repo_id": 1, "graph_build_id": 1 })
|
||||
.build(),
|
||||
)
|
||||
.await?;
|
||||
|
||||
// embedding_builds: compound (repo_id, started_at DESC)
|
||||
self.embedding_builds()
|
||||
.create_index(
|
||||
IndexModel::builder()
|
||||
.keys(doc! { "repo_id": 1, "started_at": -1 })
|
||||
.build(),
|
||||
)
|
||||
.await?;
|
||||
|
||||
tracing::info!("Database indexes ensured");
|
||||
Ok(())
|
||||
}
|
||||
@@ -210,6 +224,17 @@ impl Database {
|
||||
self.inner.collection("dast_findings")
|
||||
}
|
||||
|
||||
// Embedding collections
|
||||
pub fn code_embeddings(&self) -> Collection<compliance_core::models::embedding::CodeEmbedding> {
|
||||
self.inner.collection("code_embeddings")
|
||||
}
|
||||
|
||||
pub fn embedding_builds(
|
||||
&self,
|
||||
) -> Collection<compliance_core::models::embedding::EmbeddingBuildRun> {
|
||||
self.inner.collection("embedding_builds")
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub fn raw_collection(&self, name: &str) -> Collection<mongodb::bson::Document> {
|
||||
self.inner.collection(name)
|
||||
|
||||
@@ -8,6 +8,7 @@ pub struct LlmClient {
|
||||
base_url: String,
|
||||
api_key: SecretString,
|
||||
model: String,
|
||||
embed_model: String,
|
||||
http: reqwest::Client,
|
||||
}
|
||||
|
||||
@@ -42,16 +43,46 @@ struct ChatResponseMessage {
|
||||
content: String,
|
||||
}
|
||||
|
||||
/// Request body for the embeddings API
|
||||
#[derive(Serialize)]
|
||||
struct EmbeddingRequest {
|
||||
model: String,
|
||||
input: Vec<String>,
|
||||
}
|
||||
|
||||
/// Response from the embeddings API
|
||||
#[derive(Deserialize)]
|
||||
struct EmbeddingResponse {
|
||||
data: Vec<EmbeddingData>,
|
||||
}
|
||||
|
||||
/// A single embedding result
|
||||
#[derive(Deserialize)]
|
||||
struct EmbeddingData {
|
||||
embedding: Vec<f64>,
|
||||
index: usize,
|
||||
}
|
||||
|
||||
impl LlmClient {
|
||||
pub fn new(base_url: String, api_key: SecretString, model: String) -> Self {
|
||||
pub fn new(
|
||||
base_url: String,
|
||||
api_key: SecretString,
|
||||
model: String,
|
||||
embed_model: String,
|
||||
) -> Self {
|
||||
Self {
|
||||
base_url,
|
||||
api_key,
|
||||
model,
|
||||
embed_model,
|
||||
http: reqwest::Client::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn embed_model(&self) -> &str {
|
||||
&self.embed_model
|
||||
}
|
||||
|
||||
pub async fn chat(
|
||||
&self,
|
||||
system_prompt: &str,
|
||||
@@ -169,4 +200,49 @@ impl LlmClient {
|
||||
.map(|c| c.message.content.clone())
|
||||
.ok_or_else(|| AgentError::Other("Empty response from LiteLLM".to_string()))
|
||||
}
|
||||
|
||||
/// Generate embeddings for a batch of texts
|
||||
pub async fn embed(&self, texts: Vec<String>) -> Result<Vec<Vec<f64>>, AgentError> {
|
||||
let url = format!("{}/v1/embeddings", self.base_url.trim_end_matches('/'));
|
||||
|
||||
let request_body = EmbeddingRequest {
|
||||
model: self.embed_model.clone(),
|
||||
input: texts,
|
||||
};
|
||||
|
||||
let mut req = self
|
||||
.http
|
||||
.post(&url)
|
||||
.header("content-type", "application/json")
|
||||
.json(&request_body);
|
||||
|
||||
let key = self.api_key.expose_secret();
|
||||
if !key.is_empty() {
|
||||
req = req.header("Authorization", format!("Bearer {key}"));
|
||||
}
|
||||
|
||||
let resp = req
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| AgentError::Other(format!("Embedding request failed: {e}")))?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
let status = resp.status();
|
||||
let body = resp.text().await.unwrap_or_default();
|
||||
return Err(AgentError::Other(format!(
|
||||
"Embedding API returned {status}: {body}"
|
||||
)));
|
||||
}
|
||||
|
||||
let body: EmbeddingResponse = resp
|
||||
.json()
|
||||
.await
|
||||
.map_err(|e| AgentError::Other(format!("Failed to parse embedding response: {e}")))?;
|
||||
|
||||
// Sort by index to maintain input order
|
||||
let mut data = body.data;
|
||||
data.sort_by_key(|d| d.index);
|
||||
|
||||
Ok(data.into_iter().map(|d| d.embedding).collect())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,6 +7,7 @@ mod database;
|
||||
mod error;
|
||||
mod llm;
|
||||
mod pipeline;
|
||||
mod rag;
|
||||
mod scheduler;
|
||||
#[allow(dead_code)]
|
||||
mod trackers;
|
||||
|
||||
1
compliance-agent/src/rag/mod.rs
Normal file
1
compliance-agent/src/rag/mod.rs
Normal file
@@ -0,0 +1 @@
|
||||
pub mod pipeline;
|
||||
164
compliance-agent/src/rag/pipeline.rs
Normal file
164
compliance-agent/src/rag/pipeline.rs
Normal file
@@ -0,0 +1,164 @@
|
||||
use std::path::Path;
|
||||
use std::sync::Arc;
|
||||
|
||||
use chrono::Utc;
|
||||
use compliance_core::models::embedding::{CodeEmbedding, EmbeddingBuildRun, EmbeddingBuildStatus};
|
||||
use compliance_core::models::graph::CodeNode;
|
||||
use compliance_graph::graph::chunking::extract_chunks;
|
||||
use compliance_graph::graph::embedding_store::EmbeddingStore;
|
||||
use tracing::{error, info};
|
||||
|
||||
use crate::error::AgentError;
|
||||
use crate::llm::LlmClient;
|
||||
|
||||
/// RAG pipeline for building embeddings and performing retrieval
|
||||
pub struct RagPipeline {
|
||||
llm: Arc<LlmClient>,
|
||||
embedding_store: EmbeddingStore,
|
||||
}
|
||||
|
||||
impl RagPipeline {
|
||||
pub fn new(llm: Arc<LlmClient>, db: &mongodb::Database) -> Self {
|
||||
Self {
|
||||
llm,
|
||||
embedding_store: EmbeddingStore::new(db),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn store(&self) -> &EmbeddingStore {
|
||||
&self.embedding_store
|
||||
}
|
||||
|
||||
/// Build embeddings for all code nodes in a repository
|
||||
pub async fn build_embeddings(
|
||||
&self,
|
||||
repo_id: &str,
|
||||
repo_path: &Path,
|
||||
graph_build_id: &str,
|
||||
nodes: &[CodeNode],
|
||||
) -> Result<EmbeddingBuildRun, AgentError> {
|
||||
let embed_model = self.llm.embed_model().to_string();
|
||||
let mut build =
|
||||
EmbeddingBuildRun::new(repo_id.to_string(), graph_build_id.to_string(), embed_model);
|
||||
|
||||
// Step 1: Extract chunks
|
||||
let chunks = extract_chunks(repo_path, nodes, 2048);
|
||||
build.total_chunks = chunks.len() as u32;
|
||||
info!(
|
||||
"[{repo_id}] Extracted {} chunks for embedding",
|
||||
chunks.len()
|
||||
);
|
||||
|
||||
// Store the initial build record
|
||||
self.embedding_store
|
||||
.store_build(&build)
|
||||
.await
|
||||
.map_err(|e| AgentError::Other(format!("Failed to store build: {e}")))?;
|
||||
|
||||
if chunks.is_empty() {
|
||||
build.status = EmbeddingBuildStatus::Completed;
|
||||
build.completed_at = Some(Utc::now());
|
||||
self.embedding_store
|
||||
.update_build(
|
||||
repo_id,
|
||||
graph_build_id,
|
||||
EmbeddingBuildStatus::Completed,
|
||||
0,
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.map_err(|e| AgentError::Other(format!("Failed to update build: {e}")))?;
|
||||
return Ok(build);
|
||||
}
|
||||
|
||||
// Step 2: Delete old embeddings for this repo
|
||||
self.embedding_store
|
||||
.delete_repo_embeddings(repo_id)
|
||||
.await
|
||||
.map_err(|e| AgentError::Other(format!("Failed to delete old embeddings: {e}")))?;
|
||||
|
||||
// Step 3: Batch embed (small batches to stay within model limits)
|
||||
let batch_size = 20;
|
||||
let mut all_embeddings = Vec::new();
|
||||
let mut embedded_count = 0u32;
|
||||
|
||||
for batch_start in (0..chunks.len()).step_by(batch_size) {
|
||||
let batch_end = (batch_start + batch_size).min(chunks.len());
|
||||
let batch_chunks = &chunks[batch_start..batch_end];
|
||||
|
||||
// Prepare texts: context_header + content
|
||||
let texts: Vec<String> = batch_chunks
|
||||
.iter()
|
||||
.map(|c| format!("{}\n{}", c.context_header, c.content))
|
||||
.collect();
|
||||
|
||||
match self.llm.embed(texts).await {
|
||||
Ok(vectors) => {
|
||||
for (chunk, embedding) in batch_chunks.iter().zip(vectors) {
|
||||
all_embeddings.push(CodeEmbedding {
|
||||
id: None,
|
||||
repo_id: repo_id.to_string(),
|
||||
graph_build_id: graph_build_id.to_string(),
|
||||
qualified_name: chunk.qualified_name.clone(),
|
||||
kind: chunk.kind.clone(),
|
||||
file_path: chunk.file_path.clone(),
|
||||
start_line: chunk.start_line,
|
||||
end_line: chunk.end_line,
|
||||
language: chunk.language.clone(),
|
||||
content: chunk.content.clone(),
|
||||
context_header: chunk.context_header.clone(),
|
||||
embedding,
|
||||
token_estimate: chunk.token_estimate,
|
||||
created_at: Utc::now(),
|
||||
});
|
||||
}
|
||||
embedded_count += batch_chunks.len() as u32;
|
||||
}
|
||||
Err(e) => {
|
||||
error!("[{repo_id}] Embedding batch failed: {e}");
|
||||
build.status = EmbeddingBuildStatus::Failed;
|
||||
build.error_message = Some(e.to_string());
|
||||
build.completed_at = Some(Utc::now());
|
||||
let _ = self
|
||||
.embedding_store
|
||||
.update_build(
|
||||
repo_id,
|
||||
graph_build_id,
|
||||
EmbeddingBuildStatus::Failed,
|
||||
embedded_count,
|
||||
Some(e.to_string()),
|
||||
)
|
||||
.await;
|
||||
return Ok(build);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Step 4: Store all embeddings
|
||||
self.embedding_store
|
||||
.store_embeddings(&all_embeddings)
|
||||
.await
|
||||
.map_err(|e| AgentError::Other(format!("Failed to store embeddings: {e}")))?;
|
||||
|
||||
// Step 5: Update build status
|
||||
build.status = EmbeddingBuildStatus::Completed;
|
||||
build.embedded_chunks = embedded_count;
|
||||
build.completed_at = Some(Utc::now());
|
||||
self.embedding_store
|
||||
.update_build(
|
||||
repo_id,
|
||||
graph_build_id,
|
||||
EmbeddingBuildStatus::Completed,
|
||||
embedded_count,
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.map_err(|e| AgentError::Other(format!("Failed to update build: {e}")))?;
|
||||
|
||||
info!(
|
||||
"[{repo_id}] Embedding build complete: {embedded_count}/{} chunks",
|
||||
build.total_chunks
|
||||
);
|
||||
Ok(build)
|
||||
}
|
||||
}
|
||||
@@ -8,6 +8,7 @@ pub struct AgentConfig {
|
||||
pub litellm_url: String,
|
||||
pub litellm_api_key: SecretString,
|
||||
pub litellm_model: String,
|
||||
pub litellm_embed_model: String,
|
||||
pub github_token: Option<SecretString>,
|
||||
pub github_webhook_secret: Option<SecretString>,
|
||||
pub gitlab_url: Option<String>,
|
||||
|
||||
35
compliance-core/src/models/chat.rs
Normal file
35
compliance-core/src/models/chat.rs
Normal file
@@ -0,0 +1,35 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// A message in the chat history
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ChatMessage {
|
||||
pub role: String,
|
||||
pub content: String,
|
||||
}
|
||||
|
||||
/// Request body for the chat endpoint
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ChatRequest {
|
||||
pub message: String,
|
||||
#[serde(default)]
|
||||
pub history: Vec<ChatMessage>,
|
||||
}
|
||||
|
||||
/// A source reference from the RAG retrieval
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct SourceReference {
|
||||
pub file_path: String,
|
||||
pub qualified_name: String,
|
||||
pub start_line: u32,
|
||||
pub end_line: u32,
|
||||
pub language: String,
|
||||
pub snippet: String,
|
||||
pub score: f64,
|
||||
}
|
||||
|
||||
/// Response from the chat endpoint
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ChatResponse {
|
||||
pub message: String,
|
||||
pub sources: Vec<SourceReference>,
|
||||
}
|
||||
100
compliance-core/src/models/embedding.rs
Normal file
100
compliance-core/src/models/embedding.rs
Normal file
@@ -0,0 +1,100 @@
|
||||
use chrono::{DateTime, Utc};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Status of an embedding build operation
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum EmbeddingBuildStatus {
|
||||
Running,
|
||||
Completed,
|
||||
Failed,
|
||||
}
|
||||
|
||||
/// A code embedding stored in MongoDB Atlas Vector Search
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct CodeEmbedding {
|
||||
#[serde(rename = "_id", skip_serializing_if = "Option::is_none")]
|
||||
pub id: Option<bson::oid::ObjectId>,
|
||||
pub repo_id: String,
|
||||
pub graph_build_id: String,
|
||||
pub qualified_name: String,
|
||||
pub kind: String,
|
||||
pub file_path: String,
|
||||
pub start_line: u32,
|
||||
pub end_line: u32,
|
||||
pub language: String,
|
||||
pub content: String,
|
||||
pub context_header: String,
|
||||
pub embedding: Vec<f64>,
|
||||
pub token_estimate: u32,
|
||||
#[serde(with = "bson::serde_helpers::chrono_datetime_as_bson_datetime")]
|
||||
pub created_at: DateTime<Utc>,
|
||||
}
|
||||
|
||||
/// Tracks an embedding build operation for a repository
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct EmbeddingBuildRun {
|
||||
#[serde(rename = "_id", skip_serializing_if = "Option::is_none")]
|
||||
pub id: Option<bson::oid::ObjectId>,
|
||||
pub repo_id: String,
|
||||
pub graph_build_id: String,
|
||||
pub status: EmbeddingBuildStatus,
|
||||
pub total_chunks: u32,
|
||||
pub embedded_chunks: u32,
|
||||
pub embedding_model: String,
|
||||
pub error_message: Option<String>,
|
||||
#[serde(with = "bson::serde_helpers::chrono_datetime_as_bson_datetime")]
|
||||
pub started_at: DateTime<Utc>,
|
||||
#[serde(
|
||||
default,
|
||||
skip_serializing_if = "Option::is_none",
|
||||
with = "opt_chrono_as_bson"
|
||||
)]
|
||||
pub completed_at: Option<DateTime<Utc>>,
|
||||
}
|
||||
|
||||
impl EmbeddingBuildRun {
|
||||
pub fn new(repo_id: String, graph_build_id: String, embedding_model: String) -> Self {
|
||||
Self {
|
||||
id: None,
|
||||
repo_id,
|
||||
graph_build_id,
|
||||
status: EmbeddingBuildStatus::Running,
|
||||
total_chunks: 0,
|
||||
embedded_chunks: 0,
|
||||
embedding_model,
|
||||
error_message: None,
|
||||
started_at: Utc::now(),
|
||||
completed_at: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Serde helper for Option<DateTime<Utc>> as BSON DateTime
|
||||
mod opt_chrono_as_bson {
|
||||
use chrono::{DateTime, Utc};
|
||||
use serde::{Deserialize, Deserializer, Serialize, Serializer};
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
struct BsonDt(
|
||||
#[serde(with = "bson::serde_helpers::chrono_datetime_as_bson_datetime")] DateTime<Utc>,
|
||||
);
|
||||
|
||||
pub fn serialize<S>(value: &Option<DateTime<Utc>>, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: Serializer,
|
||||
{
|
||||
match value {
|
||||
Some(dt) => BsonDt(*dt).serialize(serializer),
|
||||
None => serializer.serialize_none(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn deserialize<'de, D>(deserializer: D) -> Result<Option<DateTime<Utc>>, D::Error>
|
||||
where
|
||||
D: Deserializer<'de>,
|
||||
{
|
||||
let opt: Option<BsonDt> = Option::deserialize(deserializer)?;
|
||||
Ok(opt.map(|d| d.0))
|
||||
}
|
||||
}
|
||||
@@ -1,5 +1,7 @@
|
||||
pub mod chat;
|
||||
pub mod cve;
|
||||
pub mod dast;
|
||||
pub mod embedding;
|
||||
pub mod finding;
|
||||
pub mod graph;
|
||||
pub mod issue;
|
||||
@@ -7,15 +9,16 @@ pub mod repository;
|
||||
pub mod sbom;
|
||||
pub mod scan;
|
||||
|
||||
pub use chat::{ChatMessage, ChatRequest, ChatResponse, SourceReference};
|
||||
pub use cve::{CveAlert, CveSource};
|
||||
pub use dast::{
|
||||
DastAuthConfig, DastEvidence, DastFinding, DastScanPhase, DastScanRun, DastScanStatus,
|
||||
DastTarget, DastTargetType, DastVulnType,
|
||||
};
|
||||
pub use embedding::{CodeEmbedding, EmbeddingBuildRun, EmbeddingBuildStatus};
|
||||
pub use finding::{Finding, FindingStatus, Severity};
|
||||
pub use graph::{
|
||||
CodeEdge, CodeEdgeKind, CodeNode, CodeNodeKind, GraphBuildRun, GraphBuildStatus,
|
||||
ImpactAnalysis,
|
||||
CodeEdge, CodeEdgeKind, CodeNode, CodeNodeKind, GraphBuildRun, GraphBuildStatus, ImpactAnalysis,
|
||||
};
|
||||
pub use issue::{IssueStatus, TrackerIssue, TrackerType};
|
||||
pub use repository::{ScanTrigger, TrackedRepository};
|
||||
|
||||
@@ -1710,3 +1710,240 @@ tbody tr:last-child td {
|
||||
white-space: nowrap;
|
||||
margin-left: auto;
|
||||
}
|
||||
|
||||
/* ── AI Chat ── */
|
||||
|
||||
.chat-embedding-banner {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: space-between;
|
||||
padding: 12px 20px;
|
||||
background: var(--bg-card);
|
||||
border: 1px solid var(--border);
|
||||
border-radius: var(--radius);
|
||||
margin-bottom: 16px;
|
||||
font-size: 13px;
|
||||
color: var(--text-secondary);
|
||||
}
|
||||
|
||||
.chat-embedding-banner .btn-sm {
|
||||
padding: 6px 14px;
|
||||
font-size: 12px;
|
||||
background: var(--accent-muted);
|
||||
color: var(--accent);
|
||||
border: 1px solid var(--border-accent);
|
||||
border-radius: var(--radius-sm);
|
||||
cursor: pointer;
|
||||
transition: all 0.2s var(--ease-out);
|
||||
}
|
||||
|
||||
.chat-embedding-banner .btn-sm:hover:not(:disabled) {
|
||||
background: var(--accent);
|
||||
color: var(--bg-primary);
|
||||
}
|
||||
|
||||
.chat-embedding-banner .btn-sm:disabled {
|
||||
opacity: 0.5;
|
||||
cursor: not-allowed;
|
||||
}
|
||||
|
||||
.chat-container {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
height: calc(100vh - 240px);
|
||||
background: var(--bg-card);
|
||||
border: 1px solid var(--border);
|
||||
border-radius: var(--radius);
|
||||
overflow: hidden;
|
||||
}
|
||||
|
||||
.chat-messages {
|
||||
flex: 1;
|
||||
overflow-y: auto;
|
||||
padding: 20px;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 16px;
|
||||
}
|
||||
|
||||
.chat-empty {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
height: 100%;
|
||||
color: var(--text-tertiary);
|
||||
text-align: center;
|
||||
}
|
||||
|
||||
.chat-empty h3 {
|
||||
font-family: var(--font-display);
|
||||
font-size: 18px;
|
||||
font-weight: 600;
|
||||
color: var(--text-secondary);
|
||||
margin-bottom: 8px;
|
||||
}
|
||||
|
||||
.chat-empty p {
|
||||
font-size: 13px;
|
||||
max-width: 400px;
|
||||
}
|
||||
|
||||
.chat-message {
|
||||
max-width: 80%;
|
||||
padding: 12px 16px;
|
||||
border-radius: var(--radius);
|
||||
font-size: 14px;
|
||||
line-height: 1.6;
|
||||
}
|
||||
|
||||
.chat-message-user {
|
||||
align-self: flex-end;
|
||||
background: var(--accent-muted);
|
||||
border: 1px solid var(--border-accent);
|
||||
color: var(--text-primary);
|
||||
}
|
||||
|
||||
.chat-message-assistant {
|
||||
align-self: flex-start;
|
||||
background: var(--bg-elevated);
|
||||
border: 1px solid var(--border);
|
||||
color: var(--text-primary);
|
||||
}
|
||||
|
||||
.chat-message-role {
|
||||
font-family: var(--font-display);
|
||||
font-size: 11px;
|
||||
font-weight: 600;
|
||||
text-transform: uppercase;
|
||||
letter-spacing: 0.5px;
|
||||
color: var(--text-tertiary);
|
||||
margin-bottom: 6px;
|
||||
}
|
||||
|
||||
.chat-message-content {
|
||||
white-space: pre-wrap;
|
||||
word-break: break-word;
|
||||
}
|
||||
|
||||
.chat-typing {
|
||||
color: var(--text-tertiary);
|
||||
font-style: italic;
|
||||
}
|
||||
|
||||
.chat-sources {
|
||||
margin-top: 12px;
|
||||
border-top: 1px solid var(--border);
|
||||
padding-top: 10px;
|
||||
}
|
||||
|
||||
.chat-sources-label {
|
||||
font-size: 11px;
|
||||
font-weight: 600;
|
||||
text-transform: uppercase;
|
||||
letter-spacing: 0.5px;
|
||||
color: var(--text-tertiary);
|
||||
display: block;
|
||||
margin-bottom: 8px;
|
||||
}
|
||||
|
||||
.chat-source-card {
|
||||
background: var(--bg-secondary);
|
||||
border: 1px solid var(--border);
|
||||
border-radius: var(--radius-sm);
|
||||
padding: 10px 12px;
|
||||
margin-bottom: 6px;
|
||||
}
|
||||
|
||||
.chat-source-header {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: space-between;
|
||||
margin-bottom: 6px;
|
||||
}
|
||||
|
||||
.chat-source-name {
|
||||
font-family: var(--font-mono);
|
||||
font-size: 12px;
|
||||
font-weight: 500;
|
||||
color: var(--accent);
|
||||
}
|
||||
|
||||
.chat-source-location {
|
||||
font-family: var(--font-mono);
|
||||
font-size: 11px;
|
||||
color: var(--text-tertiary);
|
||||
}
|
||||
|
||||
.chat-source-snippet {
|
||||
margin: 0;
|
||||
padding: 8px;
|
||||
background: var(--bg-primary);
|
||||
border-radius: 4px;
|
||||
overflow-x: auto;
|
||||
max-height: 120px;
|
||||
}
|
||||
|
||||
.chat-source-snippet code {
|
||||
font-family: var(--font-mono);
|
||||
font-size: 11px;
|
||||
color: var(--text-secondary);
|
||||
white-space: pre;
|
||||
}
|
||||
|
||||
.chat-input-area {
|
||||
display: flex;
|
||||
gap: 10px;
|
||||
padding: 16px 20px;
|
||||
border-top: 1px solid var(--border);
|
||||
background: var(--bg-secondary);
|
||||
}
|
||||
|
||||
.chat-input {
|
||||
flex: 1;
|
||||
background: var(--bg-primary);
|
||||
border: 1px solid var(--border);
|
||||
border-radius: var(--radius-sm);
|
||||
color: var(--text-primary);
|
||||
font-family: var(--font-body);
|
||||
font-size: 14px;
|
||||
padding: 10px 14px;
|
||||
resize: none;
|
||||
min-height: 42px;
|
||||
max-height: 120px;
|
||||
outline: none;
|
||||
transition: border-color 0.2s var(--ease-out);
|
||||
}
|
||||
|
||||
.chat-input:focus {
|
||||
border-color: var(--accent);
|
||||
}
|
||||
|
||||
.chat-input:disabled {
|
||||
opacity: 0.5;
|
||||
cursor: not-allowed;
|
||||
}
|
||||
|
||||
.chat-send-btn {
|
||||
padding: 10px 20px;
|
||||
background: var(--accent);
|
||||
color: var(--bg-primary);
|
||||
border: none;
|
||||
border-radius: var(--radius-sm);
|
||||
font-family: var(--font-display);
|
||||
font-weight: 600;
|
||||
font-size: 13px;
|
||||
cursor: pointer;
|
||||
transition: all 0.2s var(--ease-out);
|
||||
align-self: flex-end;
|
||||
}
|
||||
|
||||
.chat-send-btn:hover:not(:disabled) {
|
||||
background: var(--accent-hover);
|
||||
box-shadow: var(--accent-glow);
|
||||
}
|
||||
|
||||
.chat-send-btn:disabled {
|
||||
opacity: 0.5;
|
||||
cursor: not-allowed;
|
||||
}
|
||||
|
||||
@@ -26,6 +26,10 @@ pub enum Route {
|
||||
GraphExplorerPage { repo_id: String },
|
||||
#[route("/graph/:repo_id/impact/:finding_id")]
|
||||
ImpactAnalysisPage { repo_id: String, finding_id: String },
|
||||
#[route("/chat")]
|
||||
ChatIndexPage {},
|
||||
#[route("/chat/:repo_id")]
|
||||
ChatPage { repo_id: String },
|
||||
#[route("/dast")]
|
||||
DastOverviewPage {},
|
||||
#[route("/dast/targets")]
|
||||
|
||||
@@ -46,6 +46,11 @@ pub fn Sidebar() -> Element {
|
||||
route: Route::GraphIndexPage {},
|
||||
icon: rsx! { Icon { icon: BsDiagram3, width: 18, height: 18 } },
|
||||
},
|
||||
NavItem {
|
||||
label: "AI Chat",
|
||||
route: Route::ChatIndexPage {},
|
||||
icon: rsx! { Icon { icon: BsChatDots, width: 18, height: 18 } },
|
||||
},
|
||||
NavItem {
|
||||
label: "DAST",
|
||||
route: Route::DastOverviewPage {},
|
||||
@@ -58,7 +63,11 @@ pub fn Sidebar() -> Element {
|
||||
},
|
||||
];
|
||||
|
||||
let sidebar_class = if collapsed() { "sidebar collapsed" } else { "sidebar" };
|
||||
let sidebar_class = if collapsed() {
|
||||
"sidebar collapsed"
|
||||
} else {
|
||||
"sidebar"
|
||||
};
|
||||
|
||||
rsx! {
|
||||
nav { class: "{sidebar_class}",
|
||||
@@ -76,6 +85,7 @@ pub fn Sidebar() -> Element {
|
||||
(Route::GraphIndexPage {}, Route::GraphIndexPage {}) => true,
|
||||
(Route::GraphExplorerPage { .. }, Route::GraphIndexPage {}) => true,
|
||||
(Route::ImpactAnalysisPage { .. }, Route::GraphIndexPage {}) => true,
|
||||
(Route::ChatPage { .. }, Route::ChatIndexPage {}) => true,
|
||||
(Route::DastTargetsPage {}, Route::DastOverviewPage {}) => true,
|
||||
(Route::DastFindingsPage {}, Route::DastOverviewPage {}) => true,
|
||||
(Route::DastFindingDetailPage { .. }, Route::DastOverviewPage {}) => true,
|
||||
|
||||
126
compliance-dashboard/src/infrastructure/chat.rs
Normal file
126
compliance-dashboard/src/infrastructure/chat.rs
Normal file
@@ -0,0 +1,126 @@
|
||||
use dioxus::prelude::*;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
// ── Response types ──
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
||||
pub struct ChatApiResponse {
|
||||
pub data: ChatResponseData,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
||||
pub struct ChatResponseData {
|
||||
pub message: String,
|
||||
#[serde(default)]
|
||||
pub sources: Vec<SourceRef>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
||||
pub struct SourceRef {
|
||||
pub file_path: String,
|
||||
pub qualified_name: String,
|
||||
pub start_line: u32,
|
||||
pub end_line: u32,
|
||||
pub language: String,
|
||||
pub snippet: String,
|
||||
pub score: f64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
||||
pub struct EmbeddingStatusResponse {
|
||||
pub data: Option<EmbeddingBuildData>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
||||
pub struct EmbeddingBuildData {
|
||||
pub repo_id: String,
|
||||
pub status: String,
|
||||
pub total_chunks: u32,
|
||||
pub embedded_chunks: u32,
|
||||
pub embedding_model: String,
|
||||
pub error_message: Option<String>,
|
||||
#[serde(default)]
|
||||
pub started_at: Option<serde_json::Value>,
|
||||
#[serde(default)]
|
||||
pub completed_at: Option<serde_json::Value>,
|
||||
}
|
||||
|
||||
// ── Chat message history type ──
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ChatHistoryMessage {
|
||||
pub role: String,
|
||||
pub content: String,
|
||||
}
|
||||
|
||||
// ── Server functions ──
|
||||
|
||||
#[server]
|
||||
pub async fn send_chat_message(
|
||||
repo_id: String,
|
||||
message: String,
|
||||
history: Vec<ChatHistoryMessage>,
|
||||
) -> Result<ChatApiResponse, ServerFnError> {
|
||||
let state: super::server_state::ServerState =
|
||||
dioxus_fullstack::FullstackContext::extract().await?;
|
||||
|
||||
let url = format!("{}/api/v1/chat/{repo_id}", state.agent_api_url);
|
||||
let client = reqwest::Client::builder()
|
||||
.timeout(std::time::Duration::from_secs(120))
|
||||
.build()
|
||||
.map_err(|e| ServerFnError::new(e.to_string()))?;
|
||||
let resp = client
|
||||
.post(&url)
|
||||
.json(&serde_json::json!({
|
||||
"message": message,
|
||||
"history": history,
|
||||
}))
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| ServerFnError::new(format!("Request failed: {e}")))?;
|
||||
|
||||
let text = resp
|
||||
.text()
|
||||
.await
|
||||
.map_err(|e| ServerFnError::new(format!("Failed to read response: {e}")))?;
|
||||
|
||||
let body: ChatApiResponse = serde_json::from_str(&text)
|
||||
.map_err(|e| ServerFnError::new(format!("Failed to parse response: {e} — body: {text}")))?;
|
||||
Ok(body)
|
||||
}
|
||||
|
||||
#[server]
|
||||
pub async fn trigger_embedding_build(repo_id: String) -> Result<(), ServerFnError> {
|
||||
let state: super::server_state::ServerState =
|
||||
dioxus_fullstack::FullstackContext::extract().await?;
|
||||
|
||||
let url = format!(
|
||||
"{}/api/v1/chat/{repo_id}/build-embeddings",
|
||||
state.agent_api_url
|
||||
);
|
||||
let client = reqwest::Client::new();
|
||||
client
|
||||
.post(&url)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| ServerFnError::new(e.to_string()))?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[server]
|
||||
pub async fn fetch_embedding_status(
|
||||
repo_id: String,
|
||||
) -> Result<EmbeddingStatusResponse, ServerFnError> {
|
||||
let state: super::server_state::ServerState =
|
||||
dioxus_fullstack::FullstackContext::extract().await?;
|
||||
|
||||
let url = format!("{}/api/v1/chat/{repo_id}/status", state.agent_api_url);
|
||||
let resp = reqwest::get(&url)
|
||||
.await
|
||||
.map_err(|e| ServerFnError::new(e.to_string()))?;
|
||||
let body: EmbeddingStatusResponse = resp
|
||||
.json()
|
||||
.await
|
||||
.map_err(|e| ServerFnError::new(e.to_string()))?;
|
||||
Ok(body)
|
||||
}
|
||||
@@ -1,5 +1,6 @@
|
||||
// Server function modules (compiled for both web and server;
|
||||
// the #[server] macro generates client stubs for the web target)
|
||||
pub mod chat;
|
||||
pub mod dast;
|
||||
pub mod findings;
|
||||
pub mod graph;
|
||||
|
||||
232
compliance-dashboard/src/pages/chat.rs
Normal file
232
compliance-dashboard/src/pages/chat.rs
Normal file
@@ -0,0 +1,232 @@
|
||||
use dioxus::prelude::*;
|
||||
|
||||
use crate::components::page_header::PageHeader;
|
||||
use crate::infrastructure::chat::{
|
||||
fetch_embedding_status, send_chat_message, trigger_embedding_build, ChatHistoryMessage,
|
||||
SourceRef,
|
||||
};
|
||||
|
||||
/// A UI-level chat message
|
||||
#[derive(Clone, Debug)]
|
||||
struct UiChatMessage {
|
||||
role: String,
|
||||
content: String,
|
||||
sources: Vec<SourceRef>,
|
||||
}
|
||||
|
||||
#[component]
|
||||
pub fn ChatPage(repo_id: String) -> Element {
|
||||
let mut messages: Signal<Vec<UiChatMessage>> = use_signal(Vec::new);
|
||||
let mut input_text = use_signal(String::new);
|
||||
let mut loading = use_signal(|| false);
|
||||
let mut building = use_signal(|| false);
|
||||
|
||||
let repo_id_for_status = repo_id.clone();
|
||||
let mut embedding_status = use_resource(move || {
|
||||
let rid = repo_id_for_status.clone();
|
||||
async move { fetch_embedding_status(rid).await.ok() }
|
||||
});
|
||||
|
||||
let has_embeddings = {
|
||||
let status = embedding_status.read();
|
||||
match &*status {
|
||||
Some(Some(resp)) => resp
|
||||
.data
|
||||
.as_ref()
|
||||
.map(|d| d.status == "completed")
|
||||
.unwrap_or(false),
|
||||
_ => false,
|
||||
}
|
||||
};
|
||||
|
||||
let embedding_status_text = {
|
||||
let status = embedding_status.read();
|
||||
match &*status {
|
||||
Some(Some(resp)) => match &resp.data {
|
||||
Some(d) => match d.status.as_str() {
|
||||
"completed" => format!(
|
||||
"Embeddings ready: {}/{} chunks",
|
||||
d.embedded_chunks, d.total_chunks
|
||||
),
|
||||
"running" => format!(
|
||||
"Building embeddings: {}/{}...",
|
||||
d.embedded_chunks, d.total_chunks
|
||||
),
|
||||
"failed" => format!(
|
||||
"Embedding build failed: {}",
|
||||
d.error_message.as_deref().unwrap_or("unknown error")
|
||||
),
|
||||
s => format!("Status: {s}"),
|
||||
},
|
||||
None => "No embeddings built yet".to_string(),
|
||||
},
|
||||
Some(None) => "Failed to check embedding status".to_string(),
|
||||
None => "Checking embedding status...".to_string(),
|
||||
}
|
||||
};
|
||||
|
||||
let repo_id_for_build = repo_id.clone();
|
||||
let on_build = move |_| {
|
||||
let rid = repo_id_for_build.clone();
|
||||
building.set(true);
|
||||
spawn(async move {
|
||||
let _ = trigger_embedding_build(rid).await;
|
||||
building.set(false);
|
||||
embedding_status.restart();
|
||||
});
|
||||
};
|
||||
|
||||
let repo_id_for_send = repo_id.clone();
|
||||
let mut do_send = move || {
|
||||
let text = input_text.read().trim().to_string();
|
||||
if text.is_empty() || *loading.read() {
|
||||
return;
|
||||
}
|
||||
|
||||
let rid = repo_id_for_send.clone();
|
||||
let user_msg = text.clone();
|
||||
|
||||
// Add user message to UI
|
||||
messages.write().push(UiChatMessage {
|
||||
role: "user".to_string(),
|
||||
content: user_msg.clone(),
|
||||
sources: Vec::new(),
|
||||
});
|
||||
input_text.set(String::new());
|
||||
loading.set(true);
|
||||
|
||||
spawn(async move {
|
||||
// Build history from existing messages
|
||||
let history: Vec<ChatHistoryMessage> = messages
|
||||
.read()
|
||||
.iter()
|
||||
.filter(|m| m.role == "user" || m.role == "assistant")
|
||||
.rev()
|
||||
.skip(1) // skip the message we just added
|
||||
.take(10) // limit history
|
||||
.collect::<Vec<_>>()
|
||||
.into_iter()
|
||||
.rev()
|
||||
.map(|m| ChatHistoryMessage {
|
||||
role: m.role.clone(),
|
||||
content: m.content.clone(),
|
||||
})
|
||||
.collect();
|
||||
|
||||
match send_chat_message(rid, user_msg, history).await {
|
||||
Ok(resp) => {
|
||||
messages.write().push(UiChatMessage {
|
||||
role: "assistant".to_string(),
|
||||
content: resp.data.message,
|
||||
sources: resp.data.sources,
|
||||
});
|
||||
}
|
||||
Err(e) => {
|
||||
messages.write().push(UiChatMessage {
|
||||
role: "assistant".to_string(),
|
||||
content: format!("Error: {e}"),
|
||||
sources: Vec::new(),
|
||||
});
|
||||
}
|
||||
}
|
||||
loading.set(false);
|
||||
});
|
||||
};
|
||||
|
||||
let mut do_send_click = do_send.clone();
|
||||
|
||||
rsx! {
|
||||
PageHeader { title: "AI Chat" }
|
||||
|
||||
// Embedding status banner
|
||||
div { class: "chat-embedding-banner",
|
||||
span { "{embedding_status_text}" }
|
||||
button {
|
||||
class: "btn btn-sm",
|
||||
disabled: *building.read(),
|
||||
onclick: on_build,
|
||||
if *building.read() { "Building..." } else { "Build Embeddings" }
|
||||
}
|
||||
}
|
||||
|
||||
div { class: "chat-container",
|
||||
// Message list
|
||||
div { class: "chat-messages",
|
||||
if messages.read().is_empty() && !*loading.read() {
|
||||
div { class: "chat-empty",
|
||||
h3 { "Ask anything about your codebase" }
|
||||
p { "Build embeddings first, then ask questions about functions, architecture, patterns, and more." }
|
||||
}
|
||||
}
|
||||
for (i, msg) in messages.read().iter().enumerate() {
|
||||
{
|
||||
let class = if msg.role == "user" {
|
||||
"chat-message chat-message-user"
|
||||
} else {
|
||||
"chat-message chat-message-assistant"
|
||||
};
|
||||
let content = msg.content.clone();
|
||||
let sources = msg.sources.clone();
|
||||
rsx! {
|
||||
div { class: class, key: "{i}",
|
||||
div { class: "chat-message-role",
|
||||
if msg.role == "user" { "You" } else { "Assistant" }
|
||||
}
|
||||
div { class: "chat-message-content", "{content}" }
|
||||
if !sources.is_empty() {
|
||||
div { class: "chat-sources",
|
||||
span { class: "chat-sources-label", "Sources:" }
|
||||
for src in sources {
|
||||
div { class: "chat-source-card",
|
||||
div { class: "chat-source-header",
|
||||
span { class: "chat-source-name",
|
||||
"{src.qualified_name}"
|
||||
}
|
||||
span { class: "chat-source-location",
|
||||
"{src.file_path}:{src.start_line}-{src.end_line}"
|
||||
}
|
||||
}
|
||||
pre { class: "chat-source-snippet",
|
||||
code { "{src.snippet}" }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if *loading.read() {
|
||||
div { class: "chat-message chat-message-assistant",
|
||||
div { class: "chat-message-role", "Assistant" }
|
||||
div { class: "chat-message-content chat-typing", "Thinking..." }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Input area
|
||||
div { class: "chat-input-area",
|
||||
textarea {
|
||||
class: "chat-input",
|
||||
placeholder: "Ask about your codebase...",
|
||||
value: "{input_text}",
|
||||
disabled: !has_embeddings,
|
||||
oninput: move |e| input_text.set(e.value()),
|
||||
onkeydown: move |e: Event<KeyboardData>| {
|
||||
if e.key() == Key::Enter && !e.modifiers().shift() {
|
||||
e.prevent_default();
|
||||
do_send();
|
||||
}
|
||||
},
|
||||
}
|
||||
button {
|
||||
class: "btn chat-send-btn",
|
||||
disabled: *loading.read() || !has_embeddings,
|
||||
onclick: move |_| do_send_click(),
|
||||
"Send"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
71
compliance-dashboard/src/pages/chat_index.rs
Normal file
71
compliance-dashboard/src/pages/chat_index.rs
Normal file
@@ -0,0 +1,71 @@
|
||||
use dioxus::prelude::*;
|
||||
|
||||
use crate::app::Route;
|
||||
use crate::components::page_header::PageHeader;
|
||||
use crate::infrastructure::chat::fetch_embedding_status;
|
||||
use crate::infrastructure::repositories::fetch_repositories;
|
||||
|
||||
#[component]
|
||||
pub fn ChatIndexPage() -> Element {
|
||||
let repos = use_resource(|| async { fetch_repositories(1).await.ok() });
|
||||
|
||||
rsx! {
|
||||
PageHeader {
|
||||
title: "AI Chat",
|
||||
description: "Ask questions about your codebase using RAG-augmented AI",
|
||||
}
|
||||
|
||||
match &*repos.read() {
|
||||
Some(Some(data)) => {
|
||||
let repo_list = &data.data;
|
||||
if repo_list.is_empty() {
|
||||
rsx! {
|
||||
div { class: "card",
|
||||
p { "No repositories found. Add a repository first." }
|
||||
}
|
||||
}
|
||||
} else {
|
||||
rsx! {
|
||||
div { class: "graph-index-grid",
|
||||
for repo in repo_list {
|
||||
{
|
||||
let repo_id = repo.id.map(|id| id.to_hex()).unwrap_or_default();
|
||||
let name = repo.name.clone();
|
||||
let url = repo.git_url.clone();
|
||||
let branch = repo.default_branch.clone();
|
||||
rsx! {
|
||||
Link {
|
||||
to: Route::ChatPage { repo_id },
|
||||
class: "graph-repo-card",
|
||||
div { class: "graph-repo-card-header",
|
||||
div { class: "graph-repo-card-icon", "\u{1F4AC}" }
|
||||
h3 { class: "graph-repo-card-name", "{name}" }
|
||||
}
|
||||
if !url.is_empty() {
|
||||
p { class: "graph-repo-card-url", "{url}" }
|
||||
}
|
||||
div { class: "graph-repo-card-meta",
|
||||
span { class: "graph-repo-card-tag",
|
||||
"\u{E0A0} {branch}"
|
||||
}
|
||||
span { class: "graph-repo-card-tag",
|
||||
"AI Chat"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
Some(None) => rsx! {
|
||||
div { class: "card", p { "Failed to load repositories." } }
|
||||
},
|
||||
None => rsx! {
|
||||
div { class: "loading", "Loading repositories..." }
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,3 +1,5 @@
|
||||
pub mod chat;
|
||||
pub mod chat_index;
|
||||
pub mod dast_finding_detail;
|
||||
pub mod dast_findings;
|
||||
pub mod dast_overview;
|
||||
@@ -13,6 +15,8 @@ pub mod repositories;
|
||||
pub mod sbom;
|
||||
pub mod settings;
|
||||
|
||||
pub use chat::ChatPage;
|
||||
pub use chat_index::ChatIndexPage;
|
||||
pub use dast_finding_detail::DastFindingDetailPage;
|
||||
pub use dast_findings::DastFindingsPage;
|
||||
pub use dast_overview::DastOverviewPage;
|
||||
|
||||
96
compliance-graph/src/graph/chunking.rs
Normal file
96
compliance-graph/src/graph/chunking.rs
Normal file
@@ -0,0 +1,96 @@
|
||||
use std::path::Path;
|
||||
|
||||
use compliance_core::models::graph::CodeNode;
|
||||
|
||||
/// A chunk of code extracted from a source file, ready for embedding
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct CodeChunk {
|
||||
pub qualified_name: String,
|
||||
pub kind: String,
|
||||
pub file_path: String,
|
||||
pub start_line: u32,
|
||||
pub end_line: u32,
|
||||
pub language: String,
|
||||
pub content: String,
|
||||
pub context_header: String,
|
||||
pub token_estimate: u32,
|
||||
}
|
||||
|
||||
/// Extract embeddable code chunks from parsed CodeNodes.
|
||||
///
|
||||
/// For each node, reads the corresponding source lines from disk,
|
||||
/// builds a context header, and estimates tokens.
|
||||
pub fn extract_chunks(
|
||||
repo_path: &Path,
|
||||
nodes: &[CodeNode],
|
||||
max_chunk_tokens: u32,
|
||||
) -> Vec<CodeChunk> {
|
||||
let mut chunks = Vec::new();
|
||||
|
||||
for node in nodes {
|
||||
let file = repo_path.join(&node.file_path);
|
||||
let source = match std::fs::read_to_string(&file) {
|
||||
Ok(s) => s,
|
||||
Err(_) => continue,
|
||||
};
|
||||
|
||||
let lines: Vec<&str> = source.lines().collect();
|
||||
let start = node.start_line.saturating_sub(1) as usize;
|
||||
let end = (node.end_line as usize).min(lines.len());
|
||||
if start >= end {
|
||||
continue;
|
||||
}
|
||||
|
||||
let content: String = lines[start..end].join("\n");
|
||||
|
||||
// Skip tiny chunks
|
||||
if content.len() < 50 {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Estimate tokens (~4 chars per token)
|
||||
let mut token_estimate = (content.len() / 4) as u32;
|
||||
|
||||
// Truncate if too large
|
||||
let final_content = if token_estimate > max_chunk_tokens {
|
||||
let max_chars = (max_chunk_tokens as usize) * 4;
|
||||
token_estimate = max_chunk_tokens;
|
||||
content.chars().take(max_chars).collect()
|
||||
} else {
|
||||
content
|
||||
};
|
||||
|
||||
// Build context header: file path + containing scope hint
|
||||
let context_header = build_context_header(
|
||||
&node.file_path,
|
||||
&node.qualified_name,
|
||||
&node.kind.to_string(),
|
||||
);
|
||||
|
||||
chunks.push(CodeChunk {
|
||||
qualified_name: node.qualified_name.clone(),
|
||||
kind: node.kind.to_string(),
|
||||
file_path: node.file_path.clone(),
|
||||
start_line: node.start_line,
|
||||
end_line: node.end_line,
|
||||
language: node.language.clone(),
|
||||
content: final_content,
|
||||
context_header,
|
||||
token_estimate,
|
||||
});
|
||||
}
|
||||
|
||||
chunks
|
||||
}
|
||||
|
||||
fn build_context_header(file_path: &str, qualified_name: &str, kind: &str) -> String {
|
||||
// Extract containing module/class from qualified name
|
||||
// e.g. "src/main.rs::MyStruct::my_method" → parent is "MyStruct"
|
||||
let parts: Vec<&str> = qualified_name.split("::").collect();
|
||||
if parts.len() >= 2 {
|
||||
let parent = parts[..parts.len() - 1].join("::");
|
||||
format!("// {file_path} | {kind} in {parent}")
|
||||
} else {
|
||||
format!("// {file_path} | {kind}")
|
||||
}
|
||||
}
|
||||
238
compliance-graph/src/graph/embedding_store.rs
Normal file
238
compliance-graph/src/graph/embedding_store.rs
Normal file
@@ -0,0 +1,238 @@
|
||||
use compliance_core::error::CoreError;
|
||||
use compliance_core::models::embedding::{CodeEmbedding, EmbeddingBuildRun, EmbeddingBuildStatus};
|
||||
use futures_util::TryStreamExt;
|
||||
use mongodb::bson::doc;
|
||||
use mongodb::{Collection, Database, IndexModel};
|
||||
use tracing::info;
|
||||
|
||||
/// MongoDB persistence layer for code embeddings and vector search
|
||||
pub struct EmbeddingStore {
|
||||
embeddings: Collection<CodeEmbedding>,
|
||||
builds: Collection<EmbeddingBuildRun>,
|
||||
}
|
||||
|
||||
impl EmbeddingStore {
|
||||
pub fn new(db: &Database) -> Self {
|
||||
Self {
|
||||
embeddings: db.collection("code_embeddings"),
|
||||
builds: db.collection("embedding_builds"),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create standard indexes. NOTE: The Atlas Vector Search index must be
|
||||
/// created via the Atlas UI or CLI with the following definition:
|
||||
/// ```json
|
||||
/// {
|
||||
/// "fields": [
|
||||
/// { "type": "vector", "path": "embedding", "numDimensions": 1536, "similarity": "cosine" },
|
||||
/// { "type": "filter", "path": "repo_id" }
|
||||
/// ]
|
||||
/// }
|
||||
/// ```
|
||||
pub async fn ensure_indexes(&self) -> Result<(), CoreError> {
|
||||
self.embeddings
|
||||
.create_index(
|
||||
IndexModel::builder()
|
||||
.keys(doc! { "repo_id": 1, "graph_build_id": 1 })
|
||||
.build(),
|
||||
)
|
||||
.await?;
|
||||
|
||||
self.builds
|
||||
.create_index(
|
||||
IndexModel::builder()
|
||||
.keys(doc! { "repo_id": 1, "started_at": -1 })
|
||||
.build(),
|
||||
)
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Delete all embeddings for a repository
|
||||
pub async fn delete_repo_embeddings(&self, repo_id: &str) -> Result<u64, CoreError> {
|
||||
let result = self
|
||||
.embeddings
|
||||
.delete_many(doc! { "repo_id": repo_id })
|
||||
.await?;
|
||||
info!(
|
||||
"Deleted {} embeddings for repo {repo_id}",
|
||||
result.deleted_count
|
||||
);
|
||||
Ok(result.deleted_count)
|
||||
}
|
||||
|
||||
/// Store embeddings in batches of 500
|
||||
pub async fn store_embeddings(&self, embeddings: &[CodeEmbedding]) -> Result<u64, CoreError> {
|
||||
let mut total_inserted = 0u64;
|
||||
for batch in embeddings.chunks(500) {
|
||||
let result = self.embeddings.insert_many(batch).await?;
|
||||
total_inserted += result.inserted_ids.len() as u64;
|
||||
}
|
||||
info!("Stored {total_inserted} embeddings");
|
||||
Ok(total_inserted)
|
||||
}
|
||||
|
||||
/// Store a new build run
|
||||
pub async fn store_build(&self, build: &EmbeddingBuildRun) -> Result<(), CoreError> {
|
||||
self.builds.insert_one(build).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Update an existing build run
|
||||
pub async fn update_build(
|
||||
&self,
|
||||
repo_id: &str,
|
||||
graph_build_id: &str,
|
||||
status: EmbeddingBuildStatus,
|
||||
embedded_chunks: u32,
|
||||
error_message: Option<String>,
|
||||
) -> Result<(), CoreError> {
|
||||
let mut update = doc! {
|
||||
"$set": {
|
||||
"status": mongodb::bson::to_bson(&status).unwrap_or_default(),
|
||||
"embedded_chunks": embedded_chunks as i64,
|
||||
}
|
||||
};
|
||||
|
||||
if status == EmbeddingBuildStatus::Completed || status == EmbeddingBuildStatus::Failed {
|
||||
update
|
||||
.get_document_mut("$set")
|
||||
.unwrap()
|
||||
.insert("completed_at", mongodb::bson::DateTime::now());
|
||||
}
|
||||
|
||||
if let Some(msg) = error_message {
|
||||
update
|
||||
.get_document_mut("$set")
|
||||
.unwrap()
|
||||
.insert("error_message", msg);
|
||||
}
|
||||
|
||||
self.builds
|
||||
.update_one(
|
||||
doc! { "repo_id": repo_id, "graph_build_id": graph_build_id },
|
||||
update,
|
||||
)
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Get the latest embedding build for a repository
|
||||
pub async fn get_latest_build(
|
||||
&self,
|
||||
repo_id: &str,
|
||||
) -> Result<Option<EmbeddingBuildRun>, CoreError> {
|
||||
Ok(self
|
||||
.builds
|
||||
.find_one(doc! { "repo_id": repo_id })
|
||||
.sort(doc! { "started_at": -1 })
|
||||
.await?)
|
||||
}
|
||||
|
||||
/// Perform vector search. Tries Atlas $vectorSearch first, falls back to
|
||||
/// brute-force cosine similarity for local MongoDB instances.
|
||||
pub async fn vector_search(
|
||||
&self,
|
||||
repo_id: &str,
|
||||
query_embedding: Vec<f64>,
|
||||
limit: u32,
|
||||
min_score: f64,
|
||||
) -> Result<Vec<(CodeEmbedding, f64)>, CoreError> {
|
||||
match self
|
||||
.atlas_vector_search(repo_id, &query_embedding, limit, min_score)
|
||||
.await
|
||||
{
|
||||
Ok(results) => Ok(results),
|
||||
Err(e) => {
|
||||
info!(
|
||||
"Atlas $vectorSearch unavailable ({e}), falling back to brute-force cosine similarity"
|
||||
);
|
||||
self.bruteforce_vector_search(repo_id, &query_embedding, limit, min_score)
|
||||
.await
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Atlas $vectorSearch aggregation stage (requires Atlas Vector Search index)
|
||||
async fn atlas_vector_search(
|
||||
&self,
|
||||
repo_id: &str,
|
||||
query_embedding: &[f64],
|
||||
limit: u32,
|
||||
min_score: f64,
|
||||
) -> Result<Vec<(CodeEmbedding, f64)>, CoreError> {
|
||||
use mongodb::bson::{Bson, Document};
|
||||
|
||||
let pipeline = vec![
|
||||
doc! {
|
||||
"$vectorSearch": {
|
||||
"index": "embedding_vector_index",
|
||||
"path": "embedding",
|
||||
"queryVector": query_embedding.iter().map(|&v| Bson::Double(v)).collect::<Vec<_>>(),
|
||||
"numCandidates": (limit * 10) as i64,
|
||||
"limit": limit as i64,
|
||||
"filter": { "repo_id": repo_id },
|
||||
}
|
||||
},
|
||||
doc! {
|
||||
"$addFields": {
|
||||
"search_score": { "$meta": "vectorSearchScore" }
|
||||
}
|
||||
},
|
||||
doc! {
|
||||
"$match": {
|
||||
"search_score": { "$gte": min_score }
|
||||
}
|
||||
},
|
||||
];
|
||||
|
||||
let mut cursor = self.embeddings.aggregate(pipeline).await?;
|
||||
|
||||
let mut results = Vec::new();
|
||||
while let Some(doc) = cursor.try_next().await? {
|
||||
let score = doc.get_f64("search_score").unwrap_or(0.0);
|
||||
let mut clean_doc: Document = doc;
|
||||
clean_doc.remove("search_score");
|
||||
if let Ok(embedding) = mongodb::bson::from_document::<CodeEmbedding>(clean_doc) {
|
||||
results.push((embedding, score));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(results)
|
||||
}
|
||||
|
||||
/// Brute-force cosine similarity fallback for local MongoDB without Atlas
|
||||
async fn bruteforce_vector_search(
|
||||
&self,
|
||||
repo_id: &str,
|
||||
query_embedding: &[f64],
|
||||
limit: u32,
|
||||
min_score: f64,
|
||||
) -> Result<Vec<(CodeEmbedding, f64)>, CoreError> {
|
||||
let mut cursor = self.embeddings.find(doc! { "repo_id": repo_id }).await?;
|
||||
|
||||
let query_norm = dot(query_embedding, query_embedding).sqrt();
|
||||
let mut scored: Vec<(CodeEmbedding, f64)> = Vec::new();
|
||||
|
||||
while let Some(emb) = cursor.try_next().await? {
|
||||
let doc_norm = dot(&emb.embedding, &emb.embedding).sqrt();
|
||||
let score = if query_norm > 0.0 && doc_norm > 0.0 {
|
||||
dot(query_embedding, &emb.embedding) / (query_norm * doc_norm)
|
||||
} else {
|
||||
0.0
|
||||
};
|
||||
if score >= min_score {
|
||||
scored.push((emb, score));
|
||||
}
|
||||
}
|
||||
|
||||
scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
|
||||
scored.truncate(limit as usize);
|
||||
Ok(scored)
|
||||
}
|
||||
}
|
||||
|
||||
fn dot(a: &[f64], b: &[f64]) -> f64 {
|
||||
a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
|
||||
}
|
||||
@@ -1,4 +1,6 @@
|
||||
pub mod chunking;
|
||||
pub mod community;
|
||||
pub mod embedding_store;
|
||||
pub mod engine;
|
||||
pub mod impact;
|
||||
pub mod persistence;
|
||||
|
||||
Reference in New Issue
Block a user