use compliance_core::error::CoreError; use compliance_core::models::embedding::{CodeEmbedding, EmbeddingBuildRun, EmbeddingBuildStatus}; use futures_util::TryStreamExt; use mongodb::bson::doc; use mongodb::{Collection, Database, IndexModel}; use tracing::info; /// MongoDB persistence layer for code embeddings and vector search pub struct EmbeddingStore { embeddings: Collection, builds: Collection, } impl EmbeddingStore { pub fn new(db: &Database) -> Self { Self { embeddings: db.collection("code_embeddings"), builds: db.collection("embedding_builds"), } } /// Create standard indexes. NOTE: The Atlas Vector Search index must be /// created via the Atlas UI or CLI with the following definition: /// ```json /// { /// "fields": [ /// { "type": "vector", "path": "embedding", "numDimensions": 1536, "similarity": "cosine" }, /// { "type": "filter", "path": "repo_id" } /// ] /// } /// ``` pub async fn ensure_indexes(&self) -> Result<(), CoreError> { self.embeddings .create_index( IndexModel::builder() .keys(doc! { "repo_id": 1, "graph_build_id": 1 }) .build(), ) .await?; self.builds .create_index( IndexModel::builder() .keys(doc! { "repo_id": 1, "started_at": -1 }) .build(), ) .await?; Ok(()) } /// Delete all embeddings for a repository pub async fn delete_repo_embeddings(&self, repo_id: &str) -> Result { let result = self .embeddings .delete_many(doc! { "repo_id": repo_id }) .await?; info!( "Deleted {} embeddings for repo {repo_id}", result.deleted_count ); Ok(result.deleted_count) } /// Store embeddings in batches of 500 pub async fn store_embeddings(&self, embeddings: &[CodeEmbedding]) -> Result { let mut total_inserted = 0u64; for batch in embeddings.chunks(500) { let result = self.embeddings.insert_many(batch).await?; total_inserted += result.inserted_ids.len() as u64; } info!("Stored {total_inserted} embeddings"); Ok(total_inserted) } /// Store a new build run pub async fn store_build(&self, build: &EmbeddingBuildRun) -> Result<(), CoreError> { self.builds.insert_one(build).await?; Ok(()) } /// Update an existing build run pub async fn update_build( &self, repo_id: &str, graph_build_id: &str, status: EmbeddingBuildStatus, embedded_chunks: u32, error_message: Option, ) -> Result<(), CoreError> { let mut update = doc! { "$set": { "status": mongodb::bson::to_bson(&status).unwrap_or_default(), "embedded_chunks": embedded_chunks as i64, } }; if status == EmbeddingBuildStatus::Completed || status == EmbeddingBuildStatus::Failed { if let Ok(set_doc) = update.get_document_mut("$set") { set_doc.insert("completed_at", mongodb::bson::DateTime::now()); } } if let Some(msg) = error_message { if let Ok(set_doc) = update.get_document_mut("$set") { set_doc.insert("error_message", msg); } } self.builds .update_one( doc! { "repo_id": repo_id, "graph_build_id": graph_build_id }, update, ) .await?; Ok(()) } /// Get the latest embedding build for a repository pub async fn get_latest_build( &self, repo_id: &str, ) -> Result, CoreError> { Ok(self .builds .find_one(doc! { "repo_id": repo_id }) .sort(doc! { "started_at": -1 }) .await?) } /// Perform vector search. Tries Atlas $vectorSearch first, falls back to /// brute-force cosine similarity for local MongoDB instances. pub async fn vector_search( &self, repo_id: &str, query_embedding: Vec, limit: u32, min_score: f64, ) -> Result, CoreError> { match self .atlas_vector_search(repo_id, &query_embedding, limit, min_score) .await { Ok(results) => Ok(results), Err(e) => { info!( "Atlas $vectorSearch unavailable ({e}), falling back to brute-force cosine similarity" ); self.bruteforce_vector_search(repo_id, &query_embedding, limit, min_score) .await } } } /// Atlas $vectorSearch aggregation stage (requires Atlas Vector Search index) async fn atlas_vector_search( &self, repo_id: &str, query_embedding: &[f64], limit: u32, min_score: f64, ) -> Result, CoreError> { use mongodb::bson::{Bson, Document}; let pipeline = vec![ doc! { "$vectorSearch": { "index": "embedding_vector_index", "path": "embedding", "queryVector": query_embedding.iter().map(|&v| Bson::Double(v)).collect::>(), "numCandidates": (limit * 10) as i64, "limit": limit as i64, "filter": { "repo_id": repo_id }, } }, doc! { "$addFields": { "search_score": { "$meta": "vectorSearchScore" } } }, doc! { "$match": { "search_score": { "$gte": min_score } } }, ]; let mut cursor = self.embeddings.aggregate(pipeline).await?; let mut results = Vec::new(); while let Some(doc) = cursor.try_next().await? { let score = doc.get_f64("search_score").unwrap_or(0.0); let mut clean_doc: Document = doc; clean_doc.remove("search_score"); if let Ok(embedding) = mongodb::bson::from_document::(clean_doc) { results.push((embedding, score)); } } Ok(results) } /// Brute-force cosine similarity fallback for local MongoDB without Atlas async fn bruteforce_vector_search( &self, repo_id: &str, query_embedding: &[f64], limit: u32, min_score: f64, ) -> Result, CoreError> { let mut cursor = self.embeddings.find(doc! { "repo_id": repo_id }).await?; let query_norm = dot(query_embedding, query_embedding).sqrt(); let mut scored: Vec<(CodeEmbedding, f64)> = Vec::new(); while let Some(emb) = cursor.try_next().await? { let doc_norm = dot(&emb.embedding, &emb.embedding).sqrt(); let score = if query_norm > 0.0 && doc_norm > 0.0 { dot(query_embedding, &emb.embedding) / (query_norm * doc_norm) } else { 0.0 }; if score >= min_score { scored.push((emb, score)); } } scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)); scored.truncate(limit as usize); Ok(scored) } } fn dot(a: &[f64], b: &[f64]) -> f64 { a.iter().zip(b.iter()).map(|(x, y)| x * y).sum() }