Co-authored-by: Sharang Parnerkar <parnerkarsharang@gmail.com> Reviewed-on: #1
165 lines
5.9 KiB
Rust
165 lines
5.9 KiB
Rust
use std::path::Path;
|
|
use std::sync::Arc;
|
|
|
|
use chrono::Utc;
|
|
use compliance_core::models::embedding::{CodeEmbedding, EmbeddingBuildRun, EmbeddingBuildStatus};
|
|
use compliance_core::models::graph::CodeNode;
|
|
use compliance_graph::graph::chunking::extract_chunks;
|
|
use compliance_graph::graph::embedding_store::EmbeddingStore;
|
|
use tracing::{error, info};
|
|
|
|
use crate::error::AgentError;
|
|
use crate::llm::LlmClient;
|
|
|
|
/// RAG pipeline for building embeddings and performing retrieval
|
|
pub struct RagPipeline {
|
|
llm: Arc<LlmClient>,
|
|
embedding_store: EmbeddingStore,
|
|
}
|
|
|
|
impl RagPipeline {
|
|
pub fn new(llm: Arc<LlmClient>, db: &mongodb::Database) -> Self {
|
|
Self {
|
|
llm,
|
|
embedding_store: EmbeddingStore::new(db),
|
|
}
|
|
}
|
|
|
|
pub fn store(&self) -> &EmbeddingStore {
|
|
&self.embedding_store
|
|
}
|
|
|
|
/// Build embeddings for all code nodes in a repository
|
|
pub async fn build_embeddings(
|
|
&self,
|
|
repo_id: &str,
|
|
repo_path: &Path,
|
|
graph_build_id: &str,
|
|
nodes: &[CodeNode],
|
|
) -> Result<EmbeddingBuildRun, AgentError> {
|
|
let embed_model = self.llm.embed_model().to_string();
|
|
let mut build =
|
|
EmbeddingBuildRun::new(repo_id.to_string(), graph_build_id.to_string(), embed_model);
|
|
|
|
// Step 1: Extract chunks
|
|
let chunks = extract_chunks(repo_path, nodes, 2048);
|
|
build.total_chunks = chunks.len() as u32;
|
|
info!(
|
|
"[{repo_id}] Extracted {} chunks for embedding",
|
|
chunks.len()
|
|
);
|
|
|
|
// Store the initial build record
|
|
self.embedding_store
|
|
.store_build(&build)
|
|
.await
|
|
.map_err(|e| AgentError::Other(format!("Failed to store build: {e}")))?;
|
|
|
|
if chunks.is_empty() {
|
|
build.status = EmbeddingBuildStatus::Completed;
|
|
build.completed_at = Some(Utc::now());
|
|
self.embedding_store
|
|
.update_build(
|
|
repo_id,
|
|
graph_build_id,
|
|
EmbeddingBuildStatus::Completed,
|
|
0,
|
|
None,
|
|
)
|
|
.await
|
|
.map_err(|e| AgentError::Other(format!("Failed to update build: {e}")))?;
|
|
return Ok(build);
|
|
}
|
|
|
|
// Step 2: Delete old embeddings for this repo
|
|
self.embedding_store
|
|
.delete_repo_embeddings(repo_id)
|
|
.await
|
|
.map_err(|e| AgentError::Other(format!("Failed to delete old embeddings: {e}")))?;
|
|
|
|
// Step 3: Batch embed (small batches to stay within model limits)
|
|
let batch_size = 20;
|
|
let mut all_embeddings = Vec::new();
|
|
let mut embedded_count = 0u32;
|
|
|
|
for batch_start in (0..chunks.len()).step_by(batch_size) {
|
|
let batch_end = (batch_start + batch_size).min(chunks.len());
|
|
let batch_chunks = &chunks[batch_start..batch_end];
|
|
|
|
// Prepare texts: context_header + content
|
|
let texts: Vec<String> = batch_chunks
|
|
.iter()
|
|
.map(|c| format!("{}\n{}", c.context_header, c.content))
|
|
.collect();
|
|
|
|
match self.llm.embed(texts).await {
|
|
Ok(vectors) => {
|
|
for (chunk, embedding) in batch_chunks.iter().zip(vectors) {
|
|
all_embeddings.push(CodeEmbedding {
|
|
id: None,
|
|
repo_id: repo_id.to_string(),
|
|
graph_build_id: graph_build_id.to_string(),
|
|
qualified_name: chunk.qualified_name.clone(),
|
|
kind: chunk.kind.clone(),
|
|
file_path: chunk.file_path.clone(),
|
|
start_line: chunk.start_line,
|
|
end_line: chunk.end_line,
|
|
language: chunk.language.clone(),
|
|
content: chunk.content.clone(),
|
|
context_header: chunk.context_header.clone(),
|
|
embedding,
|
|
token_estimate: chunk.token_estimate,
|
|
created_at: Utc::now(),
|
|
});
|
|
}
|
|
embedded_count += batch_chunks.len() as u32;
|
|
}
|
|
Err(e) => {
|
|
error!("[{repo_id}] Embedding batch failed: {e}");
|
|
build.status = EmbeddingBuildStatus::Failed;
|
|
build.error_message = Some(e.to_string());
|
|
build.completed_at = Some(Utc::now());
|
|
let _ = self
|
|
.embedding_store
|
|
.update_build(
|
|
repo_id,
|
|
graph_build_id,
|
|
EmbeddingBuildStatus::Failed,
|
|
embedded_count,
|
|
Some(e.to_string()),
|
|
)
|
|
.await;
|
|
return Ok(build);
|
|
}
|
|
}
|
|
}
|
|
|
|
// Step 4: Store all embeddings
|
|
self.embedding_store
|
|
.store_embeddings(&all_embeddings)
|
|
.await
|
|
.map_err(|e| AgentError::Other(format!("Failed to store embeddings: {e}")))?;
|
|
|
|
// Step 5: Update build status
|
|
build.status = EmbeddingBuildStatus::Completed;
|
|
build.embedded_chunks = embedded_count;
|
|
build.completed_at = Some(Utc::now());
|
|
self.embedding_store
|
|
.update_build(
|
|
repo_id,
|
|
graph_build_id,
|
|
EmbeddingBuildStatus::Completed,
|
|
embedded_count,
|
|
None,
|
|
)
|
|
.await
|
|
.map_err(|e| AgentError::Other(format!("Failed to update build: {e}")))?;
|
|
|
|
info!(
|
|
"[{repo_id}] Embedding build complete: {embedded_count}/{} chunks",
|
|
build.total_chunks
|
|
);
|
|
Ok(build)
|
|
}
|
|
}
|