feat: rag-embedding-ai-chat (#1)
Co-authored-by: Sharang Parnerkar <parnerkarsharang@gmail.com> Reviewed-on: #1
This commit was merged in pull request #1.
This commit is contained in:
@@ -20,6 +20,7 @@ impl ComplianceAgent {
|
||||
config.litellm_url.clone(),
|
||||
config.litellm_api_key.clone(),
|
||||
config.litellm_model.clone(),
|
||||
config.litellm_embed_model.clone(),
|
||||
));
|
||||
Self {
|
||||
config,
|
||||
|
||||
238
compliance-agent/src/api/handlers/chat.rs
Normal file
238
compliance-agent/src/api/handlers/chat.rs
Normal file
@@ -0,0 +1,238 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use axum::extract::{Extension, Path};
|
||||
use axum::http::StatusCode;
|
||||
use axum::Json;
|
||||
use mongodb::bson::doc;
|
||||
|
||||
use compliance_core::models::chat::{ChatRequest, ChatResponse, SourceReference};
|
||||
use compliance_core::models::embedding::EmbeddingBuildRun;
|
||||
use compliance_graph::graph::embedding_store::EmbeddingStore;
|
||||
|
||||
use crate::agent::ComplianceAgent;
|
||||
use crate::rag::pipeline::RagPipeline;
|
||||
|
||||
use super::ApiResponse;
|
||||
|
||||
type AgentExt = Extension<Arc<ComplianceAgent>>;
|
||||
|
||||
/// POST /api/v1/chat/:repo_id — Send a chat message with RAG context
|
||||
pub async fn chat(
|
||||
Extension(agent): AgentExt,
|
||||
Path(repo_id): Path<String>,
|
||||
Json(req): Json<ChatRequest>,
|
||||
) -> Result<Json<ApiResponse<ChatResponse>>, StatusCode> {
|
||||
let pipeline = RagPipeline::new(agent.llm.clone(), agent.db.inner());
|
||||
|
||||
// Step 1: Embed the user's message
|
||||
let query_vectors = agent
|
||||
.llm
|
||||
.embed(vec![req.message.clone()])
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!("Failed to embed query: {e}");
|
||||
StatusCode::INTERNAL_SERVER_ERROR
|
||||
})?;
|
||||
|
||||
let query_embedding = query_vectors.into_iter().next().ok_or_else(|| {
|
||||
tracing::error!("Empty embedding response");
|
||||
StatusCode::INTERNAL_SERVER_ERROR
|
||||
})?;
|
||||
|
||||
// Step 2: Vector search — retrieve top 8 chunks
|
||||
let search_results = pipeline
|
||||
.store()
|
||||
.vector_search(&repo_id, query_embedding, 8, 0.5)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!("Vector search failed: {e}");
|
||||
StatusCode::INTERNAL_SERVER_ERROR
|
||||
})?;
|
||||
|
||||
// Step 3: Build system prompt with code context
|
||||
let mut context_parts = Vec::new();
|
||||
let mut sources = Vec::new();
|
||||
|
||||
for (embedding, score) in &search_results {
|
||||
context_parts.push(format!(
|
||||
"--- {} ({}, {}:L{}-L{}) ---\n{}",
|
||||
embedding.qualified_name,
|
||||
embedding.kind,
|
||||
embedding.file_path,
|
||||
embedding.start_line,
|
||||
embedding.end_line,
|
||||
embedding.content,
|
||||
));
|
||||
|
||||
// Truncate snippet for the response
|
||||
let snippet: String = embedding
|
||||
.content
|
||||
.lines()
|
||||
.take(10)
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n");
|
||||
sources.push(SourceReference {
|
||||
file_path: embedding.file_path.clone(),
|
||||
qualified_name: embedding.qualified_name.clone(),
|
||||
start_line: embedding.start_line,
|
||||
end_line: embedding.end_line,
|
||||
language: embedding.language.clone(),
|
||||
snippet,
|
||||
score: *score,
|
||||
});
|
||||
}
|
||||
|
||||
let code_context = if context_parts.is_empty() {
|
||||
"No relevant code context found.".to_string()
|
||||
} else {
|
||||
context_parts.join("\n\n")
|
||||
};
|
||||
|
||||
let system_prompt = format!(
|
||||
"You are an expert code assistant for a software repository. \
|
||||
Answer the user's question based on the code context below. \
|
||||
Reference specific files and functions when relevant. \
|
||||
If the context doesn't contain enough information, say so.\n\n\
|
||||
## Code Context\n\n{code_context}"
|
||||
);
|
||||
|
||||
// Step 4: Build messages array with history
|
||||
let mut messages: Vec<(String, String)> = Vec::new();
|
||||
messages.push(("system".to_string(), system_prompt));
|
||||
|
||||
for msg in &req.history {
|
||||
messages.push((msg.role.clone(), msg.content.clone()));
|
||||
}
|
||||
messages.push(("user".to_string(), req.message));
|
||||
|
||||
// Step 5: Call LLM
|
||||
let response_text = agent
|
||||
.llm
|
||||
.chat_with_messages(messages, Some(0.3))
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!("LLM chat failed: {e}");
|
||||
StatusCode::INTERNAL_SERVER_ERROR
|
||||
})?;
|
||||
|
||||
Ok(Json(ApiResponse {
|
||||
data: ChatResponse {
|
||||
message: response_text,
|
||||
sources,
|
||||
},
|
||||
total: None,
|
||||
page: None,
|
||||
}))
|
||||
}
|
||||
|
||||
/// POST /api/v1/chat/:repo_id/build-embeddings — Trigger embedding build
|
||||
pub async fn build_embeddings(
|
||||
Extension(agent): AgentExt,
|
||||
Path(repo_id): Path<String>,
|
||||
) -> Result<Json<serde_json::Value>, StatusCode> {
|
||||
let agent_clone = (*agent).clone();
|
||||
tokio::spawn(async move {
|
||||
let repo = match agent_clone
|
||||
.db
|
||||
.repositories()
|
||||
.find_one(doc! { "_id": mongodb::bson::oid::ObjectId::parse_str(&repo_id).ok() })
|
||||
.await
|
||||
{
|
||||
Ok(Some(r)) => r,
|
||||
_ => {
|
||||
tracing::error!("Repository {repo_id} not found for embedding build");
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
// Get latest graph build
|
||||
let build = match agent_clone
|
||||
.db
|
||||
.graph_builds()
|
||||
.find_one(doc! { "repo_id": &repo_id })
|
||||
.sort(doc! { "started_at": -1 })
|
||||
.await
|
||||
{
|
||||
Ok(Some(b)) => b,
|
||||
_ => {
|
||||
tracing::error!("[{repo_id}] No graph build found — build graph first");
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
let graph_build_id = build
|
||||
.id
|
||||
.map(|id| id.to_hex())
|
||||
.unwrap_or_else(|| "unknown".to_string());
|
||||
|
||||
// Get nodes
|
||||
let nodes: Vec<compliance_core::models::graph::CodeNode> = match agent_clone
|
||||
.db
|
||||
.graph_nodes()
|
||||
.find(doc! { "repo_id": &repo_id })
|
||||
.await
|
||||
{
|
||||
Ok(cursor) => {
|
||||
use futures_util::StreamExt;
|
||||
let mut items = Vec::new();
|
||||
let mut cursor = cursor;
|
||||
while let Some(Ok(item)) = cursor.next().await {
|
||||
items.push(item);
|
||||
}
|
||||
items
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!("[{repo_id}] Failed to fetch nodes: {e}");
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
let git_ops = crate::pipeline::git::GitOps::new(&agent_clone.config.git_clone_base_path);
|
||||
let repo_path = match git_ops.clone_or_fetch(&repo.git_url, &repo.name) {
|
||||
Ok(p) => p,
|
||||
Err(e) => {
|
||||
tracing::error!("Failed to clone repo for embedding build: {e}");
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
let pipeline = RagPipeline::new(agent_clone.llm.clone(), agent_clone.db.inner());
|
||||
match pipeline
|
||||
.build_embeddings(&repo_id, &repo_path, &graph_build_id, &nodes)
|
||||
.await
|
||||
{
|
||||
Ok(run) => {
|
||||
tracing::info!(
|
||||
"[{repo_id}] Embedding build complete: {}/{} chunks",
|
||||
run.embedded_chunks,
|
||||
run.total_chunks
|
||||
);
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!("[{repo_id}] Embedding build failed: {e}");
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
Ok(Json(
|
||||
serde_json::json!({ "status": "embedding_build_triggered" }),
|
||||
))
|
||||
}
|
||||
|
||||
/// GET /api/v1/chat/:repo_id/status — Get latest embedding build status
|
||||
pub async fn embedding_status(
|
||||
Extension(agent): AgentExt,
|
||||
Path(repo_id): Path<String>,
|
||||
) -> Result<Json<ApiResponse<Option<EmbeddingBuildRun>>>, StatusCode> {
|
||||
let store = EmbeddingStore::new(agent.db.inner());
|
||||
let build = store.get_latest_build(&repo_id).await.map_err(|e| {
|
||||
tracing::error!("Failed to get embedding status: {e}");
|
||||
StatusCode::INTERNAL_SERVER_ERROR
|
||||
})?;
|
||||
|
||||
Ok(Json(ApiResponse {
|
||||
data: build,
|
||||
total: None,
|
||||
page: None,
|
||||
}))
|
||||
}
|
||||
@@ -103,8 +103,7 @@ pub async fn trigger_scan(
|
||||
Extension(agent): AgentExt,
|
||||
Path(id): Path<String>,
|
||||
) -> Result<Json<serde_json::Value>, StatusCode> {
|
||||
let oid =
|
||||
mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?;
|
||||
let oid = mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?;
|
||||
|
||||
let target = agent
|
||||
.db
|
||||
@@ -207,8 +206,7 @@ pub async fn get_finding(
|
||||
Extension(agent): AgentExt,
|
||||
Path(id): Path<String>,
|
||||
) -> Result<Json<ApiResponse<DastFinding>>, StatusCode> {
|
||||
let oid =
|
||||
mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?;
|
||||
let oid = mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?;
|
||||
|
||||
let finding = agent
|
||||
.db
|
||||
|
||||
@@ -52,7 +52,7 @@ pub async fn get_graph(
|
||||
// so there is only one set of nodes/edges per repo.
|
||||
let filter = doc! { "repo_id": &repo_id };
|
||||
|
||||
let nodes: Vec<CodeNode> = match db.graph_nodes().find(filter.clone()).await {
|
||||
let all_nodes: Vec<CodeNode> = match db.graph_nodes().find(filter.clone()).await {
|
||||
Ok(cursor) => collect_cursor_async(cursor).await,
|
||||
Err(_) => Vec::new(),
|
||||
};
|
||||
@@ -60,6 +60,17 @@ pub async fn get_graph(
|
||||
Ok(cursor) => collect_cursor_async(cursor).await,
|
||||
Err(_) => Vec::new(),
|
||||
};
|
||||
|
||||
// Remove disconnected nodes (no edges) to keep the graph clean
|
||||
let connected: std::collections::HashSet<&str> = edges
|
||||
.iter()
|
||||
.flat_map(|e| [e.source.as_str(), e.target.as_str()])
|
||||
.collect();
|
||||
let nodes = all_nodes
|
||||
.into_iter()
|
||||
.filter(|n| connected.contains(n.qualified_name.as_str()))
|
||||
.collect();
|
||||
|
||||
(nodes, edges)
|
||||
} else {
|
||||
(Vec::new(), Vec::new())
|
||||
@@ -235,12 +246,7 @@ pub async fn get_file_content(
|
||||
// Cap at 10,000 lines
|
||||
let truncated: String = content.lines().take(10_000).collect::<Vec<_>>().join("\n");
|
||||
|
||||
let language = params
|
||||
.path
|
||||
.rsplit('.')
|
||||
.next()
|
||||
.unwrap_or("")
|
||||
.to_string();
|
||||
let language = params.path.rsplit('.').next().unwrap_or("").to_string();
|
||||
|
||||
Ok(Json(ApiResponse {
|
||||
data: FileContent {
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
pub mod chat;
|
||||
pub mod dast;
|
||||
pub mod graph;
|
||||
|
||||
@@ -5,7 +6,8 @@ use std::sync::Arc;
|
||||
|
||||
#[allow(unused_imports)]
|
||||
use axum::extract::{Extension, Path, Query};
|
||||
use axum::http::StatusCode;
|
||||
use axum::http::{header, StatusCode};
|
||||
use axum::response::IntoResponse;
|
||||
use axum::Json;
|
||||
use mongodb::bson::doc;
|
||||
use serde::{Deserialize, Serialize};
|
||||
@@ -89,6 +91,72 @@ pub struct UpdateStatusRequest {
|
||||
pub status: String,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct SbomFilter {
|
||||
#[serde(default)]
|
||||
pub repo_id: Option<String>,
|
||||
#[serde(default)]
|
||||
pub package_manager: Option<String>,
|
||||
#[serde(default)]
|
||||
pub q: Option<String>,
|
||||
#[serde(default)]
|
||||
pub has_vulns: Option<bool>,
|
||||
#[serde(default)]
|
||||
pub license: Option<String>,
|
||||
#[serde(default = "default_page")]
|
||||
pub page: u64,
|
||||
#[serde(default = "default_limit")]
|
||||
pub limit: i64,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct SbomExportParams {
|
||||
pub repo_id: String,
|
||||
#[serde(default = "default_export_format")]
|
||||
pub format: String,
|
||||
}
|
||||
|
||||
fn default_export_format() -> String {
|
||||
"cyclonedx".to_string()
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct SbomDiffParams {
|
||||
pub repo_a: String,
|
||||
pub repo_b: String,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub struct LicenseSummary {
|
||||
pub license: String,
|
||||
pub count: u64,
|
||||
pub is_copyleft: bool,
|
||||
pub packages: Vec<String>,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub struct SbomDiffResult {
|
||||
pub only_in_a: Vec<SbomDiffEntry>,
|
||||
pub only_in_b: Vec<SbomDiffEntry>,
|
||||
pub version_changed: Vec<SbomVersionDiff>,
|
||||
pub common_count: u64,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub struct SbomDiffEntry {
|
||||
pub name: String,
|
||||
pub version: String,
|
||||
pub package_manager: String,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub struct SbomVersionDiff {
|
||||
pub name: String,
|
||||
pub package_manager: String,
|
||||
pub version_a: String,
|
||||
pub version_b: String,
|
||||
}
|
||||
|
||||
type AgentExt = Extension<Arc<ComplianceAgent>>;
|
||||
type ApiResult<T> = Result<Json<ApiResponse<T>>, StatusCode>;
|
||||
|
||||
@@ -235,6 +303,52 @@ pub async fn trigger_scan(
|
||||
Ok(Json(serde_json::json!({ "status": "scan_triggered" })))
|
||||
}
|
||||
|
||||
pub async fn delete_repository(
|
||||
Extension(agent): AgentExt,
|
||||
Path(id): Path<String>,
|
||||
) -> Result<Json<serde_json::Value>, StatusCode> {
|
||||
let oid = mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?;
|
||||
let db = &agent.db;
|
||||
|
||||
// Delete the repository
|
||||
let result = db
|
||||
.repositories()
|
||||
.delete_one(doc! { "_id": oid })
|
||||
.await
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
|
||||
if result.deleted_count == 0 {
|
||||
return Err(StatusCode::NOT_FOUND);
|
||||
}
|
||||
|
||||
// Cascade delete all related data
|
||||
let _ = db.findings().delete_many(doc! { "repo_id": &id }).await;
|
||||
let _ = db.sbom_entries().delete_many(doc! { "repo_id": &id }).await;
|
||||
let _ = db.scan_runs().delete_many(doc! { "repo_id": &id }).await;
|
||||
let _ = db.cve_alerts().delete_many(doc! { "repo_id": &id }).await;
|
||||
let _ = db
|
||||
.tracker_issues()
|
||||
.delete_many(doc! { "repo_id": &id })
|
||||
.await;
|
||||
let _ = db.graph_nodes().delete_many(doc! { "repo_id": &id }).await;
|
||||
let _ = db.graph_edges().delete_many(doc! { "repo_id": &id }).await;
|
||||
let _ = db.graph_builds().delete_many(doc! { "repo_id": &id }).await;
|
||||
let _ = db
|
||||
.impact_analyses()
|
||||
.delete_many(doc! { "repo_id": &id })
|
||||
.await;
|
||||
let _ = db
|
||||
.code_embeddings()
|
||||
.delete_many(doc! { "repo_id": &id })
|
||||
.await;
|
||||
let _ = db
|
||||
.embedding_builds()
|
||||
.delete_many(doc! { "repo_id": &id })
|
||||
.await;
|
||||
|
||||
Ok(Json(serde_json::json!({ "status": "deleted" })))
|
||||
}
|
||||
|
||||
pub async fn list_findings(
|
||||
Extension(agent): AgentExt,
|
||||
Query(filter): Query<FindingsFilter>,
|
||||
@@ -322,21 +436,46 @@ pub async fn update_finding_status(
|
||||
|
||||
pub async fn list_sbom(
|
||||
Extension(agent): AgentExt,
|
||||
Query(params): Query<PaginationParams>,
|
||||
Query(filter): Query<SbomFilter>,
|
||||
) -> ApiResult<Vec<SbomEntry>> {
|
||||
let db = &agent.db;
|
||||
let skip = (params.page.saturating_sub(1)) * params.limit as u64;
|
||||
let mut query = doc! {};
|
||||
|
||||
if let Some(repo_id) = &filter.repo_id {
|
||||
query.insert("repo_id", repo_id);
|
||||
}
|
||||
if let Some(pm) = &filter.package_manager {
|
||||
query.insert("package_manager", pm);
|
||||
}
|
||||
if let Some(q) = &filter.q {
|
||||
if !q.is_empty() {
|
||||
query.insert("name", doc! { "$regex": q, "$options": "i" });
|
||||
}
|
||||
}
|
||||
if let Some(has_vulns) = filter.has_vulns {
|
||||
if has_vulns {
|
||||
query.insert("known_vulnerabilities", doc! { "$exists": true, "$ne": [] });
|
||||
} else {
|
||||
query.insert("known_vulnerabilities", doc! { "$size": 0 });
|
||||
}
|
||||
}
|
||||
if let Some(license) = &filter.license {
|
||||
query.insert("license", license);
|
||||
}
|
||||
|
||||
let skip = (filter.page.saturating_sub(1)) * filter.limit as u64;
|
||||
let total = db
|
||||
.sbom_entries()
|
||||
.count_documents(doc! {})
|
||||
.count_documents(query.clone())
|
||||
.await
|
||||
.unwrap_or(0);
|
||||
|
||||
let entries = match db
|
||||
.sbom_entries()
|
||||
.find(doc! {})
|
||||
.find(query)
|
||||
.sort(doc! { "name": 1 })
|
||||
.skip(skip)
|
||||
.limit(params.limit)
|
||||
.limit(filter.limit)
|
||||
.await
|
||||
{
|
||||
Ok(cursor) => collect_cursor_async(cursor).await,
|
||||
@@ -346,7 +485,272 @@ pub async fn list_sbom(
|
||||
Ok(Json(ApiResponse {
|
||||
data: entries,
|
||||
total: Some(total),
|
||||
page: Some(params.page),
|
||||
page: Some(filter.page),
|
||||
}))
|
||||
}
|
||||
|
||||
pub async fn export_sbom(
|
||||
Extension(agent): AgentExt,
|
||||
Query(params): Query<SbomExportParams>,
|
||||
) -> Result<impl IntoResponse, StatusCode> {
|
||||
let db = &agent.db;
|
||||
let entries: Vec<SbomEntry> = match db
|
||||
.sbom_entries()
|
||||
.find(doc! { "repo_id": ¶ms.repo_id })
|
||||
.await
|
||||
{
|
||||
Ok(cursor) => collect_cursor_async(cursor).await,
|
||||
Err(_) => Vec::new(),
|
||||
};
|
||||
|
||||
let body = if params.format == "spdx" {
|
||||
// SPDX 2.3 format
|
||||
let packages: Vec<serde_json::Value> = entries
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, e)| {
|
||||
serde_json::json!({
|
||||
"SPDXID": format!("SPDXRef-Package-{i}"),
|
||||
"name": e.name,
|
||||
"versionInfo": e.version,
|
||||
"downloadLocation": "NOASSERTION",
|
||||
"licenseConcluded": e.license.as_deref().unwrap_or("NOASSERTION"),
|
||||
"externalRefs": e.purl.as_ref().map(|p| vec![serde_json::json!({
|
||||
"referenceCategory": "PACKAGE-MANAGER",
|
||||
"referenceType": "purl",
|
||||
"referenceLocator": p,
|
||||
})]).unwrap_or_default(),
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
||||
serde_json::json!({
|
||||
"spdxVersion": "SPDX-2.3",
|
||||
"dataLicense": "CC0-1.0",
|
||||
"SPDXID": "SPDXRef-DOCUMENT",
|
||||
"name": format!("sbom-{}", params.repo_id),
|
||||
"documentNamespace": format!("https://compliance-scanner/sbom/{}", params.repo_id),
|
||||
"packages": packages,
|
||||
})
|
||||
} else {
|
||||
// CycloneDX 1.5 format
|
||||
let components: Vec<serde_json::Value> = entries
|
||||
.iter()
|
||||
.map(|e| {
|
||||
let mut comp = serde_json::json!({
|
||||
"type": "library",
|
||||
"name": e.name,
|
||||
"version": e.version,
|
||||
"group": e.package_manager,
|
||||
});
|
||||
if let Some(purl) = &e.purl {
|
||||
comp["purl"] = serde_json::Value::String(purl.clone());
|
||||
}
|
||||
if let Some(license) = &e.license {
|
||||
comp["licenses"] = serde_json::json!([{ "license": { "id": license } }]);
|
||||
}
|
||||
if !e.known_vulnerabilities.is_empty() {
|
||||
comp["vulnerabilities"] = serde_json::json!(
|
||||
e.known_vulnerabilities.iter().map(|v| serde_json::json!({
|
||||
"id": v.id,
|
||||
"source": { "name": v.source },
|
||||
"ratings": v.severity.as_ref().map(|s| vec![serde_json::json!({"severity": s})]).unwrap_or_default(),
|
||||
})).collect::<Vec<_>>()
|
||||
);
|
||||
}
|
||||
comp
|
||||
})
|
||||
.collect();
|
||||
|
||||
serde_json::json!({
|
||||
"bomFormat": "CycloneDX",
|
||||
"specVersion": "1.5",
|
||||
"version": 1,
|
||||
"metadata": {
|
||||
"component": {
|
||||
"type": "application",
|
||||
"name": format!("repo-{}", params.repo_id),
|
||||
}
|
||||
},
|
||||
"components": components,
|
||||
})
|
||||
};
|
||||
|
||||
let json_str =
|
||||
serde_json::to_string_pretty(&body).map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
let filename = if params.format == "spdx" {
|
||||
format!("sbom-{}-spdx.json", params.repo_id)
|
||||
} else {
|
||||
format!("sbom-{}-cyclonedx.json", params.repo_id)
|
||||
};
|
||||
|
||||
let disposition = format!("attachment; filename=\"{filename}\"");
|
||||
Ok((
|
||||
[
|
||||
(
|
||||
header::CONTENT_TYPE,
|
||||
header::HeaderValue::from_static("application/json"),
|
||||
),
|
||||
(
|
||||
header::CONTENT_DISPOSITION,
|
||||
header::HeaderValue::from_str(&disposition)
|
||||
.unwrap_or_else(|_| header::HeaderValue::from_static("attachment")),
|
||||
),
|
||||
],
|
||||
json_str,
|
||||
))
|
||||
}
|
||||
|
||||
const COPYLEFT_LICENSES: &[&str] = &[
|
||||
"GPL-2.0",
|
||||
"GPL-2.0-only",
|
||||
"GPL-2.0-or-later",
|
||||
"GPL-3.0",
|
||||
"GPL-3.0-only",
|
||||
"GPL-3.0-or-later",
|
||||
"AGPL-3.0",
|
||||
"AGPL-3.0-only",
|
||||
"AGPL-3.0-or-later",
|
||||
"LGPL-2.1",
|
||||
"LGPL-2.1-only",
|
||||
"LGPL-2.1-or-later",
|
||||
"LGPL-3.0",
|
||||
"LGPL-3.0-only",
|
||||
"LGPL-3.0-or-later",
|
||||
"MPL-2.0",
|
||||
];
|
||||
|
||||
pub async fn license_summary(
|
||||
Extension(agent): AgentExt,
|
||||
Query(params): Query<SbomFilter>,
|
||||
) -> ApiResult<Vec<LicenseSummary>> {
|
||||
let db = &agent.db;
|
||||
let mut query = doc! {};
|
||||
if let Some(repo_id) = ¶ms.repo_id {
|
||||
query.insert("repo_id", repo_id);
|
||||
}
|
||||
|
||||
let entries: Vec<SbomEntry> = match db.sbom_entries().find(query).await {
|
||||
Ok(cursor) => collect_cursor_async(cursor).await,
|
||||
Err(_) => Vec::new(),
|
||||
};
|
||||
|
||||
let mut license_map: std::collections::HashMap<String, Vec<String>> =
|
||||
std::collections::HashMap::new();
|
||||
for entry in &entries {
|
||||
let lic = entry.license.as_deref().unwrap_or("Unknown").to_string();
|
||||
license_map.entry(lic).or_default().push(entry.name.clone());
|
||||
}
|
||||
|
||||
let mut summaries: Vec<LicenseSummary> = license_map
|
||||
.into_iter()
|
||||
.map(|(license, packages)| {
|
||||
let is_copyleft = COPYLEFT_LICENSES
|
||||
.iter()
|
||||
.any(|c| license.to_uppercase().contains(&c.to_uppercase()));
|
||||
LicenseSummary {
|
||||
license,
|
||||
count: packages.len() as u64,
|
||||
is_copyleft,
|
||||
packages,
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
summaries.sort_by(|a, b| b.count.cmp(&a.count));
|
||||
|
||||
Ok(Json(ApiResponse {
|
||||
data: summaries,
|
||||
total: None,
|
||||
page: None,
|
||||
}))
|
||||
}
|
||||
|
||||
pub async fn sbom_diff(
|
||||
Extension(agent): AgentExt,
|
||||
Query(params): Query<SbomDiffParams>,
|
||||
) -> ApiResult<SbomDiffResult> {
|
||||
let db = &agent.db;
|
||||
|
||||
let entries_a: Vec<SbomEntry> = match db
|
||||
.sbom_entries()
|
||||
.find(doc! { "repo_id": ¶ms.repo_a })
|
||||
.await
|
||||
{
|
||||
Ok(cursor) => collect_cursor_async(cursor).await,
|
||||
Err(_) => Vec::new(),
|
||||
};
|
||||
|
||||
let entries_b: Vec<SbomEntry> = match db
|
||||
.sbom_entries()
|
||||
.find(doc! { "repo_id": ¶ms.repo_b })
|
||||
.await
|
||||
{
|
||||
Ok(cursor) => collect_cursor_async(cursor).await,
|
||||
Err(_) => Vec::new(),
|
||||
};
|
||||
|
||||
// Build maps by (name, package_manager) -> version
|
||||
let map_a: std::collections::HashMap<(String, String), String> = entries_a
|
||||
.iter()
|
||||
.map(|e| {
|
||||
(
|
||||
(e.name.clone(), e.package_manager.clone()),
|
||||
e.version.clone(),
|
||||
)
|
||||
})
|
||||
.collect();
|
||||
let map_b: std::collections::HashMap<(String, String), String> = entries_b
|
||||
.iter()
|
||||
.map(|e| {
|
||||
(
|
||||
(e.name.clone(), e.package_manager.clone()),
|
||||
e.version.clone(),
|
||||
)
|
||||
})
|
||||
.collect();
|
||||
|
||||
let mut only_in_a = Vec::new();
|
||||
let mut version_changed = Vec::new();
|
||||
let mut common_count: u64 = 0;
|
||||
|
||||
for (key, ver_a) in &map_a {
|
||||
match map_b.get(key) {
|
||||
None => only_in_a.push(SbomDiffEntry {
|
||||
name: key.0.clone(),
|
||||
version: ver_a.clone(),
|
||||
package_manager: key.1.clone(),
|
||||
}),
|
||||
Some(ver_b) if ver_a != ver_b => {
|
||||
version_changed.push(SbomVersionDiff {
|
||||
name: key.0.clone(),
|
||||
package_manager: key.1.clone(),
|
||||
version_a: ver_a.clone(),
|
||||
version_b: ver_b.clone(),
|
||||
});
|
||||
}
|
||||
Some(_) => common_count += 1,
|
||||
}
|
||||
}
|
||||
|
||||
let only_in_b: Vec<SbomDiffEntry> = map_b
|
||||
.iter()
|
||||
.filter(|(key, _)| !map_a.contains_key(key))
|
||||
.map(|(key, ver)| SbomDiffEntry {
|
||||
name: key.0.clone(),
|
||||
version: ver.clone(),
|
||||
package_manager: key.1.clone(),
|
||||
})
|
||||
.collect();
|
||||
|
||||
Ok(Json(ApiResponse {
|
||||
data: SbomDiffResult {
|
||||
only_in_a,
|
||||
only_in_b,
|
||||
version_changed,
|
||||
common_count,
|
||||
},
|
||||
total: None,
|
||||
page: None,
|
||||
}))
|
||||
}
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
use axum::routing::{get, patch, post};
|
||||
use axum::routing::{delete, get, patch, post};
|
||||
use axum::Router;
|
||||
|
||||
use crate::api::handlers;
|
||||
@@ -13,6 +13,10 @@ pub fn build_router() -> Router {
|
||||
"/api/v1/repositories/{id}/scan",
|
||||
post(handlers::trigger_scan),
|
||||
)
|
||||
.route(
|
||||
"/api/v1/repositories/{id}",
|
||||
delete(handlers::delete_repository),
|
||||
)
|
||||
.route("/api/v1/findings", get(handlers::list_findings))
|
||||
.route("/api/v1/findings/{id}", get(handlers::get_finding))
|
||||
.route(
|
||||
@@ -20,13 +24,13 @@ pub fn build_router() -> Router {
|
||||
patch(handlers::update_finding_status),
|
||||
)
|
||||
.route("/api/v1/sbom", get(handlers::list_sbom))
|
||||
.route("/api/v1/sbom/export", get(handlers::export_sbom))
|
||||
.route("/api/v1/sbom/licenses", get(handlers::license_summary))
|
||||
.route("/api/v1/sbom/diff", get(handlers::sbom_diff))
|
||||
.route("/api/v1/issues", get(handlers::list_issues))
|
||||
.route("/api/v1/scan-runs", get(handlers::list_scan_runs))
|
||||
// Graph API endpoints
|
||||
.route(
|
||||
"/api/v1/graph/{repo_id}",
|
||||
get(handlers::graph::get_graph),
|
||||
)
|
||||
.route("/api/v1/graph/{repo_id}", get(handlers::graph::get_graph))
|
||||
.route(
|
||||
"/api/v1/graph/{repo_id}/nodes",
|
||||
get(handlers::graph::get_nodes),
|
||||
@@ -52,14 +56,8 @@ pub fn build_router() -> Router {
|
||||
post(handlers::graph::trigger_build),
|
||||
)
|
||||
// DAST API endpoints
|
||||
.route(
|
||||
"/api/v1/dast/targets",
|
||||
get(handlers::dast::list_targets),
|
||||
)
|
||||
.route(
|
||||
"/api/v1/dast/targets",
|
||||
post(handlers::dast::add_target),
|
||||
)
|
||||
.route("/api/v1/dast/targets", get(handlers::dast::list_targets))
|
||||
.route("/api/v1/dast/targets", post(handlers::dast::add_target))
|
||||
.route(
|
||||
"/api/v1/dast/targets/{id}/scan",
|
||||
post(handlers::dast::trigger_scan),
|
||||
@@ -68,12 +66,19 @@ pub fn build_router() -> Router {
|
||||
"/api/v1/dast/scan-runs",
|
||||
get(handlers::dast::list_scan_runs),
|
||||
)
|
||||
.route(
|
||||
"/api/v1/dast/findings",
|
||||
get(handlers::dast::list_findings),
|
||||
)
|
||||
.route("/api/v1/dast/findings", get(handlers::dast::list_findings))
|
||||
.route(
|
||||
"/api/v1/dast/findings/{id}",
|
||||
get(handlers::dast::get_finding),
|
||||
)
|
||||
// Chat / RAG API endpoints
|
||||
.route("/api/v1/chat/{repo_id}", post(handlers::chat::chat))
|
||||
.route(
|
||||
"/api/v1/chat/{repo_id}/build-embeddings",
|
||||
post(handlers::chat::build_embeddings),
|
||||
)
|
||||
.route(
|
||||
"/api/v1/chat/{repo_id}/status",
|
||||
get(handlers::chat::embedding_status),
|
||||
)
|
||||
}
|
||||
|
||||
@@ -24,6 +24,8 @@ pub fn load_config() -> Result<AgentConfig, AgentError> {
|
||||
.unwrap_or_else(|| "http://localhost:4000".to_string()),
|
||||
litellm_api_key: SecretString::from(env_var_opt("LITELLM_API_KEY").unwrap_or_default()),
|
||||
litellm_model: env_var_opt("LITELLM_MODEL").unwrap_or_else(|| "gpt-4o".to_string()),
|
||||
litellm_embed_model: env_var_opt("LITELLM_EMBED_MODEL")
|
||||
.unwrap_or_else(|| "text-embedding-3-small".to_string()),
|
||||
github_token: env_secret_opt("GITHUB_TOKEN"),
|
||||
github_webhook_secret: env_secret_opt("GITHUB_WEBHOOK_SECRET"),
|
||||
gitlab_url: env_var_opt("GITLAB_URL"),
|
||||
|
||||
@@ -127,11 +127,7 @@ impl Database {
|
||||
|
||||
// dast_targets: index on repo_id
|
||||
self.dast_targets()
|
||||
.create_index(
|
||||
IndexModel::builder()
|
||||
.keys(doc! { "repo_id": 1 })
|
||||
.build(),
|
||||
)
|
||||
.create_index(IndexModel::builder().keys(doc! { "repo_id": 1 }).build())
|
||||
.await?;
|
||||
|
||||
// dast_scan_runs: compound (target_id, started_at DESC)
|
||||
@@ -152,6 +148,24 @@ impl Database {
|
||||
)
|
||||
.await?;
|
||||
|
||||
// code_embeddings: compound (repo_id, graph_build_id)
|
||||
self.code_embeddings()
|
||||
.create_index(
|
||||
IndexModel::builder()
|
||||
.keys(doc! { "repo_id": 1, "graph_build_id": 1 })
|
||||
.build(),
|
||||
)
|
||||
.await?;
|
||||
|
||||
// embedding_builds: compound (repo_id, started_at DESC)
|
||||
self.embedding_builds()
|
||||
.create_index(
|
||||
IndexModel::builder()
|
||||
.keys(doc! { "repo_id": 1, "started_at": -1 })
|
||||
.build(),
|
||||
)
|
||||
.await?;
|
||||
|
||||
tracing::info!("Database indexes ensured");
|
||||
Ok(())
|
||||
}
|
||||
@@ -210,6 +224,17 @@ impl Database {
|
||||
self.inner.collection("dast_findings")
|
||||
}
|
||||
|
||||
// Embedding collections
|
||||
pub fn code_embeddings(&self) -> Collection<compliance_core::models::embedding::CodeEmbedding> {
|
||||
self.inner.collection("code_embeddings")
|
||||
}
|
||||
|
||||
pub fn embedding_builds(
|
||||
&self,
|
||||
) -> Collection<compliance_core::models::embedding::EmbeddingBuildRun> {
|
||||
self.inner.collection("embedding_builds")
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub fn raw_collection(&self, name: &str) -> Collection<mongodb::bson::Document> {
|
||||
self.inner.collection(name)
|
||||
|
||||
@@ -8,6 +8,7 @@ pub struct LlmClient {
|
||||
base_url: String,
|
||||
api_key: SecretString,
|
||||
model: String,
|
||||
embed_model: String,
|
||||
http: reqwest::Client,
|
||||
}
|
||||
|
||||
@@ -42,16 +43,46 @@ struct ChatResponseMessage {
|
||||
content: String,
|
||||
}
|
||||
|
||||
/// Request body for the embeddings API
|
||||
#[derive(Serialize)]
|
||||
struct EmbeddingRequest {
|
||||
model: String,
|
||||
input: Vec<String>,
|
||||
}
|
||||
|
||||
/// Response from the embeddings API
|
||||
#[derive(Deserialize)]
|
||||
struct EmbeddingResponse {
|
||||
data: Vec<EmbeddingData>,
|
||||
}
|
||||
|
||||
/// A single embedding result
|
||||
#[derive(Deserialize)]
|
||||
struct EmbeddingData {
|
||||
embedding: Vec<f64>,
|
||||
index: usize,
|
||||
}
|
||||
|
||||
impl LlmClient {
|
||||
pub fn new(base_url: String, api_key: SecretString, model: String) -> Self {
|
||||
pub fn new(
|
||||
base_url: String,
|
||||
api_key: SecretString,
|
||||
model: String,
|
||||
embed_model: String,
|
||||
) -> Self {
|
||||
Self {
|
||||
base_url,
|
||||
api_key,
|
||||
model,
|
||||
embed_model,
|
||||
http: reqwest::Client::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn embed_model(&self) -> &str {
|
||||
&self.embed_model
|
||||
}
|
||||
|
||||
pub async fn chat(
|
||||
&self,
|
||||
system_prompt: &str,
|
||||
@@ -169,4 +200,49 @@ impl LlmClient {
|
||||
.map(|c| c.message.content.clone())
|
||||
.ok_or_else(|| AgentError::Other("Empty response from LiteLLM".to_string()))
|
||||
}
|
||||
|
||||
/// Generate embeddings for a batch of texts
|
||||
pub async fn embed(&self, texts: Vec<String>) -> Result<Vec<Vec<f64>>, AgentError> {
|
||||
let url = format!("{}/v1/embeddings", self.base_url.trim_end_matches('/'));
|
||||
|
||||
let request_body = EmbeddingRequest {
|
||||
model: self.embed_model.clone(),
|
||||
input: texts,
|
||||
};
|
||||
|
||||
let mut req = self
|
||||
.http
|
||||
.post(&url)
|
||||
.header("content-type", "application/json")
|
||||
.json(&request_body);
|
||||
|
||||
let key = self.api_key.expose_secret();
|
||||
if !key.is_empty() {
|
||||
req = req.header("Authorization", format!("Bearer {key}"));
|
||||
}
|
||||
|
||||
let resp = req
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| AgentError::Other(format!("Embedding request failed: {e}")))?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
let status = resp.status();
|
||||
let body = resp.text().await.unwrap_or_default();
|
||||
return Err(AgentError::Other(format!(
|
||||
"Embedding API returned {status}: {body}"
|
||||
)));
|
||||
}
|
||||
|
||||
let body: EmbeddingResponse = resp
|
||||
.json()
|
||||
.await
|
||||
.map_err(|e| AgentError::Other(format!("Failed to parse embedding response: {e}")))?;
|
||||
|
||||
// Sort by index to maintain input order
|
||||
let mut data = body.data;
|
||||
data.sort_by_key(|d| d.index);
|
||||
|
||||
Ok(data.into_iter().map(|d| d.embedding).collect())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,6 +7,7 @@ mod database;
|
||||
mod error;
|
||||
mod llm;
|
||||
mod pipeline;
|
||||
mod rag;
|
||||
mod scheduler;
|
||||
#[allow(dead_code)]
|
||||
mod trackers;
|
||||
|
||||
@@ -185,7 +185,9 @@ impl PipelineOrchestrator {
|
||||
// Stage 4.5: Graph Building
|
||||
tracing::info!("[{repo_id}] Stage 4.5: Graph Building");
|
||||
self.update_phase(scan_run_id, "graph_building").await;
|
||||
let graph_context = match self.build_code_graph(&repo_path, &repo_id, &all_findings).await
|
||||
let graph_context = match self
|
||||
.build_code_graph(&repo_path, &repo_id, &all_findings)
|
||||
.await
|
||||
{
|
||||
Ok(ctx) => Some(ctx),
|
||||
Err(e) => {
|
||||
@@ -296,9 +298,10 @@ impl PipelineOrchestrator {
|
||||
let graph_build_id = uuid::Uuid::new_v4().to_string();
|
||||
let engine = compliance_graph::GraphEngine::new(50_000);
|
||||
|
||||
let (mut code_graph, build_run) = engine
|
||||
.build_graph(repo_path, repo_id, &graph_build_id)
|
||||
.map_err(|e| AgentError::Other(format!("Graph build error: {e}")))?;
|
||||
let (mut code_graph, build_run) =
|
||||
engine
|
||||
.build_graph(repo_path, repo_id, &graph_build_id)
|
||||
.map_err(|e| AgentError::Other(format!("Graph build error: {e}")))?;
|
||||
|
||||
// Apply community detection
|
||||
compliance_graph::graph::community::apply_communities(&mut code_graph);
|
||||
@@ -348,15 +351,11 @@ impl PipelineOrchestrator {
|
||||
use futures_util::TryStreamExt;
|
||||
|
||||
let filter = mongodb::bson::doc! { "repo_id": repo_id };
|
||||
let targets: Vec<compliance_core::models::DastTarget> = match self
|
||||
.db
|
||||
.dast_targets()
|
||||
.find(filter)
|
||||
.await
|
||||
{
|
||||
Ok(cursor) => cursor.try_collect().await.unwrap_or_default(),
|
||||
Err(_) => return,
|
||||
};
|
||||
let targets: Vec<compliance_core::models::DastTarget> =
|
||||
match self.db.dast_targets().find(filter).await {
|
||||
Ok(cursor) => cursor.try_collect().await.unwrap_or_default(),
|
||||
Err(_) => return,
|
||||
};
|
||||
|
||||
if targets.is_empty() {
|
||||
tracing::info!("[{repo_id}] No DAST targets configured, skipping");
|
||||
@@ -379,10 +378,7 @@ impl PipelineOrchestrator {
|
||||
tracing::error!("Failed to store DAST finding: {e}");
|
||||
}
|
||||
}
|
||||
tracing::info!(
|
||||
"DAST scan complete: {} findings",
|
||||
findings.len()
|
||||
);
|
||||
tracing::info!("DAST scan complete: {} findings", findings.len());
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!("DAST scan failed: {e}");
|
||||
|
||||
1
compliance-agent/src/rag/mod.rs
Normal file
1
compliance-agent/src/rag/mod.rs
Normal file
@@ -0,0 +1 @@
|
||||
pub mod pipeline;
|
||||
164
compliance-agent/src/rag/pipeline.rs
Normal file
164
compliance-agent/src/rag/pipeline.rs
Normal file
@@ -0,0 +1,164 @@
|
||||
use std::path::Path;
|
||||
use std::sync::Arc;
|
||||
|
||||
use chrono::Utc;
|
||||
use compliance_core::models::embedding::{CodeEmbedding, EmbeddingBuildRun, EmbeddingBuildStatus};
|
||||
use compliance_core::models::graph::CodeNode;
|
||||
use compliance_graph::graph::chunking::extract_chunks;
|
||||
use compliance_graph::graph::embedding_store::EmbeddingStore;
|
||||
use tracing::{error, info};
|
||||
|
||||
use crate::error::AgentError;
|
||||
use crate::llm::LlmClient;
|
||||
|
||||
/// RAG pipeline for building embeddings and performing retrieval
|
||||
pub struct RagPipeline {
|
||||
llm: Arc<LlmClient>,
|
||||
embedding_store: EmbeddingStore,
|
||||
}
|
||||
|
||||
impl RagPipeline {
|
||||
pub fn new(llm: Arc<LlmClient>, db: &mongodb::Database) -> Self {
|
||||
Self {
|
||||
llm,
|
||||
embedding_store: EmbeddingStore::new(db),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn store(&self) -> &EmbeddingStore {
|
||||
&self.embedding_store
|
||||
}
|
||||
|
||||
/// Build embeddings for all code nodes in a repository
|
||||
pub async fn build_embeddings(
|
||||
&self,
|
||||
repo_id: &str,
|
||||
repo_path: &Path,
|
||||
graph_build_id: &str,
|
||||
nodes: &[CodeNode],
|
||||
) -> Result<EmbeddingBuildRun, AgentError> {
|
||||
let embed_model = self.llm.embed_model().to_string();
|
||||
let mut build =
|
||||
EmbeddingBuildRun::new(repo_id.to_string(), graph_build_id.to_string(), embed_model);
|
||||
|
||||
// Step 1: Extract chunks
|
||||
let chunks = extract_chunks(repo_path, nodes, 2048);
|
||||
build.total_chunks = chunks.len() as u32;
|
||||
info!(
|
||||
"[{repo_id}] Extracted {} chunks for embedding",
|
||||
chunks.len()
|
||||
);
|
||||
|
||||
// Store the initial build record
|
||||
self.embedding_store
|
||||
.store_build(&build)
|
||||
.await
|
||||
.map_err(|e| AgentError::Other(format!("Failed to store build: {e}")))?;
|
||||
|
||||
if chunks.is_empty() {
|
||||
build.status = EmbeddingBuildStatus::Completed;
|
||||
build.completed_at = Some(Utc::now());
|
||||
self.embedding_store
|
||||
.update_build(
|
||||
repo_id,
|
||||
graph_build_id,
|
||||
EmbeddingBuildStatus::Completed,
|
||||
0,
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.map_err(|e| AgentError::Other(format!("Failed to update build: {e}")))?;
|
||||
return Ok(build);
|
||||
}
|
||||
|
||||
// Step 2: Delete old embeddings for this repo
|
||||
self.embedding_store
|
||||
.delete_repo_embeddings(repo_id)
|
||||
.await
|
||||
.map_err(|e| AgentError::Other(format!("Failed to delete old embeddings: {e}")))?;
|
||||
|
||||
// Step 3: Batch embed (small batches to stay within model limits)
|
||||
let batch_size = 20;
|
||||
let mut all_embeddings = Vec::new();
|
||||
let mut embedded_count = 0u32;
|
||||
|
||||
for batch_start in (0..chunks.len()).step_by(batch_size) {
|
||||
let batch_end = (batch_start + batch_size).min(chunks.len());
|
||||
let batch_chunks = &chunks[batch_start..batch_end];
|
||||
|
||||
// Prepare texts: context_header + content
|
||||
let texts: Vec<String> = batch_chunks
|
||||
.iter()
|
||||
.map(|c| format!("{}\n{}", c.context_header, c.content))
|
||||
.collect();
|
||||
|
||||
match self.llm.embed(texts).await {
|
||||
Ok(vectors) => {
|
||||
for (chunk, embedding) in batch_chunks.iter().zip(vectors) {
|
||||
all_embeddings.push(CodeEmbedding {
|
||||
id: None,
|
||||
repo_id: repo_id.to_string(),
|
||||
graph_build_id: graph_build_id.to_string(),
|
||||
qualified_name: chunk.qualified_name.clone(),
|
||||
kind: chunk.kind.clone(),
|
||||
file_path: chunk.file_path.clone(),
|
||||
start_line: chunk.start_line,
|
||||
end_line: chunk.end_line,
|
||||
language: chunk.language.clone(),
|
||||
content: chunk.content.clone(),
|
||||
context_header: chunk.context_header.clone(),
|
||||
embedding,
|
||||
token_estimate: chunk.token_estimate,
|
||||
created_at: Utc::now(),
|
||||
});
|
||||
}
|
||||
embedded_count += batch_chunks.len() as u32;
|
||||
}
|
||||
Err(e) => {
|
||||
error!("[{repo_id}] Embedding batch failed: {e}");
|
||||
build.status = EmbeddingBuildStatus::Failed;
|
||||
build.error_message = Some(e.to_string());
|
||||
build.completed_at = Some(Utc::now());
|
||||
let _ = self
|
||||
.embedding_store
|
||||
.update_build(
|
||||
repo_id,
|
||||
graph_build_id,
|
||||
EmbeddingBuildStatus::Failed,
|
||||
embedded_count,
|
||||
Some(e.to_string()),
|
||||
)
|
||||
.await;
|
||||
return Ok(build);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Step 4: Store all embeddings
|
||||
self.embedding_store
|
||||
.store_embeddings(&all_embeddings)
|
||||
.await
|
||||
.map_err(|e| AgentError::Other(format!("Failed to store embeddings: {e}")))?;
|
||||
|
||||
// Step 5: Update build status
|
||||
build.status = EmbeddingBuildStatus::Completed;
|
||||
build.embedded_chunks = embedded_count;
|
||||
build.completed_at = Some(Utc::now());
|
||||
self.embedding_store
|
||||
.update_build(
|
||||
repo_id,
|
||||
graph_build_id,
|
||||
EmbeddingBuildStatus::Completed,
|
||||
embedded_count,
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.map_err(|e| AgentError::Other(format!("Failed to update build: {e}")))?;
|
||||
|
||||
info!(
|
||||
"[{repo_id}] Embedding build complete: {embedded_count}/{} chunks",
|
||||
build.total_chunks
|
||||
);
|
||||
Ok(build)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user