From cdfbb62f9d499e6fda2ff1a43061b1639552f2d6 Mon Sep 17 00:00:00 2001 From: Sharang Parnerkar <30073382+mighty840@users.noreply.github.com> Date: Wed, 17 Jun 2026 13:28:33 +0200 Subject: [PATCH] feat(m7.2-B): migrate API handlers to per-tenant database pool MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Builds on PR M7.2-A. Every HTTP handler in compliance-agent/src/api/ now takes a TenantCtx extractor and pulls a tenant-scoped Database from agent.db_pool.for_tenant(&ctx). The query bodies are unchanged — `db.findings().find(doc! {...})` reads from the tenant's own physical database, so the filter doc cannot leak data across tenants because the wrong tenant's data is literally on a different db handle. Changes - New `dto::tenant_db(&agent, &tenant) -> Result` helper. Every migrated handler calls it at the top of the body instead of `let db = &agent.db;`. 500 on the rare pool failure; 4xx auth failures are already handled by the M7.1 status gate. - New `api::server::inject_dev_tenant` middleware mounted only when Keycloak is NOT configured. Synthesizes a TenantContext with tenant_id = $DEV_TENANT_ID (default `dev`) so `cargo run` against a bare Mongo + no KC still serves the API. Logged loudly as "DO NOT use in any environment with real customer data". - Test harness: TestServer mounts inject_dev_tenant so existing E2E tests reach handlers; cleanup() now drops every _* per-tenant database, not just the legacy . Files migrated (handler count, all pass `cargo build`): - chat.rs (3) — also rewires RagPipeline + EmbeddingStore to the tenant DB's inner() so vector search is per-tenant - dast.rs (5) - findings.rs (5) - graph.rs (7) — also rewires GraphStore inside trigger_build's spawn to the tenant DB - health.rs (1) — stats_overview migrated; public /health stays un-scoped - issues.rs (1) - notifications.rs (5) - pentest_handlers/session.rs (12) — both wizard + legacy paths, plus pause/resume/stop/get_attack_chain/get_messages/ get_session_findings/lookup_repo. PentestOrchestrator now gets the tenant DB clone in its spawn. - pentest_handlers/export.rs (1) — fans out across sessions, attack_chain_nodes, dast_findings, findings, sbom_entries, graph_nodes from a single tenant_db acquisition - pentest_handlers/stats.rs (1) - pentest_handlers/stream.rs (1) — SSE handler verifies session via the tenant DB before subscribing - repos.rs (6) - sbom.rs (5) - scans.rs (1) help_chat.rs has no DB queries and was skipped. Test plan - cargo fmt --all clean - cargo clippy --workspace --exclude compliance-dashboard -- -D warnings clean - cargo test -p compliance-core --lib — 7 pass - cargo test -p compliance-agent --lib — 228 pass - cargo test -p compliance-agent --test tenant_isolation — 5 pass (driver-level isolation still holds post-handler migration) - cargo test -p compliance-agent --test tenant_status_middleware — 6 pass What's not yet migrated (PR-C / PR-D) - scheduler.rs (6 sites), pipeline/orchestrator.rs (14), pentest/orchestrator.rs (13), webhooks (gitea/github/gitlab), trackers/jira.rs, pipeline/dedup.rs etc. — background paths without a JWT-derived tenant context. - agent.db is still in the ComplianceAgent struct as a transitional handle for those paths. PR-D removes it once PR-C migrates the background paths. Co-Authored-By: Claude Opus 4.7 --- compliance-agent/src/api/handlers/chat.rs | 56 +++++---- compliance-agent/src/api/handlers/dast.rs | 31 +++-- compliance-agent/src/api/handlers/dto.rs | 21 ++++ compliance-agent/src/api/handlers/findings.rs | 27 ++-- compliance-agent/src/api/handlers/graph.rs | 34 +++-- compliance-agent/src/api/handlers/health.rs | 9 +- compliance-agent/src/api/handlers/issues.rs | 5 +- .../src/api/handlers/notifications.rs | 31 +++-- .../api/handlers/pentest_handlers/export.rs | 29 ++--- .../api/handlers/pentest_handlers/session.rs | 117 ++++++++++-------- .../api/handlers/pentest_handlers/stats.rs | 7 +- .../api/handlers/pentest_handlers/stream.rs | 17 ++- compliance-agent/src/api/handlers/repos.rs | 38 +++--- compliance-agent/src/api/handlers/sbom.rs | 21 +++- compliance-agent/src/api/handlers/scans.rs | 5 +- compliance-agent/src/api/server.rs | 45 ++++++- compliance-agent/tests/common/mod.rs | 18 ++- 17 files changed, 334 insertions(+), 177 deletions(-) diff --git a/compliance-agent/src/api/handlers/chat.rs b/compliance-agent/src/api/handlers/chat.rs index 2f82d13..9583ed8 100644 --- a/compliance-agent/src/api/handlers/chat.rs +++ b/compliance-agent/src/api/handlers/chat.rs @@ -7,11 +7,13 @@ use mongodb::bson::doc; use compliance_core::models::chat::{ChatRequest, ChatResponse, SourceReference}; use compliance_core::models::embedding::EmbeddingBuildRun; +use compliance_core::tenant_ctx::TenantCtx; use compliance_graph::graph::embedding_store::EmbeddingStore; use crate::agent::ComplianceAgent; use crate::rag::pipeline::RagPipeline; +use super::dto::tenant_db; use super::ApiResponse; type AgentExt = Extension>; @@ -20,10 +22,12 @@ type AgentExt = Extension>; #[tracing::instrument(skip_all, fields(repo_id = %repo_id))] pub async fn chat( Extension(agent): AgentExt, + tenant: TenantCtx, Path(repo_id): Path, Json(req): Json, ) -> Result>, StatusCode> { - let pipeline = RagPipeline::new(agent.llm.clone(), agent.db.inner()); + let db = tenant_db(&agent, &tenant).await?; + let pipeline = RagPipeline::new(agent.llm.clone(), db.inner()); // Step 1: Embed the user's message let query_vectors = agent @@ -133,12 +137,15 @@ pub async fn chat( #[tracing::instrument(skip_all, fields(repo_id = %repo_id))] pub async fn build_embeddings( Extension(agent): AgentExt, + tenant: TenantCtx, Path(repo_id): Path, ) -> Result, StatusCode> { + // Resolve the tenant DB up front so we can move it into the spawn; + // the JWT/dev context isn't available inside detached tasks. + let db = tenant_db(&agent, &tenant).await?; let agent_clone = (*agent).clone(); tokio::spawn(async move { - let repo = match agent_clone - .db + let repo = match db .repositories() .find_one(doc! { "_id": mongodb::bson::oid::ObjectId::parse_str(&repo_id).ok() }) .await @@ -151,8 +158,7 @@ pub async fn build_embeddings( }; // Get latest graph build - let build = match agent_clone - .db + let build = match db .graph_builds() .find_one(doc! { "repo_id": &repo_id }) .sort(doc! { "started_at": -1 }) @@ -171,26 +177,22 @@ pub async fn build_embeddings( .unwrap_or_else(|| "unknown".to_string()); // Get nodes - let nodes: Vec = match agent_clone - .db - .graph_nodes() - .find(doc! { "repo_id": &repo_id }) - .await - { - Ok(cursor) => { - use futures_util::StreamExt; - let mut items = Vec::new(); - let mut cursor = cursor; - while let Some(Ok(item)) = cursor.next().await { - items.push(item); + let nodes: Vec = + match db.graph_nodes().find(doc! { "repo_id": &repo_id }).await { + Ok(cursor) => { + use futures_util::StreamExt; + let mut items = Vec::new(); + let mut cursor = cursor; + while let Some(Ok(item)) = cursor.next().await { + items.push(item); + } + items } - items - } - Err(e) => { - tracing::error!("[{repo_id}] Failed to fetch nodes: {e}"); - return; - } - }; + Err(e) => { + tracing::error!("[{repo_id}] Failed to fetch nodes: {e}"); + return; + } + }; let creds = crate::pipeline::git::RepoCredentials { ssh_key_path: Some(agent_clone.config.ssh_key_path.clone()), @@ -207,7 +209,7 @@ pub async fn build_embeddings( } }; - let pipeline = RagPipeline::new(agent_clone.llm.clone(), agent_clone.db.inner()); + let pipeline = RagPipeline::new(agent_clone.llm.clone(), db.inner()); match pipeline .build_embeddings(&repo_id, &repo_path, &graph_build_id, &nodes) .await @@ -234,9 +236,11 @@ pub async fn build_embeddings( #[tracing::instrument(skip_all, fields(repo_id = %repo_id))] pub async fn embedding_status( Extension(agent): AgentExt, + tenant: TenantCtx, Path(repo_id): Path, ) -> Result>>, StatusCode> { - let store = EmbeddingStore::new(agent.db.inner()); + let db = tenant_db(&agent, &tenant).await?; + let store = EmbeddingStore::new(db.inner()); let build = store.get_latest_build(&repo_id).await.map_err(|e| { tracing::error!("Failed to get embedding status: {e}"); StatusCode::INTERNAL_SERVER_ERROR diff --git a/compliance-agent/src/api/handlers/dast.rs b/compliance-agent/src/api/handlers/dast.rs index fc74113..ecd9fff 100644 --- a/compliance-agent/src/api/handlers/dast.rs +++ b/compliance-agent/src/api/handlers/dast.rs @@ -7,9 +7,11 @@ use mongodb::bson::doc; use serde::Deserialize; use compliance_core::models::dast::{DastFinding, DastScanRun, DastTarget, DastTargetType}; +use compliance_core::tenant_ctx::TenantCtx; use crate::agent::ComplianceAgent; +use super::dto::tenant_db; use super::{collect_cursor_async, ApiResponse, PaginationParams}; type AgentExt = Extension>; @@ -45,9 +47,11 @@ fn default_rate_limit() -> u32 { #[tracing::instrument(skip_all)] pub async fn list_targets( Extension(agent): AgentExt, + tenant: TenantCtx, Query(params): Query, ) -> Result>>, StatusCode> { - let db = &agent.db; + let db = tenant_db(&agent, &tenant).await?; + let db = &db; let skip = (params.page.saturating_sub(1)) * params.limit as u64; let total = db .dast_targets() @@ -80,6 +84,7 @@ pub async fn list_targets( #[tracing::instrument(skip_all)] pub async fn add_target( Extension(agent): AgentExt, + tenant: TenantCtx, Json(req): Json, ) -> Result>, StatusCode> { let mut target = DastTarget::new(req.name, req.base_url, req.target_type); @@ -89,9 +94,8 @@ pub async fn add_target( target.rate_limit = req.rate_limit; target.allow_destructive = req.allow_destructive; - agent - .db - .dast_targets() + let db = tenant_db(&agent, &tenant).await?; + db.dast_targets() .insert_one(&target) .await .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; @@ -107,19 +111,19 @@ pub async fn add_target( #[tracing::instrument(skip_all, fields(target_id = %id))] pub async fn trigger_scan( Extension(agent): AgentExt, + tenant: TenantCtx, Path(id): Path, ) -> Result, StatusCode> { let oid = mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?; + let db = tenant_db(&agent, &tenant).await?; - let target = agent - .db + let target = db .dast_targets() .find_one(doc! { "_id": oid }) .await .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)? .ok_or(StatusCode::NOT_FOUND)?; - let db = agent.db.clone(); tokio::spawn(async move { let orchestrator = compliance_dast::DastOrchestrator::new(100); match orchestrator.run_scan(&target, Vec::new()).await { @@ -147,9 +151,11 @@ pub async fn trigger_scan( #[tracing::instrument(skip_all)] pub async fn list_scan_runs( Extension(agent): AgentExt, + tenant: TenantCtx, Query(params): Query, ) -> Result>>, StatusCode> { - let db = &agent.db; + let db = tenant_db(&agent, &tenant).await?; + let db = &db; let skip = (params.page.saturating_sub(1)) * params.limit as u64; let total = db .dast_scan_runs() @@ -183,9 +189,11 @@ pub async fn list_scan_runs( #[tracing::instrument(skip_all)] pub async fn list_findings( Extension(agent): AgentExt, + tenant: TenantCtx, Query(params): Query, ) -> Result>>, StatusCode> { - let db = &agent.db; + let db = tenant_db(&agent, &tenant).await?; + let db = &db; let skip = (params.page.saturating_sub(1)) * params.limit as u64; let total = db .dast_findings() @@ -219,12 +227,13 @@ pub async fn list_findings( #[tracing::instrument(skip_all, fields(finding_id = %id))] pub async fn get_finding( Extension(agent): AgentExt, + tenant: TenantCtx, Path(id): Path, ) -> Result>, StatusCode> { let oid = mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?; + let db = tenant_db(&agent, &tenant).await?; - let finding = agent - .db + let finding = db .dast_findings() .find_one(doc! { "_id": oid }) .await diff --git a/compliance-agent/src/api/handlers/dto.rs b/compliance-agent/src/api/handlers/dto.rs index b6c0992..a4a0ef4 100644 --- a/compliance-agent/src/api/handlers/dto.rs +++ b/compliance-agent/src/api/handlers/dto.rs @@ -180,6 +180,27 @@ pub struct SbomVersionDiff { pub(crate) type AgentExt = axum::extract::Extension>; pub(crate) type ApiResult = Result>, axum::http::StatusCode>; +/// Resolve a tenant-scoped [`Database`] from the request's +/// [`TenantContext`] (inserted by the M7.1 JWT middleware, or by the +/// dev fallback in unsecured environments). The pool ensures the +/// tenant's indexes idempotently. +/// +/// Returns 500 on the rare path where Mongo refuses the database +/// handle — the M7.1 auth/status middleware already rejects every +/// other failure mode with 4xx before we get here. +pub(crate) async fn tenant_db( + agent: &crate::agent::ComplianceAgent, + tenant: &compliance_core::tenant_ctx::TenantCtx, +) -> Result { + agent.db_pool.for_tenant(&tenant.0).await.map_err(|e| { + tracing::error!( + tenant_id = %tenant.0.tenant_id, + "Failed to acquire tenant database: {e}" + ); + axum::http::StatusCode::INTERNAL_SERVER_ERROR + }) +} + pub(crate) async fn collect_cursor_async( mut cursor: mongodb::Cursor, ) -> Vec { diff --git a/compliance-agent/src/api/handlers/findings.rs b/compliance-agent/src/api/handlers/findings.rs index d20a5e9..a8af6b8 100644 --- a/compliance-agent/src/api/handlers/findings.rs +++ b/compliance-agent/src/api/handlers/findings.rs @@ -5,13 +5,16 @@ use mongodb::bson::doc; use super::dto::*; use compliance_core::models::Finding; +use compliance_core::tenant_ctx::TenantCtx; #[tracing::instrument(skip_all, fields(repo_id = ?filter.repo_id, severity = ?filter.severity, scan_type = ?filter.scan_type))] pub async fn list_findings( Extension(agent): AgentExt, + tenant: TenantCtx, Query(filter): Query, ) -> ApiResult> { - let db = &agent.db; + let db = tenant_db(&agent, &tenant).await?; + let db = &db; let mut query = doc! {}; if let Some(repo_id) = &filter.repo_id { query.insert("repo_id", repo_id); @@ -81,11 +84,12 @@ pub async fn list_findings( #[tracing::instrument(skip_all, fields(finding_id = %id))] pub async fn get_finding( Extension(agent): AgentExt, + tenant: TenantCtx, Path(id): Path, ) -> Result>, StatusCode> { let oid = mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?; - let finding = agent - .db + let db = tenant_db(&agent, &tenant).await?; + let finding = db .findings() .find_one(doc! { "_id": oid }) .await @@ -102,14 +106,14 @@ pub async fn get_finding( #[tracing::instrument(skip_all, fields(finding_id = %id))] pub async fn update_finding_status( Extension(agent): AgentExt, + tenant: TenantCtx, Path(id): Path, Json(req): Json, ) -> Result, StatusCode> { let oid = mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?; + let db = tenant_db(&agent, &tenant).await?; - agent - .db - .findings() + db.findings() .update_one( doc! { "_id": oid }, doc! { "$set": { "status": &req.status, "updated_at": mongodb::bson::DateTime::now() } }, @@ -123,6 +127,7 @@ pub async fn update_finding_status( #[tracing::instrument(skip_all)] pub async fn bulk_update_finding_status( Extension(agent): AgentExt, + tenant: TenantCtx, Json(req): Json, ) -> Result, StatusCode> { let oids: Vec = req @@ -135,8 +140,8 @@ pub async fn bulk_update_finding_status( return Err(StatusCode::BAD_REQUEST); } - let result = agent - .db + let db = tenant_db(&agent, &tenant).await?; + let result = db .findings() .update_many( doc! { "_id": { "$in": oids } }, @@ -153,14 +158,14 @@ pub async fn bulk_update_finding_status( #[tracing::instrument(skip_all)] pub async fn update_finding_feedback( Extension(agent): AgentExt, + tenant: TenantCtx, Path(id): Path, Json(req): Json, ) -> Result, StatusCode> { let oid = mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?; + let db = tenant_db(&agent, &tenant).await?; - agent - .db - .findings() + db.findings() .update_one( doc! { "_id": oid }, doc! { "$set": { "developer_feedback": &req.feedback, "updated_at": mongodb::bson::DateTime::now() } }, diff --git a/compliance-agent/src/api/handlers/graph.rs b/compliance-agent/src/api/handlers/graph.rs index 43a79b9..8c10c9c 100644 --- a/compliance-agent/src/api/handlers/graph.rs +++ b/compliance-agent/src/api/handlers/graph.rs @@ -7,9 +7,11 @@ use mongodb::bson::doc; use serde::{Deserialize, Serialize}; use compliance_core::models::graph::{CodeEdge, CodeNode, GraphBuildRun, ImpactAnalysis}; +use compliance_core::tenant_ctx::TenantCtx; use crate::agent::ComplianceAgent; +use super::dto::tenant_db; use super::{collect_cursor_async, ApiResponse}; type AgentExt = Extension>; @@ -36,9 +38,11 @@ fn default_search_limit() -> usize { #[tracing::instrument(skip_all, fields(repo_id = %repo_id))] pub async fn get_graph( Extension(agent): AgentExt, + tenant: TenantCtx, Path(repo_id): Path, ) -> Result>, StatusCode> { - let db = &agent.db; + let db = tenant_db(&agent, &tenant).await?; + let db = &db; // Get latest build let build: Option = db @@ -98,9 +102,11 @@ pub async fn get_graph( #[tracing::instrument(skip_all, fields(repo_id = %repo_id))] pub async fn get_nodes( Extension(agent): AgentExt, + tenant: TenantCtx, Path(repo_id): Path, ) -> Result>>, StatusCode> { - let db = &agent.db; + let db = tenant_db(&agent, &tenant).await?; + let db = &db; let filter = doc! { "repo_id": &repo_id }; let nodes: Vec = match db.graph_nodes().find(filter).await { @@ -123,9 +129,11 @@ pub async fn get_nodes( #[tracing::instrument(skip_all, fields(repo_id = %repo_id))] pub async fn get_communities( Extension(agent): AgentExt, + tenant: TenantCtx, Path(repo_id): Path, ) -> Result>>, StatusCode> { - let db = &agent.db; + let db = tenant_db(&agent, &tenant).await?; + let db = &db; let filter = doc! { "repo_id": &repo_id }; let nodes: Vec = match db.graph_nodes().find(filter).await { @@ -176,9 +184,11 @@ pub struct CommunityInfo { #[tracing::instrument(skip_all, fields(repo_id = %repo_id, finding_id = %finding_id))] pub async fn get_impact( Extension(agent): AgentExt, + tenant: TenantCtx, Path((repo_id, finding_id)): Path<(String, String)>, ) -> Result>>, StatusCode> { - let db = &agent.db; + let db = tenant_db(&agent, &tenant).await?; + let db = &db; let filter = doc! { "repo_id": &repo_id, "finding_id": &finding_id }; let impact = db @@ -198,10 +208,12 @@ pub async fn get_impact( #[tracing::instrument(skip_all, fields(repo_id = %repo_id, query = %params.q))] pub async fn search_symbols( Extension(agent): AgentExt, + tenant: TenantCtx, Path(repo_id): Path, Query(params): Query, ) -> Result>>, StatusCode> { - let db = &agent.db; + let db = tenant_db(&agent, &tenant).await?; + let db = &db; // Simple text search on qualified_name and name fields let filter = doc! { @@ -234,10 +246,12 @@ pub async fn search_symbols( #[tracing::instrument(skip_all, fields(repo_id = %repo_id))] pub async fn get_file_content( Extension(agent): AgentExt, + tenant: TenantCtx, Path(repo_id): Path, Query(params): Query, ) -> Result>, StatusCode> { - let db = &agent.db; + let db = tenant_db(&agent, &tenant).await?; + let db = &db; // Look up the repository to get repo name let repo = db @@ -296,12 +310,13 @@ pub struct FileContent { #[tracing::instrument(skip_all, fields(repo_id = %repo_id))] pub async fn trigger_build( Extension(agent): AgentExt, + tenant: TenantCtx, Path(repo_id): Path, ) -> Result, StatusCode> { + let db = tenant_db(&agent, &tenant).await?; let agent_clone = (*agent).clone(); tokio::spawn(async move { - let repo = match agent_clone - .db + let repo = match db .repositories() .find_one(doc! { "_id": mongodb::bson::oid::ObjectId::parse_str(&repo_id).ok() }) .await @@ -333,8 +348,7 @@ pub async fn trigger_build( match engine.build_graph(&repo_path, &repo_id, &graph_build_id) { Ok((code_graph, build_run)) => { - let store = - compliance_graph::graph::persistence::GraphStore::new(agent_clone.db.inner()); + let store = compliance_graph::graph::persistence::GraphStore::new(db.inner()); let _ = store.delete_repo_graph(&repo_id).await; let _ = store .store_graph(&build_run, &code_graph.nodes, &code_graph.edges) diff --git a/compliance-agent/src/api/handlers/health.rs b/compliance-agent/src/api/handlers/health.rs index 264ae10..ef3b030 100644 --- a/compliance-agent/src/api/handlers/health.rs +++ b/compliance-agent/src/api/handlers/health.rs @@ -3,6 +3,7 @@ use mongodb::bson::doc; use super::dto::*; use compliance_core::models::ScanRun; +use compliance_core::tenant_ctx::TenantCtx; #[tracing::instrument(skip_all)] pub async fn health() -> Json { @@ -10,8 +11,12 @@ pub async fn health() -> Json { } #[tracing::instrument(skip_all)] -pub async fn stats_overview(axum::extract::Extension(agent): AgentExt) -> ApiResult { - let db = &agent.db; +pub async fn stats_overview( + axum::extract::Extension(agent): AgentExt, + tenant: TenantCtx, +) -> ApiResult { + let db = tenant_db(&agent, &tenant).await?; + let db = &db; let total_repositories = db .repositories() diff --git a/compliance-agent/src/api/handlers/issues.rs b/compliance-agent/src/api/handlers/issues.rs index d808445..2896b77 100644 --- a/compliance-agent/src/api/handlers/issues.rs +++ b/compliance-agent/src/api/handlers/issues.rs @@ -4,13 +4,16 @@ use mongodb::bson::doc; use super::dto::*; use compliance_core::models::TrackerIssue; +use compliance_core::tenant_ctx::TenantCtx; #[tracing::instrument(skip_all)] pub async fn list_issues( Extension(agent): AgentExt, + tenant: TenantCtx, Query(params): Query, ) -> ApiResult> { - let db = &agent.db; + let db = tenant_db(&agent, &tenant).await?; + let db = &db; let skip = (params.page.saturating_sub(1)) * params.limit as u64; let total = db .tracker_issues() diff --git a/compliance-agent/src/api/handlers/notifications.rs b/compliance-agent/src/api/handlers/notifications.rs index 2f82b8c..97d5f43 100644 --- a/compliance-agent/src/api/handlers/notifications.rs +++ b/compliance-agent/src/api/handlers/notifications.rs @@ -5,15 +5,18 @@ use mongodb::bson::doc; use serde::Deserialize; use compliance_core::models::notification::CveNotification; +use compliance_core::tenant_ctx::TenantCtx; -use super::dto::{AgentExt, ApiResponse}; +use super::dto::{tenant_db, AgentExt, ApiResponse}; /// GET /api/v1/notifications — List CVE notifications (newest first) #[tracing::instrument(skip_all)] pub async fn list_notifications( Extension(agent): AgentExt, + tenant: TenantCtx, axum::extract::Query(params): axum::extract::Query, ) -> Result>>, StatusCode> { + let db = tenant_db(&agent, &tenant).await?; let mut filter = doc! {}; // Filter by status (default: show new + read, exclude dismissed) @@ -41,15 +44,13 @@ pub async fn list_notifications( let limit = params.limit.unwrap_or(50).min(200); let skip = (page - 1) * limit as u64; - let total = agent - .db + let total = db .cve_notifications() .count_documents(filter.clone()) .await .unwrap_or(0); - let notifications: Vec = match agent - .db + let notifications: Vec = match db .cve_notifications() .find(filter) .sort(doc! { "created_at": -1 }) @@ -83,9 +84,10 @@ pub async fn list_notifications( #[tracing::instrument(skip_all)] pub async fn notification_count( Extension(agent): AgentExt, + tenant: TenantCtx, ) -> Result, StatusCode> { - let count = agent - .db + let db = tenant_db(&agent, &tenant).await?; + let count = db .cve_notifications() .count_documents(doc! { "status": "new" }) .await @@ -98,12 +100,13 @@ pub async fn notification_count( #[tracing::instrument(skip_all, fields(id = %id))] pub async fn mark_read( Extension(agent): AgentExt, + tenant: TenantCtx, axum::extract::Path(id): axum::extract::Path, ) -> Result, StatusCode> { let oid = mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?; + let db = tenant_db(&agent, &tenant).await?; - let result = agent - .db + let result = db .cve_notifications() .update_one( doc! { "_id": oid }, @@ -125,12 +128,13 @@ pub async fn mark_read( #[tracing::instrument(skip_all, fields(id = %id))] pub async fn dismiss_notification( Extension(agent): AgentExt, + tenant: TenantCtx, axum::extract::Path(id): axum::extract::Path, ) -> Result, StatusCode> { let oid = mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?; + let db = tenant_db(&agent, &tenant).await?; - let result = agent - .db + let result = db .cve_notifications() .update_one( doc! { "_id": oid }, @@ -149,9 +153,10 @@ pub async fn dismiss_notification( #[tracing::instrument(skip_all)] pub async fn mark_all_read( Extension(agent): AgentExt, + tenant: TenantCtx, ) -> Result, StatusCode> { - let result = agent - .db + let db = tenant_db(&agent, &tenant).await?; + let result = db .cve_notifications() .update_many( doc! { "status": "new" }, diff --git a/compliance-agent/src/api/handlers/pentest_handlers/export.rs b/compliance-agent/src/api/handlers/pentest_handlers/export.rs index bad7ac3..60f75f0 100644 --- a/compliance-agent/src/api/handlers/pentest_handlers/export.rs +++ b/compliance-agent/src/api/handlers/pentest_handlers/export.rs @@ -13,10 +13,11 @@ use compliance_core::models::dast::DastFinding; use compliance_core::models::finding::Finding; use compliance_core::models::pentest::*; use compliance_core::models::sbom::SbomEntry; +use compliance_core::tenant_ctx::TenantCtx; use crate::agent::ComplianceAgent; -use super::super::dto::collect_cursor_async; +use super::super::dto::{collect_cursor_async, tenant_db}; type AgentExt = Extension>; @@ -35,11 +36,15 @@ pub struct ExportBody { #[tracing::instrument(skip_all, fields(session_id = %id))] pub async fn export_session_report( Extension(agent): AgentExt, + tenant: TenantCtx, Path(id): Path, Json(body): Json, ) -> Result { let oid = mongodb::bson::oid::ObjectId::parse_str(&id) .map_err(|_| (StatusCode::BAD_REQUEST, "Invalid session ID".to_string()))?; + let db = tenant_db(&agent, &tenant) + .await + .map_err(|s| (s, "failed to acquire tenant database".to_string()))?; if body.password.len() < 8 { return Err(( @@ -49,8 +54,7 @@ pub async fn export_session_report( } // Fetch session - let session = agent - .db + let session = db .pentest_sessions() .find_one(doc! { "_id": oid }) .await @@ -64,9 +68,7 @@ pub async fn export_session_report( // Resolve target name let target = if let Ok(tid) = mongodb::bson::oid::ObjectId::parse_str(&session.target_id) { - agent - .db - .dast_targets() + db.dast_targets() .find_one(doc! { "_id": tid }) .await .ok() @@ -84,8 +86,7 @@ pub async fn export_session_report( .unwrap_or_default(); // Fetch attack chain nodes - let nodes: Vec = match agent - .db + let nodes: Vec = match db .attack_chain_nodes() .find(doc! { "session_id": &id }) .sort(doc! { "started_at": 1 }) @@ -96,8 +97,7 @@ pub async fn export_session_report( }; // Fetch DAST findings for this session, then deduplicate - let raw_findings: Vec = match agent - .db + let raw_findings: Vec = match db .dast_findings() .find(doc! { "session_id": &id }) .sort(doc! { "severity": -1, "created_at": -1 }) @@ -122,8 +122,7 @@ pub async fn export_session_report( .or_else(|| target.as_ref().and_then(|t| t.repo_id.clone())); let (sast_findings, sbom_entries, code_context) = if let Some(ref rid) = repo_id { - let sast: Vec = match agent - .db + let sast: Vec = match db .findings() .find(doc! { "repo_id": rid, @@ -143,8 +142,7 @@ pub async fn export_session_report( Err(_) => Vec::new(), }; - let sbom: Vec = match agent - .db + let sbom: Vec = match db .sbom_entries() .find(doc! { "repo_id": rid, @@ -164,8 +162,7 @@ pub async fn export_session_report( }; // Build code context from graph nodes - let code_ctx: Vec = match agent - .db + let code_ctx: Vec = match db .graph_nodes() .find(doc! { "repo_id": rid, "is_entry_point": true }) .limit(50) diff --git a/compliance-agent/src/api/handlers/pentest_handlers/session.rs b/compliance-agent/src/api/handlers/pentest_handlers/session.rs index 351ec01..91c8c58 100644 --- a/compliance-agent/src/api/handlers/pentest_handlers/session.rs +++ b/compliance-agent/src/api/handlers/pentest_handlers/session.rs @@ -7,11 +7,12 @@ use mongodb::bson::doc; use serde::Deserialize; use compliance_core::models::pentest::*; +use compliance_core::tenant_ctx::TenantCtx; use crate::agent::ComplianceAgent; use crate::pentest::PentestOrchestrator; -use super::super::dto::{collect_cursor_async, ApiResponse, PaginationParams}; +use super::super::dto::{collect_cursor_async, tenant_db, ApiResponse, PaginationParams}; type AgentExt = Extension>; @@ -43,6 +44,7 @@ pub struct LookupRepoQuery { #[tracing::instrument(skip_all)] pub async fn create_session( Extension(agent): AgentExt, + tenant: TenantCtx, Json(req): Json, ) -> Result>, (StatusCode, String)> { // Try to acquire a concurrency permit @@ -57,6 +59,10 @@ pub async fn create_session( ) })?; + let db = tenant_db(&agent, &tenant) + .await + .map_err(|s| (s, "failed to acquire tenant database".to_string()))?; + if let Some(ref config) = req.config { // ── Wizard path ────────────────────────────────────────────── if !config.disclaimer_accepted { @@ -67,8 +73,7 @@ pub async fn create_session( } // Look up or auto-create DastTarget by app_url - let target = match agent - .db + let target = match db .dast_targets() .find_one(doc! { "base_url": &config.app_url }) .await @@ -87,7 +92,7 @@ pub async fn create_session( } t.allow_destructive = config.allow_destructive; t.excluded_paths = config.scope_exclusions.clone(); - let res = agent.db.dast_targets().insert_one(&t).await.map_err(|e| { + let res = db.dast_targets().insert_one(&t).await.map_err(|e| { ( StatusCode::INTERNAL_SERVER_ERROR, format!("Failed to create target: {e}"), @@ -110,8 +115,7 @@ pub async fn create_session( // Resolve repo_id from git_repo_url if provided if let Some(ref git_url) = config.git_repo_url { - if let Ok(Some(repo)) = agent - .db + if let Ok(Some(repo)) = db .repositories() .find_one(doc! { "git_url": git_url }) .await @@ -120,8 +124,7 @@ pub async fn create_session( } } - let insert_result = agent - .db + let insert_result = db .pentest_sessions() .insert_one(&session) .await @@ -212,8 +215,7 @@ pub async fn create_session( // Persist encrypted credentials to DB if session_for_task.config.is_some() { if let Some(sid) = session.id { - let _ = agent - .db + let _ = db .pentest_sessions() .update_one( doc! { "_id": sid }, @@ -245,12 +247,13 @@ pub async fn create_session( }); let llm = agent.llm.clone(); - let db = agent.db.clone(); + let db_for_orchestrator = db.clone(); let session_clone = session.clone(); let target_clone = target.clone(); let agent_ref = agent.clone(); tokio::spawn(async move { - let orchestrator = PentestOrchestrator::new(llm, db, event_tx, Some(pause_rx)); + let orchestrator = + PentestOrchestrator::new(llm, db_for_orchestrator, event_tx, Some(pause_rx)); orchestrator .run_session_guarded(&session_clone, &target_clone, &initial_message) .await; @@ -292,8 +295,7 @@ pub async fn create_session( ) })?; - let target = agent - .db + let target = db .dast_targets() .find_one(doc! { "_id": oid }) .await @@ -310,8 +312,7 @@ pub async fn create_session( let mut session = PentestSession::new(target_id, strategy); session.repo_id = target.repo_id.clone(); - let insert_result = agent - .db + let insert_result = db .pentest_sessions() .insert_one(&session) .await @@ -338,12 +339,13 @@ pub async fn create_session( }); let llm = agent.llm.clone(); - let db = agent.db.clone(); + let db_for_orchestrator = db.clone(); let session_clone = session.clone(); let target_clone = target.clone(); let agent_ref = agent.clone(); tokio::spawn(async move { - let orchestrator = PentestOrchestrator::new(llm, db, event_tx, Some(pause_rx)); + let orchestrator = + PentestOrchestrator::new(llm, db_for_orchestrator, event_tx, Some(pause_rx)); orchestrator .run_session_guarded(&session_clone, &target_clone, &initial_message) .await; @@ -373,10 +375,11 @@ fn parse_strategy(s: &str) -> PentestStrategy { #[tracing::instrument(skip_all)] pub async fn lookup_repo( Extension(agent): AgentExt, + tenant: TenantCtx, Query(params): Query, ) -> Result>, StatusCode> { - let repo = agent - .db + let db = tenant_db(&agent, &tenant).await?; + let repo = db .repositories() .find_one(doc! { "git_url": ¶ms.url }) .await @@ -402,9 +405,11 @@ pub async fn lookup_repo( #[tracing::instrument(skip_all)] pub async fn list_sessions( Extension(agent): AgentExt, + tenant: TenantCtx, Query(params): Query, ) -> Result>>, StatusCode> { - let db = &agent.db; + let db = tenant_db(&agent, &tenant).await?; + let db = &db; let skip = (params.page.saturating_sub(1)) * params.limit as u64; let total = db .pentest_sessions() @@ -438,12 +443,13 @@ pub async fn list_sessions( #[tracing::instrument(skip_all, fields(session_id = %id))] pub async fn get_session( Extension(agent): AgentExt, + tenant: TenantCtx, Path(id): Path, ) -> Result>, StatusCode> { let oid = mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?; + let db = tenant_db(&agent, &tenant).await?; - let mut session = agent - .db + let mut session = db .pentest_sessions() .find_one(doc! { "_id": oid }) .await @@ -471,15 +477,18 @@ pub async fn get_session( #[tracing::instrument(skip_all, fields(session_id = %id))] pub async fn send_message( Extension(agent): AgentExt, + tenant: TenantCtx, Path(id): Path, Json(req): Json, ) -> Result>, (StatusCode, String)> { let oid = mongodb::bson::oid::ObjectId::parse_str(&id) .map_err(|_| (StatusCode::BAD_REQUEST, "Invalid session ID".to_string()))?; + let db = tenant_db(&agent, &tenant) + .await + .map_err(|s| (s, "failed to acquire tenant database".to_string()))?; // Verify session exists and is running - let session = agent - .db + let session = db .pentest_sessions() .find_one(doc! { "_id": oid }) .await @@ -506,8 +515,7 @@ pub async fn send_message( ) })?; - let target = agent - .db + let target = db .dast_targets() .find_one(doc! { "_id": target_oid }) .await @@ -527,13 +535,13 @@ pub async fn send_message( // Store user message let session_id = id.clone(); let user_msg = PentestMessage::user(session_id.clone(), req.message.clone()); - let _ = agent.db.pentest_messages().insert_one(&user_msg).await; + let _ = db.pentest_messages().insert_one(&user_msg).await; let response_msg = user_msg.clone(); // Spawn orchestrator to continue the session let llm = agent.llm.clone(); - let db = agent.db.clone(); + let db_for_orchestrator = db.clone(); let message = req.message.clone(); // Use existing broadcast sender if available, otherwise create a new one @@ -548,7 +556,7 @@ pub async fn send_message( .unwrap_or_else(|| agent.register_session_stream(&session_id)); tokio::spawn(async move { - let orchestrator = PentestOrchestrator::new(llm, db, event_tx, None); + let orchestrator = PentestOrchestrator::new(llm, db_for_orchestrator, event_tx, None); orchestrator .run_session_guarded(&session, &target, &message) .await; @@ -565,13 +573,16 @@ pub async fn send_message( #[tracing::instrument(skip_all, fields(session_id = %id))] pub async fn stop_session( Extension(agent): AgentExt, + tenant: TenantCtx, Path(id): Path, ) -> Result>, (StatusCode, String)> { let oid = mongodb::bson::oid::ObjectId::parse_str(&id) .map_err(|_| (StatusCode::BAD_REQUEST, "Invalid session ID".to_string()))?; + let db = tenant_db(&agent, &tenant) + .await + .map_err(|s| (s, "failed to acquire tenant database".to_string()))?; - let session = agent - .db + let session = db .pentest_sessions() .find_one(doc! { "_id": oid }) .await @@ -590,9 +601,7 @@ pub async fn stop_session( )); } - agent - .db - .pentest_sessions() + db.pentest_sessions() .update_one( doc! { "_id": oid }, doc! { "$set": { @@ -612,8 +621,7 @@ pub async fn stop_session( // Clean up session resources agent.cleanup_session(&id); - let updated = agent - .db + let updated = db .pentest_sessions() .find_one(doc! { "_id": oid }) .await @@ -641,13 +649,16 @@ pub async fn stop_session( #[tracing::instrument(skip_all, fields(session_id = %id))] pub async fn pause_session( Extension(agent): AgentExt, + tenant: TenantCtx, Path(id): Path, ) -> Result>, (StatusCode, String)> { let oid = mongodb::bson::oid::ObjectId::parse_str(&id) .map_err(|_| (StatusCode::BAD_REQUEST, "Invalid session ID".to_string()))?; + let db = tenant_db(&agent, &tenant) + .await + .map_err(|s| (s, "failed to acquire tenant database".to_string()))?; - let session = agent - .db + let session = db .pentest_sessions() .find_one(doc! { "_id": oid }) .await @@ -684,13 +695,16 @@ pub async fn pause_session( #[tracing::instrument(skip_all, fields(session_id = %id))] pub async fn resume_session( Extension(agent): AgentExt, + tenant: TenantCtx, Path(id): Path, ) -> Result>, (StatusCode, String)> { let oid = mongodb::bson::oid::ObjectId::parse_str(&id) .map_err(|_| (StatusCode::BAD_REQUEST, "Invalid session ID".to_string()))?; + let db = tenant_db(&agent, &tenant) + .await + .map_err(|s| (s, "failed to acquire tenant database".to_string()))?; - let session = agent - .db + let session = db .pentest_sessions() .find_one(doc! { "_id": oid }) .await @@ -727,12 +741,13 @@ pub async fn resume_session( #[tracing::instrument(skip_all, fields(session_id = %id))] pub async fn get_attack_chain( Extension(agent): AgentExt, + tenant: TenantCtx, Path(id): Path, ) -> Result>>, StatusCode> { let _oid = mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?; + let db = tenant_db(&agent, &tenant).await?; - let nodes = match agent - .db + let nodes = match db .attack_chain_nodes() .find(doc! { "session_id": &id }) .sort(doc! { "started_at": 1 }) @@ -757,21 +772,21 @@ pub async fn get_attack_chain( #[tracing::instrument(skip_all, fields(session_id = %id))] pub async fn get_messages( Extension(agent): AgentExt, + tenant: TenantCtx, Path(id): Path, Query(params): Query, ) -> Result>>, StatusCode> { let _oid = mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?; + let db = tenant_db(&agent, &tenant).await?; let skip = (params.page.saturating_sub(1)) * params.limit as u64; - let total = agent - .db + let total = db .pentest_messages() .count_documents(doc! { "session_id": &id }) .await .unwrap_or(0); - let messages = match agent - .db + let messages = match db .pentest_messages() .find(doc! { "session_id": &id }) .sort(doc! { "created_at": 1 }) @@ -797,21 +812,21 @@ pub async fn get_messages( #[tracing::instrument(skip_all, fields(session_id = %id))] pub async fn get_session_findings( Extension(agent): AgentExt, + tenant: TenantCtx, Path(id): Path, Query(params): Query, ) -> Result>>, StatusCode> { let _oid = mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?; + let db = tenant_db(&agent, &tenant).await?; let skip = (params.page.saturating_sub(1)) * params.limit as u64; - let total = agent - .db + let total = db .dast_findings() .count_documents(doc! { "session_id": &id }) .await .unwrap_or(0); - let findings = match agent - .db + let findings = match db .dast_findings() .find(doc! { "session_id": &id }) .sort(doc! { "created_at": -1 }) diff --git a/compliance-agent/src/api/handlers/pentest_handlers/stats.rs b/compliance-agent/src/api/handlers/pentest_handlers/stats.rs index 6333408..d849627 100644 --- a/compliance-agent/src/api/handlers/pentest_handlers/stats.rs +++ b/compliance-agent/src/api/handlers/pentest_handlers/stats.rs @@ -6,10 +6,11 @@ use axum::Json; use mongodb::bson::doc; use compliance_core::models::pentest::*; +use compliance_core::tenant_ctx::TenantCtx; use crate::agent::ComplianceAgent; -use super::super::dto::{collect_cursor_async, ApiResponse}; +use super::super::dto::{collect_cursor_async, tenant_db, ApiResponse}; type AgentExt = Extension>; @@ -17,8 +18,10 @@ type AgentExt = Extension>; #[tracing::instrument(skip_all)] pub async fn pentest_stats( Extension(agent): AgentExt, + tenant: TenantCtx, ) -> Result>, StatusCode> { - let db = &agent.db; + let db = tenant_db(&agent, &tenant).await?; + let db = &db; let running_sessions = db .pentest_sessions() diff --git a/compliance-agent/src/api/handlers/pentest_handlers/stream.rs b/compliance-agent/src/api/handlers/pentest_handlers/stream.rs index 015c288..6080470 100644 --- a/compliance-agent/src/api/handlers/pentest_handlers/stream.rs +++ b/compliance-agent/src/api/handlers/pentest_handlers/stream.rs @@ -11,10 +11,11 @@ use tokio_stream::wrappers::BroadcastStream; use tokio_stream::StreamExt; use compliance_core::models::pentest::*; +use compliance_core::tenant_ctx::TenantCtx; use crate::agent::ComplianceAgent; -use super::super::dto::collect_cursor_async; +use super::super::dto::{collect_cursor_async, tenant_db}; type AgentExt = Extension>; @@ -25,13 +26,14 @@ type AgentExt = Extension>; #[tracing::instrument(skip_all, fields(session_id = %id))] pub async fn session_stream( Extension(agent): AgentExt, + tenant: TenantCtx, Path(id): Path, ) -> Result>>, StatusCode> { let oid = mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?; + let db = tenant_db(&agent, &tenant).await?; // Verify session exists - let _session = agent - .db + let _session = db .pentest_sessions() .find_one(doc! { "_id": oid }) .await @@ -43,8 +45,7 @@ pub async fn session_stream( let mut initial_events: Vec> = Vec::new(); // Fetch recent messages for this session - let messages: Vec = match agent - .db + let messages: Vec = match db .pentest_messages() .find(doc! { "session_id": &id }) .sort(doc! { "created_at": 1 }) @@ -56,8 +57,7 @@ pub async fn session_stream( }; // Fetch recent attack chain nodes - let nodes: Vec = match agent - .db + let nodes: Vec = match db .attack_chain_nodes() .find(doc! { "session_id": &id }) .sort(doc! { "started_at": 1 }) @@ -94,8 +94,7 @@ pub async fn session_stream( } // Add current session status event - let session = agent - .db + let session = db .pentest_sessions() .find_one(doc! { "_id": oid }) .await diff --git a/compliance-agent/src/api/handlers/repos.rs b/compliance-agent/src/api/handlers/repos.rs index 891bcd5..f4b2397 100644 --- a/compliance-agent/src/api/handlers/repos.rs +++ b/compliance-agent/src/api/handlers/repos.rs @@ -5,13 +5,16 @@ use mongodb::bson::doc; use super::dto::*; use compliance_core::models::*; +use compliance_core::tenant_ctx::TenantCtx; #[tracing::instrument(skip_all)] pub async fn list_repositories( Extension(agent): AgentExt, + tenant: TenantCtx, Query(params): Query, ) -> ApiResult> { - let db = &agent.db; + let db = tenant_db(&agent, &tenant).await?; + let db = &db; let skip = (params.page.saturating_sub(1)) * params.limit as u64; let total = db .repositories() @@ -43,6 +46,7 @@ pub async fn list_repositories( #[tracing::instrument(skip_all)] pub async fn add_repository( Extension(agent): AgentExt, + tenant: TenantCtx, Json(req): Json, ) -> Result>, (StatusCode, String)> { // Validate repository access before saving @@ -69,17 +73,15 @@ pub async fn add_repository( repo.tracker_token = req.tracker_token; repo.scan_schedule = req.scan_schedule; - agent - .db - .repositories() - .insert_one(&repo) + let db = tenant_db(&agent, &tenant) .await - .map_err(|_| { - ( - StatusCode::CONFLICT, - "Repository already exists".to_string(), - ) - })?; + .map_err(|s| (s, "failed to acquire tenant database".to_string()))?; + db.repositories().insert_one(&repo).await.map_err(|_| { + ( + StatusCode::CONFLICT, + "Repository already exists".to_string(), + ) + })?; Ok(Json(ApiResponse { data: repo, @@ -91,10 +93,12 @@ pub async fn add_repository( #[tracing::instrument(skip_all, fields(repo_id = %id))] pub async fn update_repository( Extension(agent): AgentExt, + tenant: TenantCtx, Path(id): Path, Json(req): Json, ) -> Result, StatusCode> { let oid = mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?; + let db = tenant_db(&agent, &tenant).await?; let mut set_doc = doc! { "updated_at": mongodb::bson::DateTime::now() }; @@ -126,8 +130,7 @@ pub async fn update_repository( set_doc.insert("scan_schedule", schedule); } - let result = agent - .db + let result = db .repositories() .update_one(doc! { "_id": oid }, doc! { "$set": set_doc }) .await @@ -170,11 +173,12 @@ pub async fn trigger_scan( /// Return the webhook secret for a repository (used by dashboard to display it) pub async fn get_webhook_config( Extension(agent): AgentExt, + tenant: TenantCtx, Path(id): Path, ) -> Result, StatusCode> { let oid = mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?; - let repo = agent - .db + let db = tenant_db(&agent, &tenant).await?; + let repo = db .repositories() .find_one(doc! { "_id": oid }) .await @@ -196,10 +200,12 @@ pub async fn get_webhook_config( #[tracing::instrument(skip_all, fields(repo_id = %id))] pub async fn delete_repository( Extension(agent): AgentExt, + tenant: TenantCtx, Path(id): Path, ) -> Result, StatusCode> { let oid = mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?; - let db = &agent.db; + let db = tenant_db(&agent, &tenant).await?; + let db = &db; // Delete the repository let result = db diff --git a/compliance-agent/src/api/handlers/sbom.rs b/compliance-agent/src/api/handlers/sbom.rs index e9ec8ff..dbafe84 100644 --- a/compliance-agent/src/api/handlers/sbom.rs +++ b/compliance-agent/src/api/handlers/sbom.rs @@ -6,6 +6,7 @@ use mongodb::bson::doc; use super::dto::*; use compliance_core::models::SbomEntry; +use compliance_core::tenant_ctx::TenantCtx; const COPYLEFT_LICENSES: &[&str] = &[ "GPL-2.0", @@ -29,8 +30,10 @@ const COPYLEFT_LICENSES: &[&str] = &[ #[tracing::instrument(skip_all)] pub async fn sbom_filters( Extension(agent): AgentExt, + tenant: TenantCtx, ) -> Result, StatusCode> { - let db = &agent.db; + let db = tenant_db(&agent, &tenant).await?; + let db = &db; let managers: Vec = db .sbom_entries() @@ -61,9 +64,11 @@ pub async fn sbom_filters( #[tracing::instrument(skip_all, fields(repo_id = ?filter.repo_id, package_manager = ?filter.package_manager))] pub async fn list_sbom( Extension(agent): AgentExt, + tenant: TenantCtx, Query(filter): Query, ) -> ApiResult> { - let db = &agent.db; + let db = tenant_db(&agent, &tenant).await?; + let db = &db; let mut query = doc! {}; if let Some(repo_id) = &filter.repo_id { @@ -120,9 +125,11 @@ pub async fn list_sbom( #[tracing::instrument(skip_all)] pub async fn export_sbom( Extension(agent): AgentExt, + tenant: TenantCtx, Query(params): Query, ) -> Result { - let db = &agent.db; + let db = tenant_db(&agent, &tenant).await?; + let db = &db; let entries: Vec = match db .sbom_entries() .find(doc! { "repo_id": ¶ms.repo_id }) @@ -236,9 +243,11 @@ pub async fn export_sbom( #[tracing::instrument(skip_all)] pub async fn license_summary( Extension(agent): AgentExt, + tenant: TenantCtx, Query(params): Query, ) -> ApiResult> { - let db = &agent.db; + let db = tenant_db(&agent, &tenant).await?; + let db = &db; let mut query = doc! {}; if let Some(repo_id) = ¶ms.repo_id { query.insert("repo_id", repo_id); @@ -285,9 +294,11 @@ pub async fn license_summary( #[tracing::instrument(skip_all)] pub async fn sbom_diff( Extension(agent): AgentExt, + tenant: TenantCtx, Query(params): Query, ) -> ApiResult { - let db = &agent.db; + let db = tenant_db(&agent, &tenant).await?; + let db = &db; let entries_a: Vec = match db .sbom_entries() diff --git a/compliance-agent/src/api/handlers/scans.rs b/compliance-agent/src/api/handlers/scans.rs index ec16468..79553dd 100644 --- a/compliance-agent/src/api/handlers/scans.rs +++ b/compliance-agent/src/api/handlers/scans.rs @@ -4,13 +4,16 @@ use mongodb::bson::doc; use super::dto::*; use compliance_core::models::ScanRun; +use compliance_core::tenant_ctx::TenantCtx; #[tracing::instrument(skip_all)] pub async fn list_scan_runs( Extension(agent): AgentExt, + tenant: TenantCtx, Query(params): Query, ) -> ApiResult> { - let db = &agent.db; + let db = tenant_db(&agent, &tenant).await?; + let db = &db; let skip = (params.page.saturating_sub(1)) * params.limit as u64; let total = db.scan_runs().count_documents(doc! {}).await.unwrap_or(0); diff --git a/compliance-agent/src/api/server.rs b/compliance-agent/src/api/server.rs index 8b65894..e8a3f82 100644 --- a/compliance-agent/src/api/server.rs +++ b/compliance-agent/src/api/server.rs @@ -1,6 +1,9 @@ use std::sync::Arc; +use axum::extract::Request; use axum::http::HeaderValue; +use axum::middleware::Next; +use axum::response::Response; use axum::{middleware, Extension}; use tokio::sync::RwLock; use tower_http::cors::CorsLayer; @@ -8,11 +11,44 @@ use tower_http::set_header::SetResponseHeaderLayer; use tower_http::trace::TraceLayer; use compliance_core::auth::{require_jwt_auth, require_tenant_status, JwksState}; +use compliance_core::{TenantContext, TenantStatus}; use crate::agent::ComplianceAgent; use crate::api::routes; use crate::error::AgentError; +/// Synthetic tenant id used when Keycloak isn't configured (local dev, +/// `cargo run` against a bare Mongo). Lets the handler stack stay +/// uniformly tenant-scoped without the operator having to spin up KC +/// just to poke at the API. Override via `DEV_TENANT_ID`. +const DEFAULT_DEV_TENANT_ID: &str = "dev"; + +/// Inject a synthetic [`TenantContext`] for any request that lacks one. +/// Only mounted when Keycloak is NOT configured; with KC, the real +/// `require_jwt_auth` middleware owns this and we never reach here +/// without a context. +/// +/// Public so the integration-test harness can mount it without +/// duplicating the synthetic-context shape. +pub async fn inject_dev_tenant(mut request: Request, next: Next) -> Response { + if request.extensions().get::().is_none() { + let tenant_id = + std::env::var("DEV_TENANT_ID").unwrap_or_else(|_| DEFAULT_DEV_TENANT_ID.to_string()); + let ctx = TenantContext { + tenant_slug: tenant_id.clone(), + tenant_id, + org_roles: vec![], + products: vec![], + plan: "dev".to_string(), + status: TenantStatus::Active, + user_id: "dev-user".to_string(), + user_name: None, + }; + request.extensions_mut().insert(ctx); + } + next.run(request).await +} + pub async fn start_api_server(agent: ComplianceAgent, port: u16) -> Result<(), AgentError> { let mut app = routes::build_router() .layer(Extension(Arc::new(agent.clone()))) @@ -53,7 +89,14 @@ pub async fn start_api_server(agent: ComplianceAgent, port: u16) -> Result<(), A .layer(middleware::from_fn(require_jwt_auth)) .layer(Extension(jwks_state)); } else { - tracing::warn!("Keycloak not configured - API endpoints are unprotected"); + let tenant_id = + std::env::var("DEV_TENANT_ID").unwrap_or_else(|_| DEFAULT_DEV_TENANT_ID.to_string()); + tracing::warn!( + tenant_id = %tenant_id, + "Keycloak not configured — running unauthenticated against the dev tenant. \ + DO NOT use in any environment with real customer data." + ); + app = app.layer(middleware::from_fn(inject_dev_tenant)); } let addr = format!("0.0.0.0:{port}"); diff --git a/compliance-agent/tests/common/mod.rs b/compliance-agent/tests/common/mod.rs index 756ef34..cd1d307 100644 --- a/compliance-agent/tests/common/mod.rs +++ b/compliance-agent/tests/common/mod.rs @@ -75,9 +75,13 @@ impl TestServer { let agent = ComplianceAgent::new(config, db, db_pool); - // Build the router with the agent extension + // Build the router with the agent extension. After M7.2-B every + // handler takes a TenantCtx extractor; without KC in the test + // harness, the dev-tenant injector mounts a synthetic context so + // tests run end-to-end against `_dev`. let app = api::routes::build_router() .layer(axum::extract::Extension(Arc::new(agent))) + .layer(axum::middleware::from_fn(api::server::inject_dev_tenant)) .layer(tower_http::cors::CorsLayer::permissive()); // Bind to port 0 to get a random available port @@ -160,10 +164,20 @@ impl TestServer { &self.db_name } - /// Drop the test database on cleanup + /// Drop the test database on cleanup. Post-M7.2-B the actual data + /// lives in `_` per-tenant databases; list those + /// off the cluster and drop them too. pub async fn cleanup(&self) { if let Ok(client) = mongodb::Client::with_uri_str(&self.mongodb_uri).await { client.database(&self.db_name).drop().await.ok(); + if let Ok(names) = client.list_database_names().await { + let prefix = format!("{}_", self.db_name); + for name in names { + if name.starts_with(&prefix) { + client.database(&name).drop().await.ok(); + } + } + } } } }