Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| cdfbb62f9d |
@@ -7,11 +7,13 @@ use mongodb::bson::doc;
|
||||
|
||||
use compliance_core::models::chat::{ChatRequest, ChatResponse, SourceReference};
|
||||
use compliance_core::models::embedding::EmbeddingBuildRun;
|
||||
use compliance_core::tenant_ctx::TenantCtx;
|
||||
use compliance_graph::graph::embedding_store::EmbeddingStore;
|
||||
|
||||
use crate::agent::ComplianceAgent;
|
||||
use crate::rag::pipeline::RagPipeline;
|
||||
|
||||
use super::dto::tenant_db;
|
||||
use super::ApiResponse;
|
||||
|
||||
type AgentExt = Extension<Arc<ComplianceAgent>>;
|
||||
@@ -20,10 +22,12 @@ type AgentExt = Extension<Arc<ComplianceAgent>>;
|
||||
#[tracing::instrument(skip_all, fields(repo_id = %repo_id))]
|
||||
pub async fn chat(
|
||||
Extension(agent): AgentExt,
|
||||
tenant: TenantCtx,
|
||||
Path(repo_id): Path<String>,
|
||||
Json(req): Json<ChatRequest>,
|
||||
) -> Result<Json<ApiResponse<ChatResponse>>, StatusCode> {
|
||||
let pipeline = RagPipeline::new(agent.llm.clone(), agent.db.inner());
|
||||
let db = tenant_db(&agent, &tenant).await?;
|
||||
let pipeline = RagPipeline::new(agent.llm.clone(), db.inner());
|
||||
|
||||
// Step 1: Embed the user's message
|
||||
let query_vectors = agent
|
||||
@@ -133,12 +137,15 @@ pub async fn chat(
|
||||
#[tracing::instrument(skip_all, fields(repo_id = %repo_id))]
|
||||
pub async fn build_embeddings(
|
||||
Extension(agent): AgentExt,
|
||||
tenant: TenantCtx,
|
||||
Path(repo_id): Path<String>,
|
||||
) -> Result<Json<serde_json::Value>, StatusCode> {
|
||||
// Resolve the tenant DB up front so we can move it into the spawn;
|
||||
// the JWT/dev context isn't available inside detached tasks.
|
||||
let db = tenant_db(&agent, &tenant).await?;
|
||||
let agent_clone = (*agent).clone();
|
||||
tokio::spawn(async move {
|
||||
let repo = match agent_clone
|
||||
.db
|
||||
let repo = match db
|
||||
.repositories()
|
||||
.find_one(doc! { "_id": mongodb::bson::oid::ObjectId::parse_str(&repo_id).ok() })
|
||||
.await
|
||||
@@ -151,8 +158,7 @@ pub async fn build_embeddings(
|
||||
};
|
||||
|
||||
// Get latest graph build
|
||||
let build = match agent_clone
|
||||
.db
|
||||
let build = match db
|
||||
.graph_builds()
|
||||
.find_one(doc! { "repo_id": &repo_id })
|
||||
.sort(doc! { "started_at": -1 })
|
||||
@@ -171,26 +177,22 @@ pub async fn build_embeddings(
|
||||
.unwrap_or_else(|| "unknown".to_string());
|
||||
|
||||
// Get nodes
|
||||
let nodes: Vec<compliance_core::models::graph::CodeNode> = match agent_clone
|
||||
.db
|
||||
.graph_nodes()
|
||||
.find(doc! { "repo_id": &repo_id })
|
||||
.await
|
||||
{
|
||||
Ok(cursor) => {
|
||||
use futures_util::StreamExt;
|
||||
let mut items = Vec::new();
|
||||
let mut cursor = cursor;
|
||||
while let Some(Ok(item)) = cursor.next().await {
|
||||
items.push(item);
|
||||
let nodes: Vec<compliance_core::models::graph::CodeNode> =
|
||||
match db.graph_nodes().find(doc! { "repo_id": &repo_id }).await {
|
||||
Ok(cursor) => {
|
||||
use futures_util::StreamExt;
|
||||
let mut items = Vec::new();
|
||||
let mut cursor = cursor;
|
||||
while let Some(Ok(item)) = cursor.next().await {
|
||||
items.push(item);
|
||||
}
|
||||
items
|
||||
}
|
||||
items
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!("[{repo_id}] Failed to fetch nodes: {e}");
|
||||
return;
|
||||
}
|
||||
};
|
||||
Err(e) => {
|
||||
tracing::error!("[{repo_id}] Failed to fetch nodes: {e}");
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
let creds = crate::pipeline::git::RepoCredentials {
|
||||
ssh_key_path: Some(agent_clone.config.ssh_key_path.clone()),
|
||||
@@ -207,7 +209,7 @@ pub async fn build_embeddings(
|
||||
}
|
||||
};
|
||||
|
||||
let pipeline = RagPipeline::new(agent_clone.llm.clone(), agent_clone.db.inner());
|
||||
let pipeline = RagPipeline::new(agent_clone.llm.clone(), db.inner());
|
||||
match pipeline
|
||||
.build_embeddings(&repo_id, &repo_path, &graph_build_id, &nodes)
|
||||
.await
|
||||
@@ -234,9 +236,11 @@ pub async fn build_embeddings(
|
||||
#[tracing::instrument(skip_all, fields(repo_id = %repo_id))]
|
||||
pub async fn embedding_status(
|
||||
Extension(agent): AgentExt,
|
||||
tenant: TenantCtx,
|
||||
Path(repo_id): Path<String>,
|
||||
) -> Result<Json<ApiResponse<Option<EmbeddingBuildRun>>>, StatusCode> {
|
||||
let store = EmbeddingStore::new(agent.db.inner());
|
||||
let db = tenant_db(&agent, &tenant).await?;
|
||||
let store = EmbeddingStore::new(db.inner());
|
||||
let build = store.get_latest_build(&repo_id).await.map_err(|e| {
|
||||
tracing::error!("Failed to get embedding status: {e}");
|
||||
StatusCode::INTERNAL_SERVER_ERROR
|
||||
|
||||
@@ -7,9 +7,11 @@ use mongodb::bson::doc;
|
||||
use serde::Deserialize;
|
||||
|
||||
use compliance_core::models::dast::{DastFinding, DastScanRun, DastTarget, DastTargetType};
|
||||
use compliance_core::tenant_ctx::TenantCtx;
|
||||
|
||||
use crate::agent::ComplianceAgent;
|
||||
|
||||
use super::dto::tenant_db;
|
||||
use super::{collect_cursor_async, ApiResponse, PaginationParams};
|
||||
|
||||
type AgentExt = Extension<Arc<ComplianceAgent>>;
|
||||
@@ -45,9 +47,11 @@ fn default_rate_limit() -> u32 {
|
||||
#[tracing::instrument(skip_all)]
|
||||
pub async fn list_targets(
|
||||
Extension(agent): AgentExt,
|
||||
tenant: TenantCtx,
|
||||
Query(params): Query<PaginationParams>,
|
||||
) -> Result<Json<ApiResponse<Vec<DastTarget>>>, StatusCode> {
|
||||
let db = &agent.db;
|
||||
let db = tenant_db(&agent, &tenant).await?;
|
||||
let db = &db;
|
||||
let skip = (params.page.saturating_sub(1)) * params.limit as u64;
|
||||
let total = db
|
||||
.dast_targets()
|
||||
@@ -80,6 +84,7 @@ pub async fn list_targets(
|
||||
#[tracing::instrument(skip_all)]
|
||||
pub async fn add_target(
|
||||
Extension(agent): AgentExt,
|
||||
tenant: TenantCtx,
|
||||
Json(req): Json<AddTargetRequest>,
|
||||
) -> Result<Json<ApiResponse<DastTarget>>, StatusCode> {
|
||||
let mut target = DastTarget::new(req.name, req.base_url, req.target_type);
|
||||
@@ -89,9 +94,8 @@ pub async fn add_target(
|
||||
target.rate_limit = req.rate_limit;
|
||||
target.allow_destructive = req.allow_destructive;
|
||||
|
||||
agent
|
||||
.db
|
||||
.dast_targets()
|
||||
let db = tenant_db(&agent, &tenant).await?;
|
||||
db.dast_targets()
|
||||
.insert_one(&target)
|
||||
.await
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
@@ -107,19 +111,19 @@ pub async fn add_target(
|
||||
#[tracing::instrument(skip_all, fields(target_id = %id))]
|
||||
pub async fn trigger_scan(
|
||||
Extension(agent): AgentExt,
|
||||
tenant: TenantCtx,
|
||||
Path(id): Path<String>,
|
||||
) -> Result<Json<serde_json::Value>, StatusCode> {
|
||||
let oid = mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?;
|
||||
let db = tenant_db(&agent, &tenant).await?;
|
||||
|
||||
let target = agent
|
||||
.db
|
||||
let target = db
|
||||
.dast_targets()
|
||||
.find_one(doc! { "_id": oid })
|
||||
.await
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
|
||||
.ok_or(StatusCode::NOT_FOUND)?;
|
||||
|
||||
let db = agent.db.clone();
|
||||
tokio::spawn(async move {
|
||||
let orchestrator = compliance_dast::DastOrchestrator::new(100);
|
||||
match orchestrator.run_scan(&target, Vec::new()).await {
|
||||
@@ -147,9 +151,11 @@ pub async fn trigger_scan(
|
||||
#[tracing::instrument(skip_all)]
|
||||
pub async fn list_scan_runs(
|
||||
Extension(agent): AgentExt,
|
||||
tenant: TenantCtx,
|
||||
Query(params): Query<PaginationParams>,
|
||||
) -> Result<Json<ApiResponse<Vec<DastScanRun>>>, StatusCode> {
|
||||
let db = &agent.db;
|
||||
let db = tenant_db(&agent, &tenant).await?;
|
||||
let db = &db;
|
||||
let skip = (params.page.saturating_sub(1)) * params.limit as u64;
|
||||
let total = db
|
||||
.dast_scan_runs()
|
||||
@@ -183,9 +189,11 @@ pub async fn list_scan_runs(
|
||||
#[tracing::instrument(skip_all)]
|
||||
pub async fn list_findings(
|
||||
Extension(agent): AgentExt,
|
||||
tenant: TenantCtx,
|
||||
Query(params): Query<PaginationParams>,
|
||||
) -> Result<Json<ApiResponse<Vec<DastFinding>>>, StatusCode> {
|
||||
let db = &agent.db;
|
||||
let db = tenant_db(&agent, &tenant).await?;
|
||||
let db = &db;
|
||||
let skip = (params.page.saturating_sub(1)) * params.limit as u64;
|
||||
let total = db
|
||||
.dast_findings()
|
||||
@@ -219,12 +227,13 @@ pub async fn list_findings(
|
||||
#[tracing::instrument(skip_all, fields(finding_id = %id))]
|
||||
pub async fn get_finding(
|
||||
Extension(agent): AgentExt,
|
||||
tenant: TenantCtx,
|
||||
Path(id): Path<String>,
|
||||
) -> Result<Json<ApiResponse<DastFinding>>, StatusCode> {
|
||||
let oid = mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?;
|
||||
let db = tenant_db(&agent, &tenant).await?;
|
||||
|
||||
let finding = agent
|
||||
.db
|
||||
let finding = db
|
||||
.dast_findings()
|
||||
.find_one(doc! { "_id": oid })
|
||||
.await
|
||||
|
||||
@@ -180,6 +180,27 @@ pub struct SbomVersionDiff {
|
||||
pub(crate) type AgentExt = axum::extract::Extension<std::sync::Arc<crate::agent::ComplianceAgent>>;
|
||||
pub(crate) type ApiResult<T> = Result<axum::Json<ApiResponse<T>>, axum::http::StatusCode>;
|
||||
|
||||
/// Resolve a tenant-scoped [`Database`] from the request's
|
||||
/// [`TenantContext`] (inserted by the M7.1 JWT middleware, or by the
|
||||
/// dev fallback in unsecured environments). The pool ensures the
|
||||
/// tenant's indexes idempotently.
|
||||
///
|
||||
/// Returns 500 on the rare path where Mongo refuses the database
|
||||
/// handle — the M7.1 auth/status middleware already rejects every
|
||||
/// other failure mode with 4xx before we get here.
|
||||
pub(crate) async fn tenant_db(
|
||||
agent: &crate::agent::ComplianceAgent,
|
||||
tenant: &compliance_core::tenant_ctx::TenantCtx,
|
||||
) -> Result<crate::database::Database, axum::http::StatusCode> {
|
||||
agent.db_pool.for_tenant(&tenant.0).await.map_err(|e| {
|
||||
tracing::error!(
|
||||
tenant_id = %tenant.0.tenant_id,
|
||||
"Failed to acquire tenant database: {e}"
|
||||
);
|
||||
axum::http::StatusCode::INTERNAL_SERVER_ERROR
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) async fn collect_cursor_async<T: serde::de::DeserializeOwned + Unpin + Send>(
|
||||
mut cursor: mongodb::Cursor<T>,
|
||||
) -> Vec<T> {
|
||||
|
||||
@@ -5,13 +5,16 @@ use mongodb::bson::doc;
|
||||
|
||||
use super::dto::*;
|
||||
use compliance_core::models::Finding;
|
||||
use compliance_core::tenant_ctx::TenantCtx;
|
||||
|
||||
#[tracing::instrument(skip_all, fields(repo_id = ?filter.repo_id, severity = ?filter.severity, scan_type = ?filter.scan_type))]
|
||||
pub async fn list_findings(
|
||||
Extension(agent): AgentExt,
|
||||
tenant: TenantCtx,
|
||||
Query(filter): Query<FindingsFilter>,
|
||||
) -> ApiResult<Vec<Finding>> {
|
||||
let db = &agent.db;
|
||||
let db = tenant_db(&agent, &tenant).await?;
|
||||
let db = &db;
|
||||
let mut query = doc! {};
|
||||
if let Some(repo_id) = &filter.repo_id {
|
||||
query.insert("repo_id", repo_id);
|
||||
@@ -81,11 +84,12 @@ pub async fn list_findings(
|
||||
#[tracing::instrument(skip_all, fields(finding_id = %id))]
|
||||
pub async fn get_finding(
|
||||
Extension(agent): AgentExt,
|
||||
tenant: TenantCtx,
|
||||
Path(id): Path<String>,
|
||||
) -> Result<Json<ApiResponse<Finding>>, StatusCode> {
|
||||
let oid = mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?;
|
||||
let finding = agent
|
||||
.db
|
||||
let db = tenant_db(&agent, &tenant).await?;
|
||||
let finding = db
|
||||
.findings()
|
||||
.find_one(doc! { "_id": oid })
|
||||
.await
|
||||
@@ -102,14 +106,14 @@ pub async fn get_finding(
|
||||
#[tracing::instrument(skip_all, fields(finding_id = %id))]
|
||||
pub async fn update_finding_status(
|
||||
Extension(agent): AgentExt,
|
||||
tenant: TenantCtx,
|
||||
Path(id): Path<String>,
|
||||
Json(req): Json<UpdateStatusRequest>,
|
||||
) -> Result<Json<serde_json::Value>, StatusCode> {
|
||||
let oid = mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?;
|
||||
let db = tenant_db(&agent, &tenant).await?;
|
||||
|
||||
agent
|
||||
.db
|
||||
.findings()
|
||||
db.findings()
|
||||
.update_one(
|
||||
doc! { "_id": oid },
|
||||
doc! { "$set": { "status": &req.status, "updated_at": mongodb::bson::DateTime::now() } },
|
||||
@@ -123,6 +127,7 @@ pub async fn update_finding_status(
|
||||
#[tracing::instrument(skip_all)]
|
||||
pub async fn bulk_update_finding_status(
|
||||
Extension(agent): AgentExt,
|
||||
tenant: TenantCtx,
|
||||
Json(req): Json<BulkUpdateStatusRequest>,
|
||||
) -> Result<Json<serde_json::Value>, StatusCode> {
|
||||
let oids: Vec<mongodb::bson::oid::ObjectId> = req
|
||||
@@ -135,8 +140,8 @@ pub async fn bulk_update_finding_status(
|
||||
return Err(StatusCode::BAD_REQUEST);
|
||||
}
|
||||
|
||||
let result = agent
|
||||
.db
|
||||
let db = tenant_db(&agent, &tenant).await?;
|
||||
let result = db
|
||||
.findings()
|
||||
.update_many(
|
||||
doc! { "_id": { "$in": oids } },
|
||||
@@ -153,14 +158,14 @@ pub async fn bulk_update_finding_status(
|
||||
#[tracing::instrument(skip_all)]
|
||||
pub async fn update_finding_feedback(
|
||||
Extension(agent): AgentExt,
|
||||
tenant: TenantCtx,
|
||||
Path(id): Path<String>,
|
||||
Json(req): Json<UpdateFeedbackRequest>,
|
||||
) -> Result<Json<serde_json::Value>, StatusCode> {
|
||||
let oid = mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?;
|
||||
let db = tenant_db(&agent, &tenant).await?;
|
||||
|
||||
agent
|
||||
.db
|
||||
.findings()
|
||||
db.findings()
|
||||
.update_one(
|
||||
doc! { "_id": oid },
|
||||
doc! { "$set": { "developer_feedback": &req.feedback, "updated_at": mongodb::bson::DateTime::now() } },
|
||||
|
||||
@@ -7,9 +7,11 @@ use mongodb::bson::doc;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use compliance_core::models::graph::{CodeEdge, CodeNode, GraphBuildRun, ImpactAnalysis};
|
||||
use compliance_core::tenant_ctx::TenantCtx;
|
||||
|
||||
use crate::agent::ComplianceAgent;
|
||||
|
||||
use super::dto::tenant_db;
|
||||
use super::{collect_cursor_async, ApiResponse};
|
||||
|
||||
type AgentExt = Extension<Arc<ComplianceAgent>>;
|
||||
@@ -36,9 +38,11 @@ fn default_search_limit() -> usize {
|
||||
#[tracing::instrument(skip_all, fields(repo_id = %repo_id))]
|
||||
pub async fn get_graph(
|
||||
Extension(agent): AgentExt,
|
||||
tenant: TenantCtx,
|
||||
Path(repo_id): Path<String>,
|
||||
) -> Result<Json<ApiResponse<GraphData>>, StatusCode> {
|
||||
let db = &agent.db;
|
||||
let db = tenant_db(&agent, &tenant).await?;
|
||||
let db = &db;
|
||||
|
||||
// Get latest build
|
||||
let build: Option<GraphBuildRun> = db
|
||||
@@ -98,9 +102,11 @@ pub async fn get_graph(
|
||||
#[tracing::instrument(skip_all, fields(repo_id = %repo_id))]
|
||||
pub async fn get_nodes(
|
||||
Extension(agent): AgentExt,
|
||||
tenant: TenantCtx,
|
||||
Path(repo_id): Path<String>,
|
||||
) -> Result<Json<ApiResponse<Vec<CodeNode>>>, StatusCode> {
|
||||
let db = &agent.db;
|
||||
let db = tenant_db(&agent, &tenant).await?;
|
||||
let db = &db;
|
||||
let filter = doc! { "repo_id": &repo_id };
|
||||
|
||||
let nodes: Vec<CodeNode> = match db.graph_nodes().find(filter).await {
|
||||
@@ -123,9 +129,11 @@ pub async fn get_nodes(
|
||||
#[tracing::instrument(skip_all, fields(repo_id = %repo_id))]
|
||||
pub async fn get_communities(
|
||||
Extension(agent): AgentExt,
|
||||
tenant: TenantCtx,
|
||||
Path(repo_id): Path<String>,
|
||||
) -> Result<Json<ApiResponse<Vec<CommunityInfo>>>, StatusCode> {
|
||||
let db = &agent.db;
|
||||
let db = tenant_db(&agent, &tenant).await?;
|
||||
let db = &db;
|
||||
let filter = doc! { "repo_id": &repo_id };
|
||||
|
||||
let nodes: Vec<CodeNode> = match db.graph_nodes().find(filter).await {
|
||||
@@ -176,9 +184,11 @@ pub struct CommunityInfo {
|
||||
#[tracing::instrument(skip_all, fields(repo_id = %repo_id, finding_id = %finding_id))]
|
||||
pub async fn get_impact(
|
||||
Extension(agent): AgentExt,
|
||||
tenant: TenantCtx,
|
||||
Path((repo_id, finding_id)): Path<(String, String)>,
|
||||
) -> Result<Json<ApiResponse<Option<ImpactAnalysis>>>, StatusCode> {
|
||||
let db = &agent.db;
|
||||
let db = tenant_db(&agent, &tenant).await?;
|
||||
let db = &db;
|
||||
let filter = doc! { "repo_id": &repo_id, "finding_id": &finding_id };
|
||||
|
||||
let impact = db
|
||||
@@ -198,10 +208,12 @@ pub async fn get_impact(
|
||||
#[tracing::instrument(skip_all, fields(repo_id = %repo_id, query = %params.q))]
|
||||
pub async fn search_symbols(
|
||||
Extension(agent): AgentExt,
|
||||
tenant: TenantCtx,
|
||||
Path(repo_id): Path<String>,
|
||||
Query(params): Query<SearchParams>,
|
||||
) -> Result<Json<ApiResponse<Vec<CodeNode>>>, StatusCode> {
|
||||
let db = &agent.db;
|
||||
let db = tenant_db(&agent, &tenant).await?;
|
||||
let db = &db;
|
||||
|
||||
// Simple text search on qualified_name and name fields
|
||||
let filter = doc! {
|
||||
@@ -234,10 +246,12 @@ pub async fn search_symbols(
|
||||
#[tracing::instrument(skip_all, fields(repo_id = %repo_id))]
|
||||
pub async fn get_file_content(
|
||||
Extension(agent): AgentExt,
|
||||
tenant: TenantCtx,
|
||||
Path(repo_id): Path<String>,
|
||||
Query(params): Query<FileContentParams>,
|
||||
) -> Result<Json<ApiResponse<FileContent>>, StatusCode> {
|
||||
let db = &agent.db;
|
||||
let db = tenant_db(&agent, &tenant).await?;
|
||||
let db = &db;
|
||||
|
||||
// Look up the repository to get repo name
|
||||
let repo = db
|
||||
@@ -296,12 +310,13 @@ pub struct FileContent {
|
||||
#[tracing::instrument(skip_all, fields(repo_id = %repo_id))]
|
||||
pub async fn trigger_build(
|
||||
Extension(agent): AgentExt,
|
||||
tenant: TenantCtx,
|
||||
Path(repo_id): Path<String>,
|
||||
) -> Result<Json<serde_json::Value>, StatusCode> {
|
||||
let db = tenant_db(&agent, &tenant).await?;
|
||||
let agent_clone = (*agent).clone();
|
||||
tokio::spawn(async move {
|
||||
let repo = match agent_clone
|
||||
.db
|
||||
let repo = match db
|
||||
.repositories()
|
||||
.find_one(doc! { "_id": mongodb::bson::oid::ObjectId::parse_str(&repo_id).ok() })
|
||||
.await
|
||||
@@ -333,8 +348,7 @@ pub async fn trigger_build(
|
||||
|
||||
match engine.build_graph(&repo_path, &repo_id, &graph_build_id) {
|
||||
Ok((code_graph, build_run)) => {
|
||||
let store =
|
||||
compliance_graph::graph::persistence::GraphStore::new(agent_clone.db.inner());
|
||||
let store = compliance_graph::graph::persistence::GraphStore::new(db.inner());
|
||||
let _ = store.delete_repo_graph(&repo_id).await;
|
||||
let _ = store
|
||||
.store_graph(&build_run, &code_graph.nodes, &code_graph.edges)
|
||||
|
||||
@@ -3,6 +3,7 @@ use mongodb::bson::doc;
|
||||
|
||||
use super::dto::*;
|
||||
use compliance_core::models::ScanRun;
|
||||
use compliance_core::tenant_ctx::TenantCtx;
|
||||
|
||||
#[tracing::instrument(skip_all)]
|
||||
pub async fn health() -> Json<serde_json::Value> {
|
||||
@@ -10,8 +11,12 @@ pub async fn health() -> Json<serde_json::Value> {
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip_all)]
|
||||
pub async fn stats_overview(axum::extract::Extension(agent): AgentExt) -> ApiResult<OverviewStats> {
|
||||
let db = &agent.db;
|
||||
pub async fn stats_overview(
|
||||
axum::extract::Extension(agent): AgentExt,
|
||||
tenant: TenantCtx,
|
||||
) -> ApiResult<OverviewStats> {
|
||||
let db = tenant_db(&agent, &tenant).await?;
|
||||
let db = &db;
|
||||
|
||||
let total_repositories = db
|
||||
.repositories()
|
||||
|
||||
@@ -4,13 +4,16 @@ use mongodb::bson::doc;
|
||||
|
||||
use super::dto::*;
|
||||
use compliance_core::models::TrackerIssue;
|
||||
use compliance_core::tenant_ctx::TenantCtx;
|
||||
|
||||
#[tracing::instrument(skip_all)]
|
||||
pub async fn list_issues(
|
||||
Extension(agent): AgentExt,
|
||||
tenant: TenantCtx,
|
||||
Query(params): Query<PaginationParams>,
|
||||
) -> ApiResult<Vec<TrackerIssue>> {
|
||||
let db = &agent.db;
|
||||
let db = tenant_db(&agent, &tenant).await?;
|
||||
let db = &db;
|
||||
let skip = (params.page.saturating_sub(1)) * params.limit as u64;
|
||||
let total = db
|
||||
.tracker_issues()
|
||||
|
||||
@@ -5,15 +5,18 @@ use mongodb::bson::doc;
|
||||
use serde::Deserialize;
|
||||
|
||||
use compliance_core::models::notification::CveNotification;
|
||||
use compliance_core::tenant_ctx::TenantCtx;
|
||||
|
||||
use super::dto::{AgentExt, ApiResponse};
|
||||
use super::dto::{tenant_db, AgentExt, ApiResponse};
|
||||
|
||||
/// GET /api/v1/notifications — List CVE notifications (newest first)
|
||||
#[tracing::instrument(skip_all)]
|
||||
pub async fn list_notifications(
|
||||
Extension(agent): AgentExt,
|
||||
tenant: TenantCtx,
|
||||
axum::extract::Query(params): axum::extract::Query<NotificationFilter>,
|
||||
) -> Result<Json<ApiResponse<Vec<CveNotification>>>, StatusCode> {
|
||||
let db = tenant_db(&agent, &tenant).await?;
|
||||
let mut filter = doc! {};
|
||||
|
||||
// Filter by status (default: show new + read, exclude dismissed)
|
||||
@@ -41,15 +44,13 @@ pub async fn list_notifications(
|
||||
let limit = params.limit.unwrap_or(50).min(200);
|
||||
let skip = (page - 1) * limit as u64;
|
||||
|
||||
let total = agent
|
||||
.db
|
||||
let total = db
|
||||
.cve_notifications()
|
||||
.count_documents(filter.clone())
|
||||
.await
|
||||
.unwrap_or(0);
|
||||
|
||||
let notifications: Vec<CveNotification> = match agent
|
||||
.db
|
||||
let notifications: Vec<CveNotification> = match db
|
||||
.cve_notifications()
|
||||
.find(filter)
|
||||
.sort(doc! { "created_at": -1 })
|
||||
@@ -83,9 +84,10 @@ pub async fn list_notifications(
|
||||
#[tracing::instrument(skip_all)]
|
||||
pub async fn notification_count(
|
||||
Extension(agent): AgentExt,
|
||||
tenant: TenantCtx,
|
||||
) -> Result<Json<serde_json::Value>, StatusCode> {
|
||||
let count = agent
|
||||
.db
|
||||
let db = tenant_db(&agent, &tenant).await?;
|
||||
let count = db
|
||||
.cve_notifications()
|
||||
.count_documents(doc! { "status": "new" })
|
||||
.await
|
||||
@@ -98,12 +100,13 @@ pub async fn notification_count(
|
||||
#[tracing::instrument(skip_all, fields(id = %id))]
|
||||
pub async fn mark_read(
|
||||
Extension(agent): AgentExt,
|
||||
tenant: TenantCtx,
|
||||
axum::extract::Path(id): axum::extract::Path<String>,
|
||||
) -> Result<Json<serde_json::Value>, StatusCode> {
|
||||
let oid = mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?;
|
||||
let db = tenant_db(&agent, &tenant).await?;
|
||||
|
||||
let result = agent
|
||||
.db
|
||||
let result = db
|
||||
.cve_notifications()
|
||||
.update_one(
|
||||
doc! { "_id": oid },
|
||||
@@ -125,12 +128,13 @@ pub async fn mark_read(
|
||||
#[tracing::instrument(skip_all, fields(id = %id))]
|
||||
pub async fn dismiss_notification(
|
||||
Extension(agent): AgentExt,
|
||||
tenant: TenantCtx,
|
||||
axum::extract::Path(id): axum::extract::Path<String>,
|
||||
) -> Result<Json<serde_json::Value>, StatusCode> {
|
||||
let oid = mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?;
|
||||
let db = tenant_db(&agent, &tenant).await?;
|
||||
|
||||
let result = agent
|
||||
.db
|
||||
let result = db
|
||||
.cve_notifications()
|
||||
.update_one(
|
||||
doc! { "_id": oid },
|
||||
@@ -149,9 +153,10 @@ pub async fn dismiss_notification(
|
||||
#[tracing::instrument(skip_all)]
|
||||
pub async fn mark_all_read(
|
||||
Extension(agent): AgentExt,
|
||||
tenant: TenantCtx,
|
||||
) -> Result<Json<serde_json::Value>, StatusCode> {
|
||||
let result = agent
|
||||
.db
|
||||
let db = tenant_db(&agent, &tenant).await?;
|
||||
let result = db
|
||||
.cve_notifications()
|
||||
.update_many(
|
||||
doc! { "status": "new" },
|
||||
|
||||
@@ -13,10 +13,11 @@ use compliance_core::models::dast::DastFinding;
|
||||
use compliance_core::models::finding::Finding;
|
||||
use compliance_core::models::pentest::*;
|
||||
use compliance_core::models::sbom::SbomEntry;
|
||||
use compliance_core::tenant_ctx::TenantCtx;
|
||||
|
||||
use crate::agent::ComplianceAgent;
|
||||
|
||||
use super::super::dto::collect_cursor_async;
|
||||
use super::super::dto::{collect_cursor_async, tenant_db};
|
||||
|
||||
type AgentExt = Extension<Arc<ComplianceAgent>>;
|
||||
|
||||
@@ -35,11 +36,15 @@ pub struct ExportBody {
|
||||
#[tracing::instrument(skip_all, fields(session_id = %id))]
|
||||
pub async fn export_session_report(
|
||||
Extension(agent): AgentExt,
|
||||
tenant: TenantCtx,
|
||||
Path(id): Path<String>,
|
||||
Json(body): Json<ExportBody>,
|
||||
) -> Result<axum::response::Response, (StatusCode, String)> {
|
||||
let oid = mongodb::bson::oid::ObjectId::parse_str(&id)
|
||||
.map_err(|_| (StatusCode::BAD_REQUEST, "Invalid session ID".to_string()))?;
|
||||
let db = tenant_db(&agent, &tenant)
|
||||
.await
|
||||
.map_err(|s| (s, "failed to acquire tenant database".to_string()))?;
|
||||
|
||||
if body.password.len() < 8 {
|
||||
return Err((
|
||||
@@ -49,8 +54,7 @@ pub async fn export_session_report(
|
||||
}
|
||||
|
||||
// Fetch session
|
||||
let session = agent
|
||||
.db
|
||||
let session = db
|
||||
.pentest_sessions()
|
||||
.find_one(doc! { "_id": oid })
|
||||
.await
|
||||
@@ -64,9 +68,7 @@ pub async fn export_session_report(
|
||||
|
||||
// Resolve target name
|
||||
let target = if let Ok(tid) = mongodb::bson::oid::ObjectId::parse_str(&session.target_id) {
|
||||
agent
|
||||
.db
|
||||
.dast_targets()
|
||||
db.dast_targets()
|
||||
.find_one(doc! { "_id": tid })
|
||||
.await
|
||||
.ok()
|
||||
@@ -84,8 +86,7 @@ pub async fn export_session_report(
|
||||
.unwrap_or_default();
|
||||
|
||||
// Fetch attack chain nodes
|
||||
let nodes: Vec<AttackChainNode> = match agent
|
||||
.db
|
||||
let nodes: Vec<AttackChainNode> = match db
|
||||
.attack_chain_nodes()
|
||||
.find(doc! { "session_id": &id })
|
||||
.sort(doc! { "started_at": 1 })
|
||||
@@ -96,8 +97,7 @@ pub async fn export_session_report(
|
||||
};
|
||||
|
||||
// Fetch DAST findings for this session, then deduplicate
|
||||
let raw_findings: Vec<DastFinding> = match agent
|
||||
.db
|
||||
let raw_findings: Vec<DastFinding> = match db
|
||||
.dast_findings()
|
||||
.find(doc! { "session_id": &id })
|
||||
.sort(doc! { "severity": -1, "created_at": -1 })
|
||||
@@ -122,8 +122,7 @@ pub async fn export_session_report(
|
||||
.or_else(|| target.as_ref().and_then(|t| t.repo_id.clone()));
|
||||
|
||||
let (sast_findings, sbom_entries, code_context) = if let Some(ref rid) = repo_id {
|
||||
let sast: Vec<Finding> = match agent
|
||||
.db
|
||||
let sast: Vec<Finding> = match db
|
||||
.findings()
|
||||
.find(doc! {
|
||||
"repo_id": rid,
|
||||
@@ -143,8 +142,7 @@ pub async fn export_session_report(
|
||||
Err(_) => Vec::new(),
|
||||
};
|
||||
|
||||
let sbom: Vec<SbomEntry> = match agent
|
||||
.db
|
||||
let sbom: Vec<SbomEntry> = match db
|
||||
.sbom_entries()
|
||||
.find(doc! {
|
||||
"repo_id": rid,
|
||||
@@ -164,8 +162,7 @@ pub async fn export_session_report(
|
||||
};
|
||||
|
||||
// Build code context from graph nodes
|
||||
let code_ctx: Vec<CodeContextHint> = match agent
|
||||
.db
|
||||
let code_ctx: Vec<CodeContextHint> = match db
|
||||
.graph_nodes()
|
||||
.find(doc! { "repo_id": rid, "is_entry_point": true })
|
||||
.limit(50)
|
||||
|
||||
@@ -7,11 +7,12 @@ use mongodb::bson::doc;
|
||||
use serde::Deserialize;
|
||||
|
||||
use compliance_core::models::pentest::*;
|
||||
use compliance_core::tenant_ctx::TenantCtx;
|
||||
|
||||
use crate::agent::ComplianceAgent;
|
||||
use crate::pentest::PentestOrchestrator;
|
||||
|
||||
use super::super::dto::{collect_cursor_async, ApiResponse, PaginationParams};
|
||||
use super::super::dto::{collect_cursor_async, tenant_db, ApiResponse, PaginationParams};
|
||||
|
||||
type AgentExt = Extension<Arc<ComplianceAgent>>;
|
||||
|
||||
@@ -43,6 +44,7 @@ pub struct LookupRepoQuery {
|
||||
#[tracing::instrument(skip_all)]
|
||||
pub async fn create_session(
|
||||
Extension(agent): AgentExt,
|
||||
tenant: TenantCtx,
|
||||
Json(req): Json<CreateSessionRequest>,
|
||||
) -> Result<Json<ApiResponse<PentestSession>>, (StatusCode, String)> {
|
||||
// Try to acquire a concurrency permit
|
||||
@@ -57,6 +59,10 @@ pub async fn create_session(
|
||||
)
|
||||
})?;
|
||||
|
||||
let db = tenant_db(&agent, &tenant)
|
||||
.await
|
||||
.map_err(|s| (s, "failed to acquire tenant database".to_string()))?;
|
||||
|
||||
if let Some(ref config) = req.config {
|
||||
// ── Wizard path ──────────────────────────────────────────────
|
||||
if !config.disclaimer_accepted {
|
||||
@@ -67,8 +73,7 @@ pub async fn create_session(
|
||||
}
|
||||
|
||||
// Look up or auto-create DastTarget by app_url
|
||||
let target = match agent
|
||||
.db
|
||||
let target = match db
|
||||
.dast_targets()
|
||||
.find_one(doc! { "base_url": &config.app_url })
|
||||
.await
|
||||
@@ -87,7 +92,7 @@ pub async fn create_session(
|
||||
}
|
||||
t.allow_destructive = config.allow_destructive;
|
||||
t.excluded_paths = config.scope_exclusions.clone();
|
||||
let res = agent.db.dast_targets().insert_one(&t).await.map_err(|e| {
|
||||
let res = db.dast_targets().insert_one(&t).await.map_err(|e| {
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Failed to create target: {e}"),
|
||||
@@ -110,8 +115,7 @@ pub async fn create_session(
|
||||
|
||||
// Resolve repo_id from git_repo_url if provided
|
||||
if let Some(ref git_url) = config.git_repo_url {
|
||||
if let Ok(Some(repo)) = agent
|
||||
.db
|
||||
if let Ok(Some(repo)) = db
|
||||
.repositories()
|
||||
.find_one(doc! { "git_url": git_url })
|
||||
.await
|
||||
@@ -120,8 +124,7 @@ pub async fn create_session(
|
||||
}
|
||||
}
|
||||
|
||||
let insert_result = agent
|
||||
.db
|
||||
let insert_result = db
|
||||
.pentest_sessions()
|
||||
.insert_one(&session)
|
||||
.await
|
||||
@@ -212,8 +215,7 @@ pub async fn create_session(
|
||||
// Persist encrypted credentials to DB
|
||||
if session_for_task.config.is_some() {
|
||||
if let Some(sid) = session.id {
|
||||
let _ = agent
|
||||
.db
|
||||
let _ = db
|
||||
.pentest_sessions()
|
||||
.update_one(
|
||||
doc! { "_id": sid },
|
||||
@@ -245,12 +247,13 @@ pub async fn create_session(
|
||||
});
|
||||
|
||||
let llm = agent.llm.clone();
|
||||
let db = agent.db.clone();
|
||||
let db_for_orchestrator = db.clone();
|
||||
let session_clone = session.clone();
|
||||
let target_clone = target.clone();
|
||||
let agent_ref = agent.clone();
|
||||
tokio::spawn(async move {
|
||||
let orchestrator = PentestOrchestrator::new(llm, db, event_tx, Some(pause_rx));
|
||||
let orchestrator =
|
||||
PentestOrchestrator::new(llm, db_for_orchestrator, event_tx, Some(pause_rx));
|
||||
orchestrator
|
||||
.run_session_guarded(&session_clone, &target_clone, &initial_message)
|
||||
.await;
|
||||
@@ -292,8 +295,7 @@ pub async fn create_session(
|
||||
)
|
||||
})?;
|
||||
|
||||
let target = agent
|
||||
.db
|
||||
let target = db
|
||||
.dast_targets()
|
||||
.find_one(doc! { "_id": oid })
|
||||
.await
|
||||
@@ -310,8 +312,7 @@ pub async fn create_session(
|
||||
let mut session = PentestSession::new(target_id, strategy);
|
||||
session.repo_id = target.repo_id.clone();
|
||||
|
||||
let insert_result = agent
|
||||
.db
|
||||
let insert_result = db
|
||||
.pentest_sessions()
|
||||
.insert_one(&session)
|
||||
.await
|
||||
@@ -338,12 +339,13 @@ pub async fn create_session(
|
||||
});
|
||||
|
||||
let llm = agent.llm.clone();
|
||||
let db = agent.db.clone();
|
||||
let db_for_orchestrator = db.clone();
|
||||
let session_clone = session.clone();
|
||||
let target_clone = target.clone();
|
||||
let agent_ref = agent.clone();
|
||||
tokio::spawn(async move {
|
||||
let orchestrator = PentestOrchestrator::new(llm, db, event_tx, Some(pause_rx));
|
||||
let orchestrator =
|
||||
PentestOrchestrator::new(llm, db_for_orchestrator, event_tx, Some(pause_rx));
|
||||
orchestrator
|
||||
.run_session_guarded(&session_clone, &target_clone, &initial_message)
|
||||
.await;
|
||||
@@ -373,10 +375,11 @@ fn parse_strategy(s: &str) -> PentestStrategy {
|
||||
#[tracing::instrument(skip_all)]
|
||||
pub async fn lookup_repo(
|
||||
Extension(agent): AgentExt,
|
||||
tenant: TenantCtx,
|
||||
Query(params): Query<LookupRepoQuery>,
|
||||
) -> Result<Json<ApiResponse<serde_json::Value>>, StatusCode> {
|
||||
let repo = agent
|
||||
.db
|
||||
let db = tenant_db(&agent, &tenant).await?;
|
||||
let repo = db
|
||||
.repositories()
|
||||
.find_one(doc! { "git_url": ¶ms.url })
|
||||
.await
|
||||
@@ -402,9 +405,11 @@ pub async fn lookup_repo(
|
||||
#[tracing::instrument(skip_all)]
|
||||
pub async fn list_sessions(
|
||||
Extension(agent): AgentExt,
|
||||
tenant: TenantCtx,
|
||||
Query(params): Query<PaginationParams>,
|
||||
) -> Result<Json<ApiResponse<Vec<PentestSession>>>, StatusCode> {
|
||||
let db = &agent.db;
|
||||
let db = tenant_db(&agent, &tenant).await?;
|
||||
let db = &db;
|
||||
let skip = (params.page.saturating_sub(1)) * params.limit as u64;
|
||||
let total = db
|
||||
.pentest_sessions()
|
||||
@@ -438,12 +443,13 @@ pub async fn list_sessions(
|
||||
#[tracing::instrument(skip_all, fields(session_id = %id))]
|
||||
pub async fn get_session(
|
||||
Extension(agent): AgentExt,
|
||||
tenant: TenantCtx,
|
||||
Path(id): Path<String>,
|
||||
) -> Result<Json<ApiResponse<PentestSession>>, StatusCode> {
|
||||
let oid = mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?;
|
||||
let db = tenant_db(&agent, &tenant).await?;
|
||||
|
||||
let mut session = agent
|
||||
.db
|
||||
let mut session = db
|
||||
.pentest_sessions()
|
||||
.find_one(doc! { "_id": oid })
|
||||
.await
|
||||
@@ -471,15 +477,18 @@ pub async fn get_session(
|
||||
#[tracing::instrument(skip_all, fields(session_id = %id))]
|
||||
pub async fn send_message(
|
||||
Extension(agent): AgentExt,
|
||||
tenant: TenantCtx,
|
||||
Path(id): Path<String>,
|
||||
Json(req): Json<SendMessageRequest>,
|
||||
) -> Result<Json<ApiResponse<PentestMessage>>, (StatusCode, String)> {
|
||||
let oid = mongodb::bson::oid::ObjectId::parse_str(&id)
|
||||
.map_err(|_| (StatusCode::BAD_REQUEST, "Invalid session ID".to_string()))?;
|
||||
let db = tenant_db(&agent, &tenant)
|
||||
.await
|
||||
.map_err(|s| (s, "failed to acquire tenant database".to_string()))?;
|
||||
|
||||
// Verify session exists and is running
|
||||
let session = agent
|
||||
.db
|
||||
let session = db
|
||||
.pentest_sessions()
|
||||
.find_one(doc! { "_id": oid })
|
||||
.await
|
||||
@@ -506,8 +515,7 @@ pub async fn send_message(
|
||||
)
|
||||
})?;
|
||||
|
||||
let target = agent
|
||||
.db
|
||||
let target = db
|
||||
.dast_targets()
|
||||
.find_one(doc! { "_id": target_oid })
|
||||
.await
|
||||
@@ -527,13 +535,13 @@ pub async fn send_message(
|
||||
// Store user message
|
||||
let session_id = id.clone();
|
||||
let user_msg = PentestMessage::user(session_id.clone(), req.message.clone());
|
||||
let _ = agent.db.pentest_messages().insert_one(&user_msg).await;
|
||||
let _ = db.pentest_messages().insert_one(&user_msg).await;
|
||||
|
||||
let response_msg = user_msg.clone();
|
||||
|
||||
// Spawn orchestrator to continue the session
|
||||
let llm = agent.llm.clone();
|
||||
let db = agent.db.clone();
|
||||
let db_for_orchestrator = db.clone();
|
||||
let message = req.message.clone();
|
||||
|
||||
// Use existing broadcast sender if available, otherwise create a new one
|
||||
@@ -548,7 +556,7 @@ pub async fn send_message(
|
||||
.unwrap_or_else(|| agent.register_session_stream(&session_id));
|
||||
|
||||
tokio::spawn(async move {
|
||||
let orchestrator = PentestOrchestrator::new(llm, db, event_tx, None);
|
||||
let orchestrator = PentestOrchestrator::new(llm, db_for_orchestrator, event_tx, None);
|
||||
orchestrator
|
||||
.run_session_guarded(&session, &target, &message)
|
||||
.await;
|
||||
@@ -565,13 +573,16 @@ pub async fn send_message(
|
||||
#[tracing::instrument(skip_all, fields(session_id = %id))]
|
||||
pub async fn stop_session(
|
||||
Extension(agent): AgentExt,
|
||||
tenant: TenantCtx,
|
||||
Path(id): Path<String>,
|
||||
) -> Result<Json<ApiResponse<PentestSession>>, (StatusCode, String)> {
|
||||
let oid = mongodb::bson::oid::ObjectId::parse_str(&id)
|
||||
.map_err(|_| (StatusCode::BAD_REQUEST, "Invalid session ID".to_string()))?;
|
||||
let db = tenant_db(&agent, &tenant)
|
||||
.await
|
||||
.map_err(|s| (s, "failed to acquire tenant database".to_string()))?;
|
||||
|
||||
let session = agent
|
||||
.db
|
||||
let session = db
|
||||
.pentest_sessions()
|
||||
.find_one(doc! { "_id": oid })
|
||||
.await
|
||||
@@ -590,9 +601,7 @@ pub async fn stop_session(
|
||||
));
|
||||
}
|
||||
|
||||
agent
|
||||
.db
|
||||
.pentest_sessions()
|
||||
db.pentest_sessions()
|
||||
.update_one(
|
||||
doc! { "_id": oid },
|
||||
doc! { "$set": {
|
||||
@@ -612,8 +621,7 @@ pub async fn stop_session(
|
||||
// Clean up session resources
|
||||
agent.cleanup_session(&id);
|
||||
|
||||
let updated = agent
|
||||
.db
|
||||
let updated = db
|
||||
.pentest_sessions()
|
||||
.find_one(doc! { "_id": oid })
|
||||
.await
|
||||
@@ -641,13 +649,16 @@ pub async fn stop_session(
|
||||
#[tracing::instrument(skip_all, fields(session_id = %id))]
|
||||
pub async fn pause_session(
|
||||
Extension(agent): AgentExt,
|
||||
tenant: TenantCtx,
|
||||
Path(id): Path<String>,
|
||||
) -> Result<Json<ApiResponse<serde_json::Value>>, (StatusCode, String)> {
|
||||
let oid = mongodb::bson::oid::ObjectId::parse_str(&id)
|
||||
.map_err(|_| (StatusCode::BAD_REQUEST, "Invalid session ID".to_string()))?;
|
||||
let db = tenant_db(&agent, &tenant)
|
||||
.await
|
||||
.map_err(|s| (s, "failed to acquire tenant database".to_string()))?;
|
||||
|
||||
let session = agent
|
||||
.db
|
||||
let session = db
|
||||
.pentest_sessions()
|
||||
.find_one(doc! { "_id": oid })
|
||||
.await
|
||||
@@ -684,13 +695,16 @@ pub async fn pause_session(
|
||||
#[tracing::instrument(skip_all, fields(session_id = %id))]
|
||||
pub async fn resume_session(
|
||||
Extension(agent): AgentExt,
|
||||
tenant: TenantCtx,
|
||||
Path(id): Path<String>,
|
||||
) -> Result<Json<ApiResponse<serde_json::Value>>, (StatusCode, String)> {
|
||||
let oid = mongodb::bson::oid::ObjectId::parse_str(&id)
|
||||
.map_err(|_| (StatusCode::BAD_REQUEST, "Invalid session ID".to_string()))?;
|
||||
let db = tenant_db(&agent, &tenant)
|
||||
.await
|
||||
.map_err(|s| (s, "failed to acquire tenant database".to_string()))?;
|
||||
|
||||
let session = agent
|
||||
.db
|
||||
let session = db
|
||||
.pentest_sessions()
|
||||
.find_one(doc! { "_id": oid })
|
||||
.await
|
||||
@@ -727,12 +741,13 @@ pub async fn resume_session(
|
||||
#[tracing::instrument(skip_all, fields(session_id = %id))]
|
||||
pub async fn get_attack_chain(
|
||||
Extension(agent): AgentExt,
|
||||
tenant: TenantCtx,
|
||||
Path(id): Path<String>,
|
||||
) -> Result<Json<ApiResponse<Vec<AttackChainNode>>>, StatusCode> {
|
||||
let _oid = mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?;
|
||||
let db = tenant_db(&agent, &tenant).await?;
|
||||
|
||||
let nodes = match agent
|
||||
.db
|
||||
let nodes = match db
|
||||
.attack_chain_nodes()
|
||||
.find(doc! { "session_id": &id })
|
||||
.sort(doc! { "started_at": 1 })
|
||||
@@ -757,21 +772,21 @@ pub async fn get_attack_chain(
|
||||
#[tracing::instrument(skip_all, fields(session_id = %id))]
|
||||
pub async fn get_messages(
|
||||
Extension(agent): AgentExt,
|
||||
tenant: TenantCtx,
|
||||
Path(id): Path<String>,
|
||||
Query(params): Query<PaginationParams>,
|
||||
) -> Result<Json<ApiResponse<Vec<PentestMessage>>>, StatusCode> {
|
||||
let _oid = mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?;
|
||||
let db = tenant_db(&agent, &tenant).await?;
|
||||
|
||||
let skip = (params.page.saturating_sub(1)) * params.limit as u64;
|
||||
let total = agent
|
||||
.db
|
||||
let total = db
|
||||
.pentest_messages()
|
||||
.count_documents(doc! { "session_id": &id })
|
||||
.await
|
||||
.unwrap_or(0);
|
||||
|
||||
let messages = match agent
|
||||
.db
|
||||
let messages = match db
|
||||
.pentest_messages()
|
||||
.find(doc! { "session_id": &id })
|
||||
.sort(doc! { "created_at": 1 })
|
||||
@@ -797,21 +812,21 @@ pub async fn get_messages(
|
||||
#[tracing::instrument(skip_all, fields(session_id = %id))]
|
||||
pub async fn get_session_findings(
|
||||
Extension(agent): AgentExt,
|
||||
tenant: TenantCtx,
|
||||
Path(id): Path<String>,
|
||||
Query(params): Query<PaginationParams>,
|
||||
) -> Result<Json<ApiResponse<Vec<compliance_core::models::dast::DastFinding>>>, StatusCode> {
|
||||
let _oid = mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?;
|
||||
let db = tenant_db(&agent, &tenant).await?;
|
||||
|
||||
let skip = (params.page.saturating_sub(1)) * params.limit as u64;
|
||||
let total = agent
|
||||
.db
|
||||
let total = db
|
||||
.dast_findings()
|
||||
.count_documents(doc! { "session_id": &id })
|
||||
.await
|
||||
.unwrap_or(0);
|
||||
|
||||
let findings = match agent
|
||||
.db
|
||||
let findings = match db
|
||||
.dast_findings()
|
||||
.find(doc! { "session_id": &id })
|
||||
.sort(doc! { "created_at": -1 })
|
||||
|
||||
@@ -6,10 +6,11 @@ use axum::Json;
|
||||
use mongodb::bson::doc;
|
||||
|
||||
use compliance_core::models::pentest::*;
|
||||
use compliance_core::tenant_ctx::TenantCtx;
|
||||
|
||||
use crate::agent::ComplianceAgent;
|
||||
|
||||
use super::super::dto::{collect_cursor_async, ApiResponse};
|
||||
use super::super::dto::{collect_cursor_async, tenant_db, ApiResponse};
|
||||
|
||||
type AgentExt = Extension<Arc<ComplianceAgent>>;
|
||||
|
||||
@@ -17,8 +18,10 @@ type AgentExt = Extension<Arc<ComplianceAgent>>;
|
||||
#[tracing::instrument(skip_all)]
|
||||
pub async fn pentest_stats(
|
||||
Extension(agent): AgentExt,
|
||||
tenant: TenantCtx,
|
||||
) -> Result<Json<ApiResponse<PentestStats>>, StatusCode> {
|
||||
let db = &agent.db;
|
||||
let db = tenant_db(&agent, &tenant).await?;
|
||||
let db = &db;
|
||||
|
||||
let running_sessions = db
|
||||
.pentest_sessions()
|
||||
|
||||
@@ -11,10 +11,11 @@ use tokio_stream::wrappers::BroadcastStream;
|
||||
use tokio_stream::StreamExt;
|
||||
|
||||
use compliance_core::models::pentest::*;
|
||||
use compliance_core::tenant_ctx::TenantCtx;
|
||||
|
||||
use crate::agent::ComplianceAgent;
|
||||
|
||||
use super::super::dto::collect_cursor_async;
|
||||
use super::super::dto::{collect_cursor_async, tenant_db};
|
||||
|
||||
type AgentExt = Extension<Arc<ComplianceAgent>>;
|
||||
|
||||
@@ -25,13 +26,14 @@ type AgentExt = Extension<Arc<ComplianceAgent>>;
|
||||
#[tracing::instrument(skip_all, fields(session_id = %id))]
|
||||
pub async fn session_stream(
|
||||
Extension(agent): AgentExt,
|
||||
tenant: TenantCtx,
|
||||
Path(id): Path<String>,
|
||||
) -> Result<Sse<impl futures_util::Stream<Item = Result<Event, Infallible>>>, StatusCode> {
|
||||
let oid = mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?;
|
||||
let db = tenant_db(&agent, &tenant).await?;
|
||||
|
||||
// Verify session exists
|
||||
let _session = agent
|
||||
.db
|
||||
let _session = db
|
||||
.pentest_sessions()
|
||||
.find_one(doc! { "_id": oid })
|
||||
.await
|
||||
@@ -43,8 +45,7 @@ pub async fn session_stream(
|
||||
let mut initial_events: Vec<Result<Event, Infallible>> = Vec::new();
|
||||
|
||||
// Fetch recent messages for this session
|
||||
let messages: Vec<PentestMessage> = match agent
|
||||
.db
|
||||
let messages: Vec<PentestMessage> = match db
|
||||
.pentest_messages()
|
||||
.find(doc! { "session_id": &id })
|
||||
.sort(doc! { "created_at": 1 })
|
||||
@@ -56,8 +57,7 @@ pub async fn session_stream(
|
||||
};
|
||||
|
||||
// Fetch recent attack chain nodes
|
||||
let nodes: Vec<AttackChainNode> = match agent
|
||||
.db
|
||||
let nodes: Vec<AttackChainNode> = match db
|
||||
.attack_chain_nodes()
|
||||
.find(doc! { "session_id": &id })
|
||||
.sort(doc! { "started_at": 1 })
|
||||
@@ -94,8 +94,7 @@ pub async fn session_stream(
|
||||
}
|
||||
|
||||
// Add current session status event
|
||||
let session = agent
|
||||
.db
|
||||
let session = db
|
||||
.pentest_sessions()
|
||||
.find_one(doc! { "_id": oid })
|
||||
.await
|
||||
|
||||
@@ -5,13 +5,16 @@ use mongodb::bson::doc;
|
||||
|
||||
use super::dto::*;
|
||||
use compliance_core::models::*;
|
||||
use compliance_core::tenant_ctx::TenantCtx;
|
||||
|
||||
#[tracing::instrument(skip_all)]
|
||||
pub async fn list_repositories(
|
||||
Extension(agent): AgentExt,
|
||||
tenant: TenantCtx,
|
||||
Query(params): Query<PaginationParams>,
|
||||
) -> ApiResult<Vec<TrackedRepository>> {
|
||||
let db = &agent.db;
|
||||
let db = tenant_db(&agent, &tenant).await?;
|
||||
let db = &db;
|
||||
let skip = (params.page.saturating_sub(1)) * params.limit as u64;
|
||||
let total = db
|
||||
.repositories()
|
||||
@@ -43,6 +46,7 @@ pub async fn list_repositories(
|
||||
#[tracing::instrument(skip_all)]
|
||||
pub async fn add_repository(
|
||||
Extension(agent): AgentExt,
|
||||
tenant: TenantCtx,
|
||||
Json(req): Json<AddRepositoryRequest>,
|
||||
) -> Result<Json<ApiResponse<TrackedRepository>>, (StatusCode, String)> {
|
||||
// Validate repository access before saving
|
||||
@@ -69,17 +73,15 @@ pub async fn add_repository(
|
||||
repo.tracker_token = req.tracker_token;
|
||||
repo.scan_schedule = req.scan_schedule;
|
||||
|
||||
agent
|
||||
.db
|
||||
.repositories()
|
||||
.insert_one(&repo)
|
||||
let db = tenant_db(&agent, &tenant)
|
||||
.await
|
||||
.map_err(|_| {
|
||||
(
|
||||
StatusCode::CONFLICT,
|
||||
"Repository already exists".to_string(),
|
||||
)
|
||||
})?;
|
||||
.map_err(|s| (s, "failed to acquire tenant database".to_string()))?;
|
||||
db.repositories().insert_one(&repo).await.map_err(|_| {
|
||||
(
|
||||
StatusCode::CONFLICT,
|
||||
"Repository already exists".to_string(),
|
||||
)
|
||||
})?;
|
||||
|
||||
Ok(Json(ApiResponse {
|
||||
data: repo,
|
||||
@@ -91,10 +93,12 @@ pub async fn add_repository(
|
||||
#[tracing::instrument(skip_all, fields(repo_id = %id))]
|
||||
pub async fn update_repository(
|
||||
Extension(agent): AgentExt,
|
||||
tenant: TenantCtx,
|
||||
Path(id): Path<String>,
|
||||
Json(req): Json<UpdateRepositoryRequest>,
|
||||
) -> Result<Json<serde_json::Value>, StatusCode> {
|
||||
let oid = mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?;
|
||||
let db = tenant_db(&agent, &tenant).await?;
|
||||
|
||||
let mut set_doc = doc! { "updated_at": mongodb::bson::DateTime::now() };
|
||||
|
||||
@@ -126,8 +130,7 @@ pub async fn update_repository(
|
||||
set_doc.insert("scan_schedule", schedule);
|
||||
}
|
||||
|
||||
let result = agent
|
||||
.db
|
||||
let result = db
|
||||
.repositories()
|
||||
.update_one(doc! { "_id": oid }, doc! { "$set": set_doc })
|
||||
.await
|
||||
@@ -170,11 +173,12 @@ pub async fn trigger_scan(
|
||||
/// Return the webhook secret for a repository (used by dashboard to display it)
|
||||
pub async fn get_webhook_config(
|
||||
Extension(agent): AgentExt,
|
||||
tenant: TenantCtx,
|
||||
Path(id): Path<String>,
|
||||
) -> Result<Json<serde_json::Value>, StatusCode> {
|
||||
let oid = mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?;
|
||||
let repo = agent
|
||||
.db
|
||||
let db = tenant_db(&agent, &tenant).await?;
|
||||
let repo = db
|
||||
.repositories()
|
||||
.find_one(doc! { "_id": oid })
|
||||
.await
|
||||
@@ -196,10 +200,12 @@ pub async fn get_webhook_config(
|
||||
#[tracing::instrument(skip_all, fields(repo_id = %id))]
|
||||
pub async fn delete_repository(
|
||||
Extension(agent): AgentExt,
|
||||
tenant: TenantCtx,
|
||||
Path(id): Path<String>,
|
||||
) -> Result<Json<serde_json::Value>, StatusCode> {
|
||||
let oid = mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?;
|
||||
let db = &agent.db;
|
||||
let db = tenant_db(&agent, &tenant).await?;
|
||||
let db = &db;
|
||||
|
||||
// Delete the repository
|
||||
let result = db
|
||||
|
||||
@@ -6,6 +6,7 @@ use mongodb::bson::doc;
|
||||
|
||||
use super::dto::*;
|
||||
use compliance_core::models::SbomEntry;
|
||||
use compliance_core::tenant_ctx::TenantCtx;
|
||||
|
||||
const COPYLEFT_LICENSES: &[&str] = &[
|
||||
"GPL-2.0",
|
||||
@@ -29,8 +30,10 @@ const COPYLEFT_LICENSES: &[&str] = &[
|
||||
#[tracing::instrument(skip_all)]
|
||||
pub async fn sbom_filters(
|
||||
Extension(agent): AgentExt,
|
||||
tenant: TenantCtx,
|
||||
) -> Result<Json<serde_json::Value>, StatusCode> {
|
||||
let db = &agent.db;
|
||||
let db = tenant_db(&agent, &tenant).await?;
|
||||
let db = &db;
|
||||
|
||||
let managers: Vec<String> = db
|
||||
.sbom_entries()
|
||||
@@ -61,9 +64,11 @@ pub async fn sbom_filters(
|
||||
#[tracing::instrument(skip_all, fields(repo_id = ?filter.repo_id, package_manager = ?filter.package_manager))]
|
||||
pub async fn list_sbom(
|
||||
Extension(agent): AgentExt,
|
||||
tenant: TenantCtx,
|
||||
Query(filter): Query<SbomFilter>,
|
||||
) -> ApiResult<Vec<SbomEntry>> {
|
||||
let db = &agent.db;
|
||||
let db = tenant_db(&agent, &tenant).await?;
|
||||
let db = &db;
|
||||
let mut query = doc! {};
|
||||
|
||||
if let Some(repo_id) = &filter.repo_id {
|
||||
@@ -120,9 +125,11 @@ pub async fn list_sbom(
|
||||
#[tracing::instrument(skip_all)]
|
||||
pub async fn export_sbom(
|
||||
Extension(agent): AgentExt,
|
||||
tenant: TenantCtx,
|
||||
Query(params): Query<SbomExportParams>,
|
||||
) -> Result<impl IntoResponse, StatusCode> {
|
||||
let db = &agent.db;
|
||||
let db = tenant_db(&agent, &tenant).await?;
|
||||
let db = &db;
|
||||
let entries: Vec<SbomEntry> = match db
|
||||
.sbom_entries()
|
||||
.find(doc! { "repo_id": ¶ms.repo_id })
|
||||
@@ -236,9 +243,11 @@ pub async fn export_sbom(
|
||||
#[tracing::instrument(skip_all)]
|
||||
pub async fn license_summary(
|
||||
Extension(agent): AgentExt,
|
||||
tenant: TenantCtx,
|
||||
Query(params): Query<SbomFilter>,
|
||||
) -> ApiResult<Vec<LicenseSummary>> {
|
||||
let db = &agent.db;
|
||||
let db = tenant_db(&agent, &tenant).await?;
|
||||
let db = &db;
|
||||
let mut query = doc! {};
|
||||
if let Some(repo_id) = ¶ms.repo_id {
|
||||
query.insert("repo_id", repo_id);
|
||||
@@ -285,9 +294,11 @@ pub async fn license_summary(
|
||||
#[tracing::instrument(skip_all)]
|
||||
pub async fn sbom_diff(
|
||||
Extension(agent): AgentExt,
|
||||
tenant: TenantCtx,
|
||||
Query(params): Query<SbomDiffParams>,
|
||||
) -> ApiResult<SbomDiffResult> {
|
||||
let db = &agent.db;
|
||||
let db = tenant_db(&agent, &tenant).await?;
|
||||
let db = &db;
|
||||
|
||||
let entries_a: Vec<SbomEntry> = match db
|
||||
.sbom_entries()
|
||||
|
||||
@@ -4,13 +4,16 @@ use mongodb::bson::doc;
|
||||
|
||||
use super::dto::*;
|
||||
use compliance_core::models::ScanRun;
|
||||
use compliance_core::tenant_ctx::TenantCtx;
|
||||
|
||||
#[tracing::instrument(skip_all)]
|
||||
pub async fn list_scan_runs(
|
||||
Extension(agent): AgentExt,
|
||||
tenant: TenantCtx,
|
||||
Query(params): Query<PaginationParams>,
|
||||
) -> ApiResult<Vec<ScanRun>> {
|
||||
let db = &agent.db;
|
||||
let db = tenant_db(&agent, &tenant).await?;
|
||||
let db = &db;
|
||||
let skip = (params.page.saturating_sub(1)) * params.limit as u64;
|
||||
let total = db.scan_runs().count_documents(doc! {}).await.unwrap_or(0);
|
||||
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use axum::extract::Request;
|
||||
use axum::http::HeaderValue;
|
||||
use axum::middleware::Next;
|
||||
use axum::response::Response;
|
||||
use axum::{middleware, Extension};
|
||||
use tokio::sync::RwLock;
|
||||
use tower_http::cors::CorsLayer;
|
||||
@@ -8,11 +11,44 @@ use tower_http::set_header::SetResponseHeaderLayer;
|
||||
use tower_http::trace::TraceLayer;
|
||||
|
||||
use compliance_core::auth::{require_jwt_auth, require_tenant_status, JwksState};
|
||||
use compliance_core::{TenantContext, TenantStatus};
|
||||
|
||||
use crate::agent::ComplianceAgent;
|
||||
use crate::api::routes;
|
||||
use crate::error::AgentError;
|
||||
|
||||
/// Synthetic tenant id used when Keycloak isn't configured (local dev,
|
||||
/// `cargo run` against a bare Mongo). Lets the handler stack stay
|
||||
/// uniformly tenant-scoped without the operator having to spin up KC
|
||||
/// just to poke at the API. Override via `DEV_TENANT_ID`.
|
||||
const DEFAULT_DEV_TENANT_ID: &str = "dev";
|
||||
|
||||
/// Inject a synthetic [`TenantContext`] for any request that lacks one.
|
||||
/// Only mounted when Keycloak is NOT configured; with KC, the real
|
||||
/// `require_jwt_auth` middleware owns this and we never reach here
|
||||
/// without a context.
|
||||
///
|
||||
/// Public so the integration-test harness can mount it without
|
||||
/// duplicating the synthetic-context shape.
|
||||
pub async fn inject_dev_tenant(mut request: Request, next: Next) -> Response {
|
||||
if request.extensions().get::<TenantContext>().is_none() {
|
||||
let tenant_id =
|
||||
std::env::var("DEV_TENANT_ID").unwrap_or_else(|_| DEFAULT_DEV_TENANT_ID.to_string());
|
||||
let ctx = TenantContext {
|
||||
tenant_slug: tenant_id.clone(),
|
||||
tenant_id,
|
||||
org_roles: vec![],
|
||||
products: vec![],
|
||||
plan: "dev".to_string(),
|
||||
status: TenantStatus::Active,
|
||||
user_id: "dev-user".to_string(),
|
||||
user_name: None,
|
||||
};
|
||||
request.extensions_mut().insert(ctx);
|
||||
}
|
||||
next.run(request).await
|
||||
}
|
||||
|
||||
pub async fn start_api_server(agent: ComplianceAgent, port: u16) -> Result<(), AgentError> {
|
||||
let mut app = routes::build_router()
|
||||
.layer(Extension(Arc::new(agent.clone())))
|
||||
@@ -53,7 +89,14 @@ pub async fn start_api_server(agent: ComplianceAgent, port: u16) -> Result<(), A
|
||||
.layer(middleware::from_fn(require_jwt_auth))
|
||||
.layer(Extension(jwks_state));
|
||||
} else {
|
||||
tracing::warn!("Keycloak not configured - API endpoints are unprotected");
|
||||
let tenant_id =
|
||||
std::env::var("DEV_TENANT_ID").unwrap_or_else(|_| DEFAULT_DEV_TENANT_ID.to_string());
|
||||
tracing::warn!(
|
||||
tenant_id = %tenant_id,
|
||||
"Keycloak not configured — running unauthenticated against the dev tenant. \
|
||||
DO NOT use in any environment with real customer data."
|
||||
);
|
||||
app = app.layer(middleware::from_fn(inject_dev_tenant));
|
||||
}
|
||||
|
||||
let addr = format!("0.0.0.0:{port}");
|
||||
|
||||
@@ -75,9 +75,13 @@ impl TestServer {
|
||||
|
||||
let agent = ComplianceAgent::new(config, db, db_pool);
|
||||
|
||||
// Build the router with the agent extension
|
||||
// Build the router with the agent extension. After M7.2-B every
|
||||
// handler takes a TenantCtx extractor; without KC in the test
|
||||
// harness, the dev-tenant injector mounts a synthetic context so
|
||||
// tests run end-to-end against `<db_name>_dev`.
|
||||
let app = api::routes::build_router()
|
||||
.layer(axum::extract::Extension(Arc::new(agent)))
|
||||
.layer(axum::middleware::from_fn(api::server::inject_dev_tenant))
|
||||
.layer(tower_http::cors::CorsLayer::permissive());
|
||||
|
||||
// Bind to port 0 to get a random available port
|
||||
@@ -160,10 +164,20 @@ impl TestServer {
|
||||
&self.db_name
|
||||
}
|
||||
|
||||
/// Drop the test database on cleanup
|
||||
/// Drop the test database on cleanup. Post-M7.2-B the actual data
|
||||
/// lives in `<db_name>_<tenant>` per-tenant databases; list those
|
||||
/// off the cluster and drop them too.
|
||||
pub async fn cleanup(&self) {
|
||||
if let Ok(client) = mongodb::Client::with_uri_str(&self.mongodb_uri).await {
|
||||
client.database(&self.db_name).drop().await.ok();
|
||||
if let Ok(names) = client.list_database_names().await {
|
||||
let prefix = format!("{}_", self.db_name);
|
||||
for name in names {
|
||||
if name.starts_with(&prefix) {
|
||||
client.database(&name).drop().await.ok();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user