Compare commits

..

5 Commits

Author SHA1 Message Date
Sharang Parnerkar bec47f8c7d feat(dashboard): proactively refresh expired Keycloak tokens
CI / Check (pull_request) Successful in 8m7s
CI / Detect Changes (pull_request) Has been skipped
CI / Deploy Agent (pull_request) Has been skipped
CI / Deploy Dashboard (pull_request) Has been skipped
CI / Deploy Docs (pull_request) Has been skipped
CI / Deploy MCP (pull_request) Has been skipped
The dashboard stored a refresh_token in the session at login (auth.rs)
but never used it. Once the access_token's 5-minute lifespan ran out,
every subsequent agent call failed with 401 ExpiredSignature. The UI
showed "unable to load X" until the user logged out and back in.

Fix: before attaching the bearer, decode the JWT's `exp` claim and
proactively refresh via the stored refresh_token if the token is
expired or within REFRESH_SKEW_SECS (30s) of expiry. Updates the
session with the new access_token (and rotated refresh_token if KC
sends one). Refresh failures fall through with the stale token so the
agent's 401 surfaces to the UI rather than failing the request at the
dashboard layer.

Why "proactive" instead of "retry on 401"
- Saves a wasted round-trip on every agent call once the token has
  aged past 5 min.
- Doesn't require cloning RequestBuilder bodies for retry.
- Same end state — fresh token reaches the agent.

Test plan
- cargo test -p compliance-dashboard --features server
  --no-default-features infrastructure::agent_client::tests — 5 pass:
    * expired JWT → refresh
    * near-expiry within skew window → refresh
    * fresh JWT → no refresh
    * malformed/empty JWT → refresh (defensive)
    * JWT without exp claim → refresh (defensive)
- Manual after deploy: dashboard works past the 5-min token lifespan
  without manual re-login.

Note
- The refresh code addresses the ExpiredSignature failure mode. The
  separate "JWT is missing tenant_id claim" 401 is a Keycloak realm
  config issue (the user logging in lacks the M7.1 attributes that
  the protocol mappers consume) and is fixed by realm/attribute
  config, not by this PR.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
