Files
compliance-scanner-agent/compliance-agent/src/rag/pipeline.rs
T
sharang 927fbc8ecb
CI / Check (push) Has been skipped
CI / Detect Changes (push) Successful in 5s
CI / Deploy Agent (push) Successful in 7m59s
CI / Deploy Dashboard (push) Has been skipped
CI / Deploy Docs (push) Has been skipped
CI / Deploy MCP (push) Has been skipped
fix: live progress + concurrency for embedding builds (#80)
2026-05-13 10:01:05 +00:00

233 lines
8.8 KiB
Rust

use std::path::Path;
use std::sync::Arc;
use chrono::Utc;
use compliance_core::models::embedding::{CodeEmbedding, EmbeddingBuildRun, EmbeddingBuildStatus};
use compliance_core::models::graph::CodeNode;
use compliance_graph::graph::chunking::extract_chunks;
use compliance_graph::graph::embedding_store::EmbeddingStore;
use futures_util::stream::{FuturesUnordered, StreamExt};
use tracing::{error, info};
use crate::error::AgentError;
use crate::llm::LlmClient;
const EMBED_BATCH_SIZE: usize = 20;
const EMBED_CONCURRENCY: usize = 4;
const EMBED_FLUSH_EVERY: usize = 200;
/// RAG pipeline for building embeddings and performing retrieval
pub struct RagPipeline {
llm: Arc<LlmClient>,
embedding_store: EmbeddingStore,
}
impl RagPipeline {
pub fn new(llm: Arc<LlmClient>, db: &mongodb::Database) -> Self {
Self {
llm,
embedding_store: EmbeddingStore::new(db),
}
}
pub fn store(&self) -> &EmbeddingStore {
&self.embedding_store
}
/// Build embeddings for all code nodes in a repository
pub async fn build_embeddings(
&self,
repo_id: &str,
repo_path: &Path,
graph_build_id: &str,
nodes: &[CodeNode],
) -> Result<EmbeddingBuildRun, AgentError> {
let embed_model = self.llm.embed_model().to_string();
let mut build =
EmbeddingBuildRun::new(repo_id.to_string(), graph_build_id.to_string(), embed_model);
// Step 1: Extract chunks
let chunks = extract_chunks(repo_path, nodes, 2048);
build.total_chunks = chunks.len() as u32;
info!(
"[{repo_id}] Extracted {} chunks for embedding",
chunks.len()
);
// Store the initial build record
self.embedding_store
.store_build(&build)
.await
.map_err(|e| AgentError::Other(format!("Failed to store build: {e}")))?;
if chunks.is_empty() {
build.status = EmbeddingBuildStatus::Completed;
build.completed_at = Some(Utc::now());
self.embedding_store
.update_build(
repo_id,
graph_build_id,
EmbeddingBuildStatus::Completed,
0,
None,
)
.await
.map_err(|e| AgentError::Other(format!("Failed to update build: {e}")))?;
return Ok(build);
}
// Step 2: Delete old embeddings for this repo
self.embedding_store
.delete_repo_embeddings(repo_id)
.await
.map_err(|e| AgentError::Other(format!("Failed to delete old embeddings: {e}")))?;
// Step 3: Batch embed with bounded concurrency. Flush to Mongo and
// update progress periodically so the dashboard can show live status.
let mut pending = Vec::with_capacity(EMBED_FLUSH_EVERY);
let mut embedded_count = 0u32;
// Build the list of batch indices to process.
let batches: Vec<(usize, usize)> = (0..chunks.len())
.step_by(EMBED_BATCH_SIZE)
.map(|start| (start, (start + EMBED_BATCH_SIZE).min(chunks.len())))
.collect();
let mut batch_iter = batches.into_iter();
let mut in_flight = FuturesUnordered::new();
// Prime up to EMBED_CONCURRENCY batches.
for _ in 0..EMBED_CONCURRENCY {
if let Some((start, end)) = batch_iter.next() {
in_flight.push(self.embed_batch(&chunks[start..end], start, end));
}
}
while let Some(result) = in_flight.next().await {
match result {
Ok((start, end, vectors)) => {
let batch_chunks = &chunks[start..end];
for (chunk, embedding) in batch_chunks.iter().zip(vectors) {
pending.push(CodeEmbedding {
id: None,
repo_id: repo_id.to_string(),
graph_build_id: graph_build_id.to_string(),
qualified_name: chunk.qualified_name.clone(),
kind: chunk.kind.clone(),
file_path: chunk.file_path.clone(),
start_line: chunk.start_line,
end_line: chunk.end_line,
language: chunk.language.clone(),
content: chunk.content.clone(),
context_header: chunk.context_header.clone(),
embedding,
token_estimate: chunk.token_estimate,
created_at: Utc::now(),
});
}
embedded_count += batch_chunks.len() as u32;
// Flush pending embeddings to Mongo periodically and update progress.
if pending.len() >= EMBED_FLUSH_EVERY {
self.embedding_store
.store_embeddings(&pending)
.await
.map_err(|e| {
AgentError::Other(format!("Failed to store embeddings: {e}"))
})?;
pending.clear();
}
// Always update the progress counter on the build doc — even if
// we haven't flushed embeddings yet — so the UI shows movement.
if let Err(e) = self
.embedding_store
.update_build(
repo_id,
graph_build_id,
EmbeddingBuildStatus::Running,
embedded_count,
None,
)
.await
{
error!("[{repo_id}] Failed to update build progress: {e}");
}
// Queue the next batch to keep concurrency saturated.
if let Some((s, e)) = batch_iter.next() {
in_flight.push(self.embed_batch(&chunks[s..e], s, e));
}
}
Err(e) => {
error!("[{repo_id}] Embedding batch failed: {e}");
// Flush whatever we have so partial progress isn't lost.
if !pending.is_empty() {
let _ = self.embedding_store.store_embeddings(&pending).await;
}
build.status = EmbeddingBuildStatus::Failed;
build.error_message = Some(e.to_string());
build.completed_at = Some(Utc::now());
let _ = self
.embedding_store
.update_build(
repo_id,
graph_build_id,
EmbeddingBuildStatus::Failed,
embedded_count,
Some(e.to_string()),
)
.await;
return Ok(build);
}
}
}
// Step 4: Flush any remaining embeddings
if !pending.is_empty() {
self.embedding_store
.store_embeddings(&pending)
.await
.map_err(|e| AgentError::Other(format!("Failed to store embeddings: {e}")))?;
}
// Step 5: Update build status
build.status = EmbeddingBuildStatus::Completed;
build.embedded_chunks = embedded_count;
build.completed_at = Some(Utc::now());
self.embedding_store
.update_build(
repo_id,
graph_build_id,
EmbeddingBuildStatus::Completed,
embedded_count,
None,
)
.await
.map_err(|e| AgentError::Other(format!("Failed to update build: {e}")))?;
info!(
"[{repo_id}] Embedding build complete: {embedded_count}/{} chunks",
build.total_chunks
);
Ok(build)
}
/// Embed one batch of chunks. Returns the (start, end, vectors) tuple so
/// out-of-order completion from `FuturesUnordered` can still be reconciled
/// against the original chunk slice.
async fn embed_batch(
&self,
batch_chunks: &[compliance_graph::graph::chunking::CodeChunk],
start: usize,
end: usize,
) -> Result<(usize, usize, Vec<Vec<f64>>), AgentError> {
let texts: Vec<String> = batch_chunks
.iter()
.map(|c| format!("{}\n{}", c.context_header, c.content))
.collect();
let vectors = self.llm.embed(texts).await?;
Ok((start, end, vectors))
}
}