Files
compliance-scanner-agent/compliance-graph/src/graph/embedding_store.rs
Sharang Parnerkar 42cabf0582
All checks were successful
CI / Format (push) Successful in 2s
CI / Clippy (push) Successful in 2m56s
CI / Security Audit (push) Successful in 1m25s
CI / Tests (push) Successful in 3m57s
feat: rag-embedding-ai-chat (#1)
Co-authored-by: Sharang Parnerkar <parnerkarsharang@gmail.com>
Reviewed-on: #1
2026-03-06 21:54:15 +00:00

237 lines
7.6 KiB
Rust

use compliance_core::error::CoreError;
use compliance_core::models::embedding::{CodeEmbedding, EmbeddingBuildRun, EmbeddingBuildStatus};
use futures_util::TryStreamExt;
use mongodb::bson::doc;
use mongodb::{Collection, Database, IndexModel};
use tracing::info;
/// MongoDB persistence layer for code embeddings and vector search
pub struct EmbeddingStore {
embeddings: Collection<CodeEmbedding>,
builds: Collection<EmbeddingBuildRun>,
}
impl EmbeddingStore {
pub fn new(db: &Database) -> Self {
Self {
embeddings: db.collection("code_embeddings"),
builds: db.collection("embedding_builds"),
}
}
/// Create standard indexes. NOTE: The Atlas Vector Search index must be
/// created via the Atlas UI or CLI with the following definition:
/// ```json
/// {
/// "fields": [
/// { "type": "vector", "path": "embedding", "numDimensions": 1536, "similarity": "cosine" },
/// { "type": "filter", "path": "repo_id" }
/// ]
/// }
/// ```
pub async fn ensure_indexes(&self) -> Result<(), CoreError> {
self.embeddings
.create_index(
IndexModel::builder()
.keys(doc! { "repo_id": 1, "graph_build_id": 1 })
.build(),
)
.await?;
self.builds
.create_index(
IndexModel::builder()
.keys(doc! { "repo_id": 1, "started_at": -1 })
.build(),
)
.await?;
Ok(())
}
/// Delete all embeddings for a repository
pub async fn delete_repo_embeddings(&self, repo_id: &str) -> Result<u64, CoreError> {
let result = self
.embeddings
.delete_many(doc! { "repo_id": repo_id })
.await?;
info!(
"Deleted {} embeddings for repo {repo_id}",
result.deleted_count
);
Ok(result.deleted_count)
}
/// Store embeddings in batches of 500
pub async fn store_embeddings(&self, embeddings: &[CodeEmbedding]) -> Result<u64, CoreError> {
let mut total_inserted = 0u64;
for batch in embeddings.chunks(500) {
let result = self.embeddings.insert_many(batch).await?;
total_inserted += result.inserted_ids.len() as u64;
}
info!("Stored {total_inserted} embeddings");
Ok(total_inserted)
}
/// Store a new build run
pub async fn store_build(&self, build: &EmbeddingBuildRun) -> Result<(), CoreError> {
self.builds.insert_one(build).await?;
Ok(())
}
/// Update an existing build run
pub async fn update_build(
&self,
repo_id: &str,
graph_build_id: &str,
status: EmbeddingBuildStatus,
embedded_chunks: u32,
error_message: Option<String>,
) -> Result<(), CoreError> {
let mut update = doc! {
"$set": {
"status": mongodb::bson::to_bson(&status).unwrap_or_default(),
"embedded_chunks": embedded_chunks as i64,
}
};
if status == EmbeddingBuildStatus::Completed || status == EmbeddingBuildStatus::Failed {
if let Ok(set_doc) = update.get_document_mut("$set") {
set_doc.insert("completed_at", mongodb::bson::DateTime::now());
}
}
if let Some(msg) = error_message {
if let Ok(set_doc) = update.get_document_mut("$set") {
set_doc.insert("error_message", msg);
}
}
self.builds
.update_one(
doc! { "repo_id": repo_id, "graph_build_id": graph_build_id },
update,
)
.await?;
Ok(())
}
/// Get the latest embedding build for a repository
pub async fn get_latest_build(
&self,
repo_id: &str,
) -> Result<Option<EmbeddingBuildRun>, CoreError> {
Ok(self
.builds
.find_one(doc! { "repo_id": repo_id })
.sort(doc! { "started_at": -1 })
.await?)
}
/// Perform vector search. Tries Atlas $vectorSearch first, falls back to
/// brute-force cosine similarity for local MongoDB instances.
pub async fn vector_search(
&self,
repo_id: &str,
query_embedding: Vec<f64>,
limit: u32,
min_score: f64,
) -> Result<Vec<(CodeEmbedding, f64)>, CoreError> {
match self
.atlas_vector_search(repo_id, &query_embedding, limit, min_score)
.await
{
Ok(results) => Ok(results),
Err(e) => {
info!(
"Atlas $vectorSearch unavailable ({e}), falling back to brute-force cosine similarity"
);
self.bruteforce_vector_search(repo_id, &query_embedding, limit, min_score)
.await
}
}
}
/// Atlas $vectorSearch aggregation stage (requires Atlas Vector Search index)
async fn atlas_vector_search(
&self,
repo_id: &str,
query_embedding: &[f64],
limit: u32,
min_score: f64,
) -> Result<Vec<(CodeEmbedding, f64)>, CoreError> {
use mongodb::bson::{Bson, Document};
let pipeline = vec![
doc! {
"$vectorSearch": {
"index": "embedding_vector_index",
"path": "embedding",
"queryVector": query_embedding.iter().map(|&v| Bson::Double(v)).collect::<Vec<_>>(),
"numCandidates": (limit * 10) as i64,
"limit": limit as i64,
"filter": { "repo_id": repo_id },
}
},
doc! {
"$addFields": {
"search_score": { "$meta": "vectorSearchScore" }
}
},
doc! {
"$match": {
"search_score": { "$gte": min_score }
}
},
];
let mut cursor = self.embeddings.aggregate(pipeline).await?;
let mut results = Vec::new();
while let Some(doc) = cursor.try_next().await? {
let score = doc.get_f64("search_score").unwrap_or(0.0);
let mut clean_doc: Document = doc;
clean_doc.remove("search_score");
if let Ok(embedding) = mongodb::bson::from_document::<CodeEmbedding>(clean_doc) {
results.push((embedding, score));
}
}
Ok(results)
}
/// Brute-force cosine similarity fallback for local MongoDB without Atlas
async fn bruteforce_vector_search(
&self,
repo_id: &str,
query_embedding: &[f64],
limit: u32,
min_score: f64,
) -> Result<Vec<(CodeEmbedding, f64)>, CoreError> {
let mut cursor = self.embeddings.find(doc! { "repo_id": repo_id }).await?;
let query_norm = dot(query_embedding, query_embedding).sqrt();
let mut scored: Vec<(CodeEmbedding, f64)> = Vec::new();
while let Some(emb) = cursor.try_next().await? {
let doc_norm = dot(&emb.embedding, &emb.embedding).sqrt();
let score = if query_norm > 0.0 && doc_norm > 0.0 {
dot(query_embedding, &emb.embedding) / (query_norm * doc_norm)
} else {
0.0
};
if score >= min_score {
scored.push((emb, score));
}
}
scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
scored.truncate(limit as usize);
Ok(scored)
}
}
fn dot(a: &[f64], b: &[f64]) -> f64 {
a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
}