diff --git a/compliance-agent/src/agent.rs b/compliance-agent/src/agent.rs index 1577acb..e13271d 100644 --- a/compliance-agent/src/agent.rs +++ b/compliance-agent/src/agent.rs @@ -20,6 +20,7 @@ impl ComplianceAgent { config.litellm_url.clone(), config.litellm_api_key.clone(), config.litellm_model.clone(), + config.litellm_embed_model.clone(), )); Self { config, diff --git a/compliance-agent/src/api/handlers/chat.rs b/compliance-agent/src/api/handlers/chat.rs new file mode 100644 index 0000000..aafe290 --- /dev/null +++ b/compliance-agent/src/api/handlers/chat.rs @@ -0,0 +1,238 @@ +use std::sync::Arc; + +use axum::extract::{Extension, Path}; +use axum::http::StatusCode; +use axum::Json; +use mongodb::bson::doc; + +use compliance_core::models::chat::{ChatRequest, ChatResponse, SourceReference}; +use compliance_core::models::embedding::EmbeddingBuildRun; +use compliance_graph::graph::embedding_store::EmbeddingStore; + +use crate::agent::ComplianceAgent; +use crate::rag::pipeline::RagPipeline; + +use super::ApiResponse; + +type AgentExt = Extension>; + +/// POST /api/v1/chat/:repo_id — Send a chat message with RAG context +pub async fn chat( + Extension(agent): AgentExt, + Path(repo_id): Path, + Json(req): Json, +) -> Result>, StatusCode> { + let pipeline = RagPipeline::new(agent.llm.clone(), agent.db.inner()); + + // Step 1: Embed the user's message + let query_vectors = agent + .llm + .embed(vec![req.message.clone()]) + .await + .map_err(|e| { + tracing::error!("Failed to embed query: {e}"); + StatusCode::INTERNAL_SERVER_ERROR + })?; + + let query_embedding = query_vectors.into_iter().next().ok_or_else(|| { + tracing::error!("Empty embedding response"); + StatusCode::INTERNAL_SERVER_ERROR + })?; + + // Step 2: Vector search — retrieve top 8 chunks + let search_results = pipeline + .store() + .vector_search(&repo_id, query_embedding, 8, 0.5) + .await + .map_err(|e| { + tracing::error!("Vector search failed: {e}"); + StatusCode::INTERNAL_SERVER_ERROR + })?; + + // Step 3: Build system prompt with code context + let mut context_parts = Vec::new(); + let mut sources = Vec::new(); + + for (embedding, score) in &search_results { + context_parts.push(format!( + "--- {} ({}, {}:L{}-L{}) ---\n{}", + embedding.qualified_name, + embedding.kind, + embedding.file_path, + embedding.start_line, + embedding.end_line, + embedding.content, + )); + + // Truncate snippet for the response + let snippet: String = embedding + .content + .lines() + .take(10) + .collect::>() + .join("\n"); + sources.push(SourceReference { + file_path: embedding.file_path.clone(), + qualified_name: embedding.qualified_name.clone(), + start_line: embedding.start_line, + end_line: embedding.end_line, + language: embedding.language.clone(), + snippet, + score: *score, + }); + } + + let code_context = if context_parts.is_empty() { + "No relevant code context found.".to_string() + } else { + context_parts.join("\n\n") + }; + + let system_prompt = format!( + "You are an expert code assistant for a software repository. \ + Answer the user's question based on the code context below. \ + Reference specific files and functions when relevant. \ + If the context doesn't contain enough information, say so.\n\n\ + ## Code Context\n\n{code_context}" + ); + + // Step 4: Build messages array with history + let mut messages: Vec<(String, String)> = Vec::new(); + messages.push(("system".to_string(), system_prompt)); + + for msg in &req.history { + messages.push((msg.role.clone(), msg.content.clone())); + } + messages.push(("user".to_string(), req.message)); + + // Step 5: Call LLM + let response_text = agent + .llm + .chat_with_messages(messages, Some(0.3)) + .await + .map_err(|e| { + tracing::error!("LLM chat failed: {e}"); + StatusCode::INTERNAL_SERVER_ERROR + })?; + + Ok(Json(ApiResponse { + data: ChatResponse { + message: response_text, + sources, + }, + total: None, + page: None, + })) +} + +/// POST /api/v1/chat/:repo_id/build-embeddings — Trigger embedding build +pub async fn build_embeddings( + Extension(agent): AgentExt, + Path(repo_id): Path, +) -> Result, StatusCode> { + let agent_clone = (*agent).clone(); + tokio::spawn(async move { + let repo = match agent_clone + .db + .repositories() + .find_one(doc! { "_id": mongodb::bson::oid::ObjectId::parse_str(&repo_id).ok() }) + .await + { + Ok(Some(r)) => r, + _ => { + tracing::error!("Repository {repo_id} not found for embedding build"); + return; + } + }; + + // Get latest graph build + let build = match agent_clone + .db + .graph_builds() + .find_one(doc! { "repo_id": &repo_id }) + .sort(doc! { "started_at": -1 }) + .await + { + Ok(Some(b)) => b, + _ => { + tracing::error!("[{repo_id}] No graph build found — build graph first"); + return; + } + }; + + let graph_build_id = build + .id + .map(|id| id.to_hex()) + .unwrap_or_else(|| "unknown".to_string()); + + // Get nodes + let nodes: Vec = match agent_clone + .db + .graph_nodes() + .find(doc! { "repo_id": &repo_id }) + .await + { + Ok(cursor) => { + use futures_util::StreamExt; + let mut items = Vec::new(); + let mut cursor = cursor; + while let Some(Ok(item)) = cursor.next().await { + items.push(item); + } + items + } + Err(e) => { + tracing::error!("[{repo_id}] Failed to fetch nodes: {e}"); + return; + } + }; + + let git_ops = crate::pipeline::git::GitOps::new(&agent_clone.config.git_clone_base_path); + let repo_path = match git_ops.clone_or_fetch(&repo.git_url, &repo.name) { + Ok(p) => p, + Err(e) => { + tracing::error!("Failed to clone repo for embedding build: {e}"); + return; + } + }; + + let pipeline = RagPipeline::new(agent_clone.llm.clone(), agent_clone.db.inner()); + match pipeline + .build_embeddings(&repo_id, &repo_path, &graph_build_id, &nodes) + .await + { + Ok(run) => { + tracing::info!( + "[{repo_id}] Embedding build complete: {}/{} chunks", + run.embedded_chunks, + run.total_chunks + ); + } + Err(e) => { + tracing::error!("[{repo_id}] Embedding build failed: {e}"); + } + } + }); + + Ok(Json( + serde_json::json!({ "status": "embedding_build_triggered" }), + )) +} + +/// GET /api/v1/chat/:repo_id/status — Get latest embedding build status +pub async fn embedding_status( + Extension(agent): AgentExt, + Path(repo_id): Path, +) -> Result>>, StatusCode> { + let store = EmbeddingStore::new(agent.db.inner()); + let build = store.get_latest_build(&repo_id).await.map_err(|e| { + tracing::error!("Failed to get embedding status: {e}"); + StatusCode::INTERNAL_SERVER_ERROR + })?; + + Ok(Json(ApiResponse { + data: build, + total: None, + page: None, + })) +} diff --git a/compliance-agent/src/api/handlers/mod.rs b/compliance-agent/src/api/handlers/mod.rs index 39af052..b9beccc 100644 --- a/compliance-agent/src/api/handlers/mod.rs +++ b/compliance-agent/src/api/handlers/mod.rs @@ -1,3 +1,4 @@ +pub mod chat; pub mod dast; pub mod graph; diff --git a/compliance-agent/src/api/routes.rs b/compliance-agent/src/api/routes.rs index dd2794c..8f0e72a 100644 --- a/compliance-agent/src/api/routes.rs +++ b/compliance-agent/src/api/routes.rs @@ -23,10 +23,7 @@ pub fn build_router() -> Router { .route("/api/v1/issues", get(handlers::list_issues)) .route("/api/v1/scan-runs", get(handlers::list_scan_runs)) // Graph API endpoints - .route( - "/api/v1/graph/{repo_id}", - get(handlers::graph::get_graph), - ) + .route("/api/v1/graph/{repo_id}", get(handlers::graph::get_graph)) .route( "/api/v1/graph/{repo_id}/nodes", get(handlers::graph::get_nodes), @@ -52,14 +49,8 @@ pub fn build_router() -> Router { post(handlers::graph::trigger_build), ) // DAST API endpoints - .route( - "/api/v1/dast/targets", - get(handlers::dast::list_targets), - ) - .route( - "/api/v1/dast/targets", - post(handlers::dast::add_target), - ) + .route("/api/v1/dast/targets", get(handlers::dast::list_targets)) + .route("/api/v1/dast/targets", post(handlers::dast::add_target)) .route( "/api/v1/dast/targets/{id}/scan", post(handlers::dast::trigger_scan), @@ -68,12 +59,19 @@ pub fn build_router() -> Router { "/api/v1/dast/scan-runs", get(handlers::dast::list_scan_runs), ) - .route( - "/api/v1/dast/findings", - get(handlers::dast::list_findings), - ) + .route("/api/v1/dast/findings", get(handlers::dast::list_findings)) .route( "/api/v1/dast/findings/{id}", get(handlers::dast::get_finding), ) + // Chat / RAG API endpoints + .route("/api/v1/chat/{repo_id}", post(handlers::chat::chat)) + .route( + "/api/v1/chat/{repo_id}/build-embeddings", + post(handlers::chat::build_embeddings), + ) + .route( + "/api/v1/chat/{repo_id}/status", + get(handlers::chat::embedding_status), + ) } diff --git a/compliance-agent/src/config.rs b/compliance-agent/src/config.rs index 03ede73..06bf03d 100644 --- a/compliance-agent/src/config.rs +++ b/compliance-agent/src/config.rs @@ -24,6 +24,8 @@ pub fn load_config() -> Result { .unwrap_or_else(|| "http://localhost:4000".to_string()), litellm_api_key: SecretString::from(env_var_opt("LITELLM_API_KEY").unwrap_or_default()), litellm_model: env_var_opt("LITELLM_MODEL").unwrap_or_else(|| "gpt-4o".to_string()), + litellm_embed_model: env_var_opt("LITELLM_EMBED_MODEL") + .unwrap_or_else(|| "text-embedding-3-small".to_string()), github_token: env_secret_opt("GITHUB_TOKEN"), github_webhook_secret: env_secret_opt("GITHUB_WEBHOOK_SECRET"), gitlab_url: env_var_opt("GITLAB_URL"), diff --git a/compliance-agent/src/database.rs b/compliance-agent/src/database.rs index 3f32df5..c2b0740 100644 --- a/compliance-agent/src/database.rs +++ b/compliance-agent/src/database.rs @@ -127,11 +127,7 @@ impl Database { // dast_targets: index on repo_id self.dast_targets() - .create_index( - IndexModel::builder() - .keys(doc! { "repo_id": 1 }) - .build(), - ) + .create_index(IndexModel::builder().keys(doc! { "repo_id": 1 }).build()) .await?; // dast_scan_runs: compound (target_id, started_at DESC) @@ -152,6 +148,24 @@ impl Database { ) .await?; + // code_embeddings: compound (repo_id, graph_build_id) + self.code_embeddings() + .create_index( + IndexModel::builder() + .keys(doc! { "repo_id": 1, "graph_build_id": 1 }) + .build(), + ) + .await?; + + // embedding_builds: compound (repo_id, started_at DESC) + self.embedding_builds() + .create_index( + IndexModel::builder() + .keys(doc! { "repo_id": 1, "started_at": -1 }) + .build(), + ) + .await?; + tracing::info!("Database indexes ensured"); Ok(()) } @@ -210,6 +224,17 @@ impl Database { self.inner.collection("dast_findings") } + // Embedding collections + pub fn code_embeddings(&self) -> Collection { + self.inner.collection("code_embeddings") + } + + pub fn embedding_builds( + &self, + ) -> Collection { + self.inner.collection("embedding_builds") + } + #[allow(dead_code)] pub fn raw_collection(&self, name: &str) -> Collection { self.inner.collection(name) diff --git a/compliance-agent/src/llm/client.rs b/compliance-agent/src/llm/client.rs index f9a1653..c7a571d 100644 --- a/compliance-agent/src/llm/client.rs +++ b/compliance-agent/src/llm/client.rs @@ -8,6 +8,7 @@ pub struct LlmClient { base_url: String, api_key: SecretString, model: String, + embed_model: String, http: reqwest::Client, } @@ -42,16 +43,46 @@ struct ChatResponseMessage { content: String, } +/// Request body for the embeddings API +#[derive(Serialize)] +struct EmbeddingRequest { + model: String, + input: Vec, +} + +/// Response from the embeddings API +#[derive(Deserialize)] +struct EmbeddingResponse { + data: Vec, +} + +/// A single embedding result +#[derive(Deserialize)] +struct EmbeddingData { + embedding: Vec, + index: usize, +} + impl LlmClient { - pub fn new(base_url: String, api_key: SecretString, model: String) -> Self { + pub fn new( + base_url: String, + api_key: SecretString, + model: String, + embed_model: String, + ) -> Self { Self { base_url, api_key, model, + embed_model, http: reqwest::Client::new(), } } + pub fn embed_model(&self) -> &str { + &self.embed_model + } + pub async fn chat( &self, system_prompt: &str, @@ -169,4 +200,49 @@ impl LlmClient { .map(|c| c.message.content.clone()) .ok_or_else(|| AgentError::Other("Empty response from LiteLLM".to_string())) } + + /// Generate embeddings for a batch of texts + pub async fn embed(&self, texts: Vec) -> Result>, AgentError> { + let url = format!("{}/v1/embeddings", self.base_url.trim_end_matches('/')); + + let request_body = EmbeddingRequest { + model: self.embed_model.clone(), + input: texts, + }; + + let mut req = self + .http + .post(&url) + .header("content-type", "application/json") + .json(&request_body); + + let key = self.api_key.expose_secret(); + if !key.is_empty() { + req = req.header("Authorization", format!("Bearer {key}")); + } + + let resp = req + .send() + .await + .map_err(|e| AgentError::Other(format!("Embedding request failed: {e}")))?; + + if !resp.status().is_success() { + let status = resp.status(); + let body = resp.text().await.unwrap_or_default(); + return Err(AgentError::Other(format!( + "Embedding API returned {status}: {body}" + ))); + } + + let body: EmbeddingResponse = resp + .json() + .await + .map_err(|e| AgentError::Other(format!("Failed to parse embedding response: {e}")))?; + + // Sort by index to maintain input order + let mut data = body.data; + data.sort_by_key(|d| d.index); + + Ok(data.into_iter().map(|d| d.embedding).collect()) + } } diff --git a/compliance-agent/src/main.rs b/compliance-agent/src/main.rs index 98cf902..f8518bb 100644 --- a/compliance-agent/src/main.rs +++ b/compliance-agent/src/main.rs @@ -7,6 +7,7 @@ mod database; mod error; mod llm; mod pipeline; +mod rag; mod scheduler; #[allow(dead_code)] mod trackers; diff --git a/compliance-agent/src/rag/mod.rs b/compliance-agent/src/rag/mod.rs new file mode 100644 index 0000000..626c2e4 --- /dev/null +++ b/compliance-agent/src/rag/mod.rs @@ -0,0 +1 @@ +pub mod pipeline; diff --git a/compliance-agent/src/rag/pipeline.rs b/compliance-agent/src/rag/pipeline.rs new file mode 100644 index 0000000..19d5949 --- /dev/null +++ b/compliance-agent/src/rag/pipeline.rs @@ -0,0 +1,164 @@ +use std::path::Path; +use std::sync::Arc; + +use chrono::Utc; +use compliance_core::models::embedding::{CodeEmbedding, EmbeddingBuildRun, EmbeddingBuildStatus}; +use compliance_core::models::graph::CodeNode; +use compliance_graph::graph::chunking::extract_chunks; +use compliance_graph::graph::embedding_store::EmbeddingStore; +use tracing::{error, info}; + +use crate::error::AgentError; +use crate::llm::LlmClient; + +/// RAG pipeline for building embeddings and performing retrieval +pub struct RagPipeline { + llm: Arc, + embedding_store: EmbeddingStore, +} + +impl RagPipeline { + pub fn new(llm: Arc, db: &mongodb::Database) -> Self { + Self { + llm, + embedding_store: EmbeddingStore::new(db), + } + } + + pub fn store(&self) -> &EmbeddingStore { + &self.embedding_store + } + + /// Build embeddings for all code nodes in a repository + pub async fn build_embeddings( + &self, + repo_id: &str, + repo_path: &Path, + graph_build_id: &str, + nodes: &[CodeNode], + ) -> Result { + let embed_model = self.llm.embed_model().to_string(); + let mut build = + EmbeddingBuildRun::new(repo_id.to_string(), graph_build_id.to_string(), embed_model); + + // Step 1: Extract chunks + let chunks = extract_chunks(repo_path, nodes, 2048); + build.total_chunks = chunks.len() as u32; + info!( + "[{repo_id}] Extracted {} chunks for embedding", + chunks.len() + ); + + // Store the initial build record + self.embedding_store + .store_build(&build) + .await + .map_err(|e| AgentError::Other(format!("Failed to store build: {e}")))?; + + if chunks.is_empty() { + build.status = EmbeddingBuildStatus::Completed; + build.completed_at = Some(Utc::now()); + self.embedding_store + .update_build( + repo_id, + graph_build_id, + EmbeddingBuildStatus::Completed, + 0, + None, + ) + .await + .map_err(|e| AgentError::Other(format!("Failed to update build: {e}")))?; + return Ok(build); + } + + // Step 2: Delete old embeddings for this repo + self.embedding_store + .delete_repo_embeddings(repo_id) + .await + .map_err(|e| AgentError::Other(format!("Failed to delete old embeddings: {e}")))?; + + // Step 3: Batch embed (small batches to stay within model limits) + let batch_size = 20; + let mut all_embeddings = Vec::new(); + let mut embedded_count = 0u32; + + for batch_start in (0..chunks.len()).step_by(batch_size) { + let batch_end = (batch_start + batch_size).min(chunks.len()); + let batch_chunks = &chunks[batch_start..batch_end]; + + // Prepare texts: context_header + content + let texts: Vec = batch_chunks + .iter() + .map(|c| format!("{}\n{}", c.context_header, c.content)) + .collect(); + + match self.llm.embed(texts).await { + Ok(vectors) => { + for (chunk, embedding) in batch_chunks.iter().zip(vectors) { + all_embeddings.push(CodeEmbedding { + id: None, + repo_id: repo_id.to_string(), + graph_build_id: graph_build_id.to_string(), + qualified_name: chunk.qualified_name.clone(), + kind: chunk.kind.clone(), + file_path: chunk.file_path.clone(), + start_line: chunk.start_line, + end_line: chunk.end_line, + language: chunk.language.clone(), + content: chunk.content.clone(), + context_header: chunk.context_header.clone(), + embedding, + token_estimate: chunk.token_estimate, + created_at: Utc::now(), + }); + } + embedded_count += batch_chunks.len() as u32; + } + Err(e) => { + error!("[{repo_id}] Embedding batch failed: {e}"); + build.status = EmbeddingBuildStatus::Failed; + build.error_message = Some(e.to_string()); + build.completed_at = Some(Utc::now()); + let _ = self + .embedding_store + .update_build( + repo_id, + graph_build_id, + EmbeddingBuildStatus::Failed, + embedded_count, + Some(e.to_string()), + ) + .await; + return Ok(build); + } + } + } + + // Step 4: Store all embeddings + self.embedding_store + .store_embeddings(&all_embeddings) + .await + .map_err(|e| AgentError::Other(format!("Failed to store embeddings: {e}")))?; + + // Step 5: Update build status + build.status = EmbeddingBuildStatus::Completed; + build.embedded_chunks = embedded_count; + build.completed_at = Some(Utc::now()); + self.embedding_store + .update_build( + repo_id, + graph_build_id, + EmbeddingBuildStatus::Completed, + embedded_count, + None, + ) + .await + .map_err(|e| AgentError::Other(format!("Failed to update build: {e}")))?; + + info!( + "[{repo_id}] Embedding build complete: {embedded_count}/{} chunks", + build.total_chunks + ); + Ok(build) + } +} diff --git a/compliance-core/src/config.rs b/compliance-core/src/config.rs index 1c2ffae..3f38740 100644 --- a/compliance-core/src/config.rs +++ b/compliance-core/src/config.rs @@ -8,6 +8,7 @@ pub struct AgentConfig { pub litellm_url: String, pub litellm_api_key: SecretString, pub litellm_model: String, + pub litellm_embed_model: String, pub github_token: Option, pub github_webhook_secret: Option, pub gitlab_url: Option, diff --git a/compliance-core/src/models/chat.rs b/compliance-core/src/models/chat.rs new file mode 100644 index 0000000..a243c92 --- /dev/null +++ b/compliance-core/src/models/chat.rs @@ -0,0 +1,35 @@ +use serde::{Deserialize, Serialize}; + +/// A message in the chat history +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ChatMessage { + pub role: String, + pub content: String, +} + +/// Request body for the chat endpoint +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ChatRequest { + pub message: String, + #[serde(default)] + pub history: Vec, +} + +/// A source reference from the RAG retrieval +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct SourceReference { + pub file_path: String, + pub qualified_name: String, + pub start_line: u32, + pub end_line: u32, + pub language: String, + pub snippet: String, + pub score: f64, +} + +/// Response from the chat endpoint +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ChatResponse { + pub message: String, + pub sources: Vec, +} diff --git a/compliance-core/src/models/embedding.rs b/compliance-core/src/models/embedding.rs new file mode 100644 index 0000000..60f1f1a --- /dev/null +++ b/compliance-core/src/models/embedding.rs @@ -0,0 +1,100 @@ +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; + +/// Status of an embedding build operation +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +#[serde(rename_all = "snake_case")] +pub enum EmbeddingBuildStatus { + Running, + Completed, + Failed, +} + +/// A code embedding stored in MongoDB Atlas Vector Search +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct CodeEmbedding { + #[serde(rename = "_id", skip_serializing_if = "Option::is_none")] + pub id: Option, + pub repo_id: String, + pub graph_build_id: String, + pub qualified_name: String, + pub kind: String, + pub file_path: String, + pub start_line: u32, + pub end_line: u32, + pub language: String, + pub content: String, + pub context_header: String, + pub embedding: Vec, + pub token_estimate: u32, + #[serde(with = "bson::serde_helpers::chrono_datetime_as_bson_datetime")] + pub created_at: DateTime, +} + +/// Tracks an embedding build operation for a repository +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct EmbeddingBuildRun { + #[serde(rename = "_id", skip_serializing_if = "Option::is_none")] + pub id: Option, + pub repo_id: String, + pub graph_build_id: String, + pub status: EmbeddingBuildStatus, + pub total_chunks: u32, + pub embedded_chunks: u32, + pub embedding_model: String, + pub error_message: Option, + #[serde(with = "bson::serde_helpers::chrono_datetime_as_bson_datetime")] + pub started_at: DateTime, + #[serde( + default, + skip_serializing_if = "Option::is_none", + with = "opt_chrono_as_bson" + )] + pub completed_at: Option>, +} + +impl EmbeddingBuildRun { + pub fn new(repo_id: String, graph_build_id: String, embedding_model: String) -> Self { + Self { + id: None, + repo_id, + graph_build_id, + status: EmbeddingBuildStatus::Running, + total_chunks: 0, + embedded_chunks: 0, + embedding_model, + error_message: None, + started_at: Utc::now(), + completed_at: None, + } + } +} + +/// Serde helper for Option> as BSON DateTime +mod opt_chrono_as_bson { + use chrono::{DateTime, Utc}; + use serde::{Deserialize, Deserializer, Serialize, Serializer}; + + #[derive(Serialize, Deserialize)] + struct BsonDt( + #[serde(with = "bson::serde_helpers::chrono_datetime_as_bson_datetime")] DateTime, + ); + + pub fn serialize(value: &Option>, serializer: S) -> Result + where + S: Serializer, + { + match value { + Some(dt) => BsonDt(*dt).serialize(serializer), + None => serializer.serialize_none(), + } + } + + pub fn deserialize<'de, D>(deserializer: D) -> Result>, D::Error> + where + D: Deserializer<'de>, + { + let opt: Option = Option::deserialize(deserializer)?; + Ok(opt.map(|d| d.0)) + } +} diff --git a/compliance-core/src/models/mod.rs b/compliance-core/src/models/mod.rs index 1a210a5..c695d78 100644 --- a/compliance-core/src/models/mod.rs +++ b/compliance-core/src/models/mod.rs @@ -1,5 +1,7 @@ +pub mod chat; pub mod cve; pub mod dast; +pub mod embedding; pub mod finding; pub mod graph; pub mod issue; @@ -7,15 +9,16 @@ pub mod repository; pub mod sbom; pub mod scan; +pub use chat::{ChatMessage, ChatRequest, ChatResponse, SourceReference}; pub use cve::{CveAlert, CveSource}; pub use dast::{ DastAuthConfig, DastEvidence, DastFinding, DastScanPhase, DastScanRun, DastScanStatus, DastTarget, DastTargetType, DastVulnType, }; +pub use embedding::{CodeEmbedding, EmbeddingBuildRun, EmbeddingBuildStatus}; pub use finding::{Finding, FindingStatus, Severity}; pub use graph::{ - CodeEdge, CodeEdgeKind, CodeNode, CodeNodeKind, GraphBuildRun, GraphBuildStatus, - ImpactAnalysis, + CodeEdge, CodeEdgeKind, CodeNode, CodeNodeKind, GraphBuildRun, GraphBuildStatus, ImpactAnalysis, }; pub use issue::{IssueStatus, TrackerIssue, TrackerType}; pub use repository::{ScanTrigger, TrackedRepository}; diff --git a/compliance-dashboard/assets/main.css b/compliance-dashboard/assets/main.css index 37b1fbc..1cba4c1 100644 --- a/compliance-dashboard/assets/main.css +++ b/compliance-dashboard/assets/main.css @@ -1710,3 +1710,240 @@ tbody tr:last-child td { white-space: nowrap; margin-left: auto; } + +/* ── AI Chat ── */ + +.chat-embedding-banner { + display: flex; + align-items: center; + justify-content: space-between; + padding: 12px 20px; + background: var(--bg-card); + border: 1px solid var(--border); + border-radius: var(--radius); + margin-bottom: 16px; + font-size: 13px; + color: var(--text-secondary); +} + +.chat-embedding-banner .btn-sm { + padding: 6px 14px; + font-size: 12px; + background: var(--accent-muted); + color: var(--accent); + border: 1px solid var(--border-accent); + border-radius: var(--radius-sm); + cursor: pointer; + transition: all 0.2s var(--ease-out); +} + +.chat-embedding-banner .btn-sm:hover:not(:disabled) { + background: var(--accent); + color: var(--bg-primary); +} + +.chat-embedding-banner .btn-sm:disabled { + opacity: 0.5; + cursor: not-allowed; +} + +.chat-container { + display: flex; + flex-direction: column; + height: calc(100vh - 240px); + background: var(--bg-card); + border: 1px solid var(--border); + border-radius: var(--radius); + overflow: hidden; +} + +.chat-messages { + flex: 1; + overflow-y: auto; + padding: 20px; + display: flex; + flex-direction: column; + gap: 16px; +} + +.chat-empty { + display: flex; + flex-direction: column; + align-items: center; + justify-content: center; + height: 100%; + color: var(--text-tertiary); + text-align: center; +} + +.chat-empty h3 { + font-family: var(--font-display); + font-size: 18px; + font-weight: 600; + color: var(--text-secondary); + margin-bottom: 8px; +} + +.chat-empty p { + font-size: 13px; + max-width: 400px; +} + +.chat-message { + max-width: 80%; + padding: 12px 16px; + border-radius: var(--radius); + font-size: 14px; + line-height: 1.6; +} + +.chat-message-user { + align-self: flex-end; + background: var(--accent-muted); + border: 1px solid var(--border-accent); + color: var(--text-primary); +} + +.chat-message-assistant { + align-self: flex-start; + background: var(--bg-elevated); + border: 1px solid var(--border); + color: var(--text-primary); +} + +.chat-message-role { + font-family: var(--font-display); + font-size: 11px; + font-weight: 600; + text-transform: uppercase; + letter-spacing: 0.5px; + color: var(--text-tertiary); + margin-bottom: 6px; +} + +.chat-message-content { + white-space: pre-wrap; + word-break: break-word; +} + +.chat-typing { + color: var(--text-tertiary); + font-style: italic; +} + +.chat-sources { + margin-top: 12px; + border-top: 1px solid var(--border); + padding-top: 10px; +} + +.chat-sources-label { + font-size: 11px; + font-weight: 600; + text-transform: uppercase; + letter-spacing: 0.5px; + color: var(--text-tertiary); + display: block; + margin-bottom: 8px; +} + +.chat-source-card { + background: var(--bg-secondary); + border: 1px solid var(--border); + border-radius: var(--radius-sm); + padding: 10px 12px; + margin-bottom: 6px; +} + +.chat-source-header { + display: flex; + align-items: center; + justify-content: space-between; + margin-bottom: 6px; +} + +.chat-source-name { + font-family: var(--font-mono); + font-size: 12px; + font-weight: 500; + color: var(--accent); +} + +.chat-source-location { + font-family: var(--font-mono); + font-size: 11px; + color: var(--text-tertiary); +} + +.chat-source-snippet { + margin: 0; + padding: 8px; + background: var(--bg-primary); + border-radius: 4px; + overflow-x: auto; + max-height: 120px; +} + +.chat-source-snippet code { + font-family: var(--font-mono); + font-size: 11px; + color: var(--text-secondary); + white-space: pre; +} + +.chat-input-area { + display: flex; + gap: 10px; + padding: 16px 20px; + border-top: 1px solid var(--border); + background: var(--bg-secondary); +} + +.chat-input { + flex: 1; + background: var(--bg-primary); + border: 1px solid var(--border); + border-radius: var(--radius-sm); + color: var(--text-primary); + font-family: var(--font-body); + font-size: 14px; + padding: 10px 14px; + resize: none; + min-height: 42px; + max-height: 120px; + outline: none; + transition: border-color 0.2s var(--ease-out); +} + +.chat-input:focus { + border-color: var(--accent); +} + +.chat-input:disabled { + opacity: 0.5; + cursor: not-allowed; +} + +.chat-send-btn { + padding: 10px 20px; + background: var(--accent); + color: var(--bg-primary); + border: none; + border-radius: var(--radius-sm); + font-family: var(--font-display); + font-weight: 600; + font-size: 13px; + cursor: pointer; + transition: all 0.2s var(--ease-out); + align-self: flex-end; +} + +.chat-send-btn:hover:not(:disabled) { + background: var(--accent-hover); + box-shadow: var(--accent-glow); +} + +.chat-send-btn:disabled { + opacity: 0.5; + cursor: not-allowed; +} diff --git a/compliance-dashboard/src/app.rs b/compliance-dashboard/src/app.rs index 19bad87..08724cf 100644 --- a/compliance-dashboard/src/app.rs +++ b/compliance-dashboard/src/app.rs @@ -26,6 +26,10 @@ pub enum Route { GraphExplorerPage { repo_id: String }, #[route("/graph/:repo_id/impact/:finding_id")] ImpactAnalysisPage { repo_id: String, finding_id: String }, + #[route("/chat")] + ChatIndexPage {}, + #[route("/chat/:repo_id")] + ChatPage { repo_id: String }, #[route("/dast")] DastOverviewPage {}, #[route("/dast/targets")] diff --git a/compliance-dashboard/src/components/sidebar.rs b/compliance-dashboard/src/components/sidebar.rs index 8617190..bb9ab69 100644 --- a/compliance-dashboard/src/components/sidebar.rs +++ b/compliance-dashboard/src/components/sidebar.rs @@ -46,6 +46,11 @@ pub fn Sidebar() -> Element { route: Route::GraphIndexPage {}, icon: rsx! { Icon { icon: BsDiagram3, width: 18, height: 18 } }, }, + NavItem { + label: "AI Chat", + route: Route::ChatIndexPage {}, + icon: rsx! { Icon { icon: BsChatDots, width: 18, height: 18 } }, + }, NavItem { label: "DAST", route: Route::DastOverviewPage {}, @@ -58,7 +63,11 @@ pub fn Sidebar() -> Element { }, ]; - let sidebar_class = if collapsed() { "sidebar collapsed" } else { "sidebar" }; + let sidebar_class = if collapsed() { + "sidebar collapsed" + } else { + "sidebar" + }; rsx! { nav { class: "{sidebar_class}", @@ -76,6 +85,7 @@ pub fn Sidebar() -> Element { (Route::GraphIndexPage {}, Route::GraphIndexPage {}) => true, (Route::GraphExplorerPage { .. }, Route::GraphIndexPage {}) => true, (Route::ImpactAnalysisPage { .. }, Route::GraphIndexPage {}) => true, + (Route::ChatPage { .. }, Route::ChatIndexPage {}) => true, (Route::DastTargetsPage {}, Route::DastOverviewPage {}) => true, (Route::DastFindingsPage {}, Route::DastOverviewPage {}) => true, (Route::DastFindingDetailPage { .. }, Route::DastOverviewPage {}) => true, diff --git a/compliance-dashboard/src/infrastructure/chat.rs b/compliance-dashboard/src/infrastructure/chat.rs new file mode 100644 index 0000000..6dee347 --- /dev/null +++ b/compliance-dashboard/src/infrastructure/chat.rs @@ -0,0 +1,126 @@ +use dioxus::prelude::*; +use serde::{Deserialize, Serialize}; + +// ── Response types ── + +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct ChatApiResponse { + pub data: ChatResponseData, +} + +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct ChatResponseData { + pub message: String, + #[serde(default)] + pub sources: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct SourceRef { + pub file_path: String, + pub qualified_name: String, + pub start_line: u32, + pub end_line: u32, + pub language: String, + pub snippet: String, + pub score: f64, +} + +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct EmbeddingStatusResponse { + pub data: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct EmbeddingBuildData { + pub repo_id: String, + pub status: String, + pub total_chunks: u32, + pub embedded_chunks: u32, + pub embedding_model: String, + pub error_message: Option, + #[serde(default)] + pub started_at: Option, + #[serde(default)] + pub completed_at: Option, +} + +// ── Chat message history type ── + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ChatHistoryMessage { + pub role: String, + pub content: String, +} + +// ── Server functions ── + +#[server] +pub async fn send_chat_message( + repo_id: String, + message: String, + history: Vec, +) -> Result { + let state: super::server_state::ServerState = + dioxus_fullstack::FullstackContext::extract().await?; + + let url = format!("{}/api/v1/chat/{repo_id}", state.agent_api_url); + let client = reqwest::Client::builder() + .timeout(std::time::Duration::from_secs(120)) + .build() + .map_err(|e| ServerFnError::new(e.to_string()))?; + let resp = client + .post(&url) + .json(&serde_json::json!({ + "message": message, + "history": history, + })) + .send() + .await + .map_err(|e| ServerFnError::new(format!("Request failed: {e}")))?; + + let text = resp + .text() + .await + .map_err(|e| ServerFnError::new(format!("Failed to read response: {e}")))?; + + let body: ChatApiResponse = serde_json::from_str(&text) + .map_err(|e| ServerFnError::new(format!("Failed to parse response: {e} — body: {text}")))?; + Ok(body) +} + +#[server] +pub async fn trigger_embedding_build(repo_id: String) -> Result<(), ServerFnError> { + let state: super::server_state::ServerState = + dioxus_fullstack::FullstackContext::extract().await?; + + let url = format!( + "{}/api/v1/chat/{repo_id}/build-embeddings", + state.agent_api_url + ); + let client = reqwest::Client::new(); + client + .post(&url) + .send() + .await + .map_err(|e| ServerFnError::new(e.to_string()))?; + Ok(()) +} + +#[server] +pub async fn fetch_embedding_status( + repo_id: String, +) -> Result { + let state: super::server_state::ServerState = + dioxus_fullstack::FullstackContext::extract().await?; + + let url = format!("{}/api/v1/chat/{repo_id}/status", state.agent_api_url); + let resp = reqwest::get(&url) + .await + .map_err(|e| ServerFnError::new(e.to_string()))?; + let body: EmbeddingStatusResponse = resp + .json() + .await + .map_err(|e| ServerFnError::new(e.to_string()))?; + Ok(body) +} diff --git a/compliance-dashboard/src/infrastructure/mod.rs b/compliance-dashboard/src/infrastructure/mod.rs index 0862ecf..9ee2706 100644 --- a/compliance-dashboard/src/infrastructure/mod.rs +++ b/compliance-dashboard/src/infrastructure/mod.rs @@ -1,5 +1,6 @@ // Server function modules (compiled for both web and server; // the #[server] macro generates client stubs for the web target) +pub mod chat; pub mod dast; pub mod findings; pub mod graph; diff --git a/compliance-dashboard/src/pages/chat.rs b/compliance-dashboard/src/pages/chat.rs new file mode 100644 index 0000000..882dffb --- /dev/null +++ b/compliance-dashboard/src/pages/chat.rs @@ -0,0 +1,232 @@ +use dioxus::prelude::*; + +use crate::components::page_header::PageHeader; +use crate::infrastructure::chat::{ + fetch_embedding_status, send_chat_message, trigger_embedding_build, ChatHistoryMessage, + SourceRef, +}; + +/// A UI-level chat message +#[derive(Clone, Debug)] +struct UiChatMessage { + role: String, + content: String, + sources: Vec, +} + +#[component] +pub fn ChatPage(repo_id: String) -> Element { + let mut messages: Signal> = use_signal(Vec::new); + let mut input_text = use_signal(String::new); + let mut loading = use_signal(|| false); + let mut building = use_signal(|| false); + + let repo_id_for_status = repo_id.clone(); + let mut embedding_status = use_resource(move || { + let rid = repo_id_for_status.clone(); + async move { fetch_embedding_status(rid).await.ok() } + }); + + let has_embeddings = { + let status = embedding_status.read(); + match &*status { + Some(Some(resp)) => resp + .data + .as_ref() + .map(|d| d.status == "completed") + .unwrap_or(false), + _ => false, + } + }; + + let embedding_status_text = { + let status = embedding_status.read(); + match &*status { + Some(Some(resp)) => match &resp.data { + Some(d) => match d.status.as_str() { + "completed" => format!( + "Embeddings ready: {}/{} chunks", + d.embedded_chunks, d.total_chunks + ), + "running" => format!( + "Building embeddings: {}/{}...", + d.embedded_chunks, d.total_chunks + ), + "failed" => format!( + "Embedding build failed: {}", + d.error_message.as_deref().unwrap_or("unknown error") + ), + s => format!("Status: {s}"), + }, + None => "No embeddings built yet".to_string(), + }, + Some(None) => "Failed to check embedding status".to_string(), + None => "Checking embedding status...".to_string(), + } + }; + + let repo_id_for_build = repo_id.clone(); + let on_build = move |_| { + let rid = repo_id_for_build.clone(); + building.set(true); + spawn(async move { + let _ = trigger_embedding_build(rid).await; + building.set(false); + embedding_status.restart(); + }); + }; + + let repo_id_for_send = repo_id.clone(); + let mut do_send = move || { + let text = input_text.read().trim().to_string(); + if text.is_empty() || *loading.read() { + return; + } + + let rid = repo_id_for_send.clone(); + let user_msg = text.clone(); + + // Add user message to UI + messages.write().push(UiChatMessage { + role: "user".to_string(), + content: user_msg.clone(), + sources: Vec::new(), + }); + input_text.set(String::new()); + loading.set(true); + + spawn(async move { + // Build history from existing messages + let history: Vec = messages + .read() + .iter() + .filter(|m| m.role == "user" || m.role == "assistant") + .rev() + .skip(1) // skip the message we just added + .take(10) // limit history + .collect::>() + .into_iter() + .rev() + .map(|m| ChatHistoryMessage { + role: m.role.clone(), + content: m.content.clone(), + }) + .collect(); + + match send_chat_message(rid, user_msg, history).await { + Ok(resp) => { + messages.write().push(UiChatMessage { + role: "assistant".to_string(), + content: resp.data.message, + sources: resp.data.sources, + }); + } + Err(e) => { + messages.write().push(UiChatMessage { + role: "assistant".to_string(), + content: format!("Error: {e}"), + sources: Vec::new(), + }); + } + } + loading.set(false); + }); + }; + + let mut do_send_click = do_send.clone(); + + rsx! { + PageHeader { title: "AI Chat" } + + // Embedding status banner + div { class: "chat-embedding-banner", + span { "{embedding_status_text}" } + button { + class: "btn btn-sm", + disabled: *building.read(), + onclick: on_build, + if *building.read() { "Building..." } else { "Build Embeddings" } + } + } + + div { class: "chat-container", + // Message list + div { class: "chat-messages", + if messages.read().is_empty() && !*loading.read() { + div { class: "chat-empty", + h3 { "Ask anything about your codebase" } + p { "Build embeddings first, then ask questions about functions, architecture, patterns, and more." } + } + } + for (i, msg) in messages.read().iter().enumerate() { + { + let class = if msg.role == "user" { + "chat-message chat-message-user" + } else { + "chat-message chat-message-assistant" + }; + let content = msg.content.clone(); + let sources = msg.sources.clone(); + rsx! { + div { class: class, key: "{i}", + div { class: "chat-message-role", + if msg.role == "user" { "You" } else { "Assistant" } + } + div { class: "chat-message-content", "{content}" } + if !sources.is_empty() { + div { class: "chat-sources", + span { class: "chat-sources-label", "Sources:" } + for src in sources { + div { class: "chat-source-card", + div { class: "chat-source-header", + span { class: "chat-source-name", + "{src.qualified_name}" + } + span { class: "chat-source-location", + "{src.file_path}:{src.start_line}-{src.end_line}" + } + } + pre { class: "chat-source-snippet", + code { "{src.snippet}" } + } + } + } + } + } + } + } + } + } + if *loading.read() { + div { class: "chat-message chat-message-assistant", + div { class: "chat-message-role", "Assistant" } + div { class: "chat-message-content chat-typing", "Thinking..." } + } + } + } + + // Input area + div { class: "chat-input-area", + textarea { + class: "chat-input", + placeholder: "Ask about your codebase...", + value: "{input_text}", + disabled: !has_embeddings, + oninput: move |e| input_text.set(e.value()), + onkeydown: move |e: Event| { + if e.key() == Key::Enter && !e.modifiers().shift() { + e.prevent_default(); + do_send(); + } + }, + } + button { + class: "btn chat-send-btn", + disabled: *loading.read() || !has_embeddings, + onclick: move |_| do_send_click(), + "Send" + } + } + } + } +} diff --git a/compliance-dashboard/src/pages/chat_index.rs b/compliance-dashboard/src/pages/chat_index.rs new file mode 100644 index 0000000..73ff29a --- /dev/null +++ b/compliance-dashboard/src/pages/chat_index.rs @@ -0,0 +1,71 @@ +use dioxus::prelude::*; + +use crate::app::Route; +use crate::components::page_header::PageHeader; +use crate::infrastructure::chat::fetch_embedding_status; +use crate::infrastructure::repositories::fetch_repositories; + +#[component] +pub fn ChatIndexPage() -> Element { + let repos = use_resource(|| async { fetch_repositories(1).await.ok() }); + + rsx! { + PageHeader { + title: "AI Chat", + description: "Ask questions about your codebase using RAG-augmented AI", + } + + match &*repos.read() { + Some(Some(data)) => { + let repo_list = &data.data; + if repo_list.is_empty() { + rsx! { + div { class: "card", + p { "No repositories found. Add a repository first." } + } + } + } else { + rsx! { + div { class: "graph-index-grid", + for repo in repo_list { + { + let repo_id = repo.id.map(|id| id.to_hex()).unwrap_or_default(); + let name = repo.name.clone(); + let url = repo.git_url.clone(); + let branch = repo.default_branch.clone(); + rsx! { + Link { + to: Route::ChatPage { repo_id }, + class: "graph-repo-card", + div { class: "graph-repo-card-header", + div { class: "graph-repo-card-icon", "\u{1F4AC}" } + h3 { class: "graph-repo-card-name", "{name}" } + } + if !url.is_empty() { + p { class: "graph-repo-card-url", "{url}" } + } + div { class: "graph-repo-card-meta", + span { class: "graph-repo-card-tag", + "\u{E0A0} {branch}" + } + span { class: "graph-repo-card-tag", + "AI Chat" + } + } + } + } + } + } + } + } + } + }, + Some(None) => rsx! { + div { class: "card", p { "Failed to load repositories." } } + }, + None => rsx! { + div { class: "loading", "Loading repositories..." } + }, + } + } +} diff --git a/compliance-dashboard/src/pages/mod.rs b/compliance-dashboard/src/pages/mod.rs index 16b8803..5d14ed5 100644 --- a/compliance-dashboard/src/pages/mod.rs +++ b/compliance-dashboard/src/pages/mod.rs @@ -1,3 +1,5 @@ +pub mod chat; +pub mod chat_index; pub mod dast_finding_detail; pub mod dast_findings; pub mod dast_overview; @@ -13,6 +15,8 @@ pub mod repositories; pub mod sbom; pub mod settings; +pub use chat::ChatPage; +pub use chat_index::ChatIndexPage; pub use dast_finding_detail::DastFindingDetailPage; pub use dast_findings::DastFindingsPage; pub use dast_overview::DastOverviewPage; diff --git a/compliance-graph/src/graph/chunking.rs b/compliance-graph/src/graph/chunking.rs new file mode 100644 index 0000000..ebbc5a0 --- /dev/null +++ b/compliance-graph/src/graph/chunking.rs @@ -0,0 +1,96 @@ +use std::path::Path; + +use compliance_core::models::graph::CodeNode; + +/// A chunk of code extracted from a source file, ready for embedding +#[derive(Debug, Clone)] +pub struct CodeChunk { + pub qualified_name: String, + pub kind: String, + pub file_path: String, + pub start_line: u32, + pub end_line: u32, + pub language: String, + pub content: String, + pub context_header: String, + pub token_estimate: u32, +} + +/// Extract embeddable code chunks from parsed CodeNodes. +/// +/// For each node, reads the corresponding source lines from disk, +/// builds a context header, and estimates tokens. +pub fn extract_chunks( + repo_path: &Path, + nodes: &[CodeNode], + max_chunk_tokens: u32, +) -> Vec { + let mut chunks = Vec::new(); + + for node in nodes { + let file = repo_path.join(&node.file_path); + let source = match std::fs::read_to_string(&file) { + Ok(s) => s, + Err(_) => continue, + }; + + let lines: Vec<&str> = source.lines().collect(); + let start = node.start_line.saturating_sub(1) as usize; + let end = (node.end_line as usize).min(lines.len()); + if start >= end { + continue; + } + + let content: String = lines[start..end].join("\n"); + + // Skip tiny chunks + if content.len() < 50 { + continue; + } + + // Estimate tokens (~4 chars per token) + let mut token_estimate = (content.len() / 4) as u32; + + // Truncate if too large + let final_content = if token_estimate > max_chunk_tokens { + let max_chars = (max_chunk_tokens as usize) * 4; + token_estimate = max_chunk_tokens; + content.chars().take(max_chars).collect() + } else { + content + }; + + // Build context header: file path + containing scope hint + let context_header = build_context_header( + &node.file_path, + &node.qualified_name, + &node.kind.to_string(), + ); + + chunks.push(CodeChunk { + qualified_name: node.qualified_name.clone(), + kind: node.kind.to_string(), + file_path: node.file_path.clone(), + start_line: node.start_line, + end_line: node.end_line, + language: node.language.clone(), + content: final_content, + context_header, + token_estimate, + }); + } + + chunks +} + +fn build_context_header(file_path: &str, qualified_name: &str, kind: &str) -> String { + // Extract containing module/class from qualified name + // e.g. "src/main.rs::MyStruct::my_method" → parent is "MyStruct" + let parts: Vec<&str> = qualified_name.split("::").collect(); + if parts.len() >= 2 { + let parent = parts[..parts.len() - 1].join("::"); + format!("// {file_path} | {kind} in {parent}") + } else { + format!("// {file_path} | {kind}") + } +} diff --git a/compliance-graph/src/graph/embedding_store.rs b/compliance-graph/src/graph/embedding_store.rs new file mode 100644 index 0000000..689d8af --- /dev/null +++ b/compliance-graph/src/graph/embedding_store.rs @@ -0,0 +1,238 @@ +use compliance_core::error::CoreError; +use compliance_core::models::embedding::{CodeEmbedding, EmbeddingBuildRun, EmbeddingBuildStatus}; +use futures_util::TryStreamExt; +use mongodb::bson::doc; +use mongodb::{Collection, Database, IndexModel}; +use tracing::info; + +/// MongoDB persistence layer for code embeddings and vector search +pub struct EmbeddingStore { + embeddings: Collection, + builds: Collection, +} + +impl EmbeddingStore { + pub fn new(db: &Database) -> Self { + Self { + embeddings: db.collection("code_embeddings"), + builds: db.collection("embedding_builds"), + } + } + + /// Create standard indexes. NOTE: The Atlas Vector Search index must be + /// created via the Atlas UI or CLI with the following definition: + /// ```json + /// { + /// "fields": [ + /// { "type": "vector", "path": "embedding", "numDimensions": 1536, "similarity": "cosine" }, + /// { "type": "filter", "path": "repo_id" } + /// ] + /// } + /// ``` + pub async fn ensure_indexes(&self) -> Result<(), CoreError> { + self.embeddings + .create_index( + IndexModel::builder() + .keys(doc! { "repo_id": 1, "graph_build_id": 1 }) + .build(), + ) + .await?; + + self.builds + .create_index( + IndexModel::builder() + .keys(doc! { "repo_id": 1, "started_at": -1 }) + .build(), + ) + .await?; + + Ok(()) + } + + /// Delete all embeddings for a repository + pub async fn delete_repo_embeddings(&self, repo_id: &str) -> Result { + let result = self + .embeddings + .delete_many(doc! { "repo_id": repo_id }) + .await?; + info!( + "Deleted {} embeddings for repo {repo_id}", + result.deleted_count + ); + Ok(result.deleted_count) + } + + /// Store embeddings in batches of 500 + pub async fn store_embeddings(&self, embeddings: &[CodeEmbedding]) -> Result { + let mut total_inserted = 0u64; + for batch in embeddings.chunks(500) { + let result = self.embeddings.insert_many(batch).await?; + total_inserted += result.inserted_ids.len() as u64; + } + info!("Stored {total_inserted} embeddings"); + Ok(total_inserted) + } + + /// Store a new build run + pub async fn store_build(&self, build: &EmbeddingBuildRun) -> Result<(), CoreError> { + self.builds.insert_one(build).await?; + Ok(()) + } + + /// Update an existing build run + pub async fn update_build( + &self, + repo_id: &str, + graph_build_id: &str, + status: EmbeddingBuildStatus, + embedded_chunks: u32, + error_message: Option, + ) -> Result<(), CoreError> { + let mut update = doc! { + "$set": { + "status": mongodb::bson::to_bson(&status).unwrap_or_default(), + "embedded_chunks": embedded_chunks as i64, + } + }; + + if status == EmbeddingBuildStatus::Completed || status == EmbeddingBuildStatus::Failed { + update + .get_document_mut("$set") + .unwrap() + .insert("completed_at", mongodb::bson::DateTime::now()); + } + + if let Some(msg) = error_message { + update + .get_document_mut("$set") + .unwrap() + .insert("error_message", msg); + } + + self.builds + .update_one( + doc! { "repo_id": repo_id, "graph_build_id": graph_build_id }, + update, + ) + .await?; + Ok(()) + } + + /// Get the latest embedding build for a repository + pub async fn get_latest_build( + &self, + repo_id: &str, + ) -> Result, CoreError> { + Ok(self + .builds + .find_one(doc! { "repo_id": repo_id }) + .sort(doc! { "started_at": -1 }) + .await?) + } + + /// Perform vector search. Tries Atlas $vectorSearch first, falls back to + /// brute-force cosine similarity for local MongoDB instances. + pub async fn vector_search( + &self, + repo_id: &str, + query_embedding: Vec, + limit: u32, + min_score: f64, + ) -> Result, CoreError> { + match self + .atlas_vector_search(repo_id, &query_embedding, limit, min_score) + .await + { + Ok(results) => Ok(results), + Err(e) => { + info!( + "Atlas $vectorSearch unavailable ({e}), falling back to brute-force cosine similarity" + ); + self.bruteforce_vector_search(repo_id, &query_embedding, limit, min_score) + .await + } + } + } + + /// Atlas $vectorSearch aggregation stage (requires Atlas Vector Search index) + async fn atlas_vector_search( + &self, + repo_id: &str, + query_embedding: &[f64], + limit: u32, + min_score: f64, + ) -> Result, CoreError> { + use mongodb::bson::{Bson, Document}; + + let pipeline = vec![ + doc! { + "$vectorSearch": { + "index": "embedding_vector_index", + "path": "embedding", + "queryVector": query_embedding.iter().map(|&v| Bson::Double(v)).collect::>(), + "numCandidates": (limit * 10) as i64, + "limit": limit as i64, + "filter": { "repo_id": repo_id }, + } + }, + doc! { + "$addFields": { + "search_score": { "$meta": "vectorSearchScore" } + } + }, + doc! { + "$match": { + "search_score": { "$gte": min_score } + } + }, + ]; + + let mut cursor = self.embeddings.aggregate(pipeline).await?; + + let mut results = Vec::new(); + while let Some(doc) = cursor.try_next().await? { + let score = doc.get_f64("search_score").unwrap_or(0.0); + let mut clean_doc: Document = doc; + clean_doc.remove("search_score"); + if let Ok(embedding) = mongodb::bson::from_document::(clean_doc) { + results.push((embedding, score)); + } + } + + Ok(results) + } + + /// Brute-force cosine similarity fallback for local MongoDB without Atlas + async fn bruteforce_vector_search( + &self, + repo_id: &str, + query_embedding: &[f64], + limit: u32, + min_score: f64, + ) -> Result, CoreError> { + let mut cursor = self.embeddings.find(doc! { "repo_id": repo_id }).await?; + + let query_norm = dot(query_embedding, query_embedding).sqrt(); + let mut scored: Vec<(CodeEmbedding, f64)> = Vec::new(); + + while let Some(emb) = cursor.try_next().await? { + let doc_norm = dot(&emb.embedding, &emb.embedding).sqrt(); + let score = if query_norm > 0.0 && doc_norm > 0.0 { + dot(query_embedding, &emb.embedding) / (query_norm * doc_norm) + } else { + 0.0 + }; + if score >= min_score { + scored.push((emb, score)); + } + } + + scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)); + scored.truncate(limit as usize); + Ok(scored) + } +} + +fn dot(a: &[f64], b: &[f64]) -> f64 { + a.iter().zip(b.iter()).map(|(x, y)| x * y).sum() +} diff --git a/compliance-graph/src/graph/mod.rs b/compliance-graph/src/graph/mod.rs index f66238d..caaac75 100644 --- a/compliance-graph/src/graph/mod.rs +++ b/compliance-graph/src/graph/mod.rs @@ -1,4 +1,6 @@ +pub mod chunking; pub mod community; +pub mod embedding_store; pub mod engine; pub mod impact; pub mod persistence;