From 927fbc8ecb1ee476a813642007442b6304e07caa Mon Sep 17 00:00:00 2001 From: Sharang Parnerkar Date: Wed, 13 May 2026 10:01:05 +0000 Subject: [PATCH] fix: live progress + concurrency for embedding builds (#80) --- compliance-agent/src/llm/client.rs | 7 +- compliance-agent/src/rag/pipeline.rs | 106 ++++++++++++++++++++++----- 2 files changed, 93 insertions(+), 20 deletions(-) diff --git a/compliance-agent/src/llm/client.rs b/compliance-agent/src/llm/client.rs index d3bfae6..7f9e785 100644 --- a/compliance-agent/src/llm/client.rs +++ b/compliance-agent/src/llm/client.rs @@ -19,12 +19,17 @@ impl LlmClient { model: String, embed_model: String, ) -> Self { + let http = reqwest::Client::builder() + .timeout(std::time::Duration::from_secs(300)) + .connect_timeout(std::time::Duration::from_secs(10)) + .build() + .unwrap_or_default(); Self { base_url, api_key, model, embed_model, - http: reqwest::Client::new(), + http, } } diff --git a/compliance-agent/src/rag/pipeline.rs b/compliance-agent/src/rag/pipeline.rs index 19d5949..ff0c5f6 100644 --- a/compliance-agent/src/rag/pipeline.rs +++ b/compliance-agent/src/rag/pipeline.rs @@ -6,11 +6,16 @@ use compliance_core::models::embedding::{CodeEmbedding, EmbeddingBuildRun, Embed use compliance_core::models::graph::CodeNode; use compliance_graph::graph::chunking::extract_chunks; use compliance_graph::graph::embedding_store::EmbeddingStore; +use futures_util::stream::{FuturesUnordered, StreamExt}; use tracing::{error, info}; use crate::error::AgentError; use crate::llm::LlmClient; +const EMBED_BATCH_SIZE: usize = 20; +const EMBED_CONCURRENCY: usize = 4; +const EMBED_FLUSH_EVERY: usize = 200; + /// RAG pipeline for building embeddings and performing retrieval pub struct RagPipeline { llm: Arc, @@ -77,25 +82,33 @@ impl RagPipeline { .await .map_err(|e| AgentError::Other(format!("Failed to delete old embeddings: {e}")))?; - // Step 3: Batch embed (small batches to stay within model limits) - let batch_size = 20; - let mut all_embeddings = Vec::new(); + // Step 3: Batch embed with bounded concurrency. Flush to Mongo and + // update progress periodically so the dashboard can show live status. + let mut pending = Vec::with_capacity(EMBED_FLUSH_EVERY); let mut embedded_count = 0u32; - for batch_start in (0..chunks.len()).step_by(batch_size) { - let batch_end = (batch_start + batch_size).min(chunks.len()); - let batch_chunks = &chunks[batch_start..batch_end]; + // Build the list of batch indices to process. + let batches: Vec<(usize, usize)> = (0..chunks.len()) + .step_by(EMBED_BATCH_SIZE) + .map(|start| (start, (start + EMBED_BATCH_SIZE).min(chunks.len()))) + .collect(); - // Prepare texts: context_header + content - let texts: Vec = batch_chunks - .iter() - .map(|c| format!("{}\n{}", c.context_header, c.content)) - .collect(); + let mut batch_iter = batches.into_iter(); + let mut in_flight = FuturesUnordered::new(); - match self.llm.embed(texts).await { - Ok(vectors) => { + // Prime up to EMBED_CONCURRENCY batches. + for _ in 0..EMBED_CONCURRENCY { + if let Some((start, end)) = batch_iter.next() { + in_flight.push(self.embed_batch(&chunks[start..end], start, end)); + } + } + + while let Some(result) = in_flight.next().await { + match result { + Ok((start, end, vectors)) => { + let batch_chunks = &chunks[start..end]; for (chunk, embedding) in batch_chunks.iter().zip(vectors) { - all_embeddings.push(CodeEmbedding { + pending.push(CodeEmbedding { id: None, repo_id: repo_id.to_string(), graph_build_id: graph_build_id.to_string(), @@ -113,9 +126,45 @@ impl RagPipeline { }); } embedded_count += batch_chunks.len() as u32; + + // Flush pending embeddings to Mongo periodically and update progress. + if pending.len() >= EMBED_FLUSH_EVERY { + self.embedding_store + .store_embeddings(&pending) + .await + .map_err(|e| { + AgentError::Other(format!("Failed to store embeddings: {e}")) + })?; + pending.clear(); + } + + // Always update the progress counter on the build doc — even if + // we haven't flushed embeddings yet — so the UI shows movement. + if let Err(e) = self + .embedding_store + .update_build( + repo_id, + graph_build_id, + EmbeddingBuildStatus::Running, + embedded_count, + None, + ) + .await + { + error!("[{repo_id}] Failed to update build progress: {e}"); + } + + // Queue the next batch to keep concurrency saturated. + if let Some((s, e)) = batch_iter.next() { + in_flight.push(self.embed_batch(&chunks[s..e], s, e)); + } } Err(e) => { error!("[{repo_id}] Embedding batch failed: {e}"); + // Flush whatever we have so partial progress isn't lost. + if !pending.is_empty() { + let _ = self.embedding_store.store_embeddings(&pending).await; + } build.status = EmbeddingBuildStatus::Failed; build.error_message = Some(e.to_string()); build.completed_at = Some(Utc::now()); @@ -134,11 +183,13 @@ impl RagPipeline { } } - // Step 4: Store all embeddings - self.embedding_store - .store_embeddings(&all_embeddings) - .await - .map_err(|e| AgentError::Other(format!("Failed to store embeddings: {e}")))?; + // Step 4: Flush any remaining embeddings + if !pending.is_empty() { + self.embedding_store + .store_embeddings(&pending) + .await + .map_err(|e| AgentError::Other(format!("Failed to store embeddings: {e}")))?; + } // Step 5: Update build status build.status = EmbeddingBuildStatus::Completed; @@ -161,4 +212,21 @@ impl RagPipeline { ); Ok(build) } + + /// Embed one batch of chunks. Returns the (start, end, vectors) tuple so + /// out-of-order completion from `FuturesUnordered` can still be reconciled + /// against the original chunk slice. + async fn embed_batch( + &self, + batch_chunks: &[compliance_graph::graph::chunking::CodeChunk], + start: usize, + end: usize, + ) -> Result<(usize, usize, Vec>), AgentError> { + let texts: Vec = batch_chunks + .iter() + .map(|c| format!("{}\n{}", c.context_header, c.content)) + .collect(); + let vectors = self.llm.embed(texts).await?; + Ok((start, end, vectors)) + } }