2026-06-17 21:38:06 +02:00
sharang 56482911b8 fix(dashboard): attach Keycloak token on agent API calls (#90)
CI / Check (push) Has been skipped
CI / Detect Changes (push) Successful in 6s
CI / Deploy Agent (push) Successful in 4m8s
CI / Deploy Dashboard (push) Successful in 4m58s
CI / Deploy Docs (push) Has been skipped
CI / Deploy MCP (push) Has been skipped
2026-06-17 18:35:59 +00:00
sharang 183234f9af feat(m7.1): wire compliance-agent to compliance-core auth + status gate (#85)
CI / Check (push) Has been skipped
CI / Detect Changes (push) Successful in 5s
CI / Deploy Agent (push) Successful in 8m38s
CI / Deploy Dashboard (push) Successful in 7m30s
CI / Deploy Docs (push) Has been skipped
CI / Deploy MCP (push) Successful in 1m55s
2026-06-17 09:36:52 +00:00
sharang dbadff0aac fix(m7.1): JWKS refresh-on-failure in auth middleware (#84)
CI / Check (push) Has been skipped
CI / Detect Changes (push) Successful in 3s
CI / Deploy Agent (push) Successful in 11m44s
CI / Deploy Dashboard (push) Successful in 13m1s
CI / Deploy Docs (push) Has been skipped
CI / Deploy MCP (push) Successful in 1m53s
2026-06-04 14:46:14 +00:00
sharang 116293519d M7.1 smoke harness: lift auth to compliance-core + compliance-smoke service (#83)
CI / Check (push) Has been cancelled
CI / Detect Changes (push) Has been cancelled
CI / Deploy Agent (push) Has been cancelled
CI / Deploy Dashboard (push) Has been cancelled
CI / Deploy Docs (push) Has been cancelled
CI / Deploy MCP (push) Has been cancelled
2026-06-04 14:38:35 +00:00
52 changed files with 2001 additions and 793 deletions
Generated
+18
View File
@@ -701,19 +701,23 @@ dependencies = [
name = "compliance-core"
version = "0.1.0"
dependencies = [
"axum",
"bson",
"chrono",
"hex",
"jsonwebtoken",
"mongodb",
"opentelemetry",
"opentelemetry-appender-tracing",
"opentelemetry-otlp",
"opentelemetry_sdk",
"reqwest",
"secrecy",
"serde",
"serde_json",
"sha2",
"thiserror 2.0.18",
"tokio",
"tracing",
"tracing-opentelemetry",
"tracing-subscriber",
@@ -827,6 +831,20 @@ dependencies = [
"tracing-subscriber",
]
[[package]]
name = "compliance-smoke"
version = "0.1.0"
dependencies = [
"axum",
"compliance-core",
"reqwest",
"serde",
"serde_json",
"tokio",
"tracing",
"tracing-subscriber",
]
[[package]]
name = "console_error_panic_hook"
version = "0.1.7"
+1
View File
@@ -6,6 +6,7 @@ members = [
"compliance-graph",
"compliance-dast",
"compliance-mcp",
"compliance-smoke",
]
resolver = "2"
+3 -3
View File
@@ -7,7 +7,7 @@ edition = "2021"
workspace = true
[dependencies]
compliance-core = { workspace = true, features = ["mongodb", "telemetry"] }
compliance-core = { workspace = true, features = ["mongodb", "telemetry", "axum"] }
compliance-graph = { path = "../compliance-graph" }
compliance-dast = { path = "../compliance-dast" }
serde = { workspace = true }
@@ -44,7 +44,8 @@ dashmap = { workspace = true }
tokio-stream = { workspace = true }
[dev-dependencies]
compliance-core = { workspace = true, features = ["mongodb"] }
compliance-core = { workspace = true, features = ["mongodb", "axum"] }
tower = { version = "0.5", features = ["util"] }
reqwest = { workspace = true }
serde_json = { workspace = true }
tokio = { workspace = true }
@@ -52,5 +53,4 @@ mongodb = { workspace = true }
uuid = { workspace = true }
secrecy = { workspace = true }
axum = "0.8"
tower = { version = "0.5", features = ["util"] }
tower-http = { version = "0.6", features = ["cors"] }
+16 -18
View File
@@ -6,7 +6,7 @@ use tokio::sync::{broadcast, watch, Semaphore};
use compliance_core::models::pentest::PentestEvent;
use compliance_core::AgentConfig;
use crate::database::Database;
use crate::database::DatabasePool;
use crate::llm::LlmClient;
use crate::pipeline::orchestrator::PipelineOrchestrator;
@@ -16,7 +16,10 @@ const DEFAULT_MAX_CONCURRENT_SESSIONS: usize = 5;
#[derive(Clone)]
pub struct ComplianceAgent {
pub config: AgentConfig,
pub db: Database,
/// Per-tenant Mongo broker. Every code path must obtain a
/// tenant-scoped [`crate::database::Database`] from this pool —
/// there is no single shared database any more.
pub db_pool: DatabasePool,
pub llm: Arc<LlmClient>,
pub http: reqwest::Client,
/// Per-session broadcast senders for SSE streaming.
@@ -28,7 +31,7 @@ pub struct ComplianceAgent {
}
impl ComplianceAgent {
pub fn new(config: AgentConfig, db: Database) -> Self {
pub fn new(config: AgentConfig, db_pool: DatabasePool) -> Self {
let llm = Arc::new(LlmClient::new(
config.litellm_url.clone(),
config.litellm_api_key.clone(),
@@ -42,7 +45,7 @@ impl ComplianceAgent {
.unwrap_or_default();
Self {
config,
db,
db_pool,
llm,
http,
session_streams: Arc::new(DashMap::new()),
@@ -53,28 +56,27 @@ impl ComplianceAgent {
pub async fn run_scan(
&self,
tenant_id: &str,
repo_id: &str,
trigger: compliance_core::models::ScanTrigger,
) -> Result<(), crate::error::AgentError> {
let orchestrator = PipelineOrchestrator::new(
self.config.clone(),
self.db.clone(),
self.llm.clone(),
self.http.clone(),
);
let db = self.db_pool.for_tenant_id(tenant_id).await?;
let orchestrator =
PipelineOrchestrator::new(self.config.clone(), db, self.llm.clone(), self.http.clone());
orchestrator.run(repo_id, trigger).await
}
/// Run a PR review: scan the diff and post review comments.
pub async fn run_pr_review(
&self,
tenant_id: &str,
repo_id: &str,
pr_number: u64,
base_sha: &str,
head_sha: &str,
) -> Result<(), crate::error::AgentError> {
let repo = self
.db
let db = self.db_pool.for_tenant_id(tenant_id).await?;
let repo = db
.repositories()
.find_one(mongodb::bson::doc! {
"_id": mongodb::bson::oid::ObjectId::parse_str(repo_id)
@@ -85,12 +87,8 @@ impl ComplianceAgent {
crate::error::AgentError::Other(format!("Repository {repo_id} not found"))
})?;
let orchestrator = PipelineOrchestrator::new(
self.config.clone(),
self.db.clone(),
self.llm.clone(),
self.http.clone(),
);
let orchestrator =
PipelineOrchestrator::new(self.config.clone(), db, self.llm.clone(), self.http.clone());
orchestrator
.run_pr_review(&repo, repo_id, pr_number, base_sha, head_sha)
.await
+30 -26
View File
@@ -7,11 +7,13 @@ use mongodb::bson::doc;
use compliance_core::models::chat::{ChatRequest, ChatResponse, SourceReference};
use compliance_core::models::embedding::EmbeddingBuildRun;
use compliance_core::tenant_ctx::TenantCtx;
use compliance_graph::graph::embedding_store::EmbeddingStore;
use crate::agent::ComplianceAgent;
use crate::rag::pipeline::RagPipeline;
use super::dto::tenant_db;
use super::ApiResponse;
type AgentExt = Extension<Arc<ComplianceAgent>>;
@@ -20,10 +22,12 @@ type AgentExt = Extension<Arc<ComplianceAgent>>;
#[tracing::instrument(skip_all, fields(repo_id = %repo_id))]
pub async fn chat(
Extension(agent): AgentExt,
tenant: TenantCtx,
Path(repo_id): Path<String>,
Json(req): Json<ChatRequest>,
) -> Result<Json<ApiResponse<ChatResponse>>, StatusCode> {
let pipeline = RagPipeline::new(agent.llm.clone(), agent.db.inner());
let db = tenant_db(&agent, &tenant).await?;
let pipeline = RagPipeline::new(agent.llm.clone(), db.inner());
// Step 1: Embed the user's message
let query_vectors = agent
@@ -133,12 +137,15 @@ pub async fn chat(
#[tracing::instrument(skip_all, fields(repo_id = %repo_id))]
pub async fn build_embeddings(
Extension(agent): AgentExt,
tenant: TenantCtx,
Path(repo_id): Path<String>,
) -> Result<Json<serde_json::Value>, StatusCode> {
// Resolve the tenant DB up front so we can move it into the spawn;
// the JWT/dev context isn't available inside detached tasks.
let db = tenant_db(&agent, &tenant).await?;
let agent_clone = (*agent).clone();
tokio::spawn(async move {
let repo = match agent_clone
.db
let repo = match db
.repositories()
.find_one(doc! { "_id": mongodb::bson::oid::ObjectId::parse_str(&repo_id).ok() })
.await
@@ -151,8 +158,7 @@ pub async fn build_embeddings(
};
// Get latest graph build
let build = match agent_clone
.db
let build = match db
.graph_builds()
.find_one(doc! { "repo_id": &repo_id })
.sort(doc! { "started_at": -1 })
@@ -171,26 +177,22 @@ pub async fn build_embeddings(
.unwrap_or_else(|| "unknown".to_string());
// Get nodes
let nodes: Vec<compliance_core::models::graph::CodeNode> = match agent_clone
.db
.graph_nodes()
.find(doc! { "repo_id": &repo_id })
.await
{
Ok(cursor) => {
use futures_util::StreamExt;
let mut items = Vec::new();
let mut cursor = cursor;
while let Some(Ok(item)) = cursor.next().await {
items.push(item);
let nodes: Vec<compliance_core::models::graph::CodeNode> =
match db.graph_nodes().find(doc! { "repo_id": &repo_id }).await {
Ok(cursor) => {
use futures_util::StreamExt;
let mut items = Vec::new();
let mut cursor = cursor;
while let Some(Ok(item)) = cursor.next().await {
items.push(item);
}
items
}
items
}
Err(e) => {
tracing::error!("[{repo_id}] Failed to fetch nodes: {e}");
return;
}
};
Err(e) => {
tracing::error!("[{repo_id}] Failed to fetch nodes: {e}");
return;
}
};
let creds = crate::pipeline::git::RepoCredentials {
ssh_key_path: Some(agent_clone.config.ssh_key_path.clone()),
@@ -207,7 +209,7 @@ pub async fn build_embeddings(
}
};
let pipeline = RagPipeline::new(agent_clone.llm.clone(), agent_clone.db.inner());
let pipeline = RagPipeline::new(agent_clone.llm.clone(), db.inner());
match pipeline
.build_embeddings(&repo_id, &repo_path, &graph_build_id, &nodes)
.await
@@ -234,9 +236,11 @@ pub async fn build_embeddings(
#[tracing::instrument(skip_all, fields(repo_id = %repo_id))]
pub async fn embedding_status(
Extension(agent): AgentExt,
tenant: TenantCtx,
Path(repo_id): Path<String>,
) -> Result<Json<ApiResponse<Option<EmbeddingBuildRun>>>, StatusCode> {
let store = EmbeddingStore::new(agent.db.inner());
let db = tenant_db(&agent, &tenant).await?;
let store = EmbeddingStore::new(db.inner());
let build = store.get_latest_build(&repo_id).await.map_err(|e| {
tracing::error!("Failed to get embedding status: {e}");
StatusCode::INTERNAL_SERVER_ERROR
+20 -11
View File
@@ -7,9 +7,11 @@ use mongodb::bson::doc;
use serde::Deserialize;
use compliance_core::models::dast::{DastFinding, DastScanRun, DastTarget, DastTargetType};
use compliance_core::tenant_ctx::TenantCtx;
use crate::agent::ComplianceAgent;
use super::dto::tenant_db;
use super::{collect_cursor_async, ApiResponse, PaginationParams};
type AgentExt = Extension<Arc<ComplianceAgent>>;
@@ -45,9 +47,11 @@ fn default_rate_limit() -> u32 {
#[tracing::instrument(skip_all)]
pub async fn list_targets(
Extension(agent): AgentExt,
tenant: TenantCtx,
Query(params): Query<PaginationParams>,
) -> Result<Json<ApiResponse<Vec<DastTarget>>>, StatusCode> {
let db = &agent.db;
let db = tenant_db(&agent, &tenant).await?;
let db = &db;
let skip = (params.page.saturating_sub(1)) * params.limit as u64;
let total = db
.dast_targets()
@@ -80,6 +84,7 @@ pub async fn list_targets(
#[tracing::instrument(skip_all)]
pub async fn add_target(
Extension(agent): AgentExt,
tenant: TenantCtx,
Json(req): Json<AddTargetRequest>,
) -> Result<Json<ApiResponse<DastTarget>>, StatusCode> {
let mut target = DastTarget::new(req.name, req.base_url, req.target_type);
@@ -89,9 +94,8 @@ pub async fn add_target(
target.rate_limit = req.rate_limit;
target.allow_destructive = req.allow_destructive;
agent
.db
.dast_targets()
let db = tenant_db(&agent, &tenant).await?;
db.dast_targets()
.insert_one(&target)
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
@@ -107,19 +111,19 @@ pub async fn add_target(
#[tracing::instrument(skip_all, fields(target_id = %id))]
pub async fn trigger_scan(
Extension(agent): AgentExt,
tenant: TenantCtx,
Path(id): Path<String>,
) -> Result<Json<serde_json::Value>, StatusCode> {
let oid = mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?;
let db = tenant_db(&agent, &tenant).await?;
let target = agent
.db
let target = db
.dast_targets()
.find_one(doc! { "_id": oid })
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
.ok_or(StatusCode::NOT_FOUND)?;
let db = agent.db.clone();
tokio::spawn(async move {
let orchestrator = compliance_dast::DastOrchestrator::new(100);
match orchestrator.run_scan(&target, Vec::new()).await {
@@ -147,9 +151,11 @@ pub async fn trigger_scan(
#[tracing::instrument(skip_all)]
pub async fn list_scan_runs(
Extension(agent): AgentExt,
tenant: TenantCtx,
Query(params): Query<PaginationParams>,
) -> Result<Json<ApiResponse<Vec<DastScanRun>>>, StatusCode> {
let db = &agent.db;
let db = tenant_db(&agent, &tenant).await?;
let db = &db;
let skip = (params.page.saturating_sub(1)) * params.limit as u64;
let total = db
.dast_scan_runs()
@@ -183,9 +189,11 @@ pub async fn list_scan_runs(
#[tracing::instrument(skip_all)]
pub async fn list_findings(
Extension(agent): AgentExt,
tenant: TenantCtx,
Query(params): Query<PaginationParams>,
) -> Result<Json<ApiResponse<Vec<DastFinding>>>, StatusCode> {
let db = &agent.db;
let db = tenant_db(&agent, &tenant).await?;
let db = &db;
let skip = (params.page.saturating_sub(1)) * params.limit as u64;
let total = db
.dast_findings()
@@ -219,12 +227,13 @@ pub async fn list_findings(
#[tracing::instrument(skip_all, fields(finding_id = %id))]
pub async fn get_finding(
Extension(agent): AgentExt,
tenant: TenantCtx,
Path(id): Path<String>,
) -> Result<Json<ApiResponse<DastFinding>>, StatusCode> {
let oid = mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?;
let db = tenant_db(&agent, &tenant).await?;
let finding = agent
.db
let finding = db
.dast_findings()
.find_one(doc! { "_id": oid })
.await
+21
View File
@@ -180,6 +180,27 @@ pub struct SbomVersionDiff {
pub(crate) type AgentExt = axum::extract::Extension<std::sync::Arc<crate::agent::ComplianceAgent>>;
pub(crate) type ApiResult<T> = Result<axum::Json<ApiResponse<T>>, axum::http::StatusCode>;
/// Resolve a tenant-scoped [`Database`] from the request's
/// [`TenantContext`] (inserted by the M7.1 JWT middleware, or by the
/// dev fallback in unsecured environments). The pool ensures the
/// tenant's indexes idempotently.
///
/// Returns 500 on the rare path where Mongo refuses the database
/// handle — the M7.1 auth/status middleware already rejects every
/// other failure mode with 4xx before we get here.
pub(crate) async fn tenant_db(
agent: &crate::agent::ComplianceAgent,
tenant: &compliance_core::tenant_ctx::TenantCtx,
) -> Result<crate::database::Database, axum::http::StatusCode> {
agent.db_pool.for_tenant(&tenant.0).await.map_err(|e| {
tracing::error!(
tenant_id = %tenant.0.tenant_id,
"Failed to acquire tenant database: {e}"
);
axum::http::StatusCode::INTERNAL_SERVER_ERROR
})
}
pub(crate) async fn collect_cursor_async<T: serde::de::DeserializeOwned + Unpin + Send>(
mut cursor: mongodb::Cursor<T>,
) -> Vec<T> {
+16 -11
View File
@@ -5,13 +5,16 @@ use mongodb::bson::doc;
use super::dto::*;
use compliance_core::models::Finding;
use compliance_core::tenant_ctx::TenantCtx;
#[tracing::instrument(skip_all, fields(repo_id = ?filter.repo_id, severity = ?filter.severity, scan_type = ?filter.scan_type))]
pub async fn list_findings(
Extension(agent): AgentExt,
tenant: TenantCtx,
Query(filter): Query<FindingsFilter>,
) -> ApiResult<Vec<Finding>> {
let db = &agent.db;
let db = tenant_db(&agent, &tenant).await?;
let db = &db;
let mut query = doc! {};
if let Some(repo_id) = &filter.repo_id {
query.insert("repo_id", repo_id);
@@ -81,11 +84,12 @@ pub async fn list_findings(
#[tracing::instrument(skip_all, fields(finding_id = %id))]
pub async fn get_finding(
Extension(agent): AgentExt,
tenant: TenantCtx,
Path(id): Path<String>,
) -> Result<Json<ApiResponse<Finding>>, StatusCode> {
let oid = mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?;
let finding = agent
.db
let db = tenant_db(&agent, &tenant).await?;
let finding = db
.findings()
.find_one(doc! { "_id": oid })
.await
@@ -102,14 +106,14 @@ pub async fn get_finding(
#[tracing::instrument(skip_all, fields(finding_id = %id))]
pub async fn update_finding_status(
Extension(agent): AgentExt,
tenant: TenantCtx,
Path(id): Path<String>,
Json(req): Json<UpdateStatusRequest>,
) -> Result<Json<serde_json::Value>, StatusCode> {
let oid = mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?;
let db = tenant_db(&agent, &tenant).await?;
agent
.db
.findings()
db.findings()
.update_one(
doc! { "_id": oid },
doc! { "$set": { "status": &req.status, "updated_at": mongodb::bson::DateTime::now() } },
@@ -123,6 +127,7 @@ pub async fn update_finding_status(
#[tracing::instrument(skip_all)]
pub async fn bulk_update_finding_status(
Extension(agent): AgentExt,
tenant: TenantCtx,
Json(req): Json<BulkUpdateStatusRequest>,
) -> Result<Json<serde_json::Value>, StatusCode> {
let oids: Vec<mongodb::bson::oid::ObjectId> = req
@@ -135,8 +140,8 @@ pub async fn bulk_update_finding_status(
return Err(StatusCode::BAD_REQUEST);
}
let result = agent
.db
let db = tenant_db(&agent, &tenant).await?;
let result = db
.findings()
.update_many(
doc! { "_id": { "$in": oids } },
@@ -153,14 +158,14 @@ pub async fn bulk_update_finding_status(
#[tracing::instrument(skip_all)]
pub async fn update_finding_feedback(
Extension(agent): AgentExt,
tenant: TenantCtx,
Path(id): Path<String>,
Json(req): Json<UpdateFeedbackRequest>,
) -> Result<Json<serde_json::Value>, StatusCode> {
let oid = mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?;
let db = tenant_db(&agent, &tenant).await?;
agent
.db
.findings()
db.findings()
.update_one(
doc! { "_id": oid },
doc! { "$set": { "developer_feedback": &req.feedback, "updated_at": mongodb::bson::DateTime::now() } },
+24 -10
View File
@@ -7,9 +7,11 @@ use mongodb::bson::doc;
use serde::{Deserialize, Serialize};
use compliance_core::models::graph::{CodeEdge, CodeNode, GraphBuildRun, ImpactAnalysis};
use compliance_core::tenant_ctx::TenantCtx;
use crate::agent::ComplianceAgent;
use super::dto::tenant_db;
use super::{collect_cursor_async, ApiResponse};
type AgentExt = Extension<Arc<ComplianceAgent>>;
@@ -36,9 +38,11 @@ fn default_search_limit() -> usize {
#[tracing::instrument(skip_all, fields(repo_id = %repo_id))]
pub async fn get_graph(
Extension(agent): AgentExt,
tenant: TenantCtx,
Path(repo_id): Path<String>,
) -> Result<Json<ApiResponse<GraphData>>, StatusCode> {
let db = &agent.db;
let db = tenant_db(&agent, &tenant).await?;
let db = &db;
// Get latest build
let build: Option<GraphBuildRun> = db
@@ -98,9 +102,11 @@ pub async fn get_graph(
#[tracing::instrument(skip_all, fields(repo_id = %repo_id))]
pub async fn get_nodes(
Extension(agent): AgentExt,
tenant: TenantCtx,
Path(repo_id): Path<String>,
) -> Result<Json<ApiResponse<Vec<CodeNode>>>, StatusCode> {
let db = &agent.db;
let db = tenant_db(&agent, &tenant).await?;
let db = &db;
let filter = doc! { "repo_id": &repo_id };
let nodes: Vec<CodeNode> = match db.graph_nodes().find(filter).await {
@@ -123,9 +129,11 @@ pub async fn get_nodes(
#[tracing::instrument(skip_all, fields(repo_id = %repo_id))]
pub async fn get_communities(
Extension(agent): AgentExt,
tenant: TenantCtx,
Path(repo_id): Path<String>,
) -> Result<Json<ApiResponse<Vec<CommunityInfo>>>, StatusCode> {
let db = &agent.db;
let db = tenant_db(&agent, &tenant).await?;
let db = &db;
let filter = doc! { "repo_id": &repo_id };
let nodes: Vec<CodeNode> = match db.graph_nodes().find(filter).await {
@@ -176,9 +184,11 @@ pub struct CommunityInfo {
#[tracing::instrument(skip_all, fields(repo_id = %repo_id, finding_id = %finding_id))]
pub async fn get_impact(
Extension(agent): AgentExt,
tenant: TenantCtx,
Path((repo_id, finding_id)): Path<(String, String)>,
) -> Result<Json<ApiResponse<Option<ImpactAnalysis>>>, StatusCode> {
let db = &agent.db;
let db = tenant_db(&agent, &tenant).await?;
let db = &db;
let filter = doc! { "repo_id": &repo_id, "finding_id": &finding_id };
let impact = db
@@ -198,10 +208,12 @@ pub async fn get_impact(
#[tracing::instrument(skip_all, fields(repo_id = %repo_id, query = %params.q))]
pub async fn search_symbols(
Extension(agent): AgentExt,
tenant: TenantCtx,
Path(repo_id): Path<String>,
Query(params): Query<SearchParams>,
) -> Result<Json<ApiResponse<Vec<CodeNode>>>, StatusCode> {
let db = &agent.db;
let db = tenant_db(&agent, &tenant).await?;
let db = &db;
// Simple text search on qualified_name and name fields
let filter = doc! {
@@ -234,10 +246,12 @@ pub async fn search_symbols(
#[tracing::instrument(skip_all, fields(repo_id = %repo_id))]
pub async fn get_file_content(
Extension(agent): AgentExt,
tenant: TenantCtx,
Path(repo_id): Path<String>,
Query(params): Query<FileContentParams>,
) -> Result<Json<ApiResponse<FileContent>>, StatusCode> {
let db = &agent.db;
let db = tenant_db(&agent, &tenant).await?;
let db = &db;
// Look up the repository to get repo name
let repo = db
@@ -296,12 +310,13 @@ pub struct FileContent {
#[tracing::instrument(skip_all, fields(repo_id = %repo_id))]
pub async fn trigger_build(
Extension(agent): AgentExt,
tenant: TenantCtx,
Path(repo_id): Path<String>,
) -> Result<Json<serde_json::Value>, StatusCode> {
let db = tenant_db(&agent, &tenant).await?;
let agent_clone = (*agent).clone();
tokio::spawn(async move {
let repo = match agent_clone
.db
let repo = match db
.repositories()
.find_one(doc! { "_id": mongodb::bson::oid::ObjectId::parse_str(&repo_id).ok() })
.await
@@ -333,8 +348,7 @@ pub async fn trigger_build(
match engine.build_graph(&repo_path, &repo_id, &graph_build_id) {
Ok((code_graph, build_run)) => {
let store =
compliance_graph::graph::persistence::GraphStore::new(agent_clone.db.inner());
let store = compliance_graph::graph::persistence::GraphStore::new(db.inner());
let _ = store.delete_repo_graph(&repo_id).await;
let _ = store
.store_graph(&build_run, &code_graph.nodes, &code_graph.edges)
+7 -2
View File
@@ -3,6 +3,7 @@ use mongodb::bson::doc;
use super::dto::*;
use compliance_core::models::ScanRun;
use compliance_core::tenant_ctx::TenantCtx;
#[tracing::instrument(skip_all)]
pub async fn health() -> Json<serde_json::Value> {
@@ -10,8 +11,12 @@ pub async fn health() -> Json<serde_json::Value> {
}
#[tracing::instrument(skip_all)]
pub async fn stats_overview(axum::extract::Extension(agent): AgentExt) -> ApiResult<OverviewStats> {
let db = &agent.db;
pub async fn stats_overview(
axum::extract::Extension(agent): AgentExt,
tenant: TenantCtx,
) -> ApiResult<OverviewStats> {
let db = tenant_db(&agent, &tenant).await?;
let db = &db;
let total_repositories = db
.repositories()
+4 -1
View File
@@ -4,13 +4,16 @@ use mongodb::bson::doc;
use super::dto::*;
use compliance_core::models::TrackerIssue;
use compliance_core::tenant_ctx::TenantCtx;
#[tracing::instrument(skip_all)]
pub async fn list_issues(
Extension(agent): AgentExt,
tenant: TenantCtx,
Query(params): Query<PaginationParams>,
) -> ApiResult<Vec<TrackerIssue>> {
let db = &agent.db;
let db = tenant_db(&agent, &tenant).await?;
let db = &db;
let skip = (params.page.saturating_sub(1)) * params.limit as u64;
let total = db
.tracker_issues()
@@ -5,15 +5,18 @@ use mongodb::bson::doc;
use serde::Deserialize;
use compliance_core::models::notification::CveNotification;
use compliance_core::tenant_ctx::TenantCtx;
use super::dto::{AgentExt, ApiResponse};
use super::dto::{tenant_db, AgentExt, ApiResponse};
/// GET /api/v1/notifications — List CVE notifications (newest first)
#[tracing::instrument(skip_all)]
pub async fn list_notifications(
Extension(agent): AgentExt,
tenant: TenantCtx,
axum::extract::Query(params): axum::extract::Query<NotificationFilter>,
) -> Result<Json<ApiResponse<Vec<CveNotification>>>, StatusCode> {
let db = tenant_db(&agent, &tenant).await?;
let mut filter = doc! {};
// Filter by status (default: show new + read, exclude dismissed)
@@ -41,15 +44,13 @@ pub async fn list_notifications(
let limit = params.limit.unwrap_or(50).min(200);
let skip = (page - 1) * limit as u64;
let total = agent
.db
let total = db
.cve_notifications()
.count_documents(filter.clone())
.await
.unwrap_or(0);
let notifications: Vec<CveNotification> = match agent
.db
let notifications: Vec<CveNotification> = match db
.cve_notifications()
.find(filter)
.sort(doc! { "created_at": -1 })
@@ -83,9 +84,10 @@ pub async fn list_notifications(
#[tracing::instrument(skip_all)]
pub async fn notification_count(
Extension(agent): AgentExt,
tenant: TenantCtx,
) -> Result<Json<serde_json::Value>, StatusCode> {
let count = agent
.db
let db = tenant_db(&agent, &tenant).await?;
let count = db
.cve_notifications()
.count_documents(doc! { "status": "new" })
.await
@@ -98,12 +100,13 @@ pub async fn notification_count(
#[tracing::instrument(skip_all, fields(id = %id))]
pub async fn mark_read(
Extension(agent): AgentExt,
tenant: TenantCtx,
axum::extract::Path(id): axum::extract::Path<String>,
) -> Result<Json<serde_json::Value>, StatusCode> {
let oid = mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?;
let db = tenant_db(&agent, &tenant).await?;
let result = agent
.db
let result = db
.cve_notifications()
.update_one(
doc! { "_id": oid },
@@ -125,12 +128,13 @@ pub async fn mark_read(
#[tracing::instrument(skip_all, fields(id = %id))]
pub async fn dismiss_notification(
Extension(agent): AgentExt,
tenant: TenantCtx,
axum::extract::Path(id): axum::extract::Path<String>,
) -> Result<Json<serde_json::Value>, StatusCode> {
let oid = mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?;
let db = tenant_db(&agent, &tenant).await?;
let result = agent
.db
let result = db
.cve_notifications()
.update_one(
doc! { "_id": oid },
@@ -149,9 +153,10 @@ pub async fn dismiss_notification(
#[tracing::instrument(skip_all)]
pub async fn mark_all_read(
Extension(agent): AgentExt,
tenant: TenantCtx,
) -> Result<Json<serde_json::Value>, StatusCode> {
let result = agent
.db
let db = tenant_db(&agent, &tenant).await?;
let result = db
.cve_notifications()
.update_many(
doc! { "status": "new" },
@@ -13,10 +13,11 @@ use compliance_core::models::dast::DastFinding;
use compliance_core::models::finding::Finding;
use compliance_core::models::pentest::*;
use compliance_core::models::sbom::SbomEntry;
use compliance_core::tenant_ctx::TenantCtx;
use crate::agent::ComplianceAgent;
use super::super::dto::collect_cursor_async;
use super::super::dto::{collect_cursor_async, tenant_db};
type AgentExt = Extension<Arc<ComplianceAgent>>;
@@ -35,11 +36,15 @@ pub struct ExportBody {
#[tracing::instrument(skip_all, fields(session_id = %id))]
pub async fn export_session_report(
Extension(agent): AgentExt,
tenant: TenantCtx,
Path(id): Path<String>,
Json(body): Json<ExportBody>,
) -> Result<axum::response::Response, (StatusCode, String)> {
let oid = mongodb::bson::oid::ObjectId::parse_str(&id)
.map_err(|_| (StatusCode::BAD_REQUEST, "Invalid session ID".to_string()))?;
let db = tenant_db(&agent, &tenant)
.await
.map_err(|s| (s, "failed to acquire tenant database".to_string()))?;
if body.password.len() < 8 {
return Err((
@@ -49,8 +54,7 @@ pub async fn export_session_report(
}
// Fetch session
let session = agent
.db
let session = db
.pentest_sessions()
.find_one(doc! { "_id": oid })
.await
@@ -64,9 +68,7 @@ pub async fn export_session_report(
// Resolve target name
let target = if let Ok(tid) = mongodb::bson::oid::ObjectId::parse_str(&session.target_id) {
agent
.db
.dast_targets()
db.dast_targets()
.find_one(doc! { "_id": tid })
.await
.ok()
@@ -84,8 +86,7 @@ pub async fn export_session_report(
.unwrap_or_default();
// Fetch attack chain nodes
let nodes: Vec<AttackChainNode> = match agent
.db
let nodes: Vec<AttackChainNode> = match db
.attack_chain_nodes()
.find(doc! { "session_id": &id })
.sort(doc! { "started_at": 1 })
@@ -96,8 +97,7 @@ pub async fn export_session_report(
};
// Fetch DAST findings for this session, then deduplicate
let raw_findings: Vec<DastFinding> = match agent
.db
let raw_findings: Vec<DastFinding> = match db
.dast_findings()
.find(doc! { "session_id": &id })
.sort(doc! { "severity": -1, "created_at": -1 })
@@ -122,8 +122,7 @@ pub async fn export_session_report(
.or_else(|| target.as_ref().and_then(|t| t.repo_id.clone()));
let (sast_findings, sbom_entries, code_context) = if let Some(ref rid) = repo_id {
let sast: Vec<Finding> = match agent
.db
let sast: Vec<Finding> = match db
.findings()
.find(doc! {
"repo_id": rid,
@@ -143,8 +142,7 @@ pub async fn export_session_report(
Err(_) => Vec::new(),
};
let sbom: Vec<SbomEntry> = match agent
.db
let sbom: Vec<SbomEntry> = match db
.sbom_entries()
.find(doc! {
"repo_id": rid,
@@ -164,8 +162,7 @@ pub async fn export_session_report(
};
// Build code context from graph nodes
let code_ctx: Vec<CodeContextHint> = match agent
.db
let code_ctx: Vec<CodeContextHint> = match db
.graph_nodes()
.find(doc! { "repo_id": rid, "is_entry_point": true })
.limit(50)
@@ -7,11 +7,12 @@ use mongodb::bson::doc;
use serde::Deserialize;
use compliance_core::models::pentest::*;
use compliance_core::tenant_ctx::TenantCtx;
use crate::agent::ComplianceAgent;
use crate::pentest::PentestOrchestrator;
use super::super::dto::{collect_cursor_async, ApiResponse, PaginationParams};
use super::super::dto::{collect_cursor_async, tenant_db, ApiResponse, PaginationParams};
type AgentExt = Extension<Arc<ComplianceAgent>>;
@@ -43,6 +44,7 @@ pub struct LookupRepoQuery {
#[tracing::instrument(skip_all)]
pub async fn create_session(
Extension(agent): AgentExt,
tenant: TenantCtx,
Json(req): Json<CreateSessionRequest>,
) -> Result<Json<ApiResponse<PentestSession>>, (StatusCode, String)> {
// Try to acquire a concurrency permit
@@ -57,6 +59,10 @@ pub async fn create_session(
)
})?;
let db = tenant_db(&agent, &tenant)
.await
.map_err(|s| (s, "failed to acquire tenant database".to_string()))?;
if let Some(ref config) = req.config {
// ── Wizard path ──────────────────────────────────────────────
if !config.disclaimer_accepted {
@@ -67,8 +73,7 @@ pub async fn create_session(
}
// Look up or auto-create DastTarget by app_url
let target = match agent
.db
let target = match db
.dast_targets()
.find_one(doc! { "base_url": &config.app_url })
.await
@@ -87,7 +92,7 @@ pub async fn create_session(
}
t.allow_destructive = config.allow_destructive;
t.excluded_paths = config.scope_exclusions.clone();
let res = agent.db.dast_targets().insert_one(&t).await.map_err(|e| {
let res = db.dast_targets().insert_one(&t).await.map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("Failed to create target: {e}"),
@@ -110,8 +115,7 @@ pub async fn create_session(
// Resolve repo_id from git_repo_url if provided
if let Some(ref git_url) = config.git_repo_url {
if let Ok(Some(repo)) = agent
.db
if let Ok(Some(repo)) = db
.repositories()
.find_one(doc! { "git_url": git_url })
.await
@@ -120,8 +124,7 @@ pub async fn create_session(
}
}
let insert_result = agent
.db
let insert_result = db
.pentest_sessions()
.insert_one(&session)
.await
@@ -212,8 +215,7 @@ pub async fn create_session(
// Persist encrypted credentials to DB
if session_for_task.config.is_some() {
if let Some(sid) = session.id {
let _ = agent
.db
let _ = db
.pentest_sessions()
.update_one(
doc! { "_id": sid },
@@ -245,12 +247,13 @@ pub async fn create_session(
});
let llm = agent.llm.clone();
let db = agent.db.clone();
let db_for_orchestrator = db.clone();
let session_clone = session.clone();
let target_clone = target.clone();
let agent_ref = agent.clone();
tokio::spawn(async move {
let orchestrator = PentestOrchestrator::new(llm, db, event_tx, Some(pause_rx));
let orchestrator =
PentestOrchestrator::new(llm, db_for_orchestrator, event_tx, Some(pause_rx));
orchestrator
.run_session_guarded(&session_clone, &target_clone, &initial_message)
.await;
@@ -292,8 +295,7 @@ pub async fn create_session(
)
})?;
let target = agent
.db
let target = db
.dast_targets()
.find_one(doc! { "_id": oid })
.await
@@ -310,8 +312,7 @@ pub async fn create_session(
let mut session = PentestSession::new(target_id, strategy);
session.repo_id = target.repo_id.clone();
let insert_result = agent
.db
let insert_result = db
.pentest_sessions()
.insert_one(&session)
.await
@@ -338,12 +339,13 @@ pub async fn create_session(
});
let llm = agent.llm.clone();
let db = agent.db.clone();
let db_for_orchestrator = db.clone();
let session_clone = session.clone();
let target_clone = target.clone();
let agent_ref = agent.clone();
tokio::spawn(async move {
let orchestrator = PentestOrchestrator::new(llm, db, event_tx, Some(pause_rx));
let orchestrator =
PentestOrchestrator::new(llm, db_for_orchestrator, event_tx, Some(pause_rx));
orchestrator
.run_session_guarded(&session_clone, &target_clone, &initial_message)
.await;
@@ -373,10 +375,11 @@ fn parse_strategy(s: &str) -> PentestStrategy {
#[tracing::instrument(skip_all)]
pub async fn lookup_repo(
Extension(agent): AgentExt,
tenant: TenantCtx,
Query(params): Query<LookupRepoQuery>,
) -> Result<Json<ApiResponse<serde_json::Value>>, StatusCode> {
let repo = agent
.db
let db = tenant_db(&agent, &tenant).await?;
let repo = db
.repositories()
.find_one(doc! { "git_url": &params.url })
.await
@@ -402,9 +405,11 @@ pub async fn lookup_repo(
#[tracing::instrument(skip_all)]
pub async fn list_sessions(
Extension(agent): AgentExt,
tenant: TenantCtx,
Query(params): Query<PaginationParams>,
) -> Result<Json<ApiResponse<Vec<PentestSession>>>, StatusCode> {
let db = &agent.db;
let db = tenant_db(&agent, &tenant).await?;
let db = &db;
let skip = (params.page.saturating_sub(1)) * params.limit as u64;
let total = db
.pentest_sessions()
@@ -438,12 +443,13 @@ pub async fn list_sessions(
#[tracing::instrument(skip_all, fields(session_id = %id))]
pub async fn get_session(
Extension(agent): AgentExt,
tenant: TenantCtx,
Path(id): Path<String>,
) -> Result<Json<ApiResponse<PentestSession>>, StatusCode> {
let oid = mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?;
let db = tenant_db(&agent, &tenant).await?;
let mut session = agent
.db
let mut session = db
.pentest_sessions()
.find_one(doc! { "_id": oid })
.await
@@ -471,15 +477,18 @@ pub async fn get_session(
#[tracing::instrument(skip_all, fields(session_id = %id))]
pub async fn send_message(
Extension(agent): AgentExt,
tenant: TenantCtx,
Path(id): Path<String>,
Json(req): Json<SendMessageRequest>,
) -> Result<Json<ApiResponse<PentestMessage>>, (StatusCode, String)> {
let oid = mongodb::bson::oid::ObjectId::parse_str(&id)
.map_err(|_| (StatusCode::BAD_REQUEST, "Invalid session ID".to_string()))?;
let db = tenant_db(&agent, &tenant)
.await
.map_err(|s| (s, "failed to acquire tenant database".to_string()))?;
// Verify session exists and is running
let session = agent
.db
let session = db
.pentest_sessions()
.find_one(doc! { "_id": oid })
.await
@@ -506,8 +515,7 @@ pub async fn send_message(
)
})?;
let target = agent
.db
let target = db
.dast_targets()
.find_one(doc! { "_id": target_oid })
.await
@@ -527,13 +535,13 @@ pub async fn send_message(
// Store user message
let session_id = id.clone();
let user_msg = PentestMessage::user(session_id.clone(), req.message.clone());
let _ = agent.db.pentest_messages().insert_one(&user_msg).await;
let _ = db.pentest_messages().insert_one(&user_msg).await;
let response_msg = user_msg.clone();
// Spawn orchestrator to continue the session
let llm = agent.llm.clone();
let db = agent.db.clone();
let db_for_orchestrator = db.clone();
let message = req.message.clone();
// Use existing broadcast sender if available, otherwise create a new one
@@ -548,7 +556,7 @@ pub async fn send_message(
.unwrap_or_else(|| agent.register_session_stream(&session_id));
tokio::spawn(async move {
let orchestrator = PentestOrchestrator::new(llm, db, event_tx, None);
let orchestrator = PentestOrchestrator::new(llm, db_for_orchestrator, event_tx, None);
orchestrator
.run_session_guarded(&session, &target, &message)
.await;
@@ -565,13 +573,16 @@ pub async fn send_message(
#[tracing::instrument(skip_all, fields(session_id = %id))]
pub async fn stop_session(
Extension(agent): AgentExt,
tenant: TenantCtx,
Path(id): Path<String>,
) -> Result<Json<ApiResponse<PentestSession>>, (StatusCode, String)> {
let oid = mongodb::bson::oid::ObjectId::parse_str(&id)
.map_err(|_| (StatusCode::BAD_REQUEST, "Invalid session ID".to_string()))?;
let db = tenant_db(&agent, &tenant)
.await
.map_err(|s| (s, "failed to acquire tenant database".to_string()))?;
let session = agent
.db
let session = db
.pentest_sessions()
.find_one(doc! { "_id": oid })
.await
@@ -590,9 +601,7 @@ pub async fn stop_session(
));
}
agent
.db
.pentest_sessions()
db.pentest_sessions()
.update_one(
doc! { "_id": oid },
doc! { "$set": {
@@ -612,8 +621,7 @@ pub async fn stop_session(
// Clean up session resources
agent.cleanup_session(&id);
let updated = agent
.db
let updated = db
.pentest_sessions()
.find_one(doc! { "_id": oid })
.await
@@ -641,13 +649,16 @@ pub async fn stop_session(
#[tracing::instrument(skip_all, fields(session_id = %id))]
pub async fn pause_session(
Extension(agent): AgentExt,
tenant: TenantCtx,
Path(id): Path<String>,
) -> Result<Json<ApiResponse<serde_json::Value>>, (StatusCode, String)> {
let oid = mongodb::bson::oid::ObjectId::parse_str(&id)
.map_err(|_| (StatusCode::BAD_REQUEST, "Invalid session ID".to_string()))?;
let db = tenant_db(&agent, &tenant)
.await
.map_err(|s| (s, "failed to acquire tenant database".to_string()))?;
let session = agent
.db
let session = db
.pentest_sessions()
.find_one(doc! { "_id": oid })
.await
@@ -684,13 +695,16 @@ pub async fn pause_session(
#[tracing::instrument(skip_all, fields(session_id = %id))]
pub async fn resume_session(
Extension(agent): AgentExt,
tenant: TenantCtx,
Path(id): Path<String>,
) -> Result<Json<ApiResponse<serde_json::Value>>, (StatusCode, String)> {
let oid = mongodb::bson::oid::ObjectId::parse_str(&id)
.map_err(|_| (StatusCode::BAD_REQUEST, "Invalid session ID".to_string()))?;
let db = tenant_db(&agent, &tenant)
.await
.map_err(|s| (s, "failed to acquire tenant database".to_string()))?;
let session = agent
.db
let session = db
.pentest_sessions()
.find_one(doc! { "_id": oid })
.await
@@ -727,12 +741,13 @@ pub async fn resume_session(
#[tracing::instrument(skip_all, fields(session_id = %id))]
pub async fn get_attack_chain(
Extension(agent): AgentExt,
tenant: TenantCtx,
Path(id): Path<String>,
) -> Result<Json<ApiResponse<Vec<AttackChainNode>>>, StatusCode> {
let _oid = mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?;
let db = tenant_db(&agent, &tenant).await?;
let nodes = match agent
.db
let nodes = match db
.attack_chain_nodes()
.find(doc! { "session_id": &id })
.sort(doc! { "started_at": 1 })
@@ -757,21 +772,21 @@ pub async fn get_attack_chain(
#[tracing::instrument(skip_all, fields(session_id = %id))]
pub async fn get_messages(
Extension(agent): AgentExt,
tenant: TenantCtx,
Path(id): Path<String>,
Query(params): Query<PaginationParams>,
) -> Result<Json<ApiResponse<Vec<PentestMessage>>>, StatusCode> {
let _oid = mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?;
let db = tenant_db(&agent, &tenant).await?;
let skip = (params.page.saturating_sub(1)) * params.limit as u64;
let total = agent
.db
let total = db
.pentest_messages()
.count_documents(doc! { "session_id": &id })
.await
.unwrap_or(0);
let messages = match agent
.db
let messages = match db
.pentest_messages()
.find(doc! { "session_id": &id })
.sort(doc! { "created_at": 1 })
@@ -797,21 +812,21 @@ pub async fn get_messages(
#[tracing::instrument(skip_all, fields(session_id = %id))]
pub async fn get_session_findings(
Extension(agent): AgentExt,
tenant: TenantCtx,
Path(id): Path<String>,
Query(params): Query<PaginationParams>,
) -> Result<Json<ApiResponse<Vec<compliance_core::models::dast::DastFinding>>>, StatusCode> {
let _oid = mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?;
let db = tenant_db(&agent, &tenant).await?;
let skip = (params.page.saturating_sub(1)) * params.limit as u64;
let total = agent
.db
let total = db
.dast_findings()
.count_documents(doc! { "session_id": &id })
.await
.unwrap_or(0);
let findings = match agent
.db
let findings = match db
.dast_findings()
.find(doc! { "session_id": &id })
.sort(doc! { "created_at": -1 })
@@ -6,10 +6,11 @@ use axum::Json;
use mongodb::bson::doc;
use compliance_core::models::pentest::*;
use compliance_core::tenant_ctx::TenantCtx;
use crate::agent::ComplianceAgent;
use super::super::dto::{collect_cursor_async, ApiResponse};
use super::super::dto::{collect_cursor_async, tenant_db, ApiResponse};
type AgentExt = Extension<Arc<ComplianceAgent>>;
@@ -17,8 +18,10 @@ type AgentExt = Extension<Arc<ComplianceAgent>>;
#[tracing::instrument(skip_all)]
pub async fn pentest_stats(
Extension(agent): AgentExt,
tenant: TenantCtx,
) -> Result<Json<ApiResponse<PentestStats>>, StatusCode> {
let db = &agent.db;
let db = tenant_db(&agent, &tenant).await?;
let db = &db;
let running_sessions = db
.pentest_sessions()
@@ -11,10 +11,11 @@ use tokio_stream::wrappers::BroadcastStream;
use tokio_stream::StreamExt;
use compliance_core::models::pentest::*;
use compliance_core::tenant_ctx::TenantCtx;
use crate::agent::ComplianceAgent;
use super::super::dto::collect_cursor_async;
use super::super::dto::{collect_cursor_async, tenant_db};
type AgentExt = Extension<Arc<ComplianceAgent>>;
@@ -25,13 +26,14 @@ type AgentExt = Extension<Arc<ComplianceAgent>>;
#[tracing::instrument(skip_all, fields(session_id = %id))]
pub async fn session_stream(
Extension(agent): AgentExt,
tenant: TenantCtx,
Path(id): Path<String>,
) -> Result<Sse<impl futures_util::Stream<Item = Result<Event, Infallible>>>, StatusCode> {
let oid = mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?;
let db = tenant_db(&agent, &tenant).await?;
// Verify session exists
let _session = agent
.db
let _session = db
.pentest_sessions()
.find_one(doc! { "_id": oid })
.await
@@ -43,8 +45,7 @@ pub async fn session_stream(
let mut initial_events: Vec<Result<Event, Infallible>> = Vec::new();
// Fetch recent messages for this session
let messages: Vec<PentestMessage> = match agent
.db
let messages: Vec<PentestMessage> = match db
.pentest_messages()
.find(doc! { "session_id": &id })
.sort(doc! { "created_at": 1 })
@@ -56,8 +57,7 @@ pub async fn session_stream(
};
// Fetch recent attack chain nodes
let nodes: Vec<AttackChainNode> = match agent
.db
let nodes: Vec<AttackChainNode> = match db
.attack_chain_nodes()
.find(doc! { "session_id": &id })
.sort(doc! { "started_at": 1 })
@@ -94,8 +94,7 @@ pub async fn session_stream(
}
// Add current session status event
let session = agent
.db
let session = db
.pentest_sessions()
.find_one(doc! { "_id": oid })
.await
+28 -17
View File
@@ -5,13 +5,16 @@ use mongodb::bson::doc;
use super::dto::*;
use compliance_core::models::*;
use compliance_core::tenant_ctx::TenantCtx;
#[tracing::instrument(skip_all)]
pub async fn list_repositories(
Extension(agent): AgentExt,
tenant: TenantCtx,
Query(params): Query<PaginationParams>,
) -> ApiResult<Vec<TrackedRepository>> {
let db = &agent.db;
let db = tenant_db(&agent, &tenant).await?;
let db = &db;
let skip = (params.page.saturating_sub(1)) * params.limit as u64;
let total = db
.repositories()
@@ -43,6 +46,7 @@ pub async fn list_repositories(
#[tracing::instrument(skip_all)]
pub async fn add_repository(
Extension(agent): AgentExt,
tenant: TenantCtx,
Json(req): Json<AddRepositoryRequest>,
) -> Result<Json<ApiResponse<TrackedRepository>>, (StatusCode, String)> {
// Validate repository access before saving
@@ -69,17 +73,15 @@ pub async fn add_repository(
repo.tracker_token = req.tracker_token;
repo.scan_schedule = req.scan_schedule;
agent
.db
.repositories()
.insert_one(&repo)
let db = tenant_db(&agent, &tenant)
.await
.map_err(|_| {
(
StatusCode::CONFLICT,
"Repository already exists".to_string(),
)
})?;
.map_err(|s| (s, "failed to acquire tenant database".to_string()))?;
db.repositories().insert_one(&repo).await.map_err(|_| {
(
StatusCode::CONFLICT,
"Repository already exists".to_string(),
)
})?;
Ok(Json(ApiResponse {
data: repo,
@@ -91,10 +93,12 @@ pub async fn add_repository(
#[tracing::instrument(skip_all, fields(repo_id = %id))]
pub async fn update_repository(
Extension(agent): AgentExt,
tenant: TenantCtx,
Path(id): Path<String>,
Json(req): Json<UpdateRepositoryRequest>,
) -> Result<Json<serde_json::Value>, StatusCode> {
let oid = mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?;
let db = tenant_db(&agent, &tenant).await?;
let mut set_doc = doc! { "updated_at": mongodb::bson::DateTime::now() };
@@ -126,8 +130,7 @@ pub async fn update_repository(
set_doc.insert("scan_schedule", schedule);
}
let result = agent
.db
let result = db
.repositories()
.update_one(doc! { "_id": oid }, doc! { "$set": set_doc })
.await
@@ -155,11 +158,16 @@ pub async fn get_ssh_public_key(
#[tracing::instrument(skip_all, fields(repo_id = %id))]
pub async fn trigger_scan(
Extension(agent): AgentExt,
tenant: TenantCtx,
Path(id): Path<String>,
) -> Result<Json<serde_json::Value>, StatusCode> {
let agent_clone = (*agent).clone();
let tenant_id = tenant.0.tenant_id.clone();
tokio::spawn(async move {
if let Err(e) = agent_clone.run_scan(&id, ScanTrigger::Manual).await {
if let Err(e) = agent_clone
.run_scan(&tenant_id, &id, ScanTrigger::Manual)
.await
{
tracing::error!("Manual scan failed for {id}: {e}");
}
});
@@ -170,11 +178,12 @@ pub async fn trigger_scan(
/// Return the webhook secret for a repository (used by dashboard to display it)
pub async fn get_webhook_config(
Extension(agent): AgentExt,
tenant: TenantCtx,
Path(id): Path<String>,
) -> Result<Json<serde_json::Value>, StatusCode> {
let oid = mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?;
let repo = agent
.db
let db = tenant_db(&agent, &tenant).await?;
let repo = db
.repositories()
.find_one(doc! { "_id": oid })
.await
@@ -196,10 +205,12 @@ pub async fn get_webhook_config(
#[tracing::instrument(skip_all, fields(repo_id = %id))]
pub async fn delete_repository(
Extension(agent): AgentExt,
tenant: TenantCtx,
Path(id): Path<String>,
) -> Result<Json<serde_json::Value>, StatusCode> {
let oid = mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?;
let db = &agent.db;
let db = tenant_db(&agent, &tenant).await?;
let db = &db;
// Delete the repository
let result = db
+16 -5
View File
@@ -6,6 +6,7 @@ use mongodb::bson::doc;
use super::dto::*;
use compliance_core::models::SbomEntry;
use compliance_core::tenant_ctx::TenantCtx;
const COPYLEFT_LICENSES: &[&str] = &[
"GPL-2.0",
@@ -29,8 +30,10 @@ const COPYLEFT_LICENSES: &[&str] = &[
#[tracing::instrument(skip_all)]
pub async fn sbom_filters(
Extension(agent): AgentExt,
tenant: TenantCtx,
) -> Result<Json<serde_json::Value>, StatusCode> {
let db = &agent.db;
let db = tenant_db(&agent, &tenant).await?;
let db = &db;
let managers: Vec<String> = db
.sbom_entries()
@@ -61,9 +64,11 @@ pub async fn sbom_filters(
#[tracing::instrument(skip_all, fields(repo_id = ?filter.repo_id, package_manager = ?filter.package_manager))]
pub async fn list_sbom(
Extension(agent): AgentExt,
tenant: TenantCtx,
Query(filter): Query<SbomFilter>,
) -> ApiResult<Vec<SbomEntry>> {
let db = &agent.db;
let db = tenant_db(&agent, &tenant).await?;
let db = &db;
let mut query = doc! {};
if let Some(repo_id) = &filter.repo_id {
@@ -120,9 +125,11 @@ pub async fn list_sbom(
#[tracing::instrument(skip_all)]
pub async fn export_sbom(
Extension(agent): AgentExt,
tenant: TenantCtx,
Query(params): Query<SbomExportParams>,
) -> Result<impl IntoResponse, StatusCode> {
let db = &agent.db;
let db = tenant_db(&agent, &tenant).await?;
let db = &db;
let entries: Vec<SbomEntry> = match db
.sbom_entries()
.find(doc! { "repo_id": &params.repo_id })
@@ -236,9 +243,11 @@ pub async fn export_sbom(
#[tracing::instrument(skip_all)]
pub async fn license_summary(
Extension(agent): AgentExt,
tenant: TenantCtx,
Query(params): Query<SbomFilter>,
) -> ApiResult<Vec<LicenseSummary>> {
let db = &agent.db;
let db = tenant_db(&agent, &tenant).await?;
let db = &db;
let mut query = doc! {};
if let Some(repo_id) = &params.repo_id {
query.insert("repo_id", repo_id);
@@ -285,9 +294,11 @@ pub async fn license_summary(
#[tracing::instrument(skip_all)]
pub async fn sbom_diff(
Extension(agent): AgentExt,
tenant: TenantCtx,
Query(params): Query<SbomDiffParams>,
) -> ApiResult<SbomDiffResult> {
let db = &agent.db;
let db = tenant_db(&agent, &tenant).await?;
let db = &db;
let entries_a: Vec<SbomEntry> = match db
.sbom_entries()
+4 -1
View File
@@ -4,13 +4,16 @@ use mongodb::bson::doc;
use super::dto::*;
use compliance_core::models::ScanRun;
use compliance_core::tenant_ctx::TenantCtx;
#[tracing::instrument(skip_all)]
pub async fn list_scan_runs(
Extension(agent): AgentExt,
tenant: TenantCtx,
Query(params): Query<PaginationParams>,
) -> ApiResult<Vec<ScanRun>> {
let db = &agent.db;
let db = tenant_db(&agent, &tenant).await?;
let db = &db;
let skip = (params.page.saturating_sub(1)) * params.limit as u64;
let total = db.scan_runs().count_documents(doc! {}).await.unwrap_or(0);
-3
View File
@@ -1,8 +1,5 @@
pub mod auth_middleware;
pub mod handlers;
pub mod routes;
pub mod server;
pub mod tenant_ctx;
pub use server::start_api_server;
pub use tenant_ctx::TenantCtx;
+49 -6
View File
@@ -1,17 +1,54 @@
use std::sync::Arc;
use axum::extract::Request;
use axum::http::HeaderValue;
use axum::middleware::Next;
use axum::response::Response;
use axum::{middleware, Extension};
use tokio::sync::RwLock;
use tower_http::cors::CorsLayer;
use tower_http::set_header::SetResponseHeaderLayer;
use tower_http::trace::TraceLayer;
use compliance_core::auth::{require_jwt_auth, require_tenant_status, JwksState};
use compliance_core::{TenantContext, TenantStatus};
use crate::agent::ComplianceAgent;
use crate::api::auth_middleware::{require_jwt_auth, require_tenant_status, JwksState};
use crate::api::routes;
use crate::error::AgentError;
/// Synthetic tenant id used when Keycloak isn't configured (local dev,
/// `cargo run` against a bare Mongo). Lets the handler stack stay
/// uniformly tenant-scoped without the operator having to spin up KC
/// just to poke at the API. Override via `DEV_TENANT_ID`.
const DEFAULT_DEV_TENANT_ID: &str = "dev";
/// Inject a synthetic [`TenantContext`] for any request that lacks one.
/// Only mounted when Keycloak is NOT configured; with KC, the real
/// `require_jwt_auth` middleware owns this and we never reach here
/// without a context.
///
/// Public so the integration-test harness can mount it without
/// duplicating the synthetic-context shape.
pub async fn inject_dev_tenant(mut request: Request, next: Next) -> Response {
if request.extensions().get::<TenantContext>().is_none() {
let tenant_id =
std::env::var("DEV_TENANT_ID").unwrap_or_else(|_| DEFAULT_DEV_TENANT_ID.to_string());
let ctx = TenantContext {
tenant_slug: tenant_id.clone(),
tenant_id,
org_roles: vec![],
products: vec![],
plan: "dev".to_string(),
status: TenantStatus::Active,
user_id: "dev-user".to_string(),
user_name: None,
};
request.extensions_mut().insert(ctx);
}
next.run(request).await
}
pub async fn start_api_server(agent: ComplianceAgent, port: u16) -> Result<(), AgentError> {
let mut app = routes::build_router()
.layer(Extension(Arc::new(agent.clone())))
@@ -44,16 +81,22 @@ pub async fn start_api_server(agent: ComplianceAgent, port: u16) -> Result<(), A
jwks_url,
};
tracing::info!("Keycloak JWT auth enabled for realm '{kc_realm}'");
// Layers execute outermost-first. The Extension must run before
// require_jwt_auth so that middleware can read JwksState from
// request extensions, and the status gate must run after the
// JWT auth so TenantContext is in extensions.
// Layers execute outermost-first. Extension(jwks_state) must run
// before require_jwt_auth so the middleware can read it; the
// status gate runs after JWT so TenantContext is in extensions.
app = app
.layer(middleware::from_fn(require_tenant_status))
.layer(middleware::from_fn(require_jwt_auth))
.layer(Extension(jwks_state));
} else {
tracing::warn!("Keycloak not configured - API endpoints are unprotected");
let tenant_id =
std::env::var("DEV_TENANT_ID").unwrap_or_else(|_| DEFAULT_DEV_TENANT_ID.to_string());
tracing::warn!(
tenant_id = %tenant_id,
"Keycloak not configured — running unauthenticated against the dev tenant. \
DO NOT use in any environment with real customer data."
);
app = app.layer(middleware::from_fn(inject_dev_tenant));
}
let addr = format!("0.0.0.0:{port}");
+192
View File
@@ -1,11 +1,197 @@
use std::sync::Arc;
use dashmap::DashMap;
use mongodb::bson::doc;
use mongodb::options::IndexOptions;
use mongodb::{Client, Collection, IndexModel};
use sha2::{Digest, Sha256};
use compliance_core::models::*;
use compliance_core::TenantContext;
use crate::error::AgentError;
/// Mongo enforces a 63-byte cap on database names (older clusters: 64
/// on Linux, 63 on Windows; we target the conservative limit).
const MAX_DB_NAME_LEN: usize = 63;
/// Hex length of the SHA-256 truncation used for the hash fallback
/// tenant DB name (16 bytes → 32 hex chars). 16 bytes gives ~2^64
/// birthday-collision resistance — at our 10s-100s tenant scale this
/// is effectively impossible to hit.
const HASH_HEX_LEN: usize = 32;
/// Largest `db_prefix` that still guarantees the hash-fallback name
/// fits in the 63-byte cap: `prefix + "_" + 32 hex chars`.
const MAX_PREFIX_LEN: usize = MAX_DB_NAME_LEN - 1 - HASH_HEX_LEN;
/// Per-tenant Mongo connection broker (M7.2 isolation model).
///
/// Holds one [`Client`] and hands out [`Database`] handles physically
/// scoped to `<db_prefix>_<tenant_id>`. The driver is the isolation
/// boundary — a handle for tenant A cannot see tenant B's documents
/// because it is connected to a different database, not because of an
/// application-level filter.
///
/// Index creation runs idempotently the first time each tenant is seen
/// in the process's lifetime. Mongo's `createIndex` is itself idempotent
/// by index name; the in-memory `ensured` set just skips the round-trip.
#[derive(Clone, Debug)]
pub struct DatabasePool {
client: Client,
db_prefix: String,
ensured: Arc<DashMap<String, ()>>,
}
impl DatabasePool {
/// Connect to the cluster and prepare to hand out tenant databases
/// named `<db_prefix>_<tenant_id>`.
///
/// Validates `db_prefix.len() <= MAX_PREFIX_LEN` so the
/// hash-fallback path is provably within Mongo's 63-byte db-name
/// cap. Refuses to construct a pool that could ever produce an
/// over-long name.
pub async fn connect(uri: &str, db_prefix: &str) -> Result<Self, AgentError> {
if db_prefix.len() > MAX_PREFIX_LEN {
return Err(AgentError::Other(format!(
"db_prefix '{db_prefix}' is {} chars; max is {MAX_PREFIX_LEN} so the \
hash-fallback tenant DB name fits Mongo's {MAX_DB_NAME_LEN}-byte cap",
db_prefix.len()
)));
}
let client = Client::with_uri_str(uri).await?;
client
.database("admin")
.run_command(doc! { "ping": 1 })
.await?;
tracing::info!(
"MongoDB cluster reachable; per-tenant pool ready (db prefix '{db_prefix}')"
);
Ok(Self {
client,
db_prefix: db_prefix.to_string(),
ensured: Arc::new(DashMap::new()),
})
}
/// Return a [`Database`] scoped to this tenant. Ensures indexes on
/// first call per tenant (per process). Cheap on the hot path —
/// subsequent calls skip the round-trip.
pub async fn for_tenant(&self, ctx: &TenantContext) -> Result<Database, AgentError> {
self.for_tenant_id(&ctx.tenant_id).await
}
/// Like [`Self::for_tenant`] but accepts a bare tenant_id.
/// For background paths (scheduler, webhooks, pipeline orchestrators)
/// that don't have a full [`TenantContext`] but know which tenant
/// they're operating on (typically resolved from a URL path, a job
/// argument, or the registry).
pub async fn for_tenant_id(&self, tenant_id: &str) -> Result<Database, AgentError> {
let db_name = self.tenant_db_name(tenant_id);
let db = Database::from_database(self.client.database(&db_name));
// `DashMap::insert` returns the previous value; `None` means we
// were the first writer for this tenant_id and own the
// index-ensure work.
if self.ensured.insert(tenant_id.to_string(), ()).is_none() {
if let Err(e) = db.ensure_indexes().await {
// Roll the marker back so the next request retries.
self.ensured.remove(tenant_id);
return Err(e);
}
tracing::debug!(
tenant_id = %tenant_id,
db_name = %db_name,
"Indexes ensured for tenant database"
);
}
Ok(db)
}
/// Compute the Mongo database name for a tenant. Public for tests
/// and tenant offboarding (`pool.client().database(name).drop()`).
///
/// Format: `<prefix>_<sanitized_tenant_id>` if it fits the 63-byte
/// cap, else `<prefix>_<sha256-16-byte-hex-of-tenant_id>`. The
/// `db_prefix` length invariant established at [`Self::connect`]
/// guarantees the hash-fallback name always fits — no runtime
/// assertion needed.
///
/// Collision resistance: the hash fallback is a 16-byte SHA-256
/// truncation, which gives ~2^64 birthday-collision resistance. At
/// our 10s100s tenant scale the probability of two tenant_ids
/// colliding is effectively zero. (8-byte truncation would have
/// been ~2^32 — too close for comfort on a regulated product.)
pub fn tenant_db_name(&self, tenant_id: &str) -> String {
let sanitized = sanitize_tenant_id(tenant_id);
let natural = format!("{}_{}", self.db_prefix, sanitized);
if natural.len() <= MAX_DB_NAME_LEN {
natural
} else {
let mut hasher = Sha256::new();
hasher.update(tenant_id.as_bytes());
let digest = hasher.finalize();
let suffix = hex::encode(&digest[..HASH_HEX_LEN / 2]);
format!("{}_{}", self.db_prefix, suffix)
}
}
/// Raw client handle. Reserved for cross-tenant admin flows that
/// must opt in explicitly (tenant listing, drop-on-offboard).
pub fn client(&self) -> &Client {
&self.client
}
/// List every Mongo database currently belonging to this pool,
/// identified by the `<db_prefix>_` prefix. The result is the raw
/// database names — opening one for offboarding/cleanup goes
/// through [`Self::client`].
///
/// Note: hashed-fallback names (very long tenant_ids) lose the
/// original tenant_id at the cluster level — we know a database
/// exists for *some* tenant but not which one. In practice
/// tenant_ids are UUIDs (36 chars) and never hit the fallback,
/// so this is a theoretical concern, not an operational one.
pub async fn list_tenant_db_names(&self) -> Result<Vec<String>, AgentError> {
let prefix = format!("{}_", self.db_prefix);
let names = self.client.list_database_names().await?;
Ok(names
.into_iter()
.filter(|n| n.starts_with(&prefix))
.collect())
}
/// Drop the database for a specific tenant. Used by GDPR delete
/// and tenant offboarding. Idempotent — dropping a non-existent
/// database is a no-op at the driver level.
///
/// Also evicts the tenant from the in-memory `ensured` set so a
/// later re-provision triggers fresh `ensure_indexes`.
pub async fn drop_tenant(&self, tenant_id: &str) -> Result<(), AgentError> {
let db_name = self.tenant_db_name(tenant_id);
self.client.database(&db_name).drop().await?;
self.ensured.remove(tenant_id);
tracing::info!(
tenant_id = %tenant_id,
db_name = %db_name,
"Dropped tenant database"
);
Ok(())
}
}
/// Mongo database names disallow `/`, `\`, `.`, `"`, `$`, ` `, and NUL.
/// breakpilot-dev tenant_ids are UUIDs so this is belt-and-braces, but
/// it lets the pool tolerate any future tenant_id shape without surprise.
fn sanitize_tenant_id(tenant_id: &str) -> String {
tenant_id
.chars()
.map(|c| match c {
'/' | '\\' | '.' | '"' | '$' | ' ' | '\0' => '_',
c => c,
})
.collect()
}
#[derive(Clone, Debug)]
pub struct Database {
inner: mongodb::Database,
@@ -20,6 +206,12 @@ impl Database {
Ok(Self { inner: db })
}
/// Wrap an already-resolved Mongo database. Used by [`DatabasePool`]
/// to hand out tenant-scoped handles without a fresh client per tenant.
pub(crate) fn from_database(inner: mongodb::Database) -> Self {
Self { inner }
}
pub async fn ensure_indexes(&self) -> Result<(), AgentError> {
// repositories: unique git_url
self.repositories()
+6 -3
View File
@@ -25,10 +25,13 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
}
tracing::info!("Connecting to MongoDB...");
let db = database::Database::connect(&config.mongodb_uri, &config.mongodb_database).await?;
db.ensure_indexes().await?;
// Per-tenant pool only — the agent has no shared "default" database
// after M7.2-D. `mongodb_database` is now the db-name prefix used
// for tenant databases (`<prefix>_<tenant_id>`).
let db_pool =
database::DatabasePool::connect(&config.mongodb_uri, &config.mongodb_database).await?;
let agent = agent::ComplianceAgent::new(config.clone(), db.clone());
let agent = agent::ComplianceAgent::new(config.clone(), db_pool);
tracing::info!("Starting scheduler...");
let scheduler_agent = agent.clone();
+77 -22
View File
@@ -4,8 +4,14 @@ use tokio_cron_scheduler::{Job, JobScheduler};
use compliance_core::models::ScanTrigger;
use crate::agent::ComplianceAgent;
use crate::database::Database;
use crate::error::AgentError;
/// Default tenant the scheduler runs against when `SCHEDULER_TENANT_IDS`
/// isn't set. Matches the dev-injector default so a bare `cargo run` has
/// the scheduler scanning whatever lives in `<prefix>_dev`.
const DEFAULT_SCHEDULER_TENANT_ID: &str = "dev";
pub async fn start_scheduler(agent: &ComplianceAgent) -> Result<(), AgentError> {
let sched = JobScheduler::new()
.await
@@ -18,7 +24,9 @@ pub async fn start_scheduler(agent: &ComplianceAgent) -> Result<(), AgentError>
let agent = scan_agent.clone();
Box::pin(async move {
tracing::info!("Scheduled scan triggered");
scan_all_repos(&agent).await;
for tenant_id in scheduler_tenants() {
scan_all_repos(&agent, &tenant_id).await;
}
})
})
.map_err(|e| AgentError::Scheduler(format!("Failed to create scan job: {e}")))?;
@@ -34,7 +42,9 @@ pub async fn start_scheduler(agent: &ComplianceAgent) -> Result<(), AgentError>
let agent = cve_agent.clone();
Box::pin(async move {
tracing::info!("CVE monitor triggered");
monitor_cves(&agent).await;
for tenant_id in scheduler_tenants() {
monitor_cves(&agent, &tenant_id).await;
}
})
})
.map_err(|e| AgentError::Scheduler(format!("Failed to create CVE monitor job: {e}")))?;
@@ -48,8 +58,9 @@ pub async fn start_scheduler(agent: &ComplianceAgent) -> Result<(), AgentError>
.await
.map_err(|e| AgentError::Scheduler(format!("Failed to start scheduler: {e}")))?;
let tenants = scheduler_tenants();
tracing::info!(
"Scheduler started: scans='{}', CVE monitor='{}'",
"Scheduler started: scans='{}', CVE monitor='{}', tenants={tenants:?}",
agent.config.scan_schedule,
agent.config.cve_monitor_schedule,
);
@@ -60,13 +71,47 @@ pub async fn start_scheduler(agent: &ComplianceAgent) -> Result<(), AgentError>
}
}
async fn scan_all_repos(agent: &ComplianceAgent) {
/// Tenants the scheduler iterates each tick. From `SCHEDULER_TENANT_IDS`
/// (comma-separated), or `DEFAULT_SCHEDULER_TENANT_ID` if unset. M7.2-D
/// will replace this with a pull from the tenant-registry.
fn scheduler_tenants() -> Vec<String> {
std::env::var("SCHEDULER_TENANT_IDS")
.ok()
.map(|s| {
s.split(',')
.map(str::trim)
.filter(|s| !s.is_empty())
.map(String::from)
.collect::<Vec<_>>()
})
.filter(|v| !v.is_empty())
.unwrap_or_else(|| vec![DEFAULT_SCHEDULER_TENANT_ID.to_string()])
}
/// Resolve the per-tenant database. Logs and returns `None` on failure
/// so the loop in the caller can continue with other tenants.
async fn tenant_db(agent: &ComplianceAgent, tenant_id: &str) -> Option<Database> {
match agent.db_pool.for_tenant_id(tenant_id).await {
Ok(db) => Some(db),
Err(e) => {
tracing::error!("Scheduler: cannot open tenant database '{tenant_id}': {e}");
None
}
}
}
async fn scan_all_repos(agent: &ComplianceAgent, tenant_id: &str) {
use futures_util::StreamExt;
let cursor = match agent.db.repositories().find(doc! {}).await {
let db = match tenant_db(agent, tenant_id).await {
Some(db) => db,
None => return,
};
let cursor = match db.repositories().find(doc! {}).await {
Ok(c) => c,
Err(e) => {
tracing::error!("Failed to list repos for scheduled scan: {e}");
tracing::error!("Failed to list repos for tenant '{tenant_id}': {e}");
return;
}
};
@@ -75,33 +120,44 @@ async fn scan_all_repos(agent: &ComplianceAgent) {
for repo in repos {
let repo_id = repo.id.map(|id| id.to_hex()).unwrap_or_default();
if let Err(e) = agent.run_scan(&repo_id, ScanTrigger::Scheduled).await {
tracing::error!("Scheduled scan failed for {}: {e}", repo.name);
if let Err(e) = agent
.run_scan(tenant_id, &repo_id, ScanTrigger::Scheduled)
.await
{
tracing::error!(
"Scheduled scan failed for {} (tenant '{tenant_id}'): {e}",
repo.name
);
}
}
}
async fn monitor_cves(agent: &ComplianceAgent) {
async fn monitor_cves(agent: &ComplianceAgent, tenant_id: &str) {
use compliance_core::models::notification::{parse_severity, CveNotification};
use compliance_core::models::SbomEntry;
use futures_util::StreamExt;
let db = match tenant_db(agent, tenant_id).await {
Some(db) => db,
None => return,
};
// Fetch all SBOM entries grouped by repo
let cursor = match agent.db.sbom_entries().find(doc! {}).await {
let cursor = match db.sbom_entries().find(doc! {}).await {
Ok(c) => c,
Err(e) => {
tracing::error!("CVE monitor: failed to list SBOM entries: {e}");
tracing::error!("CVE monitor: failed to list SBOM entries for '{tenant_id}': {e}");
return;
}
};
let entries: Vec<SbomEntry> = cursor.filter_map(|r| async { r.ok() }).collect().await;
if entries.is_empty() {
tracing::debug!("CVE monitor: no SBOM entries, skipping");
tracing::debug!("CVE monitor: no SBOM entries for tenant '{tenant_id}', skipping");
return;
}
tracing::info!(
"CVE monitor: checking {} dependencies for new CVEs",
"CVE monitor: checking {} dependencies for new CVEs (tenant '{tenant_id}')",
entries.len()
);
@@ -112,7 +168,7 @@ async fn monitor_cves(agent: &ComplianceAgent) {
std::collections::HashMap::new();
for rid in &repo_ids {
if let Ok(oid) = mongodb::bson::oid::ObjectId::parse_str(rid) {
if let Ok(Some(repo)) = agent.db.repositories().find_one(doc! { "_id": oid }).await {
if let Ok(Some(repo)) = db.repositories().find_one(doc! { "_id": oid }).await {
repo_names.insert(rid.clone(), repo.name.clone());
}
}
@@ -160,8 +216,7 @@ async fn monitor_cves(agent: &ComplianceAgent) {
for alert in &alerts {
let filter = doc! { "cve_id": &alert.cve_id, "repo_id": &alert.repo_id };
let update = doc! { "$setOnInsert": mongodb::bson::to_bson(alert).unwrap_or_default() };
let _ = agent
.db
let _ = db
.cve_alerts()
.update_one(filter, update)
.upsert(true)
@@ -174,8 +229,7 @@ async fn monitor_cves(agent: &ComplianceAgent) {
continue;
}
if let Some(entry_id) = &entry.id {
let _ = agent
.db
let _ = db
.sbom_entries()
.update_one(
doc! { "_id": entry_id },
@@ -213,8 +267,7 @@ async fn monitor_cves(agent: &ComplianceAgent) {
let update = doc! {
"$setOnInsert": mongodb::bson::to_bson(&notification).unwrap_or_default()
};
match agent
.db
match db
.cve_notifications()
.update_one(filter, update)
.upsert(true)
@@ -232,8 +285,10 @@ async fn monitor_cves(agent: &ComplianceAgent) {
}
if new_notifications > 0 {
tracing::info!("CVE monitor: created {new_notifications} new notification(s)");
tracing::info!(
"CVE monitor: created {new_notifications} new notification(s) for tenant '{tenant_id}'"
);
} else {
tracing::info!("CVE monitor: no new CVEs found");
tracing::info!("CVE monitor: no new CVEs found for tenant '{tenant_id}'");
}
}
+23 -9
View File
@@ -14,24 +14,30 @@ type HmacSha256 = Hmac<Sha256>;
pub async fn handle_gitea_webhook(
Extension(agent): Extension<Arc<ComplianceAgent>>,
Path(repo_id): Path<String>,
Path((tenant_id, repo_id)): Path<(String, String)>,
headers: HeaderMap,
body: Bytes,
) -> StatusCode {
// Look up the repo to get its webhook secret
// Look up the repo in the tenant's database to get its webhook secret
let oid = match mongodb::bson::oid::ObjectId::parse_str(&repo_id) {
Ok(oid) => oid,
Err(_) => return StatusCode::NOT_FOUND,
};
let repo = match agent
.db
let db = match agent.db_pool.for_tenant_id(&tenant_id).await {
Ok(db) => db,
Err(e) => {
tracing::warn!("Gitea webhook: cannot open tenant database '{tenant_id}': {e}");
return StatusCode::NOT_FOUND;
}
};
let repo = match db
.repositories()
.find_one(mongodb::bson::doc! { "_id": oid })
.await
{
Ok(Some(repo)) => repo,
_ => {
tracing::warn!("Gitea webhook: repo {repo_id} not found");
tracing::warn!("Gitea webhook: repo {repo_id} not found in tenant '{tenant_id}'");
return StatusCode::NOT_FOUND;
}
};
@@ -66,15 +72,21 @@ pub async fn handle_gitea_webhook(
"push" => {
let agent_clone = (*agent).clone();
let repo_id = repo_id.clone();
let tenant_id = tenant_id.clone();
tokio::spawn(async move {
tracing::info!("Gitea push webhook: triggering scan for {repo_id}");
if let Err(e) = agent_clone.run_scan(&repo_id, ScanTrigger::Webhook).await {
tracing::info!(
"Gitea push webhook: triggering scan for {repo_id} in tenant {tenant_id}"
);
if let Err(e) = agent_clone
.run_scan(&tenant_id, &repo_id, ScanTrigger::Webhook)
.await
{
tracing::error!("Webhook-triggered scan failed: {e}");
}
});
StatusCode::OK
}
"pull_request" => handle_pull_request(agent, &repo_id, &payload).await,
"pull_request" => handle_pull_request(agent, &tenant_id, &repo_id, &payload).await,
_ => {
tracing::debug!("Gitea webhook: ignoring event '{event}'");
StatusCode::OK
@@ -84,6 +96,7 @@ pub async fn handle_gitea_webhook(
async fn handle_pull_request(
agent: Arc<ComplianceAgent>,
tenant_id: &str,
repo_id: &str,
payload: &serde_json::Value,
) -> StatusCode {
@@ -106,13 +119,14 @@ async fn handle_pull_request(
}
let repo_id = repo_id.to_string();
let tenant_id = tenant_id.to_string();
let head_sha = head_sha.to_string();
let base_sha = base_sha.to_string();
let agent_clone = (*agent).clone();
tokio::spawn(async move {
tracing::info!("Gitea PR webhook: reviewing PR #{pr_number} on {repo_id}");
if let Err(e) = agent_clone
.run_pr_review(&repo_id, pr_number, &base_sha, &head_sha)
.run_pr_review(&tenant_id, &repo_id, pr_number, &base_sha, &head_sha)
.await
{
tracing::error!("PR review failed for #{pr_number}: {e}");
+23 -9
View File
@@ -14,24 +14,30 @@ type HmacSha256 = Hmac<Sha256>;
pub async fn handle_github_webhook(
Extension(agent): Extension<Arc<ComplianceAgent>>,
Path(repo_id): Path<String>,
Path((tenant_id, repo_id)): Path<(String, String)>,
headers: HeaderMap,
body: Bytes,
) -> StatusCode {
// Look up the repo to get its webhook secret
// Look up the repo in the tenant's database to get its webhook secret
let oid = match mongodb::bson::oid::ObjectId::parse_str(&repo_id) {
Ok(oid) => oid,
Err(_) => return StatusCode::NOT_FOUND,
};
let repo = match agent
.db
let db = match agent.db_pool.for_tenant_id(&tenant_id).await {
Ok(db) => db,
Err(e) => {
tracing::warn!("GitHub webhook: cannot open tenant database '{tenant_id}': {e}");
return StatusCode::NOT_FOUND;
}
};
let repo = match db
.repositories()
.find_one(mongodb::bson::doc! { "_id": oid })
.await
{
Ok(Some(repo)) => repo,
_ => {
tracing::warn!("GitHub webhook: repo {repo_id} not found");
tracing::warn!("GitHub webhook: repo {repo_id} not found in tenant '{tenant_id}'");
return StatusCode::NOT_FOUND;
}
};
@@ -66,15 +72,21 @@ pub async fn handle_github_webhook(
"push" => {
let agent_clone = (*agent).clone();
let repo_id = repo_id.clone();
let tenant_id = tenant_id.clone();
tokio::spawn(async move {
tracing::info!("GitHub push webhook: triggering scan for {repo_id}");
if let Err(e) = agent_clone.run_scan(&repo_id, ScanTrigger::Webhook).await {
tracing::info!(
"GitHub push webhook: triggering scan for {repo_id} in tenant {tenant_id}"
);
if let Err(e) = agent_clone
.run_scan(&tenant_id, &repo_id, ScanTrigger::Webhook)
.await
{
tracing::error!("Webhook-triggered scan failed: {e}");
}
});
StatusCode::OK
}
"pull_request" => handle_pull_request(agent, &repo_id, &payload).await,
"pull_request" => handle_pull_request(agent, &tenant_id, &repo_id, &payload).await,
_ => {
tracing::debug!("GitHub webhook: ignoring event '{event}'");
StatusCode::OK
@@ -84,6 +96,7 @@ pub async fn handle_github_webhook(
async fn handle_pull_request(
agent: Arc<ComplianceAgent>,
tenant_id: &str,
repo_id: &str,
payload: &serde_json::Value,
) -> StatusCode {
@@ -105,13 +118,14 @@ async fn handle_pull_request(
}
let repo_id = repo_id.to_string();
let tenant_id = tenant_id.to_string();
let head_sha = head_sha.to_string();
let base_sha = base_sha.to_string();
let agent_clone = (*agent).clone();
tokio::spawn(async move {
tracing::info!("GitHub PR webhook: reviewing PR #{pr_number} on {repo_id}");
if let Err(e) = agent_clone
.run_pr_review(&repo_id, pr_number, &base_sha, &head_sha)
.run_pr_review(&tenant_id, &repo_id, pr_number, &base_sha, &head_sha)
.await
{
tracing::error!("PR review failed for #{pr_number}: {e}");
+23 -9
View File
@@ -10,24 +10,30 @@ use crate::agent::ComplianceAgent;
pub async fn handle_gitlab_webhook(
Extension(agent): Extension<Arc<ComplianceAgent>>,
Path(repo_id): Path<String>,
Path((tenant_id, repo_id)): Path<(String, String)>,
headers: HeaderMap,
body: Bytes,
) -> StatusCode {
// Look up the repo to get its webhook secret
// Look up the repo in the tenant's database to get its webhook secret
let oid = match mongodb::bson::oid::ObjectId::parse_str(&repo_id) {
Ok(oid) => oid,
Err(_) => return StatusCode::NOT_FOUND,
};
let repo = match agent
.db
let db = match agent.db_pool.for_tenant_id(&tenant_id).await {
Ok(db) => db,
Err(e) => {
tracing::warn!("GitLab webhook: cannot open tenant database '{tenant_id}': {e}");
return StatusCode::NOT_FOUND;
}
};
let repo = match db
.repositories()
.find_one(mongodb::bson::doc! { "_id": oid })
.await
{
Ok(Some(repo)) => repo,
_ => {
tracing::warn!("GitLab webhook: repo {repo_id} not found");
tracing::warn!("GitLab webhook: repo {repo_id} not found in tenant '{tenant_id}'");
return StatusCode::NOT_FOUND;
}
};
@@ -59,15 +65,21 @@ pub async fn handle_gitlab_webhook(
"push" => {
let agent_clone = (*agent).clone();
let repo_id = repo_id.clone();
let tenant_id = tenant_id.clone();
tokio::spawn(async move {
tracing::info!("GitLab push webhook: triggering scan for {repo_id}");
if let Err(e) = agent_clone.run_scan(&repo_id, ScanTrigger::Webhook).await {
tracing::info!(
"GitLab push webhook: triggering scan for {repo_id} in tenant {tenant_id}"
);
if let Err(e) = agent_clone
.run_scan(&tenant_id, &repo_id, ScanTrigger::Webhook)
.await
{
tracing::error!("Webhook-triggered scan failed: {e}");
}
});
StatusCode::OK
}
"merge_request" => handle_merge_request(agent, &repo_id, &payload).await,
"merge_request" => handle_merge_request(agent, &tenant_id, &repo_id, &payload).await,
_ => {
tracing::debug!("GitLab webhook: ignoring event '{event_type}'");
StatusCode::OK
@@ -77,6 +89,7 @@ pub async fn handle_gitlab_webhook(
async fn handle_merge_request(
agent: Arc<ComplianceAgent>,
tenant_id: &str,
repo_id: &str,
payload: &serde_json::Value,
) -> StatusCode {
@@ -101,13 +114,14 @@ async fn handle_merge_request(
}
let repo_id = repo_id.to_string();
let tenant_id = tenant_id.to_string();
let head_sha = head_sha.to_string();
let base_sha = base_sha.to_string();
let agent_clone = (*agent).clone();
tokio::spawn(async move {
tracing::info!("GitLab MR webhook: reviewing MR !{mr_iid} on {repo_id}");
if let Err(e) = agent_clone
.run_pr_review(&repo_id, mr_iid, &base_sha, &head_sha)
.run_pr_review(&tenant_id, &repo_id, mr_iid, &base_sha, &head_sha)
.await
{
tracing::error!("MR review failed for !{mr_iid}: {e}");
+8 -4
View File
@@ -9,17 +9,21 @@ use crate::webhooks::{gitea, github, gitlab};
pub async fn start_webhook_server(agent: &ComplianceAgent) -> Result<(), AgentError> {
let app = Router::new()
// Per-repo webhook URLs: /webhook/{platform}/{repo_id}
// Per-tenant per-repo webhook URLs: /webhook/{tenant_id}/{platform}/{repo_id}
// The tenant_id is resolved from the URL path because webhooks
// arrive without a JWT — they're authenticated via per-repo HMAC,
// not via the tenant gate. The dashboard surfaces the full URL
// including the tenant_id when the repo is registered.
.route(
"/webhook/github/{repo_id}",
"/webhook/{tenant_id}/github/{repo_id}",
post(github::handle_github_webhook),
)
.route(
"/webhook/gitlab/{repo_id}",
"/webhook/{tenant_id}/gitlab/{repo_id}",
post(gitlab::handle_gitlab_webhook),
)
.route(
"/webhook/gitea/{repo_id}",
"/webhook/{tenant_id}/gitea/{repo_id}",
post(gitea::handle_gitea_webhook),
)
.layer(Extension(Arc::new(agent.clone())));
+20 -8
View File
@@ -7,7 +7,7 @@ use std::sync::Arc;
use compliance_agent::agent::ComplianceAgent;
use compliance_agent::api;
use compliance_agent::database::Database;
use compliance_agent::database::DatabasePool;
use compliance_core::AgentConfig;
use secrecy::SecretString;
@@ -28,10 +28,9 @@ impl TestServer {
// Unique database name per test run to avoid collisions
let db_name = format!("test_{}", uuid::Uuid::new_v4().simple());
let db = Database::connect(&mongodb_uri, &db_name)
let db_pool = DatabasePool::connect(&mongodb_uri, &db_name)
.await
.expect("Failed to connect to MongoDB — is it running?");
db.ensure_indexes().await.expect("Failed to create indexes");
.expect("Failed to build DatabasePool");
let config = AgentConfig {
mongodb_uri: mongodb_uri.clone(),
@@ -69,11 +68,15 @@ impl TestServer {
pentest_imap_password: None,
};
let agent = ComplianceAgent::new(config, db);
let agent = ComplianceAgent::new(config, db_pool);
// Build the router with the agent extension
// Build the router with the agent extension. After M7.2-B every
// handler takes a TenantCtx extractor; without KC in the test
// harness, the dev-tenant injector mounts a synthetic context so
// tests run end-to-end against `<db_name>_dev`.
let app = api::routes::build_router()
.layer(axum::extract::Extension(Arc::new(agent)))
.layer(axum::middleware::from_fn(api::server::inject_dev_tenant))
.layer(tower_http::cors::CorsLayer::permissive());
// Bind to port 0 to get a random available port
@@ -156,10 +159,19 @@ impl TestServer {
&self.db_name
}
/// Drop the test database on cleanup
/// Drop every per-tenant database belonging to this test run.
/// Post-M7.2-D the agent never opens a `db_name` directly —
/// data lives only in `<db_name>_<tenant>` per-tenant databases.
pub async fn cleanup(&self) {
if let Ok(client) = mongodb::Client::with_uri_str(&self.mongodb_uri).await {
client.database(&self.db_name).drop().await.ok();
if let Ok(names) = client.list_database_names().await {
let prefix = format!("{}_", self.db_name);
for name in names {
if name.starts_with(&prefix) {
client.database(&name).drop().await.ok();
}
}
}
}
}
}
+298
View File
@@ -0,0 +1,298 @@
//! M7.2-A — `DatabasePool` isolation proof.
//!
//! Two `TenantContext`s, two databases, one client. Insert on A, query
//! on B → empty. Insert on B, query on A → only A's docs. Proves that
//! the per-tenant database split actually isolates at the driver level
//! and not at "we hope we filter."
//!
//! Requires MongoDB. Set `TEST_MONGODB_URI` to override the default
//! `mongodb://root:example@localhost:27017/?authSource=admin`.
#![allow(clippy::expect_used, clippy::unwrap_used)]
use compliance_agent::database::DatabasePool;
use compliance_core::models::TrackedRepository;
use compliance_core::{OrgRole, TenantContext, TenantStatus};
use mongodb::bson::doc;
fn ctx(tenant_id: &str, slug: &str) -> TenantContext {
TenantContext {
tenant_id: tenant_id.to_string(),
tenant_slug: slug.to_string(),
org_roles: vec![OrgRole::ItAdmin],
products: vec!["compliance-scanner".to_string()],
plan: "starter".to_string(),
status: TenantStatus::Active,
user_id: "u-1".to_string(),
user_name: None,
}
}
fn fixture_repo(name: &str, git_url: &str) -> TrackedRepository {
TrackedRepository {
id: None,
name: name.to_string(),
git_url: git_url.to_string(),
default_branch: "main".to_string(),
local_path: None,
scan_schedule: None,
webhook_enabled: false,
webhook_secret: None,
tracker_type: None,
tracker_owner: None,
tracker_repo: None,
tracker_token: None,
auth_token: None,
auth_username: None,
last_scanned_commit: None,
findings_count: 0,
created_at: chrono::Utc::now(),
updated_at: chrono::Utc::now(),
}
}
#[tokio::test]
async fn pool_isolates_tenants_at_driver_level() {
let uri = std::env::var("TEST_MONGODB_URI")
.unwrap_or_else(|_| "mongodb://root:example@localhost:27017/?authSource=admin".into());
// Unique per run so parallel test invocations don't collide. Kept
// short because Mongo caps db names at 63 bytes (prefix + tenant_id).
let prefix = format!("m72a_{}", short_id());
let pool = DatabasePool::connect(&uri, &prefix)
.await
.expect("Failed to connect to MongoDB — is it running?");
let acme = ctx("00000000-0000-0000-0000-00000000acme", "acme");
let globex = ctx("00000000-0000-0000-0000-0000globex000", "globex");
let acme_db = pool.for_tenant(&acme).await.expect("acme db");
let globex_db = pool.for_tenant(&globex).await.expect("globex db");
// Write distinct repos into each tenant's database.
acme_db
.repositories()
.insert_one(fixture_repo("acme-app", "git@example.com:acme/app.git"))
.await
.expect("insert acme");
globex_db
.repositories()
.insert_one(fixture_repo(
"globex-platform",
"git@example.com:globex/platform.git",
))
.await
.expect("insert globex");
// The point of the whole exercise: acme can ONLY see acme's repo
// and globex can ONLY see globex's, with no filter doc anywhere
// because the isolation is at the database handle, not in the query.
let acme_seen = collect(&acme_db).await;
let globex_seen = collect(&globex_db).await;
assert_eq!(acme_seen.len(), 1, "acme should see exactly its own repo");
assert_eq!(acme_seen[0].name, "acme-app");
assert_eq!(
globex_seen.len(),
1,
"globex should see exactly its own repo"
);
assert_eq!(globex_seen[0].name, "globex-platform");
// Sanity: the two databases really are different by name.
let acme_db_name = pool.tenant_db_name(&acme.tenant_id);
let globex_db_name = pool.tenant_db_name(&globex.tenant_id);
assert_ne!(acme_db_name, globex_db_name);
assert!(acme_db_name.starts_with(&prefix));
// Cleanup — drop both per-tenant databases.
pool.client()
.database(&acme_db_name)
.drop()
.await
.expect("drop acme");
pool.client()
.database(&globex_db_name)
.drop()
.await
.expect("drop globex");
}
#[tokio::test]
async fn for_tenant_is_idempotent_index_creation() {
let uri = std::env::var("TEST_MONGODB_URI")
.unwrap_or_else(|_| "mongodb://root:example@localhost:27017/?authSource=admin".into());
let prefix = format!("m72a_{}", short_id());
let pool = DatabasePool::connect(&uri, &prefix).await.expect("connect");
let acme = ctx("00000000-0000-0000-0000-00000000acme", "acme");
// Second call must not fail (ensure_indexes already ran, in-memory
// marker is set, Mongo's createIndex is idempotent by name anyway).
let _ = pool.for_tenant(&acme).await.expect("first call");
let _ = pool.for_tenant(&acme).await.expect("second call");
let _ = pool.for_tenant(&acme).await.expect("third call");
// Cleanup
let db_name = pool.tenant_db_name(&acme.tenant_id);
pool.client().database(&db_name).drop().await.expect("drop");
}
#[tokio::test]
async fn tenant_db_name_sanitizes_unsafe_characters() {
let uri = std::env::var("TEST_MONGODB_URI")
.unwrap_or_else(|_| "mongodb://root:example@localhost:27017/?authSource=admin".into());
let pool = DatabasePool::connect(&uri, "m72a_sanitize")
.await
.expect("connect");
// Mongo db names cannot contain `/ \ . " $ <space> NUL`. The pool
// must rewrite these without exploding on connect.
let funky = "te/n.a\\nt$id\" with spaces";
let name = pool.tenant_db_name(funky);
for c in ['/', '\\', '.', '"', '$', ' '] {
assert!(
!name.contains(c),
"sanitized db name still contains {c:?}: {name}"
);
}
}
#[tokio::test]
async fn admin_helpers_list_and_drop_tenant_dbs() {
let uri = std::env::var("TEST_MONGODB_URI")
.unwrap_or_else(|_| "mongodb://root:example@localhost:27017/?authSource=admin".into());
let prefix = format!("m72d_{}", short_id());
let pool = DatabasePool::connect(&uri, &prefix).await.expect("connect");
let acme = ctx("00000000-0000-0000-0000-00000000acme", "acme");
let globex = ctx("00000000-0000-0000-0000-0000globex000", "globex");
// Provision two tenants and write a doc into each so the databases
// actually materialize on the cluster (Mongo lazily creates DBs).
let acme_db = pool.for_tenant(&acme).await.expect("acme db");
let globex_db = pool.for_tenant(&globex).await.expect("globex db");
acme_db
.repositories()
.insert_one(fixture_repo("acme-app", "git@example.com:acme/app.git"))
.await
.expect("insert acme");
globex_db
.repositories()
.insert_one(fixture_repo("globex-app", "git@example.com:globex/app.git"))
.await
.expect("insert globex");
// list_tenant_db_names sees both, filtered by prefix
let names = pool.list_tenant_db_names().await.expect("list tenants");
let acme_name = pool.tenant_db_name(&acme.tenant_id);
let globex_name = pool.tenant_db_name(&globex.tenant_id);
assert!(
names.contains(&acme_name),
"expected {acme_name} in {names:?}"
);
assert!(
names.contains(&globex_name),
"expected {globex_name} in {names:?}"
);
for name in &names {
assert!(name.starts_with(&format!("{prefix}_")));
}
// drop_tenant removes acme's DB
pool.drop_tenant(&acme.tenant_id)
.await
.expect("drop acme tenant");
let after = pool
.list_tenant_db_names()
.await
.expect("list tenants after drop");
assert!(
!after.contains(&acme_name),
"acme should be gone after drop, got {after:?}"
);
assert!(
after.contains(&globex_name),
"globex should still be present, got {after:?}"
);
// Cleanup remaining
pool.drop_tenant(&globex.tenant_id)
.await
.expect("drop globex tenant");
}
#[tokio::test]
async fn tenant_db_name_falls_back_to_hash_when_too_long() {
let uri = std::env::var("TEST_MONGODB_URI")
.unwrap_or_else(|_| "mongodb://root:example@localhost:27017/?authSource=admin".into());
let pool = DatabasePool::connect(&uri, "m72a_long")
.await
.expect("connect");
// 100-byte tenant_id would overflow the 63-byte db-name cap with
// any reasonable prefix. The pool must hash it down.
let huge = "x".repeat(100);
let name = pool.tenant_db_name(&huge);
assert!(name.len() <= 63, "hashed name should fit: {name}");
assert!(name.starts_with("m72a_long_"));
// The hash suffix is 32 hex chars (16-byte SHA-256 truncation).
let suffix = name.trim_start_matches("m72a_long_");
assert_eq!(
suffix.len(),
32,
"expected 32-hex suffix (16-byte hash), got {suffix:?}"
);
assert!(suffix.chars().all(|c| c.is_ascii_hexdigit()));
// Stable: same input → same output.
assert_eq!(name, pool.tenant_db_name(&huge));
// Different inputs → different outputs (collision check on a tiny
// sample — full birthday-resistance is a proof not a test).
let huge2 = "y".repeat(100);
assert_ne!(pool.tenant_db_name(&huge), pool.tenant_db_name(&huge2));
}
#[tokio::test]
async fn connect_rejects_overlong_db_prefix() {
let uri = std::env::var("TEST_MONGODB_URI")
.unwrap_or_else(|_| "mongodb://root:example@localhost:27017/?authSource=admin".into());
// MAX_PREFIX_LEN is 30 (= 63 - 1 - 32). A 31-char prefix MUST be
// rejected at construction so the hash-fallback path can never
// produce an over-long db name at runtime.
let too_long = "a".repeat(31);
let err = DatabasePool::connect(&uri, &too_long).await.unwrap_err();
let msg = format!("{err}");
assert!(
msg.contains("max is 30") || msg.contains(&too_long),
"error should explain the cap: {msg}"
);
// Exactly 30 chars is the inclusive bound — must succeed.
let just_right = "a".repeat(30);
let _ = DatabasePool::connect(&uri, &just_right)
.await
.expect("30-char prefix should be accepted");
}
/// Short UUID slug for keeping test prefixes well under Mongo's 63-byte
/// db-name cap.
fn short_id() -> String {
uuid::Uuid::new_v4().simple().to_string()[..8].to_string()
}
/// Drain a `repositories` find cursor on the given tenant database.
async fn collect(db: &compliance_agent::database::Database) -> Vec<TrackedRepository> {
let mut cursor = db
.repositories()
.find(doc! {})
.await
.expect("find repositories");
let mut out = Vec::new();
while cursor.advance().await.expect("advance") {
out.push(cursor.deserialize_current().expect("deserialize"));
}
out
}
@@ -1,4 +1,4 @@
//! M7.1 — integration tests for `require_tenant_status`.
//! M7.1 — integration tests for `compliance_core::auth::require_tenant_status`.
//!
//! Exercises the middleware end-to-end through an Axum router so we
//! catch wiring bugs (extension propagation, method matching) that pure
@@ -15,8 +15,7 @@ use axum::{
routing::{get, post},
Router,
};
use compliance_agent::api::auth_middleware::require_tenant_status;
use compliance_core::{TenantContext, TenantStatus};
use compliance_core::{auth::require_tenant_status, TenantContext, TenantStatus};
use tower::ServiceExt;
fn ctx_with(status: TenantStatus) -> TenantContext {
+13
View File
@@ -18,6 +18,15 @@ telemetry = [
"dep:tracing-subscriber",
"dep:tracing",
]
# Pulls in the M7.1 Axum middleware + extractor. Consumers that don't
# embed an HTTP server (e.g. the wasm dashboard frontend) leave it off.
axum = [
"dep:axum",
"dep:jsonwebtoken",
"dep:reqwest",
"dep:tokio",
"dep:tracing",
]
[dependencies]
serde = { workspace = true }
@@ -37,3 +46,7 @@ opentelemetry-appender-tracing = { version = "0.29", optional = true }
tracing-opentelemetry = { version = "0.30", optional = true }
tracing-subscriber = { workspace = true, optional = true }
tracing = { workspace = true, optional = true }
axum = { version = "0.8", optional = true }
jsonwebtoken = { version = "9", optional = true }
reqwest = { workspace = true, optional = true }
tokio = { workspace = true, optional = true }
@@ -1,9 +1,9 @@
//! M7.1 — JWT validation + tenant context propagation.
//!
//! `require_jwt_auth` validates a Bearer JWT against Keycloak's JWKS and
//! attaches a `TenantContext` to the request extensions. Downstream
//! middleware (`require_tenant_status`) and Axum extractors (`TenantCtx`)
//! read it from there.
//! attaches a [`TenantContext`] to the request extensions. Downstream
//! middleware ([`require_tenant_status`]) and Axum extractors
//! ([`crate::tenant_ctx::TenantCtx`]) read it from there.
//!
//! Skipped paths:
//! * `/api/v1/health` — Kubernetes liveness; never authenticated.
@@ -23,12 +23,13 @@ use axum::{
middleware::Next,
response::{IntoResponse, Response},
};
use compliance_core::{OrgRole, TenantContext, TenantStatus};
use jsonwebtoken::{decode, decode_header, jwk::JwkSet, DecodingKey, Validation};
use reqwest::StatusCode;
use serde::Deserialize;
use tokio::sync::RwLock;
use crate::{OrgRole, TenantContext, TenantStatus};
/// Cached JWKS from Keycloak for token validation.
#[derive(Clone)]
pub struct JwksState {
@@ -147,27 +148,83 @@ async fn validate_token(token: &str, state: &JwksState) -> Result<TenantContext,
let kid = header
.kid
.clone()
.ok_or_else(|| "JWT missing kid header".to_string())?;
let jwks = fetch_or_get_jwks(state).await?;
// First try against whatever's currently cached. If the kid isn't
// there or the signature doesn't verify, the cached JWKS is most
// likely stale (KC rotated keys) — refresh once and retry before
// giving up. Without this every key rotation produces a silent 401
// storm that only goes away when the agent restarts.
let jwks = fetch_or_get_jwks(state, false).await?;
match try_validate(token, &header, &kid, &jwks) {
Ok(ctx) => Ok(ctx),
Err(ValidationError::Permanent(e)) => Err(e),
Err(ValidationError::Stale(reason)) => {
tracing::info!(
kid = %kid,
reason = %reason,
"JWKS appears stale — forcing refresh and retrying"
);
let jwks = fetch_or_get_jwks(state, true).await?;
try_validate(token, &header, &kid, &jwks).map_err(|e| match e {
ValidationError::Stale(s) | ValidationError::Permanent(s) => s,
})
}
}
}
let jwk = jwks
#[derive(Debug)]
enum ValidationError {
/// Refresh-eligible: cached JWKS may be stale.
Stale(String),
/// Refusing the token regardless of JWKS freshness.
Permanent(String),
}
fn try_validate(
token: &str,
header: &jsonwebtoken::Header,
kid: &str,
jwks: &JwkSet,
) -> Result<TenantContext, ValidationError> {
let jwk = match jwks
.keys
.iter()
.find(|k| k.common.key_id.as_deref() == Some(&kid))
.ok_or_else(|| "no matching key found in JWKS".to_string())?;
.find(|k| k.common.key_id.as_deref() == Some(kid))
{
Some(j) => j,
None => {
return Err(ValidationError::Stale(
"no matching key found in JWKS".to_string(),
))
}
};
let decoding_key =
DecodingKey::from_jwk(jwk).map_err(|e| format!("failed to create decoding key: {e}"))?;
let decoding_key = DecodingKey::from_jwk(jwk)
.map_err(|e| ValidationError::Permanent(format!("failed to create decoding key: {e}")))?;
let mut validation = Validation::new(header.alg);
validation.validate_exp = true;
validation.validate_aud = false;
let data = decode::<Claims>(token, &decoding_key, &validation)
.map_err(|e| format!("token validation failed: {e}"))?;
let data = match decode::<Claims>(token, &decoding_key, &validation) {
Ok(d) => d,
Err(e) => {
// Signature mismatch is the other refresh-eligible failure:
// the matching kid is present but the key bytes don't match.
// Everything else (expired, malformed, etc.) is permanent.
return Err(
if matches!(e.kind(), jsonwebtoken::errors::ErrorKind::InvalidSignature) {
ValidationError::Stale(format!("token validation failed: {e}"))
} else {
ValidationError::Permanent(format!("token validation failed: {e}"))
},
);
}
};
claims_to_context(data.claims)
claims_to_context(data.claims).map_err(ValidationError::Permanent)
}
/// Map the decoded JWT payload into the platform-wide `TenantContext`.
@@ -197,14 +254,25 @@ fn claims_to_context(c: Claims) -> Result<TenantContext, String> {
})
}
async fn fetch_or_get_jwks(state: &JwksState) -> Result<JwkSet, String> {
{
async fn fetch_or_get_jwks(state: &JwksState, force: bool) -> Result<JwkSet, String> {
if !force {
let cached = state.jwks.read().await;
if let Some(ref jwks) = *cached {
return Ok(jwks.clone());
}
}
// Hold the write lock across the fetch so concurrent refreshers
// don't all hammer Keycloak when keys rotate. If another writer
// already populated a fresh JWKS while we were waiting (and we
// weren't asked to force), use theirs.
let mut cached = state.jwks.write().await;
if !force {
if let Some(ref jwks) = *cached {
return Ok(jwks.clone());
}
}
let resp = reqwest::get(&state.jwks_url)
.await
.map_err(|e| format!("failed to fetch JWKS: {e}"))?;
@@ -214,7 +282,6 @@ async fn fetch_or_get_jwks(state: &JwksState) -> Result<JwkSet, String> {
.await
.map_err(|e| format!("failed to parse JWKS: {e}"))?;
let mut cached = state.jwks.write().await;
*cached = Some(jwks.clone());
Ok(jwks)
@@ -292,6 +359,24 @@ mod tests {
);
}
#[test]
fn try_validate_returns_stale_when_kid_missing_from_jwks() {
// Empty JWKS — the kid we ask for can't possibly match. The error
// must classify as Stale so the caller refreshes JWKS and retries.
let jwks = JwkSet { keys: vec![] };
let header = jsonwebtoken::Header {
alg: jsonwebtoken::Algorithm::RS256,
kid: Some("kid-rotated-out".to_string()),
..Default::default()
};
let err = try_validate("ignored.token.value", &header, "kid-rotated-out", &jwks)
.expect_err("should fail");
match err {
ValidationError::Stale(s) => assert!(s.contains("no matching key")),
ValidationError::Permanent(s) => panic!("must be Stale, got Permanent: {s}"),
}
}
#[test]
fn is_write_detects_methods() {
assert!(!is_write(&Method::GET));
+5
View File
@@ -7,6 +7,11 @@ pub mod telemetry;
pub mod tenant;
pub mod traits;
#[cfg(feature = "axum")]
pub mod auth;
#[cfg(feature = "axum")]
pub mod tenant_ctx;
pub use config::{AgentConfig, DashboardConfig};
pub use error::CoreError;
pub use tenant::{OrgRole, TenantContext, TenantStatus};
+6 -6
View File
@@ -1,9 +1,9 @@
//! Tenant context propagated through every authenticated request.
//!
//! This module is the M7.1 single source of truth for "who is this request
//! for". Claims come from a Keycloak-issued JWT and land here via
//! `compliance-agent`'s `require_jwt_auth` middleware. Handlers reach into
//! the request extensions with the `TenantCtx` Axum extractor.
//! M7.1 single source of truth for "who is this request for". Claims come
//! from a Keycloak-issued JWT and land here via [`crate::auth::require_jwt_auth`]
//! (enabled with the `axum` feature). Handlers reach into the request
//! extensions with the [`crate::tenant_ctx::TenantCtx`] extractor.
//!
//! The shape mirrors the JWT claim names the breakpilot-platform realm
//! emits (see `platform/orca-platform/dev/keycloak/realm-export.json`).
@@ -87,8 +87,8 @@ impl OrgRole {
}
}
/// Everything `compliance-agent` knows about the requesting tenant at the
/// moment a request lands. Cheap to clone (every field is owned + small).
/// Everything we know about the requesting tenant at the moment a request
/// lands. Cheap to clone (every field is owned + small).
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TenantContext {
/// `tenants.id` from the platform's tenant-registry (UUID).
@@ -9,17 +9,18 @@
//! }
//! ```
//!
//! The middleware (`require_jwt_auth`) is responsible for inserting the
//! context into the request extensions. If it's missing on a route that
//! uses this extractor, that's a bug in the wiring — we return 401 so the
//! caller sees an auth failure rather than a 500.
//! The middleware ([`crate::auth::require_jwt_auth`]) is responsible for
//! inserting the context into the request extensions. If it's missing on
//! a route that uses this extractor, that's a bug in the wiring — we
//! return 401 so the caller sees an auth failure rather than a 500.
use axum::{
extract::FromRequestParts,
http::{request::Parts, StatusCode},
response::{IntoResponse, Response},
};
use compliance_core::TenantContext;
use crate::TenantContext;
#[derive(Debug, Clone)]
pub struct TenantCtx(pub TenantContext);
@@ -57,8 +58,8 @@ where
#[allow(clippy::expect_used, clippy::unwrap_used)]
mod tests {
use super::*;
use crate::TenantStatus;
use axum::http::Request;
use compliance_core::TenantStatus;
fn ctx() -> TenantContext {
TenantContext {
@@ -0,0 +1,210 @@
//! Authenticated HTTP client for talking to the compliance-agent.
//!
//! Every dashboard server function that hits `comp-dev.meghsakha.com/api/v1/*`
//! must go through here so the Keycloak access token from the user's
//! session is attached as `Authorization: Bearer <token>`. Without it
//! the agent's M7.1 `require_jwt_auth` middleware rejects with 401
//! "Missing authorization header".
//!
//! When Keycloak is not configured (dev convenience), the helper
//! returns an unauthenticated builder — matching the agent's
//! pass-through behavior in the same state.
//!
//! **Token refresh**: KC access tokens are short-lived (5 min default
//! in the certifai realm). Before attaching, we decode the JWT's `exp`
//! claim and proactively refresh via the stored refresh_token if the
//! access token is expired or about to expire. The session is updated
//! with the new pair. If refresh fails, we send the (stale) token
//! anyway — the agent's 401 will surface to the UI, which can prompt
//! re-login.
use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine};
use dioxus::prelude::ServerFnError;
use dioxus_fullstack::FullstackContext;
use reqwest::Method;
use super::auth::LOGGED_IN_USER_SESS_KEY;
use super::server_state::ServerState;
use super::user_state::UserStateInner;
/// Seconds before the JWT's `exp` time at which we consider it stale
/// enough to refresh. Covers clock skew + the round-trip to the agent
/// so the token doesn't expire mid-flight.
const REFRESH_SKEW_SECS: i64 = 30;
/// Build a `RequestBuilder` for `<agent_api_url><path>` with the
/// session's access token attached. `path` should include a leading
/// `/`, e.g. `"/api/v1/repositories"`.
pub async fn agent_request(
method: Method,
path: &str,
) -> Result<reqwest::RequestBuilder, ServerFnError> {
let state: ServerState = FullstackContext::extract().await?;
let url = format!("{}{}", state.agent_api_url, path);
let mut req = reqwest::Client::new().request(method, &url);
req = attach_token(req, &state).await?;
Ok(req)
}
/// Same as [`agent_request`] but for `GET`. Convenience for the common case.
pub async fn agent_get(path: &str) -> Result<reqwest::RequestBuilder, ServerFnError> {
agent_request(Method::GET, path).await
}
/// Attach the session's bearer token if Keycloak is configured AND the
/// session has a logged-in user. Refresh the token proactively if it's
/// expired or about to expire. Persists refreshed tokens back into the
/// session.
async fn attach_token(
req: reqwest::RequestBuilder,
state: &ServerState,
) -> Result<reqwest::RequestBuilder, ServerFnError> {
if state.keycloak.is_none() {
return Ok(req);
}
let session: tower_sessions::Session = FullstackContext::extract().await?;
let user: Option<UserStateInner> = session
.get(LOGGED_IN_USER_SESS_KEY)
.await
.map_err(|e| ServerFnError::new(format!("session read failed: {e}")))?;
let Some(mut user) = user else {
return Ok(req);
};
if token_needs_refresh(&user.access_token) {
tracing::debug!("Access token expired or near-expiring; refreshing");
match refresh_tokens(state, &user.refresh_token).await {
Ok((new_access, new_refresh)) => {
user.access_token = new_access;
if let Some(rt) = new_refresh {
user.refresh_token = rt;
}
if let Err(e) = session.insert(LOGGED_IN_USER_SESS_KEY, &user).await {
tracing::warn!("Failed to persist refreshed tokens: {e}");
}
}
Err(e) => {
tracing::warn!("Token refresh failed: {e}; sending current token anyway");
// Fall through — the agent will 401 and the UI will
// prompt re-login. Better than failing the request at
// the dashboard layer with no helpful UX cue.
}
}
}
Ok(req.bearer_auth(user.access_token))
}
/// Decode the JWT's payload (no signature verification — the agent
/// does that) and check the `exp` claim. Treats malformed tokens as
/// expired so the refresh path runs.
fn token_needs_refresh(jwt: &str) -> bool {
let Some(payload_b64) = jwt.split('.').nth(1) else {
return true;
};
let Ok(bytes) = URL_SAFE_NO_PAD.decode(payload_b64) else {
return true;
};
#[derive(serde::Deserialize)]
struct ExpClaim {
exp: i64,
}
let Ok(claims) = serde_json::from_slice::<ExpClaim>(&bytes) else {
return true;
};
let now = chrono::Utc::now().timestamp();
claims.exp - REFRESH_SKEW_SECS <= now
}
/// Exchange a refresh_token for a new access_token. Returns the new
/// access_token and (optionally) the new refresh_token KC issued.
/// KC may rotate refresh_tokens on use; we honor whatever it sends.
async fn refresh_tokens(
state: &ServerState,
refresh_token: &str,
) -> Result<(String, Option<String>), String> {
let kc = state
.keycloak
.ok_or_else(|| "Keycloak not configured".to_string())?;
if refresh_token.is_empty() {
return Err("no refresh_token in session".to_string());
}
#[derive(serde::Deserialize)]
struct TokenResp {
access_token: String,
refresh_token: Option<String>,
}
let resp = reqwest::Client::new()
.post(kc.token_endpoint())
.form(&[
("grant_type", "refresh_token"),
("client_id", kc.client_id.as_str()),
("refresh_token", refresh_token),
])
.send()
.await
.map_err(|e| format!("refresh request failed: {e}"))?;
if !resp.status().is_success() {
let status = resp.status();
let body = resp.text().await.unwrap_or_default();
return Err(format!("refresh rejected ({status}): {body}"));
}
let r: TokenResp = resp
.json()
.await
.map_err(|e| format!("refresh response parse failed: {e}"))?;
Ok((r.access_token, r.refresh_token))
}
#[cfg(test)]
mod tests {
use super::*;
use base64::Engine;
/// Build a JWT-shaped string (header.payload.sig) with the given
/// payload. Signature is bogus — we never verify it locally.
fn make_jwt(payload: &serde_json::Value) -> String {
let payload_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_vec(payload).unwrap());
format!("hdr.{payload_b64}.sig")
}
#[test]
fn token_needs_refresh_true_when_expired() {
let exp = chrono::Utc::now().timestamp() - 60;
let jwt = make_jwt(&serde_json::json!({ "exp": exp }));
assert!(token_needs_refresh(&jwt));
}
#[test]
fn token_needs_refresh_true_within_skew_window() {
// 10 seconds left; less than the 30s skew → must refresh.
let exp = chrono::Utc::now().timestamp() + 10;
let jwt = make_jwt(&serde_json::json!({ "exp": exp }));
assert!(token_needs_refresh(&jwt));
}
#[test]
fn token_needs_refresh_false_with_plenty_of_life() {
let exp = chrono::Utc::now().timestamp() + 600;
let jwt = make_jwt(&serde_json::json!({ "exp": exp }));
assert!(!token_needs_refresh(&jwt));
}
#[test]
fn token_needs_refresh_true_on_malformed_jwt() {
assert!(token_needs_refresh(""));
assert!(token_needs_refresh("not.a.jwt"));
assert!(token_needs_refresh("only-one-segment"));
assert!(token_needs_refresh("hdr.not-base64!.sig"));
}
#[test]
fn token_needs_refresh_true_when_exp_missing() {
let jwt = make_jwt(&serde_json::json!({ "sub": "abc" }));
assert!(token_needs_refresh(&jwt));
}
}
+26 -35
View File
@@ -61,23 +61,21 @@ pub async fn send_chat_message(
message: String,
history: Vec<ChatHistoryMessage>,
) -> Result<ChatApiResponse, ServerFnError> {
let state: super::server_state::ServerState =
dioxus_fullstack::FullstackContext::extract().await?;
let url = format!("{}/api/v1/chat/{repo_id}", state.agent_api_url);
let client = reqwest::Client::builder()
.timeout(std::time::Duration::from_secs(120))
.build()
.map_err(|e| ServerFnError::new(e.to_string()))?;
let resp = client
.post(&url)
.json(&serde_json::json!({
"message": message,
"history": history,
}))
.send()
.await
.map_err(|e| ServerFnError::new(format!("Request failed: {e}")))?;
// Chat uses a longer timeout because the LLM round-trip can be slow;
// agent_request doesn't expose a per-call timeout so we layer one on.
let resp = super::agent_client::agent_request(
reqwest::Method::POST,
&format!("/api/v1/chat/{repo_id}"),
)
.await?
.timeout(std::time::Duration::from_secs(120))
.json(&serde_json::json!({
"message": message,
"history": history,
}))
.send()
.await
.map_err(|e| ServerFnError::new(format!("Request failed: {e}")))?;
let text = resp
.text()
@@ -91,19 +89,14 @@ pub async fn send_chat_message(
#[server]
pub async fn trigger_embedding_build(repo_id: String) -> Result<(), ServerFnError> {
let state: super::server_state::ServerState =
dioxus_fullstack::FullstackContext::extract().await?;
let url = format!(
"{}/api/v1/chat/{repo_id}/build-embeddings",
state.agent_api_url
);
let client = reqwest::Client::new();
client
.post(&url)
.send()
.await
.map_err(|e| ServerFnError::new(e.to_string()))?;
super::agent_client::agent_request(
reqwest::Method::POST,
&format!("/api/v1/chat/{repo_id}/build-embeddings"),
)
.await?
.send()
.await
.map_err(|e| ServerFnError::new(e.to_string()))?;
Ok(())
}
@@ -111,11 +104,9 @@ pub async fn trigger_embedding_build(repo_id: String) -> Result<(), ServerFnErro
pub async fn fetch_embedding_status(
repo_id: String,
) -> Result<EmbeddingStatusResponse, ServerFnError> {
let state: super::server_state::ServerState =
dioxus_fullstack::FullstackContext::extract().await?;
let url = format!("{}/api/v1/chat/{repo_id}/status", state.agent_api_url);
let resp = reqwest::get(&url)
let resp = super::agent_client::agent_get(&format!("/api/v1/chat/{repo_id}/status"))
.await?
.send()
.await
.map_err(|e| ServerFnError::new(e.to_string()))?;
let body: EmbeddingStatusResponse = resp
+22 -34
View File
@@ -26,10 +26,9 @@ pub struct DastFindingDetailResponse {
#[server]
pub async fn fetch_dast_targets() -> Result<DastTargetsResponse, ServerFnError> {
let state: super::server_state::ServerState =
dioxus_fullstack::FullstackContext::extract().await?;
let url = format!("{}/api/v1/dast/targets", state.agent_api_url);
let resp = reqwest::get(&url)
let resp = super::agent_client::agent_get("/api/v1/dast/targets")
.await?
.send()
.await
.map_err(|e| ServerFnError::new(e.to_string()))?;
let body: DastTargetsResponse = resp
@@ -41,10 +40,9 @@ pub async fn fetch_dast_targets() -> Result<DastTargetsResponse, ServerFnError>
#[server]
pub async fn fetch_dast_scan_runs() -> Result<DastScanRunsResponse, ServerFnError> {
let state: super::server_state::ServerState =
dioxus_fullstack::FullstackContext::extract().await?;
let url = format!("{}/api/v1/dast/scan-runs", state.agent_api_url);
let resp = reqwest::get(&url)
let resp = super::agent_client::agent_get("/api/v1/dast/scan-runs")
.await?
.send()
.await
.map_err(|e| ServerFnError::new(e.to_string()))?;
let body: DastScanRunsResponse = resp
@@ -56,10 +54,9 @@ pub async fn fetch_dast_scan_runs() -> Result<DastScanRunsResponse, ServerFnErro
#[server]
pub async fn fetch_dast_findings() -> Result<DastFindingsResponse, ServerFnError> {
let state: super::server_state::ServerState =
dioxus_fullstack::FullstackContext::extract().await?;
let url = format!("{}/api/v1/dast/findings", state.agent_api_url);
let resp = reqwest::get(&url)
let resp = super::agent_client::agent_get("/api/v1/dast/findings")
.await?
.send()
.await
.map_err(|e| ServerFnError::new(e.to_string()))?;
let body: DastFindingsResponse = resp
@@ -73,10 +70,9 @@ pub async fn fetch_dast_findings() -> Result<DastFindingsResponse, ServerFnError
pub async fn fetch_dast_finding_detail(
id: String,
) -> Result<DastFindingDetailResponse, ServerFnError> {
let state: super::server_state::ServerState =
dioxus_fullstack::FullstackContext::extract().await?;
let url = format!("{}/api/v1/dast/findings/{id}", state.agent_api_url);
let resp = reqwest::get(&url)
let resp = super::agent_client::agent_get(&format!("/api/v1/dast/findings/{id}"))
.await?
.send()
.await
.map_err(|e| ServerFnError::new(e.to_string()))?;
let body: DastFindingDetailResponse = resp
@@ -88,12 +84,8 @@ pub async fn fetch_dast_finding_detail(
#[server]
pub async fn add_dast_target(name: String, base_url: String) -> Result<(), ServerFnError> {
let state: super::server_state::ServerState =
dioxus_fullstack::FullstackContext::extract().await?;
let url = format!("{}/api/v1/dast/targets", state.agent_api_url);
let client = reqwest::Client::new();
client
.post(&url)
super::agent_client::agent_request(reqwest::Method::POST, "/api/v1/dast/targets")
.await?
.json(&serde_json::json!({
"name": name,
"base_url": base_url,
@@ -106,17 +98,13 @@ pub async fn add_dast_target(name: String, base_url: String) -> Result<(), Serve
#[server]
pub async fn trigger_dast_scan(target_id: String) -> Result<(), ServerFnError> {
let state: super::server_state::ServerState =
dioxus_fullstack::FullstackContext::extract().await?;
let url = format!(
"{}/api/v1/dast/targets/{target_id}/scan",
state.agent_api_url
);
let client = reqwest::Client::new();
client
.post(&url)
.send()
.await
.map_err(|e| ServerFnError::new(e.to_string()))?;
super::agent_client::agent_request(
reqwest::Method::POST,
&format!("/api/v1/dast/targets/{target_id}/scan"),
)
.await?
.send()
.await
.map_err(|e| ServerFnError::new(e.to_string()))?;
Ok(())
}
@@ -24,39 +24,35 @@ pub struct FindingsQuery {
#[server]
pub async fn fetch_findings(query: FindingsQuery) -> Result<FindingsListResponse, ServerFnError> {
let state: super::server_state::ServerState =
dioxus_fullstack::FullstackContext::extract().await?;
let mut url = format!(
"{}/api/v1/findings?page={}&limit=20",
state.agent_api_url, query.page
);
let mut path = format!("/api/v1/findings?page={}&limit=20", query.page);
if !query.severity.is_empty() {
url.push_str(&format!("&severity={}", query.severity));
path.push_str(&format!("&severity={}", query.severity));
}
if !query.scan_type.is_empty() {
url.push_str(&format!("&scan_type={}", query.scan_type));
path.push_str(&format!("&scan_type={}", query.scan_type));
}
if !query.status.is_empty() {
url.push_str(&format!("&status={}", query.status));
path.push_str(&format!("&status={}", query.status));
}
if !query.repo_id.is_empty() {
url.push_str(&format!("&repo_id={}", query.repo_id));
path.push_str(&format!("&repo_id={}", query.repo_id));
}
if !query.q.is_empty() {
url.push_str(&format!(
path.push_str(&format!(
"&q={}",
url::form_urlencoded::byte_serialize(query.q.as_bytes()).collect::<String>()
));
}
if !query.sort_by.is_empty() {
url.push_str(&format!("&sort_by={}", query.sort_by));
path.push_str(&format!("&sort_by={}", query.sort_by));
}
if !query.sort_order.is_empty() {
url.push_str(&format!("&sort_order={}", query.sort_order));
path.push_str(&format!("&sort_order={}", query.sort_order));
}
let resp = reqwest::get(&url)
let resp = super::agent_client::agent_get(&path)
.await?
.send()
.await
.map_err(|e| ServerFnError::new(e.to_string()))?;
let body: FindingsListResponse = resp
@@ -68,11 +64,9 @@ pub async fn fetch_findings(query: FindingsQuery) -> Result<FindingsListResponse
#[server]
pub async fn fetch_finding_detail(id: String) -> Result<Finding, ServerFnError> {
let state: super::server_state::ServerState =
dioxus_fullstack::FullstackContext::extract().await?;
let url = format!("{}/api/v1/findings/{id}", state.agent_api_url);
let resp = reqwest::get(&url)
let resp = super::agent_client::agent_get(&format!("/api/v1/findings/{id}"))
.await?
.send()
.await
.map_err(|e| ServerFnError::new(e.to_string()))?;
let body: serde_json::Value = resp
@@ -86,18 +80,15 @@ pub async fn fetch_finding_detail(id: String) -> Result<Finding, ServerFnError>
#[server]
pub async fn update_finding_status(id: String, status: String) -> Result<(), ServerFnError> {
let state: super::server_state::ServerState =
dioxus_fullstack::FullstackContext::extract().await?;
let url = format!("{}/api/v1/findings/{id}/status", state.agent_api_url);
let client = reqwest::Client::new();
client
.patch(&url)
.json(&serde_json::json!({ "status": status }))
.send()
.await
.map_err(|e| ServerFnError::new(e.to_string()))?;
super::agent_client::agent_request(
reqwest::Method::PATCH,
&format!("/api/v1/findings/{id}/status"),
)
.await?
.json(&serde_json::json!({ "status": status }))
.send()
.await
.map_err(|e| ServerFnError::new(e.to_string()))?;
Ok(())
}
@@ -106,34 +97,25 @@ pub async fn bulk_update_finding_status(
ids: Vec<String>,
status: String,
) -> Result<(), ServerFnError> {
let state: super::server_state::ServerState =
dioxus_fullstack::FullstackContext::extract().await?;
let url = format!("{}/api/v1/findings/bulk-status", state.agent_api_url);
let client = reqwest::Client::new();
client
.patch(&url)
super::agent_client::agent_request(reqwest::Method::PATCH, "/api/v1/findings/bulk-status")
.await?
.json(&serde_json::json!({ "ids": ids, "status": status }))
.send()
.await
.map_err(|e| ServerFnError::new(e.to_string()))?;
Ok(())
}
#[server]
pub async fn update_finding_feedback(id: String, feedback: String) -> Result<(), ServerFnError> {
let state: super::server_state::ServerState =
dioxus_fullstack::FullstackContext::extract().await?;
let url = format!("{}/api/v1/findings/{id}/feedback", state.agent_api_url);
let client = reqwest::Client::new();
client
.patch(&url)
.json(&serde_json::json!({ "feedback": feedback }))
.send()
.await
.map_err(|e| ServerFnError::new(e.to_string()))?;
super::agent_client::agent_request(
reqwest::Method::PATCH,
&format!("/api/v1/findings/{id}/feedback"),
)
.await?
.json(&serde_json::json!({ "feedback": feedback }))
.send()
.await
.map_err(|e| ServerFnError::new(e.to_string()))?;
Ok(())
}
@@ -50,10 +50,9 @@ pub struct SearchResponse {
#[server]
pub async fn fetch_graph(repo_id: String) -> Result<GraphDataResponse, ServerFnError> {
let state: super::server_state::ServerState =
dioxus_fullstack::FullstackContext::extract().await?;
let url = format!("{}/api/v1/graph/{repo_id}", state.agent_api_url);
let resp = reqwest::get(&url)
let resp = super::agent_client::agent_get(&format!("/api/v1/graph/{repo_id}"))
.await?
.send()
.await
.map_err(|e| ServerFnError::new(e.to_string()))?;
let body: GraphDataResponse = resp
@@ -68,15 +67,12 @@ pub async fn fetch_impact(
repo_id: String,
finding_id: String,
) -> Result<ImpactResponse, ServerFnError> {
let state: super::server_state::ServerState =
dioxus_fullstack::FullstackContext::extract().await?;
let url = format!(
"{}/api/v1/graph/{repo_id}/impact/{finding_id}",
state.agent_api_url
);
let resp = reqwest::get(&url)
.await
.map_err(|e| ServerFnError::new(e.to_string()))?;
let resp =
super::agent_client::agent_get(&format!("/api/v1/graph/{repo_id}/impact/{finding_id}"))
.await?
.send()
.await
.map_err(|e| ServerFnError::new(e.to_string()))?;
let body: ImpactResponse = resp
.json()
.await
@@ -86,10 +82,9 @@ pub async fn fetch_impact(
#[server]
pub async fn fetch_communities(repo_id: String) -> Result<CommunitiesResponse, ServerFnError> {
let state: super::server_state::ServerState =
dioxus_fullstack::FullstackContext::extract().await?;
let url = format!("{}/api/v1/graph/{repo_id}/communities", state.agent_api_url);
let resp = reqwest::get(&url)
let resp = super::agent_client::agent_get(&format!("/api/v1/graph/{repo_id}/communities"))
.await?
.send()
.await
.map_err(|e| ServerFnError::new(e.to_string()))?;
let body: CommunitiesResponse = resp
@@ -104,15 +99,13 @@ pub async fn fetch_file_content(
repo_id: String,
file_path: String,
) -> Result<FileContentResponse, ServerFnError> {
let state: super::server_state::ServerState =
dioxus_fullstack::FullstackContext::extract().await?;
let url = format!(
"{}/api/v1/graph/{repo_id}/file-content?path={file_path}",
state.agent_api_url
);
let resp = reqwest::get(&url)
.await
.map_err(|e| ServerFnError::new(e.to_string()))?;
let resp = super::agent_client::agent_get(&format!(
"/api/v1/graph/{repo_id}/file-content?path={file_path}"
))
.await?
.send()
.await
.map_err(|e| ServerFnError::new(e.to_string()))?;
let body: FileContentResponse = resp
.json()
.await
@@ -122,15 +115,13 @@ pub async fn fetch_file_content(
#[server]
pub async fn search_nodes(repo_id: String, query: String) -> Result<SearchResponse, ServerFnError> {
let state: super::server_state::ServerState =
dioxus_fullstack::FullstackContext::extract().await?;
let url = format!(
"{}/api/v1/graph/{repo_id}/search?q={query}&limit=50",
state.agent_api_url
);
let resp = reqwest::get(&url)
.await
.map_err(|e| ServerFnError::new(e.to_string()))?;
let resp = super::agent_client::agent_get(&format!(
"/api/v1/graph/{repo_id}/search?q={query}&limit=50"
))
.await?
.send()
.await
.map_err(|e| ServerFnError::new(e.to_string()))?;
let body: SearchResponse = resp
.json()
.await
@@ -140,14 +131,13 @@ pub async fn search_nodes(repo_id: String, query: String) -> Result<SearchRespon
#[server]
pub async fn trigger_graph_build(repo_id: String) -> Result<(), ServerFnError> {
let state: super::server_state::ServerState =
dioxus_fullstack::FullstackContext::extract().await?;
let url = format!("{}/api/v1/graph/{repo_id}/build", state.agent_api_url);
let client = reqwest::Client::new();
client
.post(&url)
.send()
.await
.map_err(|e| ServerFnError::new(e.to_string()))?;
super::agent_client::agent_request(
reqwest::Method::POST,
&format!("/api/v1/graph/{repo_id}/build"),
)
.await?
.send()
.await
.map_err(|e| ServerFnError::new(e.to_string()))?;
Ok(())
}
@@ -12,11 +12,9 @@ pub struct IssuesListResponse {
#[server]
pub async fn fetch_issues(page: u64) -> Result<IssuesListResponse, ServerFnError> {
let state: super::server_state::ServerState =
dioxus_fullstack::FullstackContext::extract().await?;
let url = format!("{}/api/v1/issues?page={page}&limit=20", state.agent_api_url);
let resp = reqwest::get(&url)
let resp = super::agent_client::agent_get(&format!("/api/v1/issues?page={page}&limit=20"))
.await?
.send()
.await
.map_err(|e| ServerFnError::new(e.to_string()))?;
let body: IssuesListResponse = resp
@@ -18,6 +18,8 @@ pub mod stats;
// Server-only modules
#[cfg(feature = "server")]
mod agent_client;
#[cfg(feature = "server")]
mod auth;
#[cfg(feature = "server")]
mod auth_middleware;
@@ -32,11 +32,9 @@ pub struct NotificationCountResponse {
#[server]
pub async fn fetch_notification_count() -> Result<u64, ServerFnError> {
let state: super::server_state::ServerState =
dioxus_fullstack::FullstackContext::extract().await?;
let url = format!("{}/api/v1/notifications/count", state.agent_api_url);
let resp = reqwest::get(&url)
let resp = super::agent_client::agent_get("/api/v1/notifications/count")
.await?
.send()
.await
.map_err(|e| ServerFnError::new(e.to_string()))?;
let body: NotificationCountResponse = resp
@@ -48,11 +46,9 @@ pub async fn fetch_notification_count() -> Result<u64, ServerFnError> {
#[server]
pub async fn fetch_notifications() -> Result<NotificationListResponse, ServerFnError> {
let state: super::server_state::ServerState =
dioxus_fullstack::FullstackContext::extract().await?;
let url = format!("{}/api/v1/notifications?limit=20", state.agent_api_url);
let resp = reqwest::get(&url)
let resp = super::agent_client::agent_get("/api/v1/notifications?limit=20")
.await?
.send()
.await
.map_err(|e| ServerFnError::new(e.to_string()))?;
let body: NotificationListResponse = resp
@@ -64,12 +60,8 @@ pub async fn fetch_notifications() -> Result<NotificationListResponse, ServerFnE
#[server]
pub async fn mark_all_notifications_read() -> Result<(), ServerFnError> {
let state: super::server_state::ServerState =
dioxus_fullstack::FullstackContext::extract().await?;
let url = format!("{}/api/v1/notifications/read-all", state.agent_api_url);
reqwest::Client::new()
.post(&url)
super::agent_client::agent_request(reqwest::Method::POST, "/api/v1/notifications/read-all")
.await?
.send()
.await
.map_err(|e| ServerFnError::new(e.to_string()))?;
@@ -78,14 +70,13 @@ pub async fn mark_all_notifications_read() -> Result<(), ServerFnError> {
#[server]
pub async fn dismiss_notification(id: String) -> Result<(), ServerFnError> {
let state: super::server_state::ServerState =
dioxus_fullstack::FullstackContext::extract().await?;
let url = format!("{}/api/v1/notifications/{id}/dismiss", state.agent_api_url);
reqwest::Client::new()
.patch(&url)
.send()
.await
.map_err(|e| ServerFnError::new(e.to_string()))?;
super::agent_client::agent_request(
reqwest::Method::PATCH,
&format!("/api/v1/notifications/{id}/dismiss"),
)
.await?
.send()
.await
.map_err(|e| ServerFnError::new(e.to_string()))?;
Ok(())
}
+145 -184
View File
@@ -32,12 +32,10 @@ pub struct AttackChainResponse {
#[server]
pub async fn fetch_pentest_sessions() -> Result<PentestSessionsResponse, ServerFnError> {
let state: super::server_state::ServerState =
dioxus_fullstack::FullstackContext::extract().await?;
// Fetch sessions
let url = format!("{}/api/v1/pentest/sessions", state.agent_api_url);
let resp = reqwest::get(&url)
let resp = super::agent_client::agent_get("/api/v1/pentest/sessions")
.await?
.send()
.await
.map_err(|e| ServerFnError::new(e.to_string()))?;
let mut body: PentestSessionsResponse = resp
@@ -46,31 +44,32 @@ pub async fn fetch_pentest_sessions() -> Result<PentestSessionsResponse, ServerF
.map_err(|e| ServerFnError::new(e.to_string()))?;
// Fetch DAST targets to resolve target names
let targets_url = format!("{}/api/v1/dast/targets", state.agent_api_url);
if let Ok(tresp) = reqwest::get(&targets_url).await {
if let Ok(tbody) = tresp.json::<serde_json::Value>().await {
let targets = tbody.get("data").and_then(|v| v.as_array());
if let Some(targets) = targets {
// Build target_id -> name lookup
let target_map: std::collections::HashMap<String, String> = targets
.iter()
.filter_map(|t| {
let id = t.get("_id")?.get("$oid")?.as_str()?.to_string();
let name = t.get("name")?.as_str()?.to_string();
Some((id, name))
})
.collect();
if let Ok(tresp_builder) = super::agent_client::agent_get("/api/v1/dast/targets").await {
if let Ok(tresp) = tresp_builder.send().await {
if let Ok(tbody) = tresp.json::<serde_json::Value>().await {
let targets = tbody.get("data").and_then(|v| v.as_array());
if let Some(targets) = targets {
// Build target_id -> name lookup
let target_map: std::collections::HashMap<String, String> = targets
.iter()
.filter_map(|t| {
let id = t.get("_id")?.get("$oid")?.as_str()?.to_string();
let name = t.get("name")?.as_str()?.to_string();
Some((id, name))
})
.collect();
// Enrich sessions with target_name
for session in body.data.iter_mut() {
if let Some(tid) = session.get("target_id").and_then(|v| v.as_str()) {
if let Some(name) = target_map.get(tid) {
session.as_object_mut().map(|obj| {
obj.insert(
"target_name".to_string(),
serde_json::Value::String(name.clone()),
)
});
// Enrich sessions with target_name
for session in body.data.iter_mut() {
if let Some(tid) = session.get("target_id").and_then(|v| v.as_str()) {
if let Some(name) = target_map.get(tid) {
session.as_object_mut().map(|obj| {
obj.insert(
"target_name".to_string(),
serde_json::Value::String(name.clone()),
)
});
}
}
}
}
@@ -83,10 +82,9 @@ pub async fn fetch_pentest_sessions() -> Result<PentestSessionsResponse, ServerF
#[server]
pub async fn fetch_pentest_session(id: String) -> Result<PentestSessionResponse, ServerFnError> {
let state: super::server_state::ServerState =
dioxus_fullstack::FullstackContext::extract().await?;
let url = format!("{}/api/v1/pentest/sessions/{id}", state.agent_api_url);
let resp = reqwest::get(&url)
let resp = super::agent_client::agent_get(&format!("/api/v1/pentest/sessions/{id}"))
.await?
.send()
.await
.map_err(|e| ServerFnError::new(e.to_string()))?;
let mut body: PentestSessionResponse = resp
@@ -96,26 +94,27 @@ pub async fn fetch_pentest_session(id: String) -> Result<PentestSessionResponse,
// Resolve target name from targets list
if let Some(tid) = body.data.get("target_id").and_then(|v| v.as_str()) {
let targets_url = format!("{}/api/v1/dast/targets", state.agent_api_url);
if let Ok(tresp) = reqwest::get(&targets_url).await {
if let Ok(tbody) = tresp.json::<serde_json::Value>().await {
if let Some(targets) = tbody.get("data").and_then(|v| v.as_array()) {
for t in targets {
let t_id = t
.get("_id")
.and_then(|v| v.get("$oid"))
.and_then(|v| v.as_str())
.unwrap_or("");
if t_id == tid {
if let Some(name) = t.get("name").and_then(|v| v.as_str()) {
body.data.as_object_mut().map(|obj| {
obj.insert(
"target_name".to_string(),
serde_json::Value::String(name.to_string()),
)
});
if let Ok(tresp_builder) = super::agent_client::agent_get("/api/v1/dast/targets").await {
if let Ok(tresp) = tresp_builder.send().await {
if let Ok(tbody) = tresp.json::<serde_json::Value>().await {
if let Some(targets) = tbody.get("data").and_then(|v| v.as_array()) {
for t in targets {
let t_id = t
.get("_id")
.and_then(|v| v.get("$oid"))
.and_then(|v| v.as_str())
.unwrap_or("");
if t_id == tid {
if let Some(name) = t.get("name").and_then(|v| v.as_str()) {
body.data.as_object_mut().map(|obj| {
obj.insert(
"target_name".to_string(),
serde_json::Value::String(name.to_string()),
)
});
}
break;
}
break;
}
}
}
@@ -130,15 +129,12 @@ pub async fn fetch_pentest_session(id: String) -> Result<PentestSessionResponse,
pub async fn fetch_pentest_messages(
session_id: String,
) -> Result<PentestMessagesResponse, ServerFnError> {
let state: super::server_state::ServerState =
dioxus_fullstack::FullstackContext::extract().await?;
let url = format!(
"{}/api/v1/pentest/sessions/{session_id}/messages",
state.agent_api_url
);
let resp = reqwest::get(&url)
.await
.map_err(|e| ServerFnError::new(e.to_string()))?;
let resp =
super::agent_client::agent_get(&format!("/api/v1/pentest/sessions/{session_id}/messages"))
.await?
.send()
.await
.map_err(|e| ServerFnError::new(e.to_string()))?;
let body: PentestMessagesResponse = resp
.json()
.await
@@ -148,10 +144,9 @@ pub async fn fetch_pentest_messages(
#[server]
pub async fn fetch_pentest_stats() -> Result<PentestStatsResponse, ServerFnError> {
let state: super::server_state::ServerState =
dioxus_fullstack::FullstackContext::extract().await?;
let url = format!("{}/api/v1/pentest/stats", state.agent_api_url);
let resp = reqwest::get(&url)
let resp = super::agent_client::agent_get("/api/v1/pentest/stats")
.await?
.send()
.await
.map_err(|e| ServerFnError::new(e.to_string()))?;
let body: PentestStatsResponse = resp
@@ -163,15 +158,13 @@ pub async fn fetch_pentest_stats() -> Result<PentestStatsResponse, ServerFnError
#[server]
pub async fn fetch_attack_chain(session_id: String) -> Result<AttackChainResponse, ServerFnError> {
let state: super::server_state::ServerState =
dioxus_fullstack::FullstackContext::extract().await?;
let url = format!(
"{}/api/v1/pentest/sessions/{session_id}/attack-chain",
state.agent_api_url
);
let resp = reqwest::get(&url)
.await
.map_err(|e| ServerFnError::new(e.to_string()))?;
let resp = super::agent_client::agent_get(&format!(
"/api/v1/pentest/sessions/{session_id}/attack-chain"
))
.await?
.send()
.await
.map_err(|e| ServerFnError::new(e.to_string()))?;
let body: AttackChainResponse = resp
.json()
.await
@@ -185,20 +178,17 @@ pub async fn create_pentest_session(
strategy: String,
message: String,
) -> Result<PentestSessionResponse, ServerFnError> {
let state: super::server_state::ServerState =
dioxus_fullstack::FullstackContext::extract().await?;
let url = format!("{}/api/v1/pentest/sessions", state.agent_api_url);
let client = reqwest::Client::new();
let resp = client
.post(&url)
.json(&serde_json::json!({
"target_id": target_id,
"strategy": strategy,
"message": message,
}))
.send()
.await
.map_err(|e| ServerFnError::new(e.to_string()))?;
let resp =
super::agent_client::agent_request(reqwest::Method::POST, "/api/v1/pentest/sessions")
.await?
.json(&serde_json::json!({
"target_id": target_id,
"strategy": strategy,
"message": message,
}))
.send()
.await
.map_err(|e| ServerFnError::new(e.to_string()))?;
let body: PentestSessionResponse = resp
.json()
.await
@@ -211,18 +201,15 @@ pub async fn create_pentest_session(
pub async fn create_pentest_session_wizard(
config_json: String,
) -> Result<PentestSessionResponse, ServerFnError> {
let state: super::server_state::ServerState =
dioxus_fullstack::FullstackContext::extract().await?;
let url = format!("{}/api/v1/pentest/sessions", state.agent_api_url);
let config: serde_json::Value =
serde_json::from_str(&config_json).map_err(|e| ServerFnError::new(e.to_string()))?;
let client = reqwest::Client::new();
let resp = client
.post(&url)
.json(&serde_json::json!({ "config": config }))
.send()
.await
.map_err(|e| ServerFnError::new(e.to_string()))?;
let resp =
super::agent_client::agent_request(reqwest::Method::POST, "/api/v1/pentest/sessions")
.await?
.json(&serde_json::json!({ "config": config }))
.send()
.await
.map_err(|e| ServerFnError::new(e.to_string()))?;
if !resp.status().is_success() {
let text = resp.text().await.unwrap_or_default();
return Err(ServerFnError::new(format!(
@@ -239,8 +226,6 @@ pub async fn create_pentest_session_wizard(
/// Look up a tracked repository by its git URL
#[server]
pub async fn lookup_repo_by_url(url: String) -> Result<serde_json::Value, ServerFnError> {
let state: super::server_state::ServerState =
dioxus_fullstack::FullstackContext::extract().await?;
let encoded_url: String = url
.bytes()
.flat_map(|b| {
@@ -251,13 +236,12 @@ pub async fn lookup_repo_by_url(url: String) -> Result<serde_json::Value, Server
}
})
.collect();
let api_url = format!(
"{}/api/v1/pentest/lookup-repo?url={}",
state.agent_api_url, encoded_url
);
let resp = reqwest::get(&api_url)
.await
.map_err(|e| ServerFnError::new(e.to_string()))?;
let resp =
super::agent_client::agent_get(&format!("/api/v1/pentest/lookup-repo?url={encoded_url}"))
.await?
.send()
.await
.map_err(|e| ServerFnError::new(e.to_string()))?;
let body: serde_json::Value = resp
.json()
.await
@@ -270,21 +254,17 @@ pub async fn send_pentest_message(
session_id: String,
message: String,
) -> Result<PentestMessagesResponse, ServerFnError> {
let state: super::server_state::ServerState =
dioxus_fullstack::FullstackContext::extract().await?;
let url = format!(
"{}/api/v1/pentest/sessions/{session_id}/chat",
state.agent_api_url
);
let client = reqwest::Client::new();
let resp = client
.post(&url)
.json(&serde_json::json!({
"message": message,
}))
.send()
.await
.map_err(|e| ServerFnError::new(e.to_string()))?;
let resp = super::agent_client::agent_request(
reqwest::Method::POST,
&format!("/api/v1/pentest/sessions/{session_id}/chat"),
)
.await?
.json(&serde_json::json!({
"message": message,
}))
.send()
.await
.map_err(|e| ServerFnError::new(e.to_string()))?;
let body: PentestMessagesResponse = resp
.json()
.await
@@ -294,35 +274,27 @@ pub async fn send_pentest_message(
#[server]
pub async fn stop_pentest_session(session_id: String) -> Result<(), ServerFnError> {
let state: super::server_state::ServerState =
dioxus_fullstack::FullstackContext::extract().await?;
let url = format!(
"{}/api/v1/pentest/sessions/{session_id}/stop",
state.agent_api_url
);
let client = reqwest::Client::new();
client
.post(&url)
.send()
.await
.map_err(|e| ServerFnError::new(e.to_string()))?;
super::agent_client::agent_request(
reqwest::Method::POST,
&format!("/api/v1/pentest/sessions/{session_id}/stop"),
)
.await?
.send()
.await
.map_err(|e| ServerFnError::new(e.to_string()))?;
Ok(())
}
#[server]
pub async fn pause_pentest_session(session_id: String) -> Result<(), ServerFnError> {
let state: super::server_state::ServerState =
dioxus_fullstack::FullstackContext::extract().await?;
let url = format!(
"{}/api/v1/pentest/sessions/{session_id}/pause",
state.agent_api_url
);
let client = reqwest::Client::new();
let resp = client
.post(&url)
.send()
.await
.map_err(|e| ServerFnError::new(e.to_string()))?;
let resp = super::agent_client::agent_request(
reqwest::Method::POST,
&format!("/api/v1/pentest/sessions/{session_id}/pause"),
)
.await?
.send()
.await
.map_err(|e| ServerFnError::new(e.to_string()))?;
if !resp.status().is_success() {
let text = resp.text().await.unwrap_or_default();
return Err(ServerFnError::new(format!("Pause failed: {text}")));
@@ -332,18 +304,14 @@ pub async fn pause_pentest_session(session_id: String) -> Result<(), ServerFnErr
#[server]
pub async fn resume_pentest_session(session_id: String) -> Result<(), ServerFnError> {
let state: super::server_state::ServerState =
dioxus_fullstack::FullstackContext::extract().await?;
let url = format!(
"{}/api/v1/pentest/sessions/{session_id}/resume",
state.agent_api_url
);
let client = reqwest::Client::new();
let resp = client
.post(&url)
.send()
.await
.map_err(|e| ServerFnError::new(e.to_string()))?;
let resp = super::agent_client::agent_request(
reqwest::Method::POST,
&format!("/api/v1/pentest/sessions/{session_id}/resume"),
)
.await?
.send()
.await
.map_err(|e| ServerFnError::new(e.to_string()))?;
if !resp.status().is_success() {
let text = resp.text().await.unwrap_or_default();
return Err(ServerFnError::new(format!("Resume failed: {text}")));
@@ -355,15 +323,12 @@ pub async fn resume_pentest_session(session_id: String) -> Result<(), ServerFnEr
pub async fn fetch_pentest_findings(
session_id: String,
) -> Result<DastFindingsResponse, ServerFnError> {
let state: super::server_state::ServerState =
dioxus_fullstack::FullstackContext::extract().await?;
let url = format!(
"{}/api/v1/pentest/sessions/{session_id}/findings",
state.agent_api_url
);
let resp = reqwest::get(&url)
.await
.map_err(|e| ServerFnError::new(e.to_string()))?;
let resp =
super::agent_client::agent_get(&format!("/api/v1/pentest/sessions/{session_id}/findings"))
.await?
.send()
.await
.map_err(|e| ServerFnError::new(e.to_string()))?;
let body: DastFindingsResponse = resp
.json()
.await
@@ -385,23 +350,19 @@ pub async fn export_pentest_report(
requester_name: String,
requester_email: String,
) -> Result<ExportReportResponse, ServerFnError> {
let state: super::server_state::ServerState =
dioxus_fullstack::FullstackContext::extract().await?;
let url = format!(
"{}/api/v1/pentest/sessions/{session_id}/export",
state.agent_api_url
);
let client = reqwest::Client::new();
let resp = client
.post(&url)
.json(&serde_json::json!({
"password": password,
"requester_name": requester_name,
"requester_email": requester_email,
}))
.send()
.await
.map_err(|e| ServerFnError::new(e.to_string()))?;
let resp = super::agent_client::agent_request(
reqwest::Method::POST,
&format!("/api/v1/pentest/sessions/{session_id}/export"),
)
.await?
.json(&serde_json::json!({
"password": password,
"requester_name": requester_name,
"requester_email": requester_email,
}))
.send()
.await
.map_err(|e| ServerFnError::new(e.to_string()))?;
if !resp.status().is_success() {
let text = resp.text().await.unwrap_or_default();
return Err(ServerFnError::new(format!("Export failed: {text}")));
@@ -12,14 +12,10 @@ pub struct RepositoryListResponse {
#[server]
pub async fn fetch_repositories(page: u64) -> Result<RepositoryListResponse, ServerFnError> {
let state: super::server_state::ServerState =
dioxus_fullstack::FullstackContext::extract().await?;
let url = format!(
"{}/api/v1/repositories?page={page}&limit=20",
state.agent_api_url
);
let resp = reqwest::get(&url)
let path = format!("/api/v1/repositories?page={page}&limit=20");
let resp = super::agent_client::agent_get(&path)
.await?
.send()
.await
.map_err(|e| ServerFnError::new(e.to_string()))?;
let body: RepositoryListResponse = resp
@@ -41,10 +37,6 @@ pub async fn add_repository(
tracker_repo: Option<String>,
tracker_token: Option<String>,
) -> Result<(), ServerFnError> {
let state: super::server_state::ServerState =
dioxus_fullstack::FullstackContext::extract().await?;
let url = format!("{}/api/v1/repositories", state.agent_api_url);
let mut body = serde_json::json!({
"name": name,
"git_url": git_url,
@@ -69,9 +61,8 @@ pub async fn add_repository(
body["tracker_token"] = serde_json::Value::String(tk);
}
let client = reqwest::Client::new();
let resp = client
.post(&url)
let resp = super::agent_client::agent_request(reqwest::Method::POST, "/api/v1/repositories")
.await?
.json(&body)
.send()
.await
@@ -100,10 +91,6 @@ pub async fn update_repository(
tracker_token: Option<String>,
scan_schedule: Option<String>,
) -> Result<(), ServerFnError> {
let state: super::server_state::ServerState =
dioxus_fullstack::FullstackContext::extract().await?;
let url = format!("{}/api/v1/repositories/{repo_id}", state.agent_api_url);
let mut body = serde_json::Map::new();
if let Some(v) = name.filter(|s| !s.is_empty()) {
body.insert("name".into(), serde_json::Value::String(v));
@@ -133,13 +120,15 @@ pub async fn update_repository(
body.insert("scan_schedule".into(), serde_json::Value::String(v));
}
let client = reqwest::Client::new();
let resp = client
.patch(&url)
.json(&body)
.send()
.await
.map_err(|e| ServerFnError::new(e.to_string()))?;
let resp = super::agent_client::agent_request(
reqwest::Method::PATCH,
&format!("/api/v1/repositories/{repo_id}"),
)
.await?
.json(&body)
.send()
.await
.map_err(|e| ServerFnError::new(e.to_string()))?;
if !resp.status().is_success() {
let text = resp.text().await.unwrap_or_default();
@@ -153,11 +142,9 @@ pub async fn update_repository(
#[server]
pub async fn fetch_ssh_public_key() -> Result<String, ServerFnError> {
let state: super::server_state::ServerState =
dioxus_fullstack::FullstackContext::extract().await?;
let url = format!("{}/api/v1/settings/ssh-public-key", state.agent_api_url);
let resp = reqwest::get(&url)
let resp = super::agent_client::agent_get("/api/v1/settings/ssh-public-key")
.await?
.send()
.await
.map_err(|e| ServerFnError::new(e.to_string()))?;
@@ -179,16 +166,14 @@ pub async fn fetch_ssh_public_key() -> Result<String, ServerFnError> {
#[server]
pub async fn delete_repository(repo_id: String) -> Result<(), ServerFnError> {
let state: super::server_state::ServerState =
dioxus_fullstack::FullstackContext::extract().await?;
let url = format!("{}/api/v1/repositories/{repo_id}", state.agent_api_url);
let client = reqwest::Client::new();
let resp = client
.delete(&url)
.send()
.await
.map_err(|e| ServerFnError::new(e.to_string()))?;
let resp = super::agent_client::agent_request(
reqwest::Method::DELETE,
&format!("/api/v1/repositories/{repo_id}"),
)
.await?
.send()
.await
.map_err(|e| ServerFnError::new(e.to_string()))?;
if !resp.status().is_success() {
let body = resp.text().await.unwrap_or_default();
@@ -202,16 +187,14 @@ pub async fn delete_repository(repo_id: String) -> Result<(), ServerFnError> {
#[server]
pub async fn trigger_repo_scan(repo_id: String) -> Result<(), ServerFnError> {
let state: super::server_state::ServerState =
dioxus_fullstack::FullstackContext::extract().await?;
let url = format!("{}/api/v1/repositories/{repo_id}/scan", state.agent_api_url);
let client = reqwest::Client::new();
client
.post(&url)
.send()
.await
.map_err(|e| ServerFnError::new(e.to_string()))?;
super::agent_client::agent_request(
reqwest::Method::POST,
&format!("/api/v1/repositories/{repo_id}/scan"),
)
.await?
.send()
.await
.map_err(|e| ServerFnError::new(e.to_string()))?;
Ok(())
}
@@ -224,16 +207,12 @@ pub struct WebhookConfigResponse {
#[server]
pub async fn fetch_webhook_config(repo_id: String) -> Result<WebhookConfigResponse, ServerFnError> {
let state: super::server_state::ServerState =
dioxus_fullstack::FullstackContext::extract().await?;
let url = format!(
"{}/api/v1/repositories/{repo_id}/webhook-config",
state.agent_api_url
);
let resp = reqwest::get(&url)
.await
.map_err(|e| ServerFnError::new(e.to_string()))?;
let resp =
super::agent_client::agent_get(&format!("/api/v1/repositories/{repo_id}/webhook-config"))
.await?
.send()
.await
.map_err(|e| ServerFnError::new(e.to_string()))?;
let body: WebhookConfigResponse = resp
.json()
.await
@@ -244,11 +223,9 @@ pub async fn fetch_webhook_config(repo_id: String) -> Result<WebhookConfigRespon
/// Check if a repository has any running scans
#[server]
pub async fn check_repo_scanning(repo_id: String) -> Result<bool, ServerFnError> {
let state: super::server_state::ServerState =
dioxus_fullstack::FullstackContext::extract().await?;
let url = format!("{}/api/v1/scan-runs?page=1&limit=1", state.agent_api_url);
let resp = reqwest::get(&url)
let resp = super::agent_client::agent_get("/api/v1/scan-runs?page=1&limit=1")
.await?
.send()
.await
.map_err(|e| ServerFnError::new(e.to_string()))?;
let body: serde_json::Value = resp
+20 -35
View File
@@ -87,11 +87,9 @@ pub struct SbomFiltersResponse {
#[server]
pub async fn fetch_sbom_filters() -> Result<SbomFiltersResponse, ServerFnError> {
let state: super::server_state::ServerState =
dioxus_fullstack::FullstackContext::extract().await?;
let url = format!("{}/api/v1/sbom/filters", state.agent_api_url);
let resp = reqwest::get(&url)
let resp = super::agent_client::agent_get("/api/v1/sbom/filters")
.await?
.send()
.await
.map_err(|e| ServerFnError::new(e.to_string()))?;
let text = resp
@@ -112,9 +110,6 @@ pub async fn fetch_sbom_filtered(
license: Option<String>,
page: u64,
) -> Result<SbomListResponse, ServerFnError> {
let state: super::server_state::ServerState =
dioxus_fullstack::FullstackContext::extract().await?;
let mut params = vec![format!("page={page}"), "limit=50".to_string()];
if let Some(r) = &repo_id {
if !r.is_empty() {
@@ -140,9 +135,10 @@ pub async fn fetch_sbom_filtered(
}
}
let url = format!("{}/api/v1/sbom?{}", state.agent_api_url, params.join("&"));
let resp = reqwest::get(&url)
let path = format!("/api/v1/sbom?{}", params.join("&"));
let resp = super::agent_client::agent_get(&path)
.await?
.send()
.await
.map_err(|e| ServerFnError::new(e.to_string()))?;
let text = resp
@@ -156,15 +152,10 @@ pub async fn fetch_sbom_filtered(
#[server]
pub async fn fetch_sbom_export(repo_id: String, format: String) -> Result<String, ServerFnError> {
let state: super::server_state::ServerState =
dioxus_fullstack::FullstackContext::extract().await?;
let url = format!(
"{}/api/v1/sbom/export?repo_id={}&format={}",
state.agent_api_url, repo_id, format
);
let resp = reqwest::get(&url)
let path = format!("/api/v1/sbom/export?repo_id={repo_id}&format={format}");
let resp = super::agent_client::agent_get(&path)
.await?
.send()
.await
.map_err(|e| ServerFnError::new(e.to_string()))?;
let text = resp
@@ -178,17 +169,16 @@ pub async fn fetch_sbom_export(repo_id: String, format: String) -> Result<String
pub async fn fetch_license_summary(
repo_id: Option<String>,
) -> Result<LicenseSummaryResponse, ServerFnError> {
let state: super::server_state::ServerState =
dioxus_fullstack::FullstackContext::extract().await?;
let mut url = format!("{}/api/v1/sbom/licenses", state.agent_api_url);
let mut path = "/api/v1/sbom/licenses".to_string();
if let Some(r) = &repo_id {
if !r.is_empty() {
url = format!("{url}?repo_id={r}");
path = format!("{path}?repo_id={r}");
}
}
let resp = reqwest::get(&url)
let resp = super::agent_client::agent_get(&path)
.await?
.send()
.await
.map_err(|e| ServerFnError::new(e.to_string()))?;
let text = resp
@@ -205,15 +195,10 @@ pub async fn fetch_sbom_diff(
repo_a: String,
repo_b: String,
) -> Result<SbomDiffResponse, ServerFnError> {
let state: super::server_state::ServerState =
dioxus_fullstack::FullstackContext::extract().await?;
let url = format!(
"{}/api/v1/sbom/diff?repo_a={}&repo_b={}",
state.agent_api_url, repo_a, repo_b
);
let resp = reqwest::get(&url)
let path = format!("/api/v1/sbom/diff?repo_a={repo_a}&repo_b={repo_b}");
let resp = super::agent_client::agent_get(&path)
.await?
.send()
.await
.map_err(|e| ServerFnError::new(e.to_string()))?;
let text = resp
@@ -12,14 +12,9 @@ pub struct ScansListResponse {
#[server]
pub async fn fetch_scan_runs(page: u64) -> Result<ScansListResponse, ServerFnError> {
let state: super::server_state::ServerState =
dioxus_fullstack::FullstackContext::extract().await?;
let url = format!(
"{}/api/v1/scan-runs?page={page}&limit=20",
state.agent_api_url
);
let resp = reqwest::get(&url)
let resp = super::agent_client::agent_get(&format!("/api/v1/scan-runs?page={page}&limit=20"))
.await?
.send()
.await
.map_err(|e| ServerFnError::new(e.to_string()))?;
let body: ScansListResponse = resp
@@ -16,11 +16,9 @@ pub struct OverviewStats {
#[server]
pub async fn fetch_overview_stats() -> Result<OverviewStats, ServerFnError> {
let state: super::server_state::ServerState =
dioxus_fullstack::FullstackContext::extract().await?;
let url = format!("{}/api/v1/stats/overview", state.agent_api_url);
let resp = reqwest::get(&url)
let resp = super::agent_client::agent_get("/api/v1/stats/overview")
.await?
.send()
.await
.map_err(|e| ServerFnError::new(e.to_string()))?;
let body: serde_json::Value = resp
+22
View File
@@ -0,0 +1,22 @@
[package]
name = "compliance-smoke"
version = "0.1.0"
edition = "2021"
description = "Tiny Axum service exercising compliance-core M7.1 tenant gating. Run smoke.sh against it before merging anything that touches the auth/tenant path."
[lints]
workspace = true
[[bin]]
name = "compliance-smoke"
path = "src/main.rs"
[dependencies]
compliance-core = { workspace = true, features = ["axum"] }
axum = "0.8"
tokio = { workspace = true }
tracing = { workspace = true }
tracing-subscriber = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }
reqwest = { workspace = true }
+111
View File
@@ -0,0 +1,111 @@
//! M7.1 smoke service.
//!
//! A standalone Axum binary whose only job is to host the
//! [`compliance_core::auth`] middleware + [`compliance_core::tenant_ctx`]
//! extractor on three endpoints, so `scripts/smoke.sh` can prove the
//! tenant-gating contract end-to-end before any auth-path PR merges.
//!
//! Endpoints:
//! * `GET /api/v1/health` — public, never authenticated.
//! * `GET /api/v1/echo` — protected read; returns the [`TenantContext`].
//! * `POST /api/v1/echo` — protected write; exercises the `Frozen → 402`
//! gate on the same handler.
//!
//! Configuration (env):
//! * `KEYCLOAK_URL` — e.g. `http://localhost:8080`. Required.
//! * `KEYCLOAK_REALM` — e.g. `certifai`. Required.
//! * `SMOKE_PORT` — defaults to `3010`.
use std::sync::Arc;
use axum::{middleware, routing::get, Extension, Json, Router};
use compliance_core::{
auth::{require_jwt_auth, require_tenant_status, JwksState},
tenant_ctx::TenantCtx,
};
use serde::Serialize;
use tokio::sync::RwLock;
#[derive(Serialize)]
struct EchoResponse {
method: &'static str,
tenant_id: String,
tenant_slug: String,
plan: String,
status: String,
products: Vec<String>,
org_roles: Vec<String>,
user_id: String,
user_name: Option<String>,
}
async fn health() -> Json<serde_json::Value> {
Json(serde_json::json!({ "ok": true }))
}
async fn echo_read(TenantCtx(ctx): TenantCtx) -> Json<EchoResponse> {
Json(echo(ctx, "GET"))
}
async fn echo_write(TenantCtx(ctx): TenantCtx) -> Json<EchoResponse> {
Json(echo(ctx, "POST"))
}
fn echo(ctx: compliance_core::TenantContext, method: &'static str) -> EchoResponse {
EchoResponse {
method,
tenant_id: ctx.tenant_id,
tenant_slug: ctx.tenant_slug,
plan: ctx.plan,
status: ctx.status.to_string(),
products: ctx.products,
org_roles: ctx.org_roles.iter().map(|r| format!("{r:?}")).collect(),
user_id: ctx.user_id,
user_name: ctx.user_name,
}
}
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
tracing_subscriber::fmt()
.with_env_filter(
tracing_subscriber::EnvFilter::try_from_default_env()
.unwrap_or_else(|_| tracing_subscriber::EnvFilter::new("info")),
)
.init();
let kc_url = std::env::var("KEYCLOAK_URL")
.map_err(|_| "KEYCLOAK_URL is required (e.g. http://localhost:8080)")?;
let kc_realm = std::env::var("KEYCLOAK_REALM")
.map_err(|_| "KEYCLOAK_REALM is required (e.g. certifai)")?;
let port: u16 = std::env::var("SMOKE_PORT")
.ok()
.and_then(|s| s.parse().ok())
.unwrap_or(3010);
let jwks_url = format!("{kc_url}/realms/{kc_realm}/protocol/openid-connect/certs");
let jwks_state = JwksState {
jwks: Arc::new(RwLock::new(None)),
jwks_url: jwks_url.clone(),
};
// Layers execute outermost-first. The Extension must be registered
// before `require_jwt_auth` so the middleware can read JwksState; the
// status gate must run after JWT so `TenantContext` is in extensions.
let app = Router::new()
.route("/api/v1/health", get(health))
.route("/api/v1/echo", get(echo_read).post(echo_write))
.layer(middleware::from_fn(require_tenant_status))
.layer(middleware::from_fn(require_jwt_auth))
.layer(Extension(jwks_state));
let addr = format!("0.0.0.0:{port}");
let listener = tokio::net::TcpListener::bind(&addr).await?;
tracing::info!(
port,
jwks = %jwks_url,
"compliance-smoke listening — try `scripts/smoke.sh`"
);
axum::serve(listener, app).await?;
Ok(())
}
+136
View File
@@ -0,0 +1,136 @@
#!/usr/bin/env bash
# M7.1 tenant-gating smoke test.
#
# Drives compliance-smoke against a live Keycloak realm with five test
# users (one per tenant_status), asserts the response code on each
# endpoint, and exits non-zero on any mismatch.
#
# Pre-reqs (one-time):
# * KC up at $KC_URL with realm $KC_REALM
# * Client $KC_CLIENT has direct-access-grants enabled
# * Users + tenant_status mappers per certifai/keycloak/realm-export.json
# * compliance-smoke binary running and reachable at $SMOKE_URL
#
# Usage:
# scripts/smoke.sh # uses defaults below
# SMOKE_URL=... scripts/smoke.sh
set -euo pipefail
KC_URL="${KC_URL:-http://localhost:8080}"
KC_REALM="${KC_REALM:-certifai}"
KC_CLIENT="${KC_CLIENT:-certifai-dashboard}"
SMOKE_URL="${SMOKE_URL:-http://localhost:3010}"
readonly TOKEN_ENDPOINT="${KC_URL}/realms/${KC_REALM}/protocol/openid-connect/token"
PASS=0
FAIL=0
red() { printf '\033[31m%s\033[0m' "$*"; }
green() { printf '\033[32m%s\033[0m' "$*"; }
yellow() { printf '\033[33m%s\033[0m' "$*"; }
# Fetches an access token via direct access grant. Echoes the raw token.
get_token() {
local user="$1" pass="$2"
curl -sS -X POST "$TOKEN_ENDPOINT" \
-H 'Content-Type: application/x-www-form-urlencoded' \
-d "grant_type=password" \
-d "client_id=${KC_CLIENT}" \
-d "username=${user}" \
-d "password=${pass}" \
-d "scope=openid" \
| sed -n 's/.*"access_token":"\([^"]*\)".*/\1/p'
}
# Hits SMOKE_URL$path with the given method and (optional) bearer token,
# asserts the response status code matches $want.
assert_status() {
local label="$1" method="$2" path="$3" want="$4" token="${5:-}"
local args=(-sS -o /dev/null -w '%{http_code}' -X "$method" "${SMOKE_URL}${path}")
if [[ -n "$token" ]]; then
args+=(-H "Authorization: Bearer ${token}")
fi
local got
got=$(curl "${args[@]}")
if [[ "$got" == "$want" ]]; then
printf ' %s %s %-4s %-15s → %s\n' "$(green PASS)" "$label" "$method" "$path" "$got"
PASS=$((PASS + 1))
else
printf ' %s %s %-4s %-15s → got %s, want %s\n' "$(red FAIL)" "$label" "$method" "$path" "$got" "$want"
FAIL=$((FAIL + 1))
fi
}
header() {
printf '\n%s %s\n' "$(yellow '##')" "$1"
}
# ---- Pre-flight ----------------------------------------------------------
header "Pre-flight"
if ! curl -sS -o /dev/null -w '%{http_code}\n' "${SMOKE_URL}/api/v1/health" | grep -q '^200$'; then
printf ' %s smoke service not reachable at %s\n' "$(red ERR)" "$SMOKE_URL"
exit 2
fi
if ! curl -sS -o /dev/null -w '%{http_code}\n' "${KC_URL}/realms/${KC_REALM}/.well-known/openid-configuration" | grep -q '^200$'; then
printf ' %s Keycloak realm %s not reachable at %s\n' "$(red ERR)" "$KC_REALM" "$KC_URL"
exit 2
fi
printf ' %s smoke service + Keycloak both up\n' "$(green OK)"
# ---- Public endpoint --------------------------------------------------
header "Public endpoint (no auth required)"
assert_status anon GET /api/v1/health 200
# ---- Anonymous access to protected endpoints ----------------------------
header "Anonymous → 401 on protected endpoints"
assert_status anon GET /api/v1/echo 401
assert_status anon POST /api/v1/echo 401
# ---- Bad token ----------------------------------------------------------
header "Bad token → 401"
assert_status bogus GET /api/v1/echo 401 "not-a-real-jwt"
assert_status bogus POST /api/v1/echo 401 "not-a-real-jwt"
# ---- Active tenant (admin user) -----------------------------------------
header "admin@certifai.local (active) → full access"
TOKEN=$(get_token admin@certifai.local admin)
if [[ -z "$TOKEN" ]]; then
printf ' %s failed to fetch token for admin\n' "$(red ERR)"
exit 2
fi
assert_status active GET /api/v1/echo 200 "$TOKEN"
assert_status active POST /api/v1/echo 200 "$TOKEN"
# ---- Active tenant (USER role) ------------------------------------------
header "user@certifai.local (active) → full access"
TOKEN=$(get_token user@certifai.local user)
assert_status active GET /api/v1/echo 200 "$TOKEN"
assert_status active POST /api/v1/echo 200 "$TOKEN"
# ---- Trial tenant -------------------------------------------------------
header "trial@acme.local (trial) → full access"
TOKEN=$(get_token trial@acme.local trial)
assert_status trial GET /api/v1/echo 200 "$TOKEN"
assert_status trial POST /api/v1/echo 200 "$TOKEN"
# ---- Frozen tenant ------------------------------------------------------
header "frozen@acme.local (frozen) → read-only, writes 402"
TOKEN=$(get_token frozen@acme.local frozen)
assert_status frozen GET /api/v1/echo 200 "$TOKEN"
assert_status frozen POST /api/v1/echo 402 "$TOKEN"
# ---- Archived tenant ----------------------------------------------------
header "archived@acme.local (archived) → 410 everywhere"
TOKEN=$(get_token archived@acme.local archived)
assert_status archived GET /api/v1/echo 410 "$TOKEN"
assert_status archived POST /api/v1/echo 410 "$TOKEN"
# ---- Summary ------------------------------------------------------------
printf '\n'
if [[ "$FAIL" -gt 0 ]]; then
printf '%s %d passed, %d failed\n' "$(red FAIL)" "$PASS" "$FAIL"
exit 1
fi
printf '%s %d/%d assertions passed\n' "$(green PASS)" "$PASS" "$PASS"