use std::path::Path; use std::sync::Arc; use chrono::Utc; use compliance_core::models::embedding::{CodeEmbedding, EmbeddingBuildRun, EmbeddingBuildStatus}; use compliance_core::models::graph::CodeNode; use compliance_graph::graph::chunking::extract_chunks; use compliance_graph::graph::embedding_store::EmbeddingStore; use tracing::{error, info}; use crate::error::AgentError; use crate::llm::LlmClient; /// RAG pipeline for building embeddings and performing retrieval pub struct RagPipeline { llm: Arc, embedding_store: EmbeddingStore, } impl RagPipeline { pub fn new(llm: Arc, db: &mongodb::Database) -> Self { Self { llm, embedding_store: EmbeddingStore::new(db), } } pub fn store(&self) -> &EmbeddingStore { &self.embedding_store } /// Build embeddings for all code nodes in a repository pub async fn build_embeddings( &self, repo_id: &str, repo_path: &Path, graph_build_id: &str, nodes: &[CodeNode], ) -> Result { let embed_model = self.llm.embed_model().to_string(); let mut build = EmbeddingBuildRun::new(repo_id.to_string(), graph_build_id.to_string(), embed_model); // Step 1: Extract chunks let chunks = extract_chunks(repo_path, nodes, 2048); build.total_chunks = chunks.len() as u32; info!( "[{repo_id}] Extracted {} chunks for embedding", chunks.len() ); // Store the initial build record self.embedding_store .store_build(&build) .await .map_err(|e| AgentError::Other(format!("Failed to store build: {e}")))?; if chunks.is_empty() { build.status = EmbeddingBuildStatus::Completed; build.completed_at = Some(Utc::now()); self.embedding_store .update_build( repo_id, graph_build_id, EmbeddingBuildStatus::Completed, 0, None, ) .await .map_err(|e| AgentError::Other(format!("Failed to update build: {e}")))?; return Ok(build); } // Step 2: Delete old embeddings for this repo self.embedding_store .delete_repo_embeddings(repo_id) .await .map_err(|e| AgentError::Other(format!("Failed to delete old embeddings: {e}")))?; // Step 3: Batch embed (small batches to stay within model limits) let batch_size = 20; let mut all_embeddings = Vec::new(); let mut embedded_count = 0u32; for batch_start in (0..chunks.len()).step_by(batch_size) { let batch_end = (batch_start + batch_size).min(chunks.len()); let batch_chunks = &chunks[batch_start..batch_end]; // Prepare texts: context_header + content let texts: Vec = batch_chunks .iter() .map(|c| format!("{}\n{}", c.context_header, c.content)) .collect(); match self.llm.embed(texts).await { Ok(vectors) => { for (chunk, embedding) in batch_chunks.iter().zip(vectors) { all_embeddings.push(CodeEmbedding { id: None, repo_id: repo_id.to_string(), graph_build_id: graph_build_id.to_string(), qualified_name: chunk.qualified_name.clone(), kind: chunk.kind.clone(), file_path: chunk.file_path.clone(), start_line: chunk.start_line, end_line: chunk.end_line, language: chunk.language.clone(), content: chunk.content.clone(), context_header: chunk.context_header.clone(), embedding, token_estimate: chunk.token_estimate, created_at: Utc::now(), }); } embedded_count += batch_chunks.len() as u32; } Err(e) => { error!("[{repo_id}] Embedding batch failed: {e}"); build.status = EmbeddingBuildStatus::Failed; build.error_message = Some(e.to_string()); build.completed_at = Some(Utc::now()); let _ = self .embedding_store .update_build( repo_id, graph_build_id, EmbeddingBuildStatus::Failed, embedded_count, Some(e.to_string()), ) .await; return Ok(build); } } } // Step 4: Store all embeddings self.embedding_store .store_embeddings(&all_embeddings) .await .map_err(|e| AgentError::Other(format!("Failed to store embeddings: {e}")))?; // Step 5: Update build status build.status = EmbeddingBuildStatus::Completed; build.embedded_chunks = embedded_count; build.completed_at = Some(Utc::now()); self.embedding_store .update_build( repo_id, graph_build_id, EmbeddingBuildStatus::Completed, embedded_count, None, ) .await .map_err(|e| AgentError::Other(format!("Failed to update build: {e}")))?; info!( "[{repo_id}] Embedding build complete: {embedded_count}/{} chunks", build.total_chunks ); Ok(build) } }