Compare commits

..

1 Commits

Author SHA1 Message Date
Sharang Parnerkar f474699279 fix(core): JWKS refresh-on-failure in M7.1 auth middleware
CI / Check (pull_request) Successful in 8m17s
CI / Detect Changes (pull_request) Has been skipped
CI / Deploy Agent (pull_request) Has been skipped
CI / Deploy Dashboard (pull_request) Has been skipped
CI / Deploy Docs (pull_request) Has been skipped
CI / Deploy MCP (pull_request) Has been skipped
Without this, every Keycloak signing-key rotation produces a silent
401 storm against every request until the agent restarts — the cached
JWKS is held forever and never reconciled against KC.

Now: when `kid` isn't in the cached JWKS or the matching key fails
signature verification, we classify the failure as Stale, force a JWKS
refresh, and retry once. Anything else (expired, malformed, missing
tenant_id) is Permanent and short-circuits straight to 401.

* Splits the path into a pure `try_validate(token, header, kid, jwks)`
  helper returning a `ValidationError { Stale | Permanent }` enum.
* `fetch_or_get_jwks(state, force)` takes a force flag and holds the
  write lock across the network fetch so concurrent refreshers don't
  all hammer Keycloak when keys rotate (the second writer reuses what
  the first put in cache).
* Adds a unit test for the kid-not-found Stale classification.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
2026-06-04 16:40:55 +02:00
55 changed files with 905 additions and 2327 deletions
Generated
-5
View File
@@ -676,7 +676,6 @@ dependencies = [
"jsonwebtoken", "jsonwebtoken",
"mongodb", "mongodb",
"octocrab", "octocrab",
"rand 0.9.2",
"regex", "regex",
"reqwest", "reqwest",
"secrecy", "secrecy",
@@ -688,7 +687,6 @@ dependencies = [
"tokio-cron-scheduler", "tokio-cron-scheduler",
"tokio-stream", "tokio-stream",
"tokio-tungstenite 0.26.2", "tokio-tungstenite 0.26.2",
"tower",
"tower-http", "tower-http",
"tracing", "tracing",
"tracing-subscriber", "tracing-subscriber",
@@ -819,15 +817,12 @@ dependencies = [
"bson", "bson",
"chrono", "chrono",
"compliance-core", "compliance-core",
"dashmap",
"dotenvy", "dotenvy",
"hex",
"mongodb", "mongodb",
"rmcp", "rmcp",
"schemars 1.2.1", "schemars 1.2.1",
"serde", "serde",
"serde_json", "serde_json",
"sha2",
"thiserror 2.0.18", "thiserror 2.0.18",
"tokio", "tokio",
"tower-http", "tower-http",
-2
View File
@@ -34,5 +34,3 @@ zip = { version = "2", features = ["aes-crypto", "deflate"] }
dashmap = "6" dashmap = "6"
tokio-stream = { version = "0.1", features = ["sync"] } tokio-stream = { version = "0.1", features = ["sync"] }
aes-gcm = "0.10" aes-gcm = "0.10"
rand = "0.9"
base64 = "0.22"
+2 -4
View File
@@ -7,7 +7,7 @@ edition = "2021"
workspace = true workspace = true
[dependencies] [dependencies]
compliance-core = { workspace = true, features = ["mongodb", "telemetry", "axum"] } compliance-core = { workspace = true, features = ["mongodb", "telemetry"] }
compliance-graph = { path = "../compliance-graph" } compliance-graph = { path = "../compliance-graph" }
compliance-dast = { path = "../compliance-dast" } compliance-dast = { path = "../compliance-dast" }
serde = { workspace = true } serde = { workspace = true }
@@ -42,11 +42,9 @@ tokio-tungstenite = { version = "0.26", features = ["rustls-tls-webpki-roots"] }
futures-core = "0.3" futures-core = "0.3"
dashmap = { workspace = true } dashmap = { workspace = true }
tokio-stream = { workspace = true } tokio-stream = { workspace = true }
rand = { workspace = true }
[dev-dependencies] [dev-dependencies]
compliance-core = { workspace = true, features = ["mongodb", "axum"] } compliance-core = { workspace = true, features = ["mongodb"] }
tower = { version = "0.5", features = ["util"] }
reqwest = { workspace = true } reqwest = { workspace = true }
serde_json = { workspace = true } serde_json = { workspace = true }
tokio = { workspace = true } tokio = { workspace = true }
+18 -16
View File
@@ -6,7 +6,7 @@ use tokio::sync::{broadcast, watch, Semaphore};
use compliance_core::models::pentest::PentestEvent; use compliance_core::models::pentest::PentestEvent;
use compliance_core::AgentConfig; use compliance_core::AgentConfig;
use crate::database::DatabasePool; use crate::database::Database;
use crate::llm::LlmClient; use crate::llm::LlmClient;
use crate::pipeline::orchestrator::PipelineOrchestrator; use crate::pipeline::orchestrator::PipelineOrchestrator;
@@ -16,10 +16,7 @@ const DEFAULT_MAX_CONCURRENT_SESSIONS: usize = 5;
#[derive(Clone)] #[derive(Clone)]
pub struct ComplianceAgent { pub struct ComplianceAgent {
pub config: AgentConfig, pub config: AgentConfig,
/// Per-tenant Mongo broker. Every code path must obtain a pub db: Database,
/// tenant-scoped [`crate::database::Database`] from this pool —
/// there is no single shared database any more.
pub db_pool: DatabasePool,
pub llm: Arc<LlmClient>, pub llm: Arc<LlmClient>,
pub http: reqwest::Client, pub http: reqwest::Client,
/// Per-session broadcast senders for SSE streaming. /// Per-session broadcast senders for SSE streaming.
@@ -31,7 +28,7 @@ pub struct ComplianceAgent {
} }
impl ComplianceAgent { impl ComplianceAgent {
pub fn new(config: AgentConfig, db_pool: DatabasePool) -> Self { pub fn new(config: AgentConfig, db: Database) -> Self {
let llm = Arc::new(LlmClient::new( let llm = Arc::new(LlmClient::new(
config.litellm_url.clone(), config.litellm_url.clone(),
config.litellm_api_key.clone(), config.litellm_api_key.clone(),
@@ -45,7 +42,7 @@ impl ComplianceAgent {
.unwrap_or_default(); .unwrap_or_default();
Self { Self {
config, config,
db_pool, db,
llm, llm,
http, http,
session_streams: Arc::new(DashMap::new()), session_streams: Arc::new(DashMap::new()),
@@ -56,27 +53,28 @@ impl ComplianceAgent {
pub async fn run_scan( pub async fn run_scan(
&self, &self,
tenant_id: &str,
repo_id: &str, repo_id: &str,
trigger: compliance_core::models::ScanTrigger, trigger: compliance_core::models::ScanTrigger,
) -> Result<(), crate::error::AgentError> { ) -> Result<(), crate::error::AgentError> {
let db = self.db_pool.for_tenant_id(tenant_id).await?; let orchestrator = PipelineOrchestrator::new(
let orchestrator = self.config.clone(),
PipelineOrchestrator::new(self.config.clone(), db, self.llm.clone(), self.http.clone()); self.db.clone(),
self.llm.clone(),
self.http.clone(),
);
orchestrator.run(repo_id, trigger).await orchestrator.run(repo_id, trigger).await
} }
/// Run a PR review: scan the diff and post review comments. /// Run a PR review: scan the diff and post review comments.
pub async fn run_pr_review( pub async fn run_pr_review(
&self, &self,
tenant_id: &str,
repo_id: &str, repo_id: &str,
pr_number: u64, pr_number: u64,
base_sha: &str, base_sha: &str,
head_sha: &str, head_sha: &str,
) -> Result<(), crate::error::AgentError> { ) -> Result<(), crate::error::AgentError> {
let db = self.db_pool.for_tenant_id(tenant_id).await?; let repo = self
let repo = db .db
.repositories() .repositories()
.find_one(mongodb::bson::doc! { .find_one(mongodb::bson::doc! {
"_id": mongodb::bson::oid::ObjectId::parse_str(repo_id) "_id": mongodb::bson::oid::ObjectId::parse_str(repo_id)
@@ -87,8 +85,12 @@ impl ComplianceAgent {
crate::error::AgentError::Other(format!("Repository {repo_id} not found")) crate::error::AgentError::Other(format!("Repository {repo_id} not found"))
})?; })?;
let orchestrator = let orchestrator = PipelineOrchestrator::new(
PipelineOrchestrator::new(self.config.clone(), db, self.llm.clone(), self.http.clone()); self.config.clone(),
self.db.clone(),
self.llm.clone(),
self.http.clone(),
);
orchestrator orchestrator
.run_pr_review(&repo, repo_id, pr_number, base_sha, head_sha) .run_pr_review(&repo, repo_id, pr_number, base_sha, head_sha)
.await .await
+113
View File
@@ -0,0 +1,113 @@
use std::sync::Arc;
use axum::{
extract::Request,
middleware::Next,
response::{IntoResponse, Response},
};
use jsonwebtoken::{decode, decode_header, jwk::JwkSet, DecodingKey, Validation};
use reqwest::StatusCode;
use serde::Deserialize;
use tokio::sync::RwLock;
/// Cached JWKS from Keycloak for token validation.
#[derive(Clone)]
pub struct JwksState {
pub jwks: Arc<RwLock<Option<JwkSet>>>,
pub jwks_url: String,
}
#[derive(Debug, Deserialize)]
struct Claims {
#[allow(dead_code)]
sub: String,
}
const PUBLIC_ENDPOINTS: &[&str] = &["/api/v1/health"];
/// Middleware that validates Bearer JWT tokens against Keycloak's JWKS.
///
/// Skips validation for health check endpoints.
/// If `JwksState` is not present as an extension (keycloak not configured),
/// all requests pass through.
pub async fn require_jwt_auth(request: Request, next: Next) -> Response {
let path = request.uri().path();
if PUBLIC_ENDPOINTS.contains(&path) {
return next.run(request).await;
}
let jwks_state = match request.extensions().get::<JwksState>() {
Some(s) => s.clone(),
None => return next.run(request).await,
};
let auth_header = match request.headers().get("authorization") {
Some(h) => h,
None => return (StatusCode::UNAUTHORIZED, "Missing authorization header").into_response(),
};
let token = match auth_header.to_str() {
Ok(s) if s.starts_with("Bearer ") => &s[7..],
_ => return (StatusCode::UNAUTHORIZED, "Invalid authorization header").into_response(),
};
match validate_token(token, &jwks_state).await {
Ok(()) => next.run(request).await,
Err(e) => {
tracing::warn!("JWT validation failed: {e}");
(StatusCode::UNAUTHORIZED, "Invalid token").into_response()
}
}
}
async fn validate_token(token: &str, state: &JwksState) -> Result<(), String> {
let header = decode_header(token).map_err(|e| format!("failed to decode JWT header: {e}"))?;
let kid = header
.kid
.ok_or_else(|| "JWT missing kid header".to_string())?;
let jwks = fetch_or_get_jwks(state).await?;
let jwk = jwks
.keys
.iter()
.find(|k| k.common.key_id.as_deref() == Some(&kid))
.ok_or_else(|| "no matching key found in JWKS".to_string())?;
let decoding_key =
DecodingKey::from_jwk(jwk).map_err(|e| format!("failed to create decoding key: {e}"))?;
let mut validation = Validation::new(header.alg);
validation.validate_exp = true;
validation.validate_aud = false;
decode::<Claims>(token, &decoding_key, &validation)
.map_err(|e| format!("token validation failed: {e}"))?;
Ok(())
}
async fn fetch_or_get_jwks(state: &JwksState) -> Result<JwkSet, String> {
{
let cached = state.jwks.read().await;
if let Some(ref jwks) = *cached {
return Ok(jwks.clone());
}
}
let resp = reqwest::get(&state.jwks_url)
.await
.map_err(|e| format!("failed to fetch JWKS: {e}"))?;
let jwks: JwkSet = resp
.json()
.await
.map_err(|e| format!("failed to parse JWKS: {e}"))?;
let mut cached = state.jwks.write().await;
*cached = Some(jwks.clone());
Ok(jwks)
}
+26 -30
View File
@@ -7,13 +7,11 @@ use mongodb::bson::doc;
use compliance_core::models::chat::{ChatRequest, ChatResponse, SourceReference}; use compliance_core::models::chat::{ChatRequest, ChatResponse, SourceReference};
use compliance_core::models::embedding::EmbeddingBuildRun; use compliance_core::models::embedding::EmbeddingBuildRun;
use compliance_core::tenant_ctx::TenantCtx;
use compliance_graph::graph::embedding_store::EmbeddingStore; use compliance_graph::graph::embedding_store::EmbeddingStore;
use crate::agent::ComplianceAgent; use crate::agent::ComplianceAgent;
use crate::rag::pipeline::RagPipeline; use crate::rag::pipeline::RagPipeline;
use super::dto::tenant_db;
use super::ApiResponse; use super::ApiResponse;
type AgentExt = Extension<Arc<ComplianceAgent>>; type AgentExt = Extension<Arc<ComplianceAgent>>;
@@ -22,12 +20,10 @@ type AgentExt = Extension<Arc<ComplianceAgent>>;
#[tracing::instrument(skip_all, fields(repo_id = %repo_id))] #[tracing::instrument(skip_all, fields(repo_id = %repo_id))]
pub async fn chat( pub async fn chat(
Extension(agent): AgentExt, Extension(agent): AgentExt,
tenant: TenantCtx,
Path(repo_id): Path<String>, Path(repo_id): Path<String>,
Json(req): Json<ChatRequest>, Json(req): Json<ChatRequest>,
) -> Result<Json<ApiResponse<ChatResponse>>, StatusCode> { ) -> Result<Json<ApiResponse<ChatResponse>>, StatusCode> {
let db = tenant_db(&agent, &tenant).await?; let pipeline = RagPipeline::new(agent.llm.clone(), agent.db.inner());
let pipeline = RagPipeline::new(agent.llm.clone(), db.inner());
// Step 1: Embed the user's message // Step 1: Embed the user's message
let query_vectors = agent let query_vectors = agent
@@ -137,15 +133,12 @@ pub async fn chat(
#[tracing::instrument(skip_all, fields(repo_id = %repo_id))] #[tracing::instrument(skip_all, fields(repo_id = %repo_id))]
pub async fn build_embeddings( pub async fn build_embeddings(
Extension(agent): AgentExt, Extension(agent): AgentExt,
tenant: TenantCtx,
Path(repo_id): Path<String>, Path(repo_id): Path<String>,
) -> Result<Json<serde_json::Value>, StatusCode> { ) -> Result<Json<serde_json::Value>, StatusCode> {
// Resolve the tenant DB up front so we can move it into the spawn;
// the JWT/dev context isn't available inside detached tasks.
let db = tenant_db(&agent, &tenant).await?;
let agent_clone = (*agent).clone(); let agent_clone = (*agent).clone();
tokio::spawn(async move { tokio::spawn(async move {
let repo = match db let repo = match agent_clone
.db
.repositories() .repositories()
.find_one(doc! { "_id": mongodb::bson::oid::ObjectId::parse_str(&repo_id).ok() }) .find_one(doc! { "_id": mongodb::bson::oid::ObjectId::parse_str(&repo_id).ok() })
.await .await
@@ -158,7 +151,8 @@ pub async fn build_embeddings(
}; };
// Get latest graph build // Get latest graph build
let build = match db let build = match agent_clone
.db
.graph_builds() .graph_builds()
.find_one(doc! { "repo_id": &repo_id }) .find_one(doc! { "repo_id": &repo_id })
.sort(doc! { "started_at": -1 }) .sort(doc! { "started_at": -1 })
@@ -177,22 +171,26 @@ pub async fn build_embeddings(
.unwrap_or_else(|| "unknown".to_string()); .unwrap_or_else(|| "unknown".to_string());
// Get nodes // Get nodes
let nodes: Vec<compliance_core::models::graph::CodeNode> = let nodes: Vec<compliance_core::models::graph::CodeNode> = match agent_clone
match db.graph_nodes().find(doc! { "repo_id": &repo_id }).await { .db
Ok(cursor) => { .graph_nodes()
use futures_util::StreamExt; .find(doc! { "repo_id": &repo_id })
let mut items = Vec::new(); .await
let mut cursor = cursor; {
while let Some(Ok(item)) = cursor.next().await { Ok(cursor) => {
items.push(item); use futures_util::StreamExt;
} let mut items = Vec::new();
items let mut cursor = cursor;
while let Some(Ok(item)) = cursor.next().await {
items.push(item);
} }
Err(e) => { items
tracing::error!("[{repo_id}] Failed to fetch nodes: {e}"); }
return; Err(e) => {
} tracing::error!("[{repo_id}] Failed to fetch nodes: {e}");
}; return;
}
};
let creds = crate::pipeline::git::RepoCredentials { let creds = crate::pipeline::git::RepoCredentials {
ssh_key_path: Some(agent_clone.config.ssh_key_path.clone()), ssh_key_path: Some(agent_clone.config.ssh_key_path.clone()),
@@ -209,7 +207,7 @@ pub async fn build_embeddings(
} }
}; };
let pipeline = RagPipeline::new(agent_clone.llm.clone(), db.inner()); let pipeline = RagPipeline::new(agent_clone.llm.clone(), agent_clone.db.inner());
match pipeline match pipeline
.build_embeddings(&repo_id, &repo_path, &graph_build_id, &nodes) .build_embeddings(&repo_id, &repo_path, &graph_build_id, &nodes)
.await .await
@@ -236,11 +234,9 @@ pub async fn build_embeddings(
#[tracing::instrument(skip_all, fields(repo_id = %repo_id))] #[tracing::instrument(skip_all, fields(repo_id = %repo_id))]
pub async fn embedding_status( pub async fn embedding_status(
Extension(agent): AgentExt, Extension(agent): AgentExt,
tenant: TenantCtx,
Path(repo_id): Path<String>, Path(repo_id): Path<String>,
) -> Result<Json<ApiResponse<Option<EmbeddingBuildRun>>>, StatusCode> { ) -> Result<Json<ApiResponse<Option<EmbeddingBuildRun>>>, StatusCode> {
let db = tenant_db(&agent, &tenant).await?; let store = EmbeddingStore::new(agent.db.inner());
let store = EmbeddingStore::new(db.inner());
let build = store.get_latest_build(&repo_id).await.map_err(|e| { let build = store.get_latest_build(&repo_id).await.map_err(|e| {
tracing::error!("Failed to get embedding status: {e}"); tracing::error!("Failed to get embedding status: {e}");
StatusCode::INTERNAL_SERVER_ERROR StatusCode::INTERNAL_SERVER_ERROR
+11 -20
View File
@@ -7,11 +7,9 @@ use mongodb::bson::doc;
use serde::Deserialize; use serde::Deserialize;
use compliance_core::models::dast::{DastFinding, DastScanRun, DastTarget, DastTargetType}; use compliance_core::models::dast::{DastFinding, DastScanRun, DastTarget, DastTargetType};
use compliance_core::tenant_ctx::TenantCtx;
use crate::agent::ComplianceAgent; use crate::agent::ComplianceAgent;
use super::dto::tenant_db;
use super::{collect_cursor_async, ApiResponse, PaginationParams}; use super::{collect_cursor_async, ApiResponse, PaginationParams};
type AgentExt = Extension<Arc<ComplianceAgent>>; type AgentExt = Extension<Arc<ComplianceAgent>>;
@@ -47,11 +45,9 @@ fn default_rate_limit() -> u32 {
#[tracing::instrument(skip_all)] #[tracing::instrument(skip_all)]
pub async fn list_targets( pub async fn list_targets(
Extension(agent): AgentExt, Extension(agent): AgentExt,
tenant: TenantCtx,
Query(params): Query<PaginationParams>, Query(params): Query<PaginationParams>,
) -> Result<Json<ApiResponse<Vec<DastTarget>>>, StatusCode> { ) -> Result<Json<ApiResponse<Vec<DastTarget>>>, StatusCode> {
let db = tenant_db(&agent, &tenant).await?; let db = &agent.db;
let db = &db;
let skip = (params.page.saturating_sub(1)) * params.limit as u64; let skip = (params.page.saturating_sub(1)) * params.limit as u64;
let total = db let total = db
.dast_targets() .dast_targets()
@@ -84,7 +80,6 @@ pub async fn list_targets(
#[tracing::instrument(skip_all)] #[tracing::instrument(skip_all)]
pub async fn add_target( pub async fn add_target(
Extension(agent): AgentExt, Extension(agent): AgentExt,
tenant: TenantCtx,
Json(req): Json<AddTargetRequest>, Json(req): Json<AddTargetRequest>,
) -> Result<Json<ApiResponse<DastTarget>>, StatusCode> { ) -> Result<Json<ApiResponse<DastTarget>>, StatusCode> {
let mut target = DastTarget::new(req.name, req.base_url, req.target_type); let mut target = DastTarget::new(req.name, req.base_url, req.target_type);
@@ -94,8 +89,9 @@ pub async fn add_target(
target.rate_limit = req.rate_limit; target.rate_limit = req.rate_limit;
target.allow_destructive = req.allow_destructive; target.allow_destructive = req.allow_destructive;
let db = tenant_db(&agent, &tenant).await?; agent
db.dast_targets() .db
.dast_targets()
.insert_one(&target) .insert_one(&target)
.await .await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
@@ -111,19 +107,19 @@ pub async fn add_target(
#[tracing::instrument(skip_all, fields(target_id = %id))] #[tracing::instrument(skip_all, fields(target_id = %id))]
pub async fn trigger_scan( pub async fn trigger_scan(
Extension(agent): AgentExt, Extension(agent): AgentExt,
tenant: TenantCtx,
Path(id): Path<String>, Path(id): Path<String>,
) -> Result<Json<serde_json::Value>, StatusCode> { ) -> Result<Json<serde_json::Value>, StatusCode> {
let oid = mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?; let oid = mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?;
let db = tenant_db(&agent, &tenant).await?;
let target = db let target = agent
.db
.dast_targets() .dast_targets()
.find_one(doc! { "_id": oid }) .find_one(doc! { "_id": oid })
.await .await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)? .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
.ok_or(StatusCode::NOT_FOUND)?; .ok_or(StatusCode::NOT_FOUND)?;
let db = agent.db.clone();
tokio::spawn(async move { tokio::spawn(async move {
let orchestrator = compliance_dast::DastOrchestrator::new(100); let orchestrator = compliance_dast::DastOrchestrator::new(100);
match orchestrator.run_scan(&target, Vec::new()).await { match orchestrator.run_scan(&target, Vec::new()).await {
@@ -151,11 +147,9 @@ pub async fn trigger_scan(
#[tracing::instrument(skip_all)] #[tracing::instrument(skip_all)]
pub async fn list_scan_runs( pub async fn list_scan_runs(
Extension(agent): AgentExt, Extension(agent): AgentExt,
tenant: TenantCtx,
Query(params): Query<PaginationParams>, Query(params): Query<PaginationParams>,
) -> Result<Json<ApiResponse<Vec<DastScanRun>>>, StatusCode> { ) -> Result<Json<ApiResponse<Vec<DastScanRun>>>, StatusCode> {
let db = tenant_db(&agent, &tenant).await?; let db = &agent.db;
let db = &db;
let skip = (params.page.saturating_sub(1)) * params.limit as u64; let skip = (params.page.saturating_sub(1)) * params.limit as u64;
let total = db let total = db
.dast_scan_runs() .dast_scan_runs()
@@ -189,11 +183,9 @@ pub async fn list_scan_runs(
#[tracing::instrument(skip_all)] #[tracing::instrument(skip_all)]
pub async fn list_findings( pub async fn list_findings(
Extension(agent): AgentExt, Extension(agent): AgentExt,
tenant: TenantCtx,
Query(params): Query<PaginationParams>, Query(params): Query<PaginationParams>,
) -> Result<Json<ApiResponse<Vec<DastFinding>>>, StatusCode> { ) -> Result<Json<ApiResponse<Vec<DastFinding>>>, StatusCode> {
let db = tenant_db(&agent, &tenant).await?; let db = &agent.db;
let db = &db;
let skip = (params.page.saturating_sub(1)) * params.limit as u64; let skip = (params.page.saturating_sub(1)) * params.limit as u64;
let total = db let total = db
.dast_findings() .dast_findings()
@@ -227,13 +219,12 @@ pub async fn list_findings(
#[tracing::instrument(skip_all, fields(finding_id = %id))] #[tracing::instrument(skip_all, fields(finding_id = %id))]
pub async fn get_finding( pub async fn get_finding(
Extension(agent): AgentExt, Extension(agent): AgentExt,
tenant: TenantCtx,
Path(id): Path<String>, Path(id): Path<String>,
) -> Result<Json<ApiResponse<DastFinding>>, StatusCode> { ) -> Result<Json<ApiResponse<DastFinding>>, StatusCode> {
let oid = mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?; let oid = mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?;
let db = tenant_db(&agent, &tenant).await?;
let finding = db let finding = agent
.db
.dast_findings() .dast_findings()
.find_one(doc! { "_id": oid }) .find_one(doc! { "_id": oid })
.await .await
-21
View File
@@ -180,27 +180,6 @@ pub struct SbomVersionDiff {
pub(crate) type AgentExt = axum::extract::Extension<std::sync::Arc<crate::agent::ComplianceAgent>>; pub(crate) type AgentExt = axum::extract::Extension<std::sync::Arc<crate::agent::ComplianceAgent>>;
pub(crate) type ApiResult<T> = Result<axum::Json<ApiResponse<T>>, axum::http::StatusCode>; pub(crate) type ApiResult<T> = Result<axum::Json<ApiResponse<T>>, axum::http::StatusCode>;
/// Resolve a tenant-scoped [`Database`] from the request's
/// [`TenantContext`] (inserted by the M7.1 JWT middleware, or by the
/// dev fallback in unsecured environments). The pool ensures the
/// tenant's indexes idempotently.
///
/// Returns 500 on the rare path where Mongo refuses the database
/// handle — the M7.1 auth/status middleware already rejects every
/// other failure mode with 4xx before we get here.
pub(crate) async fn tenant_db(
agent: &crate::agent::ComplianceAgent,
tenant: &compliance_core::tenant_ctx::TenantCtx,
) -> Result<crate::database::Database, axum::http::StatusCode> {
agent.db_pool.for_tenant(&tenant.0).await.map_err(|e| {
tracing::error!(
tenant_id = %tenant.0.tenant_id,
"Failed to acquire tenant database: {e}"
);
axum::http::StatusCode::INTERNAL_SERVER_ERROR
})
}
pub(crate) async fn collect_cursor_async<T: serde::de::DeserializeOwned + Unpin + Send>( pub(crate) async fn collect_cursor_async<T: serde::de::DeserializeOwned + Unpin + Send>(
mut cursor: mongodb::Cursor<T>, mut cursor: mongodb::Cursor<T>,
) -> Vec<T> { ) -> Vec<T> {
+11 -16
View File
@@ -5,16 +5,13 @@ use mongodb::bson::doc;
use super::dto::*; use super::dto::*;
use compliance_core::models::Finding; use compliance_core::models::Finding;
use compliance_core::tenant_ctx::TenantCtx;
#[tracing::instrument(skip_all, fields(repo_id = ?filter.repo_id, severity = ?filter.severity, scan_type = ?filter.scan_type))] #[tracing::instrument(skip_all, fields(repo_id = ?filter.repo_id, severity = ?filter.severity, scan_type = ?filter.scan_type))]
pub async fn list_findings( pub async fn list_findings(
Extension(agent): AgentExt, Extension(agent): AgentExt,
tenant: TenantCtx,
Query(filter): Query<FindingsFilter>, Query(filter): Query<FindingsFilter>,
) -> ApiResult<Vec<Finding>> { ) -> ApiResult<Vec<Finding>> {
let db = tenant_db(&agent, &tenant).await?; let db = &agent.db;
let db = &db;
let mut query = doc! {}; let mut query = doc! {};
if let Some(repo_id) = &filter.repo_id { if let Some(repo_id) = &filter.repo_id {
query.insert("repo_id", repo_id); query.insert("repo_id", repo_id);
@@ -84,12 +81,11 @@ pub async fn list_findings(
#[tracing::instrument(skip_all, fields(finding_id = %id))] #[tracing::instrument(skip_all, fields(finding_id = %id))]
pub async fn get_finding( pub async fn get_finding(
Extension(agent): AgentExt, Extension(agent): AgentExt,
tenant: TenantCtx,
Path(id): Path<String>, Path(id): Path<String>,
) -> Result<Json<ApiResponse<Finding>>, StatusCode> { ) -> Result<Json<ApiResponse<Finding>>, StatusCode> {
let oid = mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?; let oid = mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?;
let db = tenant_db(&agent, &tenant).await?; let finding = agent
let finding = db .db
.findings() .findings()
.find_one(doc! { "_id": oid }) .find_one(doc! { "_id": oid })
.await .await
@@ -106,14 +102,14 @@ pub async fn get_finding(
#[tracing::instrument(skip_all, fields(finding_id = %id))] #[tracing::instrument(skip_all, fields(finding_id = %id))]
pub async fn update_finding_status( pub async fn update_finding_status(
Extension(agent): AgentExt, Extension(agent): AgentExt,
tenant: TenantCtx,
Path(id): Path<String>, Path(id): Path<String>,
Json(req): Json<UpdateStatusRequest>, Json(req): Json<UpdateStatusRequest>,
) -> Result<Json<serde_json::Value>, StatusCode> { ) -> Result<Json<serde_json::Value>, StatusCode> {
let oid = mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?; let oid = mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?;
let db = tenant_db(&agent, &tenant).await?;
db.findings() agent
.db
.findings()
.update_one( .update_one(
doc! { "_id": oid }, doc! { "_id": oid },
doc! { "$set": { "status": &req.status, "updated_at": mongodb::bson::DateTime::now() } }, doc! { "$set": { "status": &req.status, "updated_at": mongodb::bson::DateTime::now() } },
@@ -127,7 +123,6 @@ pub async fn update_finding_status(
#[tracing::instrument(skip_all)] #[tracing::instrument(skip_all)]
pub async fn bulk_update_finding_status( pub async fn bulk_update_finding_status(
Extension(agent): AgentExt, Extension(agent): AgentExt,
tenant: TenantCtx,
Json(req): Json<BulkUpdateStatusRequest>, Json(req): Json<BulkUpdateStatusRequest>,
) -> Result<Json<serde_json::Value>, StatusCode> { ) -> Result<Json<serde_json::Value>, StatusCode> {
let oids: Vec<mongodb::bson::oid::ObjectId> = req let oids: Vec<mongodb::bson::oid::ObjectId> = req
@@ -140,8 +135,8 @@ pub async fn bulk_update_finding_status(
return Err(StatusCode::BAD_REQUEST); return Err(StatusCode::BAD_REQUEST);
} }
let db = tenant_db(&agent, &tenant).await?; let result = agent
let result = db .db
.findings() .findings()
.update_many( .update_many(
doc! { "_id": { "$in": oids } }, doc! { "_id": { "$in": oids } },
@@ -158,14 +153,14 @@ pub async fn bulk_update_finding_status(
#[tracing::instrument(skip_all)] #[tracing::instrument(skip_all)]
pub async fn update_finding_feedback( pub async fn update_finding_feedback(
Extension(agent): AgentExt, Extension(agent): AgentExt,
tenant: TenantCtx,
Path(id): Path<String>, Path(id): Path<String>,
Json(req): Json<UpdateFeedbackRequest>, Json(req): Json<UpdateFeedbackRequest>,
) -> Result<Json<serde_json::Value>, StatusCode> { ) -> Result<Json<serde_json::Value>, StatusCode> {
let oid = mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?; let oid = mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?;
let db = tenant_db(&agent, &tenant).await?;
db.findings() agent
.db
.findings()
.update_one( .update_one(
doc! { "_id": oid }, doc! { "_id": oid },
doc! { "$set": { "developer_feedback": &req.feedback, "updated_at": mongodb::bson::DateTime::now() } }, doc! { "$set": { "developer_feedback": &req.feedback, "updated_at": mongodb::bson::DateTime::now() } },
+10 -24
View File
@@ -7,11 +7,9 @@ use mongodb::bson::doc;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use compliance_core::models::graph::{CodeEdge, CodeNode, GraphBuildRun, ImpactAnalysis}; use compliance_core::models::graph::{CodeEdge, CodeNode, GraphBuildRun, ImpactAnalysis};
use compliance_core::tenant_ctx::TenantCtx;
use crate::agent::ComplianceAgent; use crate::agent::ComplianceAgent;
use super::dto::tenant_db;
use super::{collect_cursor_async, ApiResponse}; use super::{collect_cursor_async, ApiResponse};
type AgentExt = Extension<Arc<ComplianceAgent>>; type AgentExt = Extension<Arc<ComplianceAgent>>;
@@ -38,11 +36,9 @@ fn default_search_limit() -> usize {
#[tracing::instrument(skip_all, fields(repo_id = %repo_id))] #[tracing::instrument(skip_all, fields(repo_id = %repo_id))]
pub async fn get_graph( pub async fn get_graph(
Extension(agent): AgentExt, Extension(agent): AgentExt,
tenant: TenantCtx,
Path(repo_id): Path<String>, Path(repo_id): Path<String>,
) -> Result<Json<ApiResponse<GraphData>>, StatusCode> { ) -> Result<Json<ApiResponse<GraphData>>, StatusCode> {
let db = tenant_db(&agent, &tenant).await?; let db = &agent.db;
let db = &db;
// Get latest build // Get latest build
let build: Option<GraphBuildRun> = db let build: Option<GraphBuildRun> = db
@@ -102,11 +98,9 @@ pub async fn get_graph(
#[tracing::instrument(skip_all, fields(repo_id = %repo_id))] #[tracing::instrument(skip_all, fields(repo_id = %repo_id))]
pub async fn get_nodes( pub async fn get_nodes(
Extension(agent): AgentExt, Extension(agent): AgentExt,
tenant: TenantCtx,
Path(repo_id): Path<String>, Path(repo_id): Path<String>,
) -> Result<Json<ApiResponse<Vec<CodeNode>>>, StatusCode> { ) -> Result<Json<ApiResponse<Vec<CodeNode>>>, StatusCode> {
let db = tenant_db(&agent, &tenant).await?; let db = &agent.db;
let db = &db;
let filter = doc! { "repo_id": &repo_id }; let filter = doc! { "repo_id": &repo_id };
let nodes: Vec<CodeNode> = match db.graph_nodes().find(filter).await { let nodes: Vec<CodeNode> = match db.graph_nodes().find(filter).await {
@@ -129,11 +123,9 @@ pub async fn get_nodes(
#[tracing::instrument(skip_all, fields(repo_id = %repo_id))] #[tracing::instrument(skip_all, fields(repo_id = %repo_id))]
pub async fn get_communities( pub async fn get_communities(
Extension(agent): AgentExt, Extension(agent): AgentExt,
tenant: TenantCtx,
Path(repo_id): Path<String>, Path(repo_id): Path<String>,
) -> Result<Json<ApiResponse<Vec<CommunityInfo>>>, StatusCode> { ) -> Result<Json<ApiResponse<Vec<CommunityInfo>>>, StatusCode> {
let db = tenant_db(&agent, &tenant).await?; let db = &agent.db;
let db = &db;
let filter = doc! { "repo_id": &repo_id }; let filter = doc! { "repo_id": &repo_id };
let nodes: Vec<CodeNode> = match db.graph_nodes().find(filter).await { let nodes: Vec<CodeNode> = match db.graph_nodes().find(filter).await {
@@ -184,11 +176,9 @@ pub struct CommunityInfo {
#[tracing::instrument(skip_all, fields(repo_id = %repo_id, finding_id = %finding_id))] #[tracing::instrument(skip_all, fields(repo_id = %repo_id, finding_id = %finding_id))]
pub async fn get_impact( pub async fn get_impact(
Extension(agent): AgentExt, Extension(agent): AgentExt,
tenant: TenantCtx,
Path((repo_id, finding_id)): Path<(String, String)>, Path((repo_id, finding_id)): Path<(String, String)>,
) -> Result<Json<ApiResponse<Option<ImpactAnalysis>>>, StatusCode> { ) -> Result<Json<ApiResponse<Option<ImpactAnalysis>>>, StatusCode> {
let db = tenant_db(&agent, &tenant).await?; let db = &agent.db;
let db = &db;
let filter = doc! { "repo_id": &repo_id, "finding_id": &finding_id }; let filter = doc! { "repo_id": &repo_id, "finding_id": &finding_id };
let impact = db let impact = db
@@ -208,12 +198,10 @@ pub async fn get_impact(
#[tracing::instrument(skip_all, fields(repo_id = %repo_id, query = %params.q))] #[tracing::instrument(skip_all, fields(repo_id = %repo_id, query = %params.q))]
pub async fn search_symbols( pub async fn search_symbols(
Extension(agent): AgentExt, Extension(agent): AgentExt,
tenant: TenantCtx,
Path(repo_id): Path<String>, Path(repo_id): Path<String>,
Query(params): Query<SearchParams>, Query(params): Query<SearchParams>,
) -> Result<Json<ApiResponse<Vec<CodeNode>>>, StatusCode> { ) -> Result<Json<ApiResponse<Vec<CodeNode>>>, StatusCode> {
let db = tenant_db(&agent, &tenant).await?; let db = &agent.db;
let db = &db;
// Simple text search on qualified_name and name fields // Simple text search on qualified_name and name fields
let filter = doc! { let filter = doc! {
@@ -246,12 +234,10 @@ pub async fn search_symbols(
#[tracing::instrument(skip_all, fields(repo_id = %repo_id))] #[tracing::instrument(skip_all, fields(repo_id = %repo_id))]
pub async fn get_file_content( pub async fn get_file_content(
Extension(agent): AgentExt, Extension(agent): AgentExt,
tenant: TenantCtx,
Path(repo_id): Path<String>, Path(repo_id): Path<String>,
Query(params): Query<FileContentParams>, Query(params): Query<FileContentParams>,
) -> Result<Json<ApiResponse<FileContent>>, StatusCode> { ) -> Result<Json<ApiResponse<FileContent>>, StatusCode> {
let db = tenant_db(&agent, &tenant).await?; let db = &agent.db;
let db = &db;
// Look up the repository to get repo name // Look up the repository to get repo name
let repo = db let repo = db
@@ -310,13 +296,12 @@ pub struct FileContent {
#[tracing::instrument(skip_all, fields(repo_id = %repo_id))] #[tracing::instrument(skip_all, fields(repo_id = %repo_id))]
pub async fn trigger_build( pub async fn trigger_build(
Extension(agent): AgentExt, Extension(agent): AgentExt,
tenant: TenantCtx,
Path(repo_id): Path<String>, Path(repo_id): Path<String>,
) -> Result<Json<serde_json::Value>, StatusCode> { ) -> Result<Json<serde_json::Value>, StatusCode> {
let db = tenant_db(&agent, &tenant).await?;
let agent_clone = (*agent).clone(); let agent_clone = (*agent).clone();
tokio::spawn(async move { tokio::spawn(async move {
let repo = match db let repo = match agent_clone
.db
.repositories() .repositories()
.find_one(doc! { "_id": mongodb::bson::oid::ObjectId::parse_str(&repo_id).ok() }) .find_one(doc! { "_id": mongodb::bson::oid::ObjectId::parse_str(&repo_id).ok() })
.await .await
@@ -348,7 +333,8 @@ pub async fn trigger_build(
match engine.build_graph(&repo_path, &repo_id, &graph_build_id) { match engine.build_graph(&repo_path, &repo_id, &graph_build_id) {
Ok((code_graph, build_run)) => { Ok((code_graph, build_run)) => {
let store = compliance_graph::graph::persistence::GraphStore::new(db.inner()); let store =
compliance_graph::graph::persistence::GraphStore::new(agent_clone.db.inner());
let _ = store.delete_repo_graph(&repo_id).await; let _ = store.delete_repo_graph(&repo_id).await;
let _ = store let _ = store
.store_graph(&build_run, &code_graph.nodes, &code_graph.edges) .store_graph(&build_run, &code_graph.nodes, &code_graph.edges)
+2 -7
View File
@@ -3,7 +3,6 @@ use mongodb::bson::doc;
use super::dto::*; use super::dto::*;
use compliance_core::models::ScanRun; use compliance_core::models::ScanRun;
use compliance_core::tenant_ctx::TenantCtx;
#[tracing::instrument(skip_all)] #[tracing::instrument(skip_all)]
pub async fn health() -> Json<serde_json::Value> { pub async fn health() -> Json<serde_json::Value> {
@@ -11,12 +10,8 @@ pub async fn health() -> Json<serde_json::Value> {
} }
#[tracing::instrument(skip_all)] #[tracing::instrument(skip_all)]
pub async fn stats_overview( pub async fn stats_overview(axum::extract::Extension(agent): AgentExt) -> ApiResult<OverviewStats> {
axum::extract::Extension(agent): AgentExt, let db = &agent.db;
tenant: TenantCtx,
) -> ApiResult<OverviewStats> {
let db = tenant_db(&agent, &tenant).await?;
let db = &db;
let total_repositories = db let total_repositories = db
.repositories() .repositories()
+1 -4
View File
@@ -4,16 +4,13 @@ use mongodb::bson::doc;
use super::dto::*; use super::dto::*;
use compliance_core::models::TrackerIssue; use compliance_core::models::TrackerIssue;
use compliance_core::tenant_ctx::TenantCtx;
#[tracing::instrument(skip_all)] #[tracing::instrument(skip_all)]
pub async fn list_issues( pub async fn list_issues(
Extension(agent): AgentExt, Extension(agent): AgentExt,
tenant: TenantCtx,
Query(params): Query<PaginationParams>, Query(params): Query<PaginationParams>,
) -> ApiResult<Vec<TrackerIssue>> { ) -> ApiResult<Vec<TrackerIssue>> {
let db = tenant_db(&agent, &tenant).await?; let db = &agent.db;
let db = &db;
let skip = (params.page.saturating_sub(1)) * params.limit as u64; let skip = (params.page.saturating_sub(1)) * params.limit as u64;
let total = db let total = db
.tracker_issues() .tracker_issues()
@@ -1,186 +0,0 @@
//! `/api/v1/mcp-tokens` — per-tenant API tokens for the MCP server.
//!
//! These are opaque static bearers issued via the dashboard (or a
//! direct curl with a KC JWT) and copied into LLM clients (Claude
//! Desktop / Cursor / ChatGPT). The MCP server hashes incoming bearers
//! and looks them up in the cross-tenant `<prefix>__admin.mcp_tokens`
//! collection to derive the tenant_id for routing.
//!
//! The raw token is shown to the caller exactly once at creation; the
//! database only ever stores the SHA-256 hash. Revocation is a soft
//! delete (sets `revoked: true`) so the audit log keeps the record.
use axum::extract::{Extension, Path};
use axum::http::StatusCode;
use axum::Json;
use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine as _};
use compliance_core::models::{McpToken, McpTokenView};
use compliance_core::tenant_ctx::TenantCtx;
use mongodb::bson::doc;
use rand::RngCore;
use sha2::{Digest, Sha256};
use super::dto::{AgentExt, ApiResponse};
/// Mongo collection name inside the admin DB.
const COLLECTION: &str = "mcp_tokens";
/// Token prefix the MCP server expects on every bearer.
const TOKEN_PREFIX: &str = "mcpt_";
/// Bytes of randomness behind each token. 32 → ~256 bits.
/// Encoded as URL-safe base64 without padding → 43 chars.
/// Combined with `mcpt_` → 48-char tokens.
const TOKEN_RAND_BYTES: usize = 32;
#[derive(serde::Deserialize)]
pub struct CreateMcpTokenRequest {
pub name: String,
}
/// Returned exactly once at creation. The `token` field is gone from
/// the listing endpoint — the user must save it now.
#[derive(serde::Serialize)]
pub struct CreateMcpTokenResponse {
pub token: String,
pub view: McpTokenView,
}
/// `POST /api/v1/mcp-tokens` — mint a new token for the caller's tenant.
#[tracing::instrument(skip_all)]
pub async fn create_mcp_token(
Extension(agent): AgentExt,
tenant: TenantCtx,
Json(req): Json<CreateMcpTokenRequest>,
) -> Result<Json<CreateMcpTokenResponse>, StatusCode> {
if req.name.trim().is_empty() {
return Err(StatusCode::BAD_REQUEST);
}
let raw = generate_token();
let token_hash = sha256_hex(&raw);
let token_prefix: String = raw.chars().take(12).collect();
let mut token = McpToken {
id: None,
token_hash,
token_prefix,
tenant_id: tenant.0.tenant_id.clone(),
name: req.name.trim().to_string(),
created_by: tenant.0.user_id.clone(),
created_at: chrono::Utc::now(),
last_used_at: None,
revoked: false,
};
let col = agent.db_pool.admin_db().collection::<McpToken>(COLLECTION);
let res = col.insert_one(&token).await.map_err(|e| {
tracing::error!("Failed to insert MCP token: {e}");
StatusCode::INTERNAL_SERVER_ERROR
})?;
token.id = res.inserted_id.as_object_id();
Ok(Json(CreateMcpTokenResponse {
view: McpTokenView::from(&token),
token: raw,
}))
}
/// `GET /api/v1/mcp-tokens` — list tokens for the caller's tenant.
/// Hash is never returned; only metadata + the 12-char prefix so the
/// user can identify which row is which.
#[tracing::instrument(skip_all)]
pub async fn list_mcp_tokens(
Extension(agent): AgentExt,
tenant: TenantCtx,
) -> Result<Json<ApiResponse<Vec<McpTokenView>>>, StatusCode> {
let col = agent.db_pool.admin_db().collection::<McpToken>(COLLECTION);
let mut cursor = col
.find(doc! { "tenant_id": &tenant.0.tenant_id })
.sort(doc! { "created_at": -1 })
.await
.map_err(|e| {
tracing::error!("Failed to list MCP tokens: {e}");
StatusCode::INTERNAL_SERVER_ERROR
})?;
let mut out = Vec::new();
while cursor.advance().await.map_err(|e| {
tracing::warn!("MCP tokens cursor advance failed: {e}");
StatusCode::INTERNAL_SERVER_ERROR
})? {
match cursor.deserialize_current() {
Ok(t) => out.push(McpTokenView::from(&t)),
Err(e) => tracing::warn!("Failed to deserialize MCP token: {e}"),
}
}
Ok(Json(ApiResponse {
data: out,
total: None,
page: None,
}))
}
/// `DELETE /api/v1/mcp-tokens/{id}` — revoke (soft delete).
/// Scoped to the caller's tenant: a user can't revoke another tenant's
/// token even if they guess its id.
#[tracing::instrument(skip_all, fields(id = %id))]
pub async fn revoke_mcp_token(
Extension(agent): AgentExt,
tenant: TenantCtx,
Path(id): Path<String>,
) -> Result<Json<serde_json::Value>, StatusCode> {
let oid = mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?;
let col = agent.db_pool.admin_db().collection::<McpToken>(COLLECTION);
let result = col
.update_one(
doc! { "_id": oid, "tenant_id": &tenant.0.tenant_id },
doc! { "$set": { "revoked": true } },
)
.await
.map_err(|e| {
tracing::error!("Failed to revoke MCP token: {e}");
StatusCode::INTERNAL_SERVER_ERROR
})?;
if result.matched_count == 0 {
return Err(StatusCode::NOT_FOUND);
}
Ok(Json(serde_json::json!({ "status": "revoked" })))
}
/// 32 bytes random → URL-safe base64 → 43 chars, no padding.
/// Prefixed with `mcpt_` so the MCP server can sniff the format
/// before bothering with the DB lookup.
fn generate_token() -> String {
let mut bytes = [0u8; TOKEN_RAND_BYTES];
rand::rng().fill_bytes(&mut bytes);
format!("{TOKEN_PREFIX}{}", URL_SAFE_NO_PAD.encode(bytes))
}
fn sha256_hex(s: &str) -> String {
let mut h = Sha256::new();
h.update(s.as_bytes());
hex::encode(h.finalize())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn generated_tokens_are_unique_and_prefixed() {
let a = generate_token();
let b = generate_token();
assert_ne!(a, b);
assert!(a.starts_with(TOKEN_PREFIX));
assert!(b.starts_with(TOKEN_PREFIX));
// 5 + 43 = 48 chars
assert_eq!(a.len(), 5 + 43);
}
#[test]
fn sha256_is_stable_and_64_hex() {
let h = sha256_hex("mcpt_abc");
assert_eq!(h.len(), 64);
assert!(h.chars().all(|c| c.is_ascii_hexdigit()));
assert_eq!(sha256_hex("mcpt_abc"), h);
}
}
-1
View File
@@ -6,7 +6,6 @@ pub mod graph;
pub mod health; pub mod health;
pub mod help_chat; pub mod help_chat;
pub mod issues; pub mod issues;
pub mod mcp_tokens;
pub mod notifications; pub mod notifications;
pub mod pentest_handlers; pub mod pentest_handlers;
pub use pentest_handlers as pentest; pub use pentest_handlers as pentest;
@@ -5,18 +5,15 @@ use mongodb::bson::doc;
use serde::Deserialize; use serde::Deserialize;
use compliance_core::models::notification::CveNotification; use compliance_core::models::notification::CveNotification;
use compliance_core::tenant_ctx::TenantCtx;
use super::dto::{tenant_db, AgentExt, ApiResponse}; use super::dto::{AgentExt, ApiResponse};
/// GET /api/v1/notifications — List CVE notifications (newest first) /// GET /api/v1/notifications — List CVE notifications (newest first)
#[tracing::instrument(skip_all)] #[tracing::instrument(skip_all)]
pub async fn list_notifications( pub async fn list_notifications(
Extension(agent): AgentExt, Extension(agent): AgentExt,
tenant: TenantCtx,
axum::extract::Query(params): axum::extract::Query<NotificationFilter>, axum::extract::Query(params): axum::extract::Query<NotificationFilter>,
) -> Result<Json<ApiResponse<Vec<CveNotification>>>, StatusCode> { ) -> Result<Json<ApiResponse<Vec<CveNotification>>>, StatusCode> {
let db = tenant_db(&agent, &tenant).await?;
let mut filter = doc! {}; let mut filter = doc! {};
// Filter by status (default: show new + read, exclude dismissed) // Filter by status (default: show new + read, exclude dismissed)
@@ -44,13 +41,15 @@ pub async fn list_notifications(
let limit = params.limit.unwrap_or(50).min(200); let limit = params.limit.unwrap_or(50).min(200);
let skip = (page - 1) * limit as u64; let skip = (page - 1) * limit as u64;
let total = db let total = agent
.db
.cve_notifications() .cve_notifications()
.count_documents(filter.clone()) .count_documents(filter.clone())
.await .await
.unwrap_or(0); .unwrap_or(0);
let notifications: Vec<CveNotification> = match db let notifications: Vec<CveNotification> = match agent
.db
.cve_notifications() .cve_notifications()
.find(filter) .find(filter)
.sort(doc! { "created_at": -1 }) .sort(doc! { "created_at": -1 })
@@ -84,10 +83,9 @@ pub async fn list_notifications(
#[tracing::instrument(skip_all)] #[tracing::instrument(skip_all)]
pub async fn notification_count( pub async fn notification_count(
Extension(agent): AgentExt, Extension(agent): AgentExt,
tenant: TenantCtx,
) -> Result<Json<serde_json::Value>, StatusCode> { ) -> Result<Json<serde_json::Value>, StatusCode> {
let db = tenant_db(&agent, &tenant).await?; let count = agent
let count = db .db
.cve_notifications() .cve_notifications()
.count_documents(doc! { "status": "new" }) .count_documents(doc! { "status": "new" })
.await .await
@@ -100,13 +98,12 @@ pub async fn notification_count(
#[tracing::instrument(skip_all, fields(id = %id))] #[tracing::instrument(skip_all, fields(id = %id))]
pub async fn mark_read( pub async fn mark_read(
Extension(agent): AgentExt, Extension(agent): AgentExt,
tenant: TenantCtx,
axum::extract::Path(id): axum::extract::Path<String>, axum::extract::Path(id): axum::extract::Path<String>,
) -> Result<Json<serde_json::Value>, StatusCode> { ) -> Result<Json<serde_json::Value>, StatusCode> {
let oid = mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?; let oid = mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?;
let db = tenant_db(&agent, &tenant).await?;
let result = db let result = agent
.db
.cve_notifications() .cve_notifications()
.update_one( .update_one(
doc! { "_id": oid }, doc! { "_id": oid },
@@ -128,13 +125,12 @@ pub async fn mark_read(
#[tracing::instrument(skip_all, fields(id = %id))] #[tracing::instrument(skip_all, fields(id = %id))]
pub async fn dismiss_notification( pub async fn dismiss_notification(
Extension(agent): AgentExt, Extension(agent): AgentExt,
tenant: TenantCtx,
axum::extract::Path(id): axum::extract::Path<String>, axum::extract::Path(id): axum::extract::Path<String>,
) -> Result<Json<serde_json::Value>, StatusCode> { ) -> Result<Json<serde_json::Value>, StatusCode> {
let oid = mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?; let oid = mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?;
let db = tenant_db(&agent, &tenant).await?;
let result = db let result = agent
.db
.cve_notifications() .cve_notifications()
.update_one( .update_one(
doc! { "_id": oid }, doc! { "_id": oid },
@@ -153,10 +149,9 @@ pub async fn dismiss_notification(
#[tracing::instrument(skip_all)] #[tracing::instrument(skip_all)]
pub async fn mark_all_read( pub async fn mark_all_read(
Extension(agent): AgentExt, Extension(agent): AgentExt,
tenant: TenantCtx,
) -> Result<Json<serde_json::Value>, StatusCode> { ) -> Result<Json<serde_json::Value>, StatusCode> {
let db = tenant_db(&agent, &tenant).await?; let result = agent
let result = db .db
.cve_notifications() .cve_notifications()
.update_many( .update_many(
doc! { "status": "new" }, doc! { "status": "new" },
@@ -13,11 +13,10 @@ use compliance_core::models::dast::DastFinding;
use compliance_core::models::finding::Finding; use compliance_core::models::finding::Finding;
use compliance_core::models::pentest::*; use compliance_core::models::pentest::*;
use compliance_core::models::sbom::SbomEntry; use compliance_core::models::sbom::SbomEntry;
use compliance_core::tenant_ctx::TenantCtx;
use crate::agent::ComplianceAgent; use crate::agent::ComplianceAgent;
use super::super::dto::{collect_cursor_async, tenant_db}; use super::super::dto::collect_cursor_async;
type AgentExt = Extension<Arc<ComplianceAgent>>; type AgentExt = Extension<Arc<ComplianceAgent>>;
@@ -36,15 +35,11 @@ pub struct ExportBody {
#[tracing::instrument(skip_all, fields(session_id = %id))] #[tracing::instrument(skip_all, fields(session_id = %id))]
pub async fn export_session_report( pub async fn export_session_report(
Extension(agent): AgentExt, Extension(agent): AgentExt,
tenant: TenantCtx,
Path(id): Path<String>, Path(id): Path<String>,
Json(body): Json<ExportBody>, Json(body): Json<ExportBody>,
) -> Result<axum::response::Response, (StatusCode, String)> { ) -> Result<axum::response::Response, (StatusCode, String)> {
let oid = mongodb::bson::oid::ObjectId::parse_str(&id) let oid = mongodb::bson::oid::ObjectId::parse_str(&id)
.map_err(|_| (StatusCode::BAD_REQUEST, "Invalid session ID".to_string()))?; .map_err(|_| (StatusCode::BAD_REQUEST, "Invalid session ID".to_string()))?;
let db = tenant_db(&agent, &tenant)
.await
.map_err(|s| (s, "failed to acquire tenant database".to_string()))?;
if body.password.len() < 8 { if body.password.len() < 8 {
return Err(( return Err((
@@ -54,7 +49,8 @@ pub async fn export_session_report(
} }
// Fetch session // Fetch session
let session = db let session = agent
.db
.pentest_sessions() .pentest_sessions()
.find_one(doc! { "_id": oid }) .find_one(doc! { "_id": oid })
.await .await
@@ -68,7 +64,9 @@ pub async fn export_session_report(
// Resolve target name // Resolve target name
let target = if let Ok(tid) = mongodb::bson::oid::ObjectId::parse_str(&session.target_id) { let target = if let Ok(tid) = mongodb::bson::oid::ObjectId::parse_str(&session.target_id) {
db.dast_targets() agent
.db
.dast_targets()
.find_one(doc! { "_id": tid }) .find_one(doc! { "_id": tid })
.await .await
.ok() .ok()
@@ -86,7 +84,8 @@ pub async fn export_session_report(
.unwrap_or_default(); .unwrap_or_default();
// Fetch attack chain nodes // Fetch attack chain nodes
let nodes: Vec<AttackChainNode> = match db let nodes: Vec<AttackChainNode> = match agent
.db
.attack_chain_nodes() .attack_chain_nodes()
.find(doc! { "session_id": &id }) .find(doc! { "session_id": &id })
.sort(doc! { "started_at": 1 }) .sort(doc! { "started_at": 1 })
@@ -97,7 +96,8 @@ pub async fn export_session_report(
}; };
// Fetch DAST findings for this session, then deduplicate // Fetch DAST findings for this session, then deduplicate
let raw_findings: Vec<DastFinding> = match db let raw_findings: Vec<DastFinding> = match agent
.db
.dast_findings() .dast_findings()
.find(doc! { "session_id": &id }) .find(doc! { "session_id": &id })
.sort(doc! { "severity": -1, "created_at": -1 }) .sort(doc! { "severity": -1, "created_at": -1 })
@@ -122,7 +122,8 @@ pub async fn export_session_report(
.or_else(|| target.as_ref().and_then(|t| t.repo_id.clone())); .or_else(|| target.as_ref().and_then(|t| t.repo_id.clone()));
let (sast_findings, sbom_entries, code_context) = if let Some(ref rid) = repo_id { let (sast_findings, sbom_entries, code_context) = if let Some(ref rid) = repo_id {
let sast: Vec<Finding> = match db let sast: Vec<Finding> = match agent
.db
.findings() .findings()
.find(doc! { .find(doc! {
"repo_id": rid, "repo_id": rid,
@@ -142,7 +143,8 @@ pub async fn export_session_report(
Err(_) => Vec::new(), Err(_) => Vec::new(),
}; };
let sbom: Vec<SbomEntry> = match db let sbom: Vec<SbomEntry> = match agent
.db
.sbom_entries() .sbom_entries()
.find(doc! { .find(doc! {
"repo_id": rid, "repo_id": rid,
@@ -162,7 +164,8 @@ pub async fn export_session_report(
}; };
// Build code context from graph nodes // Build code context from graph nodes
let code_ctx: Vec<CodeContextHint> = match db let code_ctx: Vec<CodeContextHint> = match agent
.db
.graph_nodes() .graph_nodes()
.find(doc! { "repo_id": rid, "is_entry_point": true }) .find(doc! { "repo_id": rid, "is_entry_point": true })
.limit(50) .limit(50)
@@ -7,12 +7,11 @@ use mongodb::bson::doc;
use serde::Deserialize; use serde::Deserialize;
use compliance_core::models::pentest::*; use compliance_core::models::pentest::*;
use compliance_core::tenant_ctx::TenantCtx;
use crate::agent::ComplianceAgent; use crate::agent::ComplianceAgent;
use crate::pentest::PentestOrchestrator; use crate::pentest::PentestOrchestrator;
use super::super::dto::{collect_cursor_async, tenant_db, ApiResponse, PaginationParams}; use super::super::dto::{collect_cursor_async, ApiResponse, PaginationParams};
type AgentExt = Extension<Arc<ComplianceAgent>>; type AgentExt = Extension<Arc<ComplianceAgent>>;
@@ -44,7 +43,6 @@ pub struct LookupRepoQuery {
#[tracing::instrument(skip_all)] #[tracing::instrument(skip_all)]
pub async fn create_session( pub async fn create_session(
Extension(agent): AgentExt, Extension(agent): AgentExt,
tenant: TenantCtx,
Json(req): Json<CreateSessionRequest>, Json(req): Json<CreateSessionRequest>,
) -> Result<Json<ApiResponse<PentestSession>>, (StatusCode, String)> { ) -> Result<Json<ApiResponse<PentestSession>>, (StatusCode, String)> {
// Try to acquire a concurrency permit // Try to acquire a concurrency permit
@@ -59,10 +57,6 @@ pub async fn create_session(
) )
})?; })?;
let db = tenant_db(&agent, &tenant)
.await
.map_err(|s| (s, "failed to acquire tenant database".to_string()))?;
if let Some(ref config) = req.config { if let Some(ref config) = req.config {
// ── Wizard path ────────────────────────────────────────────── // ── Wizard path ──────────────────────────────────────────────
if !config.disclaimer_accepted { if !config.disclaimer_accepted {
@@ -73,7 +67,8 @@ pub async fn create_session(
} }
// Look up or auto-create DastTarget by app_url // Look up or auto-create DastTarget by app_url
let target = match db let target = match agent
.db
.dast_targets() .dast_targets()
.find_one(doc! { "base_url": &config.app_url }) .find_one(doc! { "base_url": &config.app_url })
.await .await
@@ -92,7 +87,7 @@ pub async fn create_session(
} }
t.allow_destructive = config.allow_destructive; t.allow_destructive = config.allow_destructive;
t.excluded_paths = config.scope_exclusions.clone(); t.excluded_paths = config.scope_exclusions.clone();
let res = db.dast_targets().insert_one(&t).await.map_err(|e| { let res = agent.db.dast_targets().insert_one(&t).await.map_err(|e| {
( (
StatusCode::INTERNAL_SERVER_ERROR, StatusCode::INTERNAL_SERVER_ERROR,
format!("Failed to create target: {e}"), format!("Failed to create target: {e}"),
@@ -115,7 +110,8 @@ pub async fn create_session(
// Resolve repo_id from git_repo_url if provided // Resolve repo_id from git_repo_url if provided
if let Some(ref git_url) = config.git_repo_url { if let Some(ref git_url) = config.git_repo_url {
if let Ok(Some(repo)) = db if let Ok(Some(repo)) = agent
.db
.repositories() .repositories()
.find_one(doc! { "git_url": git_url }) .find_one(doc! { "git_url": git_url })
.await .await
@@ -124,7 +120,8 @@ pub async fn create_session(
} }
} }
let insert_result = db let insert_result = agent
.db
.pentest_sessions() .pentest_sessions()
.insert_one(&session) .insert_one(&session)
.await .await
@@ -215,7 +212,8 @@ pub async fn create_session(
// Persist encrypted credentials to DB // Persist encrypted credentials to DB
if session_for_task.config.is_some() { if session_for_task.config.is_some() {
if let Some(sid) = session.id { if let Some(sid) = session.id {
let _ = db let _ = agent
.db
.pentest_sessions() .pentest_sessions()
.update_one( .update_one(
doc! { "_id": sid }, doc! { "_id": sid },
@@ -247,13 +245,12 @@ pub async fn create_session(
}); });
let llm = agent.llm.clone(); let llm = agent.llm.clone();
let db_for_orchestrator = db.clone(); let db = agent.db.clone();
let session_clone = session.clone(); let session_clone = session.clone();
let target_clone = target.clone(); let target_clone = target.clone();
let agent_ref = agent.clone(); let agent_ref = agent.clone();
tokio::spawn(async move { tokio::spawn(async move {
let orchestrator = let orchestrator = PentestOrchestrator::new(llm, db, event_tx, Some(pause_rx));
PentestOrchestrator::new(llm, db_for_orchestrator, event_tx, Some(pause_rx));
orchestrator orchestrator
.run_session_guarded(&session_clone, &target_clone, &initial_message) .run_session_guarded(&session_clone, &target_clone, &initial_message)
.await; .await;
@@ -295,7 +292,8 @@ pub async fn create_session(
) )
})?; })?;
let target = db let target = agent
.db
.dast_targets() .dast_targets()
.find_one(doc! { "_id": oid }) .find_one(doc! { "_id": oid })
.await .await
@@ -312,7 +310,8 @@ pub async fn create_session(
let mut session = PentestSession::new(target_id, strategy); let mut session = PentestSession::new(target_id, strategy);
session.repo_id = target.repo_id.clone(); session.repo_id = target.repo_id.clone();
let insert_result = db let insert_result = agent
.db
.pentest_sessions() .pentest_sessions()
.insert_one(&session) .insert_one(&session)
.await .await
@@ -339,13 +338,12 @@ pub async fn create_session(
}); });
let llm = agent.llm.clone(); let llm = agent.llm.clone();
let db_for_orchestrator = db.clone(); let db = agent.db.clone();
let session_clone = session.clone(); let session_clone = session.clone();
let target_clone = target.clone(); let target_clone = target.clone();
let agent_ref = agent.clone(); let agent_ref = agent.clone();
tokio::spawn(async move { tokio::spawn(async move {
let orchestrator = let orchestrator = PentestOrchestrator::new(llm, db, event_tx, Some(pause_rx));
PentestOrchestrator::new(llm, db_for_orchestrator, event_tx, Some(pause_rx));
orchestrator orchestrator
.run_session_guarded(&session_clone, &target_clone, &initial_message) .run_session_guarded(&session_clone, &target_clone, &initial_message)
.await; .await;
@@ -375,11 +373,10 @@ fn parse_strategy(s: &str) -> PentestStrategy {
#[tracing::instrument(skip_all)] #[tracing::instrument(skip_all)]
pub async fn lookup_repo( pub async fn lookup_repo(
Extension(agent): AgentExt, Extension(agent): AgentExt,
tenant: TenantCtx,
Query(params): Query<LookupRepoQuery>, Query(params): Query<LookupRepoQuery>,
) -> Result<Json<ApiResponse<serde_json::Value>>, StatusCode> { ) -> Result<Json<ApiResponse<serde_json::Value>>, StatusCode> {
let db = tenant_db(&agent, &tenant).await?; let repo = agent
let repo = db .db
.repositories() .repositories()
.find_one(doc! { "git_url": &params.url }) .find_one(doc! { "git_url": &params.url })
.await .await
@@ -405,11 +402,9 @@ pub async fn lookup_repo(
#[tracing::instrument(skip_all)] #[tracing::instrument(skip_all)]
pub async fn list_sessions( pub async fn list_sessions(
Extension(agent): AgentExt, Extension(agent): AgentExt,
tenant: TenantCtx,
Query(params): Query<PaginationParams>, Query(params): Query<PaginationParams>,
) -> Result<Json<ApiResponse<Vec<PentestSession>>>, StatusCode> { ) -> Result<Json<ApiResponse<Vec<PentestSession>>>, StatusCode> {
let db = tenant_db(&agent, &tenant).await?; let db = &agent.db;
let db = &db;
let skip = (params.page.saturating_sub(1)) * params.limit as u64; let skip = (params.page.saturating_sub(1)) * params.limit as u64;
let total = db let total = db
.pentest_sessions() .pentest_sessions()
@@ -443,13 +438,12 @@ pub async fn list_sessions(
#[tracing::instrument(skip_all, fields(session_id = %id))] #[tracing::instrument(skip_all, fields(session_id = %id))]
pub async fn get_session( pub async fn get_session(
Extension(agent): AgentExt, Extension(agent): AgentExt,
tenant: TenantCtx,
Path(id): Path<String>, Path(id): Path<String>,
) -> Result<Json<ApiResponse<PentestSession>>, StatusCode> { ) -> Result<Json<ApiResponse<PentestSession>>, StatusCode> {
let oid = mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?; let oid = mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?;
let db = tenant_db(&agent, &tenant).await?;
let mut session = db let mut session = agent
.db
.pentest_sessions() .pentest_sessions()
.find_one(doc! { "_id": oid }) .find_one(doc! { "_id": oid })
.await .await
@@ -477,18 +471,15 @@ pub async fn get_session(
#[tracing::instrument(skip_all, fields(session_id = %id))] #[tracing::instrument(skip_all, fields(session_id = %id))]
pub async fn send_message( pub async fn send_message(
Extension(agent): AgentExt, Extension(agent): AgentExt,
tenant: TenantCtx,
Path(id): Path<String>, Path(id): Path<String>,
Json(req): Json<SendMessageRequest>, Json(req): Json<SendMessageRequest>,
) -> Result<Json<ApiResponse<PentestMessage>>, (StatusCode, String)> { ) -> Result<Json<ApiResponse<PentestMessage>>, (StatusCode, String)> {
let oid = mongodb::bson::oid::ObjectId::parse_str(&id) let oid = mongodb::bson::oid::ObjectId::parse_str(&id)
.map_err(|_| (StatusCode::BAD_REQUEST, "Invalid session ID".to_string()))?; .map_err(|_| (StatusCode::BAD_REQUEST, "Invalid session ID".to_string()))?;
let db = tenant_db(&agent, &tenant)
.await
.map_err(|s| (s, "failed to acquire tenant database".to_string()))?;
// Verify session exists and is running // Verify session exists and is running
let session = db let session = agent
.db
.pentest_sessions() .pentest_sessions()
.find_one(doc! { "_id": oid }) .find_one(doc! { "_id": oid })
.await .await
@@ -515,7 +506,8 @@ pub async fn send_message(
) )
})?; })?;
let target = db let target = agent
.db
.dast_targets() .dast_targets()
.find_one(doc! { "_id": target_oid }) .find_one(doc! { "_id": target_oid })
.await .await
@@ -535,13 +527,13 @@ pub async fn send_message(
// Store user message // Store user message
let session_id = id.clone(); let session_id = id.clone();
let user_msg = PentestMessage::user(session_id.clone(), req.message.clone()); let user_msg = PentestMessage::user(session_id.clone(), req.message.clone());
let _ = db.pentest_messages().insert_one(&user_msg).await; let _ = agent.db.pentest_messages().insert_one(&user_msg).await;
let response_msg = user_msg.clone(); let response_msg = user_msg.clone();
// Spawn orchestrator to continue the session // Spawn orchestrator to continue the session
let llm = agent.llm.clone(); let llm = agent.llm.clone();
let db_for_orchestrator = db.clone(); let db = agent.db.clone();
let message = req.message.clone(); let message = req.message.clone();
// Use existing broadcast sender if available, otherwise create a new one // Use existing broadcast sender if available, otherwise create a new one
@@ -556,7 +548,7 @@ pub async fn send_message(
.unwrap_or_else(|| agent.register_session_stream(&session_id)); .unwrap_or_else(|| agent.register_session_stream(&session_id));
tokio::spawn(async move { tokio::spawn(async move {
let orchestrator = PentestOrchestrator::new(llm, db_for_orchestrator, event_tx, None); let orchestrator = PentestOrchestrator::new(llm, db, event_tx, None);
orchestrator orchestrator
.run_session_guarded(&session, &target, &message) .run_session_guarded(&session, &target, &message)
.await; .await;
@@ -573,16 +565,13 @@ pub async fn send_message(
#[tracing::instrument(skip_all, fields(session_id = %id))] #[tracing::instrument(skip_all, fields(session_id = %id))]
pub async fn stop_session( pub async fn stop_session(
Extension(agent): AgentExt, Extension(agent): AgentExt,
tenant: TenantCtx,
Path(id): Path<String>, Path(id): Path<String>,
) -> Result<Json<ApiResponse<PentestSession>>, (StatusCode, String)> { ) -> Result<Json<ApiResponse<PentestSession>>, (StatusCode, String)> {
let oid = mongodb::bson::oid::ObjectId::parse_str(&id) let oid = mongodb::bson::oid::ObjectId::parse_str(&id)
.map_err(|_| (StatusCode::BAD_REQUEST, "Invalid session ID".to_string()))?; .map_err(|_| (StatusCode::BAD_REQUEST, "Invalid session ID".to_string()))?;
let db = tenant_db(&agent, &tenant)
.await
.map_err(|s| (s, "failed to acquire tenant database".to_string()))?;
let session = db let session = agent
.db
.pentest_sessions() .pentest_sessions()
.find_one(doc! { "_id": oid }) .find_one(doc! { "_id": oid })
.await .await
@@ -601,7 +590,9 @@ pub async fn stop_session(
)); ));
} }
db.pentest_sessions() agent
.db
.pentest_sessions()
.update_one( .update_one(
doc! { "_id": oid }, doc! { "_id": oid },
doc! { "$set": { doc! { "$set": {
@@ -621,7 +612,8 @@ pub async fn stop_session(
// Clean up session resources // Clean up session resources
agent.cleanup_session(&id); agent.cleanup_session(&id);
let updated = db let updated = agent
.db
.pentest_sessions() .pentest_sessions()
.find_one(doc! { "_id": oid }) .find_one(doc! { "_id": oid })
.await .await
@@ -649,16 +641,13 @@ pub async fn stop_session(
#[tracing::instrument(skip_all, fields(session_id = %id))] #[tracing::instrument(skip_all, fields(session_id = %id))]
pub async fn pause_session( pub async fn pause_session(
Extension(agent): AgentExt, Extension(agent): AgentExt,
tenant: TenantCtx,
Path(id): Path<String>, Path(id): Path<String>,
) -> Result<Json<ApiResponse<serde_json::Value>>, (StatusCode, String)> { ) -> Result<Json<ApiResponse<serde_json::Value>>, (StatusCode, String)> {
let oid = mongodb::bson::oid::ObjectId::parse_str(&id) let oid = mongodb::bson::oid::ObjectId::parse_str(&id)
.map_err(|_| (StatusCode::BAD_REQUEST, "Invalid session ID".to_string()))?; .map_err(|_| (StatusCode::BAD_REQUEST, "Invalid session ID".to_string()))?;
let db = tenant_db(&agent, &tenant)
.await
.map_err(|s| (s, "failed to acquire tenant database".to_string()))?;
let session = db let session = agent
.db
.pentest_sessions() .pentest_sessions()
.find_one(doc! { "_id": oid }) .find_one(doc! { "_id": oid })
.await .await
@@ -695,16 +684,13 @@ pub async fn pause_session(
#[tracing::instrument(skip_all, fields(session_id = %id))] #[tracing::instrument(skip_all, fields(session_id = %id))]
pub async fn resume_session( pub async fn resume_session(
Extension(agent): AgentExt, Extension(agent): AgentExt,
tenant: TenantCtx,
Path(id): Path<String>, Path(id): Path<String>,
) -> Result<Json<ApiResponse<serde_json::Value>>, (StatusCode, String)> { ) -> Result<Json<ApiResponse<serde_json::Value>>, (StatusCode, String)> {
let oid = mongodb::bson::oid::ObjectId::parse_str(&id) let oid = mongodb::bson::oid::ObjectId::parse_str(&id)
.map_err(|_| (StatusCode::BAD_REQUEST, "Invalid session ID".to_string()))?; .map_err(|_| (StatusCode::BAD_REQUEST, "Invalid session ID".to_string()))?;
let db = tenant_db(&agent, &tenant)
.await
.map_err(|s| (s, "failed to acquire tenant database".to_string()))?;
let session = db let session = agent
.db
.pentest_sessions() .pentest_sessions()
.find_one(doc! { "_id": oid }) .find_one(doc! { "_id": oid })
.await .await
@@ -741,13 +727,12 @@ pub async fn resume_session(
#[tracing::instrument(skip_all, fields(session_id = %id))] #[tracing::instrument(skip_all, fields(session_id = %id))]
pub async fn get_attack_chain( pub async fn get_attack_chain(
Extension(agent): AgentExt, Extension(agent): AgentExt,
tenant: TenantCtx,
Path(id): Path<String>, Path(id): Path<String>,
) -> Result<Json<ApiResponse<Vec<AttackChainNode>>>, StatusCode> { ) -> Result<Json<ApiResponse<Vec<AttackChainNode>>>, StatusCode> {
let _oid = mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?; let _oid = mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?;
let db = tenant_db(&agent, &tenant).await?;
let nodes = match db let nodes = match agent
.db
.attack_chain_nodes() .attack_chain_nodes()
.find(doc! { "session_id": &id }) .find(doc! { "session_id": &id })
.sort(doc! { "started_at": 1 }) .sort(doc! { "started_at": 1 })
@@ -772,21 +757,21 @@ pub async fn get_attack_chain(
#[tracing::instrument(skip_all, fields(session_id = %id))] #[tracing::instrument(skip_all, fields(session_id = %id))]
pub async fn get_messages( pub async fn get_messages(
Extension(agent): AgentExt, Extension(agent): AgentExt,
tenant: TenantCtx,
Path(id): Path<String>, Path(id): Path<String>,
Query(params): Query<PaginationParams>, Query(params): Query<PaginationParams>,
) -> Result<Json<ApiResponse<Vec<PentestMessage>>>, StatusCode> { ) -> Result<Json<ApiResponse<Vec<PentestMessage>>>, StatusCode> {
let _oid = mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?; let _oid = mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?;
let db = tenant_db(&agent, &tenant).await?;
let skip = (params.page.saturating_sub(1)) * params.limit as u64; let skip = (params.page.saturating_sub(1)) * params.limit as u64;
let total = db let total = agent
.db
.pentest_messages() .pentest_messages()
.count_documents(doc! { "session_id": &id }) .count_documents(doc! { "session_id": &id })
.await .await
.unwrap_or(0); .unwrap_or(0);
let messages = match db let messages = match agent
.db
.pentest_messages() .pentest_messages()
.find(doc! { "session_id": &id }) .find(doc! { "session_id": &id })
.sort(doc! { "created_at": 1 }) .sort(doc! { "created_at": 1 })
@@ -812,21 +797,21 @@ pub async fn get_messages(
#[tracing::instrument(skip_all, fields(session_id = %id))] #[tracing::instrument(skip_all, fields(session_id = %id))]
pub async fn get_session_findings( pub async fn get_session_findings(
Extension(agent): AgentExt, Extension(agent): AgentExt,
tenant: TenantCtx,
Path(id): Path<String>, Path(id): Path<String>,
Query(params): Query<PaginationParams>, Query(params): Query<PaginationParams>,
) -> Result<Json<ApiResponse<Vec<compliance_core::models::dast::DastFinding>>>, StatusCode> { ) -> Result<Json<ApiResponse<Vec<compliance_core::models::dast::DastFinding>>>, StatusCode> {
let _oid = mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?; let _oid = mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?;
let db = tenant_db(&agent, &tenant).await?;
let skip = (params.page.saturating_sub(1)) * params.limit as u64; let skip = (params.page.saturating_sub(1)) * params.limit as u64;
let total = db let total = agent
.db
.dast_findings() .dast_findings()
.count_documents(doc! { "session_id": &id }) .count_documents(doc! { "session_id": &id })
.await .await
.unwrap_or(0); .unwrap_or(0);
let findings = match db let findings = match agent
.db
.dast_findings() .dast_findings()
.find(doc! { "session_id": &id }) .find(doc! { "session_id": &id })
.sort(doc! { "created_at": -1 }) .sort(doc! { "created_at": -1 })
@@ -6,11 +6,10 @@ use axum::Json;
use mongodb::bson::doc; use mongodb::bson::doc;
use compliance_core::models::pentest::*; use compliance_core::models::pentest::*;
use compliance_core::tenant_ctx::TenantCtx;
use crate::agent::ComplianceAgent; use crate::agent::ComplianceAgent;
use super::super::dto::{collect_cursor_async, tenant_db, ApiResponse}; use super::super::dto::{collect_cursor_async, ApiResponse};
type AgentExt = Extension<Arc<ComplianceAgent>>; type AgentExt = Extension<Arc<ComplianceAgent>>;
@@ -18,10 +17,8 @@ type AgentExt = Extension<Arc<ComplianceAgent>>;
#[tracing::instrument(skip_all)] #[tracing::instrument(skip_all)]
pub async fn pentest_stats( pub async fn pentest_stats(
Extension(agent): AgentExt, Extension(agent): AgentExt,
tenant: TenantCtx,
) -> Result<Json<ApiResponse<PentestStats>>, StatusCode> { ) -> Result<Json<ApiResponse<PentestStats>>, StatusCode> {
let db = tenant_db(&agent, &tenant).await?; let db = &agent.db;
let db = &db;
let running_sessions = db let running_sessions = db
.pentest_sessions() .pentest_sessions()
@@ -11,11 +11,10 @@ use tokio_stream::wrappers::BroadcastStream;
use tokio_stream::StreamExt; use tokio_stream::StreamExt;
use compliance_core::models::pentest::*; use compliance_core::models::pentest::*;
use compliance_core::tenant_ctx::TenantCtx;
use crate::agent::ComplianceAgent; use crate::agent::ComplianceAgent;
use super::super::dto::{collect_cursor_async, tenant_db}; use super::super::dto::collect_cursor_async;
type AgentExt = Extension<Arc<ComplianceAgent>>; type AgentExt = Extension<Arc<ComplianceAgent>>;
@@ -26,14 +25,13 @@ type AgentExt = Extension<Arc<ComplianceAgent>>;
#[tracing::instrument(skip_all, fields(session_id = %id))] #[tracing::instrument(skip_all, fields(session_id = %id))]
pub async fn session_stream( pub async fn session_stream(
Extension(agent): AgentExt, Extension(agent): AgentExt,
tenant: TenantCtx,
Path(id): Path<String>, Path(id): Path<String>,
) -> Result<Sse<impl futures_util::Stream<Item = Result<Event, Infallible>>>, StatusCode> { ) -> Result<Sse<impl futures_util::Stream<Item = Result<Event, Infallible>>>, StatusCode> {
let oid = mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?; let oid = mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?;
let db = tenant_db(&agent, &tenant).await?;
// Verify session exists // Verify session exists
let _session = db let _session = agent
.db
.pentest_sessions() .pentest_sessions()
.find_one(doc! { "_id": oid }) .find_one(doc! { "_id": oid })
.await .await
@@ -45,7 +43,8 @@ pub async fn session_stream(
let mut initial_events: Vec<Result<Event, Infallible>> = Vec::new(); let mut initial_events: Vec<Result<Event, Infallible>> = Vec::new();
// Fetch recent messages for this session // Fetch recent messages for this session
let messages: Vec<PentestMessage> = match db let messages: Vec<PentestMessage> = match agent
.db
.pentest_messages() .pentest_messages()
.find(doc! { "session_id": &id }) .find(doc! { "session_id": &id })
.sort(doc! { "created_at": 1 }) .sort(doc! { "created_at": 1 })
@@ -57,7 +56,8 @@ pub async fn session_stream(
}; };
// Fetch recent attack chain nodes // Fetch recent attack chain nodes
let nodes: Vec<AttackChainNode> = match db let nodes: Vec<AttackChainNode> = match agent
.db
.attack_chain_nodes() .attack_chain_nodes()
.find(doc! { "session_id": &id }) .find(doc! { "session_id": &id })
.sort(doc! { "started_at": 1 }) .sort(doc! { "started_at": 1 })
@@ -94,7 +94,8 @@ pub async fn session_stream(
} }
// Add current session status event // Add current session status event
let session = db let session = agent
.db
.pentest_sessions() .pentest_sessions()
.find_one(doc! { "_id": oid }) .find_one(doc! { "_id": oid })
.await .await
+17 -28
View File
@@ -5,16 +5,13 @@ use mongodb::bson::doc;
use super::dto::*; use super::dto::*;
use compliance_core::models::*; use compliance_core::models::*;
use compliance_core::tenant_ctx::TenantCtx;
#[tracing::instrument(skip_all)] #[tracing::instrument(skip_all)]
pub async fn list_repositories( pub async fn list_repositories(
Extension(agent): AgentExt, Extension(agent): AgentExt,
tenant: TenantCtx,
Query(params): Query<PaginationParams>, Query(params): Query<PaginationParams>,
) -> ApiResult<Vec<TrackedRepository>> { ) -> ApiResult<Vec<TrackedRepository>> {
let db = tenant_db(&agent, &tenant).await?; let db = &agent.db;
let db = &db;
let skip = (params.page.saturating_sub(1)) * params.limit as u64; let skip = (params.page.saturating_sub(1)) * params.limit as u64;
let total = db let total = db
.repositories() .repositories()
@@ -46,7 +43,6 @@ pub async fn list_repositories(
#[tracing::instrument(skip_all)] #[tracing::instrument(skip_all)]
pub async fn add_repository( pub async fn add_repository(
Extension(agent): AgentExt, Extension(agent): AgentExt,
tenant: TenantCtx,
Json(req): Json<AddRepositoryRequest>, Json(req): Json<AddRepositoryRequest>,
) -> Result<Json<ApiResponse<TrackedRepository>>, (StatusCode, String)> { ) -> Result<Json<ApiResponse<TrackedRepository>>, (StatusCode, String)> {
// Validate repository access before saving // Validate repository access before saving
@@ -73,15 +69,17 @@ pub async fn add_repository(
repo.tracker_token = req.tracker_token; repo.tracker_token = req.tracker_token;
repo.scan_schedule = req.scan_schedule; repo.scan_schedule = req.scan_schedule;
let db = tenant_db(&agent, &tenant) agent
.db
.repositories()
.insert_one(&repo)
.await .await
.map_err(|s| (s, "failed to acquire tenant database".to_string()))?; .map_err(|_| {
db.repositories().insert_one(&repo).await.map_err(|_| { (
( StatusCode::CONFLICT,
StatusCode::CONFLICT, "Repository already exists".to_string(),
"Repository already exists".to_string(), )
) })?;
})?;
Ok(Json(ApiResponse { Ok(Json(ApiResponse {
data: repo, data: repo,
@@ -93,12 +91,10 @@ pub async fn add_repository(
#[tracing::instrument(skip_all, fields(repo_id = %id))] #[tracing::instrument(skip_all, fields(repo_id = %id))]
pub async fn update_repository( pub async fn update_repository(
Extension(agent): AgentExt, Extension(agent): AgentExt,
tenant: TenantCtx,
Path(id): Path<String>, Path(id): Path<String>,
Json(req): Json<UpdateRepositoryRequest>, Json(req): Json<UpdateRepositoryRequest>,
) -> Result<Json<serde_json::Value>, StatusCode> { ) -> Result<Json<serde_json::Value>, StatusCode> {
let oid = mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?; let oid = mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?;
let db = tenant_db(&agent, &tenant).await?;
let mut set_doc = doc! { "updated_at": mongodb::bson::DateTime::now() }; let mut set_doc = doc! { "updated_at": mongodb::bson::DateTime::now() };
@@ -130,7 +126,8 @@ pub async fn update_repository(
set_doc.insert("scan_schedule", schedule); set_doc.insert("scan_schedule", schedule);
} }
let result = db let result = agent
.db
.repositories() .repositories()
.update_one(doc! { "_id": oid }, doc! { "$set": set_doc }) .update_one(doc! { "_id": oid }, doc! { "$set": set_doc })
.await .await
@@ -158,16 +155,11 @@ pub async fn get_ssh_public_key(
#[tracing::instrument(skip_all, fields(repo_id = %id))] #[tracing::instrument(skip_all, fields(repo_id = %id))]
pub async fn trigger_scan( pub async fn trigger_scan(
Extension(agent): AgentExt, Extension(agent): AgentExt,
tenant: TenantCtx,
Path(id): Path<String>, Path(id): Path<String>,
) -> Result<Json<serde_json::Value>, StatusCode> { ) -> Result<Json<serde_json::Value>, StatusCode> {
let agent_clone = (*agent).clone(); let agent_clone = (*agent).clone();
let tenant_id = tenant.0.tenant_id.clone();
tokio::spawn(async move { tokio::spawn(async move {
if let Err(e) = agent_clone if let Err(e) = agent_clone.run_scan(&id, ScanTrigger::Manual).await {
.run_scan(&tenant_id, &id, ScanTrigger::Manual)
.await
{
tracing::error!("Manual scan failed for {id}: {e}"); tracing::error!("Manual scan failed for {id}: {e}");
} }
}); });
@@ -178,12 +170,11 @@ pub async fn trigger_scan(
/// Return the webhook secret for a repository (used by dashboard to display it) /// Return the webhook secret for a repository (used by dashboard to display it)
pub async fn get_webhook_config( pub async fn get_webhook_config(
Extension(agent): AgentExt, Extension(agent): AgentExt,
tenant: TenantCtx,
Path(id): Path<String>, Path(id): Path<String>,
) -> Result<Json<serde_json::Value>, StatusCode> { ) -> Result<Json<serde_json::Value>, StatusCode> {
let oid = mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?; let oid = mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?;
let db = tenant_db(&agent, &tenant).await?; let repo = agent
let repo = db .db
.repositories() .repositories()
.find_one(doc! { "_id": oid }) .find_one(doc! { "_id": oid })
.await .await
@@ -205,12 +196,10 @@ pub async fn get_webhook_config(
#[tracing::instrument(skip_all, fields(repo_id = %id))] #[tracing::instrument(skip_all, fields(repo_id = %id))]
pub async fn delete_repository( pub async fn delete_repository(
Extension(agent): AgentExt, Extension(agent): AgentExt,
tenant: TenantCtx,
Path(id): Path<String>, Path(id): Path<String>,
) -> Result<Json<serde_json::Value>, StatusCode> { ) -> Result<Json<serde_json::Value>, StatusCode> {
let oid = mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?; let oid = mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?;
let db = tenant_db(&agent, &tenant).await?; let db = &agent.db;
let db = &db;
// Delete the repository // Delete the repository
let result = db let result = db
+5 -16
View File
@@ -6,7 +6,6 @@ use mongodb::bson::doc;
use super::dto::*; use super::dto::*;
use compliance_core::models::SbomEntry; use compliance_core::models::SbomEntry;
use compliance_core::tenant_ctx::TenantCtx;
const COPYLEFT_LICENSES: &[&str] = &[ const COPYLEFT_LICENSES: &[&str] = &[
"GPL-2.0", "GPL-2.0",
@@ -30,10 +29,8 @@ const COPYLEFT_LICENSES: &[&str] = &[
#[tracing::instrument(skip_all)] #[tracing::instrument(skip_all)]
pub async fn sbom_filters( pub async fn sbom_filters(
Extension(agent): AgentExt, Extension(agent): AgentExt,
tenant: TenantCtx,
) -> Result<Json<serde_json::Value>, StatusCode> { ) -> Result<Json<serde_json::Value>, StatusCode> {
let db = tenant_db(&agent, &tenant).await?; let db = &agent.db;
let db = &db;
let managers: Vec<String> = db let managers: Vec<String> = db
.sbom_entries() .sbom_entries()
@@ -64,11 +61,9 @@ pub async fn sbom_filters(
#[tracing::instrument(skip_all, fields(repo_id = ?filter.repo_id, package_manager = ?filter.package_manager))] #[tracing::instrument(skip_all, fields(repo_id = ?filter.repo_id, package_manager = ?filter.package_manager))]
pub async fn list_sbom( pub async fn list_sbom(
Extension(agent): AgentExt, Extension(agent): AgentExt,
tenant: TenantCtx,
Query(filter): Query<SbomFilter>, Query(filter): Query<SbomFilter>,
) -> ApiResult<Vec<SbomEntry>> { ) -> ApiResult<Vec<SbomEntry>> {
let db = tenant_db(&agent, &tenant).await?; let db = &agent.db;
let db = &db;
let mut query = doc! {}; let mut query = doc! {};
if let Some(repo_id) = &filter.repo_id { if let Some(repo_id) = &filter.repo_id {
@@ -125,11 +120,9 @@ pub async fn list_sbom(
#[tracing::instrument(skip_all)] #[tracing::instrument(skip_all)]
pub async fn export_sbom( pub async fn export_sbom(
Extension(agent): AgentExt, Extension(agent): AgentExt,
tenant: TenantCtx,
Query(params): Query<SbomExportParams>, Query(params): Query<SbomExportParams>,
) -> Result<impl IntoResponse, StatusCode> { ) -> Result<impl IntoResponse, StatusCode> {
let db = tenant_db(&agent, &tenant).await?; let db = &agent.db;
let db = &db;
let entries: Vec<SbomEntry> = match db let entries: Vec<SbomEntry> = match db
.sbom_entries() .sbom_entries()
.find(doc! { "repo_id": &params.repo_id }) .find(doc! { "repo_id": &params.repo_id })
@@ -243,11 +236,9 @@ pub async fn export_sbom(
#[tracing::instrument(skip_all)] #[tracing::instrument(skip_all)]
pub async fn license_summary( pub async fn license_summary(
Extension(agent): AgentExt, Extension(agent): AgentExt,
tenant: TenantCtx,
Query(params): Query<SbomFilter>, Query(params): Query<SbomFilter>,
) -> ApiResult<Vec<LicenseSummary>> { ) -> ApiResult<Vec<LicenseSummary>> {
let db = tenant_db(&agent, &tenant).await?; let db = &agent.db;
let db = &db;
let mut query = doc! {}; let mut query = doc! {};
if let Some(repo_id) = &params.repo_id { if let Some(repo_id) = &params.repo_id {
query.insert("repo_id", repo_id); query.insert("repo_id", repo_id);
@@ -294,11 +285,9 @@ pub async fn license_summary(
#[tracing::instrument(skip_all)] #[tracing::instrument(skip_all)]
pub async fn sbom_diff( pub async fn sbom_diff(
Extension(agent): AgentExt, Extension(agent): AgentExt,
tenant: TenantCtx,
Query(params): Query<SbomDiffParams>, Query(params): Query<SbomDiffParams>,
) -> ApiResult<SbomDiffResult> { ) -> ApiResult<SbomDiffResult> {
let db = tenant_db(&agent, &tenant).await?; let db = &agent.db;
let db = &db;
let entries_a: Vec<SbomEntry> = match db let entries_a: Vec<SbomEntry> = match db
.sbom_entries() .sbom_entries()
+1 -4
View File
@@ -4,16 +4,13 @@ use mongodb::bson::doc;
use super::dto::*; use super::dto::*;
use compliance_core::models::ScanRun; use compliance_core::models::ScanRun;
use compliance_core::tenant_ctx::TenantCtx;
#[tracing::instrument(skip_all)] #[tracing::instrument(skip_all)]
pub async fn list_scan_runs( pub async fn list_scan_runs(
Extension(agent): AgentExt, Extension(agent): AgentExt,
tenant: TenantCtx,
Query(params): Query<PaginationParams>, Query(params): Query<PaginationParams>,
) -> ApiResult<Vec<ScanRun>> { ) -> ApiResult<Vec<ScanRun>> {
let db = tenant_db(&agent, &tenant).await?; let db = &agent.db;
let db = &db;
let skip = (params.page.saturating_sub(1)) * params.limit as u64; let skip = (params.page.saturating_sub(1)) * params.limit as u64;
let total = db.scan_runs().count_documents(doc! {}).await.unwrap_or(0); let total = db.scan_runs().count_documents(doc! {}).await.unwrap_or(0);
+1
View File
@@ -1,3 +1,4 @@
pub mod auth_middleware;
pub mod handlers; pub mod handlers;
pub mod routes; pub mod routes;
pub mod server; pub mod server;
-9
View File
@@ -47,15 +47,6 @@ pub fn build_router() -> Router {
.route("/api/v1/sbom/diff", get(handlers::sbom_diff)) .route("/api/v1/sbom/diff", get(handlers::sbom_diff))
.route("/api/v1/issues", get(handlers::list_issues)) .route("/api/v1/issues", get(handlers::list_issues))
.route("/api/v1/scan-runs", get(handlers::list_scan_runs)) .route("/api/v1/scan-runs", get(handlers::list_scan_runs))
// MCP token management (per-tenant API tokens for the MCP server)
.route(
"/api/v1/mcp-tokens",
get(handlers::mcp_tokens::list_mcp_tokens).post(handlers::mcp_tokens::create_mcp_token),
)
.route(
"/api/v1/mcp-tokens/{id}",
delete(handlers::mcp_tokens::revoke_mcp_token),
)
// Graph API endpoints // Graph API endpoints
.route("/api/v1/graph/{repo_id}", get(handlers::graph::get_graph)) .route("/api/v1/graph/{repo_id}", get(handlers::graph::get_graph))
.route( .route(
+4 -52
View File
@@ -1,54 +1,17 @@
use std::sync::Arc; use std::sync::Arc;
use axum::extract::Request;
use axum::http::HeaderValue; use axum::http::HeaderValue;
use axum::middleware::Next;
use axum::response::Response;
use axum::{middleware, Extension}; use axum::{middleware, Extension};
use tokio::sync::RwLock; use tokio::sync::RwLock;
use tower_http::cors::CorsLayer; use tower_http::cors::CorsLayer;
use tower_http::set_header::SetResponseHeaderLayer; use tower_http::set_header::SetResponseHeaderLayer;
use tower_http::trace::TraceLayer; use tower_http::trace::TraceLayer;
use compliance_core::auth::{require_jwt_auth, require_tenant_status, JwksState};
use compliance_core::{TenantContext, TenantStatus};
use crate::agent::ComplianceAgent; use crate::agent::ComplianceAgent;
use crate::api::auth_middleware::{require_jwt_auth, JwksState};
use crate::api::routes; use crate::api::routes;
use crate::error::AgentError; use crate::error::AgentError;
/// Synthetic tenant id used when Keycloak isn't configured (local dev,
/// `cargo run` against a bare Mongo). Lets the handler stack stay
/// uniformly tenant-scoped without the operator having to spin up KC
/// just to poke at the API. Override via `DEV_TENANT_ID`.
const DEFAULT_DEV_TENANT_ID: &str = "dev";
/// Inject a synthetic [`TenantContext`] for any request that lacks one.
/// Only mounted when Keycloak is NOT configured; with KC, the real
/// `require_jwt_auth` middleware owns this and we never reach here
/// without a context.
///
/// Public so the integration-test harness can mount it without
/// duplicating the synthetic-context shape.
pub async fn inject_dev_tenant(mut request: Request, next: Next) -> Response {
if request.extensions().get::<TenantContext>().is_none() {
let tenant_id =
std::env::var("DEV_TENANT_ID").unwrap_or_else(|_| DEFAULT_DEV_TENANT_ID.to_string());
let ctx = TenantContext {
tenant_slug: tenant_id.clone(),
tenant_id,
org_roles: vec![],
products: vec![],
plan: "dev".to_string(),
status: TenantStatus::Active,
user_id: "dev-user".to_string(),
user_name: None,
};
request.extensions_mut().insert(ctx);
}
next.run(request).await
}
pub async fn start_api_server(agent: ComplianceAgent, port: u16) -> Result<(), AgentError> { pub async fn start_api_server(agent: ComplianceAgent, port: u16) -> Result<(), AgentError> {
let mut app = routes::build_router() let mut app = routes::build_router()
.layer(Extension(Arc::new(agent.clone()))) .layer(Extension(Arc::new(agent.clone())))
@@ -81,22 +44,11 @@ pub async fn start_api_server(agent: ComplianceAgent, port: u16) -> Result<(), A
jwks_url, jwks_url,
}; };
tracing::info!("Keycloak JWT auth enabled for realm '{kc_realm}'"); tracing::info!("Keycloak JWT auth enabled for realm '{kc_realm}'");
// Layers execute outermost-first. Extension(jwks_state) must run
// before require_jwt_auth so the middleware can read it; the
// status gate runs after JWT so TenantContext is in extensions.
app = app app = app
.layer(middleware::from_fn(require_tenant_status)) .layer(Extension(jwks_state))
.layer(middleware::from_fn(require_jwt_auth)) .layer(middleware::from_fn(require_jwt_auth));
.layer(Extension(jwks_state));
} else { } else {
let tenant_id = tracing::warn!("Keycloak not configured - API endpoints are unprotected");
std::env::var("DEV_TENANT_ID").unwrap_or_else(|_| DEFAULT_DEV_TENANT_ID.to_string());
tracing::warn!(
tenant_id = %tenant_id,
"Keycloak not configured — running unauthenticated against the dev tenant. \
DO NOT use in any environment with real customer data."
);
app = app.layer(middleware::from_fn(inject_dev_tenant));
} }
let addr = format!("0.0.0.0:{port}"); let addr = format!("0.0.0.0:{port}");
-211
View File
@@ -1,216 +1,11 @@
use std::sync::Arc;
use dashmap::DashMap;
use mongodb::bson::doc; use mongodb::bson::doc;
use mongodb::options::IndexOptions; use mongodb::options::IndexOptions;
use mongodb::{Client, Collection, IndexModel}; use mongodb::{Client, Collection, IndexModel};
use sha2::{Digest, Sha256};
use compliance_core::models::*; use compliance_core::models::*;
use compliance_core::TenantContext;
use crate::error::AgentError; use crate::error::AgentError;
/// Mongo enforces a 63-byte cap on database names (older clusters: 64
/// on Linux, 63 on Windows; we target the conservative limit).
const MAX_DB_NAME_LEN: usize = 63;
/// Hex length of the SHA-256 truncation used for the hash fallback
/// tenant DB name (16 bytes → 32 hex chars). 16 bytes gives ~2^64
/// birthday-collision resistance — at our 10s-100s tenant scale this
/// is effectively impossible to hit.
const HASH_HEX_LEN: usize = 32;
/// Largest `db_prefix` that still guarantees the hash-fallback name
/// fits in the 63-byte cap: `prefix + "_" + 32 hex chars`.
const MAX_PREFIX_LEN: usize = MAX_DB_NAME_LEN - 1 - HASH_HEX_LEN;
/// Per-tenant Mongo connection broker (M7.2 isolation model).
///
/// Holds one [`Client`] and hands out [`Database`] handles physically
/// scoped to `<db_prefix>_<tenant_id>`. The driver is the isolation
/// boundary — a handle for tenant A cannot see tenant B's documents
/// because it is connected to a different database, not because of an
/// application-level filter.
///
/// Index creation runs idempotently the first time each tenant is seen
/// in the process's lifetime. Mongo's `createIndex` is itself idempotent
/// by index name; the in-memory `ensured` set just skips the round-trip.
#[derive(Clone, Debug)]
pub struct DatabasePool {
client: Client,
db_prefix: String,
ensured: Arc<DashMap<String, ()>>,
}
impl DatabasePool {
/// Connect to the cluster and prepare to hand out tenant databases
/// named `<db_prefix>_<tenant_id>`.
///
/// Validates `db_prefix.len() <= MAX_PREFIX_LEN` so the
/// hash-fallback path is provably within Mongo's 63-byte db-name
/// cap. Refuses to construct a pool that could ever produce an
/// over-long name.
pub async fn connect(uri: &str, db_prefix: &str) -> Result<Self, AgentError> {
if db_prefix.len() > MAX_PREFIX_LEN {
return Err(AgentError::Other(format!(
"db_prefix '{db_prefix}' is {} chars; max is {MAX_PREFIX_LEN} so the \
hash-fallback tenant DB name fits Mongo's {MAX_DB_NAME_LEN}-byte cap",
db_prefix.len()
)));
}
let client = Client::with_uri_str(uri).await?;
client
.database("admin")
.run_command(doc! { "ping": 1 })
.await?;
tracing::info!(
"MongoDB cluster reachable; per-tenant pool ready (db prefix '{db_prefix}')"
);
Ok(Self {
client,
db_prefix: db_prefix.to_string(),
ensured: Arc::new(DashMap::new()),
})
}
/// Return a [`Database`] scoped to this tenant. Ensures indexes on
/// first call per tenant (per process). Cheap on the hot path —
/// subsequent calls skip the round-trip.
pub async fn for_tenant(&self, ctx: &TenantContext) -> Result<Database, AgentError> {
self.for_tenant_id(&ctx.tenant_id).await
}
/// Like [`Self::for_tenant`] but accepts a bare tenant_id.
/// For background paths (scheduler, webhooks, pipeline orchestrators)
/// that don't have a full [`TenantContext`] but know which tenant
/// they're operating on (typically resolved from a URL path, a job
/// argument, or the registry).
pub async fn for_tenant_id(&self, tenant_id: &str) -> Result<Database, AgentError> {
let db_name = self.tenant_db_name(tenant_id);
let db = Database::from_database(self.client.database(&db_name));
// `DashMap::insert` returns the previous value; `None` means we
// were the first writer for this tenant_id and own the
// index-ensure work.
if self.ensured.insert(tenant_id.to_string(), ()).is_none() {
if let Err(e) = db.ensure_indexes().await {
// Roll the marker back so the next request retries.
self.ensured.remove(tenant_id);
return Err(e);
}
tracing::debug!(
tenant_id = %tenant_id,
db_name = %db_name,
"Indexes ensured for tenant database"
);
}
Ok(db)
}
/// Compute the Mongo database name for a tenant. Public for tests
/// and tenant offboarding (`pool.client().database(name).drop()`).
///
/// Format: `<prefix>_<sanitized_tenant_id>` if it fits the 63-byte
/// cap, else `<prefix>_<sha256-16-byte-hex-of-tenant_id>`. The
/// `db_prefix` length invariant established at [`Self::connect`]
/// guarantees the hash-fallback name always fits — no runtime
/// assertion needed.
///
/// Collision resistance: the hash fallback is a 16-byte SHA-256
/// truncation, which gives ~2^64 birthday-collision resistance. At
/// our 10s100s tenant scale the probability of two tenant_ids
/// colliding is effectively zero. (8-byte truncation would have
/// been ~2^32 — too close for comfort on a regulated product.)
pub fn tenant_db_name(&self, tenant_id: &str) -> String {
let sanitized = sanitize_tenant_id(tenant_id);
let natural = format!("{}_{}", self.db_prefix, sanitized);
if natural.len() <= MAX_DB_NAME_LEN {
natural
} else {
let mut hasher = Sha256::new();
hasher.update(tenant_id.as_bytes());
let digest = hasher.finalize();
let suffix = hex::encode(&digest[..HASH_HEX_LEN / 2]);
format!("{}_{}", self.db_prefix, suffix)
}
}
/// Raw client handle. Reserved for cross-tenant admin flows that
/// must opt in explicitly (tenant listing, drop-on-offboard).
pub fn client(&self) -> &Client {
&self.client
}
/// Cross-tenant admin database used by features that intentionally
/// span tenants (today: MCP bearer tokens — each token row carries
/// a `tenant_id` and the MCP server reads them to route requests).
///
/// The name `<db_prefix>__admin` (double underscore) is reserved —
/// the sanitizer never produces it for a normal tenant DB because
/// the natural format is `<db_prefix>_<sanitized_tenant_id>` (one
/// underscore) and tenant_ids would have to start with `_admin` to
/// collide. New tenant provisioning should reject such ids.
pub fn admin_db(&self) -> mongodb::Database {
self.client.database(&self.admin_db_name())
}
/// Name of the admin database — public so tests / operators can
/// drop it via the raw client.
pub fn admin_db_name(&self) -> String {
format!("{}__admin", self.db_prefix)
}
/// List every Mongo database currently belonging to this pool,
/// identified by the `<db_prefix>_` prefix. The result is the raw
/// database names — opening one for offboarding/cleanup goes
/// through [`Self::client`].
///
/// Note: hashed-fallback names (very long tenant_ids) lose the
/// original tenant_id at the cluster level — we know a database
/// exists for *some* tenant but not which one. In practice
/// tenant_ids are UUIDs (36 chars) and never hit the fallback,
/// so this is a theoretical concern, not an operational one.
pub async fn list_tenant_db_names(&self) -> Result<Vec<String>, AgentError> {
let prefix = format!("{}_", self.db_prefix);
let names = self.client.list_database_names().await?;
Ok(names
.into_iter()
.filter(|n| n.starts_with(&prefix))
.collect())
}
/// Drop the database for a specific tenant. Used by GDPR delete
/// and tenant offboarding. Idempotent — dropping a non-existent
/// database is a no-op at the driver level.
///
/// Also evicts the tenant from the in-memory `ensured` set so a
/// later re-provision triggers fresh `ensure_indexes`.
pub async fn drop_tenant(&self, tenant_id: &str) -> Result<(), AgentError> {
let db_name = self.tenant_db_name(tenant_id);
self.client.database(&db_name).drop().await?;
self.ensured.remove(tenant_id);
tracing::info!(
tenant_id = %tenant_id,
db_name = %db_name,
"Dropped tenant database"
);
Ok(())
}
}
/// Mongo database names disallow `/`, `\`, `.`, `"`, `$`, ` `, and NUL.
/// breakpilot-dev tenant_ids are UUIDs so this is belt-and-braces, but
/// it lets the pool tolerate any future tenant_id shape without surprise.
fn sanitize_tenant_id(tenant_id: &str) -> String {
tenant_id
.chars()
.map(|c| match c {
'/' | '\\' | '.' | '"' | '$' | ' ' | '\0' => '_',
c => c,
})
.collect()
}
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub struct Database { pub struct Database {
inner: mongodb::Database, inner: mongodb::Database,
@@ -225,12 +20,6 @@ impl Database {
Ok(Self { inner: db }) Ok(Self { inner: db })
} }
/// Wrap an already-resolved Mongo database. Used by [`DatabasePool`]
/// to hand out tenant-scoped handles without a fresh client per tenant.
pub(crate) fn from_database(inner: mongodb::Database) -> Self {
Self { inner }
}
pub async fn ensure_indexes(&self) -> Result<(), AgentError> { pub async fn ensure_indexes(&self) -> Result<(), AgentError> {
// repositories: unique git_url // repositories: unique git_url
self.repositories() self.repositories()
+3 -6
View File
@@ -25,13 +25,10 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
} }
tracing::info!("Connecting to MongoDB..."); tracing::info!("Connecting to MongoDB...");
// Per-tenant pool only — the agent has no shared "default" database let db = database::Database::connect(&config.mongodb_uri, &config.mongodb_database).await?;
// after M7.2-D. `mongodb_database` is now the db-name prefix used db.ensure_indexes().await?;
// for tenant databases (`<prefix>_<tenant_id>`).
let db_pool =
database::DatabasePool::connect(&config.mongodb_uri, &config.mongodb_database).await?;
let agent = agent::ComplianceAgent::new(config.clone(), db_pool); let agent = agent::ComplianceAgent::new(config.clone(), db.clone());
tracing::info!("Starting scheduler..."); tracing::info!("Starting scheduler...");
let scheduler_agent = agent.clone(); let scheduler_agent = agent.clone();
+22 -77
View File
@@ -4,14 +4,8 @@ use tokio_cron_scheduler::{Job, JobScheduler};
use compliance_core::models::ScanTrigger; use compliance_core::models::ScanTrigger;
use crate::agent::ComplianceAgent; use crate::agent::ComplianceAgent;
use crate::database::Database;
use crate::error::AgentError; use crate::error::AgentError;
/// Default tenant the scheduler runs against when `SCHEDULER_TENANT_IDS`
/// isn't set. Matches the dev-injector default so a bare `cargo run` has
/// the scheduler scanning whatever lives in `<prefix>_dev`.
const DEFAULT_SCHEDULER_TENANT_ID: &str = "dev";
pub async fn start_scheduler(agent: &ComplianceAgent) -> Result<(), AgentError> { pub async fn start_scheduler(agent: &ComplianceAgent) -> Result<(), AgentError> {
let sched = JobScheduler::new() let sched = JobScheduler::new()
.await .await
@@ -24,9 +18,7 @@ pub async fn start_scheduler(agent: &ComplianceAgent) -> Result<(), AgentError>
let agent = scan_agent.clone(); let agent = scan_agent.clone();
Box::pin(async move { Box::pin(async move {
tracing::info!("Scheduled scan triggered"); tracing::info!("Scheduled scan triggered");
for tenant_id in scheduler_tenants() { scan_all_repos(&agent).await;
scan_all_repos(&agent, &tenant_id).await;
}
}) })
}) })
.map_err(|e| AgentError::Scheduler(format!("Failed to create scan job: {e}")))?; .map_err(|e| AgentError::Scheduler(format!("Failed to create scan job: {e}")))?;
@@ -42,9 +34,7 @@ pub async fn start_scheduler(agent: &ComplianceAgent) -> Result<(), AgentError>
let agent = cve_agent.clone(); let agent = cve_agent.clone();
Box::pin(async move { Box::pin(async move {
tracing::info!("CVE monitor triggered"); tracing::info!("CVE monitor triggered");
for tenant_id in scheduler_tenants() { monitor_cves(&agent).await;
monitor_cves(&agent, &tenant_id).await;
}
}) })
}) })
.map_err(|e| AgentError::Scheduler(format!("Failed to create CVE monitor job: {e}")))?; .map_err(|e| AgentError::Scheduler(format!("Failed to create CVE monitor job: {e}")))?;
@@ -58,9 +48,8 @@ pub async fn start_scheduler(agent: &ComplianceAgent) -> Result<(), AgentError>
.await .await
.map_err(|e| AgentError::Scheduler(format!("Failed to start scheduler: {e}")))?; .map_err(|e| AgentError::Scheduler(format!("Failed to start scheduler: {e}")))?;
let tenants = scheduler_tenants();
tracing::info!( tracing::info!(
"Scheduler started: scans='{}', CVE monitor='{}', tenants={tenants:?}", "Scheduler started: scans='{}', CVE monitor='{}'",
agent.config.scan_schedule, agent.config.scan_schedule,
agent.config.cve_monitor_schedule, agent.config.cve_monitor_schedule,
); );
@@ -71,47 +60,13 @@ pub async fn start_scheduler(agent: &ComplianceAgent) -> Result<(), AgentError>
} }
} }
/// Tenants the scheduler iterates each tick. From `SCHEDULER_TENANT_IDS` async fn scan_all_repos(agent: &ComplianceAgent) {
/// (comma-separated), or `DEFAULT_SCHEDULER_TENANT_ID` if unset. M7.2-D
/// will replace this with a pull from the tenant-registry.
fn scheduler_tenants() -> Vec<String> {
std::env::var("SCHEDULER_TENANT_IDS")
.ok()
.map(|s| {
s.split(',')
.map(str::trim)
.filter(|s| !s.is_empty())
.map(String::from)
.collect::<Vec<_>>()
})
.filter(|v| !v.is_empty())
.unwrap_or_else(|| vec![DEFAULT_SCHEDULER_TENANT_ID.to_string()])
}
/// Resolve the per-tenant database. Logs and returns `None` on failure
/// so the loop in the caller can continue with other tenants.
async fn tenant_db(agent: &ComplianceAgent, tenant_id: &str) -> Option<Database> {
match agent.db_pool.for_tenant_id(tenant_id).await {
Ok(db) => Some(db),
Err(e) => {
tracing::error!("Scheduler: cannot open tenant database '{tenant_id}': {e}");
None
}
}
}
async fn scan_all_repos(agent: &ComplianceAgent, tenant_id: &str) {
use futures_util::StreamExt; use futures_util::StreamExt;
let db = match tenant_db(agent, tenant_id).await { let cursor = match agent.db.repositories().find(doc! {}).await {
Some(db) => db,
None => return,
};
let cursor = match db.repositories().find(doc! {}).await {
Ok(c) => c, Ok(c) => c,
Err(e) => { Err(e) => {
tracing::error!("Failed to list repos for tenant '{tenant_id}': {e}"); tracing::error!("Failed to list repos for scheduled scan: {e}");
return; return;
} }
}; };
@@ -120,44 +75,33 @@ async fn scan_all_repos(agent: &ComplianceAgent, tenant_id: &str) {
for repo in repos { for repo in repos {
let repo_id = repo.id.map(|id| id.to_hex()).unwrap_or_default(); let repo_id = repo.id.map(|id| id.to_hex()).unwrap_or_default();
if let Err(e) = agent if let Err(e) = agent.run_scan(&repo_id, ScanTrigger::Scheduled).await {
.run_scan(tenant_id, &repo_id, ScanTrigger::Scheduled) tracing::error!("Scheduled scan failed for {}: {e}", repo.name);
.await
{
tracing::error!(
"Scheduled scan failed for {} (tenant '{tenant_id}'): {e}",
repo.name
);
} }
} }
} }
async fn monitor_cves(agent: &ComplianceAgent, tenant_id: &str) { async fn monitor_cves(agent: &ComplianceAgent) {
use compliance_core::models::notification::{parse_severity, CveNotification}; use compliance_core::models::notification::{parse_severity, CveNotification};
use compliance_core::models::SbomEntry; use compliance_core::models::SbomEntry;
use futures_util::StreamExt; use futures_util::StreamExt;
let db = match tenant_db(agent, tenant_id).await {
Some(db) => db,
None => return,
};
// Fetch all SBOM entries grouped by repo // Fetch all SBOM entries grouped by repo
let cursor = match db.sbom_entries().find(doc! {}).await { let cursor = match agent.db.sbom_entries().find(doc! {}).await {
Ok(c) => c, Ok(c) => c,
Err(e) => { Err(e) => {
tracing::error!("CVE monitor: failed to list SBOM entries for '{tenant_id}': {e}"); tracing::error!("CVE monitor: failed to list SBOM entries: {e}");
return; return;
} }
}; };
let entries: Vec<SbomEntry> = cursor.filter_map(|r| async { r.ok() }).collect().await; let entries: Vec<SbomEntry> = cursor.filter_map(|r| async { r.ok() }).collect().await;
if entries.is_empty() { if entries.is_empty() {
tracing::debug!("CVE monitor: no SBOM entries for tenant '{tenant_id}', skipping"); tracing::debug!("CVE monitor: no SBOM entries, skipping");
return; return;
} }
tracing::info!( tracing::info!(
"CVE monitor: checking {} dependencies for new CVEs (tenant '{tenant_id}')", "CVE monitor: checking {} dependencies for new CVEs",
entries.len() entries.len()
); );
@@ -168,7 +112,7 @@ async fn monitor_cves(agent: &ComplianceAgent, tenant_id: &str) {
std::collections::HashMap::new(); std::collections::HashMap::new();
for rid in &repo_ids { for rid in &repo_ids {
if let Ok(oid) = mongodb::bson::oid::ObjectId::parse_str(rid) { if let Ok(oid) = mongodb::bson::oid::ObjectId::parse_str(rid) {
if let Ok(Some(repo)) = db.repositories().find_one(doc! { "_id": oid }).await { if let Ok(Some(repo)) = agent.db.repositories().find_one(doc! { "_id": oid }).await {
repo_names.insert(rid.clone(), repo.name.clone()); repo_names.insert(rid.clone(), repo.name.clone());
} }
} }
@@ -216,7 +160,8 @@ async fn monitor_cves(agent: &ComplianceAgent, tenant_id: &str) {
for alert in &alerts { for alert in &alerts {
let filter = doc! { "cve_id": &alert.cve_id, "repo_id": &alert.repo_id }; let filter = doc! { "cve_id": &alert.cve_id, "repo_id": &alert.repo_id };
let update = doc! { "$setOnInsert": mongodb::bson::to_bson(alert).unwrap_or_default() }; let update = doc! { "$setOnInsert": mongodb::bson::to_bson(alert).unwrap_or_default() };
let _ = db let _ = agent
.db
.cve_alerts() .cve_alerts()
.update_one(filter, update) .update_one(filter, update)
.upsert(true) .upsert(true)
@@ -229,7 +174,8 @@ async fn monitor_cves(agent: &ComplianceAgent, tenant_id: &str) {
continue; continue;
} }
if let Some(entry_id) = &entry.id { if let Some(entry_id) = &entry.id {
let _ = db let _ = agent
.db
.sbom_entries() .sbom_entries()
.update_one( .update_one(
doc! { "_id": entry_id }, doc! { "_id": entry_id },
@@ -267,7 +213,8 @@ async fn monitor_cves(agent: &ComplianceAgent, tenant_id: &str) {
let update = doc! { let update = doc! {
"$setOnInsert": mongodb::bson::to_bson(&notification).unwrap_or_default() "$setOnInsert": mongodb::bson::to_bson(&notification).unwrap_or_default()
}; };
match db match agent
.db
.cve_notifications() .cve_notifications()
.update_one(filter, update) .update_one(filter, update)
.upsert(true) .upsert(true)
@@ -285,10 +232,8 @@ async fn monitor_cves(agent: &ComplianceAgent, tenant_id: &str) {
} }
if new_notifications > 0 { if new_notifications > 0 {
tracing::info!( tracing::info!("CVE monitor: created {new_notifications} new notification(s)");
"CVE monitor: created {new_notifications} new notification(s) for tenant '{tenant_id}'"
);
} else { } else {
tracing::info!("CVE monitor: no new CVEs found for tenant '{tenant_id}'"); tracing::info!("CVE monitor: no new CVEs found");
} }
} }
+9 -23
View File
@@ -14,30 +14,24 @@ type HmacSha256 = Hmac<Sha256>;
pub async fn handle_gitea_webhook( pub async fn handle_gitea_webhook(
Extension(agent): Extension<Arc<ComplianceAgent>>, Extension(agent): Extension<Arc<ComplianceAgent>>,
Path((tenant_id, repo_id)): Path<(String, String)>, Path(repo_id): Path<String>,
headers: HeaderMap, headers: HeaderMap,
body: Bytes, body: Bytes,
) -> StatusCode { ) -> StatusCode {
// Look up the repo in the tenant's database to get its webhook secret // Look up the repo to get its webhook secret
let oid = match mongodb::bson::oid::ObjectId::parse_str(&repo_id) { let oid = match mongodb::bson::oid::ObjectId::parse_str(&repo_id) {
Ok(oid) => oid, Ok(oid) => oid,
Err(_) => return StatusCode::NOT_FOUND, Err(_) => return StatusCode::NOT_FOUND,
}; };
let db = match agent.db_pool.for_tenant_id(&tenant_id).await { let repo = match agent
Ok(db) => db, .db
Err(e) => {
tracing::warn!("Gitea webhook: cannot open tenant database '{tenant_id}': {e}");
return StatusCode::NOT_FOUND;
}
};
let repo = match db
.repositories() .repositories()
.find_one(mongodb::bson::doc! { "_id": oid }) .find_one(mongodb::bson::doc! { "_id": oid })
.await .await
{ {
Ok(Some(repo)) => repo, Ok(Some(repo)) => repo,
_ => { _ => {
tracing::warn!("Gitea webhook: repo {repo_id} not found in tenant '{tenant_id}'"); tracing::warn!("Gitea webhook: repo {repo_id} not found");
return StatusCode::NOT_FOUND; return StatusCode::NOT_FOUND;
} }
}; };
@@ -72,21 +66,15 @@ pub async fn handle_gitea_webhook(
"push" => { "push" => {
let agent_clone = (*agent).clone(); let agent_clone = (*agent).clone();
let repo_id = repo_id.clone(); let repo_id = repo_id.clone();
let tenant_id = tenant_id.clone();
tokio::spawn(async move { tokio::spawn(async move {
tracing::info!( tracing::info!("Gitea push webhook: triggering scan for {repo_id}");
"Gitea push webhook: triggering scan for {repo_id} in tenant {tenant_id}" if let Err(e) = agent_clone.run_scan(&repo_id, ScanTrigger::Webhook).await {
);
if let Err(e) = agent_clone
.run_scan(&tenant_id, &repo_id, ScanTrigger::Webhook)
.await
{
tracing::error!("Webhook-triggered scan failed: {e}"); tracing::error!("Webhook-triggered scan failed: {e}");
} }
}); });
StatusCode::OK StatusCode::OK
} }
"pull_request" => handle_pull_request(agent, &tenant_id, &repo_id, &payload).await, "pull_request" => handle_pull_request(agent, &repo_id, &payload).await,
_ => { _ => {
tracing::debug!("Gitea webhook: ignoring event '{event}'"); tracing::debug!("Gitea webhook: ignoring event '{event}'");
StatusCode::OK StatusCode::OK
@@ -96,7 +84,6 @@ pub async fn handle_gitea_webhook(
async fn handle_pull_request( async fn handle_pull_request(
agent: Arc<ComplianceAgent>, agent: Arc<ComplianceAgent>,
tenant_id: &str,
repo_id: &str, repo_id: &str,
payload: &serde_json::Value, payload: &serde_json::Value,
) -> StatusCode { ) -> StatusCode {
@@ -119,14 +106,13 @@ async fn handle_pull_request(
} }
let repo_id = repo_id.to_string(); let repo_id = repo_id.to_string();
let tenant_id = tenant_id.to_string();
let head_sha = head_sha.to_string(); let head_sha = head_sha.to_string();
let base_sha = base_sha.to_string(); let base_sha = base_sha.to_string();
let agent_clone = (*agent).clone(); let agent_clone = (*agent).clone();
tokio::spawn(async move { tokio::spawn(async move {
tracing::info!("Gitea PR webhook: reviewing PR #{pr_number} on {repo_id}"); tracing::info!("Gitea PR webhook: reviewing PR #{pr_number} on {repo_id}");
if let Err(e) = agent_clone if let Err(e) = agent_clone
.run_pr_review(&tenant_id, &repo_id, pr_number, &base_sha, &head_sha) .run_pr_review(&repo_id, pr_number, &base_sha, &head_sha)
.await .await
{ {
tracing::error!("PR review failed for #{pr_number}: {e}"); tracing::error!("PR review failed for #{pr_number}: {e}");
+9 -23
View File
@@ -14,30 +14,24 @@ type HmacSha256 = Hmac<Sha256>;
pub async fn handle_github_webhook( pub async fn handle_github_webhook(
Extension(agent): Extension<Arc<ComplianceAgent>>, Extension(agent): Extension<Arc<ComplianceAgent>>,
Path((tenant_id, repo_id)): Path<(String, String)>, Path(repo_id): Path<String>,
headers: HeaderMap, headers: HeaderMap,
body: Bytes, body: Bytes,
) -> StatusCode { ) -> StatusCode {
// Look up the repo in the tenant's database to get its webhook secret // Look up the repo to get its webhook secret
let oid = match mongodb::bson::oid::ObjectId::parse_str(&repo_id) { let oid = match mongodb::bson::oid::ObjectId::parse_str(&repo_id) {
Ok(oid) => oid, Ok(oid) => oid,
Err(_) => return StatusCode::NOT_FOUND, Err(_) => return StatusCode::NOT_FOUND,
}; };
let db = match agent.db_pool.for_tenant_id(&tenant_id).await { let repo = match agent
Ok(db) => db, .db
Err(e) => {
tracing::warn!("GitHub webhook: cannot open tenant database '{tenant_id}': {e}");
return StatusCode::NOT_FOUND;
}
};
let repo = match db
.repositories() .repositories()
.find_one(mongodb::bson::doc! { "_id": oid }) .find_one(mongodb::bson::doc! { "_id": oid })
.await .await
{ {
Ok(Some(repo)) => repo, Ok(Some(repo)) => repo,
_ => { _ => {
tracing::warn!("GitHub webhook: repo {repo_id} not found in tenant '{tenant_id}'"); tracing::warn!("GitHub webhook: repo {repo_id} not found");
return StatusCode::NOT_FOUND; return StatusCode::NOT_FOUND;
} }
}; };
@@ -72,21 +66,15 @@ pub async fn handle_github_webhook(
"push" => { "push" => {
let agent_clone = (*agent).clone(); let agent_clone = (*agent).clone();
let repo_id = repo_id.clone(); let repo_id = repo_id.clone();
let tenant_id = tenant_id.clone();
tokio::spawn(async move { tokio::spawn(async move {
tracing::info!( tracing::info!("GitHub push webhook: triggering scan for {repo_id}");
"GitHub push webhook: triggering scan for {repo_id} in tenant {tenant_id}" if let Err(e) = agent_clone.run_scan(&repo_id, ScanTrigger::Webhook).await {
);
if let Err(e) = agent_clone
.run_scan(&tenant_id, &repo_id, ScanTrigger::Webhook)
.await
{
tracing::error!("Webhook-triggered scan failed: {e}"); tracing::error!("Webhook-triggered scan failed: {e}");
} }
}); });
StatusCode::OK StatusCode::OK
} }
"pull_request" => handle_pull_request(agent, &tenant_id, &repo_id, &payload).await, "pull_request" => handle_pull_request(agent, &repo_id, &payload).await,
_ => { _ => {
tracing::debug!("GitHub webhook: ignoring event '{event}'"); tracing::debug!("GitHub webhook: ignoring event '{event}'");
StatusCode::OK StatusCode::OK
@@ -96,7 +84,6 @@ pub async fn handle_github_webhook(
async fn handle_pull_request( async fn handle_pull_request(
agent: Arc<ComplianceAgent>, agent: Arc<ComplianceAgent>,
tenant_id: &str,
repo_id: &str, repo_id: &str,
payload: &serde_json::Value, payload: &serde_json::Value,
) -> StatusCode { ) -> StatusCode {
@@ -118,14 +105,13 @@ async fn handle_pull_request(
} }
let repo_id = repo_id.to_string(); let repo_id = repo_id.to_string();
let tenant_id = tenant_id.to_string();
let head_sha = head_sha.to_string(); let head_sha = head_sha.to_string();
let base_sha = base_sha.to_string(); let base_sha = base_sha.to_string();
let agent_clone = (*agent).clone(); let agent_clone = (*agent).clone();
tokio::spawn(async move { tokio::spawn(async move {
tracing::info!("GitHub PR webhook: reviewing PR #{pr_number} on {repo_id}"); tracing::info!("GitHub PR webhook: reviewing PR #{pr_number} on {repo_id}");
if let Err(e) = agent_clone if let Err(e) = agent_clone
.run_pr_review(&tenant_id, &repo_id, pr_number, &base_sha, &head_sha) .run_pr_review(&repo_id, pr_number, &base_sha, &head_sha)
.await .await
{ {
tracing::error!("PR review failed for #{pr_number}: {e}"); tracing::error!("PR review failed for #{pr_number}: {e}");
+9 -23
View File
@@ -10,30 +10,24 @@ use crate::agent::ComplianceAgent;
pub async fn handle_gitlab_webhook( pub async fn handle_gitlab_webhook(
Extension(agent): Extension<Arc<ComplianceAgent>>, Extension(agent): Extension<Arc<ComplianceAgent>>,
Path((tenant_id, repo_id)): Path<(String, String)>, Path(repo_id): Path<String>,
headers: HeaderMap, headers: HeaderMap,
body: Bytes, body: Bytes,
) -> StatusCode { ) -> StatusCode {
// Look up the repo in the tenant's database to get its webhook secret // Look up the repo to get its webhook secret
let oid = match mongodb::bson::oid::ObjectId::parse_str(&repo_id) { let oid = match mongodb::bson::oid::ObjectId::parse_str(&repo_id) {
Ok(oid) => oid, Ok(oid) => oid,
Err(_) => return StatusCode::NOT_FOUND, Err(_) => return StatusCode::NOT_FOUND,
}; };
let db = match agent.db_pool.for_tenant_id(&tenant_id).await { let repo = match agent
Ok(db) => db, .db
Err(e) => {
tracing::warn!("GitLab webhook: cannot open tenant database '{tenant_id}': {e}");
return StatusCode::NOT_FOUND;
}
};
let repo = match db
.repositories() .repositories()
.find_one(mongodb::bson::doc! { "_id": oid }) .find_one(mongodb::bson::doc! { "_id": oid })
.await .await
{ {
Ok(Some(repo)) => repo, Ok(Some(repo)) => repo,
_ => { _ => {
tracing::warn!("GitLab webhook: repo {repo_id} not found in tenant '{tenant_id}'"); tracing::warn!("GitLab webhook: repo {repo_id} not found");
return StatusCode::NOT_FOUND; return StatusCode::NOT_FOUND;
} }
}; };
@@ -65,21 +59,15 @@ pub async fn handle_gitlab_webhook(
"push" => { "push" => {
let agent_clone = (*agent).clone(); let agent_clone = (*agent).clone();
let repo_id = repo_id.clone(); let repo_id = repo_id.clone();
let tenant_id = tenant_id.clone();
tokio::spawn(async move { tokio::spawn(async move {
tracing::info!( tracing::info!("GitLab push webhook: triggering scan for {repo_id}");
"GitLab push webhook: triggering scan for {repo_id} in tenant {tenant_id}" if let Err(e) = agent_clone.run_scan(&repo_id, ScanTrigger::Webhook).await {
);
if let Err(e) = agent_clone
.run_scan(&tenant_id, &repo_id, ScanTrigger::Webhook)
.await
{
tracing::error!("Webhook-triggered scan failed: {e}"); tracing::error!("Webhook-triggered scan failed: {e}");
} }
}); });
StatusCode::OK StatusCode::OK
} }
"merge_request" => handle_merge_request(agent, &tenant_id, &repo_id, &payload).await, "merge_request" => handle_merge_request(agent, &repo_id, &payload).await,
_ => { _ => {
tracing::debug!("GitLab webhook: ignoring event '{event_type}'"); tracing::debug!("GitLab webhook: ignoring event '{event_type}'");
StatusCode::OK StatusCode::OK
@@ -89,7 +77,6 @@ pub async fn handle_gitlab_webhook(
async fn handle_merge_request( async fn handle_merge_request(
agent: Arc<ComplianceAgent>, agent: Arc<ComplianceAgent>,
tenant_id: &str,
repo_id: &str, repo_id: &str,
payload: &serde_json::Value, payload: &serde_json::Value,
) -> StatusCode { ) -> StatusCode {
@@ -114,14 +101,13 @@ async fn handle_merge_request(
} }
let repo_id = repo_id.to_string(); let repo_id = repo_id.to_string();
let tenant_id = tenant_id.to_string();
let head_sha = head_sha.to_string(); let head_sha = head_sha.to_string();
let base_sha = base_sha.to_string(); let base_sha = base_sha.to_string();
let agent_clone = (*agent).clone(); let agent_clone = (*agent).clone();
tokio::spawn(async move { tokio::spawn(async move {
tracing::info!("GitLab MR webhook: reviewing MR !{mr_iid} on {repo_id}"); tracing::info!("GitLab MR webhook: reviewing MR !{mr_iid} on {repo_id}");
if let Err(e) = agent_clone if let Err(e) = agent_clone
.run_pr_review(&tenant_id, &repo_id, mr_iid, &base_sha, &head_sha) .run_pr_review(&repo_id, mr_iid, &base_sha, &head_sha)
.await .await
{ {
tracing::error!("MR review failed for !{mr_iid}: {e}"); tracing::error!("MR review failed for !{mr_iid}: {e}");
+4 -8
View File
@@ -9,21 +9,17 @@ use crate::webhooks::{gitea, github, gitlab};
pub async fn start_webhook_server(agent: &ComplianceAgent) -> Result<(), AgentError> { pub async fn start_webhook_server(agent: &ComplianceAgent) -> Result<(), AgentError> {
let app = Router::new() let app = Router::new()
// Per-tenant per-repo webhook URLs: /webhook/{tenant_id}/{platform}/{repo_id} // Per-repo webhook URLs: /webhook/{platform}/{repo_id}
// The tenant_id is resolved from the URL path because webhooks
// arrive without a JWT — they're authenticated via per-repo HMAC,
// not via the tenant gate. The dashboard surfaces the full URL
// including the tenant_id when the repo is registered.
.route( .route(
"/webhook/{tenant_id}/github/{repo_id}", "/webhook/github/{repo_id}",
post(github::handle_github_webhook), post(github::handle_github_webhook),
) )
.route( .route(
"/webhook/{tenant_id}/gitlab/{repo_id}", "/webhook/gitlab/{repo_id}",
post(gitlab::handle_gitlab_webhook), post(gitlab::handle_gitlab_webhook),
) )
.route( .route(
"/webhook/{tenant_id}/gitea/{repo_id}", "/webhook/gitea/{repo_id}",
post(gitea::handle_gitea_webhook), post(gitea::handle_gitea_webhook),
) )
.layer(Extension(Arc::new(agent.clone()))); .layer(Extension(Arc::new(agent.clone())));
+8 -20
View File
@@ -7,7 +7,7 @@ use std::sync::Arc;
use compliance_agent::agent::ComplianceAgent; use compliance_agent::agent::ComplianceAgent;
use compliance_agent::api; use compliance_agent::api;
use compliance_agent::database::DatabasePool; use compliance_agent::database::Database;
use compliance_core::AgentConfig; use compliance_core::AgentConfig;
use secrecy::SecretString; use secrecy::SecretString;
@@ -28,9 +28,10 @@ impl TestServer {
// Unique database name per test run to avoid collisions // Unique database name per test run to avoid collisions
let db_name = format!("test_{}", uuid::Uuid::new_v4().simple()); let db_name = format!("test_{}", uuid::Uuid::new_v4().simple());
let db_pool = DatabasePool::connect(&mongodb_uri, &db_name) let db = Database::connect(&mongodb_uri, &db_name)
.await .await
.expect("Failed to build DatabasePool"); .expect("Failed to connect to MongoDB — is it running?");
db.ensure_indexes().await.expect("Failed to create indexes");
let config = AgentConfig { let config = AgentConfig {
mongodb_uri: mongodb_uri.clone(), mongodb_uri: mongodb_uri.clone(),
@@ -68,15 +69,11 @@ impl TestServer {
pentest_imap_password: None, pentest_imap_password: None,
}; };
let agent = ComplianceAgent::new(config, db_pool); let agent = ComplianceAgent::new(config, db);
// Build the router with the agent extension. After M7.2-B every // Build the router with the agent extension
// handler takes a TenantCtx extractor; without KC in the test
// harness, the dev-tenant injector mounts a synthetic context so
// tests run end-to-end against `<db_name>_dev`.
let app = api::routes::build_router() let app = api::routes::build_router()
.layer(axum::extract::Extension(Arc::new(agent))) .layer(axum::extract::Extension(Arc::new(agent)))
.layer(axum::middleware::from_fn(api::server::inject_dev_tenant))
.layer(tower_http::cors::CorsLayer::permissive()); .layer(tower_http::cors::CorsLayer::permissive());
// Bind to port 0 to get a random available port // Bind to port 0 to get a random available port
@@ -159,19 +156,10 @@ impl TestServer {
&self.db_name &self.db_name
} }
/// Drop every per-tenant database belonging to this test run. /// Drop the test database on cleanup
/// Post-M7.2-D the agent never opens a `db_name` directly —
/// data lives only in `<db_name>_<tenant>` per-tenant databases.
pub async fn cleanup(&self) { pub async fn cleanup(&self) {
if let Ok(client) = mongodb::Client::with_uri_str(&self.mongodb_uri).await { if let Ok(client) = mongodb::Client::with_uri_str(&self.mongodb_uri).await {
if let Ok(names) = client.list_database_names().await { client.database(&self.db_name).drop().await.ok();
let prefix = format!("{}_", self.db_name);
for name in names {
if name.starts_with(&prefix) {
client.database(&name).drop().await.ok();
}
}
}
} }
} }
} }
-298
View File
@@ -1,298 +0,0 @@
//! M7.2-A — `DatabasePool` isolation proof.
//!
//! Two `TenantContext`s, two databases, one client. Insert on A, query
//! on B → empty. Insert on B, query on A → only A's docs. Proves that
//! the per-tenant database split actually isolates at the driver level
//! and not at "we hope we filter."
//!
//! Requires MongoDB. Set `TEST_MONGODB_URI` to override the default
//! `mongodb://root:example@localhost:27017/?authSource=admin`.
#![allow(clippy::expect_used, clippy::unwrap_used)]
use compliance_agent::database::DatabasePool;
use compliance_core::models::TrackedRepository;
use compliance_core::{OrgRole, TenantContext, TenantStatus};
use mongodb::bson::doc;
fn ctx(tenant_id: &str, slug: &str) -> TenantContext {
TenantContext {
tenant_id: tenant_id.to_string(),
tenant_slug: slug.to_string(),
org_roles: vec![OrgRole::ItAdmin],
products: vec!["compliance-scanner".to_string()],
plan: "starter".to_string(),
status: TenantStatus::Active,
user_id: "u-1".to_string(),
user_name: None,
}
}
fn fixture_repo(name: &str, git_url: &str) -> TrackedRepository {
TrackedRepository {
id: None,
name: name.to_string(),
git_url: git_url.to_string(),
default_branch: "main".to_string(),
local_path: None,
scan_schedule: None,
webhook_enabled: false,
webhook_secret: None,
tracker_type: None,
tracker_owner: None,
tracker_repo: None,
tracker_token: None,
auth_token: None,
auth_username: None,
last_scanned_commit: None,
findings_count: 0,
created_at: chrono::Utc::now(),
updated_at: chrono::Utc::now(),
}
}
#[tokio::test]
async fn pool_isolates_tenants_at_driver_level() {
let uri = std::env::var("TEST_MONGODB_URI")
.unwrap_or_else(|_| "mongodb://root:example@localhost:27017/?authSource=admin".into());
// Unique per run so parallel test invocations don't collide. Kept
// short because Mongo caps db names at 63 bytes (prefix + tenant_id).
let prefix = format!("m72a_{}", short_id());
let pool = DatabasePool::connect(&uri, &prefix)
.await
.expect("Failed to connect to MongoDB — is it running?");
let acme = ctx("00000000-0000-0000-0000-00000000acme", "acme");
let globex = ctx("00000000-0000-0000-0000-0000globex000", "globex");
let acme_db = pool.for_tenant(&acme).await.expect("acme db");
let globex_db = pool.for_tenant(&globex).await.expect("globex db");
// Write distinct repos into each tenant's database.
acme_db
.repositories()
.insert_one(fixture_repo("acme-app", "git@example.com:acme/app.git"))
.await
.expect("insert acme");
globex_db
.repositories()
.insert_one(fixture_repo(
"globex-platform",
"git@example.com:globex/platform.git",
))
.await
.expect("insert globex");
// The point of the whole exercise: acme can ONLY see acme's repo
// and globex can ONLY see globex's, with no filter doc anywhere
// because the isolation is at the database handle, not in the query.
let acme_seen = collect(&acme_db).await;
let globex_seen = collect(&globex_db).await;
assert_eq!(acme_seen.len(), 1, "acme should see exactly its own repo");
assert_eq!(acme_seen[0].name, "acme-app");
assert_eq!(
globex_seen.len(),
1,
"globex should see exactly its own repo"
);
assert_eq!(globex_seen[0].name, "globex-platform");
// Sanity: the two databases really are different by name.
let acme_db_name = pool.tenant_db_name(&acme.tenant_id);
let globex_db_name = pool.tenant_db_name(&globex.tenant_id);
assert_ne!(acme_db_name, globex_db_name);
assert!(acme_db_name.starts_with(&prefix));
// Cleanup — drop both per-tenant databases.
pool.client()
.database(&acme_db_name)
.drop()
.await
.expect("drop acme");
pool.client()
.database(&globex_db_name)
.drop()
.await
.expect("drop globex");
}
#[tokio::test]
async fn for_tenant_is_idempotent_index_creation() {
let uri = std::env::var("TEST_MONGODB_URI")
.unwrap_or_else(|_| "mongodb://root:example@localhost:27017/?authSource=admin".into());
let prefix = format!("m72a_{}", short_id());
let pool = DatabasePool::connect(&uri, &prefix).await.expect("connect");
let acme = ctx("00000000-0000-0000-0000-00000000acme", "acme");
// Second call must not fail (ensure_indexes already ran, in-memory
// marker is set, Mongo's createIndex is idempotent by name anyway).
let _ = pool.for_tenant(&acme).await.expect("first call");
let _ = pool.for_tenant(&acme).await.expect("second call");
let _ = pool.for_tenant(&acme).await.expect("third call");
// Cleanup
let db_name = pool.tenant_db_name(&acme.tenant_id);
pool.client().database(&db_name).drop().await.expect("drop");
}
#[tokio::test]
async fn tenant_db_name_sanitizes_unsafe_characters() {
let uri = std::env::var("TEST_MONGODB_URI")
.unwrap_or_else(|_| "mongodb://root:example@localhost:27017/?authSource=admin".into());
let pool = DatabasePool::connect(&uri, "m72a_sanitize")
.await
.expect("connect");
// Mongo db names cannot contain `/ \ . " $ <space> NUL`. The pool
// must rewrite these without exploding on connect.
let funky = "te/n.a\\nt$id\" with spaces";
let name = pool.tenant_db_name(funky);
for c in ['/', '\\', '.', '"', '$', ' '] {
assert!(
!name.contains(c),
"sanitized db name still contains {c:?}: {name}"
);
}
}
#[tokio::test]
async fn admin_helpers_list_and_drop_tenant_dbs() {
let uri = std::env::var("TEST_MONGODB_URI")
.unwrap_or_else(|_| "mongodb://root:example@localhost:27017/?authSource=admin".into());
let prefix = format!("m72d_{}", short_id());
let pool = DatabasePool::connect(&uri, &prefix).await.expect("connect");
let acme = ctx("00000000-0000-0000-0000-00000000acme", "acme");
let globex = ctx("00000000-0000-0000-0000-0000globex000", "globex");
// Provision two tenants and write a doc into each so the databases
// actually materialize on the cluster (Mongo lazily creates DBs).
let acme_db = pool.for_tenant(&acme).await.expect("acme db");
let globex_db = pool.for_tenant(&globex).await.expect("globex db");
acme_db
.repositories()
.insert_one(fixture_repo("acme-app", "git@example.com:acme/app.git"))
.await
.expect("insert acme");
globex_db
.repositories()
.insert_one(fixture_repo("globex-app", "git@example.com:globex/app.git"))
.await
.expect("insert globex");
// list_tenant_db_names sees both, filtered by prefix
let names = pool.list_tenant_db_names().await.expect("list tenants");
let acme_name = pool.tenant_db_name(&acme.tenant_id);
let globex_name = pool.tenant_db_name(&globex.tenant_id);
assert!(
names.contains(&acme_name),
"expected {acme_name} in {names:?}"
);
assert!(
names.contains(&globex_name),
"expected {globex_name} in {names:?}"
);
for name in &names {
assert!(name.starts_with(&format!("{prefix}_")));
}
// drop_tenant removes acme's DB
pool.drop_tenant(&acme.tenant_id)
.await
.expect("drop acme tenant");
let after = pool
.list_tenant_db_names()
.await
.expect("list tenants after drop");
assert!(
!after.contains(&acme_name),
"acme should be gone after drop, got {after:?}"
);
assert!(
after.contains(&globex_name),
"globex should still be present, got {after:?}"
);
// Cleanup remaining
pool.drop_tenant(&globex.tenant_id)
.await
.expect("drop globex tenant");
}
#[tokio::test]
async fn tenant_db_name_falls_back_to_hash_when_too_long() {
let uri = std::env::var("TEST_MONGODB_URI")
.unwrap_or_else(|_| "mongodb://root:example@localhost:27017/?authSource=admin".into());
let pool = DatabasePool::connect(&uri, "m72a_long")
.await
.expect("connect");
// 100-byte tenant_id would overflow the 63-byte db-name cap with
// any reasonable prefix. The pool must hash it down.
let huge = "x".repeat(100);
let name = pool.tenant_db_name(&huge);
assert!(name.len() <= 63, "hashed name should fit: {name}");
assert!(name.starts_with("m72a_long_"));
// The hash suffix is 32 hex chars (16-byte SHA-256 truncation).
let suffix = name.trim_start_matches("m72a_long_");
assert_eq!(
suffix.len(),
32,
"expected 32-hex suffix (16-byte hash), got {suffix:?}"
);
assert!(suffix.chars().all(|c| c.is_ascii_hexdigit()));
// Stable: same input → same output.
assert_eq!(name, pool.tenant_db_name(&huge));
// Different inputs → different outputs (collision check on a tiny
// sample — full birthday-resistance is a proof not a test).
let huge2 = "y".repeat(100);
assert_ne!(pool.tenant_db_name(&huge), pool.tenant_db_name(&huge2));
}
#[tokio::test]
async fn connect_rejects_overlong_db_prefix() {
let uri = std::env::var("TEST_MONGODB_URI")
.unwrap_or_else(|_| "mongodb://root:example@localhost:27017/?authSource=admin".into());
// MAX_PREFIX_LEN is 30 (= 63 - 1 - 32). A 31-char prefix MUST be
// rejected at construction so the hash-fallback path can never
// produce an over-long db name at runtime.
let too_long = "a".repeat(31);
let err = DatabasePool::connect(&uri, &too_long).await.unwrap_err();
let msg = format!("{err}");
assert!(
msg.contains("max is 30") || msg.contains(&too_long),
"error should explain the cap: {msg}"
);
// Exactly 30 chars is the inclusive bound — must succeed.
let just_right = "a".repeat(30);
let _ = DatabasePool::connect(&uri, &just_right)
.await
.expect("30-char prefix should be accepted");
}
/// Short UUID slug for keeping test prefixes well under Mongo's 63-byte
/// db-name cap.
fn short_id() -> String {
uuid::Uuid::new_v4().simple().to_string()[..8].to_string()
}
/// Drain a `repositories` find cursor on the given tenant database.
async fn collect(db: &compliance_agent::database::Database) -> Vec<TrackedRepository> {
let mut cursor = db
.repositories()
.find(doc! {})
.await
.expect("find repositories");
let mut out = Vec::new();
while cursor.advance().await.expect("advance") {
out.push(cursor.deserialize_current().expect("deserialize"));
}
out
}
@@ -1,122 +0,0 @@
//! M7.1 — integration tests for `compliance_core::auth::require_tenant_status`.
//!
//! Exercises the middleware end-to-end through an Axum router so we
//! catch wiring bugs (extension propagation, method matching) that pure
//! unit tests would miss.
#![allow(clippy::expect_used, clippy::unwrap_used)]
use axum::{
body::Body,
extract::Request,
http::{Method, StatusCode},
middleware::{from_fn, Next},
response::Response,
routing::{get, post},
Router,
};
use compliance_core::{auth::require_tenant_status, TenantContext, TenantStatus};
use tower::ServiceExt;
fn ctx_with(status: TenantStatus) -> TenantContext {
TenantContext {
tenant_id: "t-1".to_string(),
tenant_slug: "acme".to_string(),
org_roles: vec![],
products: vec![],
plan: "starter".to_string(),
status,
user_id: "u-1".to_string(),
user_name: None,
}
}
fn router_with_ctx(ctx: Option<TenantContext>) -> Router {
let injector = move |mut req: Request, next: Next| {
let ctx = ctx.clone();
async move {
if let Some(c) = ctx {
req.extensions_mut().insert(c);
}
next.run(req).await
}
};
Router::new()
.route("/r", get(|| async { "read" }))
.route("/w", post(|| async { "write" }))
.layer(from_fn(require_tenant_status))
.layer(from_fn(injector))
}
async fn call(router: Router, method: Method, path: &str) -> Response {
let req = Request::builder()
.method(method)
.uri(path)
.body(Body::empty())
.expect("request build");
router.oneshot(req).await.expect("oneshot")
}
#[tokio::test]
async fn active_tenant_can_read_and_write() {
let r = router_with_ctx(Some(ctx_with(TenantStatus::Active)));
assert_eq!(
call(r.clone(), Method::GET, "/r").await.status(),
StatusCode::OK
);
assert_eq!(call(r, Method::POST, "/w").await.status(), StatusCode::OK);
}
#[tokio::test]
async fn trial_tenant_can_read_and_write() {
let r = router_with_ctx(Some(ctx_with(TenantStatus::Trial)));
assert_eq!(
call(r.clone(), Method::GET, "/r").await.status(),
StatusCode::OK
);
assert_eq!(call(r, Method::POST, "/w").await.status(), StatusCode::OK);
}
#[tokio::test]
async fn demo_tenant_can_read_and_write() {
let r = router_with_ctx(Some(ctx_with(TenantStatus::Demo)));
assert_eq!(
call(r.clone(), Method::GET, "/r").await.status(),
StatusCode::OK
);
assert_eq!(call(r, Method::POST, "/w").await.status(), StatusCode::OK);
}
#[tokio::test]
async fn frozen_tenant_can_read_but_not_write() {
let r = router_with_ctx(Some(ctx_with(TenantStatus::Frozen)));
assert_eq!(
call(r.clone(), Method::GET, "/r").await.status(),
StatusCode::OK
);
assert_eq!(
call(r, Method::POST, "/w").await.status(),
StatusCode::PAYMENT_REQUIRED
);
}
#[tokio::test]
async fn archived_tenant_is_gone_on_every_method() {
let r = router_with_ctx(Some(ctx_with(TenantStatus::Archived)));
assert_eq!(
call(r.clone(), Method::GET, "/r").await.status(),
StatusCode::GONE
);
assert_eq!(call(r, Method::POST, "/w").await.status(), StatusCode::GONE);
}
#[tokio::test]
async fn no_context_passes_through() {
let r = router_with_ctx(None);
assert_eq!(
call(r.clone(), Method::GET, "/r").await.status(),
StatusCode::OK
);
assert_eq!(call(r, Method::POST, "/w").await.status(), StatusCode::OK);
}
-69
View File
@@ -1,69 +0,0 @@
//! Per-tenant API tokens used by `compliance-mcp` to authenticate MCP
//! HTTP requests on behalf of LLM clients (Claude Desktop, Cursor,
//! ChatGPT, etc.) that can't run a Keycloak OIDC flow.
//!
//! Tokens are opaque strings of the form `mcpt_<44 url-safe random
//! chars>`. The raw value is shown to the user exactly once at
//! creation; the database only ever sees the SHA-256 hash. Lookups go
//! through the cross-tenant `<prefix>__admin.mcp_tokens` collection
//! and return the `tenant_id` the MCP server should route to.
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
/// Persisted token metadata. `token_hash` is the SHA-256 hex of the
/// raw token; the raw token itself is never stored.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct McpToken {
#[serde(rename = "_id", skip_serializing_if = "Option::is_none")]
pub id: Option<bson::oid::ObjectId>,
/// SHA-256 hex of the raw token. Unique index in the collection.
pub token_hash: String,
/// First 8 chars of the raw token — purely for UI display so users
/// can identify which token is which without re-issuing.
pub token_prefix: String,
/// Routes to `<db_prefix>_<tenant_id>` on MCP requests.
pub tenant_id: String,
/// User-given label, e.g. "Claude Desktop" or "Sharang's laptop".
pub name: String,
/// Keycloak `sub` of the user who created this token, for audit.
pub created_by: String,
#[serde(with = "super::serde_helpers::bson_datetime")]
pub created_at: DateTime<Utc>,
#[serde(default, with = "super::serde_helpers::opt_bson_datetime")]
pub last_used_at: Option<DateTime<Utc>>,
/// Soft-delete flag. A revoked token doc stays around for audit
/// but never authenticates.
#[serde(default)]
pub revoked: bool,
}
/// Public projection of a token — never includes the hash.
/// Returned by `GET /api/v1/mcp-tokens`.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct McpTokenView {
pub id: String,
pub name: String,
/// `mcpt_xxxx…` so the user can identify which row is which.
pub token_prefix: String,
pub created_by: String,
#[serde(with = "super::serde_helpers::bson_datetime")]
pub created_at: DateTime<Utc>,
#[serde(default, with = "super::serde_helpers::opt_bson_datetime")]
pub last_used_at: Option<DateTime<Utc>>,
pub revoked: bool,
}
impl From<&McpToken> for McpTokenView {
fn from(t: &McpToken) -> Self {
Self {
id: t.id.map(|o| o.to_hex()).unwrap_or_default(),
name: t.name.clone(),
token_prefix: t.token_prefix.clone(),
created_by: t.created_by.clone(),
created_at: t.created_at,
last_used_at: t.last_used_at,
revoked: t.revoked,
}
}
}
-2
View File
@@ -7,7 +7,6 @@ pub mod finding;
pub mod graph; pub mod graph;
pub mod issue; pub mod issue;
pub mod mcp; pub mod mcp;
pub mod mcp_token;
pub mod notification; pub mod notification;
pub mod pentest; pub mod pentest;
pub mod repository; pub mod repository;
@@ -29,7 +28,6 @@ pub use graph::{
}; };
pub use issue::{IssueStatus, TrackerIssue, TrackerType}; pub use issue::{IssueStatus, TrackerIssue, TrackerType};
pub use mcp::{McpServerConfig, McpServerStatus, McpTransport}; pub use mcp::{McpServerConfig, McpServerStatus, McpTransport};
pub use mcp_token::{McpToken, McpTokenView};
pub use notification::{CveNotification, NotificationSeverity, NotificationStatus}; pub use notification::{CveNotification, NotificationSeverity, NotificationStatus};
pub use pentest::{ pub use pentest::{
AttackChainNode, AttackNodeStatus, AuthMode, CodeContextHint, Environment, IdentityProvider, AttackChainNode, AttackNodeStatus, AuthMode, CodeContextHint, Environment, IdentityProvider,
@@ -1,210 +0,0 @@
//! Authenticated HTTP client for talking to the compliance-agent.
//!
//! Every dashboard server function that hits `comp-dev.meghsakha.com/api/v1/*`
//! must go through here so the Keycloak access token from the user's
//! session is attached as `Authorization: Bearer <token>`. Without it
//! the agent's M7.1 `require_jwt_auth` middleware rejects with 401
//! "Missing authorization header".
//!
//! When Keycloak is not configured (dev convenience), the helper
//! returns an unauthenticated builder — matching the agent's
//! pass-through behavior in the same state.
//!
//! **Token refresh**: KC access tokens are short-lived (5 min default
//! in the certifai realm). Before attaching, we decode the JWT's `exp`
//! claim and proactively refresh via the stored refresh_token if the
//! access token is expired or about to expire. The session is updated
//! with the new pair. If refresh fails, we send the (stale) token
//! anyway — the agent's 401 will surface to the UI, which can prompt
//! re-login.
use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine};
use dioxus::prelude::ServerFnError;
use dioxus_fullstack::FullstackContext;
use reqwest::Method;
use super::auth::LOGGED_IN_USER_SESS_KEY;
use super::server_state::ServerState;
use super::user_state::UserStateInner;
/// Seconds before the JWT's `exp` time at which we consider it stale
/// enough to refresh. Covers clock skew + the round-trip to the agent
/// so the token doesn't expire mid-flight.
const REFRESH_SKEW_SECS: i64 = 30;
/// Build a `RequestBuilder` for `<agent_api_url><path>` with the
/// session's access token attached. `path` should include a leading
/// `/`, e.g. `"/api/v1/repositories"`.
pub async fn agent_request(
method: Method,
path: &str,
) -> Result<reqwest::RequestBuilder, ServerFnError> {
let state: ServerState = FullstackContext::extract().await?;
let url = format!("{}{}", state.agent_api_url, path);
let mut req = reqwest::Client::new().request(method, &url);
req = attach_token(req, &state).await?;
Ok(req)
}
/// Same as [`agent_request`] but for `GET`. Convenience for the common case.
pub async fn agent_get(path: &str) -> Result<reqwest::RequestBuilder, ServerFnError> {
agent_request(Method::GET, path).await
}
/// Attach the session's bearer token if Keycloak is configured AND the
/// session has a logged-in user. Refresh the token proactively if it's
/// expired or about to expire. Persists refreshed tokens back into the
/// session.
async fn attach_token(
req: reqwest::RequestBuilder,
state: &ServerState,
) -> Result<reqwest::RequestBuilder, ServerFnError> {
if state.keycloak.is_none() {
return Ok(req);
}
let session: tower_sessions::Session = FullstackContext::extract().await?;
let user: Option<UserStateInner> = session
.get(LOGGED_IN_USER_SESS_KEY)
.await
.map_err(|e| ServerFnError::new(format!("session read failed: {e}")))?;
let Some(mut user) = user else {
return Ok(req);
};
if token_needs_refresh(&user.access_token) {
tracing::debug!("Access token expired or near-expiring; refreshing");
match refresh_tokens(state, &user.refresh_token).await {
Ok((new_access, new_refresh)) => {
user.access_token = new_access;
if let Some(rt) = new_refresh {
user.refresh_token = rt;
}
if let Err(e) = session.insert(LOGGED_IN_USER_SESS_KEY, &user).await {
tracing::warn!("Failed to persist refreshed tokens: {e}");
}
}
Err(e) => {
tracing::warn!("Token refresh failed: {e}; sending current token anyway");
// Fall through — the agent will 401 and the UI will
// prompt re-login. Better than failing the request at
// the dashboard layer with no helpful UX cue.
}
}
}
Ok(req.bearer_auth(user.access_token))
}
/// Decode the JWT's payload (no signature verification — the agent
/// does that) and check the `exp` claim. Treats malformed tokens as
/// expired so the refresh path runs.
fn token_needs_refresh(jwt: &str) -> bool {
let Some(payload_b64) = jwt.split('.').nth(1) else {
return true;
};
let Ok(bytes) = URL_SAFE_NO_PAD.decode(payload_b64) else {
return true;
};
#[derive(serde::Deserialize)]
struct ExpClaim {
exp: i64,
}
let Ok(claims) = serde_json::from_slice::<ExpClaim>(&bytes) else {
return true;
};
let now = chrono::Utc::now().timestamp();
claims.exp - REFRESH_SKEW_SECS <= now
}
/// Exchange a refresh_token for a new access_token. Returns the new
/// access_token and (optionally) the new refresh_token KC issued.
/// KC may rotate refresh_tokens on use; we honor whatever it sends.
async fn refresh_tokens(
state: &ServerState,
refresh_token: &str,
) -> Result<(String, Option<String>), String> {
let kc = state
.keycloak
.ok_or_else(|| "Keycloak not configured".to_string())?;
if refresh_token.is_empty() {
return Err("no refresh_token in session".to_string());
}
#[derive(serde::Deserialize)]
struct TokenResp {
access_token: String,
refresh_token: Option<String>,
}
let resp = reqwest::Client::new()
.post(kc.token_endpoint())
.form(&[
("grant_type", "refresh_token"),
("client_id", kc.client_id.as_str()),
("refresh_token", refresh_token),
])
.send()
.await
.map_err(|e| format!("refresh request failed: {e}"))?;
if !resp.status().is_success() {
let status = resp.status();
let body = resp.text().await.unwrap_or_default();
return Err(format!("refresh rejected ({status}): {body}"));
}
let r: TokenResp = resp
.json()
.await
.map_err(|e| format!("refresh response parse failed: {e}"))?;
Ok((r.access_token, r.refresh_token))
}
#[cfg(test)]
mod tests {
use super::*;
use base64::Engine;
/// Build a JWT-shaped string (header.payload.sig) with the given
/// payload. Signature is bogus — we never verify it locally.
fn make_jwt(payload: &serde_json::Value) -> String {
let payload_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_vec(payload).unwrap());
format!("hdr.{payload_b64}.sig")
}
#[test]
fn token_needs_refresh_true_when_expired() {
let exp = chrono::Utc::now().timestamp() - 60;
let jwt = make_jwt(&serde_json::json!({ "exp": exp }));
assert!(token_needs_refresh(&jwt));
}
#[test]
fn token_needs_refresh_true_within_skew_window() {
// 10 seconds left; less than the 30s skew → must refresh.
let exp = chrono::Utc::now().timestamp() + 10;
let jwt = make_jwt(&serde_json::json!({ "exp": exp }));
assert!(token_needs_refresh(&jwt));
}
#[test]
fn token_needs_refresh_false_with_plenty_of_life() {
let exp = chrono::Utc::now().timestamp() + 600;
let jwt = make_jwt(&serde_json::json!({ "exp": exp }));
assert!(!token_needs_refresh(&jwt));
}
#[test]
fn token_needs_refresh_true_on_malformed_jwt() {
assert!(token_needs_refresh(""));
assert!(token_needs_refresh("not.a.jwt"));
assert!(token_needs_refresh("only-one-segment"));
assert!(token_needs_refresh("hdr.not-base64!.sig"));
}
#[test]
fn token_needs_refresh_true_when_exp_missing() {
let jwt = make_jwt(&serde_json::json!({ "sub": "abc" }));
assert!(token_needs_refresh(&jwt));
}
}
+35 -26
View File
@@ -61,21 +61,23 @@ pub async fn send_chat_message(
message: String, message: String,
history: Vec<ChatHistoryMessage>, history: Vec<ChatHistoryMessage>,
) -> Result<ChatApiResponse, ServerFnError> { ) -> Result<ChatApiResponse, ServerFnError> {
// Chat uses a longer timeout because the LLM round-trip can be slow; let state: super::server_state::ServerState =
// agent_request doesn't expose a per-call timeout so we layer one on. dioxus_fullstack::FullstackContext::extract().await?;
let resp = super::agent_client::agent_request(
reqwest::Method::POST, let url = format!("{}/api/v1/chat/{repo_id}", state.agent_api_url);
&format!("/api/v1/chat/{repo_id}"), let client = reqwest::Client::builder()
) .timeout(std::time::Duration::from_secs(120))
.await? .build()
.timeout(std::time::Duration::from_secs(120)) .map_err(|e| ServerFnError::new(e.to_string()))?;
.json(&serde_json::json!({ let resp = client
"message": message, .post(&url)
"history": history, .json(&serde_json::json!({
})) "message": message,
.send() "history": history,
.await }))
.map_err(|e| ServerFnError::new(format!("Request failed: {e}")))?; .send()
.await
.map_err(|e| ServerFnError::new(format!("Request failed: {e}")))?;
let text = resp let text = resp
.text() .text()
@@ -89,14 +91,19 @@ pub async fn send_chat_message(
#[server] #[server]
pub async fn trigger_embedding_build(repo_id: String) -> Result<(), ServerFnError> { pub async fn trigger_embedding_build(repo_id: String) -> Result<(), ServerFnError> {
super::agent_client::agent_request( let state: super::server_state::ServerState =
reqwest::Method::POST, dioxus_fullstack::FullstackContext::extract().await?;
&format!("/api/v1/chat/{repo_id}/build-embeddings"),
) let url = format!(
.await? "{}/api/v1/chat/{repo_id}/build-embeddings",
.send() state.agent_api_url
.await );
.map_err(|e| ServerFnError::new(e.to_string()))?; let client = reqwest::Client::new();
client
.post(&url)
.send()
.await
.map_err(|e| ServerFnError::new(e.to_string()))?;
Ok(()) Ok(())
} }
@@ -104,9 +111,11 @@ pub async fn trigger_embedding_build(repo_id: String) -> Result<(), ServerFnErro
pub async fn fetch_embedding_status( pub async fn fetch_embedding_status(
repo_id: String, repo_id: String,
) -> Result<EmbeddingStatusResponse, ServerFnError> { ) -> Result<EmbeddingStatusResponse, ServerFnError> {
let resp = super::agent_client::agent_get(&format!("/api/v1/chat/{repo_id}/status")) let state: super::server_state::ServerState =
.await? dioxus_fullstack::FullstackContext::extract().await?;
.send()
let url = format!("{}/api/v1/chat/{repo_id}/status", state.agent_api_url);
let resp = reqwest::get(&url)
.await .await
.map_err(|e| ServerFnError::new(e.to_string()))?; .map_err(|e| ServerFnError::new(e.to_string()))?;
let body: EmbeddingStatusResponse = resp let body: EmbeddingStatusResponse = resp
+34 -22
View File
@@ -26,9 +26,10 @@ pub struct DastFindingDetailResponse {
#[server] #[server]
pub async fn fetch_dast_targets() -> Result<DastTargetsResponse, ServerFnError> { pub async fn fetch_dast_targets() -> Result<DastTargetsResponse, ServerFnError> {
let resp = super::agent_client::agent_get("/api/v1/dast/targets") let state: super::server_state::ServerState =
.await? dioxus_fullstack::FullstackContext::extract().await?;
.send() let url = format!("{}/api/v1/dast/targets", state.agent_api_url);
let resp = reqwest::get(&url)
.await .await
.map_err(|e| ServerFnError::new(e.to_string()))?; .map_err(|e| ServerFnError::new(e.to_string()))?;
let body: DastTargetsResponse = resp let body: DastTargetsResponse = resp
@@ -40,9 +41,10 @@ pub async fn fetch_dast_targets() -> Result<DastTargetsResponse, ServerFnError>
#[server] #[server]
pub async fn fetch_dast_scan_runs() -> Result<DastScanRunsResponse, ServerFnError> { pub async fn fetch_dast_scan_runs() -> Result<DastScanRunsResponse, ServerFnError> {
let resp = super::agent_client::agent_get("/api/v1/dast/scan-runs") let state: super::server_state::ServerState =
.await? dioxus_fullstack::FullstackContext::extract().await?;
.send() let url = format!("{}/api/v1/dast/scan-runs", state.agent_api_url);
let resp = reqwest::get(&url)
.await .await
.map_err(|e| ServerFnError::new(e.to_string()))?; .map_err(|e| ServerFnError::new(e.to_string()))?;
let body: DastScanRunsResponse = resp let body: DastScanRunsResponse = resp
@@ -54,9 +56,10 @@ pub async fn fetch_dast_scan_runs() -> Result<DastScanRunsResponse, ServerFnErro
#[server] #[server]
pub async fn fetch_dast_findings() -> Result<DastFindingsResponse, ServerFnError> { pub async fn fetch_dast_findings() -> Result<DastFindingsResponse, ServerFnError> {
let resp = super::agent_client::agent_get("/api/v1/dast/findings") let state: super::server_state::ServerState =
.await? dioxus_fullstack::FullstackContext::extract().await?;
.send() let url = format!("{}/api/v1/dast/findings", state.agent_api_url);
let resp = reqwest::get(&url)
.await .await
.map_err(|e| ServerFnError::new(e.to_string()))?; .map_err(|e| ServerFnError::new(e.to_string()))?;
let body: DastFindingsResponse = resp let body: DastFindingsResponse = resp
@@ -70,9 +73,10 @@ pub async fn fetch_dast_findings() -> Result<DastFindingsResponse, ServerFnError
pub async fn fetch_dast_finding_detail( pub async fn fetch_dast_finding_detail(
id: String, id: String,
) -> Result<DastFindingDetailResponse, ServerFnError> { ) -> Result<DastFindingDetailResponse, ServerFnError> {
let resp = super::agent_client::agent_get(&format!("/api/v1/dast/findings/{id}")) let state: super::server_state::ServerState =
.await? dioxus_fullstack::FullstackContext::extract().await?;
.send() let url = format!("{}/api/v1/dast/findings/{id}", state.agent_api_url);
let resp = reqwest::get(&url)
.await .await
.map_err(|e| ServerFnError::new(e.to_string()))?; .map_err(|e| ServerFnError::new(e.to_string()))?;
let body: DastFindingDetailResponse = resp let body: DastFindingDetailResponse = resp
@@ -84,8 +88,12 @@ pub async fn fetch_dast_finding_detail(
#[server] #[server]
pub async fn add_dast_target(name: String, base_url: String) -> Result<(), ServerFnError> { pub async fn add_dast_target(name: String, base_url: String) -> Result<(), ServerFnError> {
super::agent_client::agent_request(reqwest::Method::POST, "/api/v1/dast/targets") let state: super::server_state::ServerState =
.await? dioxus_fullstack::FullstackContext::extract().await?;
let url = format!("{}/api/v1/dast/targets", state.agent_api_url);
let client = reqwest::Client::new();
client
.post(&url)
.json(&serde_json::json!({ .json(&serde_json::json!({
"name": name, "name": name,
"base_url": base_url, "base_url": base_url,
@@ -98,13 +106,17 @@ pub async fn add_dast_target(name: String, base_url: String) -> Result<(), Serve
#[server] #[server]
pub async fn trigger_dast_scan(target_id: String) -> Result<(), ServerFnError> { pub async fn trigger_dast_scan(target_id: String) -> Result<(), ServerFnError> {
super::agent_client::agent_request( let state: super::server_state::ServerState =
reqwest::Method::POST, dioxus_fullstack::FullstackContext::extract().await?;
&format!("/api/v1/dast/targets/{target_id}/scan"), let url = format!(
) "{}/api/v1/dast/targets/{target_id}/scan",
.await? state.agent_api_url
.send() );
.await let client = reqwest::Client::new();
.map_err(|e| ServerFnError::new(e.to_string()))?; client
.post(&url)
.send()
.await
.map_err(|e| ServerFnError::new(e.to_string()))?;
Ok(()) Ok(())
} }
@@ -24,35 +24,39 @@ pub struct FindingsQuery {
#[server] #[server]
pub async fn fetch_findings(query: FindingsQuery) -> Result<FindingsListResponse, ServerFnError> { pub async fn fetch_findings(query: FindingsQuery) -> Result<FindingsListResponse, ServerFnError> {
let mut path = format!("/api/v1/findings?page={}&limit=20", query.page); let state: super::server_state::ServerState =
dioxus_fullstack::FullstackContext::extract().await?;
let mut url = format!(
"{}/api/v1/findings?page={}&limit=20",
state.agent_api_url, query.page
);
if !query.severity.is_empty() { if !query.severity.is_empty() {
path.push_str(&format!("&severity={}", query.severity)); url.push_str(&format!("&severity={}", query.severity));
} }
if !query.scan_type.is_empty() { if !query.scan_type.is_empty() {
path.push_str(&format!("&scan_type={}", query.scan_type)); url.push_str(&format!("&scan_type={}", query.scan_type));
} }
if !query.status.is_empty() { if !query.status.is_empty() {
path.push_str(&format!("&status={}", query.status)); url.push_str(&format!("&status={}", query.status));
} }
if !query.repo_id.is_empty() { if !query.repo_id.is_empty() {
path.push_str(&format!("&repo_id={}", query.repo_id)); url.push_str(&format!("&repo_id={}", query.repo_id));
} }
if !query.q.is_empty() { if !query.q.is_empty() {
path.push_str(&format!( url.push_str(&format!(
"&q={}", "&q={}",
url::form_urlencoded::byte_serialize(query.q.as_bytes()).collect::<String>() url::form_urlencoded::byte_serialize(query.q.as_bytes()).collect::<String>()
)); ));
} }
if !query.sort_by.is_empty() { if !query.sort_by.is_empty() {
path.push_str(&format!("&sort_by={}", query.sort_by)); url.push_str(&format!("&sort_by={}", query.sort_by));
} }
if !query.sort_order.is_empty() { if !query.sort_order.is_empty() {
path.push_str(&format!("&sort_order={}", query.sort_order)); url.push_str(&format!("&sort_order={}", query.sort_order));
} }
let resp = super::agent_client::agent_get(&path) let resp = reqwest::get(&url)
.await?
.send()
.await .await
.map_err(|e| ServerFnError::new(e.to_string()))?; .map_err(|e| ServerFnError::new(e.to_string()))?;
let body: FindingsListResponse = resp let body: FindingsListResponse = resp
@@ -64,9 +68,11 @@ pub async fn fetch_findings(query: FindingsQuery) -> Result<FindingsListResponse
#[server] #[server]
pub async fn fetch_finding_detail(id: String) -> Result<Finding, ServerFnError> { pub async fn fetch_finding_detail(id: String) -> Result<Finding, ServerFnError> {
let resp = super::agent_client::agent_get(&format!("/api/v1/findings/{id}")) let state: super::server_state::ServerState =
.await? dioxus_fullstack::FullstackContext::extract().await?;
.send() let url = format!("{}/api/v1/findings/{id}", state.agent_api_url);
let resp = reqwest::get(&url)
.await .await
.map_err(|e| ServerFnError::new(e.to_string()))?; .map_err(|e| ServerFnError::new(e.to_string()))?;
let body: serde_json::Value = resp let body: serde_json::Value = resp
@@ -80,15 +86,18 @@ pub async fn fetch_finding_detail(id: String) -> Result<Finding, ServerFnError>
#[server] #[server]
pub async fn update_finding_status(id: String, status: String) -> Result<(), ServerFnError> { pub async fn update_finding_status(id: String, status: String) -> Result<(), ServerFnError> {
super::agent_client::agent_request( let state: super::server_state::ServerState =
reqwest::Method::PATCH, dioxus_fullstack::FullstackContext::extract().await?;
&format!("/api/v1/findings/{id}/status"), let url = format!("{}/api/v1/findings/{id}/status", state.agent_api_url);
)
.await? let client = reqwest::Client::new();
.json(&serde_json::json!({ "status": status })) client
.send() .patch(&url)
.await .json(&serde_json::json!({ "status": status }))
.map_err(|e| ServerFnError::new(e.to_string()))?; .send()
.await
.map_err(|e| ServerFnError::new(e.to_string()))?;
Ok(()) Ok(())
} }
@@ -97,25 +106,34 @@ pub async fn bulk_update_finding_status(
ids: Vec<String>, ids: Vec<String>,
status: String, status: String,
) -> Result<(), ServerFnError> { ) -> Result<(), ServerFnError> {
super::agent_client::agent_request(reqwest::Method::PATCH, "/api/v1/findings/bulk-status") let state: super::server_state::ServerState =
.await? dioxus_fullstack::FullstackContext::extract().await?;
let url = format!("{}/api/v1/findings/bulk-status", state.agent_api_url);
let client = reqwest::Client::new();
client
.patch(&url)
.json(&serde_json::json!({ "ids": ids, "status": status })) .json(&serde_json::json!({ "ids": ids, "status": status }))
.send() .send()
.await .await
.map_err(|e| ServerFnError::new(e.to_string()))?; .map_err(|e| ServerFnError::new(e.to_string()))?;
Ok(()) Ok(())
} }
#[server] #[server]
pub async fn update_finding_feedback(id: String, feedback: String) -> Result<(), ServerFnError> { pub async fn update_finding_feedback(id: String, feedback: String) -> Result<(), ServerFnError> {
super::agent_client::agent_request( let state: super::server_state::ServerState =
reqwest::Method::PATCH, dioxus_fullstack::FullstackContext::extract().await?;
&format!("/api/v1/findings/{id}/feedback"), let url = format!("{}/api/v1/findings/{id}/feedback", state.agent_api_url);
)
.await? let client = reqwest::Client::new();
.json(&serde_json::json!({ "feedback": feedback })) client
.send() .patch(&url)
.await .json(&serde_json::json!({ "feedback": feedback }))
.map_err(|e| ServerFnError::new(e.to_string()))?; .send()
.await
.map_err(|e| ServerFnError::new(e.to_string()))?;
Ok(()) Ok(())
} }
@@ -50,9 +50,10 @@ pub struct SearchResponse {
#[server] #[server]
pub async fn fetch_graph(repo_id: String) -> Result<GraphDataResponse, ServerFnError> { pub async fn fetch_graph(repo_id: String) -> Result<GraphDataResponse, ServerFnError> {
let resp = super::agent_client::agent_get(&format!("/api/v1/graph/{repo_id}")) let state: super::server_state::ServerState =
.await? dioxus_fullstack::FullstackContext::extract().await?;
.send() let url = format!("{}/api/v1/graph/{repo_id}", state.agent_api_url);
let resp = reqwest::get(&url)
.await .await
.map_err(|e| ServerFnError::new(e.to_string()))?; .map_err(|e| ServerFnError::new(e.to_string()))?;
let body: GraphDataResponse = resp let body: GraphDataResponse = resp
@@ -67,12 +68,15 @@ pub async fn fetch_impact(
repo_id: String, repo_id: String,
finding_id: String, finding_id: String,
) -> Result<ImpactResponse, ServerFnError> { ) -> Result<ImpactResponse, ServerFnError> {
let resp = let state: super::server_state::ServerState =
super::agent_client::agent_get(&format!("/api/v1/graph/{repo_id}/impact/{finding_id}")) dioxus_fullstack::FullstackContext::extract().await?;
.await? let url = format!(
.send() "{}/api/v1/graph/{repo_id}/impact/{finding_id}",
.await state.agent_api_url
.map_err(|e| ServerFnError::new(e.to_string()))?; );
let resp = reqwest::get(&url)
.await
.map_err(|e| ServerFnError::new(e.to_string()))?;
let body: ImpactResponse = resp let body: ImpactResponse = resp
.json() .json()
.await .await
@@ -82,9 +86,10 @@ pub async fn fetch_impact(
#[server] #[server]
pub async fn fetch_communities(repo_id: String) -> Result<CommunitiesResponse, ServerFnError> { pub async fn fetch_communities(repo_id: String) -> Result<CommunitiesResponse, ServerFnError> {
let resp = super::agent_client::agent_get(&format!("/api/v1/graph/{repo_id}/communities")) let state: super::server_state::ServerState =
.await? dioxus_fullstack::FullstackContext::extract().await?;
.send() let url = format!("{}/api/v1/graph/{repo_id}/communities", state.agent_api_url);
let resp = reqwest::get(&url)
.await .await
.map_err(|e| ServerFnError::new(e.to_string()))?; .map_err(|e| ServerFnError::new(e.to_string()))?;
let body: CommunitiesResponse = resp let body: CommunitiesResponse = resp
@@ -99,13 +104,15 @@ pub async fn fetch_file_content(
repo_id: String, repo_id: String,
file_path: String, file_path: String,
) -> Result<FileContentResponse, ServerFnError> { ) -> Result<FileContentResponse, ServerFnError> {
let resp = super::agent_client::agent_get(&format!( let state: super::server_state::ServerState =
"/api/v1/graph/{repo_id}/file-content?path={file_path}" dioxus_fullstack::FullstackContext::extract().await?;
)) let url = format!(
.await? "{}/api/v1/graph/{repo_id}/file-content?path={file_path}",
.send() state.agent_api_url
.await );
.map_err(|e| ServerFnError::new(e.to_string()))?; let resp = reqwest::get(&url)
.await
.map_err(|e| ServerFnError::new(e.to_string()))?;
let body: FileContentResponse = resp let body: FileContentResponse = resp
.json() .json()
.await .await
@@ -115,13 +122,15 @@ pub async fn fetch_file_content(
#[server] #[server]
pub async fn search_nodes(repo_id: String, query: String) -> Result<SearchResponse, ServerFnError> { pub async fn search_nodes(repo_id: String, query: String) -> Result<SearchResponse, ServerFnError> {
let resp = super::agent_client::agent_get(&format!( let state: super::server_state::ServerState =
"/api/v1/graph/{repo_id}/search?q={query}&limit=50" dioxus_fullstack::FullstackContext::extract().await?;
)) let url = format!(
.await? "{}/api/v1/graph/{repo_id}/search?q={query}&limit=50",
.send() state.agent_api_url
.await );
.map_err(|e| ServerFnError::new(e.to_string()))?; let resp = reqwest::get(&url)
.await
.map_err(|e| ServerFnError::new(e.to_string()))?;
let body: SearchResponse = resp let body: SearchResponse = resp
.json() .json()
.await .await
@@ -131,13 +140,14 @@ pub async fn search_nodes(repo_id: String, query: String) -> Result<SearchRespon
#[server] #[server]
pub async fn trigger_graph_build(repo_id: String) -> Result<(), ServerFnError> { pub async fn trigger_graph_build(repo_id: String) -> Result<(), ServerFnError> {
super::agent_client::agent_request( let state: super::server_state::ServerState =
reqwest::Method::POST, dioxus_fullstack::FullstackContext::extract().await?;
&format!("/api/v1/graph/{repo_id}/build"), let url = format!("{}/api/v1/graph/{repo_id}/build", state.agent_api_url);
) let client = reqwest::Client::new();
.await? client
.send() .post(&url)
.await .send()
.map_err(|e| ServerFnError::new(e.to_string()))?; .await
.map_err(|e| ServerFnError::new(e.to_string()))?;
Ok(()) Ok(())
} }
@@ -12,9 +12,11 @@ pub struct IssuesListResponse {
#[server] #[server]
pub async fn fetch_issues(page: u64) -> Result<IssuesListResponse, ServerFnError> { pub async fn fetch_issues(page: u64) -> Result<IssuesListResponse, ServerFnError> {
let resp = super::agent_client::agent_get(&format!("/api/v1/issues?page={page}&limit=20")) let state: super::server_state::ServerState =
.await? dioxus_fullstack::FullstackContext::extract().await?;
.send() let url = format!("{}/api/v1/issues?page={page}&limit=20", state.agent_api_url);
let resp = reqwest::get(&url)
.await .await
.map_err(|e| ServerFnError::new(e.to_string()))?; .map_err(|e| ServerFnError::new(e.to_string()))?;
let body: IssuesListResponse = resp let body: IssuesListResponse = resp
@@ -18,8 +18,6 @@ pub mod stats;
// Server-only modules // Server-only modules
#[cfg(feature = "server")] #[cfg(feature = "server")]
mod agent_client;
#[cfg(feature = "server")]
mod auth; mod auth;
#[cfg(feature = "server")] #[cfg(feature = "server")]
mod auth_middleware; mod auth_middleware;
@@ -32,9 +32,11 @@ pub struct NotificationCountResponse {
#[server] #[server]
pub async fn fetch_notification_count() -> Result<u64, ServerFnError> { pub async fn fetch_notification_count() -> Result<u64, ServerFnError> {
let resp = super::agent_client::agent_get("/api/v1/notifications/count") let state: super::server_state::ServerState =
.await? dioxus_fullstack::FullstackContext::extract().await?;
.send()
let url = format!("{}/api/v1/notifications/count", state.agent_api_url);
let resp = reqwest::get(&url)
.await .await
.map_err(|e| ServerFnError::new(e.to_string()))?; .map_err(|e| ServerFnError::new(e.to_string()))?;
let body: NotificationCountResponse = resp let body: NotificationCountResponse = resp
@@ -46,9 +48,11 @@ pub async fn fetch_notification_count() -> Result<u64, ServerFnError> {
#[server] #[server]
pub async fn fetch_notifications() -> Result<NotificationListResponse, ServerFnError> { pub async fn fetch_notifications() -> Result<NotificationListResponse, ServerFnError> {
let resp = super::agent_client::agent_get("/api/v1/notifications?limit=20") let state: super::server_state::ServerState =
.await? dioxus_fullstack::FullstackContext::extract().await?;
.send()
let url = format!("{}/api/v1/notifications?limit=20", state.agent_api_url);
let resp = reqwest::get(&url)
.await .await
.map_err(|e| ServerFnError::new(e.to_string()))?; .map_err(|e| ServerFnError::new(e.to_string()))?;
let body: NotificationListResponse = resp let body: NotificationListResponse = resp
@@ -60,8 +64,12 @@ pub async fn fetch_notifications() -> Result<NotificationListResponse, ServerFnE
#[server] #[server]
pub async fn mark_all_notifications_read() -> Result<(), ServerFnError> { pub async fn mark_all_notifications_read() -> Result<(), ServerFnError> {
super::agent_client::agent_request(reqwest::Method::POST, "/api/v1/notifications/read-all") let state: super::server_state::ServerState =
.await? dioxus_fullstack::FullstackContext::extract().await?;
let url = format!("{}/api/v1/notifications/read-all", state.agent_api_url);
reqwest::Client::new()
.post(&url)
.send() .send()
.await .await
.map_err(|e| ServerFnError::new(e.to_string()))?; .map_err(|e| ServerFnError::new(e.to_string()))?;
@@ -70,13 +78,14 @@ pub async fn mark_all_notifications_read() -> Result<(), ServerFnError> {
#[server] #[server]
pub async fn dismiss_notification(id: String) -> Result<(), ServerFnError> { pub async fn dismiss_notification(id: String) -> Result<(), ServerFnError> {
super::agent_client::agent_request( let state: super::server_state::ServerState =
reqwest::Method::PATCH, dioxus_fullstack::FullstackContext::extract().await?;
&format!("/api/v1/notifications/{id}/dismiss"),
) let url = format!("{}/api/v1/notifications/{id}/dismiss", state.agent_api_url);
.await? reqwest::Client::new()
.send() .patch(&url)
.await .send()
.map_err(|e| ServerFnError::new(e.to_string()))?; .await
.map_err(|e| ServerFnError::new(e.to_string()))?;
Ok(()) Ok(())
} }
+184 -145
View File
@@ -32,10 +32,12 @@ pub struct AttackChainResponse {
#[server] #[server]
pub async fn fetch_pentest_sessions() -> Result<PentestSessionsResponse, ServerFnError> { pub async fn fetch_pentest_sessions() -> Result<PentestSessionsResponse, ServerFnError> {
let state: super::server_state::ServerState =
dioxus_fullstack::FullstackContext::extract().await?;
// Fetch sessions // Fetch sessions
let resp = super::agent_client::agent_get("/api/v1/pentest/sessions") let url = format!("{}/api/v1/pentest/sessions", state.agent_api_url);
.await? let resp = reqwest::get(&url)
.send()
.await .await
.map_err(|e| ServerFnError::new(e.to_string()))?; .map_err(|e| ServerFnError::new(e.to_string()))?;
let mut body: PentestSessionsResponse = resp let mut body: PentestSessionsResponse = resp
@@ -44,32 +46,31 @@ pub async fn fetch_pentest_sessions() -> Result<PentestSessionsResponse, ServerF
.map_err(|e| ServerFnError::new(e.to_string()))?; .map_err(|e| ServerFnError::new(e.to_string()))?;
// Fetch DAST targets to resolve target names // Fetch DAST targets to resolve target names
if let Ok(tresp_builder) = super::agent_client::agent_get("/api/v1/dast/targets").await { let targets_url = format!("{}/api/v1/dast/targets", state.agent_api_url);
if let Ok(tresp) = tresp_builder.send().await { if let Ok(tresp) = reqwest::get(&targets_url).await {
if let Ok(tbody) = tresp.json::<serde_json::Value>().await { if let Ok(tbody) = tresp.json::<serde_json::Value>().await {
let targets = tbody.get("data").and_then(|v| v.as_array()); let targets = tbody.get("data").and_then(|v| v.as_array());
if let Some(targets) = targets { if let Some(targets) = targets {
// Build target_id -> name lookup // Build target_id -> name lookup
let target_map: std::collections::HashMap<String, String> = targets let target_map: std::collections::HashMap<String, String> = targets
.iter() .iter()
.filter_map(|t| { .filter_map(|t| {
let id = t.get("_id")?.get("$oid")?.as_str()?.to_string(); let id = t.get("_id")?.get("$oid")?.as_str()?.to_string();
let name = t.get("name")?.as_str()?.to_string(); let name = t.get("name")?.as_str()?.to_string();
Some((id, name)) Some((id, name))
}) })
.collect(); .collect();
// Enrich sessions with target_name // Enrich sessions with target_name
for session in body.data.iter_mut() { for session in body.data.iter_mut() {
if let Some(tid) = session.get("target_id").and_then(|v| v.as_str()) { if let Some(tid) = session.get("target_id").and_then(|v| v.as_str()) {
if let Some(name) = target_map.get(tid) { if let Some(name) = target_map.get(tid) {
session.as_object_mut().map(|obj| { session.as_object_mut().map(|obj| {
obj.insert( obj.insert(
"target_name".to_string(), "target_name".to_string(),
serde_json::Value::String(name.clone()), serde_json::Value::String(name.clone()),
) )
}); });
}
} }
} }
} }
@@ -82,9 +83,10 @@ pub async fn fetch_pentest_sessions() -> Result<PentestSessionsResponse, ServerF
#[server] #[server]
pub async fn fetch_pentest_session(id: String) -> Result<PentestSessionResponse, ServerFnError> { pub async fn fetch_pentest_session(id: String) -> Result<PentestSessionResponse, ServerFnError> {
let resp = super::agent_client::agent_get(&format!("/api/v1/pentest/sessions/{id}")) let state: super::server_state::ServerState =
.await? dioxus_fullstack::FullstackContext::extract().await?;
.send() let url = format!("{}/api/v1/pentest/sessions/{id}", state.agent_api_url);
let resp = reqwest::get(&url)
.await .await
.map_err(|e| ServerFnError::new(e.to_string()))?; .map_err(|e| ServerFnError::new(e.to_string()))?;
let mut body: PentestSessionResponse = resp let mut body: PentestSessionResponse = resp
@@ -94,27 +96,26 @@ pub async fn fetch_pentest_session(id: String) -> Result<PentestSessionResponse,
// Resolve target name from targets list // Resolve target name from targets list
if let Some(tid) = body.data.get("target_id").and_then(|v| v.as_str()) { if let Some(tid) = body.data.get("target_id").and_then(|v| v.as_str()) {
if let Ok(tresp_builder) = super::agent_client::agent_get("/api/v1/dast/targets").await { let targets_url = format!("{}/api/v1/dast/targets", state.agent_api_url);
if let Ok(tresp) = tresp_builder.send().await { if let Ok(tresp) = reqwest::get(&targets_url).await {
if let Ok(tbody) = tresp.json::<serde_json::Value>().await { if let Ok(tbody) = tresp.json::<serde_json::Value>().await {
if let Some(targets) = tbody.get("data").and_then(|v| v.as_array()) { if let Some(targets) = tbody.get("data").and_then(|v| v.as_array()) {
for t in targets { for t in targets {
let t_id = t let t_id = t
.get("_id") .get("_id")
.and_then(|v| v.get("$oid")) .and_then(|v| v.get("$oid"))
.and_then(|v| v.as_str()) .and_then(|v| v.as_str())
.unwrap_or(""); .unwrap_or("");
if t_id == tid { if t_id == tid {
if let Some(name) = t.get("name").and_then(|v| v.as_str()) { if let Some(name) = t.get("name").and_then(|v| v.as_str()) {
body.data.as_object_mut().map(|obj| { body.data.as_object_mut().map(|obj| {
obj.insert( obj.insert(
"target_name".to_string(), "target_name".to_string(),
serde_json::Value::String(name.to_string()), serde_json::Value::String(name.to_string()),
) )
}); });
}
break;
} }
break;
} }
} }
} }
@@ -129,12 +130,15 @@ pub async fn fetch_pentest_session(id: String) -> Result<PentestSessionResponse,
pub async fn fetch_pentest_messages( pub async fn fetch_pentest_messages(
session_id: String, session_id: String,
) -> Result<PentestMessagesResponse, ServerFnError> { ) -> Result<PentestMessagesResponse, ServerFnError> {
let resp = let state: super::server_state::ServerState =
super::agent_client::agent_get(&format!("/api/v1/pentest/sessions/{session_id}/messages")) dioxus_fullstack::FullstackContext::extract().await?;
.await? let url = format!(
.send() "{}/api/v1/pentest/sessions/{session_id}/messages",
.await state.agent_api_url
.map_err(|e| ServerFnError::new(e.to_string()))?; );
let resp = reqwest::get(&url)
.await
.map_err(|e| ServerFnError::new(e.to_string()))?;
let body: PentestMessagesResponse = resp let body: PentestMessagesResponse = resp
.json() .json()
.await .await
@@ -144,9 +148,10 @@ pub async fn fetch_pentest_messages(
#[server] #[server]
pub async fn fetch_pentest_stats() -> Result<PentestStatsResponse, ServerFnError> { pub async fn fetch_pentest_stats() -> Result<PentestStatsResponse, ServerFnError> {
let resp = super::agent_client::agent_get("/api/v1/pentest/stats") let state: super::server_state::ServerState =
.await? dioxus_fullstack::FullstackContext::extract().await?;
.send() let url = format!("{}/api/v1/pentest/stats", state.agent_api_url);
let resp = reqwest::get(&url)
.await .await
.map_err(|e| ServerFnError::new(e.to_string()))?; .map_err(|e| ServerFnError::new(e.to_string()))?;
let body: PentestStatsResponse = resp let body: PentestStatsResponse = resp
@@ -158,13 +163,15 @@ pub async fn fetch_pentest_stats() -> Result<PentestStatsResponse, ServerFnError
#[server] #[server]
pub async fn fetch_attack_chain(session_id: String) -> Result<AttackChainResponse, ServerFnError> { pub async fn fetch_attack_chain(session_id: String) -> Result<AttackChainResponse, ServerFnError> {
let resp = super::agent_client::agent_get(&format!( let state: super::server_state::ServerState =
"/api/v1/pentest/sessions/{session_id}/attack-chain" dioxus_fullstack::FullstackContext::extract().await?;
)) let url = format!(
.await? "{}/api/v1/pentest/sessions/{session_id}/attack-chain",
.send() state.agent_api_url
.await );
.map_err(|e| ServerFnError::new(e.to_string()))?; let resp = reqwest::get(&url)
.await
.map_err(|e| ServerFnError::new(e.to_string()))?;
let body: AttackChainResponse = resp let body: AttackChainResponse = resp
.json() .json()
.await .await
@@ -178,17 +185,20 @@ pub async fn create_pentest_session(
strategy: String, strategy: String,
message: String, message: String,
) -> Result<PentestSessionResponse, ServerFnError> { ) -> Result<PentestSessionResponse, ServerFnError> {
let resp = let state: super::server_state::ServerState =
super::agent_client::agent_request(reqwest::Method::POST, "/api/v1/pentest/sessions") dioxus_fullstack::FullstackContext::extract().await?;
.await? let url = format!("{}/api/v1/pentest/sessions", state.agent_api_url);
.json(&serde_json::json!({ let client = reqwest::Client::new();
"target_id": target_id, let resp = client
"strategy": strategy, .post(&url)
"message": message, .json(&serde_json::json!({
})) "target_id": target_id,
.send() "strategy": strategy,
.await "message": message,
.map_err(|e| ServerFnError::new(e.to_string()))?; }))
.send()
.await
.map_err(|e| ServerFnError::new(e.to_string()))?;
let body: PentestSessionResponse = resp let body: PentestSessionResponse = resp
.json() .json()
.await .await
@@ -201,15 +211,18 @@ pub async fn create_pentest_session(
pub async fn create_pentest_session_wizard( pub async fn create_pentest_session_wizard(
config_json: String, config_json: String,
) -> Result<PentestSessionResponse, ServerFnError> { ) -> Result<PentestSessionResponse, ServerFnError> {
let state: super::server_state::ServerState =
dioxus_fullstack::FullstackContext::extract().await?;
let url = format!("{}/api/v1/pentest/sessions", state.agent_api_url);
let config: serde_json::Value = let config: serde_json::Value =
serde_json::from_str(&config_json).map_err(|e| ServerFnError::new(e.to_string()))?; serde_json::from_str(&config_json).map_err(|e| ServerFnError::new(e.to_string()))?;
let resp = let client = reqwest::Client::new();
super::agent_client::agent_request(reqwest::Method::POST, "/api/v1/pentest/sessions") let resp = client
.await? .post(&url)
.json(&serde_json::json!({ "config": config })) .json(&serde_json::json!({ "config": config }))
.send() .send()
.await .await
.map_err(|e| ServerFnError::new(e.to_string()))?; .map_err(|e| ServerFnError::new(e.to_string()))?;
if !resp.status().is_success() { if !resp.status().is_success() {
let text = resp.text().await.unwrap_or_default(); let text = resp.text().await.unwrap_or_default();
return Err(ServerFnError::new(format!( return Err(ServerFnError::new(format!(
@@ -226,6 +239,8 @@ pub async fn create_pentest_session_wizard(
/// Look up a tracked repository by its git URL /// Look up a tracked repository by its git URL
#[server] #[server]
pub async fn lookup_repo_by_url(url: String) -> Result<serde_json::Value, ServerFnError> { pub async fn lookup_repo_by_url(url: String) -> Result<serde_json::Value, ServerFnError> {
let state: super::server_state::ServerState =
dioxus_fullstack::FullstackContext::extract().await?;
let encoded_url: String = url let encoded_url: String = url
.bytes() .bytes()
.flat_map(|b| { .flat_map(|b| {
@@ -236,12 +251,13 @@ pub async fn lookup_repo_by_url(url: String) -> Result<serde_json::Value, Server
} }
}) })
.collect(); .collect();
let resp = let api_url = format!(
super::agent_client::agent_get(&format!("/api/v1/pentest/lookup-repo?url={encoded_url}")) "{}/api/v1/pentest/lookup-repo?url={}",
.await? state.agent_api_url, encoded_url
.send() );
.await let resp = reqwest::get(&api_url)
.map_err(|e| ServerFnError::new(e.to_string()))?; .await
.map_err(|e| ServerFnError::new(e.to_string()))?;
let body: serde_json::Value = resp let body: serde_json::Value = resp
.json() .json()
.await .await
@@ -254,17 +270,21 @@ pub async fn send_pentest_message(
session_id: String, session_id: String,
message: String, message: String,
) -> Result<PentestMessagesResponse, ServerFnError> { ) -> Result<PentestMessagesResponse, ServerFnError> {
let resp = super::agent_client::agent_request( let state: super::server_state::ServerState =
reqwest::Method::POST, dioxus_fullstack::FullstackContext::extract().await?;
&format!("/api/v1/pentest/sessions/{session_id}/chat"), let url = format!(
) "{}/api/v1/pentest/sessions/{session_id}/chat",
.await? state.agent_api_url
.json(&serde_json::json!({ );
"message": message, let client = reqwest::Client::new();
})) let resp = client
.send() .post(&url)
.await .json(&serde_json::json!({
.map_err(|e| ServerFnError::new(e.to_string()))?; "message": message,
}))
.send()
.await
.map_err(|e| ServerFnError::new(e.to_string()))?;
let body: PentestMessagesResponse = resp let body: PentestMessagesResponse = resp
.json() .json()
.await .await
@@ -274,27 +294,35 @@ pub async fn send_pentest_message(
#[server] #[server]
pub async fn stop_pentest_session(session_id: String) -> Result<(), ServerFnError> { pub async fn stop_pentest_session(session_id: String) -> Result<(), ServerFnError> {
super::agent_client::agent_request( let state: super::server_state::ServerState =
reqwest::Method::POST, dioxus_fullstack::FullstackContext::extract().await?;
&format!("/api/v1/pentest/sessions/{session_id}/stop"), let url = format!(
) "{}/api/v1/pentest/sessions/{session_id}/stop",
.await? state.agent_api_url
.send() );
.await let client = reqwest::Client::new();
.map_err(|e| ServerFnError::new(e.to_string()))?; client
.post(&url)
.send()
.await
.map_err(|e| ServerFnError::new(e.to_string()))?;
Ok(()) Ok(())
} }
#[server] #[server]
pub async fn pause_pentest_session(session_id: String) -> Result<(), ServerFnError> { pub async fn pause_pentest_session(session_id: String) -> Result<(), ServerFnError> {
let resp = super::agent_client::agent_request( let state: super::server_state::ServerState =
reqwest::Method::POST, dioxus_fullstack::FullstackContext::extract().await?;
&format!("/api/v1/pentest/sessions/{session_id}/pause"), let url = format!(
) "{}/api/v1/pentest/sessions/{session_id}/pause",
.await? state.agent_api_url
.send() );
.await let client = reqwest::Client::new();
.map_err(|e| ServerFnError::new(e.to_string()))?; let resp = client
.post(&url)
.send()
.await
.map_err(|e| ServerFnError::new(e.to_string()))?;
if !resp.status().is_success() { if !resp.status().is_success() {
let text = resp.text().await.unwrap_or_default(); let text = resp.text().await.unwrap_or_default();
return Err(ServerFnError::new(format!("Pause failed: {text}"))); return Err(ServerFnError::new(format!("Pause failed: {text}")));
@@ -304,14 +332,18 @@ pub async fn pause_pentest_session(session_id: String) -> Result<(), ServerFnErr
#[server] #[server]
pub async fn resume_pentest_session(session_id: String) -> Result<(), ServerFnError> { pub async fn resume_pentest_session(session_id: String) -> Result<(), ServerFnError> {
let resp = super::agent_client::agent_request( let state: super::server_state::ServerState =
reqwest::Method::POST, dioxus_fullstack::FullstackContext::extract().await?;
&format!("/api/v1/pentest/sessions/{session_id}/resume"), let url = format!(
) "{}/api/v1/pentest/sessions/{session_id}/resume",
.await? state.agent_api_url
.send() );
.await let client = reqwest::Client::new();
.map_err(|e| ServerFnError::new(e.to_string()))?; let resp = client
.post(&url)
.send()
.await
.map_err(|e| ServerFnError::new(e.to_string()))?;
if !resp.status().is_success() { if !resp.status().is_success() {
let text = resp.text().await.unwrap_or_default(); let text = resp.text().await.unwrap_or_default();
return Err(ServerFnError::new(format!("Resume failed: {text}"))); return Err(ServerFnError::new(format!("Resume failed: {text}")));
@@ -323,12 +355,15 @@ pub async fn resume_pentest_session(session_id: String) -> Result<(), ServerFnEr
pub async fn fetch_pentest_findings( pub async fn fetch_pentest_findings(
session_id: String, session_id: String,
) -> Result<DastFindingsResponse, ServerFnError> { ) -> Result<DastFindingsResponse, ServerFnError> {
let resp = let state: super::server_state::ServerState =
super::agent_client::agent_get(&format!("/api/v1/pentest/sessions/{session_id}/findings")) dioxus_fullstack::FullstackContext::extract().await?;
.await? let url = format!(
.send() "{}/api/v1/pentest/sessions/{session_id}/findings",
.await state.agent_api_url
.map_err(|e| ServerFnError::new(e.to_string()))?; );
let resp = reqwest::get(&url)
.await
.map_err(|e| ServerFnError::new(e.to_string()))?;
let body: DastFindingsResponse = resp let body: DastFindingsResponse = resp
.json() .json()
.await .await
@@ -350,19 +385,23 @@ pub async fn export_pentest_report(
requester_name: String, requester_name: String,
requester_email: String, requester_email: String,
) -> Result<ExportReportResponse, ServerFnError> { ) -> Result<ExportReportResponse, ServerFnError> {
let resp = super::agent_client::agent_request( let state: super::server_state::ServerState =
reqwest::Method::POST, dioxus_fullstack::FullstackContext::extract().await?;
&format!("/api/v1/pentest/sessions/{session_id}/export"), let url = format!(
) "{}/api/v1/pentest/sessions/{session_id}/export",
.await? state.agent_api_url
.json(&serde_json::json!({ );
"password": password, let client = reqwest::Client::new();
"requester_name": requester_name, let resp = client
"requester_email": requester_email, .post(&url)
})) .json(&serde_json::json!({
.send() "password": password,
.await "requester_name": requester_name,
.map_err(|e| ServerFnError::new(e.to_string()))?; "requester_email": requester_email,
}))
.send()
.await
.map_err(|e| ServerFnError::new(e.to_string()))?;
if !resp.status().is_success() { if !resp.status().is_success() {
let text = resp.text().await.unwrap_or_default(); let text = resp.text().await.unwrap_or_default();
return Err(ServerFnError::new(format!("Export failed: {text}"))); return Err(ServerFnError::new(format!("Export failed: {text}")));
@@ -12,10 +12,14 @@ pub struct RepositoryListResponse {
#[server] #[server]
pub async fn fetch_repositories(page: u64) -> Result<RepositoryListResponse, ServerFnError> { pub async fn fetch_repositories(page: u64) -> Result<RepositoryListResponse, ServerFnError> {
let path = format!("/api/v1/repositories?page={page}&limit=20"); let state: super::server_state::ServerState =
let resp = super::agent_client::agent_get(&path) dioxus_fullstack::FullstackContext::extract().await?;
.await? let url = format!(
.send() "{}/api/v1/repositories?page={page}&limit=20",
state.agent_api_url
);
let resp = reqwest::get(&url)
.await .await
.map_err(|e| ServerFnError::new(e.to_string()))?; .map_err(|e| ServerFnError::new(e.to_string()))?;
let body: RepositoryListResponse = resp let body: RepositoryListResponse = resp
@@ -37,6 +41,10 @@ pub async fn add_repository(
tracker_repo: Option<String>, tracker_repo: Option<String>,
tracker_token: Option<String>, tracker_token: Option<String>,
) -> Result<(), ServerFnError> { ) -> Result<(), ServerFnError> {
let state: super::server_state::ServerState =
dioxus_fullstack::FullstackContext::extract().await?;
let url = format!("{}/api/v1/repositories", state.agent_api_url);
let mut body = serde_json::json!({ let mut body = serde_json::json!({
"name": name, "name": name,
"git_url": git_url, "git_url": git_url,
@@ -61,8 +69,9 @@ pub async fn add_repository(
body["tracker_token"] = serde_json::Value::String(tk); body["tracker_token"] = serde_json::Value::String(tk);
} }
let resp = super::agent_client::agent_request(reqwest::Method::POST, "/api/v1/repositories") let client = reqwest::Client::new();
.await? let resp = client
.post(&url)
.json(&body) .json(&body)
.send() .send()
.await .await
@@ -91,6 +100,10 @@ pub async fn update_repository(
tracker_token: Option<String>, tracker_token: Option<String>,
scan_schedule: Option<String>, scan_schedule: Option<String>,
) -> Result<(), ServerFnError> { ) -> Result<(), ServerFnError> {
let state: super::server_state::ServerState =
dioxus_fullstack::FullstackContext::extract().await?;
let url = format!("{}/api/v1/repositories/{repo_id}", state.agent_api_url);
let mut body = serde_json::Map::new(); let mut body = serde_json::Map::new();
if let Some(v) = name.filter(|s| !s.is_empty()) { if let Some(v) = name.filter(|s| !s.is_empty()) {
body.insert("name".into(), serde_json::Value::String(v)); body.insert("name".into(), serde_json::Value::String(v));
@@ -120,15 +133,13 @@ pub async fn update_repository(
body.insert("scan_schedule".into(), serde_json::Value::String(v)); body.insert("scan_schedule".into(), serde_json::Value::String(v));
} }
let resp = super::agent_client::agent_request( let client = reqwest::Client::new();
reqwest::Method::PATCH, let resp = client
&format!("/api/v1/repositories/{repo_id}"), .patch(&url)
) .json(&body)
.await? .send()
.json(&body) .await
.send() .map_err(|e| ServerFnError::new(e.to_string()))?;
.await
.map_err(|e| ServerFnError::new(e.to_string()))?;
if !resp.status().is_success() { if !resp.status().is_success() {
let text = resp.text().await.unwrap_or_default(); let text = resp.text().await.unwrap_or_default();
@@ -142,9 +153,11 @@ pub async fn update_repository(
#[server] #[server]
pub async fn fetch_ssh_public_key() -> Result<String, ServerFnError> { pub async fn fetch_ssh_public_key() -> Result<String, ServerFnError> {
let resp = super::agent_client::agent_get("/api/v1/settings/ssh-public-key") let state: super::server_state::ServerState =
.await? dioxus_fullstack::FullstackContext::extract().await?;
.send() let url = format!("{}/api/v1/settings/ssh-public-key", state.agent_api_url);
let resp = reqwest::get(&url)
.await .await
.map_err(|e| ServerFnError::new(e.to_string()))?; .map_err(|e| ServerFnError::new(e.to_string()))?;
@@ -166,14 +179,16 @@ pub async fn fetch_ssh_public_key() -> Result<String, ServerFnError> {
#[server] #[server]
pub async fn delete_repository(repo_id: String) -> Result<(), ServerFnError> { pub async fn delete_repository(repo_id: String) -> Result<(), ServerFnError> {
let resp = super::agent_client::agent_request( let state: super::server_state::ServerState =
reqwest::Method::DELETE, dioxus_fullstack::FullstackContext::extract().await?;
&format!("/api/v1/repositories/{repo_id}"), let url = format!("{}/api/v1/repositories/{repo_id}", state.agent_api_url);
)
.await? let client = reqwest::Client::new();
.send() let resp = client
.await .delete(&url)
.map_err(|e| ServerFnError::new(e.to_string()))?; .send()
.await
.map_err(|e| ServerFnError::new(e.to_string()))?;
if !resp.status().is_success() { if !resp.status().is_success() {
let body = resp.text().await.unwrap_or_default(); let body = resp.text().await.unwrap_or_default();
@@ -187,14 +202,16 @@ pub async fn delete_repository(repo_id: String) -> Result<(), ServerFnError> {
#[server] #[server]
pub async fn trigger_repo_scan(repo_id: String) -> Result<(), ServerFnError> { pub async fn trigger_repo_scan(repo_id: String) -> Result<(), ServerFnError> {
super::agent_client::agent_request( let state: super::server_state::ServerState =
reqwest::Method::POST, dioxus_fullstack::FullstackContext::extract().await?;
&format!("/api/v1/repositories/{repo_id}/scan"), let url = format!("{}/api/v1/repositories/{repo_id}/scan", state.agent_api_url);
)
.await? let client = reqwest::Client::new();
.send() client
.await .post(&url)
.map_err(|e| ServerFnError::new(e.to_string()))?; .send()
.await
.map_err(|e| ServerFnError::new(e.to_string()))?;
Ok(()) Ok(())
} }
@@ -207,12 +224,16 @@ pub struct WebhookConfigResponse {
#[server] #[server]
pub async fn fetch_webhook_config(repo_id: String) -> Result<WebhookConfigResponse, ServerFnError> { pub async fn fetch_webhook_config(repo_id: String) -> Result<WebhookConfigResponse, ServerFnError> {
let resp = let state: super::server_state::ServerState =
super::agent_client::agent_get(&format!("/api/v1/repositories/{repo_id}/webhook-config")) dioxus_fullstack::FullstackContext::extract().await?;
.await? let url = format!(
.send() "{}/api/v1/repositories/{repo_id}/webhook-config",
.await state.agent_api_url
.map_err(|e| ServerFnError::new(e.to_string()))?; );
let resp = reqwest::get(&url)
.await
.map_err(|e| ServerFnError::new(e.to_string()))?;
let body: WebhookConfigResponse = resp let body: WebhookConfigResponse = resp
.json() .json()
.await .await
@@ -223,9 +244,11 @@ pub async fn fetch_webhook_config(repo_id: String) -> Result<WebhookConfigRespon
/// Check if a repository has any running scans /// Check if a repository has any running scans
#[server] #[server]
pub async fn check_repo_scanning(repo_id: String) -> Result<bool, ServerFnError> { pub async fn check_repo_scanning(repo_id: String) -> Result<bool, ServerFnError> {
let resp = super::agent_client::agent_get("/api/v1/scan-runs?page=1&limit=1") let state: super::server_state::ServerState =
.await? dioxus_fullstack::FullstackContext::extract().await?;
.send() let url = format!("{}/api/v1/scan-runs?page=1&limit=1", state.agent_api_url);
let resp = reqwest::get(&url)
.await .await
.map_err(|e| ServerFnError::new(e.to_string()))?; .map_err(|e| ServerFnError::new(e.to_string()))?;
let body: serde_json::Value = resp let body: serde_json::Value = resp
+35 -20
View File
@@ -87,9 +87,11 @@ pub struct SbomFiltersResponse {
#[server] #[server]
pub async fn fetch_sbom_filters() -> Result<SbomFiltersResponse, ServerFnError> { pub async fn fetch_sbom_filters() -> Result<SbomFiltersResponse, ServerFnError> {
let resp = super::agent_client::agent_get("/api/v1/sbom/filters") let state: super::server_state::ServerState =
.await? dioxus_fullstack::FullstackContext::extract().await?;
.send()
let url = format!("{}/api/v1/sbom/filters", state.agent_api_url);
let resp = reqwest::get(&url)
.await .await
.map_err(|e| ServerFnError::new(e.to_string()))?; .map_err(|e| ServerFnError::new(e.to_string()))?;
let text = resp let text = resp
@@ -110,6 +112,9 @@ pub async fn fetch_sbom_filtered(
license: Option<String>, license: Option<String>,
page: u64, page: u64,
) -> Result<SbomListResponse, ServerFnError> { ) -> Result<SbomListResponse, ServerFnError> {
let state: super::server_state::ServerState =
dioxus_fullstack::FullstackContext::extract().await?;
let mut params = vec![format!("page={page}"), "limit=50".to_string()]; let mut params = vec![format!("page={page}"), "limit=50".to_string()];
if let Some(r) = &repo_id { if let Some(r) = &repo_id {
if !r.is_empty() { if !r.is_empty() {
@@ -135,10 +140,9 @@ pub async fn fetch_sbom_filtered(
} }
} }
let path = format!("/api/v1/sbom?{}", params.join("&")); let url = format!("{}/api/v1/sbom?{}", state.agent_api_url, params.join("&"));
let resp = super::agent_client::agent_get(&path)
.await? let resp = reqwest::get(&url)
.send()
.await .await
.map_err(|e| ServerFnError::new(e.to_string()))?; .map_err(|e| ServerFnError::new(e.to_string()))?;
let text = resp let text = resp
@@ -152,10 +156,15 @@ pub async fn fetch_sbom_filtered(
#[server] #[server]
pub async fn fetch_sbom_export(repo_id: String, format: String) -> Result<String, ServerFnError> { pub async fn fetch_sbom_export(repo_id: String, format: String) -> Result<String, ServerFnError> {
let path = format!("/api/v1/sbom/export?repo_id={repo_id}&format={format}"); let state: super::server_state::ServerState =
let resp = super::agent_client::agent_get(&path) dioxus_fullstack::FullstackContext::extract().await?;
.await?
.send() let url = format!(
"{}/api/v1/sbom/export?repo_id={}&format={}",
state.agent_api_url, repo_id, format
);
let resp = reqwest::get(&url)
.await .await
.map_err(|e| ServerFnError::new(e.to_string()))?; .map_err(|e| ServerFnError::new(e.to_string()))?;
let text = resp let text = resp
@@ -169,16 +178,17 @@ pub async fn fetch_sbom_export(repo_id: String, format: String) -> Result<String
pub async fn fetch_license_summary( pub async fn fetch_license_summary(
repo_id: Option<String>, repo_id: Option<String>,
) -> Result<LicenseSummaryResponse, ServerFnError> { ) -> Result<LicenseSummaryResponse, ServerFnError> {
let mut path = "/api/v1/sbom/licenses".to_string(); let state: super::server_state::ServerState =
dioxus_fullstack::FullstackContext::extract().await?;
let mut url = format!("{}/api/v1/sbom/licenses", state.agent_api_url);
if let Some(r) = &repo_id { if let Some(r) = &repo_id {
if !r.is_empty() { if !r.is_empty() {
path = format!("{path}?repo_id={r}"); url = format!("{url}?repo_id={r}");
} }
} }
let resp = super::agent_client::agent_get(&path) let resp = reqwest::get(&url)
.await?
.send()
.await .await
.map_err(|e| ServerFnError::new(e.to_string()))?; .map_err(|e| ServerFnError::new(e.to_string()))?;
let text = resp let text = resp
@@ -195,10 +205,15 @@ pub async fn fetch_sbom_diff(
repo_a: String, repo_a: String,
repo_b: String, repo_b: String,
) -> Result<SbomDiffResponse, ServerFnError> { ) -> Result<SbomDiffResponse, ServerFnError> {
let path = format!("/api/v1/sbom/diff?repo_a={repo_a}&repo_b={repo_b}"); let state: super::server_state::ServerState =
let resp = super::agent_client::agent_get(&path) dioxus_fullstack::FullstackContext::extract().await?;
.await?
.send() let url = format!(
"{}/api/v1/sbom/diff?repo_a={}&repo_b={}",
state.agent_api_url, repo_a, repo_b
);
let resp = reqwest::get(&url)
.await .await
.map_err(|e| ServerFnError::new(e.to_string()))?; .map_err(|e| ServerFnError::new(e.to_string()))?;
let text = resp let text = resp
@@ -12,9 +12,14 @@ pub struct ScansListResponse {
#[server] #[server]
pub async fn fetch_scan_runs(page: u64) -> Result<ScansListResponse, ServerFnError> { pub async fn fetch_scan_runs(page: u64) -> Result<ScansListResponse, ServerFnError> {
let resp = super::agent_client::agent_get(&format!("/api/v1/scan-runs?page={page}&limit=20")) let state: super::server_state::ServerState =
.await? dioxus_fullstack::FullstackContext::extract().await?;
.send() let url = format!(
"{}/api/v1/scan-runs?page={page}&limit=20",
state.agent_api_url
);
let resp = reqwest::get(&url)
.await .await
.map_err(|e| ServerFnError::new(e.to_string()))?; .map_err(|e| ServerFnError::new(e.to_string()))?;
let body: ScansListResponse = resp let body: ScansListResponse = resp
@@ -16,9 +16,11 @@ pub struct OverviewStats {
#[server] #[server]
pub async fn fetch_overview_stats() -> Result<OverviewStats, ServerFnError> { pub async fn fetch_overview_stats() -> Result<OverviewStats, ServerFnError> {
let resp = super::agent_client::agent_get("/api/v1/stats/overview") let state: super::server_state::ServerState =
.await? dioxus_fullstack::FullstackContext::extract().await?;
.send() let url = format!("{}/api/v1/stats/overview", state.agent_api_url);
let resp = reqwest::get(&url)
.await .await
.map_err(|e| ServerFnError::new(e.to_string()))?; .map_err(|e| ServerFnError::new(e.to_string()))?;
let body: serde_json::Value = resp let body: serde_json::Value = resp
+1 -4
View File
@@ -4,7 +4,7 @@ version = "0.1.0"
edition = "2021" edition = "2021"
[dependencies] [dependencies]
compliance-core = { workspace = true, features = ["mongodb", "axum"] } compliance-core = { workspace = true, features = ["mongodb"] }
rmcp = { version = "0.16", features = ["server", "macros", "transport-io", "transport-streamable-http-server"] } rmcp = { version = "0.16", features = ["server", "macros", "transport-io", "transport-streamable-http-server"] }
tokio = { workspace = true } tokio = { workspace = true }
serde = { workspace = true } serde = { workspace = true }
@@ -19,6 +19,3 @@ bson = { version = "2", features = ["chrono-0_4"] }
schemars = "1.0" schemars = "1.0"
axum = "0.8" axum = "0.8"
tower-http = { version = "0.6", features = ["cors"] } tower-http = { version = "0.6", features = ["cors"] }
sha2 = { workspace = true }
hex = { workspace = true }
dashmap = { workspace = true }
-129
View File
@@ -1,129 +0,0 @@
//! Bearer-token authentication for incoming MCP HTTP requests.
//!
//! LLM clients (Claude Desktop / Cursor / ChatGPT / etc.) can't run
//! Keycloak OIDC, so the MCP server uses opaque static tokens minted
//! per-tenant via the agent's `POST /api/v1/mcp-tokens` endpoint.
//!
//! Flow per request:
//! 1. Extract `Authorization: Bearer <token>`. Missing → 401.
//! 2. SHA-256 hash the token.
//! 3. Look up the hash in `<prefix>__admin.mcp_tokens`. Missing or
//! revoked → 401.
//! 4. Fire-and-forget update of `last_used_at` so the dashboard can
//! show staleness without blocking the handler.
//! 5. Stash the tenant_id in [`TENANT_ID`] (a `tokio::task_local`) so
//! the MCP tool handlers can read it without modifying rmcp's
//! handler signatures.
//!
//! The `task_local` is scoped around the inner service call via
//! [`bearer_auth`], so every handler invoked downstream sees the
//! tenant_id without us having to thread it through the macro-
//! generated tool router.
use axum::body::Body;
use axum::extract::{Request, State};
use axum::http::StatusCode;
use axum::middleware::Next;
use axum::response::{IntoResponse, Response};
use mongodb::bson::doc;
use sha2::{Digest, Sha256};
use crate::database::DatabasePool;
tokio::task_local! {
/// Tenant id resolved from the bearer for this request. Set by
/// [`bearer_auth`] before the inner service runs; read by the
/// MCP tool handlers via [`current_tenant_id`].
pub static TENANT_ID: String;
}
/// Mongo collection name in `<prefix>__admin`.
const COLLECTION: &str = "mcp_tokens";
/// Returns the tenant_id set by the auth middleware. `None` outside a
/// request scope (e.g. unit tests that bypass the middleware).
pub fn current_tenant_id() -> Option<String> {
TENANT_ID.try_with(|s| s.clone()).ok()
}
/// Axum middleware: validate bearer → set [`TENANT_ID`] → call inner.
pub async fn bearer_auth(
State(pool): State<DatabasePool>,
request: Request,
next: Next,
) -> Response {
let Some(token) = extract_bearer(&request) else {
return (StatusCode::UNAUTHORIZED, "Missing bearer token").into_response();
};
if !token.starts_with("mcpt_") {
return (StatusCode::UNAUTHORIZED, "Invalid token format").into_response();
}
let token_hash = sha256_hex(&token);
let col = pool.admin_db().collection::<TokenLookup>(COLLECTION);
let found = match col
.find_one(doc! { "token_hash": &token_hash, "revoked": false })
.await
{
Ok(Some(t)) => t,
Ok(None) => {
return (StatusCode::UNAUTHORIZED, "Invalid or revoked token").into_response();
}
Err(e) => {
tracing::error!("MCP token lookup failed: {e}");
return (StatusCode::INTERNAL_SERVER_ERROR, "Token lookup error").into_response();
}
};
// Fire-and-forget last_used_at update — never block the handler.
let col2 = pool.admin_db().collection::<TokenLookup>(COLLECTION);
let hash_for_update = token_hash.clone();
tokio::spawn(async move {
let _ = col2
.update_one(
doc! { "token_hash": &hash_for_update },
doc! { "$set": { "last_used_at": mongodb::bson::DateTime::now() } },
)
.await;
});
let tenant_id = found.tenant_id;
let inner = next.run(request);
TENANT_ID.scope(tenant_id, inner).await
}
/// Bare-bones projection — we don't need the whole `McpToken` here,
/// just enough to route and confirm validity.
#[derive(serde::Deserialize)]
struct TokenLookup {
tenant_id: String,
}
fn extract_bearer(req: &Request<Body>) -> Option<String> {
req.headers()
.get(axum::http::header::AUTHORIZATION)
.and_then(|v| v.to_str().ok())
.and_then(|s| s.strip_prefix("Bearer "))
.map(|s| s.trim().to_string())
.filter(|s| !s.is_empty())
}
fn sha256_hex(s: &str) -> String {
let mut h = Sha256::new();
h.update(s.as_bytes());
hex::encode(h.finalize())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn sha256_known_value() {
// python -c 'import hashlib; print(hashlib.sha256(b"mcpt_known").hexdigest())'
assert_eq!(
sha256_hex("mcpt_known"),
"27cf6cf678a44244106863c1c031be8e57b84c2b3019d742f755f8e7afa75dfd"
);
}
}
+7 -115
View File
@@ -1,127 +1,19 @@
//! Per-tenant Mongo broker for the MCP server. use mongodb::{Client, Collection};
//!
//! Mirror of the agent's `compliance_agent::database::DatabasePool` —
//! duplicated here rather than lifted into `compliance-core` to keep
//! this PR focused. If a third consumer ever needs it, lift then.
//!
//! Bearer tokens (validated by the auth middleware) carry a tenant_id
//! and the handler resolves the per-tenant database via
//! [`DatabasePool::for_tenant_id`]. The admin database
//! (`<db_prefix>__admin`) holds the cross-tenant `mcp_tokens`
//! collection that the middleware queries on every request.
use std::sync::Arc;
use dashmap::DashMap;
use mongodb::{bson::doc, Client, Collection};
use sha2::{Digest, Sha256};
use compliance_core::models::*; use compliance_core::models::*;
/// 63-byte Mongo db-name cap; same invariant as the agent's pool.
const MAX_DB_NAME_LEN: usize = 63;
/// 16-byte SHA-256 truncation, hex-encoded → 32 chars.
const HASH_HEX_LEN: usize = 32;
const MAX_PREFIX_LEN: usize = MAX_DB_NAME_LEN - 1 - HASH_HEX_LEN;
#[derive(Clone, Debug)]
pub struct DatabasePool {
client: Client,
db_prefix: String,
/// Tenants we've handed out a [`Database`] for. The MCP server
/// doesn't ensure indexes (the agent owns that side of the
/// schema), so the marker exists only to satisfy the parallel
/// shape — current code never reads it.
#[allow(dead_code)]
seen: Arc<DashMap<String, ()>>,
}
#[derive(Debug, thiserror::Error)]
pub enum DbError {
#[error("db_prefix '{prefix}' is {len} chars; max is {max} so the hash-fallback DB name fits Mongo's 63-byte cap")]
PrefixTooLong {
prefix: String,
len: usize,
max: usize,
},
#[error(transparent)]
Mongo(#[from] mongodb::error::Error),
}
impl DatabasePool {
pub async fn connect(uri: &str, db_prefix: &str) -> Result<Self, DbError> {
if db_prefix.len() > MAX_PREFIX_LEN {
return Err(DbError::PrefixTooLong {
prefix: db_prefix.to_string(),
len: db_prefix.len(),
max: MAX_PREFIX_LEN,
});
}
let client = Client::with_uri_str(uri).await?;
client
.database("admin")
.run_command(doc! { "ping": 1 })
.await?;
tracing::info!(
"MCP MongoDB cluster reachable; per-tenant pool ready (db prefix '{db_prefix}')"
);
Ok(Self {
client,
db_prefix: db_prefix.to_string(),
seen: Arc::new(DashMap::new()),
})
}
/// Read-only handle to the tenant's database. No indexes are
/// ensured here — the agent owns writes, MCP only reads.
pub fn for_tenant_id(&self, tenant_id: &str) -> Database {
let db_name = self.tenant_db_name(tenant_id);
self.seen.insert(tenant_id.to_string(), ());
Database::new(self.client.database(&db_name))
}
/// Cross-tenant admin DB — holds the `mcp_tokens` collection that
/// the auth middleware queries to map bearer → tenant_id.
pub fn admin_db(&self) -> mongodb::Database {
self.client.database(&format!("{}__admin", self.db_prefix))
}
pub fn tenant_db_name(&self, tenant_id: &str) -> String {
let sanitized = sanitize_tenant_id(tenant_id);
let natural = format!("{}_{}", self.db_prefix, sanitized);
if natural.len() <= MAX_DB_NAME_LEN {
natural
} else {
let mut h = Sha256::new();
h.update(tenant_id.as_bytes());
let digest = h.finalize();
let suffix = hex::encode(&digest[..HASH_HEX_LEN / 2]);
format!("{}_{}", self.db_prefix, suffix)
}
}
}
fn sanitize_tenant_id(tenant_id: &str) -> String {
tenant_id
.chars()
.map(|c| match c {
'/' | '\\' | '.' | '"' | '$' | ' ' | '\0' => '_',
c => c,
})
.collect()
}
/// Typed accessors for the MCP-readable collections in a tenant DB.
/// Matches the agent's `Database` shape but only exposes what the MCP
/// tool handlers actually need.
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub struct Database { pub struct Database {
inner: mongodb::Database, inner: mongodb::Database,
} }
impl Database { impl Database {
pub(crate) fn new(inner: mongodb::Database) -> Self { pub async fn connect(uri: &str, db_name: &str) -> Result<Self, mongodb::error::Error> {
Self { inner } let client = Client::with_uri_str(uri).await?;
let db = client.database(db_name);
db.run_command(mongodb::bson::doc! { "ping": 1 }).await?;
tracing::info!("MCP server connected to MongoDB '{db_name}'");
Ok(Self { inner: db })
} }
pub fn findings(&self) -> Collection<Finding> { pub fn findings(&self) -> Collection<Finding> {
+10 -35
View File
@@ -1,11 +1,10 @@
mod auth;
mod database; mod database;
mod server; mod server;
mod tools; mod tools;
use std::sync::Arc; use std::sync::Arc;
use database::DatabasePool; use database::Database;
use rmcp::transport::{ use rmcp::transport::{
streamable_http_server::session::local::LocalSessionManager, StreamableHttpServerConfig, streamable_http_server::session::local::LocalSessionManager, StreamableHttpServerConfig,
StreamableHttpService, StreamableHttpService,
@@ -25,60 +24,36 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
let mongo_uri = let mongo_uri =
std::env::var("MONGODB_URI").unwrap_or_else(|_| "mongodb://localhost:27017".to_string()); std::env::var("MONGODB_URI").unwrap_or_else(|_| "mongodb://localhost:27017".to_string());
// MONGODB_DATABASE is reused as the per-tenant DB-name prefix — let db_name =
// same convention as the agent so `<prefix>__admin.mcp_tokens`
// and `<prefix>_<tenant_id>` line up across services.
let db_prefix =
std::env::var("MONGODB_DATABASE").unwrap_or_else(|_| "compliance_scanner".to_string()); std::env::var("MONGODB_DATABASE").unwrap_or_else(|_| "compliance_scanner".to_string());
let pool = DatabasePool::connect(&mongo_uri, &db_prefix).await?; let db = Database::connect(&mongo_uri, &db_name).await?;
// HTTP transport: bind a small axum router with bearer-auth in // If MCP_PORT is set, run as Streamable HTTP server; otherwise use stdio.
// front of the rmcp service. `/health` stays public for orca's
// container probe.
if let Ok(port_str) = std::env::var("MCP_PORT") { if let Ok(port_str) = std::env::var("MCP_PORT") {
let port: u16 = port_str.parse()?; let port: u16 = port_str.parse()?;
tracing::info!("Starting MCP server on HTTP port {port}"); tracing::info!("Starting MCP server on HTTP port {port}");
let pool_for_factory = pool.clone(); let db_clone = db.clone();
let service = StreamableHttpService::new( let service = StreamableHttpService::new(
move || Ok(ComplianceMcpServer::new(pool_for_factory.clone())), move || Ok(ComplianceMcpServer::new(db_clone.clone())),
Arc::new(LocalSessionManager::default()), Arc::new(LocalSessionManager::default()),
StreamableHttpServerConfig::default(), StreamableHttpServerConfig::default(),
); );
let router = axum::Router::new() let router = axum::Router::new()
.route("/health", axum::routing::get(|| async { "ok" })) .route("/health", axum::routing::get(|| async { "ok" }))
.nest_service( .nest_service("/mcp", service);
"/mcp",
axum::Router::new().fallback_service(service).layer(
axum::middleware::from_fn_with_state(pool.clone(), auth::bearer_auth),
),
);
let listener = tokio::net::TcpListener::bind(("0.0.0.0", port)).await?; let listener = tokio::net::TcpListener::bind(("0.0.0.0", port)).await?;
tracing::info!("MCP HTTP server listening on 0.0.0.0:{port}"); tracing::info!("MCP HTTP server listening on 0.0.0.0:{port}");
axum::serve(listener, router).await?; axum::serve(listener, router).await?;
} else { } else {
// stdio transport — used when run as a local MCP server next
// to the LLM client. There's no HTTP layer to do bearer auth,
// so we synthesize a tenant_id from STDIO_TENANT_ID for local
// development. NEVER use this in production.
tracing::info!("Starting MCP server on stdio"); tracing::info!("Starting MCP server on stdio");
let synth_tenant = std::env::var("STDIO_TENANT_ID").unwrap_or_else(|_| "dev".to_string()); let server = ComplianceMcpServer::new(db);
tracing::warn!(
tenant_id = %synth_tenant,
"stdio transport — using synthetic tenant id; DO NOT use in production"
);
let server = ComplianceMcpServer::new(pool);
let transport = rmcp::transport::stdio(); let transport = rmcp::transport::stdio();
use rmcp::ServiceExt; use rmcp::ServiceExt;
auth::TENANT_ID let handle = server.serve(transport).await?;
.scope(synth_tenant, async { handle.waiting().await?;
let handle = server.serve(transport).await?;
handle.waiting().await?;
Ok::<_, Box<dyn std::error::Error>>(())
})
.await?;
} }
Ok(()) Ok(())
+17 -46
View File
@@ -2,37 +2,20 @@ use rmcp::{
handler::server::wrapper::Parameters, model::*, tool, tool_handler, tool_router, ServerHandler, handler::server::wrapper::Parameters, model::*, tool, tool_handler, tool_router, ServerHandler,
}; };
use crate::auth::current_tenant_id; use crate::database::Database;
use crate::database::{Database, DatabasePool};
use crate::tools::{dast, findings, pentest, sbom}; use crate::tools::{dast, findings, pentest, sbom};
pub struct ComplianceMcpServer { pub struct ComplianceMcpServer {
pool: DatabasePool, db: Database,
#[allow(dead_code)] #[allow(dead_code)]
tool_router: rmcp::handler::server::router::tool::ToolRouter<Self>, tool_router: rmcp::handler::server::router::tool::ToolRouter<Self>,
} }
impl ComplianceMcpServer {
/// Resolve the per-tenant `Database` from the bearer-set
/// `task_local`. Every tool handler calls this; missing context
/// surfaces as `internal_error` because it means the auth
/// middleware was misconfigured (handler ran without scope).
fn tenant_db(&self) -> Result<Database, rmcp::ErrorData> {
let tenant_id = current_tenant_id().ok_or_else(|| {
rmcp::ErrorData::internal_error(
"no tenant context — bearer middleware not in chain".to_string(),
None,
)
})?;
Ok(self.pool.for_tenant_id(&tenant_id))
}
}
#[tool_router] #[tool_router]
impl ComplianceMcpServer { impl ComplianceMcpServer {
pub fn new(pool: DatabasePool) -> Self { pub fn new(db: Database) -> Self {
Self { Self {
pool, db,
tool_router: Self::tool_router(), tool_router: Self::tool_router(),
} }
} }
@@ -46,8 +29,7 @@ impl ComplianceMcpServer {
&self, &self,
Parameters(params): Parameters<findings::ListFindingsParams>, Parameters(params): Parameters<findings::ListFindingsParams>,
) -> Result<CallToolResult, rmcp::ErrorData> { ) -> Result<CallToolResult, rmcp::ErrorData> {
let db = self.tenant_db()?; findings::list_findings(&self.db, params).await
findings::list_findings(&db, params).await
} }
#[tool(description = "Get a single finding by its ID")] #[tool(description = "Get a single finding by its ID")]
@@ -55,8 +37,7 @@ impl ComplianceMcpServer {
&self, &self,
Parameters(params): Parameters<findings::GetFindingParams>, Parameters(params): Parameters<findings::GetFindingParams>,
) -> Result<CallToolResult, rmcp::ErrorData> { ) -> Result<CallToolResult, rmcp::ErrorData> {
let db = self.tenant_db()?; findings::get_finding(&self.db, params).await
findings::get_finding(&db, params).await
} }
#[tool(description = "Get a summary of findings counts grouped by severity and status")] #[tool(description = "Get a summary of findings counts grouped by severity and status")]
@@ -64,8 +45,7 @@ impl ComplianceMcpServer {
&self, &self,
Parameters(params): Parameters<findings::FindingsSummaryParams>, Parameters(params): Parameters<findings::FindingsSummaryParams>,
) -> Result<CallToolResult, rmcp::ErrorData> { ) -> Result<CallToolResult, rmcp::ErrorData> {
let db = self.tenant_db()?; findings::findings_summary(&self.db, params).await
findings::findings_summary(&db, params).await
} }
// ── SBOM ────────────────────────────────────────────── // ── SBOM ──────────────────────────────────────────────
@@ -77,8 +57,7 @@ impl ComplianceMcpServer {
&self, &self,
Parameters(params): Parameters<sbom::ListSbomPackagesParams>, Parameters(params): Parameters<sbom::ListSbomPackagesParams>,
) -> Result<CallToolResult, rmcp::ErrorData> { ) -> Result<CallToolResult, rmcp::ErrorData> {
let db = self.tenant_db()?; sbom::list_sbom_packages(&self.db, params).await
sbom::list_sbom_packages(&db, params).await
} }
#[tool( #[tool(
@@ -88,8 +67,7 @@ impl ComplianceMcpServer {
&self, &self,
Parameters(params): Parameters<sbom::SbomVulnReportParams>, Parameters(params): Parameters<sbom::SbomVulnReportParams>,
) -> Result<CallToolResult, rmcp::ErrorData> { ) -> Result<CallToolResult, rmcp::ErrorData> {
let db = self.tenant_db()?; sbom::sbom_vuln_report(&self.db, params).await
sbom::sbom_vuln_report(&db, params).await
} }
// ── DAST ────────────────────────────────────────────── // ── DAST ──────────────────────────────────────────────
@@ -101,8 +79,7 @@ impl ComplianceMcpServer {
&self, &self,
Parameters(params): Parameters<dast::ListDastFindingsParams>, Parameters(params): Parameters<dast::ListDastFindingsParams>,
) -> Result<CallToolResult, rmcp::ErrorData> { ) -> Result<CallToolResult, rmcp::ErrorData> {
let db = self.tenant_db()?; dast::list_dast_findings(&self.db, params).await
dast::list_dast_findings(&db, params).await
} }
#[tool(description = "Get a summary of recent DAST scan runs and finding counts")] #[tool(description = "Get a summary of recent DAST scan runs and finding counts")]
@@ -110,8 +87,7 @@ impl ComplianceMcpServer {
&self, &self,
Parameters(params): Parameters<dast::DastScanSummaryParams>, Parameters(params): Parameters<dast::DastScanSummaryParams>,
) -> Result<CallToolResult, rmcp::ErrorData> { ) -> Result<CallToolResult, rmcp::ErrorData> {
let db = self.tenant_db()?; dast::dast_scan_summary(&self.db, params).await
dast::dast_scan_summary(&db, params).await
} }
// ── Pentest ───────────────────────────────────────────── // ── Pentest ─────────────────────────────────────────────
@@ -123,8 +99,7 @@ impl ComplianceMcpServer {
&self, &self,
Parameters(params): Parameters<pentest::ListPentestSessionsParams>, Parameters(params): Parameters<pentest::ListPentestSessionsParams>,
) -> Result<CallToolResult, rmcp::ErrorData> { ) -> Result<CallToolResult, rmcp::ErrorData> {
let db = self.tenant_db()?; pentest::list_pentest_sessions(&self.db, params).await
pentest::list_pentest_sessions(&db, params).await
} }
#[tool(description = "Get a single AI pentest session by its ID")] #[tool(description = "Get a single AI pentest session by its ID")]
@@ -132,8 +107,7 @@ impl ComplianceMcpServer {
&self, &self,
Parameters(params): Parameters<pentest::GetPentestSessionParams>, Parameters(params): Parameters<pentest::GetPentestSessionParams>,
) -> Result<CallToolResult, rmcp::ErrorData> { ) -> Result<CallToolResult, rmcp::ErrorData> {
let db = self.tenant_db()?; pentest::get_pentest_session(&self.db, params).await
pentest::get_pentest_session(&db, params).await
} }
#[tool( #[tool(
@@ -143,8 +117,7 @@ impl ComplianceMcpServer {
&self, &self,
Parameters(params): Parameters<pentest::GetAttackChainParams>, Parameters(params): Parameters<pentest::GetAttackChainParams>,
) -> Result<CallToolResult, rmcp::ErrorData> { ) -> Result<CallToolResult, rmcp::ErrorData> {
let db = self.tenant_db()?; pentest::get_attack_chain(&self.db, params).await
pentest::get_attack_chain(&db, params).await
} }
#[tool(description = "Get chat messages from a pentest session")] #[tool(description = "Get chat messages from a pentest session")]
@@ -152,8 +125,7 @@ impl ComplianceMcpServer {
&self, &self,
Parameters(params): Parameters<pentest::GetPentestMessagesParams>, Parameters(params): Parameters<pentest::GetPentestMessagesParams>,
) -> Result<CallToolResult, rmcp::ErrorData> { ) -> Result<CallToolResult, rmcp::ErrorData> {
let db = self.tenant_db()?; pentest::get_pentest_messages(&self.db, params).await
pentest::get_pentest_messages(&db, params).await
} }
#[tool( #[tool(
@@ -163,8 +135,7 @@ impl ComplianceMcpServer {
&self, &self,
Parameters(params): Parameters<pentest::PentestStatsParams>, Parameters(params): Parameters<pentest::PentestStatsParams>,
) -> Result<CallToolResult, rmcp::ErrorData> { ) -> Result<CallToolResult, rmcp::ErrorData> {
let db = self.tenant_db()?; pentest::pentest_stats(&self.db, params).await
pentest::pentest_stats(&db, params).await
} }
} }
@@ -178,7 +149,7 @@ impl ServerHandler for ComplianceMcpServer {
.build(), .build(),
server_info: Implementation::from_build_env(), server_info: Implementation::from_build_env(),
instructions: Some( instructions: Some(
"Compliance Scanner MCP server. Query security findings, SBOM data, DAST results, and AI pentest sessions for your tenant." "Compliance Scanner MCP server. Query security findings, SBOM data, DAST results, and AI pentest sessions."
.to_string(), .to_string(),
), ),
} }