feat: rag-embedding-ai-chat (#1)
Co-authored-by: Sharang Parnerkar <parnerkarsharang@gmail.com> Reviewed-on: #1
This commit was merged in pull request #1.
This commit is contained in:
@@ -0,0 +1,96 @@
|
||||
use std::path::Path;
|
||||
|
||||
use compliance_core::models::graph::CodeNode;
|
||||
|
||||
/// A chunk of code extracted from a source file, ready for embedding
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct CodeChunk {
|
||||
pub qualified_name: String,
|
||||
pub kind: String,
|
||||
pub file_path: String,
|
||||
pub start_line: u32,
|
||||
pub end_line: u32,
|
||||
pub language: String,
|
||||
pub content: String,
|
||||
pub context_header: String,
|
||||
pub token_estimate: u32,
|
||||
}
|
||||
|
||||
/// Extract embeddable code chunks from parsed CodeNodes.
|
||||
///
|
||||
/// For each node, reads the corresponding source lines from disk,
|
||||
/// builds a context header, and estimates tokens.
|
||||
pub fn extract_chunks(
|
||||
repo_path: &Path,
|
||||
nodes: &[CodeNode],
|
||||
max_chunk_tokens: u32,
|
||||
) -> Vec<CodeChunk> {
|
||||
let mut chunks = Vec::new();
|
||||
|
||||
for node in nodes {
|
||||
let file = repo_path.join(&node.file_path);
|
||||
let source = match std::fs::read_to_string(&file) {
|
||||
Ok(s) => s,
|
||||
Err(_) => continue,
|
||||
};
|
||||
|
||||
let lines: Vec<&str> = source.lines().collect();
|
||||
let start = node.start_line.saturating_sub(1) as usize;
|
||||
let end = (node.end_line as usize).min(lines.len());
|
||||
if start >= end {
|
||||
continue;
|
||||
}
|
||||
|
||||
let content: String = lines[start..end].join("\n");
|
||||
|
||||
// Skip tiny chunks
|
||||
if content.len() < 50 {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Estimate tokens (~4 chars per token)
|
||||
let mut token_estimate = (content.len() / 4) as u32;
|
||||
|
||||
// Truncate if too large
|
||||
let final_content = if token_estimate > max_chunk_tokens {
|
||||
let max_chars = (max_chunk_tokens as usize) * 4;
|
||||
token_estimate = max_chunk_tokens;
|
||||
content.chars().take(max_chars).collect()
|
||||
} else {
|
||||
content
|
||||
};
|
||||
|
||||
// Build context header: file path + containing scope hint
|
||||
let context_header = build_context_header(
|
||||
&node.file_path,
|
||||
&node.qualified_name,
|
||||
&node.kind.to_string(),
|
||||
);
|
||||
|
||||
chunks.push(CodeChunk {
|
||||
qualified_name: node.qualified_name.clone(),
|
||||
kind: node.kind.to_string(),
|
||||
file_path: node.file_path.clone(),
|
||||
start_line: node.start_line,
|
||||
end_line: node.end_line,
|
||||
language: node.language.clone(),
|
||||
content: final_content,
|
||||
context_header,
|
||||
token_estimate,
|
||||
});
|
||||
}
|
||||
|
||||
chunks
|
||||
}
|
||||
|
||||
fn build_context_header(file_path: &str, qualified_name: &str, kind: &str) -> String {
|
||||
// Extract containing module/class from qualified name
|
||||
// e.g. "src/main.rs::MyStruct::my_method" → parent is "MyStruct"
|
||||
let parts: Vec<&str> = qualified_name.split("::").collect();
|
||||
if parts.len() >= 2 {
|
||||
let parent = parts[..parts.len() - 1].join("::");
|
||||
format!("// {file_path} | {kind} in {parent}")
|
||||
} else {
|
||||
format!("// {file_path} | {kind}")
|
||||
}
|
||||
}
|
||||
@@ -109,8 +109,8 @@ pub fn detect_communities(code_graph: &CodeGraph) -> u32 {
|
||||
let mut comm_remap: HashMap<u32, u32> = HashMap::new();
|
||||
let mut next_id: u32 = 0;
|
||||
for &c in community.values() {
|
||||
if !comm_remap.contains_key(&c) {
|
||||
comm_remap.insert(c, next_id);
|
||||
if let std::collections::hash_map::Entry::Vacant(e) = comm_remap.entry(c) {
|
||||
e.insert(next_id);
|
||||
next_id += 1;
|
||||
}
|
||||
}
|
||||
@@ -137,8 +137,7 @@ pub fn detect_communities(code_graph: &CodeGraph) -> u32 {
|
||||
|
||||
/// Apply community assignments back to code nodes
|
||||
pub fn apply_communities(code_graph: &mut CodeGraph) -> u32 {
|
||||
let count = detect_communities_with_assignment(code_graph);
|
||||
count
|
||||
detect_communities_with_assignment(code_graph)
|
||||
}
|
||||
|
||||
/// Detect communities and write assignments into the nodes
|
||||
@@ -235,8 +234,8 @@ fn detect_communities_with_assignment(code_graph: &mut CodeGraph) -> u32 {
|
||||
let mut comm_remap: HashMap<u32, u32> = HashMap::new();
|
||||
let mut next_id: u32 = 0;
|
||||
for &c in community.values() {
|
||||
if !comm_remap.contains_key(&c) {
|
||||
comm_remap.insert(c, next_id);
|
||||
if let std::collections::hash_map::Entry::Vacant(e) = comm_remap.entry(c) {
|
||||
e.insert(next_id);
|
||||
next_id += 1;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,236 @@
|
||||
use compliance_core::error::CoreError;
|
||||
use compliance_core::models::embedding::{CodeEmbedding, EmbeddingBuildRun, EmbeddingBuildStatus};
|
||||
use futures_util::TryStreamExt;
|
||||
use mongodb::bson::doc;
|
||||
use mongodb::{Collection, Database, IndexModel};
|
||||
use tracing::info;
|
||||
|
||||
/// MongoDB persistence layer for code embeddings and vector search
|
||||
pub struct EmbeddingStore {
|
||||
embeddings: Collection<CodeEmbedding>,
|
||||
builds: Collection<EmbeddingBuildRun>,
|
||||
}
|
||||
|
||||
impl EmbeddingStore {
|
||||
pub fn new(db: &Database) -> Self {
|
||||
Self {
|
||||
embeddings: db.collection("code_embeddings"),
|
||||
builds: db.collection("embedding_builds"),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create standard indexes. NOTE: The Atlas Vector Search index must be
|
||||
/// created via the Atlas UI or CLI with the following definition:
|
||||
/// ```json
|
||||
/// {
|
||||
/// "fields": [
|
||||
/// { "type": "vector", "path": "embedding", "numDimensions": 1536, "similarity": "cosine" },
|
||||
/// { "type": "filter", "path": "repo_id" }
|
||||
/// ]
|
||||
/// }
|
||||
/// ```
|
||||
pub async fn ensure_indexes(&self) -> Result<(), CoreError> {
|
||||
self.embeddings
|
||||
.create_index(
|
||||
IndexModel::builder()
|
||||
.keys(doc! { "repo_id": 1, "graph_build_id": 1 })
|
||||
.build(),
|
||||
)
|
||||
.await?;
|
||||
|
||||
self.builds
|
||||
.create_index(
|
||||
IndexModel::builder()
|
||||
.keys(doc! { "repo_id": 1, "started_at": -1 })
|
||||
.build(),
|
||||
)
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Delete all embeddings for a repository
|
||||
pub async fn delete_repo_embeddings(&self, repo_id: &str) -> Result<u64, CoreError> {
|
||||
let result = self
|
||||
.embeddings
|
||||
.delete_many(doc! { "repo_id": repo_id })
|
||||
.await?;
|
||||
info!(
|
||||
"Deleted {} embeddings for repo {repo_id}",
|
||||
result.deleted_count
|
||||
);
|
||||
Ok(result.deleted_count)
|
||||
}
|
||||
|
||||
/// Store embeddings in batches of 500
|
||||
pub async fn store_embeddings(&self, embeddings: &[CodeEmbedding]) -> Result<u64, CoreError> {
|
||||
let mut total_inserted = 0u64;
|
||||
for batch in embeddings.chunks(500) {
|
||||
let result = self.embeddings.insert_many(batch).await?;
|
||||
total_inserted += result.inserted_ids.len() as u64;
|
||||
}
|
||||
info!("Stored {total_inserted} embeddings");
|
||||
Ok(total_inserted)
|
||||
}
|
||||
|
||||
/// Store a new build run
|
||||
pub async fn store_build(&self, build: &EmbeddingBuildRun) -> Result<(), CoreError> {
|
||||
self.builds.insert_one(build).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Update an existing build run
|
||||
pub async fn update_build(
|
||||
&self,
|
||||
repo_id: &str,
|
||||
graph_build_id: &str,
|
||||
status: EmbeddingBuildStatus,
|
||||
embedded_chunks: u32,
|
||||
error_message: Option<String>,
|
||||
) -> Result<(), CoreError> {
|
||||
let mut update = doc! {
|
||||
"$set": {
|
||||
"status": mongodb::bson::to_bson(&status).unwrap_or_default(),
|
||||
"embedded_chunks": embedded_chunks as i64,
|
||||
}
|
||||
};
|
||||
|
||||
if status == EmbeddingBuildStatus::Completed || status == EmbeddingBuildStatus::Failed {
|
||||
if let Ok(set_doc) = update.get_document_mut("$set") {
|
||||
set_doc.insert("completed_at", mongodb::bson::DateTime::now());
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(msg) = error_message {
|
||||
if let Ok(set_doc) = update.get_document_mut("$set") {
|
||||
set_doc.insert("error_message", msg);
|
||||
}
|
||||
}
|
||||
|
||||
self.builds
|
||||
.update_one(
|
||||
doc! { "repo_id": repo_id, "graph_build_id": graph_build_id },
|
||||
update,
|
||||
)
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Get the latest embedding build for a repository
|
||||
pub async fn get_latest_build(
|
||||
&self,
|
||||
repo_id: &str,
|
||||
) -> Result<Option<EmbeddingBuildRun>, CoreError> {
|
||||
Ok(self
|
||||
.builds
|
||||
.find_one(doc! { "repo_id": repo_id })
|
||||
.sort(doc! { "started_at": -1 })
|
||||
.await?)
|
||||
}
|
||||
|
||||
/// Perform vector search. Tries Atlas $vectorSearch first, falls back to
|
||||
/// brute-force cosine similarity for local MongoDB instances.
|
||||
pub async fn vector_search(
|
||||
&self,
|
||||
repo_id: &str,
|
||||
query_embedding: Vec<f64>,
|
||||
limit: u32,
|
||||
min_score: f64,
|
||||
) -> Result<Vec<(CodeEmbedding, f64)>, CoreError> {
|
||||
match self
|
||||
.atlas_vector_search(repo_id, &query_embedding, limit, min_score)
|
||||
.await
|
||||
{
|
||||
Ok(results) => Ok(results),
|
||||
Err(e) => {
|
||||
info!(
|
||||
"Atlas $vectorSearch unavailable ({e}), falling back to brute-force cosine similarity"
|
||||
);
|
||||
self.bruteforce_vector_search(repo_id, &query_embedding, limit, min_score)
|
||||
.await
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Atlas $vectorSearch aggregation stage (requires Atlas Vector Search index)
|
||||
async fn atlas_vector_search(
|
||||
&self,
|
||||
repo_id: &str,
|
||||
query_embedding: &[f64],
|
||||
limit: u32,
|
||||
min_score: f64,
|
||||
) -> Result<Vec<(CodeEmbedding, f64)>, CoreError> {
|
||||
use mongodb::bson::{Bson, Document};
|
||||
|
||||
let pipeline = vec![
|
||||
doc! {
|
||||
"$vectorSearch": {
|
||||
"index": "embedding_vector_index",
|
||||
"path": "embedding",
|
||||
"queryVector": query_embedding.iter().map(|&v| Bson::Double(v)).collect::<Vec<_>>(),
|
||||
"numCandidates": (limit * 10) as i64,
|
||||
"limit": limit as i64,
|
||||
"filter": { "repo_id": repo_id },
|
||||
}
|
||||
},
|
||||
doc! {
|
||||
"$addFields": {
|
||||
"search_score": { "$meta": "vectorSearchScore" }
|
||||
}
|
||||
},
|
||||
doc! {
|
||||
"$match": {
|
||||
"search_score": { "$gte": min_score }
|
||||
}
|
||||
},
|
||||
];
|
||||
|
||||
let mut cursor = self.embeddings.aggregate(pipeline).await?;
|
||||
|
||||
let mut results = Vec::new();
|
||||
while let Some(doc) = cursor.try_next().await? {
|
||||
let score = doc.get_f64("search_score").unwrap_or(0.0);
|
||||
let mut clean_doc: Document = doc;
|
||||
clean_doc.remove("search_score");
|
||||
if let Ok(embedding) = mongodb::bson::from_document::<CodeEmbedding>(clean_doc) {
|
||||
results.push((embedding, score));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(results)
|
||||
}
|
||||
|
||||
/// Brute-force cosine similarity fallback for local MongoDB without Atlas
|
||||
async fn bruteforce_vector_search(
|
||||
&self,
|
||||
repo_id: &str,
|
||||
query_embedding: &[f64],
|
||||
limit: u32,
|
||||
min_score: f64,
|
||||
) -> Result<Vec<(CodeEmbedding, f64)>, CoreError> {
|
||||
let mut cursor = self.embeddings.find(doc! { "repo_id": repo_id }).await?;
|
||||
|
||||
let query_norm = dot(query_embedding, query_embedding).sqrt();
|
||||
let mut scored: Vec<(CodeEmbedding, f64)> = Vec::new();
|
||||
|
||||
while let Some(emb) = cursor.try_next().await? {
|
||||
let doc_norm = dot(&emb.embedding, &emb.embedding).sqrt();
|
||||
let score = if query_norm > 0.0 && doc_norm > 0.0 {
|
||||
dot(query_embedding, &emb.embedding) / (query_norm * doc_norm)
|
||||
} else {
|
||||
0.0
|
||||
};
|
||||
if score >= min_score {
|
||||
scored.push((emb, score));
|
||||
}
|
||||
}
|
||||
|
||||
scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
|
||||
scored.truncate(limit as usize);
|
||||
Ok(scored)
|
||||
}
|
||||
}
|
||||
|
||||
fn dot(a: &[f64], b: &[f64]) -> f64 {
|
||||
a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
|
||||
}
|
||||
@@ -133,10 +133,10 @@ impl GraphEngine {
|
||||
}
|
||||
|
||||
/// Try to resolve an edge target to a known node
|
||||
fn resolve_edge_target<'a>(
|
||||
fn resolve_edge_target(
|
||||
&self,
|
||||
target: &str,
|
||||
node_map: &'a HashMap<String, NodeIndex>,
|
||||
node_map: &HashMap<String, NodeIndex>,
|
||||
) -> Option<NodeIndex> {
|
||||
// Direct match
|
||||
if let Some(idx) = node_map.get(target) {
|
||||
|
||||
@@ -26,8 +26,11 @@ impl<'a> ImpactAnalyzer<'a> {
|
||||
file_path: &str,
|
||||
line_number: Option<u32>,
|
||||
) -> ImpactAnalysis {
|
||||
let mut analysis =
|
||||
ImpactAnalysis::new(repo_id.to_string(), finding_id.to_string(), graph_build_id.to_string());
|
||||
let mut analysis = ImpactAnalysis::new(
|
||||
repo_id.to_string(),
|
||||
finding_id.to_string(),
|
||||
graph_build_id.to_string(),
|
||||
);
|
||||
|
||||
// Find the node containing the finding
|
||||
let target_node = self.find_node_at_location(file_path, line_number);
|
||||
@@ -97,7 +100,11 @@ impl<'a> ImpactAnalyzer<'a> {
|
||||
}
|
||||
|
||||
/// Find the graph node at a given file/line location
|
||||
fn find_node_at_location(&self, file_path: &str, line_number: Option<u32>) -> Option<NodeIndex> {
|
||||
fn find_node_at_location(
|
||||
&self,
|
||||
file_path: &str,
|
||||
line_number: Option<u32>,
|
||||
) -> Option<NodeIndex> {
|
||||
let mut best: Option<(NodeIndex, u32)> = None; // (index, line_span)
|
||||
|
||||
for node in &self.code_graph.nodes {
|
||||
@@ -166,12 +173,7 @@ impl<'a> ImpactAnalyzer<'a> {
|
||||
}
|
||||
|
||||
/// Find a path from source to target (BFS, limited depth)
|
||||
fn find_path(
|
||||
&self,
|
||||
from: NodeIndex,
|
||||
to: NodeIndex,
|
||||
max_depth: usize,
|
||||
) -> Option<Vec<String>> {
|
||||
fn find_path(&self, from: NodeIndex, to: NodeIndex, max_depth: usize) -> Option<Vec<String>> {
|
||||
let mut visited = HashSet::new();
|
||||
let mut queue: VecDeque<(NodeIndex, Vec<NodeIndex>)> = VecDeque::new();
|
||||
queue.push_back((from, vec![from]));
|
||||
@@ -209,7 +211,10 @@ impl<'a> ImpactAnalyzer<'a> {
|
||||
None
|
||||
}
|
||||
|
||||
fn get_node_by_index(&self, idx: NodeIndex) -> Option<&compliance_core::models::graph::CodeNode> {
|
||||
fn get_node_by_index(
|
||||
&self,
|
||||
idx: NodeIndex,
|
||||
) -> Option<&compliance_core::models::graph::CodeNode> {
|
||||
let target_gi = idx.index() as u32;
|
||||
self.code_graph
|
||||
.nodes
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
pub mod chunking;
|
||||
pub mod community;
|
||||
pub mod embedding_store;
|
||||
pub mod engine;
|
||||
pub mod impact;
|
||||
pub mod persistence;
|
||||
|
||||
@@ -211,8 +211,6 @@ impl GraphStore {
|
||||
repo_id: &str,
|
||||
graph_build_id: &str,
|
||||
) -> Result<Vec<CommunityInfo>, CoreError> {
|
||||
|
||||
|
||||
let filter = doc! {
|
||||
"repo_id": repo_id,
|
||||
"graph_build_id": graph_build_id,
|
||||
|
||||
Reference in New Issue
Block a user