feat: rag-embedding-ai-chat (#1)
All checks were successful
CI / Format (push) Successful in 2s
CI / Clippy (push) Successful in 2m56s
CI / Security Audit (push) Successful in 1m25s
CI / Tests (push) Successful in 3m57s

Co-authored-by: Sharang Parnerkar <parnerkarsharang@gmail.com>
Reviewed-on: #1
This commit was merged in pull request #1.
This commit is contained in:
2026-03-06 21:54:15 +00:00
parent db454867f3
commit 42cabf0582
61 changed files with 3868 additions and 307 deletions

View File

@@ -20,6 +20,7 @@ impl ComplianceAgent {
config.litellm_url.clone(),
config.litellm_api_key.clone(),
config.litellm_model.clone(),
config.litellm_embed_model.clone(),
));
Self {
config,

View File

@@ -0,0 +1,238 @@
use std::sync::Arc;
use axum::extract::{Extension, Path};
use axum::http::StatusCode;
use axum::Json;
use mongodb::bson::doc;
use compliance_core::models::chat::{ChatRequest, ChatResponse, SourceReference};
use compliance_core::models::embedding::EmbeddingBuildRun;
use compliance_graph::graph::embedding_store::EmbeddingStore;
use crate::agent::ComplianceAgent;
use crate::rag::pipeline::RagPipeline;
use super::ApiResponse;
type AgentExt = Extension<Arc<ComplianceAgent>>;
/// POST /api/v1/chat/:repo_id — Send a chat message with RAG context
pub async fn chat(
Extension(agent): AgentExt,
Path(repo_id): Path<String>,
Json(req): Json<ChatRequest>,
) -> Result<Json<ApiResponse<ChatResponse>>, StatusCode> {
let pipeline = RagPipeline::new(agent.llm.clone(), agent.db.inner());
// Step 1: Embed the user's message
let query_vectors = agent
.llm
.embed(vec![req.message.clone()])
.await
.map_err(|e| {
tracing::error!("Failed to embed query: {e}");
StatusCode::INTERNAL_SERVER_ERROR
})?;
let query_embedding = query_vectors.into_iter().next().ok_or_else(|| {
tracing::error!("Empty embedding response");
StatusCode::INTERNAL_SERVER_ERROR
})?;
// Step 2: Vector search — retrieve top 8 chunks
let search_results = pipeline
.store()
.vector_search(&repo_id, query_embedding, 8, 0.5)
.await
.map_err(|e| {
tracing::error!("Vector search failed: {e}");
StatusCode::INTERNAL_SERVER_ERROR
})?;
// Step 3: Build system prompt with code context
let mut context_parts = Vec::new();
let mut sources = Vec::new();
for (embedding, score) in &search_results {
context_parts.push(format!(
"--- {} ({}, {}:L{}-L{}) ---\n{}",
embedding.qualified_name,
embedding.kind,
embedding.file_path,
embedding.start_line,
embedding.end_line,
embedding.content,
));
// Truncate snippet for the response
let snippet: String = embedding
.content
.lines()
.take(10)
.collect::<Vec<_>>()
.join("\n");
sources.push(SourceReference {
file_path: embedding.file_path.clone(),
qualified_name: embedding.qualified_name.clone(),
start_line: embedding.start_line,
end_line: embedding.end_line,
language: embedding.language.clone(),
snippet,
score: *score,
});
}
let code_context = if context_parts.is_empty() {
"No relevant code context found.".to_string()
} else {
context_parts.join("\n\n")
};
let system_prompt = format!(
"You are an expert code assistant for a software repository. \
Answer the user's question based on the code context below. \
Reference specific files and functions when relevant. \
If the context doesn't contain enough information, say so.\n\n\
## Code Context\n\n{code_context}"
);
// Step 4: Build messages array with history
let mut messages: Vec<(String, String)> = Vec::new();
messages.push(("system".to_string(), system_prompt));
for msg in &req.history {
messages.push((msg.role.clone(), msg.content.clone()));
}
messages.push(("user".to_string(), req.message));
// Step 5: Call LLM
let response_text = agent
.llm
.chat_with_messages(messages, Some(0.3))
.await
.map_err(|e| {
tracing::error!("LLM chat failed: {e}");
StatusCode::INTERNAL_SERVER_ERROR
})?;
Ok(Json(ApiResponse {
data: ChatResponse {
message: response_text,
sources,
},
total: None,
page: None,
}))
}
/// POST /api/v1/chat/:repo_id/build-embeddings — Trigger embedding build
pub async fn build_embeddings(
Extension(agent): AgentExt,
Path(repo_id): Path<String>,
) -> Result<Json<serde_json::Value>, StatusCode> {
let agent_clone = (*agent).clone();
tokio::spawn(async move {
let repo = match agent_clone
.db
.repositories()
.find_one(doc! { "_id": mongodb::bson::oid::ObjectId::parse_str(&repo_id).ok() })
.await
{
Ok(Some(r)) => r,
_ => {
tracing::error!("Repository {repo_id} not found for embedding build");
return;
}
};
// Get latest graph build
let build = match agent_clone
.db
.graph_builds()
.find_one(doc! { "repo_id": &repo_id })
.sort(doc! { "started_at": -1 })
.await
{
Ok(Some(b)) => b,
_ => {
tracing::error!("[{repo_id}] No graph build found — build graph first");
return;
}
};
let graph_build_id = build
.id
.map(|id| id.to_hex())
.unwrap_or_else(|| "unknown".to_string());
// Get nodes
let nodes: Vec<compliance_core::models::graph::CodeNode> = match agent_clone
.db
.graph_nodes()
.find(doc! { "repo_id": &repo_id })
.await
{
Ok(cursor) => {
use futures_util::StreamExt;
let mut items = Vec::new();
let mut cursor = cursor;
while let Some(Ok(item)) = cursor.next().await {
items.push(item);
}
items
}
Err(e) => {
tracing::error!("[{repo_id}] Failed to fetch nodes: {e}");
return;
}
};
let git_ops = crate::pipeline::git::GitOps::new(&agent_clone.config.git_clone_base_path);
let repo_path = match git_ops.clone_or_fetch(&repo.git_url, &repo.name) {
Ok(p) => p,
Err(e) => {
tracing::error!("Failed to clone repo for embedding build: {e}");
return;
}
};
let pipeline = RagPipeline::new(agent_clone.llm.clone(), agent_clone.db.inner());
match pipeline
.build_embeddings(&repo_id, &repo_path, &graph_build_id, &nodes)
.await
{
Ok(run) => {
tracing::info!(
"[{repo_id}] Embedding build complete: {}/{} chunks",
run.embedded_chunks,
run.total_chunks
);
}
Err(e) => {
tracing::error!("[{repo_id}] Embedding build failed: {e}");
}
}
});
Ok(Json(
serde_json::json!({ "status": "embedding_build_triggered" }),
))
}
/// GET /api/v1/chat/:repo_id/status — Get latest embedding build status
pub async fn embedding_status(
Extension(agent): AgentExt,
Path(repo_id): Path<String>,
) -> Result<Json<ApiResponse<Option<EmbeddingBuildRun>>>, StatusCode> {
let store = EmbeddingStore::new(agent.db.inner());
let build = store.get_latest_build(&repo_id).await.map_err(|e| {
tracing::error!("Failed to get embedding status: {e}");
StatusCode::INTERNAL_SERVER_ERROR
})?;
Ok(Json(ApiResponse {
data: build,
total: None,
page: None,
}))
}

View File

@@ -103,8 +103,7 @@ pub async fn trigger_scan(
Extension(agent): AgentExt,
Path(id): Path<String>,
) -> Result<Json<serde_json::Value>, StatusCode> {
let oid =
mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?;
let oid = mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?;
let target = agent
.db
@@ -207,8 +206,7 @@ pub async fn get_finding(
Extension(agent): AgentExt,
Path(id): Path<String>,
) -> Result<Json<ApiResponse<DastFinding>>, StatusCode> {
let oid =
mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?;
let oid = mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?;
let finding = agent
.db

View File

@@ -52,7 +52,7 @@ pub async fn get_graph(
// so there is only one set of nodes/edges per repo.
let filter = doc! { "repo_id": &repo_id };
let nodes: Vec<CodeNode> = match db.graph_nodes().find(filter.clone()).await {
let all_nodes: Vec<CodeNode> = match db.graph_nodes().find(filter.clone()).await {
Ok(cursor) => collect_cursor_async(cursor).await,
Err(_) => Vec::new(),
};
@@ -60,6 +60,17 @@ pub async fn get_graph(
Ok(cursor) => collect_cursor_async(cursor).await,
Err(_) => Vec::new(),
};
// Remove disconnected nodes (no edges) to keep the graph clean
let connected: std::collections::HashSet<&str> = edges
.iter()
.flat_map(|e| [e.source.as_str(), e.target.as_str()])
.collect();
let nodes = all_nodes
.into_iter()
.filter(|n| connected.contains(n.qualified_name.as_str()))
.collect();
(nodes, edges)
} else {
(Vec::new(), Vec::new())
@@ -235,12 +246,7 @@ pub async fn get_file_content(
// Cap at 10,000 lines
let truncated: String = content.lines().take(10_000).collect::<Vec<_>>().join("\n");
let language = params
.path
.rsplit('.')
.next()
.unwrap_or("")
.to_string();
let language = params.path.rsplit('.').next().unwrap_or("").to_string();
Ok(Json(ApiResponse {
data: FileContent {

View File

@@ -1,3 +1,4 @@
pub mod chat;
pub mod dast;
pub mod graph;
@@ -5,7 +6,8 @@ use std::sync::Arc;
#[allow(unused_imports)]
use axum::extract::{Extension, Path, Query};
use axum::http::StatusCode;
use axum::http::{header, StatusCode};
use axum::response::IntoResponse;
use axum::Json;
use mongodb::bson::doc;
use serde::{Deserialize, Serialize};
@@ -89,6 +91,72 @@ pub struct UpdateStatusRequest {
pub status: String,
}
#[derive(Deserialize)]
pub struct SbomFilter {
#[serde(default)]
pub repo_id: Option<String>,
#[serde(default)]
pub package_manager: Option<String>,
#[serde(default)]
pub q: Option<String>,
#[serde(default)]
pub has_vulns: Option<bool>,
#[serde(default)]
pub license: Option<String>,
#[serde(default = "default_page")]
pub page: u64,
#[serde(default = "default_limit")]
pub limit: i64,
}
#[derive(Deserialize)]
pub struct SbomExportParams {
pub repo_id: String,
#[serde(default = "default_export_format")]
pub format: String,
}
fn default_export_format() -> String {
"cyclonedx".to_string()
}
#[derive(Deserialize)]
pub struct SbomDiffParams {
pub repo_a: String,
pub repo_b: String,
}
#[derive(Serialize)]
pub struct LicenseSummary {
pub license: String,
pub count: u64,
pub is_copyleft: bool,
pub packages: Vec<String>,
}
#[derive(Serialize)]
pub struct SbomDiffResult {
pub only_in_a: Vec<SbomDiffEntry>,
pub only_in_b: Vec<SbomDiffEntry>,
pub version_changed: Vec<SbomVersionDiff>,
pub common_count: u64,
}
#[derive(Serialize)]
pub struct SbomDiffEntry {
pub name: String,
pub version: String,
pub package_manager: String,
}
#[derive(Serialize)]
pub struct SbomVersionDiff {
pub name: String,
pub package_manager: String,
pub version_a: String,
pub version_b: String,
}
type AgentExt = Extension<Arc<ComplianceAgent>>;
type ApiResult<T> = Result<Json<ApiResponse<T>>, StatusCode>;
@@ -235,6 +303,52 @@ pub async fn trigger_scan(
Ok(Json(serde_json::json!({ "status": "scan_triggered" })))
}
pub async fn delete_repository(
Extension(agent): AgentExt,
Path(id): Path<String>,
) -> Result<Json<serde_json::Value>, StatusCode> {
let oid = mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?;
let db = &agent.db;
// Delete the repository
let result = db
.repositories()
.delete_one(doc! { "_id": oid })
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
if result.deleted_count == 0 {
return Err(StatusCode::NOT_FOUND);
}
// Cascade delete all related data
let _ = db.findings().delete_many(doc! { "repo_id": &id }).await;
let _ = db.sbom_entries().delete_many(doc! { "repo_id": &id }).await;
let _ = db.scan_runs().delete_many(doc! { "repo_id": &id }).await;
let _ = db.cve_alerts().delete_many(doc! { "repo_id": &id }).await;
let _ = db
.tracker_issues()
.delete_many(doc! { "repo_id": &id })
.await;
let _ = db.graph_nodes().delete_many(doc! { "repo_id": &id }).await;
let _ = db.graph_edges().delete_many(doc! { "repo_id": &id }).await;
let _ = db.graph_builds().delete_many(doc! { "repo_id": &id }).await;
let _ = db
.impact_analyses()
.delete_many(doc! { "repo_id": &id })
.await;
let _ = db
.code_embeddings()
.delete_many(doc! { "repo_id": &id })
.await;
let _ = db
.embedding_builds()
.delete_many(doc! { "repo_id": &id })
.await;
Ok(Json(serde_json::json!({ "status": "deleted" })))
}
pub async fn list_findings(
Extension(agent): AgentExt,
Query(filter): Query<FindingsFilter>,
@@ -322,21 +436,46 @@ pub async fn update_finding_status(
pub async fn list_sbom(
Extension(agent): AgentExt,
Query(params): Query<PaginationParams>,
Query(filter): Query<SbomFilter>,
) -> ApiResult<Vec<SbomEntry>> {
let db = &agent.db;
let skip = (params.page.saturating_sub(1)) * params.limit as u64;
let mut query = doc! {};
if let Some(repo_id) = &filter.repo_id {
query.insert("repo_id", repo_id);
}
if let Some(pm) = &filter.package_manager {
query.insert("package_manager", pm);
}
if let Some(q) = &filter.q {
if !q.is_empty() {
query.insert("name", doc! { "$regex": q, "$options": "i" });
}
}
if let Some(has_vulns) = filter.has_vulns {
if has_vulns {
query.insert("known_vulnerabilities", doc! { "$exists": true, "$ne": [] });
} else {
query.insert("known_vulnerabilities", doc! { "$size": 0 });
}
}
if let Some(license) = &filter.license {
query.insert("license", license);
}
let skip = (filter.page.saturating_sub(1)) * filter.limit as u64;
let total = db
.sbom_entries()
.count_documents(doc! {})
.count_documents(query.clone())
.await
.unwrap_or(0);
let entries = match db
.sbom_entries()
.find(doc! {})
.find(query)
.sort(doc! { "name": 1 })
.skip(skip)
.limit(params.limit)
.limit(filter.limit)
.await
{
Ok(cursor) => collect_cursor_async(cursor).await,
@@ -346,7 +485,272 @@ pub async fn list_sbom(
Ok(Json(ApiResponse {
data: entries,
total: Some(total),
page: Some(params.page),
page: Some(filter.page),
}))
}
pub async fn export_sbom(
Extension(agent): AgentExt,
Query(params): Query<SbomExportParams>,
) -> Result<impl IntoResponse, StatusCode> {
let db = &agent.db;
let entries: Vec<SbomEntry> = match db
.sbom_entries()
.find(doc! { "repo_id": &params.repo_id })
.await
{
Ok(cursor) => collect_cursor_async(cursor).await,
Err(_) => Vec::new(),
};
let body = if params.format == "spdx" {
// SPDX 2.3 format
let packages: Vec<serde_json::Value> = entries
.iter()
.enumerate()
.map(|(i, e)| {
serde_json::json!({
"SPDXID": format!("SPDXRef-Package-{i}"),
"name": e.name,
"versionInfo": e.version,
"downloadLocation": "NOASSERTION",
"licenseConcluded": e.license.as_deref().unwrap_or("NOASSERTION"),
"externalRefs": e.purl.as_ref().map(|p| vec![serde_json::json!({
"referenceCategory": "PACKAGE-MANAGER",
"referenceType": "purl",
"referenceLocator": p,
})]).unwrap_or_default(),
})
})
.collect();
serde_json::json!({
"spdxVersion": "SPDX-2.3",
"dataLicense": "CC0-1.0",
"SPDXID": "SPDXRef-DOCUMENT",
"name": format!("sbom-{}", params.repo_id),
"documentNamespace": format!("https://compliance-scanner/sbom/{}", params.repo_id),
"packages": packages,
})
} else {
// CycloneDX 1.5 format
let components: Vec<serde_json::Value> = entries
.iter()
.map(|e| {
let mut comp = serde_json::json!({
"type": "library",
"name": e.name,
"version": e.version,
"group": e.package_manager,
});
if let Some(purl) = &e.purl {
comp["purl"] = serde_json::Value::String(purl.clone());
}
if let Some(license) = &e.license {
comp["licenses"] = serde_json::json!([{ "license": { "id": license } }]);
}
if !e.known_vulnerabilities.is_empty() {
comp["vulnerabilities"] = serde_json::json!(
e.known_vulnerabilities.iter().map(|v| serde_json::json!({
"id": v.id,
"source": { "name": v.source },
"ratings": v.severity.as_ref().map(|s| vec![serde_json::json!({"severity": s})]).unwrap_or_default(),
})).collect::<Vec<_>>()
);
}
comp
})
.collect();
serde_json::json!({
"bomFormat": "CycloneDX",
"specVersion": "1.5",
"version": 1,
"metadata": {
"component": {
"type": "application",
"name": format!("repo-{}", params.repo_id),
}
},
"components": components,
})
};
let json_str =
serde_json::to_string_pretty(&body).map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
let filename = if params.format == "spdx" {
format!("sbom-{}-spdx.json", params.repo_id)
} else {
format!("sbom-{}-cyclonedx.json", params.repo_id)
};
let disposition = format!("attachment; filename=\"{filename}\"");
Ok((
[
(
header::CONTENT_TYPE,
header::HeaderValue::from_static("application/json"),
),
(
header::CONTENT_DISPOSITION,
header::HeaderValue::from_str(&disposition)
.unwrap_or_else(|_| header::HeaderValue::from_static("attachment")),
),
],
json_str,
))
}
const COPYLEFT_LICENSES: &[&str] = &[
"GPL-2.0",
"GPL-2.0-only",
"GPL-2.0-or-later",
"GPL-3.0",
"GPL-3.0-only",
"GPL-3.0-or-later",
"AGPL-3.0",
"AGPL-3.0-only",
"AGPL-3.0-or-later",
"LGPL-2.1",
"LGPL-2.1-only",
"LGPL-2.1-or-later",
"LGPL-3.0",
"LGPL-3.0-only",
"LGPL-3.0-or-later",
"MPL-2.0",
];
pub async fn license_summary(
Extension(agent): AgentExt,
Query(params): Query<SbomFilter>,
) -> ApiResult<Vec<LicenseSummary>> {
let db = &agent.db;
let mut query = doc! {};
if let Some(repo_id) = &params.repo_id {
query.insert("repo_id", repo_id);
}
let entries: Vec<SbomEntry> = match db.sbom_entries().find(query).await {
Ok(cursor) => collect_cursor_async(cursor).await,
Err(_) => Vec::new(),
};
let mut license_map: std::collections::HashMap<String, Vec<String>> =
std::collections::HashMap::new();
for entry in &entries {
let lic = entry.license.as_deref().unwrap_or("Unknown").to_string();
license_map.entry(lic).or_default().push(entry.name.clone());
}
let mut summaries: Vec<LicenseSummary> = license_map
.into_iter()
.map(|(license, packages)| {
let is_copyleft = COPYLEFT_LICENSES
.iter()
.any(|c| license.to_uppercase().contains(&c.to_uppercase()));
LicenseSummary {
license,
count: packages.len() as u64,
is_copyleft,
packages,
}
})
.collect();
summaries.sort_by(|a, b| b.count.cmp(&a.count));
Ok(Json(ApiResponse {
data: summaries,
total: None,
page: None,
}))
}
pub async fn sbom_diff(
Extension(agent): AgentExt,
Query(params): Query<SbomDiffParams>,
) -> ApiResult<SbomDiffResult> {
let db = &agent.db;
let entries_a: Vec<SbomEntry> = match db
.sbom_entries()
.find(doc! { "repo_id": &params.repo_a })
.await
{
Ok(cursor) => collect_cursor_async(cursor).await,
Err(_) => Vec::new(),
};
let entries_b: Vec<SbomEntry> = match db
.sbom_entries()
.find(doc! { "repo_id": &params.repo_b })
.await
{
Ok(cursor) => collect_cursor_async(cursor).await,
Err(_) => Vec::new(),
};
// Build maps by (name, package_manager) -> version
let map_a: std::collections::HashMap<(String, String), String> = entries_a
.iter()
.map(|e| {
(
(e.name.clone(), e.package_manager.clone()),
e.version.clone(),
)
})
.collect();
let map_b: std::collections::HashMap<(String, String), String> = entries_b
.iter()
.map(|e| {
(
(e.name.clone(), e.package_manager.clone()),
e.version.clone(),
)
})
.collect();
let mut only_in_a = Vec::new();
let mut version_changed = Vec::new();
let mut common_count: u64 = 0;
for (key, ver_a) in &map_a {
match map_b.get(key) {
None => only_in_a.push(SbomDiffEntry {
name: key.0.clone(),
version: ver_a.clone(),
package_manager: key.1.clone(),
}),
Some(ver_b) if ver_a != ver_b => {
version_changed.push(SbomVersionDiff {
name: key.0.clone(),
package_manager: key.1.clone(),
version_a: ver_a.clone(),
version_b: ver_b.clone(),
});
}
Some(_) => common_count += 1,
}
}
let only_in_b: Vec<SbomDiffEntry> = map_b
.iter()
.filter(|(key, _)| !map_a.contains_key(key))
.map(|(key, ver)| SbomDiffEntry {
name: key.0.clone(),
version: ver.clone(),
package_manager: key.1.clone(),
})
.collect();
Ok(Json(ApiResponse {
data: SbomDiffResult {
only_in_a,
only_in_b,
version_changed,
common_count,
},
total: None,
page: None,
}))
}

View File

@@ -1,4 +1,4 @@
use axum::routing::{get, patch, post};
use axum::routing::{delete, get, patch, post};
use axum::Router;
use crate::api::handlers;
@@ -13,6 +13,10 @@ pub fn build_router() -> Router {
"/api/v1/repositories/{id}/scan",
post(handlers::trigger_scan),
)
.route(
"/api/v1/repositories/{id}",
delete(handlers::delete_repository),
)
.route("/api/v1/findings", get(handlers::list_findings))
.route("/api/v1/findings/{id}", get(handlers::get_finding))
.route(
@@ -20,13 +24,13 @@ pub fn build_router() -> Router {
patch(handlers::update_finding_status),
)
.route("/api/v1/sbom", get(handlers::list_sbom))
.route("/api/v1/sbom/export", get(handlers::export_sbom))
.route("/api/v1/sbom/licenses", get(handlers::license_summary))
.route("/api/v1/sbom/diff", get(handlers::sbom_diff))
.route("/api/v1/issues", get(handlers::list_issues))
.route("/api/v1/scan-runs", get(handlers::list_scan_runs))
// Graph API endpoints
.route(
"/api/v1/graph/{repo_id}",
get(handlers::graph::get_graph),
)
.route("/api/v1/graph/{repo_id}", get(handlers::graph::get_graph))
.route(
"/api/v1/graph/{repo_id}/nodes",
get(handlers::graph::get_nodes),
@@ -52,14 +56,8 @@ pub fn build_router() -> Router {
post(handlers::graph::trigger_build),
)
// DAST API endpoints
.route(
"/api/v1/dast/targets",
get(handlers::dast::list_targets),
)
.route(
"/api/v1/dast/targets",
post(handlers::dast::add_target),
)
.route("/api/v1/dast/targets", get(handlers::dast::list_targets))
.route("/api/v1/dast/targets", post(handlers::dast::add_target))
.route(
"/api/v1/dast/targets/{id}/scan",
post(handlers::dast::trigger_scan),
@@ -68,12 +66,19 @@ pub fn build_router() -> Router {
"/api/v1/dast/scan-runs",
get(handlers::dast::list_scan_runs),
)
.route(
"/api/v1/dast/findings",
get(handlers::dast::list_findings),
)
.route("/api/v1/dast/findings", get(handlers::dast::list_findings))
.route(
"/api/v1/dast/findings/{id}",
get(handlers::dast::get_finding),
)
// Chat / RAG API endpoints
.route("/api/v1/chat/{repo_id}", post(handlers::chat::chat))
.route(
"/api/v1/chat/{repo_id}/build-embeddings",
post(handlers::chat::build_embeddings),
)
.route(
"/api/v1/chat/{repo_id}/status",
get(handlers::chat::embedding_status),
)
}

View File

@@ -24,6 +24,8 @@ pub fn load_config() -> Result<AgentConfig, AgentError> {
.unwrap_or_else(|| "http://localhost:4000".to_string()),
litellm_api_key: SecretString::from(env_var_opt("LITELLM_API_KEY").unwrap_or_default()),
litellm_model: env_var_opt("LITELLM_MODEL").unwrap_or_else(|| "gpt-4o".to_string()),
litellm_embed_model: env_var_opt("LITELLM_EMBED_MODEL")
.unwrap_or_else(|| "text-embedding-3-small".to_string()),
github_token: env_secret_opt("GITHUB_TOKEN"),
github_webhook_secret: env_secret_opt("GITHUB_WEBHOOK_SECRET"),
gitlab_url: env_var_opt("GITLAB_URL"),

View File

@@ -127,11 +127,7 @@ impl Database {
// dast_targets: index on repo_id
self.dast_targets()
.create_index(
IndexModel::builder()
.keys(doc! { "repo_id": 1 })
.build(),
)
.create_index(IndexModel::builder().keys(doc! { "repo_id": 1 }).build())
.await?;
// dast_scan_runs: compound (target_id, started_at DESC)
@@ -152,6 +148,24 @@ impl Database {
)
.await?;
// code_embeddings: compound (repo_id, graph_build_id)
self.code_embeddings()
.create_index(
IndexModel::builder()
.keys(doc! { "repo_id": 1, "graph_build_id": 1 })
.build(),
)
.await?;
// embedding_builds: compound (repo_id, started_at DESC)
self.embedding_builds()
.create_index(
IndexModel::builder()
.keys(doc! { "repo_id": 1, "started_at": -1 })
.build(),
)
.await?;
tracing::info!("Database indexes ensured");
Ok(())
}
@@ -210,6 +224,17 @@ impl Database {
self.inner.collection("dast_findings")
}
// Embedding collections
pub fn code_embeddings(&self) -> Collection<compliance_core::models::embedding::CodeEmbedding> {
self.inner.collection("code_embeddings")
}
pub fn embedding_builds(
&self,
) -> Collection<compliance_core::models::embedding::EmbeddingBuildRun> {
self.inner.collection("embedding_builds")
}
#[allow(dead_code)]
pub fn raw_collection(&self, name: &str) -> Collection<mongodb::bson::Document> {
self.inner.collection(name)

View File

@@ -8,6 +8,7 @@ pub struct LlmClient {
base_url: String,
api_key: SecretString,
model: String,
embed_model: String,
http: reqwest::Client,
}
@@ -42,16 +43,46 @@ struct ChatResponseMessage {
content: String,
}
/// Request body for the embeddings API
#[derive(Serialize)]
struct EmbeddingRequest {
model: String,
input: Vec<String>,
}
/// Response from the embeddings API
#[derive(Deserialize)]
struct EmbeddingResponse {
data: Vec<EmbeddingData>,
}
/// A single embedding result
#[derive(Deserialize)]
struct EmbeddingData {
embedding: Vec<f64>,
index: usize,
}
impl LlmClient {
pub fn new(base_url: String, api_key: SecretString, model: String) -> Self {
pub fn new(
base_url: String,
api_key: SecretString,
model: String,
embed_model: String,
) -> Self {
Self {
base_url,
api_key,
model,
embed_model,
http: reqwest::Client::new(),
}
}
pub fn embed_model(&self) -> &str {
&self.embed_model
}
pub async fn chat(
&self,
system_prompt: &str,
@@ -169,4 +200,49 @@ impl LlmClient {
.map(|c| c.message.content.clone())
.ok_or_else(|| AgentError::Other("Empty response from LiteLLM".to_string()))
}
/// Generate embeddings for a batch of texts
pub async fn embed(&self, texts: Vec<String>) -> Result<Vec<Vec<f64>>, AgentError> {
let url = format!("{}/v1/embeddings", self.base_url.trim_end_matches('/'));
let request_body = EmbeddingRequest {
model: self.embed_model.clone(),
input: texts,
};
let mut req = self
.http
.post(&url)
.header("content-type", "application/json")
.json(&request_body);
let key = self.api_key.expose_secret();
if !key.is_empty() {
req = req.header("Authorization", format!("Bearer {key}"));
}
let resp = req
.send()
.await
.map_err(|e| AgentError::Other(format!("Embedding request failed: {e}")))?;
if !resp.status().is_success() {
let status = resp.status();
let body = resp.text().await.unwrap_or_default();
return Err(AgentError::Other(format!(
"Embedding API returned {status}: {body}"
)));
}
let body: EmbeddingResponse = resp
.json()
.await
.map_err(|e| AgentError::Other(format!("Failed to parse embedding response: {e}")))?;
// Sort by index to maintain input order
let mut data = body.data;
data.sort_by_key(|d| d.index);
Ok(data.into_iter().map(|d| d.embedding).collect())
}
}

View File

@@ -7,6 +7,7 @@ mod database;
mod error;
mod llm;
mod pipeline;
mod rag;
mod scheduler;
#[allow(dead_code)]
mod trackers;

View File

@@ -185,7 +185,9 @@ impl PipelineOrchestrator {
// Stage 4.5: Graph Building
tracing::info!("[{repo_id}] Stage 4.5: Graph Building");
self.update_phase(scan_run_id, "graph_building").await;
let graph_context = match self.build_code_graph(&repo_path, &repo_id, &all_findings).await
let graph_context = match self
.build_code_graph(&repo_path, &repo_id, &all_findings)
.await
{
Ok(ctx) => Some(ctx),
Err(e) => {
@@ -296,9 +298,10 @@ impl PipelineOrchestrator {
let graph_build_id = uuid::Uuid::new_v4().to_string();
let engine = compliance_graph::GraphEngine::new(50_000);
let (mut code_graph, build_run) = engine
.build_graph(repo_path, repo_id, &graph_build_id)
.map_err(|e| AgentError::Other(format!("Graph build error: {e}")))?;
let (mut code_graph, build_run) =
engine
.build_graph(repo_path, repo_id, &graph_build_id)
.map_err(|e| AgentError::Other(format!("Graph build error: {e}")))?;
// Apply community detection
compliance_graph::graph::community::apply_communities(&mut code_graph);
@@ -348,15 +351,11 @@ impl PipelineOrchestrator {
use futures_util::TryStreamExt;
let filter = mongodb::bson::doc! { "repo_id": repo_id };
let targets: Vec<compliance_core::models::DastTarget> = match self
.db
.dast_targets()
.find(filter)
.await
{
Ok(cursor) => cursor.try_collect().await.unwrap_or_default(),
Err(_) => return,
};
let targets: Vec<compliance_core::models::DastTarget> =
match self.db.dast_targets().find(filter).await {
Ok(cursor) => cursor.try_collect().await.unwrap_or_default(),
Err(_) => return,
};
if targets.is_empty() {
tracing::info!("[{repo_id}] No DAST targets configured, skipping");
@@ -379,10 +378,7 @@ impl PipelineOrchestrator {
tracing::error!("Failed to store DAST finding: {e}");
}
}
tracing::info!(
"DAST scan complete: {} findings",
findings.len()
);
tracing::info!("DAST scan complete: {} findings", findings.len());
}
Err(e) => {
tracing::error!("DAST scan failed: {e}");

View File

@@ -0,0 +1 @@
pub mod pipeline;

View File

@@ -0,0 +1,164 @@
use std::path::Path;
use std::sync::Arc;
use chrono::Utc;
use compliance_core::models::embedding::{CodeEmbedding, EmbeddingBuildRun, EmbeddingBuildStatus};
use compliance_core::models::graph::CodeNode;
use compliance_graph::graph::chunking::extract_chunks;
use compliance_graph::graph::embedding_store::EmbeddingStore;
use tracing::{error, info};
use crate::error::AgentError;
use crate::llm::LlmClient;
/// RAG pipeline for building embeddings and performing retrieval
pub struct RagPipeline {
llm: Arc<LlmClient>,
embedding_store: EmbeddingStore,
}
impl RagPipeline {
pub fn new(llm: Arc<LlmClient>, db: &mongodb::Database) -> Self {
Self {
llm,
embedding_store: EmbeddingStore::new(db),
}
}
pub fn store(&self) -> &EmbeddingStore {
&self.embedding_store
}
/// Build embeddings for all code nodes in a repository
pub async fn build_embeddings(
&self,
repo_id: &str,
repo_path: &Path,
graph_build_id: &str,
nodes: &[CodeNode],
) -> Result<EmbeddingBuildRun, AgentError> {
let embed_model = self.llm.embed_model().to_string();
let mut build =
EmbeddingBuildRun::new(repo_id.to_string(), graph_build_id.to_string(), embed_model);
// Step 1: Extract chunks
let chunks = extract_chunks(repo_path, nodes, 2048);
build.total_chunks = chunks.len() as u32;
info!(
"[{repo_id}] Extracted {} chunks for embedding",
chunks.len()
);
// Store the initial build record
self.embedding_store
.store_build(&build)
.await
.map_err(|e| AgentError::Other(format!("Failed to store build: {e}")))?;
if chunks.is_empty() {
build.status = EmbeddingBuildStatus::Completed;
build.completed_at = Some(Utc::now());
self.embedding_store
.update_build(
repo_id,
graph_build_id,
EmbeddingBuildStatus::Completed,
0,
None,
)
.await
.map_err(|e| AgentError::Other(format!("Failed to update build: {e}")))?;
return Ok(build);
}
// Step 2: Delete old embeddings for this repo
self.embedding_store
.delete_repo_embeddings(repo_id)
.await
.map_err(|e| AgentError::Other(format!("Failed to delete old embeddings: {e}")))?;
// Step 3: Batch embed (small batches to stay within model limits)
let batch_size = 20;
let mut all_embeddings = Vec::new();
let mut embedded_count = 0u32;
for batch_start in (0..chunks.len()).step_by(batch_size) {
let batch_end = (batch_start + batch_size).min(chunks.len());
let batch_chunks = &chunks[batch_start..batch_end];
// Prepare texts: context_header + content
let texts: Vec<String> = batch_chunks
.iter()
.map(|c| format!("{}\n{}", c.context_header, c.content))
.collect();
match self.llm.embed(texts).await {
Ok(vectors) => {
for (chunk, embedding) in batch_chunks.iter().zip(vectors) {
all_embeddings.push(CodeEmbedding {
id: None,
repo_id: repo_id.to_string(),
graph_build_id: graph_build_id.to_string(),
qualified_name: chunk.qualified_name.clone(),
kind: chunk.kind.clone(),
file_path: chunk.file_path.clone(),
start_line: chunk.start_line,
end_line: chunk.end_line,
language: chunk.language.clone(),
content: chunk.content.clone(),
context_header: chunk.context_header.clone(),
embedding,
token_estimate: chunk.token_estimate,
created_at: Utc::now(),
});
}
embedded_count += batch_chunks.len() as u32;
}
Err(e) => {
error!("[{repo_id}] Embedding batch failed: {e}");
build.status = EmbeddingBuildStatus::Failed;
build.error_message = Some(e.to_string());
build.completed_at = Some(Utc::now());
let _ = self
.embedding_store
.update_build(
repo_id,
graph_build_id,
EmbeddingBuildStatus::Failed,
embedded_count,
Some(e.to_string()),
)
.await;
return Ok(build);
}
}
}
// Step 4: Store all embeddings
self.embedding_store
.store_embeddings(&all_embeddings)
.await
.map_err(|e| AgentError::Other(format!("Failed to store embeddings: {e}")))?;
// Step 5: Update build status
build.status = EmbeddingBuildStatus::Completed;
build.embedded_chunks = embedded_count;
build.completed_at = Some(Utc::now());
self.embedding_store
.update_build(
repo_id,
graph_build_id,
EmbeddingBuildStatus::Completed,
embedded_count,
None,
)
.await
.map_err(|e| AgentError::Other(format!("Failed to update build: {e}")))?;
info!(
"[{repo_id}] Embedding build complete: {embedded_count}/{} chunks",
build.total_chunks
);
Ok(build)
}
}