feat: rag-embedding-ai-chat (#1)
All checks were successful
CI / Format (push) Successful in 2s
CI / Clippy (push) Successful in 2m56s
CI / Security Audit (push) Successful in 1m25s
CI / Tests (push) Successful in 3m57s

Co-authored-by: Sharang Parnerkar <parnerkarsharang@gmail.com>
Reviewed-on: #1
This commit was merged in pull request #1.
This commit is contained in:
2026-03-06 21:54:15 +00:00
parent db454867f3
commit 42cabf0582
61 changed files with 3868 additions and 307 deletions

View File

@@ -0,0 +1,96 @@
use std::path::Path;
use compliance_core::models::graph::CodeNode;
/// A chunk of code extracted from a source file, ready for embedding
#[derive(Debug, Clone)]
pub struct CodeChunk {
pub qualified_name: String,
pub kind: String,
pub file_path: String,
pub start_line: u32,
pub end_line: u32,
pub language: String,
pub content: String,
pub context_header: String,
pub token_estimate: u32,
}
/// Extract embeddable code chunks from parsed CodeNodes.
///
/// For each node, reads the corresponding source lines from disk,
/// builds a context header, and estimates tokens.
pub fn extract_chunks(
repo_path: &Path,
nodes: &[CodeNode],
max_chunk_tokens: u32,
) -> Vec<CodeChunk> {
let mut chunks = Vec::new();
for node in nodes {
let file = repo_path.join(&node.file_path);
let source = match std::fs::read_to_string(&file) {
Ok(s) => s,
Err(_) => continue,
};
let lines: Vec<&str> = source.lines().collect();
let start = node.start_line.saturating_sub(1) as usize;
let end = (node.end_line as usize).min(lines.len());
if start >= end {
continue;
}
let content: String = lines[start..end].join("\n");
// Skip tiny chunks
if content.len() < 50 {
continue;
}
// Estimate tokens (~4 chars per token)
let mut token_estimate = (content.len() / 4) as u32;
// Truncate if too large
let final_content = if token_estimate > max_chunk_tokens {
let max_chars = (max_chunk_tokens as usize) * 4;
token_estimate = max_chunk_tokens;
content.chars().take(max_chars).collect()
} else {
content
};
// Build context header: file path + containing scope hint
let context_header = build_context_header(
&node.file_path,
&node.qualified_name,
&node.kind.to_string(),
);
chunks.push(CodeChunk {
qualified_name: node.qualified_name.clone(),
kind: node.kind.to_string(),
file_path: node.file_path.clone(),
start_line: node.start_line,
end_line: node.end_line,
language: node.language.clone(),
content: final_content,
context_header,
token_estimate,
});
}
chunks
}
fn build_context_header(file_path: &str, qualified_name: &str, kind: &str) -> String {
// Extract containing module/class from qualified name
// e.g. "src/main.rs::MyStruct::my_method" → parent is "MyStruct"
let parts: Vec<&str> = qualified_name.split("::").collect();
if parts.len() >= 2 {
let parent = parts[..parts.len() - 1].join("::");
format!("// {file_path} | {kind} in {parent}")
} else {
format!("// {file_path} | {kind}")
}
}

View File

@@ -109,8 +109,8 @@ pub fn detect_communities(code_graph: &CodeGraph) -> u32 {
let mut comm_remap: HashMap<u32, u32> = HashMap::new();
let mut next_id: u32 = 0;
for &c in community.values() {
if !comm_remap.contains_key(&c) {
comm_remap.insert(c, next_id);
if let std::collections::hash_map::Entry::Vacant(e) = comm_remap.entry(c) {
e.insert(next_id);
next_id += 1;
}
}
@@ -137,8 +137,7 @@ pub fn detect_communities(code_graph: &CodeGraph) -> u32 {
/// Apply community assignments back to code nodes
pub fn apply_communities(code_graph: &mut CodeGraph) -> u32 {
let count = detect_communities_with_assignment(code_graph);
count
detect_communities_with_assignment(code_graph)
}
/// Detect communities and write assignments into the nodes
@@ -235,8 +234,8 @@ fn detect_communities_with_assignment(code_graph: &mut CodeGraph) -> u32 {
let mut comm_remap: HashMap<u32, u32> = HashMap::new();
let mut next_id: u32 = 0;
for &c in community.values() {
if !comm_remap.contains_key(&c) {
comm_remap.insert(c, next_id);
if let std::collections::hash_map::Entry::Vacant(e) = comm_remap.entry(c) {
e.insert(next_id);
next_id += 1;
}
}

View File

@@ -0,0 +1,236 @@
use compliance_core::error::CoreError;
use compliance_core::models::embedding::{CodeEmbedding, EmbeddingBuildRun, EmbeddingBuildStatus};
use futures_util::TryStreamExt;
use mongodb::bson::doc;
use mongodb::{Collection, Database, IndexModel};
use tracing::info;
/// MongoDB persistence layer for code embeddings and vector search
pub struct EmbeddingStore {
embeddings: Collection<CodeEmbedding>,
builds: Collection<EmbeddingBuildRun>,
}
impl EmbeddingStore {
pub fn new(db: &Database) -> Self {
Self {
embeddings: db.collection("code_embeddings"),
builds: db.collection("embedding_builds"),
}
}
/// Create standard indexes. NOTE: The Atlas Vector Search index must be
/// created via the Atlas UI or CLI with the following definition:
/// ```json
/// {
/// "fields": [
/// { "type": "vector", "path": "embedding", "numDimensions": 1536, "similarity": "cosine" },
/// { "type": "filter", "path": "repo_id" }
/// ]
/// }
/// ```
pub async fn ensure_indexes(&self) -> Result<(), CoreError> {
self.embeddings
.create_index(
IndexModel::builder()
.keys(doc! { "repo_id": 1, "graph_build_id": 1 })
.build(),
)
.await?;
self.builds
.create_index(
IndexModel::builder()
.keys(doc! { "repo_id": 1, "started_at": -1 })
.build(),
)
.await?;
Ok(())
}
/// Delete all embeddings for a repository
pub async fn delete_repo_embeddings(&self, repo_id: &str) -> Result<u64, CoreError> {
let result = self
.embeddings
.delete_many(doc! { "repo_id": repo_id })
.await?;
info!(
"Deleted {} embeddings for repo {repo_id}",
result.deleted_count
);
Ok(result.deleted_count)
}
/// Store embeddings in batches of 500
pub async fn store_embeddings(&self, embeddings: &[CodeEmbedding]) -> Result<u64, CoreError> {
let mut total_inserted = 0u64;
for batch in embeddings.chunks(500) {
let result = self.embeddings.insert_many(batch).await?;
total_inserted += result.inserted_ids.len() as u64;
}
info!("Stored {total_inserted} embeddings");
Ok(total_inserted)
}
/// Store a new build run
pub async fn store_build(&self, build: &EmbeddingBuildRun) -> Result<(), CoreError> {
self.builds.insert_one(build).await?;
Ok(())
}
/// Update an existing build run
pub async fn update_build(
&self,
repo_id: &str,
graph_build_id: &str,
status: EmbeddingBuildStatus,
embedded_chunks: u32,
error_message: Option<String>,
) -> Result<(), CoreError> {
let mut update = doc! {
"$set": {
"status": mongodb::bson::to_bson(&status).unwrap_or_default(),
"embedded_chunks": embedded_chunks as i64,
}
};
if status == EmbeddingBuildStatus::Completed || status == EmbeddingBuildStatus::Failed {
if let Ok(set_doc) = update.get_document_mut("$set") {
set_doc.insert("completed_at", mongodb::bson::DateTime::now());
}
}
if let Some(msg) = error_message {
if let Ok(set_doc) = update.get_document_mut("$set") {
set_doc.insert("error_message", msg);
}
}
self.builds
.update_one(
doc! { "repo_id": repo_id, "graph_build_id": graph_build_id },
update,
)
.await?;
Ok(())
}
/// Get the latest embedding build for a repository
pub async fn get_latest_build(
&self,
repo_id: &str,
) -> Result<Option<EmbeddingBuildRun>, CoreError> {
Ok(self
.builds
.find_one(doc! { "repo_id": repo_id })
.sort(doc! { "started_at": -1 })
.await?)
}
/// Perform vector search. Tries Atlas $vectorSearch first, falls back to
/// brute-force cosine similarity for local MongoDB instances.
pub async fn vector_search(
&self,
repo_id: &str,
query_embedding: Vec<f64>,
limit: u32,
min_score: f64,
) -> Result<Vec<(CodeEmbedding, f64)>, CoreError> {
match self
.atlas_vector_search(repo_id, &query_embedding, limit, min_score)
.await
{
Ok(results) => Ok(results),
Err(e) => {
info!(
"Atlas $vectorSearch unavailable ({e}), falling back to brute-force cosine similarity"
);
self.bruteforce_vector_search(repo_id, &query_embedding, limit, min_score)
.await
}
}
}
/// Atlas $vectorSearch aggregation stage (requires Atlas Vector Search index)
async fn atlas_vector_search(
&self,
repo_id: &str,
query_embedding: &[f64],
limit: u32,
min_score: f64,
) -> Result<Vec<(CodeEmbedding, f64)>, CoreError> {
use mongodb::bson::{Bson, Document};
let pipeline = vec![
doc! {
"$vectorSearch": {
"index": "embedding_vector_index",
"path": "embedding",
"queryVector": query_embedding.iter().map(|&v| Bson::Double(v)).collect::<Vec<_>>(),
"numCandidates": (limit * 10) as i64,
"limit": limit as i64,
"filter": { "repo_id": repo_id },
}
},
doc! {
"$addFields": {
"search_score": { "$meta": "vectorSearchScore" }
}
},
doc! {
"$match": {
"search_score": { "$gte": min_score }
}
},
];
let mut cursor = self.embeddings.aggregate(pipeline).await?;
let mut results = Vec::new();
while let Some(doc) = cursor.try_next().await? {
let score = doc.get_f64("search_score").unwrap_or(0.0);
let mut clean_doc: Document = doc;
clean_doc.remove("search_score");
if let Ok(embedding) = mongodb::bson::from_document::<CodeEmbedding>(clean_doc) {
results.push((embedding, score));
}
}
Ok(results)
}
/// Brute-force cosine similarity fallback for local MongoDB without Atlas
async fn bruteforce_vector_search(
&self,
repo_id: &str,
query_embedding: &[f64],
limit: u32,
min_score: f64,
) -> Result<Vec<(CodeEmbedding, f64)>, CoreError> {
let mut cursor = self.embeddings.find(doc! { "repo_id": repo_id }).await?;
let query_norm = dot(query_embedding, query_embedding).sqrt();
let mut scored: Vec<(CodeEmbedding, f64)> = Vec::new();
while let Some(emb) = cursor.try_next().await? {
let doc_norm = dot(&emb.embedding, &emb.embedding).sqrt();
let score = if query_norm > 0.0 && doc_norm > 0.0 {
dot(query_embedding, &emb.embedding) / (query_norm * doc_norm)
} else {
0.0
};
if score >= min_score {
scored.push((emb, score));
}
}
scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
scored.truncate(limit as usize);
Ok(scored)
}
}
fn dot(a: &[f64], b: &[f64]) -> f64 {
a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
}

View File

@@ -133,10 +133,10 @@ impl GraphEngine {
}
/// Try to resolve an edge target to a known node
fn resolve_edge_target<'a>(
fn resolve_edge_target(
&self,
target: &str,
node_map: &'a HashMap<String, NodeIndex>,
node_map: &HashMap<String, NodeIndex>,
) -> Option<NodeIndex> {
// Direct match
if let Some(idx) = node_map.get(target) {

View File

@@ -26,8 +26,11 @@ impl<'a> ImpactAnalyzer<'a> {
file_path: &str,
line_number: Option<u32>,
) -> ImpactAnalysis {
let mut analysis =
ImpactAnalysis::new(repo_id.to_string(), finding_id.to_string(), graph_build_id.to_string());
let mut analysis = ImpactAnalysis::new(
repo_id.to_string(),
finding_id.to_string(),
graph_build_id.to_string(),
);
// Find the node containing the finding
let target_node = self.find_node_at_location(file_path, line_number);
@@ -97,7 +100,11 @@ impl<'a> ImpactAnalyzer<'a> {
}
/// Find the graph node at a given file/line location
fn find_node_at_location(&self, file_path: &str, line_number: Option<u32>) -> Option<NodeIndex> {
fn find_node_at_location(
&self,
file_path: &str,
line_number: Option<u32>,
) -> Option<NodeIndex> {
let mut best: Option<(NodeIndex, u32)> = None; // (index, line_span)
for node in &self.code_graph.nodes {
@@ -166,12 +173,7 @@ impl<'a> ImpactAnalyzer<'a> {
}
/// Find a path from source to target (BFS, limited depth)
fn find_path(
&self,
from: NodeIndex,
to: NodeIndex,
max_depth: usize,
) -> Option<Vec<String>> {
fn find_path(&self, from: NodeIndex, to: NodeIndex, max_depth: usize) -> Option<Vec<String>> {
let mut visited = HashSet::new();
let mut queue: VecDeque<(NodeIndex, Vec<NodeIndex>)> = VecDeque::new();
queue.push_back((from, vec![from]));
@@ -209,7 +211,10 @@ impl<'a> ImpactAnalyzer<'a> {
None
}
fn get_node_by_index(&self, idx: NodeIndex) -> Option<&compliance_core::models::graph::CodeNode> {
fn get_node_by_index(
&self,
idx: NodeIndex,
) -> Option<&compliance_core::models::graph::CodeNode> {
let target_gi = idx.index() as u32;
self.code_graph
.nodes

View File

@@ -1,4 +1,6 @@
pub mod chunking;
pub mod community;
pub mod embedding_store;
pub mod engine;
pub mod impact;
pub mod persistence;

View File

@@ -211,8 +211,6 @@ impl GraphStore {
repo_id: &str,
graph_build_id: &str,
) -> Result<Vec<CommunityInfo>, CoreError> {
let filter = doc! {
"repo_id": repo_id,
"graph_build_id": graph_build_id,

View File

@@ -1,3 +1,6 @@
#![allow(clippy::only_used_in_recursion)]
#![allow(clippy::too_many_arguments)]
pub mod graph;
pub mod parsers;
pub mod search;

View File

@@ -7,6 +7,12 @@ use tree_sitter::{Node, Parser};
pub struct JavaScriptParser;
impl Default for JavaScriptParser {
fn default() -> Self {
Self::new()
}
}
impl JavaScriptParser {
pub fn new() -> Self {
Self
@@ -51,7 +57,13 @@ impl JavaScriptParser {
if let Some(body) = node.child_by_field_name("body") {
self.extract_calls(
body, source, file_path, repo_id, graph_build_id, &qualified, output,
body,
source,
file_path,
repo_id,
graph_build_id,
&qualified,
output,
);
}
}
@@ -97,7 +109,12 @@ impl JavaScriptParser {
if let Some(body) = node.child_by_field_name("body") {
self.walk_children(
body, source, file_path, repo_id, graph_build_id, Some(&qualified),
body,
source,
file_path,
repo_id,
graph_build_id,
Some(&qualified),
output,
);
}
@@ -130,7 +147,13 @@ impl JavaScriptParser {
if let Some(body) = node.child_by_field_name("body") {
self.extract_calls(
body, source, file_path, repo_id, graph_build_id, &qualified, output,
body,
source,
file_path,
repo_id,
graph_build_id,
&qualified,
output,
);
}
}
@@ -138,7 +161,13 @@ impl JavaScriptParser {
// Arrow functions assigned to variables: const foo = () => {}
"lexical_declaration" | "variable_declaration" => {
self.extract_arrow_functions(
node, source, file_path, repo_id, graph_build_id, parent_qualified, output,
node,
source,
file_path,
repo_id,
graph_build_id,
parent_qualified,
output,
);
}
"import_statement" => {
@@ -183,7 +212,13 @@ impl JavaScriptParser {
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
self.walk_tree(
child, source, file_path, repo_id, graph_build_id, parent_qualified, output,
child,
source,
file_path,
repo_id,
graph_build_id,
parent_qualified,
output,
);
}
}
@@ -217,7 +252,13 @@ impl JavaScriptParser {
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
self.extract_calls(
child, source, file_path, repo_id, graph_build_id, caller_qualified, output,
child,
source,
file_path,
repo_id,
graph_build_id,
caller_qualified,
output,
);
}
}
@@ -263,7 +304,12 @@ impl JavaScriptParser {
if let Some(body) = value_n.child_by_field_name("body") {
self.extract_calls(
body, source, file_path, repo_id, graph_build_id, &qualified,
body,
source,
file_path,
repo_id,
graph_build_id,
&qualified,
output,
);
}

View File

@@ -7,6 +7,12 @@ use tree_sitter::{Node, Parser};
pub struct PythonParser;
impl Default for PythonParser {
fn default() -> Self {
Self::new()
}
}
impl PythonParser {
pub fn new() -> Self {
Self

View File

@@ -57,10 +57,7 @@ impl ParserRegistry {
repo_id: &str,
graph_build_id: &str,
) -> Result<Option<ParseOutput>, CoreError> {
let ext = file_path
.extension()
.and_then(|e| e.to_str())
.unwrap_or("");
let ext = file_path.extension().and_then(|e| e.to_str()).unwrap_or("");
let parser_idx = match self.extension_map.get(ext) {
Some(idx) => *idx,
@@ -89,7 +86,15 @@ impl ParserRegistry {
let mut combined = ParseOutput::default();
let mut node_count: u32 = 0;
self.walk_directory(dir, dir, repo_id, graph_build_id, max_nodes, &mut node_count, &mut combined)?;
self.walk_directory(
dir,
dir,
repo_id,
graph_build_id,
max_nodes,
&mut node_count,
&mut combined,
)?;
info!(
nodes = combined.nodes.len(),
@@ -162,8 +167,7 @@ impl ParserRegistry {
Err(_) => continue, // Skip binary/unreadable files
};
if let Some(output) = self.parse_file(rel_path, &source, repo_id, graph_build_id)?
{
if let Some(output) = self.parse_file(rel_path, &source, repo_id, graph_build_id)? {
*node_count += output.nodes.len() as u32;
combined.nodes.extend(output.nodes);
combined.edges.extend(output.edges);

View File

@@ -7,6 +7,12 @@ use tree_sitter::{Node, Parser};
pub struct RustParser;
impl Default for RustParser {
fn default() -> Self {
Self::new()
}
}
impl RustParser {
pub fn new() -> Self {
Self
@@ -196,9 +202,7 @@ impl RustParser {
id: None,
repo_id: repo_id.to_string(),
graph_build_id: graph_build_id.to_string(),
source: parent_qualified
.unwrap_or(file_path)
.to_string(),
source: parent_qualified.unwrap_or(file_path).to_string(),
target: path,
kind: CodeEdgeKind::Imports,
file_path: file_path.to_string(),
@@ -354,10 +358,7 @@ impl RustParser {
fn extract_use_path(&self, use_text: &str) -> Option<String> {
// "use foo::bar::baz;" -> "foo::bar::baz"
let trimmed = use_text
.strip_prefix("use ")?
.trim_end_matches(';')
.trim();
let trimmed = use_text.strip_prefix("use ")?.trim_end_matches(';').trim();
Some(trimmed.to_string())
}
}

View File

@@ -7,6 +7,12 @@ use tree_sitter::{Node, Parser};
pub struct TypeScriptParser;
impl Default for TypeScriptParser {
fn default() -> Self {
Self::new()
}
}
impl TypeScriptParser {
pub fn new() -> Self {
Self
@@ -49,7 +55,13 @@ impl TypeScriptParser {
if let Some(body) = node.child_by_field_name("body") {
self.extract_calls(
body, source, file_path, repo_id, graph_build_id, &qualified, output,
body,
source,
file_path,
repo_id,
graph_build_id,
&qualified,
output,
);
}
}
@@ -80,12 +92,23 @@ impl TypeScriptParser {
// Heritage clause (extends/implements)
self.extract_heritage(
&node, source, file_path, repo_id, graph_build_id, &qualified, output,
&node,
source,
file_path,
repo_id,
graph_build_id,
&qualified,
output,
);
if let Some(body) = node.child_by_field_name("body") {
self.walk_children(
body, source, file_path, repo_id, graph_build_id, Some(&qualified),
body,
source,
file_path,
repo_id,
graph_build_id,
Some(&qualified),
output,
);
}
@@ -143,14 +166,26 @@ impl TypeScriptParser {
if let Some(body) = node.child_by_field_name("body") {
self.extract_calls(
body, source, file_path, repo_id, graph_build_id, &qualified, output,
body,
source,
file_path,
repo_id,
graph_build_id,
&qualified,
output,
);
}
}
}
"lexical_declaration" | "variable_declaration" => {
self.extract_arrow_functions(
node, source, file_path, repo_id, graph_build_id, parent_qualified, output,
node,
source,
file_path,
repo_id,
graph_build_id,
parent_qualified,
output,
);
}
"import_statement" => {
@@ -172,7 +207,13 @@ impl TypeScriptParser {
}
self.walk_children(
node, source, file_path, repo_id, graph_build_id, parent_qualified, output,
node,
source,
file_path,
repo_id,
graph_build_id,
parent_qualified,
output,
);
}
@@ -189,7 +230,13 @@ impl TypeScriptParser {
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
self.walk_tree(
child, source, file_path, repo_id, graph_build_id, parent_qualified, output,
child,
source,
file_path,
repo_id,
graph_build_id,
parent_qualified,
output,
);
}
}
@@ -223,7 +270,13 @@ impl TypeScriptParser {
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
self.extract_calls(
child, source, file_path, repo_id, graph_build_id, caller_qualified, output,
child,
source,
file_path,
repo_id,
graph_build_id,
caller_qualified,
output,
);
}
}
@@ -269,7 +322,12 @@ impl TypeScriptParser {
if let Some(body) = value_n.child_by_field_name("body") {
self.extract_calls(
body, source, file_path, repo_id, graph_build_id, &qualified,
body,
source,
file_path,
repo_id,
graph_build_id,
&qualified,
output,
);
}

View File

@@ -89,8 +89,10 @@ impl SymbolIndex {
.map_err(|e| CoreError::Graph(format!("Failed to create reader: {e}")))?;
let searcher = reader.searcher();
let query_parser =
QueryParser::for_index(&self.index, vec![self.name_field, self.qualified_name_field]);
let query_parser = QueryParser::for_index(
&self.index,
vec![self.name_field, self.qualified_name_field],
);
let query = query_parser
.parse_query(query_str)