use std::path::Path; use std::sync::Arc; use chrono::Utc; use compliance_core::models::embedding::{CodeEmbedding, EmbeddingBuildRun, EmbeddingBuildStatus}; use compliance_core::models::graph::CodeNode; use compliance_graph::graph::chunking::extract_chunks; use compliance_graph::graph::embedding_store::EmbeddingStore; use futures_util::stream::{FuturesUnordered, StreamExt}; use tracing::{error, info}; use crate::error::AgentError; use crate::llm::LlmClient; const EMBED_BATCH_SIZE: usize = 20; const EMBED_CONCURRENCY: usize = 4; const EMBED_FLUSH_EVERY: usize = 200; /// RAG pipeline for building embeddings and performing retrieval pub struct RagPipeline { llm: Arc, embedding_store: EmbeddingStore, } impl RagPipeline { pub fn new(llm: Arc, db: &mongodb::Database) -> Self { Self { llm, embedding_store: EmbeddingStore::new(db), } } pub fn store(&self) -> &EmbeddingStore { &self.embedding_store } /// Build embeddings for all code nodes in a repository pub async fn build_embeddings( &self, repo_id: &str, repo_path: &Path, graph_build_id: &str, nodes: &[CodeNode], ) -> Result { let embed_model = self.llm.embed_model().to_string(); let mut build = EmbeddingBuildRun::new(repo_id.to_string(), graph_build_id.to_string(), embed_model); // Step 1: Extract chunks let chunks = extract_chunks(repo_path, nodes, 2048); build.total_chunks = chunks.len() as u32; info!( "[{repo_id}] Extracted {} chunks for embedding", chunks.len() ); // Store the initial build record self.embedding_store .store_build(&build) .await .map_err(|e| AgentError::Other(format!("Failed to store build: {e}")))?; if chunks.is_empty() { build.status = EmbeddingBuildStatus::Completed; build.completed_at = Some(Utc::now()); self.embedding_store .update_build( repo_id, graph_build_id, EmbeddingBuildStatus::Completed, 0, None, ) .await .map_err(|e| AgentError::Other(format!("Failed to update build: {e}")))?; return Ok(build); } // Step 2: Delete old embeddings for this repo self.embedding_store .delete_repo_embeddings(repo_id) .await .map_err(|e| AgentError::Other(format!("Failed to delete old embeddings: {e}")))?; // Step 3: Batch embed with bounded concurrency. Flush to Mongo and // update progress periodically so the dashboard can show live status. let mut pending = Vec::with_capacity(EMBED_FLUSH_EVERY); let mut embedded_count = 0u32; // Build the list of batch indices to process. let batches: Vec<(usize, usize)> = (0..chunks.len()) .step_by(EMBED_BATCH_SIZE) .map(|start| (start, (start + EMBED_BATCH_SIZE).min(chunks.len()))) .collect(); let mut batch_iter = batches.into_iter(); let mut in_flight = FuturesUnordered::new(); // Prime up to EMBED_CONCURRENCY batches. for _ in 0..EMBED_CONCURRENCY { if let Some((start, end)) = batch_iter.next() { in_flight.push(self.embed_batch(&chunks[start..end], start, end)); } } while let Some(result) = in_flight.next().await { match result { Ok((start, end, vectors)) => { let batch_chunks = &chunks[start..end]; for (chunk, embedding) in batch_chunks.iter().zip(vectors) { pending.push(CodeEmbedding { id: None, repo_id: repo_id.to_string(), graph_build_id: graph_build_id.to_string(), qualified_name: chunk.qualified_name.clone(), kind: chunk.kind.clone(), file_path: chunk.file_path.clone(), start_line: chunk.start_line, end_line: chunk.end_line, language: chunk.language.clone(), content: chunk.content.clone(), context_header: chunk.context_header.clone(), embedding, token_estimate: chunk.token_estimate, created_at: Utc::now(), }); } embedded_count += batch_chunks.len() as u32; // Flush pending embeddings to Mongo periodically and update progress. if pending.len() >= EMBED_FLUSH_EVERY { self.embedding_store .store_embeddings(&pending) .await .map_err(|e| { AgentError::Other(format!("Failed to store embeddings: {e}")) })?; pending.clear(); } // Always update the progress counter on the build doc — even if // we haven't flushed embeddings yet — so the UI shows movement. if let Err(e) = self .embedding_store .update_build( repo_id, graph_build_id, EmbeddingBuildStatus::Running, embedded_count, None, ) .await { error!("[{repo_id}] Failed to update build progress: {e}"); } // Queue the next batch to keep concurrency saturated. if let Some((s, e)) = batch_iter.next() { in_flight.push(self.embed_batch(&chunks[s..e], s, e)); } } Err(e) => { error!("[{repo_id}] Embedding batch failed: {e}"); // Flush whatever we have so partial progress isn't lost. if !pending.is_empty() { let _ = self.embedding_store.store_embeddings(&pending).await; } build.status = EmbeddingBuildStatus::Failed; build.error_message = Some(e.to_string()); build.completed_at = Some(Utc::now()); let _ = self .embedding_store .update_build( repo_id, graph_build_id, EmbeddingBuildStatus::Failed, embedded_count, Some(e.to_string()), ) .await; return Ok(build); } } } // Step 4: Flush any remaining embeddings if !pending.is_empty() { self.embedding_store .store_embeddings(&pending) .await .map_err(|e| AgentError::Other(format!("Failed to store embeddings: {e}")))?; } // Step 5: Update build status build.status = EmbeddingBuildStatus::Completed; build.embedded_chunks = embedded_count; build.completed_at = Some(Utc::now()); self.embedding_store .update_build( repo_id, graph_build_id, EmbeddingBuildStatus::Completed, embedded_count, None, ) .await .map_err(|e| AgentError::Other(format!("Failed to update build: {e}")))?; info!( "[{repo_id}] Embedding build complete: {embedded_count}/{} chunks", build.total_chunks ); Ok(build) } /// Embed one batch of chunks. Returns the (start, end, vectors) tuple so /// out-of-order completion from `FuturesUnordered` can still be reconciled /// against the original chunk slice. async fn embed_batch( &self, batch_chunks: &[compliance_graph::graph::chunking::CodeChunk], start: usize, end: usize, ) -> Result<(usize, usize, Vec>), AgentError> { let texts: Vec = batch_chunks .iter() .map(|c| format!("{}\n{}", c.context_header, c.content)) .collect(); let vectors = self.llm.embed(texts).await?; Ok((start, end, vectors)) } }