Compare commits

..

1 Commits

Author SHA1 Message Date
Sharang Parnerkar f474699279 fix(core): JWKS refresh-on-failure in M7.1 auth middleware
CI / Check (pull_request) Successful in 8m17s
CI / Detect Changes (pull_request) Has been skipped
CI / Deploy Agent (pull_request) Has been skipped
CI / Deploy Dashboard (pull_request) Has been skipped
CI / Deploy Docs (pull_request) Has been skipped
CI / Deploy MCP (pull_request) Has been skipped
Without this, every Keycloak signing-key rotation produces a silent
401 storm against every request until the agent restarts — the cached
JWKS is held forever and never reconciled against KC.

Now: when `kid` isn't in the cached JWKS or the matching key fails
signature verification, we classify the failure as Stale, force a JWKS
refresh, and retry once. Anything else (expired, malformed, missing
tenant_id) is Permanent and short-circuits straight to 401.

* Splits the path into a pure `try_validate(token, header, kid, jwks)`
  helper returning a `ValidationError { Stale | Permanent }` enum.
* `fetch_or_get_jwks(state, force)` takes a force flag and holds the
  write lock across the network fetch so concurrent refreshers don't
  all hammer Keycloak when keys rotate (the second writer reuses what
  the first put in cache).
* Adds a unit test for the kid-not-found Stale classification.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
2026-06-04 16:40:55 +02:00
26 changed files with 301 additions and 870 deletions
Generated
-1
View File
@@ -687,7 +687,6 @@ dependencies = [
"tokio-cron-scheduler",
"tokio-stream",
"tokio-tungstenite 0.26.2",
"tower",
"tower-http",
"tracing",
"tracing-subscriber",
+2 -3
View File
@@ -7,7 +7,7 @@ edition = "2021"
workspace = true
[dependencies]
compliance-core = { workspace = true, features = ["mongodb", "telemetry", "axum"] }
compliance-core = { workspace = true, features = ["mongodb", "telemetry"] }
compliance-graph = { path = "../compliance-graph" }
compliance-dast = { path = "../compliance-dast" }
serde = { workspace = true }
@@ -44,8 +44,7 @@ dashmap = { workspace = true }
tokio-stream = { workspace = true }
[dev-dependencies]
compliance-core = { workspace = true, features = ["mongodb", "axum"] }
tower = { version = "0.5", features = ["util"] }
compliance-core = { workspace = true, features = ["mongodb"] }
reqwest = { workspace = true }
serde_json = { workspace = true }
tokio = { workspace = true }
+2 -9
View File
@@ -6,7 +6,7 @@ use tokio::sync::{broadcast, watch, Semaphore};
use compliance_core::models::pentest::PentestEvent;
use compliance_core::AgentConfig;
use crate::database::{Database, DatabasePool};
use crate::database::Database;
use crate::llm::LlmClient;
use crate::pipeline::orchestrator::PipelineOrchestrator;
@@ -16,13 +16,7 @@ const DEFAULT_MAX_CONCURRENT_SESSIONS: usize = 5;
#[derive(Clone)]
pub struct ComplianceAgent {
pub config: AgentConfig,
/// Transitional single-database handle. Used by handlers that have
/// not yet been migrated to `db_pool.for_tenant(&ctx)` (M7.2-B/C).
/// Will be removed once every call site is tenant-scoped (M7.2-D).
pub db: Database,
/// Per-tenant Mongo broker introduced in M7.2-A. Handlers should
/// prefer this and obtain a tenant-scoped [`Database`] from it.
pub db_pool: DatabasePool,
pub llm: Arc<LlmClient>,
pub http: reqwest::Client,
/// Per-session broadcast senders for SSE streaming.
@@ -34,7 +28,7 @@ pub struct ComplianceAgent {
}
impl ComplianceAgent {
pub fn new(config: AgentConfig, db: Database, db_pool: DatabasePool) -> Self {
pub fn new(config: AgentConfig, db: Database) -> Self {
let llm = Arc::new(LlmClient::new(
config.litellm_url.clone(),
config.litellm_api_key.clone(),
@@ -49,7 +43,6 @@ impl ComplianceAgent {
Self {
config,
db,
db_pool,
llm,
http,
session_streams: Arc::new(DashMap::new()),
+113
View File
@@ -0,0 +1,113 @@
use std::sync::Arc;
use axum::{
extract::Request,
middleware::Next,
response::{IntoResponse, Response},
};
use jsonwebtoken::{decode, decode_header, jwk::JwkSet, DecodingKey, Validation};
use reqwest::StatusCode;
use serde::Deserialize;
use tokio::sync::RwLock;
/// Cached JWKS from Keycloak for token validation.
#[derive(Clone)]
pub struct JwksState {
pub jwks: Arc<RwLock<Option<JwkSet>>>,
pub jwks_url: String,
}
#[derive(Debug, Deserialize)]
struct Claims {
#[allow(dead_code)]
sub: String,
}
const PUBLIC_ENDPOINTS: &[&str] = &["/api/v1/health"];
/// Middleware that validates Bearer JWT tokens against Keycloak's JWKS.
///
/// Skips validation for health check endpoints.
/// If `JwksState` is not present as an extension (keycloak not configured),
/// all requests pass through.
pub async fn require_jwt_auth(request: Request, next: Next) -> Response {
let path = request.uri().path();
if PUBLIC_ENDPOINTS.contains(&path) {
return next.run(request).await;
}
let jwks_state = match request.extensions().get::<JwksState>() {
Some(s) => s.clone(),
None => return next.run(request).await,
};
let auth_header = match request.headers().get("authorization") {
Some(h) => h,
None => return (StatusCode::UNAUTHORIZED, "Missing authorization header").into_response(),
};
let token = match auth_header.to_str() {
Ok(s) if s.starts_with("Bearer ") => &s[7..],
_ => return (StatusCode::UNAUTHORIZED, "Invalid authorization header").into_response(),
};
match validate_token(token, &jwks_state).await {
Ok(()) => next.run(request).await,
Err(e) => {
tracing::warn!("JWT validation failed: {e}");
(StatusCode::UNAUTHORIZED, "Invalid token").into_response()
}
}
}
async fn validate_token(token: &str, state: &JwksState) -> Result<(), String> {
let header = decode_header(token).map_err(|e| format!("failed to decode JWT header: {e}"))?;
let kid = header
.kid
.ok_or_else(|| "JWT missing kid header".to_string())?;
let jwks = fetch_or_get_jwks(state).await?;
let jwk = jwks
.keys
.iter()
.find(|k| k.common.key_id.as_deref() == Some(&kid))
.ok_or_else(|| "no matching key found in JWKS".to_string())?;
let decoding_key =
DecodingKey::from_jwk(jwk).map_err(|e| format!("failed to create decoding key: {e}"))?;
let mut validation = Validation::new(header.alg);
validation.validate_exp = true;
validation.validate_aud = false;
decode::<Claims>(token, &decoding_key, &validation)
.map_err(|e| format!("token validation failed: {e}"))?;
Ok(())
}
async fn fetch_or_get_jwks(state: &JwksState) -> Result<JwkSet, String> {
{
let cached = state.jwks.read().await;
if let Some(ref jwks) = *cached {
return Ok(jwks.clone());
}
}
let resp = reqwest::get(&state.jwks_url)
.await
.map_err(|e| format!("failed to fetch JWKS: {e}"))?;
let jwks: JwkSet = resp
.json()
.await
.map_err(|e| format!("failed to parse JWKS: {e}"))?;
let mut cached = state.jwks.write().await;
*cached = Some(jwks.clone());
Ok(jwks)
}
+26 -30
View File
@@ -7,13 +7,11 @@ use mongodb::bson::doc;
use compliance_core::models::chat::{ChatRequest, ChatResponse, SourceReference};
use compliance_core::models::embedding::EmbeddingBuildRun;
use compliance_core::tenant_ctx::TenantCtx;
use compliance_graph::graph::embedding_store::EmbeddingStore;
use crate::agent::ComplianceAgent;
use crate::rag::pipeline::RagPipeline;
use super::dto::tenant_db;
use super::ApiResponse;
type AgentExt = Extension<Arc<ComplianceAgent>>;
@@ -22,12 +20,10 @@ type AgentExt = Extension<Arc<ComplianceAgent>>;
#[tracing::instrument(skip_all, fields(repo_id = %repo_id))]
pub async fn chat(
Extension(agent): AgentExt,
tenant: TenantCtx,
Path(repo_id): Path<String>,
Json(req): Json<ChatRequest>,
) -> Result<Json<ApiResponse<ChatResponse>>, StatusCode> {
let db = tenant_db(&agent, &tenant).await?;
let pipeline = RagPipeline::new(agent.llm.clone(), db.inner());
let pipeline = RagPipeline::new(agent.llm.clone(), agent.db.inner());
// Step 1: Embed the user's message
let query_vectors = agent
@@ -137,15 +133,12 @@ pub async fn chat(
#[tracing::instrument(skip_all, fields(repo_id = %repo_id))]
pub async fn build_embeddings(
Extension(agent): AgentExt,
tenant: TenantCtx,
Path(repo_id): Path<String>,
) -> Result<Json<serde_json::Value>, StatusCode> {
// Resolve the tenant DB up front so we can move it into the spawn;
// the JWT/dev context isn't available inside detached tasks.
let db = tenant_db(&agent, &tenant).await?;
let agent_clone = (*agent).clone();
tokio::spawn(async move {
let repo = match db
let repo = match agent_clone
.db
.repositories()
.find_one(doc! { "_id": mongodb::bson::oid::ObjectId::parse_str(&repo_id).ok() })
.await
@@ -158,7 +151,8 @@ pub async fn build_embeddings(
};
// Get latest graph build
let build = match db
let build = match agent_clone
.db
.graph_builds()
.find_one(doc! { "repo_id": &repo_id })
.sort(doc! { "started_at": -1 })
@@ -177,22 +171,26 @@ pub async fn build_embeddings(
.unwrap_or_else(|| "unknown".to_string());
// Get nodes
let nodes: Vec<compliance_core::models::graph::CodeNode> =
match db.graph_nodes().find(doc! { "repo_id": &repo_id }).await {
Ok(cursor) => {
use futures_util::StreamExt;
let mut items = Vec::new();
let mut cursor = cursor;
while let Some(Ok(item)) = cursor.next().await {
items.push(item);
}
items
let nodes: Vec<compliance_core::models::graph::CodeNode> = match agent_clone
.db
.graph_nodes()
.find(doc! { "repo_id": &repo_id })
.await
{
Ok(cursor) => {
use futures_util::StreamExt;
let mut items = Vec::new();
let mut cursor = cursor;
while let Some(Ok(item)) = cursor.next().await {
items.push(item);
}
Err(e) => {
tracing::error!("[{repo_id}] Failed to fetch nodes: {e}");
return;
}
};
items
}
Err(e) => {
tracing::error!("[{repo_id}] Failed to fetch nodes: {e}");
return;
}
};
let creds = crate::pipeline::git::RepoCredentials {
ssh_key_path: Some(agent_clone.config.ssh_key_path.clone()),
@@ -209,7 +207,7 @@ pub async fn build_embeddings(
}
};
let pipeline = RagPipeline::new(agent_clone.llm.clone(), db.inner());
let pipeline = RagPipeline::new(agent_clone.llm.clone(), agent_clone.db.inner());
match pipeline
.build_embeddings(&repo_id, &repo_path, &graph_build_id, &nodes)
.await
@@ -236,11 +234,9 @@ pub async fn build_embeddings(
#[tracing::instrument(skip_all, fields(repo_id = %repo_id))]
pub async fn embedding_status(
Extension(agent): AgentExt,
tenant: TenantCtx,
Path(repo_id): Path<String>,
) -> Result<Json<ApiResponse<Option<EmbeddingBuildRun>>>, StatusCode> {
let db = tenant_db(&agent, &tenant).await?;
let store = EmbeddingStore::new(db.inner());
let store = EmbeddingStore::new(agent.db.inner());
let build = store.get_latest_build(&repo_id).await.map_err(|e| {
tracing::error!("Failed to get embedding status: {e}");
StatusCode::INTERNAL_SERVER_ERROR
+11 -20
View File
@@ -7,11 +7,9 @@ use mongodb::bson::doc;
use serde::Deserialize;
use compliance_core::models::dast::{DastFinding, DastScanRun, DastTarget, DastTargetType};
use compliance_core::tenant_ctx::TenantCtx;
use crate::agent::ComplianceAgent;
use super::dto::tenant_db;
use super::{collect_cursor_async, ApiResponse, PaginationParams};
type AgentExt = Extension<Arc<ComplianceAgent>>;
@@ -47,11 +45,9 @@ fn default_rate_limit() -> u32 {
#[tracing::instrument(skip_all)]
pub async fn list_targets(
Extension(agent): AgentExt,
tenant: TenantCtx,
Query(params): Query<PaginationParams>,
) -> Result<Json<ApiResponse<Vec<DastTarget>>>, StatusCode> {
let db = tenant_db(&agent, &tenant).await?;
let db = &db;
let db = &agent.db;
let skip = (params.page.saturating_sub(1)) * params.limit as u64;
let total = db
.dast_targets()
@@ -84,7 +80,6 @@ pub async fn list_targets(
#[tracing::instrument(skip_all)]
pub async fn add_target(
Extension(agent): AgentExt,
tenant: TenantCtx,
Json(req): Json<AddTargetRequest>,
) -> Result<Json<ApiResponse<DastTarget>>, StatusCode> {
let mut target = DastTarget::new(req.name, req.base_url, req.target_type);
@@ -94,8 +89,9 @@ pub async fn add_target(
target.rate_limit = req.rate_limit;
target.allow_destructive = req.allow_destructive;
let db = tenant_db(&agent, &tenant).await?;
db.dast_targets()
agent
.db
.dast_targets()
.insert_one(&target)
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
@@ -111,19 +107,19 @@ pub async fn add_target(
#[tracing::instrument(skip_all, fields(target_id = %id))]
pub async fn trigger_scan(
Extension(agent): AgentExt,
tenant: TenantCtx,
Path(id): Path<String>,
) -> Result<Json<serde_json::Value>, StatusCode> {
let oid = mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?;
let db = tenant_db(&agent, &tenant).await?;
let target = db
let target = agent
.db
.dast_targets()
.find_one(doc! { "_id": oid })
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
.ok_or(StatusCode::NOT_FOUND)?;
let db = agent.db.clone();
tokio::spawn(async move {
let orchestrator = compliance_dast::DastOrchestrator::new(100);
match orchestrator.run_scan(&target, Vec::new()).await {
@@ -151,11 +147,9 @@ pub async fn trigger_scan(
#[tracing::instrument(skip_all)]
pub async fn list_scan_runs(
Extension(agent): AgentExt,
tenant: TenantCtx,
Query(params): Query<PaginationParams>,
) -> Result<Json<ApiResponse<Vec<DastScanRun>>>, StatusCode> {
let db = tenant_db(&agent, &tenant).await?;
let db = &db;
let db = &agent.db;
let skip = (params.page.saturating_sub(1)) * params.limit as u64;
let total = db
.dast_scan_runs()
@@ -189,11 +183,9 @@ pub async fn list_scan_runs(
#[tracing::instrument(skip_all)]
pub async fn list_findings(
Extension(agent): AgentExt,
tenant: TenantCtx,
Query(params): Query<PaginationParams>,
) -> Result<Json<ApiResponse<Vec<DastFinding>>>, StatusCode> {
let db = tenant_db(&agent, &tenant).await?;
let db = &db;
let db = &agent.db;
let skip = (params.page.saturating_sub(1)) * params.limit as u64;
let total = db
.dast_findings()
@@ -227,13 +219,12 @@ pub async fn list_findings(
#[tracing::instrument(skip_all, fields(finding_id = %id))]
pub async fn get_finding(
Extension(agent): AgentExt,
tenant: TenantCtx,
Path(id): Path<String>,
) -> Result<Json<ApiResponse<DastFinding>>, StatusCode> {
let oid = mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?;
let db = tenant_db(&agent, &tenant).await?;
let finding = db
let finding = agent
.db
.dast_findings()
.find_one(doc! { "_id": oid })
.await
-21
View File
@@ -180,27 +180,6 @@ pub struct SbomVersionDiff {
pub(crate) type AgentExt = axum::extract::Extension<std::sync::Arc<crate::agent::ComplianceAgent>>;
pub(crate) type ApiResult<T> = Result<axum::Json<ApiResponse<T>>, axum::http::StatusCode>;
/// Resolve a tenant-scoped [`Database`] from the request's
/// [`TenantContext`] (inserted by the M7.1 JWT middleware, or by the
/// dev fallback in unsecured environments). The pool ensures the
/// tenant's indexes idempotently.
///
/// Returns 500 on the rare path where Mongo refuses the database
/// handle — the M7.1 auth/status middleware already rejects every
/// other failure mode with 4xx before we get here.
pub(crate) async fn tenant_db(
agent: &crate::agent::ComplianceAgent,
tenant: &compliance_core::tenant_ctx::TenantCtx,
) -> Result<crate::database::Database, axum::http::StatusCode> {
agent.db_pool.for_tenant(&tenant.0).await.map_err(|e| {
tracing::error!(
tenant_id = %tenant.0.tenant_id,
"Failed to acquire tenant database: {e}"
);
axum::http::StatusCode::INTERNAL_SERVER_ERROR
})
}
pub(crate) async fn collect_cursor_async<T: serde::de::DeserializeOwned + Unpin + Send>(
mut cursor: mongodb::Cursor<T>,
) -> Vec<T> {
+11 -16
View File
@@ -5,16 +5,13 @@ use mongodb::bson::doc;
use super::dto::*;
use compliance_core::models::Finding;
use compliance_core::tenant_ctx::TenantCtx;
#[tracing::instrument(skip_all, fields(repo_id = ?filter.repo_id, severity = ?filter.severity, scan_type = ?filter.scan_type))]
pub async fn list_findings(
Extension(agent): AgentExt,
tenant: TenantCtx,
Query(filter): Query<FindingsFilter>,
) -> ApiResult<Vec<Finding>> {
let db = tenant_db(&agent, &tenant).await?;
let db = &db;
let db = &agent.db;
let mut query = doc! {};
if let Some(repo_id) = &filter.repo_id {
query.insert("repo_id", repo_id);
@@ -84,12 +81,11 @@ pub async fn list_findings(
#[tracing::instrument(skip_all, fields(finding_id = %id))]
pub async fn get_finding(
Extension(agent): AgentExt,
tenant: TenantCtx,
Path(id): Path<String>,
) -> Result<Json<ApiResponse<Finding>>, StatusCode> {
let oid = mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?;
let db = tenant_db(&agent, &tenant).await?;
let finding = db
let finding = agent
.db
.findings()
.find_one(doc! { "_id": oid })
.await
@@ -106,14 +102,14 @@ pub async fn get_finding(
#[tracing::instrument(skip_all, fields(finding_id = %id))]
pub async fn update_finding_status(
Extension(agent): AgentExt,
tenant: TenantCtx,
Path(id): Path<String>,
Json(req): Json<UpdateStatusRequest>,
) -> Result<Json<serde_json::Value>, StatusCode> {
let oid = mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?;
let db = tenant_db(&agent, &tenant).await?;
db.findings()
agent
.db
.findings()
.update_one(
doc! { "_id": oid },
doc! { "$set": { "status": &req.status, "updated_at": mongodb::bson::DateTime::now() } },
@@ -127,7 +123,6 @@ pub async fn update_finding_status(
#[tracing::instrument(skip_all)]
pub async fn bulk_update_finding_status(
Extension(agent): AgentExt,
tenant: TenantCtx,
Json(req): Json<BulkUpdateStatusRequest>,
) -> Result<Json<serde_json::Value>, StatusCode> {
let oids: Vec<mongodb::bson::oid::ObjectId> = req
@@ -140,8 +135,8 @@ pub async fn bulk_update_finding_status(
return Err(StatusCode::BAD_REQUEST);
}
let db = tenant_db(&agent, &tenant).await?;
let result = db
let result = agent
.db
.findings()
.update_many(
doc! { "_id": { "$in": oids } },
@@ -158,14 +153,14 @@ pub async fn bulk_update_finding_status(
#[tracing::instrument(skip_all)]
pub async fn update_finding_feedback(
Extension(agent): AgentExt,
tenant: TenantCtx,
Path(id): Path<String>,
Json(req): Json<UpdateFeedbackRequest>,
) -> Result<Json<serde_json::Value>, StatusCode> {
let oid = mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?;
let db = tenant_db(&agent, &tenant).await?;
db.findings()
agent
.db
.findings()
.update_one(
doc! { "_id": oid },
doc! { "$set": { "developer_feedback": &req.feedback, "updated_at": mongodb::bson::DateTime::now() } },
+10 -24
View File
@@ -7,11 +7,9 @@ use mongodb::bson::doc;
use serde::{Deserialize, Serialize};
use compliance_core::models::graph::{CodeEdge, CodeNode, GraphBuildRun, ImpactAnalysis};
use compliance_core::tenant_ctx::TenantCtx;
use crate::agent::ComplianceAgent;
use super::dto::tenant_db;
use super::{collect_cursor_async, ApiResponse};
type AgentExt = Extension<Arc<ComplianceAgent>>;
@@ -38,11 +36,9 @@ fn default_search_limit() -> usize {
#[tracing::instrument(skip_all, fields(repo_id = %repo_id))]
pub async fn get_graph(
Extension(agent): AgentExt,
tenant: TenantCtx,
Path(repo_id): Path<String>,
) -> Result<Json<ApiResponse<GraphData>>, StatusCode> {
let db = tenant_db(&agent, &tenant).await?;
let db = &db;
let db = &agent.db;
// Get latest build
let build: Option<GraphBuildRun> = db
@@ -102,11 +98,9 @@ pub async fn get_graph(
#[tracing::instrument(skip_all, fields(repo_id = %repo_id))]
pub async fn get_nodes(
Extension(agent): AgentExt,
tenant: TenantCtx,
Path(repo_id): Path<String>,
) -> Result<Json<ApiResponse<Vec<CodeNode>>>, StatusCode> {
let db = tenant_db(&agent, &tenant).await?;
let db = &db;
let db = &agent.db;
let filter = doc! { "repo_id": &repo_id };
let nodes: Vec<CodeNode> = match db.graph_nodes().find(filter).await {
@@ -129,11 +123,9 @@ pub async fn get_nodes(
#[tracing::instrument(skip_all, fields(repo_id = %repo_id))]
pub async fn get_communities(
Extension(agent): AgentExt,
tenant: TenantCtx,
Path(repo_id): Path<String>,
) -> Result<Json<ApiResponse<Vec<CommunityInfo>>>, StatusCode> {
let db = tenant_db(&agent, &tenant).await?;
let db = &db;
let db = &agent.db;
let filter = doc! { "repo_id": &repo_id };
let nodes: Vec<CodeNode> = match db.graph_nodes().find(filter).await {
@@ -184,11 +176,9 @@ pub struct CommunityInfo {
#[tracing::instrument(skip_all, fields(repo_id = %repo_id, finding_id = %finding_id))]
pub async fn get_impact(
Extension(agent): AgentExt,
tenant: TenantCtx,
Path((repo_id, finding_id)): Path<(String, String)>,
) -> Result<Json<ApiResponse<Option<ImpactAnalysis>>>, StatusCode> {
let db = tenant_db(&agent, &tenant).await?;
let db = &db;
let db = &agent.db;
let filter = doc! { "repo_id": &repo_id, "finding_id": &finding_id };
let impact = db
@@ -208,12 +198,10 @@ pub async fn get_impact(
#[tracing::instrument(skip_all, fields(repo_id = %repo_id, query = %params.q))]
pub async fn search_symbols(
Extension(agent): AgentExt,
tenant: TenantCtx,
Path(repo_id): Path<String>,
Query(params): Query<SearchParams>,
) -> Result<Json<ApiResponse<Vec<CodeNode>>>, StatusCode> {
let db = tenant_db(&agent, &tenant).await?;
let db = &db;
let db = &agent.db;
// Simple text search on qualified_name and name fields
let filter = doc! {
@@ -246,12 +234,10 @@ pub async fn search_symbols(
#[tracing::instrument(skip_all, fields(repo_id = %repo_id))]
pub async fn get_file_content(
Extension(agent): AgentExt,
tenant: TenantCtx,
Path(repo_id): Path<String>,
Query(params): Query<FileContentParams>,
) -> Result<Json<ApiResponse<FileContent>>, StatusCode> {
let db = tenant_db(&agent, &tenant).await?;
let db = &db;
let db = &agent.db;
// Look up the repository to get repo name
let repo = db
@@ -310,13 +296,12 @@ pub struct FileContent {
#[tracing::instrument(skip_all, fields(repo_id = %repo_id))]
pub async fn trigger_build(
Extension(agent): AgentExt,
tenant: TenantCtx,
Path(repo_id): Path<String>,
) -> Result<Json<serde_json::Value>, StatusCode> {
let db = tenant_db(&agent, &tenant).await?;
let agent_clone = (*agent).clone();
tokio::spawn(async move {
let repo = match db
let repo = match agent_clone
.db
.repositories()
.find_one(doc! { "_id": mongodb::bson::oid::ObjectId::parse_str(&repo_id).ok() })
.await
@@ -348,7 +333,8 @@ pub async fn trigger_build(
match engine.build_graph(&repo_path, &repo_id, &graph_build_id) {
Ok((code_graph, build_run)) => {
let store = compliance_graph::graph::persistence::GraphStore::new(db.inner());
let store =
compliance_graph::graph::persistence::GraphStore::new(agent_clone.db.inner());
let _ = store.delete_repo_graph(&repo_id).await;
let _ = store
.store_graph(&build_run, &code_graph.nodes, &code_graph.edges)
+2 -7
View File
@@ -3,7 +3,6 @@ use mongodb::bson::doc;
use super::dto::*;
use compliance_core::models::ScanRun;
use compliance_core::tenant_ctx::TenantCtx;
#[tracing::instrument(skip_all)]
pub async fn health() -> Json<serde_json::Value> {
@@ -11,12 +10,8 @@ pub async fn health() -> Json<serde_json::Value> {
}
#[tracing::instrument(skip_all)]
pub async fn stats_overview(
axum::extract::Extension(agent): AgentExt,
tenant: TenantCtx,
) -> ApiResult<OverviewStats> {
let db = tenant_db(&agent, &tenant).await?;
let db = &db;
pub async fn stats_overview(axum::extract::Extension(agent): AgentExt) -> ApiResult<OverviewStats> {
let db = &agent.db;
let total_repositories = db
.repositories()
+1 -4
View File
@@ -4,16 +4,13 @@ use mongodb::bson::doc;
use super::dto::*;
use compliance_core::models::TrackerIssue;
use compliance_core::tenant_ctx::TenantCtx;
#[tracing::instrument(skip_all)]
pub async fn list_issues(
Extension(agent): AgentExt,
tenant: TenantCtx,
Query(params): Query<PaginationParams>,
) -> ApiResult<Vec<TrackerIssue>> {
let db = tenant_db(&agent, &tenant).await?;
let db = &db;
let db = &agent.db;
let skip = (params.page.saturating_sub(1)) * params.limit as u64;
let total = db
.tracker_issues()
@@ -5,18 +5,15 @@ use mongodb::bson::doc;
use serde::Deserialize;
use compliance_core::models::notification::CveNotification;
use compliance_core::tenant_ctx::TenantCtx;
use super::dto::{tenant_db, AgentExt, ApiResponse};
use super::dto::{AgentExt, ApiResponse};
/// GET /api/v1/notifications — List CVE notifications (newest first)
#[tracing::instrument(skip_all)]
pub async fn list_notifications(
Extension(agent): AgentExt,
tenant: TenantCtx,
axum::extract::Query(params): axum::extract::Query<NotificationFilter>,
) -> Result<Json<ApiResponse<Vec<CveNotification>>>, StatusCode> {
let db = tenant_db(&agent, &tenant).await?;
let mut filter = doc! {};
// Filter by status (default: show new + read, exclude dismissed)
@@ -44,13 +41,15 @@ pub async fn list_notifications(
let limit = params.limit.unwrap_or(50).min(200);
let skip = (page - 1) * limit as u64;
let total = db
let total = agent
.db
.cve_notifications()
.count_documents(filter.clone())
.await
.unwrap_or(0);
let notifications: Vec<CveNotification> = match db
let notifications: Vec<CveNotification> = match agent
.db
.cve_notifications()
.find(filter)
.sort(doc! { "created_at": -1 })
@@ -84,10 +83,9 @@ pub async fn list_notifications(
#[tracing::instrument(skip_all)]
pub async fn notification_count(
Extension(agent): AgentExt,
tenant: TenantCtx,
) -> Result<Json<serde_json::Value>, StatusCode> {
let db = tenant_db(&agent, &tenant).await?;
let count = db
let count = agent
.db
.cve_notifications()
.count_documents(doc! { "status": "new" })
.await
@@ -100,13 +98,12 @@ pub async fn notification_count(
#[tracing::instrument(skip_all, fields(id = %id))]
pub async fn mark_read(
Extension(agent): AgentExt,
tenant: TenantCtx,
axum::extract::Path(id): axum::extract::Path<String>,
) -> Result<Json<serde_json::Value>, StatusCode> {
let oid = mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?;
let db = tenant_db(&agent, &tenant).await?;
let result = db
let result = agent
.db
.cve_notifications()
.update_one(
doc! { "_id": oid },
@@ -128,13 +125,12 @@ pub async fn mark_read(
#[tracing::instrument(skip_all, fields(id = %id))]
pub async fn dismiss_notification(
Extension(agent): AgentExt,
tenant: TenantCtx,
axum::extract::Path(id): axum::extract::Path<String>,
) -> Result<Json<serde_json::Value>, StatusCode> {
let oid = mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?;
let db = tenant_db(&agent, &tenant).await?;
let result = db
let result = agent
.db
.cve_notifications()
.update_one(
doc! { "_id": oid },
@@ -153,10 +149,9 @@ pub async fn dismiss_notification(
#[tracing::instrument(skip_all)]
pub async fn mark_all_read(
Extension(agent): AgentExt,
tenant: TenantCtx,
) -> Result<Json<serde_json::Value>, StatusCode> {
let db = tenant_db(&agent, &tenant).await?;
let result = db
let result = agent
.db
.cve_notifications()
.update_many(
doc! { "status": "new" },
@@ -13,11 +13,10 @@ use compliance_core::models::dast::DastFinding;
use compliance_core::models::finding::Finding;
use compliance_core::models::pentest::*;
use compliance_core::models::sbom::SbomEntry;
use compliance_core::tenant_ctx::TenantCtx;
use crate::agent::ComplianceAgent;
use super::super::dto::{collect_cursor_async, tenant_db};
use super::super::dto::collect_cursor_async;
type AgentExt = Extension<Arc<ComplianceAgent>>;
@@ -36,15 +35,11 @@ pub struct ExportBody {
#[tracing::instrument(skip_all, fields(session_id = %id))]
pub async fn export_session_report(
Extension(agent): AgentExt,
tenant: TenantCtx,
Path(id): Path<String>,
Json(body): Json<ExportBody>,
) -> Result<axum::response::Response, (StatusCode, String)> {
let oid = mongodb::bson::oid::ObjectId::parse_str(&id)
.map_err(|_| (StatusCode::BAD_REQUEST, "Invalid session ID".to_string()))?;
let db = tenant_db(&agent, &tenant)
.await
.map_err(|s| (s, "failed to acquire tenant database".to_string()))?;
if body.password.len() < 8 {
return Err((
@@ -54,7 +49,8 @@ pub async fn export_session_report(
}
// Fetch session
let session = db
let session = agent
.db
.pentest_sessions()
.find_one(doc! { "_id": oid })
.await
@@ -68,7 +64,9 @@ pub async fn export_session_report(
// Resolve target name
let target = if let Ok(tid) = mongodb::bson::oid::ObjectId::parse_str(&session.target_id) {
db.dast_targets()
agent
.db
.dast_targets()
.find_one(doc! { "_id": tid })
.await
.ok()
@@ -86,7 +84,8 @@ pub async fn export_session_report(
.unwrap_or_default();
// Fetch attack chain nodes
let nodes: Vec<AttackChainNode> = match db
let nodes: Vec<AttackChainNode> = match agent
.db
.attack_chain_nodes()
.find(doc! { "session_id": &id })
.sort(doc! { "started_at": 1 })
@@ -97,7 +96,8 @@ pub async fn export_session_report(
};
// Fetch DAST findings for this session, then deduplicate
let raw_findings: Vec<DastFinding> = match db
let raw_findings: Vec<DastFinding> = match agent
.db
.dast_findings()
.find(doc! { "session_id": &id })
.sort(doc! { "severity": -1, "created_at": -1 })
@@ -122,7 +122,8 @@ pub async fn export_session_report(
.or_else(|| target.as_ref().and_then(|t| t.repo_id.clone()));
let (sast_findings, sbom_entries, code_context) = if let Some(ref rid) = repo_id {
let sast: Vec<Finding> = match db
let sast: Vec<Finding> = match agent
.db
.findings()
.find(doc! {
"repo_id": rid,
@@ -142,7 +143,8 @@ pub async fn export_session_report(
Err(_) => Vec::new(),
};
let sbom: Vec<SbomEntry> = match db
let sbom: Vec<SbomEntry> = match agent
.db
.sbom_entries()
.find(doc! {
"repo_id": rid,
@@ -162,7 +164,8 @@ pub async fn export_session_report(
};
// Build code context from graph nodes
let code_ctx: Vec<CodeContextHint> = match db
let code_ctx: Vec<CodeContextHint> = match agent
.db
.graph_nodes()
.find(doc! { "repo_id": rid, "is_entry_point": true })
.limit(50)
@@ -7,12 +7,11 @@ use mongodb::bson::doc;
use serde::Deserialize;
use compliance_core::models::pentest::*;
use compliance_core::tenant_ctx::TenantCtx;
use crate::agent::ComplianceAgent;
use crate::pentest::PentestOrchestrator;
use super::super::dto::{collect_cursor_async, tenant_db, ApiResponse, PaginationParams};
use super::super::dto::{collect_cursor_async, ApiResponse, PaginationParams};
type AgentExt = Extension<Arc<ComplianceAgent>>;
@@ -44,7 +43,6 @@ pub struct LookupRepoQuery {
#[tracing::instrument(skip_all)]
pub async fn create_session(
Extension(agent): AgentExt,
tenant: TenantCtx,
Json(req): Json<CreateSessionRequest>,
) -> Result<Json<ApiResponse<PentestSession>>, (StatusCode, String)> {
// Try to acquire a concurrency permit
@@ -59,10 +57,6 @@ pub async fn create_session(
)
})?;
let db = tenant_db(&agent, &tenant)
.await
.map_err(|s| (s, "failed to acquire tenant database".to_string()))?;
if let Some(ref config) = req.config {
// ── Wizard path ──────────────────────────────────────────────
if !config.disclaimer_accepted {
@@ -73,7 +67,8 @@ pub async fn create_session(
}
// Look up or auto-create DastTarget by app_url
let target = match db
let target = match agent
.db
.dast_targets()
.find_one(doc! { "base_url": &config.app_url })
.await
@@ -92,7 +87,7 @@ pub async fn create_session(
}
t.allow_destructive = config.allow_destructive;
t.excluded_paths = config.scope_exclusions.clone();
let res = db.dast_targets().insert_one(&t).await.map_err(|e| {
let res = agent.db.dast_targets().insert_one(&t).await.map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("Failed to create target: {e}"),
@@ -115,7 +110,8 @@ pub async fn create_session(
// Resolve repo_id from git_repo_url if provided
if let Some(ref git_url) = config.git_repo_url {
if let Ok(Some(repo)) = db
if let Ok(Some(repo)) = agent
.db
.repositories()
.find_one(doc! { "git_url": git_url })
.await
@@ -124,7 +120,8 @@ pub async fn create_session(
}
}
let insert_result = db
let insert_result = agent
.db
.pentest_sessions()
.insert_one(&session)
.await
@@ -215,7 +212,8 @@ pub async fn create_session(
// Persist encrypted credentials to DB
if session_for_task.config.is_some() {
if let Some(sid) = session.id {
let _ = db
let _ = agent
.db
.pentest_sessions()
.update_one(
doc! { "_id": sid },
@@ -247,13 +245,12 @@ pub async fn create_session(
});
let llm = agent.llm.clone();
let db_for_orchestrator = db.clone();
let db = agent.db.clone();
let session_clone = session.clone();
let target_clone = target.clone();
let agent_ref = agent.clone();
tokio::spawn(async move {
let orchestrator =
PentestOrchestrator::new(llm, db_for_orchestrator, event_tx, Some(pause_rx));
let orchestrator = PentestOrchestrator::new(llm, db, event_tx, Some(pause_rx));
orchestrator
.run_session_guarded(&session_clone, &target_clone, &initial_message)
.await;
@@ -295,7 +292,8 @@ pub async fn create_session(
)
})?;
let target = db
let target = agent
.db
.dast_targets()
.find_one(doc! { "_id": oid })
.await
@@ -312,7 +310,8 @@ pub async fn create_session(
let mut session = PentestSession::new(target_id, strategy);
session.repo_id = target.repo_id.clone();
let insert_result = db
let insert_result = agent
.db
.pentest_sessions()
.insert_one(&session)
.await
@@ -339,13 +338,12 @@ pub async fn create_session(
});
let llm = agent.llm.clone();
let db_for_orchestrator = db.clone();
let db = agent.db.clone();
let session_clone = session.clone();
let target_clone = target.clone();
let agent_ref = agent.clone();
tokio::spawn(async move {
let orchestrator =
PentestOrchestrator::new(llm, db_for_orchestrator, event_tx, Some(pause_rx));
let orchestrator = PentestOrchestrator::new(llm, db, event_tx, Some(pause_rx));
orchestrator
.run_session_guarded(&session_clone, &target_clone, &initial_message)
.await;
@@ -375,11 +373,10 @@ fn parse_strategy(s: &str) -> PentestStrategy {
#[tracing::instrument(skip_all)]
pub async fn lookup_repo(
Extension(agent): AgentExt,
tenant: TenantCtx,
Query(params): Query<LookupRepoQuery>,
) -> Result<Json<ApiResponse<serde_json::Value>>, StatusCode> {
let db = tenant_db(&agent, &tenant).await?;
let repo = db
let repo = agent
.db
.repositories()
.find_one(doc! { "git_url": &params.url })
.await
@@ -405,11 +402,9 @@ pub async fn lookup_repo(
#[tracing::instrument(skip_all)]
pub async fn list_sessions(
Extension(agent): AgentExt,
tenant: TenantCtx,
Query(params): Query<PaginationParams>,
) -> Result<Json<ApiResponse<Vec<PentestSession>>>, StatusCode> {
let db = tenant_db(&agent, &tenant).await?;
let db = &db;
let db = &agent.db;
let skip = (params.page.saturating_sub(1)) * params.limit as u64;
let total = db
.pentest_sessions()
@@ -443,13 +438,12 @@ pub async fn list_sessions(
#[tracing::instrument(skip_all, fields(session_id = %id))]
pub async fn get_session(
Extension(agent): AgentExt,
tenant: TenantCtx,
Path(id): Path<String>,
) -> Result<Json<ApiResponse<PentestSession>>, StatusCode> {
let oid = mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?;
let db = tenant_db(&agent, &tenant).await?;
let mut session = db
let mut session = agent
.db
.pentest_sessions()
.find_one(doc! { "_id": oid })
.await
@@ -477,18 +471,15 @@ pub async fn get_session(
#[tracing::instrument(skip_all, fields(session_id = %id))]
pub async fn send_message(
Extension(agent): AgentExt,
tenant: TenantCtx,
Path(id): Path<String>,
Json(req): Json<SendMessageRequest>,
) -> Result<Json<ApiResponse<PentestMessage>>, (StatusCode, String)> {
let oid = mongodb::bson::oid::ObjectId::parse_str(&id)
.map_err(|_| (StatusCode::BAD_REQUEST, "Invalid session ID".to_string()))?;
let db = tenant_db(&agent, &tenant)
.await
.map_err(|s| (s, "failed to acquire tenant database".to_string()))?;
// Verify session exists and is running
let session = db
let session = agent
.db
.pentest_sessions()
.find_one(doc! { "_id": oid })
.await
@@ -515,7 +506,8 @@ pub async fn send_message(
)
})?;
let target = db
let target = agent
.db
.dast_targets()
.find_one(doc! { "_id": target_oid })
.await
@@ -535,13 +527,13 @@ pub async fn send_message(
// Store user message
let session_id = id.clone();
let user_msg = PentestMessage::user(session_id.clone(), req.message.clone());
let _ = db.pentest_messages().insert_one(&user_msg).await;
let _ = agent.db.pentest_messages().insert_one(&user_msg).await;
let response_msg = user_msg.clone();
// Spawn orchestrator to continue the session
let llm = agent.llm.clone();
let db_for_orchestrator = db.clone();
let db = agent.db.clone();
let message = req.message.clone();
// Use existing broadcast sender if available, otherwise create a new one
@@ -556,7 +548,7 @@ pub async fn send_message(
.unwrap_or_else(|| agent.register_session_stream(&session_id));
tokio::spawn(async move {
let orchestrator = PentestOrchestrator::new(llm, db_for_orchestrator, event_tx, None);
let orchestrator = PentestOrchestrator::new(llm, db, event_tx, None);
orchestrator
.run_session_guarded(&session, &target, &message)
.await;
@@ -573,16 +565,13 @@ pub async fn send_message(
#[tracing::instrument(skip_all, fields(session_id = %id))]
pub async fn stop_session(
Extension(agent): AgentExt,
tenant: TenantCtx,
Path(id): Path<String>,
) -> Result<Json<ApiResponse<PentestSession>>, (StatusCode, String)> {
let oid = mongodb::bson::oid::ObjectId::parse_str(&id)
.map_err(|_| (StatusCode::BAD_REQUEST, "Invalid session ID".to_string()))?;
let db = tenant_db(&agent, &tenant)
.await
.map_err(|s| (s, "failed to acquire tenant database".to_string()))?;
let session = db
let session = agent
.db
.pentest_sessions()
.find_one(doc! { "_id": oid })
.await
@@ -601,7 +590,9 @@ pub async fn stop_session(
));
}
db.pentest_sessions()
agent
.db
.pentest_sessions()
.update_one(
doc! { "_id": oid },
doc! { "$set": {
@@ -621,7 +612,8 @@ pub async fn stop_session(
// Clean up session resources
agent.cleanup_session(&id);
let updated = db
let updated = agent
.db
.pentest_sessions()
.find_one(doc! { "_id": oid })
.await
@@ -649,16 +641,13 @@ pub async fn stop_session(
#[tracing::instrument(skip_all, fields(session_id = %id))]
pub async fn pause_session(
Extension(agent): AgentExt,
tenant: TenantCtx,
Path(id): Path<String>,
) -> Result<Json<ApiResponse<serde_json::Value>>, (StatusCode, String)> {
let oid = mongodb::bson::oid::ObjectId::parse_str(&id)
.map_err(|_| (StatusCode::BAD_REQUEST, "Invalid session ID".to_string()))?;
let db = tenant_db(&agent, &tenant)
.await
.map_err(|s| (s, "failed to acquire tenant database".to_string()))?;
let session = db
let session = agent
.db
.pentest_sessions()
.find_one(doc! { "_id": oid })
.await
@@ -695,16 +684,13 @@ pub async fn pause_session(
#[tracing::instrument(skip_all, fields(session_id = %id))]
pub async fn resume_session(
Extension(agent): AgentExt,
tenant: TenantCtx,
Path(id): Path<String>,
) -> Result<Json<ApiResponse<serde_json::Value>>, (StatusCode, String)> {
let oid = mongodb::bson::oid::ObjectId::parse_str(&id)
.map_err(|_| (StatusCode::BAD_REQUEST, "Invalid session ID".to_string()))?;
let db = tenant_db(&agent, &tenant)
.await
.map_err(|s| (s, "failed to acquire tenant database".to_string()))?;
let session = db
let session = agent
.db
.pentest_sessions()
.find_one(doc! { "_id": oid })
.await
@@ -741,13 +727,12 @@ pub async fn resume_session(
#[tracing::instrument(skip_all, fields(session_id = %id))]
pub async fn get_attack_chain(
Extension(agent): AgentExt,
tenant: TenantCtx,
Path(id): Path<String>,
) -> Result<Json<ApiResponse<Vec<AttackChainNode>>>, StatusCode> {
let _oid = mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?;
let db = tenant_db(&agent, &tenant).await?;
let nodes = match db
let nodes = match agent
.db
.attack_chain_nodes()
.find(doc! { "session_id": &id })
.sort(doc! { "started_at": 1 })
@@ -772,21 +757,21 @@ pub async fn get_attack_chain(
#[tracing::instrument(skip_all, fields(session_id = %id))]
pub async fn get_messages(
Extension(agent): AgentExt,
tenant: TenantCtx,
Path(id): Path<String>,
Query(params): Query<PaginationParams>,
) -> Result<Json<ApiResponse<Vec<PentestMessage>>>, StatusCode> {
let _oid = mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?;
let db = tenant_db(&agent, &tenant).await?;
let skip = (params.page.saturating_sub(1)) * params.limit as u64;
let total = db
let total = agent
.db
.pentest_messages()
.count_documents(doc! { "session_id": &id })
.await
.unwrap_or(0);
let messages = match db
let messages = match agent
.db
.pentest_messages()
.find(doc! { "session_id": &id })
.sort(doc! { "created_at": 1 })
@@ -812,21 +797,21 @@ pub async fn get_messages(
#[tracing::instrument(skip_all, fields(session_id = %id))]
pub async fn get_session_findings(
Extension(agent): AgentExt,
tenant: TenantCtx,
Path(id): Path<String>,
Query(params): Query<PaginationParams>,
) -> Result<Json<ApiResponse<Vec<compliance_core::models::dast::DastFinding>>>, StatusCode> {
let _oid = mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?;
let db = tenant_db(&agent, &tenant).await?;
let skip = (params.page.saturating_sub(1)) * params.limit as u64;
let total = db
let total = agent
.db
.dast_findings()
.count_documents(doc! { "session_id": &id })
.await
.unwrap_or(0);
let findings = match db
let findings = match agent
.db
.dast_findings()
.find(doc! { "session_id": &id })
.sort(doc! { "created_at": -1 })
@@ -6,11 +6,10 @@ use axum::Json;
use mongodb::bson::doc;
use compliance_core::models::pentest::*;
use compliance_core::tenant_ctx::TenantCtx;
use crate::agent::ComplianceAgent;
use super::super::dto::{collect_cursor_async, tenant_db, ApiResponse};
use super::super::dto::{collect_cursor_async, ApiResponse};
type AgentExt = Extension<Arc<ComplianceAgent>>;
@@ -18,10 +17,8 @@ type AgentExt = Extension<Arc<ComplianceAgent>>;
#[tracing::instrument(skip_all)]
pub async fn pentest_stats(
Extension(agent): AgentExt,
tenant: TenantCtx,
) -> Result<Json<ApiResponse<PentestStats>>, StatusCode> {
let db = tenant_db(&agent, &tenant).await?;
let db = &db;
let db = &agent.db;
let running_sessions = db
.pentest_sessions()
@@ -11,11 +11,10 @@ use tokio_stream::wrappers::BroadcastStream;
use tokio_stream::StreamExt;
use compliance_core::models::pentest::*;
use compliance_core::tenant_ctx::TenantCtx;
use crate::agent::ComplianceAgent;
use super::super::dto::{collect_cursor_async, tenant_db};
use super::super::dto::collect_cursor_async;
type AgentExt = Extension<Arc<ComplianceAgent>>;
@@ -26,14 +25,13 @@ type AgentExt = Extension<Arc<ComplianceAgent>>;
#[tracing::instrument(skip_all, fields(session_id = %id))]
pub async fn session_stream(
Extension(agent): AgentExt,
tenant: TenantCtx,
Path(id): Path<String>,
) -> Result<Sse<impl futures_util::Stream<Item = Result<Event, Infallible>>>, StatusCode> {
let oid = mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?;
let db = tenant_db(&agent, &tenant).await?;
// Verify session exists
let _session = db
let _session = agent
.db
.pentest_sessions()
.find_one(doc! { "_id": oid })
.await
@@ -45,7 +43,8 @@ pub async fn session_stream(
let mut initial_events: Vec<Result<Event, Infallible>> = Vec::new();
// Fetch recent messages for this session
let messages: Vec<PentestMessage> = match db
let messages: Vec<PentestMessage> = match agent
.db
.pentest_messages()
.find(doc! { "session_id": &id })
.sort(doc! { "created_at": 1 })
@@ -57,7 +56,8 @@ pub async fn session_stream(
};
// Fetch recent attack chain nodes
let nodes: Vec<AttackChainNode> = match db
let nodes: Vec<AttackChainNode> = match agent
.db
.attack_chain_nodes()
.find(doc! { "session_id": &id })
.sort(doc! { "started_at": 1 })
@@ -94,7 +94,8 @@ pub async fn session_stream(
}
// Add current session status event
let session = db
let session = agent
.db
.pentest_sessions()
.find_one(doc! { "_id": oid })
.await
+16 -22
View File
@@ -5,16 +5,13 @@ use mongodb::bson::doc;
use super::dto::*;
use compliance_core::models::*;
use compliance_core::tenant_ctx::TenantCtx;
#[tracing::instrument(skip_all)]
pub async fn list_repositories(
Extension(agent): AgentExt,
tenant: TenantCtx,
Query(params): Query<PaginationParams>,
) -> ApiResult<Vec<TrackedRepository>> {
let db = tenant_db(&agent, &tenant).await?;
let db = &db;
let db = &agent.db;
let skip = (params.page.saturating_sub(1)) * params.limit as u64;
let total = db
.repositories()
@@ -46,7 +43,6 @@ pub async fn list_repositories(
#[tracing::instrument(skip_all)]
pub async fn add_repository(
Extension(agent): AgentExt,
tenant: TenantCtx,
Json(req): Json<AddRepositoryRequest>,
) -> Result<Json<ApiResponse<TrackedRepository>>, (StatusCode, String)> {
// Validate repository access before saving
@@ -73,15 +69,17 @@ pub async fn add_repository(
repo.tracker_token = req.tracker_token;
repo.scan_schedule = req.scan_schedule;
let db = tenant_db(&agent, &tenant)
agent
.db
.repositories()
.insert_one(&repo)
.await
.map_err(|s| (s, "failed to acquire tenant database".to_string()))?;
db.repositories().insert_one(&repo).await.map_err(|_| {
(
StatusCode::CONFLICT,
"Repository already exists".to_string(),
)
})?;
.map_err(|_| {
(
StatusCode::CONFLICT,
"Repository already exists".to_string(),
)
})?;
Ok(Json(ApiResponse {
data: repo,
@@ -93,12 +91,10 @@ pub async fn add_repository(
#[tracing::instrument(skip_all, fields(repo_id = %id))]
pub async fn update_repository(
Extension(agent): AgentExt,
tenant: TenantCtx,
Path(id): Path<String>,
Json(req): Json<UpdateRepositoryRequest>,
) -> Result<Json<serde_json::Value>, StatusCode> {
let oid = mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?;
let db = tenant_db(&agent, &tenant).await?;
let mut set_doc = doc! { "updated_at": mongodb::bson::DateTime::now() };
@@ -130,7 +126,8 @@ pub async fn update_repository(
set_doc.insert("scan_schedule", schedule);
}
let result = db
let result = agent
.db
.repositories()
.update_one(doc! { "_id": oid }, doc! { "$set": set_doc })
.await
@@ -173,12 +170,11 @@ pub async fn trigger_scan(
/// Return the webhook secret for a repository (used by dashboard to display it)
pub async fn get_webhook_config(
Extension(agent): AgentExt,
tenant: TenantCtx,
Path(id): Path<String>,
) -> Result<Json<serde_json::Value>, StatusCode> {
let oid = mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?;
let db = tenant_db(&agent, &tenant).await?;
let repo = db
let repo = agent
.db
.repositories()
.find_one(doc! { "_id": oid })
.await
@@ -200,12 +196,10 @@ pub async fn get_webhook_config(
#[tracing::instrument(skip_all, fields(repo_id = %id))]
pub async fn delete_repository(
Extension(agent): AgentExt,
tenant: TenantCtx,
Path(id): Path<String>,
) -> Result<Json<serde_json::Value>, StatusCode> {
let oid = mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?;
let db = tenant_db(&agent, &tenant).await?;
let db = &db;
let db = &agent.db;
// Delete the repository
let result = db
+5 -16
View File
@@ -6,7 +6,6 @@ use mongodb::bson::doc;
use super::dto::*;
use compliance_core::models::SbomEntry;
use compliance_core::tenant_ctx::TenantCtx;
const COPYLEFT_LICENSES: &[&str] = &[
"GPL-2.0",
@@ -30,10 +29,8 @@ const COPYLEFT_LICENSES: &[&str] = &[
#[tracing::instrument(skip_all)]
pub async fn sbom_filters(
Extension(agent): AgentExt,
tenant: TenantCtx,
) -> Result<Json<serde_json::Value>, StatusCode> {
let db = tenant_db(&agent, &tenant).await?;
let db = &db;
let db = &agent.db;
let managers: Vec<String> = db
.sbom_entries()
@@ -64,11 +61,9 @@ pub async fn sbom_filters(
#[tracing::instrument(skip_all, fields(repo_id = ?filter.repo_id, package_manager = ?filter.package_manager))]
pub async fn list_sbom(
Extension(agent): AgentExt,
tenant: TenantCtx,
Query(filter): Query<SbomFilter>,
) -> ApiResult<Vec<SbomEntry>> {
let db = tenant_db(&agent, &tenant).await?;
let db = &db;
let db = &agent.db;
let mut query = doc! {};
if let Some(repo_id) = &filter.repo_id {
@@ -125,11 +120,9 @@ pub async fn list_sbom(
#[tracing::instrument(skip_all)]
pub async fn export_sbom(
Extension(agent): AgentExt,
tenant: TenantCtx,
Query(params): Query<SbomExportParams>,
) -> Result<impl IntoResponse, StatusCode> {
let db = tenant_db(&agent, &tenant).await?;
let db = &db;
let db = &agent.db;
let entries: Vec<SbomEntry> = match db
.sbom_entries()
.find(doc! { "repo_id": &params.repo_id })
@@ -243,11 +236,9 @@ pub async fn export_sbom(
#[tracing::instrument(skip_all)]
pub async fn license_summary(
Extension(agent): AgentExt,
tenant: TenantCtx,
Query(params): Query<SbomFilter>,
) -> ApiResult<Vec<LicenseSummary>> {
let db = tenant_db(&agent, &tenant).await?;
let db = &db;
let db = &agent.db;
let mut query = doc! {};
if let Some(repo_id) = &params.repo_id {
query.insert("repo_id", repo_id);
@@ -294,11 +285,9 @@ pub async fn license_summary(
#[tracing::instrument(skip_all)]
pub async fn sbom_diff(
Extension(agent): AgentExt,
tenant: TenantCtx,
Query(params): Query<SbomDiffParams>,
) -> ApiResult<SbomDiffResult> {
let db = tenant_db(&agent, &tenant).await?;
let db = &db;
let db = &agent.db;
let entries_a: Vec<SbomEntry> = match db
.sbom_entries()
+1 -4
View File
@@ -4,16 +4,13 @@ use mongodb::bson::doc;
use super::dto::*;
use compliance_core::models::ScanRun;
use compliance_core::tenant_ctx::TenantCtx;
#[tracing::instrument(skip_all)]
pub async fn list_scan_runs(
Extension(agent): AgentExt,
tenant: TenantCtx,
Query(params): Query<PaginationParams>,
) -> ApiResult<Vec<ScanRun>> {
let db = tenant_db(&agent, &tenant).await?;
let db = &db;
let db = &agent.db;
let skip = (params.page.saturating_sub(1)) * params.limit as u64;
let total = db.scan_runs().count_documents(doc! {}).await.unwrap_or(0);
+1
View File
@@ -1,3 +1,4 @@
pub mod auth_middleware;
pub mod handlers;
pub mod routes;
pub mod server;
+4 -52
View File
@@ -1,54 +1,17 @@
use std::sync::Arc;
use axum::extract::Request;
use axum::http::HeaderValue;
use axum::middleware::Next;
use axum::response::Response;
use axum::{middleware, Extension};
use tokio::sync::RwLock;
use tower_http::cors::CorsLayer;
use tower_http::set_header::SetResponseHeaderLayer;
use tower_http::trace::TraceLayer;
use compliance_core::auth::{require_jwt_auth, require_tenant_status, JwksState};
use compliance_core::{TenantContext, TenantStatus};
use crate::agent::ComplianceAgent;
use crate::api::auth_middleware::{require_jwt_auth, JwksState};
use crate::api::routes;
use crate::error::AgentError;
/// Synthetic tenant id used when Keycloak isn't configured (local dev,
/// `cargo run` against a bare Mongo). Lets the handler stack stay
/// uniformly tenant-scoped without the operator having to spin up KC
/// just to poke at the API. Override via `DEV_TENANT_ID`.
const DEFAULT_DEV_TENANT_ID: &str = "dev";
/// Inject a synthetic [`TenantContext`] for any request that lacks one.
/// Only mounted when Keycloak is NOT configured; with KC, the real
/// `require_jwt_auth` middleware owns this and we never reach here
/// without a context.
///
/// Public so the integration-test harness can mount it without
/// duplicating the synthetic-context shape.
pub async fn inject_dev_tenant(mut request: Request, next: Next) -> Response {
if request.extensions().get::<TenantContext>().is_none() {
let tenant_id =
std::env::var("DEV_TENANT_ID").unwrap_or_else(|_| DEFAULT_DEV_TENANT_ID.to_string());
let ctx = TenantContext {
tenant_slug: tenant_id.clone(),
tenant_id,
org_roles: vec![],
products: vec![],
plan: "dev".to_string(),
status: TenantStatus::Active,
user_id: "dev-user".to_string(),
user_name: None,
};
request.extensions_mut().insert(ctx);
}
next.run(request).await
}
pub async fn start_api_server(agent: ComplianceAgent, port: u16) -> Result<(), AgentError> {
let mut app = routes::build_router()
.layer(Extension(Arc::new(agent.clone())))
@@ -81,22 +44,11 @@ pub async fn start_api_server(agent: ComplianceAgent, port: u16) -> Result<(), A
jwks_url,
};
tracing::info!("Keycloak JWT auth enabled for realm '{kc_realm}'");
// Layers execute outermost-first. Extension(jwks_state) must run
// before require_jwt_auth so the middleware can read it; the
// status gate runs after JWT so TenantContext is in extensions.
app = app
.layer(middleware::from_fn(require_tenant_status))
.layer(middleware::from_fn(require_jwt_auth))
.layer(Extension(jwks_state));
.layer(Extension(jwks_state))
.layer(middleware::from_fn(require_jwt_auth));
} else {
let tenant_id =
std::env::var("DEV_TENANT_ID").unwrap_or_else(|_| DEFAULT_DEV_TENANT_ID.to_string());
tracing::warn!(
tenant_id = %tenant_id,
"Keycloak not configured — running unauthenticated against the dev tenant. \
DO NOT use in any environment with real customer data."
);
app = app.layer(middleware::from_fn(inject_dev_tenant));
tracing::warn!("Keycloak not configured - API endpoints are unprotected");
}
let addr = format!("0.0.0.0:{port}");
-146
View File
@@ -1,151 +1,11 @@
use std::sync::Arc;
use dashmap::DashMap;
use mongodb::bson::doc;
use mongodb::options::IndexOptions;
use mongodb::{Client, Collection, IndexModel};
use sha2::{Digest, Sha256};
use compliance_core::models::*;
use compliance_core::TenantContext;
use crate::error::AgentError;
/// Mongo enforces a 63-byte cap on database names (older clusters: 64
/// on Linux, 63 on Windows; we target the conservative limit).
const MAX_DB_NAME_LEN: usize = 63;
/// Hex length of the SHA-256 truncation used for the hash fallback
/// tenant DB name (16 bytes → 32 hex chars). 16 bytes gives ~2^64
/// birthday-collision resistance — at our 10s-100s tenant scale this
/// is effectively impossible to hit.
const HASH_HEX_LEN: usize = 32;
/// Largest `db_prefix` that still guarantees the hash-fallback name
/// fits in the 63-byte cap: `prefix + "_" + 32 hex chars`.
const MAX_PREFIX_LEN: usize = MAX_DB_NAME_LEN - 1 - HASH_HEX_LEN;
/// Per-tenant Mongo connection broker (M7.2 isolation model).
///
/// Holds one [`Client`] and hands out [`Database`] handles physically
/// scoped to `<db_prefix>_<tenant_id>`. The driver is the isolation
/// boundary — a handle for tenant A cannot see tenant B's documents
/// because it is connected to a different database, not because of an
/// application-level filter.
///
/// Index creation runs idempotently the first time each tenant is seen
/// in the process's lifetime. Mongo's `createIndex` is itself idempotent
/// by index name; the in-memory `ensured` set just skips the round-trip.
#[derive(Clone, Debug)]
pub struct DatabasePool {
client: Client,
db_prefix: String,
ensured: Arc<DashMap<String, ()>>,
}
impl DatabasePool {
/// Connect to the cluster and prepare to hand out tenant databases
/// named `<db_prefix>_<tenant_id>`.
///
/// Validates `db_prefix.len() <= MAX_PREFIX_LEN` so the
/// hash-fallback path is provably within Mongo's 63-byte db-name
/// cap. Refuses to construct a pool that could ever produce an
/// over-long name.
pub async fn connect(uri: &str, db_prefix: &str) -> Result<Self, AgentError> {
if db_prefix.len() > MAX_PREFIX_LEN {
return Err(AgentError::Other(format!(
"db_prefix '{db_prefix}' is {} chars; max is {MAX_PREFIX_LEN} so the \
hash-fallback tenant DB name fits Mongo's {MAX_DB_NAME_LEN}-byte cap",
db_prefix.len()
)));
}
let client = Client::with_uri_str(uri).await?;
client
.database("admin")
.run_command(doc! { "ping": 1 })
.await?;
tracing::info!(
"MongoDB cluster reachable; per-tenant pool ready (db prefix '{db_prefix}')"
);
Ok(Self {
client,
db_prefix: db_prefix.to_string(),
ensured: Arc::new(DashMap::new()),
})
}
/// Return a [`Database`] scoped to this tenant. Ensures indexes on
/// first call per tenant (per process). Cheap on the hot path —
/// subsequent calls skip the round-trip.
pub async fn for_tenant(&self, ctx: &TenantContext) -> Result<Database, AgentError> {
let db_name = self.tenant_db_name(&ctx.tenant_id);
let db = Database::from_database(self.client.database(&db_name));
// `DashMap::insert` returns the previous value; `None` means we
// were the first writer for this tenant_id and own the
// index-ensure work.
if self.ensured.insert(ctx.tenant_id.clone(), ()).is_none() {
if let Err(e) = db.ensure_indexes().await {
// Roll the marker back so the next request retries.
self.ensured.remove(&ctx.tenant_id);
return Err(e);
}
tracing::debug!(
tenant_id = %ctx.tenant_id,
db_name = %db_name,
"Indexes ensured for tenant database"
);
}
Ok(db)
}
/// Compute the Mongo database name for a tenant. Public for tests
/// and tenant offboarding (`pool.client().database(name).drop()`).
///
/// Format: `<prefix>_<sanitized_tenant_id>` if it fits the 63-byte
/// cap, else `<prefix>_<sha256-16-byte-hex-of-tenant_id>`. The
/// `db_prefix` length invariant established at [`Self::connect`]
/// guarantees the hash-fallback name always fits — no runtime
/// assertion needed.
///
/// Collision resistance: the hash fallback is a 16-byte SHA-256
/// truncation, which gives ~2^64 birthday-collision resistance. At
/// our 10s100s tenant scale the probability of two tenant_ids
/// colliding is effectively zero. (8-byte truncation would have
/// been ~2^32 — too close for comfort on a regulated product.)
pub fn tenant_db_name(&self, tenant_id: &str) -> String {
let sanitized = sanitize_tenant_id(tenant_id);
let natural = format!("{}_{}", self.db_prefix, sanitized);
if natural.len() <= MAX_DB_NAME_LEN {
natural
} else {
let mut hasher = Sha256::new();
hasher.update(tenant_id.as_bytes());
let digest = hasher.finalize();
let suffix = hex::encode(&digest[..HASH_HEX_LEN / 2]);
format!("{}_{}", self.db_prefix, suffix)
}
}
/// Raw client handle. Reserved for cross-tenant admin flows that
/// must opt in explicitly (tenant listing, drop-on-offboard).
pub fn client(&self) -> &Client {
&self.client
}
}
/// Mongo database names disallow `/`, `\`, `.`, `"`, `$`, ` `, and NUL.
/// breakpilot-dev tenant_ids are UUIDs so this is belt-and-braces, but
/// it lets the pool tolerate any future tenant_id shape without surprise.
fn sanitize_tenant_id(tenant_id: &str) -> String {
tenant_id
.chars()
.map(|c| match c {
'/' | '\\' | '.' | '"' | '$' | ' ' | '\0' => '_',
c => c,
})
.collect()
}
#[derive(Clone, Debug)]
pub struct Database {
inner: mongodb::Database,
@@ -160,12 +20,6 @@ impl Database {
Ok(Self { inner: db })
}
/// Wrap an already-resolved Mongo database. Used by [`DatabasePool`]
/// to hand out tenant-scoped handles without a fresh client per tenant.
pub(crate) fn from_database(inner: mongodb::Database) -> Self {
Self { inner }
}
pub async fn ensure_indexes(&self) -> Result<(), AgentError> {
// repositories: unique git_url
self.repositories()
+1 -7
View File
@@ -28,13 +28,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
let db = database::Database::connect(&config.mongodb_uri, &config.mongodb_database).await?;
db.ensure_indexes().await?;
// M7.2-A: per-tenant pool. Uses `mongodb_database` as the db-name
// prefix so tenant databases land as `<prefix>_<tenant_id>` next to
// the legacy single-tenant database.
let db_pool =
database::DatabasePool::connect(&config.mongodb_uri, &config.mongodb_database).await?;
let agent = agent::ComplianceAgent::new(config.clone(), db.clone(), db_pool);
let agent = agent::ComplianceAgent::new(config.clone(), db.clone());
tracing::info!("Starting scheduler...");
let scheduler_agent = agent.clone();
+4 -22
View File
@@ -7,7 +7,7 @@ use std::sync::Arc;
use compliance_agent::agent::ComplianceAgent;
use compliance_agent::api;
use compliance_agent::database::{Database, DatabasePool};
use compliance_agent::database::Database;
use compliance_core::AgentConfig;
use secrecy::SecretString;
@@ -33,10 +33,6 @@ impl TestServer {
.expect("Failed to connect to MongoDB — is it running?");
db.ensure_indexes().await.expect("Failed to create indexes");
let db_pool = DatabasePool::connect(&mongodb_uri, &db_name)
.await
.expect("Failed to build DatabasePool");
let config = AgentConfig {
mongodb_uri: mongodb_uri.clone(),
mongodb_database: db_name.clone(),
@@ -73,15 +69,11 @@ impl TestServer {
pentest_imap_password: None,
};
let agent = ComplianceAgent::new(config, db, db_pool);
let agent = ComplianceAgent::new(config, db);
// Build the router with the agent extension. After M7.2-B every
// handler takes a TenantCtx extractor; without KC in the test
// harness, the dev-tenant injector mounts a synthetic context so
// tests run end-to-end against `<db_name>_dev`.
// Build the router with the agent extension
let app = api::routes::build_router()
.layer(axum::extract::Extension(Arc::new(agent)))
.layer(axum::middleware::from_fn(api::server::inject_dev_tenant))
.layer(tower_http::cors::CorsLayer::permissive());
// Bind to port 0 to get a random available port
@@ -164,20 +156,10 @@ impl TestServer {
&self.db_name
}
/// Drop the test database on cleanup. Post-M7.2-B the actual data
/// lives in `<db_name>_<tenant>` per-tenant databases; list those
/// off the cluster and drop them too.
/// Drop the test database on cleanup
pub async fn cleanup(&self) {
if let Ok(client) = mongodb::Client::with_uri_str(&self.mongodb_uri).await {
client.database(&self.db_name).drop().await.ok();
if let Ok(names) = client.list_database_names().await {
let prefix = format!("{}_", self.db_name);
for name in names {
if name.starts_with(&prefix) {
client.database(&name).drop().await.ok();
}
}
}
}
}
}
-234
View File
@@ -1,234 +0,0 @@
//! M7.2-A — `DatabasePool` isolation proof.
//!
//! Two `TenantContext`s, two databases, one client. Insert on A, query
//! on B → empty. Insert on B, query on A → only A's docs. Proves that
//! the per-tenant database split actually isolates at the driver level
//! and not at "we hope we filter."
//!
//! Requires MongoDB. Set `TEST_MONGODB_URI` to override the default
//! `mongodb://root:example@localhost:27017/?authSource=admin`.
#![allow(clippy::expect_used, clippy::unwrap_used)]
use compliance_agent::database::DatabasePool;
use compliance_core::models::TrackedRepository;
use compliance_core::{OrgRole, TenantContext, TenantStatus};
use mongodb::bson::doc;
fn ctx(tenant_id: &str, slug: &str) -> TenantContext {
TenantContext {
tenant_id: tenant_id.to_string(),
tenant_slug: slug.to_string(),
org_roles: vec![OrgRole::ItAdmin],
products: vec!["compliance-scanner".to_string()],
plan: "starter".to_string(),
status: TenantStatus::Active,
user_id: "u-1".to_string(),
user_name: None,
}
}
fn fixture_repo(name: &str, git_url: &str) -> TrackedRepository {
TrackedRepository {
id: None,
name: name.to_string(),
git_url: git_url.to_string(),
default_branch: "main".to_string(),
local_path: None,
scan_schedule: None,
webhook_enabled: false,
webhook_secret: None,
tracker_type: None,
tracker_owner: None,
tracker_repo: None,
tracker_token: None,
auth_token: None,
auth_username: None,
last_scanned_commit: None,
findings_count: 0,
created_at: chrono::Utc::now(),
updated_at: chrono::Utc::now(),
}
}
#[tokio::test]
async fn pool_isolates_tenants_at_driver_level() {
let uri = std::env::var("TEST_MONGODB_URI")
.unwrap_or_else(|_| "mongodb://root:example@localhost:27017/?authSource=admin".into());
// Unique per run so parallel test invocations don't collide. Kept
// short because Mongo caps db names at 63 bytes (prefix + tenant_id).
let prefix = format!("m72a_{}", short_id());
let pool = DatabasePool::connect(&uri, &prefix)
.await
.expect("Failed to connect to MongoDB — is it running?");
let acme = ctx("00000000-0000-0000-0000-00000000acme", "acme");
let globex = ctx("00000000-0000-0000-0000-0000globex000", "globex");
let acme_db = pool.for_tenant(&acme).await.expect("acme db");
let globex_db = pool.for_tenant(&globex).await.expect("globex db");
// Write distinct repos into each tenant's database.
acme_db
.repositories()
.insert_one(fixture_repo("acme-app", "git@example.com:acme/app.git"))
.await
.expect("insert acme");
globex_db
.repositories()
.insert_one(fixture_repo(
"globex-platform",
"git@example.com:globex/platform.git",
))
.await
.expect("insert globex");
// The point of the whole exercise: acme can ONLY see acme's repo
// and globex can ONLY see globex's, with no filter doc anywhere
// because the isolation is at the database handle, not in the query.
let acme_seen = collect(&acme_db).await;
let globex_seen = collect(&globex_db).await;
assert_eq!(acme_seen.len(), 1, "acme should see exactly its own repo");
assert_eq!(acme_seen[0].name, "acme-app");
assert_eq!(
globex_seen.len(),
1,
"globex should see exactly its own repo"
);
assert_eq!(globex_seen[0].name, "globex-platform");
// Sanity: the two databases really are different by name.
let acme_db_name = pool.tenant_db_name(&acme.tenant_id);
let globex_db_name = pool.tenant_db_name(&globex.tenant_id);
assert_ne!(acme_db_name, globex_db_name);
assert!(acme_db_name.starts_with(&prefix));
// Cleanup — drop both per-tenant databases.
pool.client()
.database(&acme_db_name)
.drop()
.await
.expect("drop acme");
pool.client()
.database(&globex_db_name)
.drop()
.await
.expect("drop globex");
}
#[tokio::test]
async fn for_tenant_is_idempotent_index_creation() {
let uri = std::env::var("TEST_MONGODB_URI")
.unwrap_or_else(|_| "mongodb://root:example@localhost:27017/?authSource=admin".into());
let prefix = format!("m72a_{}", short_id());
let pool = DatabasePool::connect(&uri, &prefix).await.expect("connect");
let acme = ctx("00000000-0000-0000-0000-00000000acme", "acme");
// Second call must not fail (ensure_indexes already ran, in-memory
// marker is set, Mongo's createIndex is idempotent by name anyway).
let _ = pool.for_tenant(&acme).await.expect("first call");
let _ = pool.for_tenant(&acme).await.expect("second call");
let _ = pool.for_tenant(&acme).await.expect("third call");
// Cleanup
let db_name = pool.tenant_db_name(&acme.tenant_id);
pool.client().database(&db_name).drop().await.expect("drop");
}
#[tokio::test]
async fn tenant_db_name_sanitizes_unsafe_characters() {
let uri = std::env::var("TEST_MONGODB_URI")
.unwrap_or_else(|_| "mongodb://root:example@localhost:27017/?authSource=admin".into());
let pool = DatabasePool::connect(&uri, "m72a_sanitize")
.await
.expect("connect");
// Mongo db names cannot contain `/ \ . " $ <space> NUL`. The pool
// must rewrite these without exploding on connect.
let funky = "te/n.a\\nt$id\" with spaces";
let name = pool.tenant_db_name(funky);
for c in ['/', '\\', '.', '"', '$', ' '] {
assert!(
!name.contains(c),
"sanitized db name still contains {c:?}: {name}"
);
}
}
#[tokio::test]
async fn tenant_db_name_falls_back_to_hash_when_too_long() {
let uri = std::env::var("TEST_MONGODB_URI")
.unwrap_or_else(|_| "mongodb://root:example@localhost:27017/?authSource=admin".into());
let pool = DatabasePool::connect(&uri, "m72a_long")
.await
.expect("connect");
// 100-byte tenant_id would overflow the 63-byte db-name cap with
// any reasonable prefix. The pool must hash it down.
let huge = "x".repeat(100);
let name = pool.tenant_db_name(&huge);
assert!(name.len() <= 63, "hashed name should fit: {name}");
assert!(name.starts_with("m72a_long_"));
// The hash suffix is 32 hex chars (16-byte SHA-256 truncation).
let suffix = name.trim_start_matches("m72a_long_");
assert_eq!(
suffix.len(),
32,
"expected 32-hex suffix (16-byte hash), got {suffix:?}"
);
assert!(suffix.chars().all(|c| c.is_ascii_hexdigit()));
// Stable: same input → same output.
assert_eq!(name, pool.tenant_db_name(&huge));
// Different inputs → different outputs (collision check on a tiny
// sample — full birthday-resistance is a proof not a test).
let huge2 = "y".repeat(100);
assert_ne!(pool.tenant_db_name(&huge), pool.tenant_db_name(&huge2));
}
#[tokio::test]
async fn connect_rejects_overlong_db_prefix() {
let uri = std::env::var("TEST_MONGODB_URI")
.unwrap_or_else(|_| "mongodb://root:example@localhost:27017/?authSource=admin".into());
// MAX_PREFIX_LEN is 30 (= 63 - 1 - 32). A 31-char prefix MUST be
// rejected at construction so the hash-fallback path can never
// produce an over-long db name at runtime.
let too_long = "a".repeat(31);
let err = DatabasePool::connect(&uri, &too_long).await.unwrap_err();
let msg = format!("{err}");
assert!(
msg.contains("max is 30") || msg.contains(&too_long),
"error should explain the cap: {msg}"
);
// Exactly 30 chars is the inclusive bound — must succeed.
let just_right = "a".repeat(30);
let _ = DatabasePool::connect(&uri, &just_right)
.await
.expect("30-char prefix should be accepted");
}
/// Short UUID slug for keeping test prefixes well under Mongo's 63-byte
/// db-name cap.
fn short_id() -> String {
uuid::Uuid::new_v4().simple().to_string()[..8].to_string()
}
/// Drain a `repositories` find cursor on the given tenant database.
async fn collect(db: &compliance_agent::database::Database) -> Vec<TrackedRepository> {
let mut cursor = db
.repositories()
.find(doc! {})
.await
.expect("find repositories");
let mut out = Vec::new();
while cursor.advance().await.expect("advance") {
out.push(cursor.deserialize_current().expect("deserialize"));
}
out
}
@@ -1,122 +0,0 @@
//! M7.1 — integration tests for `compliance_core::auth::require_tenant_status`.
//!
//! Exercises the middleware end-to-end through an Axum router so we
//! catch wiring bugs (extension propagation, method matching) that pure
//! unit tests would miss.
#![allow(clippy::expect_used, clippy::unwrap_used)]
use axum::{
body::Body,
extract::Request,
http::{Method, StatusCode},
middleware::{from_fn, Next},
response::Response,
routing::{get, post},
Router,
};
use compliance_core::{auth::require_tenant_status, TenantContext, TenantStatus};
use tower::ServiceExt;
fn ctx_with(status: TenantStatus) -> TenantContext {
TenantContext {
tenant_id: "t-1".to_string(),
tenant_slug: "acme".to_string(),
org_roles: vec![],
products: vec![],
plan: "starter".to_string(),
status,
user_id: "u-1".to_string(),
user_name: None,
}
}
fn router_with_ctx(ctx: Option<TenantContext>) -> Router {
let injector = move |mut req: Request, next: Next| {
let ctx = ctx.clone();
async move {
if let Some(c) = ctx {
req.extensions_mut().insert(c);
}
next.run(req).await
}
};
Router::new()
.route("/r", get(|| async { "read" }))
.route("/w", post(|| async { "write" }))
.layer(from_fn(require_tenant_status))
.layer(from_fn(injector))
}
async fn call(router: Router, method: Method, path: &str) -> Response {
let req = Request::builder()
.method(method)
.uri(path)
.body(Body::empty())
.expect("request build");
router.oneshot(req).await.expect("oneshot")
}
#[tokio::test]
async fn active_tenant_can_read_and_write() {
let r = router_with_ctx(Some(ctx_with(TenantStatus::Active)));
assert_eq!(
call(r.clone(), Method::GET, "/r").await.status(),
StatusCode::OK
);
assert_eq!(call(r, Method::POST, "/w").await.status(), StatusCode::OK);
}
#[tokio::test]
async fn trial_tenant_can_read_and_write() {
let r = router_with_ctx(Some(ctx_with(TenantStatus::Trial)));
assert_eq!(
call(r.clone(), Method::GET, "/r").await.status(),
StatusCode::OK
);
assert_eq!(call(r, Method::POST, "/w").await.status(), StatusCode::OK);
}
#[tokio::test]
async fn demo_tenant_can_read_and_write() {
let r = router_with_ctx(Some(ctx_with(TenantStatus::Demo)));
assert_eq!(
call(r.clone(), Method::GET, "/r").await.status(),
StatusCode::OK
);
assert_eq!(call(r, Method::POST, "/w").await.status(), StatusCode::OK);
}
#[tokio::test]
async fn frozen_tenant_can_read_but_not_write() {
let r = router_with_ctx(Some(ctx_with(TenantStatus::Frozen)));
assert_eq!(
call(r.clone(), Method::GET, "/r").await.status(),
StatusCode::OK
);
assert_eq!(
call(r, Method::POST, "/w").await.status(),
StatusCode::PAYMENT_REQUIRED
);
}
#[tokio::test]
async fn archived_tenant_is_gone_on_every_method() {
let r = router_with_ctx(Some(ctx_with(TenantStatus::Archived)));
assert_eq!(
call(r.clone(), Method::GET, "/r").await.status(),
StatusCode::GONE
);
assert_eq!(call(r, Method::POST, "/w").await.status(), StatusCode::GONE);
}
#[tokio::test]
async fn no_context_passes_through() {
let r = router_with_ctx(None);
assert_eq!(
call(r.clone(), Method::GET, "/r").await.status(),
StatusCode::OK
);
assert_eq!(call(r, Method::POST, "/w").await.status(), StatusCode::OK);
}