refactor: modularize codebase and add 404 unit tests (#13)
All checks were successful
CI / Format (push) Successful in 4s
CI / Clippy (push) Successful in 4m19s
CI / Security Audit (push) Successful in 1m44s
CI / Detect Changes (push) Successful in 5s
CI / Tests (push) Successful in 5m15s
CI / Deploy Agent (push) Successful in 2s
CI / Deploy Dashboard (push) Successful in 2s
CI / Deploy Docs (push) Has been skipped
CI / Deploy MCP (push) Successful in 2s

This commit was merged in pull request #13.
This commit is contained in:
2026-03-13 08:03:45 +00:00
parent acc5b86aa4
commit 3bb690e5bb
89 changed files with 11884 additions and 6046 deletions

View File

@@ -0,0 +1,481 @@
use compliance_core::models::TrackerType;
use serde::{Deserialize, Serialize};
use compliance_core::models::ScanRun;
#[derive(Deserialize)]
pub struct PaginationParams {
#[serde(default = "default_page")]
pub page: u64,
#[serde(default = "default_limit")]
pub limit: i64,
}
pub(crate) fn default_page() -> u64 {
1
}
pub(crate) fn default_limit() -> i64 {
50
}
#[derive(Deserialize)]
pub struct FindingsFilter {
#[serde(default)]
pub repo_id: Option<String>,
#[serde(default)]
pub severity: Option<String>,
#[serde(default)]
pub scan_type: Option<String>,
#[serde(default)]
pub status: Option<String>,
#[serde(default)]
pub q: Option<String>,
#[serde(default)]
pub sort_by: Option<String>,
#[serde(default)]
pub sort_order: Option<String>,
#[serde(default = "default_page")]
pub page: u64,
#[serde(default = "default_limit")]
pub limit: i64,
}
#[derive(Serialize)]
pub struct ApiResponse<T: Serialize> {
pub data: T,
#[serde(skip_serializing_if = "Option::is_none")]
pub total: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub page: Option<u64>,
}
#[derive(Serialize)]
pub struct OverviewStats {
pub total_repositories: u64,
pub total_findings: u64,
pub critical_findings: u64,
pub high_findings: u64,
pub medium_findings: u64,
pub low_findings: u64,
pub total_sbom_entries: u64,
pub total_cve_alerts: u64,
pub total_issues: u64,
pub recent_scans: Vec<ScanRun>,
}
#[derive(Deserialize)]
pub struct AddRepositoryRequest {
pub name: String,
pub git_url: String,
#[serde(default = "default_branch")]
pub default_branch: String,
pub auth_token: Option<String>,
pub auth_username: Option<String>,
pub tracker_type: Option<TrackerType>,
pub tracker_owner: Option<String>,
pub tracker_repo: Option<String>,
pub tracker_token: Option<String>,
pub scan_schedule: Option<String>,
}
#[derive(Deserialize)]
pub struct UpdateRepositoryRequest {
pub name: Option<String>,
pub default_branch: Option<String>,
pub auth_token: Option<String>,
pub auth_username: Option<String>,
pub tracker_type: Option<TrackerType>,
pub tracker_owner: Option<String>,
pub tracker_repo: Option<String>,
pub tracker_token: Option<String>,
pub scan_schedule: Option<String>,
}
fn default_branch() -> String {
"main".to_string()
}
#[derive(Deserialize)]
pub struct UpdateStatusRequest {
pub status: String,
}
#[derive(Deserialize)]
pub struct BulkUpdateStatusRequest {
pub ids: Vec<String>,
pub status: String,
}
#[derive(Deserialize)]
pub struct UpdateFeedbackRequest {
pub feedback: String,
}
#[derive(Deserialize)]
pub struct SbomFilter {
#[serde(default)]
pub repo_id: Option<String>,
#[serde(default)]
pub package_manager: Option<String>,
#[serde(default)]
pub q: Option<String>,
#[serde(default)]
pub has_vulns: Option<bool>,
#[serde(default)]
pub license: Option<String>,
#[serde(default = "default_page")]
pub page: u64,
#[serde(default = "default_limit")]
pub limit: i64,
}
#[derive(Deserialize)]
pub struct SbomExportParams {
pub repo_id: String,
#[serde(default = "default_export_format")]
pub format: String,
}
fn default_export_format() -> String {
"cyclonedx".to_string()
}
#[derive(Deserialize)]
pub struct SbomDiffParams {
pub repo_a: String,
pub repo_b: String,
}
#[derive(Serialize)]
pub struct LicenseSummary {
pub license: String,
pub count: u64,
pub is_copyleft: bool,
pub packages: Vec<String>,
}
#[derive(Serialize)]
pub struct SbomDiffResult {
pub only_in_a: Vec<SbomDiffEntry>,
pub only_in_b: Vec<SbomDiffEntry>,
pub version_changed: Vec<SbomVersionDiff>,
pub common_count: u64,
}
#[derive(Serialize)]
pub struct SbomDiffEntry {
pub name: String,
pub version: String,
pub package_manager: String,
}
#[derive(Serialize)]
pub struct SbomVersionDiff {
pub name: String,
pub package_manager: String,
pub version_a: String,
pub version_b: String,
}
pub(crate) type AgentExt = axum::extract::Extension<std::sync::Arc<crate::agent::ComplianceAgent>>;
pub(crate) type ApiResult<T> = Result<axum::Json<ApiResponse<T>>, axum::http::StatusCode>;
pub(crate) async fn collect_cursor_async<T: serde::de::DeserializeOwned + Unpin + Send>(
mut cursor: mongodb::Cursor<T>,
) -> Vec<T> {
use futures_util::StreamExt;
let mut items = Vec::new();
while let Some(result) = cursor.next().await {
match result {
Ok(item) => items.push(item),
Err(e) => tracing::warn!("Failed to deserialize document: {e}"),
}
}
items
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
// ── PaginationParams ─────────────────────────────────────────
#[test]
fn pagination_params_defaults() {
let p: PaginationParams = serde_json::from_str("{}").unwrap();
assert_eq!(p.page, 1);
assert_eq!(p.limit, 50);
}
#[test]
fn pagination_params_custom_values() {
let p: PaginationParams = serde_json::from_str(r#"{"page":3,"limit":10}"#).unwrap();
assert_eq!(p.page, 3);
assert_eq!(p.limit, 10);
}
#[test]
fn pagination_params_partial_override() {
let p: PaginationParams = serde_json::from_str(r#"{"page":5}"#).unwrap();
assert_eq!(p.page, 5);
assert_eq!(p.limit, 50);
}
#[test]
fn pagination_params_zero_page() {
let p: PaginationParams = serde_json::from_str(r#"{"page":0}"#).unwrap();
assert_eq!(p.page, 0);
}
// ── FindingsFilter ───────────────────────────────────────────
#[test]
fn findings_filter_all_defaults() {
let f: FindingsFilter = serde_json::from_str("{}").unwrap();
assert!(f.repo_id.is_none());
assert!(f.severity.is_none());
assert!(f.scan_type.is_none());
assert!(f.status.is_none());
assert!(f.q.is_none());
assert!(f.sort_by.is_none());
assert!(f.sort_order.is_none());
assert_eq!(f.page, 1);
assert_eq!(f.limit, 50);
}
#[test]
fn findings_filter_with_all_fields() {
let f: FindingsFilter = serde_json::from_str(
r#"{
"repo_id": "abc",
"severity": "high",
"scan_type": "sast",
"status": "open",
"q": "sql injection",
"sort_by": "severity",
"sort_order": "desc",
"page": 2,
"limit": 25
}"#,
)
.unwrap();
assert_eq!(f.repo_id.as_deref(), Some("abc"));
assert_eq!(f.severity.as_deref(), Some("high"));
assert_eq!(f.scan_type.as_deref(), Some("sast"));
assert_eq!(f.status.as_deref(), Some("open"));
assert_eq!(f.q.as_deref(), Some("sql injection"));
assert_eq!(f.sort_by.as_deref(), Some("severity"));
assert_eq!(f.sort_order.as_deref(), Some("desc"));
assert_eq!(f.page, 2);
assert_eq!(f.limit, 25);
}
#[test]
fn findings_filter_empty_string_fields() {
let f: FindingsFilter = serde_json::from_str(r#"{"repo_id":"","severity":""}"#).unwrap();
assert_eq!(f.repo_id.as_deref(), Some(""));
assert_eq!(f.severity.as_deref(), Some(""));
}
// ── ApiResponse ──────────────────────────────────────────────
#[test]
fn api_response_serializes_with_all_fields() {
let resp = ApiResponse {
data: vec!["a", "b"],
total: Some(100),
page: Some(1),
};
let v = serde_json::to_value(&resp).unwrap();
assert_eq!(v["data"], json!(["a", "b"]));
assert_eq!(v["total"], 100);
assert_eq!(v["page"], 1);
}
#[test]
fn api_response_skips_none_fields() {
let resp = ApiResponse {
data: "hello",
total: None,
page: None,
};
let v = serde_json::to_value(&resp).unwrap();
assert_eq!(v["data"], "hello");
assert!(v.get("total").is_none());
assert!(v.get("page").is_none());
}
#[test]
fn api_response_with_nested_struct() {
#[derive(Serialize)]
struct Item {
id: u32,
}
let resp = ApiResponse {
data: Item { id: 42 },
total: Some(1),
page: None,
};
let v = serde_json::to_value(&resp).unwrap();
assert_eq!(v["data"]["id"], 42);
assert_eq!(v["total"], 1);
assert!(v.get("page").is_none());
}
#[test]
fn api_response_empty_vec() {
let resp: ApiResponse<Vec<String>> = ApiResponse {
data: vec![],
total: Some(0),
page: Some(1),
};
let v = serde_json::to_value(&resp).unwrap();
assert!(v["data"].as_array().unwrap().is_empty());
}
// ── SbomFilter ───────────────────────────────────────────────
#[test]
fn sbom_filter_defaults() {
let f: SbomFilter = serde_json::from_str("{}").unwrap();
assert!(f.repo_id.is_none());
assert!(f.package_manager.is_none());
assert!(f.q.is_none());
assert!(f.has_vulns.is_none());
assert!(f.license.is_none());
assert_eq!(f.page, 1);
assert_eq!(f.limit, 50);
}
#[test]
fn sbom_filter_has_vulns_bool() {
let f: SbomFilter = serde_json::from_str(r#"{"has_vulns": true}"#).unwrap();
assert_eq!(f.has_vulns, Some(true));
}
// ── SbomExportParams ─────────────────────────────────────────
#[test]
fn sbom_export_params_default_format() {
let p: SbomExportParams = serde_json::from_str(r#"{"repo_id":"r1"}"#).unwrap();
assert_eq!(p.repo_id, "r1");
assert_eq!(p.format, "cyclonedx");
}
#[test]
fn sbom_export_params_custom_format() {
let p: SbomExportParams =
serde_json::from_str(r#"{"repo_id":"r1","format":"spdx"}"#).unwrap();
assert_eq!(p.format, "spdx");
}
// ── AddRepositoryRequest ─────────────────────────────────────
#[test]
fn add_repository_request_defaults() {
let r: AddRepositoryRequest = serde_json::from_str(
r#"{
"name": "my-repo",
"git_url": "https://github.com/x/y.git"
}"#,
)
.unwrap();
assert_eq!(r.name, "my-repo");
assert_eq!(r.git_url, "https://github.com/x/y.git");
assert_eq!(r.default_branch, "main");
assert!(r.auth_token.is_none());
assert!(r.tracker_type.is_none());
assert!(r.scan_schedule.is_none());
}
#[test]
fn add_repository_request_custom_branch() {
let r: AddRepositoryRequest = serde_json::from_str(
r#"{
"name": "repo",
"git_url": "url",
"default_branch": "develop"
}"#,
)
.unwrap();
assert_eq!(r.default_branch, "develop");
}
// ── UpdateStatusRequest / BulkUpdateStatusRequest ────────────
#[test]
fn update_status_request() {
let r: UpdateStatusRequest = serde_json::from_str(r#"{"status":"resolved"}"#).unwrap();
assert_eq!(r.status, "resolved");
}
#[test]
fn bulk_update_status_request() {
let r: BulkUpdateStatusRequest =
serde_json::from_str(r#"{"ids":["a","b"],"status":"dismissed"}"#).unwrap();
assert_eq!(r.ids, vec!["a", "b"]);
assert_eq!(r.status, "dismissed");
}
#[test]
fn bulk_update_status_empty_ids() {
let r: BulkUpdateStatusRequest =
serde_json::from_str(r#"{"ids":[],"status":"x"}"#).unwrap();
assert!(r.ids.is_empty());
}
// ── SbomDiffResult serialization ─────────────────────────────
#[test]
fn sbom_diff_result_serializes() {
let r = SbomDiffResult {
only_in_a: vec![SbomDiffEntry {
name: "pkg-a".to_string(),
version: "1.0".to_string(),
package_manager: "npm".to_string(),
}],
only_in_b: vec![],
version_changed: vec![SbomVersionDiff {
name: "shared".to_string(),
package_manager: "cargo".to_string(),
version_a: "0.1".to_string(),
version_b: "0.2".to_string(),
}],
common_count: 10,
};
let v = serde_json::to_value(&r).unwrap();
assert_eq!(v["only_in_a"].as_array().unwrap().len(), 1);
assert_eq!(v["only_in_b"].as_array().unwrap().len(), 0);
assert_eq!(v["version_changed"][0]["version_a"], "0.1");
assert_eq!(v["common_count"], 10);
}
// ── LicenseSummary ───────────────────────────────────────────
#[test]
fn license_summary_serializes() {
let ls = LicenseSummary {
license: "MIT".to_string(),
count: 42,
is_copyleft: false,
packages: vec!["serde".to_string()],
};
let v = serde_json::to_value(&ls).unwrap();
assert_eq!(v["license"], "MIT");
assert_eq!(v["is_copyleft"], false);
assert_eq!(v["count"], 42);
}
// ── Default helper functions ─────────────────────────────────
#[test]
fn default_page_returns_1() {
assert_eq!(default_page(), 1);
}
#[test]
fn default_limit_returns_50() {
assert_eq!(default_limit(), 50);
}
}

View File

@@ -0,0 +1,172 @@
use axum::extract::{Extension, Path, Query};
use axum::http::StatusCode;
use axum::Json;
use mongodb::bson::doc;
use super::dto::*;
use compliance_core::models::Finding;
#[tracing::instrument(skip_all, fields(repo_id = ?filter.repo_id, severity = ?filter.severity, scan_type = ?filter.scan_type))]
pub async fn list_findings(
Extension(agent): AgentExt,
Query(filter): Query<FindingsFilter>,
) -> ApiResult<Vec<Finding>> {
let db = &agent.db;
let mut query = doc! {};
if let Some(repo_id) = &filter.repo_id {
query.insert("repo_id", repo_id);
}
if let Some(severity) = &filter.severity {
query.insert("severity", severity);
}
if let Some(scan_type) = &filter.scan_type {
query.insert("scan_type", scan_type);
}
if let Some(status) = &filter.status {
query.insert("status", status);
}
// Text search across title, description, file_path, rule_id
if let Some(q) = &filter.q {
if !q.is_empty() {
let regex = doc! { "$regex": q, "$options": "i" };
query.insert(
"$or",
mongodb::bson::bson!([
{ "title": regex.clone() },
{ "description": regex.clone() },
{ "file_path": regex.clone() },
{ "rule_id": regex },
]),
);
}
}
// Dynamic sort
let sort_field = filter.sort_by.as_deref().unwrap_or("created_at");
let sort_dir: i32 = match filter.sort_order.as_deref() {
Some("asc") => 1,
_ => -1,
};
let sort_doc = doc! { sort_field: sort_dir };
let skip = (filter.page.saturating_sub(1)) * filter.limit as u64;
let total = db
.findings()
.count_documents(query.clone())
.await
.unwrap_or(0);
let findings = match db
.findings()
.find(query)
.sort(sort_doc)
.skip(skip)
.limit(filter.limit)
.await
{
Ok(cursor) => collect_cursor_async(cursor).await,
Err(e) => {
tracing::warn!("Failed to fetch findings: {e}");
Vec::new()
}
};
Ok(Json(ApiResponse {
data: findings,
total: Some(total),
page: Some(filter.page),
}))
}
#[tracing::instrument(skip_all, fields(finding_id = %id))]
pub async fn get_finding(
Extension(agent): AgentExt,
Path(id): Path<String>,
) -> Result<Json<ApiResponse<Finding>>, StatusCode> {
let oid = mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?;
let finding = agent
.db
.findings()
.find_one(doc! { "_id": oid })
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
.ok_or(StatusCode::NOT_FOUND)?;
Ok(Json(ApiResponse {
data: finding,
total: None,
page: None,
}))
}
#[tracing::instrument(skip_all, fields(finding_id = %id))]
pub async fn update_finding_status(
Extension(agent): AgentExt,
Path(id): Path<String>,
Json(req): Json<UpdateStatusRequest>,
) -> Result<Json<serde_json::Value>, StatusCode> {
let oid = mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?;
agent
.db
.findings()
.update_one(
doc! { "_id": oid },
doc! { "$set": { "status": &req.status, "updated_at": mongodb::bson::DateTime::now() } },
)
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
Ok(Json(serde_json::json!({ "status": "updated" })))
}
#[tracing::instrument(skip_all)]
pub async fn bulk_update_finding_status(
Extension(agent): AgentExt,
Json(req): Json<BulkUpdateStatusRequest>,
) -> Result<Json<serde_json::Value>, StatusCode> {
let oids: Vec<mongodb::bson::oid::ObjectId> = req
.ids
.iter()
.filter_map(|id| mongodb::bson::oid::ObjectId::parse_str(id).ok())
.collect();
if oids.is_empty() {
return Err(StatusCode::BAD_REQUEST);
}
let result = agent
.db
.findings()
.update_many(
doc! { "_id": { "$in": oids } },
doc! { "$set": { "status": &req.status, "updated_at": mongodb::bson::DateTime::now() } },
)
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
Ok(Json(
serde_json::json!({ "status": "updated", "modified_count": result.modified_count }),
))
}
#[tracing::instrument(skip_all)]
pub async fn update_finding_feedback(
Extension(agent): AgentExt,
Path(id): Path<String>,
Json(req): Json<UpdateFeedbackRequest>,
) -> Result<Json<serde_json::Value>, StatusCode> {
let oid = mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?;
agent
.db
.findings()
.update_one(
doc! { "_id": oid },
doc! { "$set": { "developer_feedback": &req.feedback, "updated_at": mongodb::bson::DateTime::now() } },
)
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
Ok(Json(serde_json::json!({ "status": "updated" })))
}

View File

@@ -0,0 +1,84 @@
use axum::Json;
use mongodb::bson::doc;
use super::dto::*;
use compliance_core::models::ScanRun;
#[tracing::instrument(skip_all)]
pub async fn health() -> Json<serde_json::Value> {
Json(serde_json::json!({ "status": "ok" }))
}
#[tracing::instrument(skip_all)]
pub async fn stats_overview(axum::extract::Extension(agent): AgentExt) -> ApiResult<OverviewStats> {
let db = &agent.db;
let total_repositories = db
.repositories()
.count_documents(doc! {})
.await
.unwrap_or(0);
let total_findings = db.findings().count_documents(doc! {}).await.unwrap_or(0);
let critical_findings = db
.findings()
.count_documents(doc! { "severity": "critical" })
.await
.unwrap_or(0);
let high_findings = db
.findings()
.count_documents(doc! { "severity": "high" })
.await
.unwrap_or(0);
let medium_findings = db
.findings()
.count_documents(doc! { "severity": "medium" })
.await
.unwrap_or(0);
let low_findings = db
.findings()
.count_documents(doc! { "severity": "low" })
.await
.unwrap_or(0);
let total_sbom_entries = db
.sbom_entries()
.count_documents(doc! {})
.await
.unwrap_or(0);
let total_cve_alerts = db.cve_alerts().count_documents(doc! {}).await.unwrap_or(0);
let total_issues = db
.tracker_issues()
.count_documents(doc! {})
.await
.unwrap_or(0);
let recent_scans: Vec<ScanRun> = match db
.scan_runs()
.find(doc! {})
.sort(doc! { "started_at": -1 })
.limit(10)
.await
{
Ok(cursor) => collect_cursor_async(cursor).await,
Err(e) => {
tracing::warn!("Failed to fetch recent scans: {e}");
Vec::new()
}
};
Ok(Json(ApiResponse {
data: OverviewStats {
total_repositories,
total_findings,
critical_findings,
high_findings,
medium_findings,
low_findings,
total_sbom_entries,
total_cve_alerts,
total_issues,
recent_scans,
},
total: None,
page: None,
}))
}

View File

@@ -0,0 +1,41 @@
use axum::extract::{Extension, Query};
use axum::Json;
use mongodb::bson::doc;
use super::dto::*;
use compliance_core::models::TrackerIssue;
#[tracing::instrument(skip_all)]
pub async fn list_issues(
Extension(agent): AgentExt,
Query(params): Query<PaginationParams>,
) -> ApiResult<Vec<TrackerIssue>> {
let db = &agent.db;
let skip = (params.page.saturating_sub(1)) * params.limit as u64;
let total = db
.tracker_issues()
.count_documents(doc! {})
.await
.unwrap_or(0);
let issues = match db
.tracker_issues()
.find(doc! {})
.sort(doc! { "created_at": -1 })
.skip(skip)
.limit(params.limit)
.await
{
Ok(cursor) => collect_cursor_async(cursor).await,
Err(e) => {
tracing::warn!("Failed to fetch tracker issues: {e}");
Vec::new()
}
};
Ok(Json(ApiResponse {
data: issues,
total: Some(total),
page: Some(params.page),
}))
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,131 @@
use std::sync::Arc;
use axum::extract::{Extension, Path};
use axum::http::StatusCode;
use axum::response::IntoResponse;
use axum::Json;
use mongodb::bson::doc;
use serde::Deserialize;
use compliance_core::models::dast::DastFinding;
use compliance_core::models::pentest::*;
use crate::agent::ComplianceAgent;
use super::super::dto::collect_cursor_async;
type AgentExt = Extension<Arc<ComplianceAgent>>;
#[derive(Deserialize)]
pub struct ExportBody {
pub password: String,
/// Requester display name (from auth)
#[serde(default)]
pub requester_name: String,
/// Requester email (from auth)
#[serde(default)]
pub requester_email: String,
}
/// POST /api/v1/pentest/sessions/:id/export — Export an encrypted pentest report archive
#[tracing::instrument(skip_all, fields(session_id = %id))]
pub async fn export_session_report(
Extension(agent): AgentExt,
Path(id): Path<String>,
Json(body): Json<ExportBody>,
) -> Result<axum::response::Response, (StatusCode, String)> {
let oid = mongodb::bson::oid::ObjectId::parse_str(&id)
.map_err(|_| (StatusCode::BAD_REQUEST, "Invalid session ID".to_string()))?;
if body.password.len() < 8 {
return Err((
StatusCode::BAD_REQUEST,
"Password must be at least 8 characters".to_string(),
));
}
// Fetch session
let session = agent
.db
.pentest_sessions()
.find_one(doc! { "_id": oid })
.await
.map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("Database error: {e}"),
)
})?
.ok_or_else(|| (StatusCode::NOT_FOUND, "Session not found".to_string()))?;
// Resolve target name
let target = if let Ok(tid) = mongodb::bson::oid::ObjectId::parse_str(&session.target_id) {
agent
.db
.dast_targets()
.find_one(doc! { "_id": tid })
.await
.ok()
.flatten()
} else {
None
};
let target_name = target
.as_ref()
.map(|t| t.name.clone())
.unwrap_or_else(|| "Unknown Target".to_string());
let target_url = target
.as_ref()
.map(|t| t.base_url.clone())
.unwrap_or_default();
// Fetch attack chain nodes
let nodes: Vec<AttackChainNode> = match agent
.db
.attack_chain_nodes()
.find(doc! { "session_id": &id })
.sort(doc! { "started_at": 1 })
.await
{
Ok(cursor) => collect_cursor_async(cursor).await,
Err(_) => Vec::new(),
};
// Fetch DAST findings for this session
let findings: Vec<DastFinding> = match agent
.db
.dast_findings()
.find(doc! { "session_id": &id })
.sort(doc! { "severity": -1, "created_at": -1 })
.await
{
Ok(cursor) => collect_cursor_async(cursor).await,
Err(_) => Vec::new(),
};
let ctx = crate::pentest::report::ReportContext {
session,
target_name,
target_url,
findings,
attack_chain: nodes,
requester_name: if body.requester_name.is_empty() {
"Unknown".to_string()
} else {
body.requester_name
},
requester_email: body.requester_email,
};
let report = crate::pentest::generate_encrypted_report(&ctx, &body.password)
.await
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e))?;
let response = serde_json::json!({
"archive_base64": base64::Engine::encode(&base64::engine::general_purpose::STANDARD, &report.archive),
"sha256": report.sha256,
"filename": format!("pentest-report-{id}.zip"),
});
Ok(Json(response).into_response())
}

View File

@@ -0,0 +1,9 @@
mod export;
mod session;
mod stats;
mod stream;
pub use export::*;
pub use session::*;
pub use stats::*;
pub use stream::*;

View File

@@ -2,20 +2,16 @@ use std::sync::Arc;
use axum::extract::{Extension, Path, Query};
use axum::http::StatusCode;
use axum::response::sse::{Event, Sse};
use axum::response::IntoResponse;
use axum::Json;
use futures_util::stream;
use mongodb::bson::doc;
use serde::Deserialize;
use compliance_core::models::dast::DastFinding;
use compliance_core::models::pentest::*;
use crate::agent::ComplianceAgent;
use crate::pentest::PentestOrchestrator;
use super::{collect_cursor_async, ApiResponse, PaginationParams};
use super::super::dto::{collect_cursor_async, ApiResponse, PaginationParams};
type AgentExt = Extension<Arc<ComplianceAgent>>;
@@ -160,8 +156,7 @@ pub async fn get_session(
Extension(agent): AgentExt,
Path(id): Path<String>,
) -> Result<Json<ApiResponse<PentestSession>>, StatusCode> {
let oid =
mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?;
let oid = mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?;
let session = agent
.db
@@ -210,13 +205,12 @@ pub async fn send_message(
}
// Look up the target
let target_oid =
mongodb::bson::oid::ObjectId::parse_str(&session.target_id).map_err(|_| {
(
StatusCode::INTERNAL_SERVER_ERROR,
"Invalid target_id in session".to_string(),
)
})?;
let target_oid = mongodb::bson::oid::ObjectId::parse_str(&session.target_id).map_err(|_| {
(
StatusCode::INTERNAL_SERVER_ERROR,
"Invalid target_id in session".to_string(),
)
})?;
let target = agent
.db
@@ -261,106 +255,6 @@ pub async fn send_message(
}))
}
/// GET /api/v1/pentest/sessions/:id/stream — SSE endpoint for real-time events
///
/// Returns recent messages as SSE events (polling approach).
/// True real-time streaming with broadcast channels will be added in a future iteration.
#[tracing::instrument(skip_all, fields(session_id = %id))]
pub async fn session_stream(
Extension(agent): AgentExt,
Path(id): Path<String>,
) -> Result<Sse<impl futures_util::Stream<Item = Result<Event, std::convert::Infallible>>>, StatusCode>
{
let oid =
mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?;
// Verify session exists
let _session = agent
.db
.pentest_sessions()
.find_one(doc! { "_id": oid })
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
.ok_or(StatusCode::NOT_FOUND)?;
// Fetch recent messages for this session
let messages: Vec<PentestMessage> = match agent
.db
.pentest_messages()
.find(doc! { "session_id": &id })
.sort(doc! { "created_at": 1 })
.limit(100)
.await
{
Ok(cursor) => collect_cursor_async(cursor).await,
Err(_) => Vec::new(),
};
// Fetch recent attack chain nodes
let nodes: Vec<AttackChainNode> = match agent
.db
.attack_chain_nodes()
.find(doc! { "session_id": &id })
.sort(doc! { "started_at": 1 })
.limit(100)
.await
{
Ok(cursor) => collect_cursor_async(cursor).await,
Err(_) => Vec::new(),
};
// Build SSE events from stored data
let mut events: Vec<Result<Event, std::convert::Infallible>> = Vec::new();
for msg in &messages {
let event_data = serde_json::json!({
"type": "message",
"role": msg.role,
"content": msg.content,
"created_at": msg.created_at.to_rfc3339(),
});
if let Ok(data) = serde_json::to_string(&event_data) {
events.push(Ok(Event::default().event("message").data(data)));
}
}
for node in &nodes {
let event_data = serde_json::json!({
"type": "tool_execution",
"node_id": node.node_id,
"tool_name": node.tool_name,
"status": node.status,
"findings_produced": node.findings_produced,
});
if let Ok(data) = serde_json::to_string(&event_data) {
events.push(Ok(Event::default().event("tool").data(data)));
}
}
// Add session status event
let session = agent
.db
.pentest_sessions()
.find_one(doc! { "_id": oid })
.await
.ok()
.flatten();
if let Some(s) = session {
let status_data = serde_json::json!({
"type": "status",
"status": s.status,
"findings_count": s.findings_count,
"tool_invocations": s.tool_invocations,
});
if let Ok(data) = serde_json::to_string(&status_data) {
events.push(Ok(Event::default().event("status").data(data)));
}
}
Ok(Sse::new(stream::iter(events)))
}
/// POST /api/v1/pentest/sessions/:id/stop — Stop a running pentest session
#[tracing::instrument(skip_all, fields(session_id = %id))]
pub async fn stop_session(
@@ -375,7 +269,12 @@ pub async fn stop_session(
.pentest_sessions()
.find_one(doc! { "_id": oid })
.await
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, format!("Database error: {e}")))?
.map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("Database error: {e}"),
)
})?
.ok_or_else(|| (StatusCode::NOT_FOUND, "Session not found".to_string()))?;
if session.status != PentestStatus::Running {
@@ -397,15 +296,30 @@ pub async fn stop_session(
}},
)
.await
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, format!("Database error: {e}")))?;
.map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("Database error: {e}"),
)
})?;
let updated = agent
.db
.pentest_sessions()
.find_one(doc! { "_id": oid })
.await
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, format!("Database error: {e}")))?
.ok_or_else(|| (StatusCode::NOT_FOUND, "Session not found after update".to_string()))?;
.map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("Database error: {e}"),
)
})?
.ok_or_else(|| {
(
StatusCode::NOT_FOUND,
"Session not found after update".to_string(),
)
})?;
Ok(Json(ApiResponse {
data: updated,
@@ -420,9 +334,7 @@ pub async fn get_attack_chain(
Extension(agent): AgentExt,
Path(id): Path<String>,
) -> Result<Json<ApiResponse<Vec<AttackChainNode>>>, StatusCode> {
// Verify the session ID is valid
let _oid =
mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?;
let _oid = mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?;
let nodes = match agent
.db
@@ -453,8 +365,7 @@ pub async fn get_messages(
Path(id): Path<String>,
Query(params): Query<PaginationParams>,
) -> Result<Json<ApiResponse<Vec<PentestMessage>>>, StatusCode> {
let _oid =
mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?;
let _oid = mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?;
let skip = (params.page.saturating_sub(1)) * params.limit as u64;
let total = agent
@@ -487,95 +398,14 @@ pub async fn get_messages(
}))
}
/// GET /api/v1/pentest/stats — Aggregated pentest statistics
#[tracing::instrument(skip_all)]
pub async fn pentest_stats(
Extension(agent): AgentExt,
) -> Result<Json<ApiResponse<PentestStats>>, StatusCode> {
let db = &agent.db;
let running_sessions = db
.pentest_sessions()
.count_documents(doc! { "status": "running" })
.await
.unwrap_or(0) as u32;
// Count DAST findings from pentest sessions
let total_vulnerabilities = db
.dast_findings()
.count_documents(doc! { "session_id": { "$exists": true, "$ne": null } })
.await
.unwrap_or(0) as u32;
// Aggregate tool invocations from all sessions
let sessions: Vec<PentestSession> = match db.pentest_sessions().find(doc! {}).await {
Ok(cursor) => collect_cursor_async(cursor).await,
Err(_) => Vec::new(),
};
let total_tool_invocations: u32 = sessions.iter().map(|s| s.tool_invocations).sum();
let total_successes: u32 = sessions.iter().map(|s| s.tool_successes).sum();
let tool_success_rate = if total_tool_invocations == 0 {
100.0
} else {
(total_successes as f64 / total_tool_invocations as f64) * 100.0
};
// Severity distribution from pentest-related DAST findings
let critical = db
.dast_findings()
.count_documents(doc! { "session_id": { "$exists": true, "$ne": null }, "severity": "critical" })
.await
.unwrap_or(0) as u32;
let high = db
.dast_findings()
.count_documents(doc! { "session_id": { "$exists": true, "$ne": null }, "severity": "high" })
.await
.unwrap_or(0) as u32;
let medium = db
.dast_findings()
.count_documents(doc! { "session_id": { "$exists": true, "$ne": null }, "severity": "medium" })
.await
.unwrap_or(0) as u32;
let low = db
.dast_findings()
.count_documents(doc! { "session_id": { "$exists": true, "$ne": null }, "severity": "low" })
.await
.unwrap_or(0) as u32;
let info = db
.dast_findings()
.count_documents(doc! { "session_id": { "$exists": true, "$ne": null }, "severity": "info" })
.await
.unwrap_or(0) as u32;
Ok(Json(ApiResponse {
data: PentestStats {
running_sessions,
total_vulnerabilities,
total_tool_invocations,
tool_success_rate,
severity_distribution: SeverityDistribution {
critical,
high,
medium,
low,
info,
},
},
total: None,
page: None,
}))
}
/// GET /api/v1/pentest/sessions/:id/findings — Get DAST findings for a pentest session
#[tracing::instrument(skip_all, fields(session_id = %id))]
pub async fn get_session_findings(
Extension(agent): AgentExt,
Path(id): Path<String>,
Query(params): Query<PaginationParams>,
) -> Result<Json<ApiResponse<Vec<DastFinding>>>, StatusCode> {
let _oid =
mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?;
) -> Result<Json<ApiResponse<Vec<compliance_core::models::dast::DastFinding>>>, StatusCode> {
let _oid = mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?;
let skip = (params.page.saturating_sub(1)) * params.limit as u64;
let total = agent
@@ -607,112 +437,3 @@ pub async fn get_session_findings(
page: Some(params.page),
}))
}
#[derive(Deserialize)]
pub struct ExportBody {
pub password: String,
/// Requester display name (from auth)
#[serde(default)]
pub requester_name: String,
/// Requester email (from auth)
#[serde(default)]
pub requester_email: String,
}
/// POST /api/v1/pentest/sessions/:id/export — Export an encrypted pentest report archive
#[tracing::instrument(skip_all, fields(session_id = %id))]
pub async fn export_session_report(
Extension(agent): AgentExt,
Path(id): Path<String>,
Json(body): Json<ExportBody>,
) -> Result<axum::response::Response, (StatusCode, String)> {
let oid = mongodb::bson::oid::ObjectId::parse_str(&id)
.map_err(|_| (StatusCode::BAD_REQUEST, "Invalid session ID".to_string()))?;
if body.password.len() < 8 {
return Err((
StatusCode::BAD_REQUEST,
"Password must be at least 8 characters".to_string(),
));
}
// Fetch session
let session = agent
.db
.pentest_sessions()
.find_one(doc! { "_id": oid })
.await
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, format!("Database error: {e}")))?
.ok_or_else(|| (StatusCode::NOT_FOUND, "Session not found".to_string()))?;
// Resolve target name
let target = if let Ok(tid) = mongodb::bson::oid::ObjectId::parse_str(&session.target_id) {
agent
.db
.dast_targets()
.find_one(doc! { "_id": tid })
.await
.ok()
.flatten()
} else {
None
};
let target_name = target
.as_ref()
.map(|t| t.name.clone())
.unwrap_or_else(|| "Unknown Target".to_string());
let target_url = target
.as_ref()
.map(|t| t.base_url.clone())
.unwrap_or_default();
// Fetch attack chain nodes
let nodes: Vec<AttackChainNode> = match agent
.db
.attack_chain_nodes()
.find(doc! { "session_id": &id })
.sort(doc! { "started_at": 1 })
.await
{
Ok(cursor) => collect_cursor_async(cursor).await,
Err(_) => Vec::new(),
};
// Fetch DAST findings for this session
let findings: Vec<DastFinding> = match agent
.db
.dast_findings()
.find(doc! { "session_id": &id })
.sort(doc! { "severity": -1, "created_at": -1 })
.await
{
Ok(cursor) => collect_cursor_async(cursor).await,
Err(_) => Vec::new(),
};
let ctx = crate::pentest::report::ReportContext {
session,
target_name,
target_url,
findings,
attack_chain: nodes,
requester_name: if body.requester_name.is_empty() {
"Unknown".to_string()
} else {
body.requester_name
},
requester_email: body.requester_email,
};
let report = crate::pentest::generate_encrypted_report(&ctx, &body.password)
.await
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e))?;
let response = serde_json::json!({
"archive_base64": base64::Engine::encode(&base64::engine::general_purpose::STANDARD, &report.archive),
"sha256": report.sha256,
"filename": format!("pentest-report-{id}.zip"),
});
Ok(Json(response).into_response())
}

View File

@@ -0,0 +1,102 @@
use std::sync::Arc;
use axum::extract::Extension;
use axum::http::StatusCode;
use axum::Json;
use mongodb::bson::doc;
use compliance_core::models::pentest::*;
use crate::agent::ComplianceAgent;
use super::super::dto::{collect_cursor_async, ApiResponse};
type AgentExt = Extension<Arc<ComplianceAgent>>;
/// GET /api/v1/pentest/stats — Aggregated pentest statistics
#[tracing::instrument(skip_all)]
pub async fn pentest_stats(
Extension(agent): AgentExt,
) -> Result<Json<ApiResponse<PentestStats>>, StatusCode> {
let db = &agent.db;
let running_sessions = db
.pentest_sessions()
.count_documents(doc! { "status": "running" })
.await
.unwrap_or(0) as u32;
// Count DAST findings from pentest sessions
let total_vulnerabilities = db
.dast_findings()
.count_documents(doc! { "session_id": { "$exists": true, "$ne": null } })
.await
.unwrap_or(0) as u32;
// Aggregate tool invocations from all sessions
let sessions: Vec<PentestSession> = match db.pentest_sessions().find(doc! {}).await {
Ok(cursor) => collect_cursor_async(cursor).await,
Err(_) => Vec::new(),
};
let total_tool_invocations: u32 = sessions.iter().map(|s| s.tool_invocations).sum();
let total_successes: u32 = sessions.iter().map(|s| s.tool_successes).sum();
let tool_success_rate = if total_tool_invocations == 0 {
100.0
} else {
(total_successes as f64 / total_tool_invocations as f64) * 100.0
};
// Severity distribution from pentest-related DAST findings
let critical = db
.dast_findings()
.count_documents(
doc! { "session_id": { "$exists": true, "$ne": null }, "severity": "critical" },
)
.await
.unwrap_or(0) as u32;
let high = db
.dast_findings()
.count_documents(
doc! { "session_id": { "$exists": true, "$ne": null }, "severity": "high" },
)
.await
.unwrap_or(0) as u32;
let medium = db
.dast_findings()
.count_documents(
doc! { "session_id": { "$exists": true, "$ne": null }, "severity": "medium" },
)
.await
.unwrap_or(0) as u32;
let low = db
.dast_findings()
.count_documents(doc! { "session_id": { "$exists": true, "$ne": null }, "severity": "low" })
.await
.unwrap_or(0) as u32;
let info = db
.dast_findings()
.count_documents(
doc! { "session_id": { "$exists": true, "$ne": null }, "severity": "info" },
)
.await
.unwrap_or(0) as u32;
Ok(Json(ApiResponse {
data: PentestStats {
running_sessions,
total_vulnerabilities,
total_tool_invocations,
tool_success_rate,
severity_distribution: SeverityDistribution {
critical,
high,
medium,
low,
info,
},
},
total: None,
page: None,
}))
}

View File

@@ -0,0 +1,116 @@
use std::sync::Arc;
use axum::extract::{Extension, Path};
use axum::http::StatusCode;
use axum::response::sse::{Event, Sse};
use futures_util::stream;
use mongodb::bson::doc;
use compliance_core::models::pentest::*;
use crate::agent::ComplianceAgent;
use super::super::dto::collect_cursor_async;
type AgentExt = Extension<Arc<ComplianceAgent>>;
/// GET /api/v1/pentest/sessions/:id/stream — SSE endpoint for real-time events
///
/// Returns recent messages as SSE events (polling approach).
/// True real-time streaming with broadcast channels will be added in a future iteration.
#[tracing::instrument(skip_all, fields(session_id = %id))]
pub async fn session_stream(
Extension(agent): AgentExt,
Path(id): Path<String>,
) -> Result<
Sse<impl futures_util::Stream<Item = Result<Event, std::convert::Infallible>>>,
StatusCode,
> {
let oid = mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?;
// Verify session exists
let _session = agent
.db
.pentest_sessions()
.find_one(doc! { "_id": oid })
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
.ok_or(StatusCode::NOT_FOUND)?;
// Fetch recent messages for this session
let messages: Vec<PentestMessage> = match agent
.db
.pentest_messages()
.find(doc! { "session_id": &id })
.sort(doc! { "created_at": 1 })
.limit(100)
.await
{
Ok(cursor) => collect_cursor_async(cursor).await,
Err(_) => Vec::new(),
};
// Fetch recent attack chain nodes
let nodes: Vec<AttackChainNode> = match agent
.db
.attack_chain_nodes()
.find(doc! { "session_id": &id })
.sort(doc! { "started_at": 1 })
.limit(100)
.await
{
Ok(cursor) => collect_cursor_async(cursor).await,
Err(_) => Vec::new(),
};
// Build SSE events from stored data
let mut events: Vec<Result<Event, std::convert::Infallible>> = Vec::new();
for msg in &messages {
let event_data = serde_json::json!({
"type": "message",
"role": msg.role,
"content": msg.content,
"created_at": msg.created_at.to_rfc3339(),
});
if let Ok(data) = serde_json::to_string(&event_data) {
events.push(Ok(Event::default().event("message").data(data)));
}
}
for node in &nodes {
let event_data = serde_json::json!({
"type": "tool_execution",
"node_id": node.node_id,
"tool_name": node.tool_name,
"status": node.status,
"findings_produced": node.findings_produced,
});
if let Ok(data) = serde_json::to_string(&event_data) {
events.push(Ok(Event::default().event("tool").data(data)));
}
}
// Add session status event
let session = agent
.db
.pentest_sessions()
.find_one(doc! { "_id": oid })
.await
.ok()
.flatten();
if let Some(s) = session {
let status_data = serde_json::json!({
"type": "status",
"status": s.status,
"findings_count": s.findings_count,
"tool_invocations": s.tool_invocations,
});
if let Ok(data) = serde_json::to_string(&status_data) {
events.push(Ok(Event::default().event("status").data(data)));
}
}
Ok(Sse::new(stream::iter(events)))
}

View File

@@ -0,0 +1,241 @@
use axum::extract::{Extension, Path, Query};
use axum::http::StatusCode;
use axum::Json;
use mongodb::bson::doc;
use super::dto::*;
use compliance_core::models::*;
#[tracing::instrument(skip_all)]
pub async fn list_repositories(
Extension(agent): AgentExt,
Query(params): Query<PaginationParams>,
) -> ApiResult<Vec<TrackedRepository>> {
let db = &agent.db;
let skip = (params.page.saturating_sub(1)) * params.limit as u64;
let total = db
.repositories()
.count_documents(doc! {})
.await
.unwrap_or(0);
let repos = match db
.repositories()
.find(doc! {})
.skip(skip)
.limit(params.limit)
.await
{
Ok(cursor) => collect_cursor_async(cursor).await,
Err(e) => {
tracing::warn!("Failed to fetch repositories: {e}");
Vec::new()
}
};
Ok(Json(ApiResponse {
data: repos,
total: Some(total),
page: Some(params.page),
}))
}
#[tracing::instrument(skip_all)]
pub async fn add_repository(
Extension(agent): AgentExt,
Json(req): Json<AddRepositoryRequest>,
) -> Result<Json<ApiResponse<TrackedRepository>>, (StatusCode, String)> {
// Validate repository access before saving
let creds = crate::pipeline::git::RepoCredentials {
ssh_key_path: Some(agent.config.ssh_key_path.clone()),
auth_token: req.auth_token.clone(),
auth_username: req.auth_username.clone(),
};
if let Err(e) = crate::pipeline::git::GitOps::test_access(&req.git_url, &creds) {
return Err((
StatusCode::BAD_REQUEST,
format!("Cannot access repository: {e}"),
));
}
let mut repo = TrackedRepository::new(req.name, req.git_url);
repo.default_branch = req.default_branch;
repo.auth_token = req.auth_token;
repo.auth_username = req.auth_username;
repo.tracker_type = req.tracker_type;
repo.tracker_owner = req.tracker_owner;
repo.tracker_repo = req.tracker_repo;
repo.tracker_token = req.tracker_token;
repo.scan_schedule = req.scan_schedule;
agent
.db
.repositories()
.insert_one(&repo)
.await
.map_err(|_| {
(
StatusCode::CONFLICT,
"Repository already exists".to_string(),
)
})?;
Ok(Json(ApiResponse {
data: repo,
total: None,
page: None,
}))
}
#[tracing::instrument(skip_all, fields(repo_id = %id))]
pub async fn update_repository(
Extension(agent): AgentExt,
Path(id): Path<String>,
Json(req): Json<UpdateRepositoryRequest>,
) -> Result<Json<serde_json::Value>, StatusCode> {
let oid = mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?;
let mut set_doc = doc! { "updated_at": mongodb::bson::DateTime::now() };
if let Some(name) = &req.name {
set_doc.insert("name", name);
}
if let Some(branch) = &req.default_branch {
set_doc.insert("default_branch", branch);
}
if let Some(token) = &req.auth_token {
set_doc.insert("auth_token", token);
}
if let Some(username) = &req.auth_username {
set_doc.insert("auth_username", username);
}
if let Some(tracker_type) = &req.tracker_type {
set_doc.insert("tracker_type", tracker_type.to_string());
}
if let Some(owner) = &req.tracker_owner {
set_doc.insert("tracker_owner", owner);
}
if let Some(repo) = &req.tracker_repo {
set_doc.insert("tracker_repo", repo);
}
if let Some(token) = &req.tracker_token {
set_doc.insert("tracker_token", token);
}
if let Some(schedule) = &req.scan_schedule {
set_doc.insert("scan_schedule", schedule);
}
let result = agent
.db
.repositories()
.update_one(doc! { "_id": oid }, doc! { "$set": set_doc })
.await
.map_err(|e| {
tracing::warn!("Failed to update repository: {e}");
StatusCode::INTERNAL_SERVER_ERROR
})?;
if result.matched_count == 0 {
return Err(StatusCode::NOT_FOUND);
}
Ok(Json(serde_json::json!({ "status": "updated" })))
}
#[tracing::instrument(skip_all)]
pub async fn get_ssh_public_key(
Extension(agent): AgentExt,
) -> Result<Json<serde_json::Value>, StatusCode> {
let public_path = format!("{}.pub", agent.config.ssh_key_path);
let public_key = std::fs::read_to_string(&public_path).map_err(|_| StatusCode::NOT_FOUND)?;
Ok(Json(serde_json::json!({ "public_key": public_key.trim() })))
}
#[tracing::instrument(skip_all, fields(repo_id = %id))]
pub async fn trigger_scan(
Extension(agent): AgentExt,
Path(id): Path<String>,
) -> Result<Json<serde_json::Value>, StatusCode> {
let agent_clone = (*agent).clone();
tokio::spawn(async move {
if let Err(e) = agent_clone.run_scan(&id, ScanTrigger::Manual).await {
tracing::error!("Manual scan failed for {id}: {e}");
}
});
Ok(Json(serde_json::json!({ "status": "scan_triggered" })))
}
/// Return the webhook secret for a repository (used by dashboard to display it)
pub async fn get_webhook_config(
Extension(agent): AgentExt,
Path(id): Path<String>,
) -> Result<Json<serde_json::Value>, StatusCode> {
let oid = mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?;
let repo = agent
.db
.repositories()
.find_one(doc! { "_id": oid })
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
.ok_or(StatusCode::NOT_FOUND)?;
let tracker_type = repo
.tracker_type
.as_ref()
.map(|t| t.to_string())
.unwrap_or_else(|| "gitea".to_string());
Ok(Json(serde_json::json!({
"webhook_secret": repo.webhook_secret,
"tracker_type": tracker_type,
})))
}
#[tracing::instrument(skip_all, fields(repo_id = %id))]
pub async fn delete_repository(
Extension(agent): AgentExt,
Path(id): Path<String>,
) -> Result<Json<serde_json::Value>, StatusCode> {
let oid = mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?;
let db = &agent.db;
// Delete the repository
let result = db
.repositories()
.delete_one(doc! { "_id": oid })
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
if result.deleted_count == 0 {
return Err(StatusCode::NOT_FOUND);
}
// Cascade delete all related data
let _ = db.findings().delete_many(doc! { "repo_id": &id }).await;
let _ = db.sbom_entries().delete_many(doc! { "repo_id": &id }).await;
let _ = db.scan_runs().delete_many(doc! { "repo_id": &id }).await;
let _ = db.cve_alerts().delete_many(doc! { "repo_id": &id }).await;
let _ = db
.tracker_issues()
.delete_many(doc! { "repo_id": &id })
.await;
let _ = db.graph_nodes().delete_many(doc! { "repo_id": &id }).await;
let _ = db.graph_edges().delete_many(doc! { "repo_id": &id }).await;
let _ = db.graph_builds().delete_many(doc! { "repo_id": &id }).await;
let _ = db
.impact_analyses()
.delete_many(doc! { "repo_id": &id })
.await;
let _ = db
.code_embeddings()
.delete_many(doc! { "repo_id": &id })
.await;
let _ = db
.embedding_builds()
.delete_many(doc! { "repo_id": &id })
.await;
Ok(Json(serde_json::json!({ "status": "deleted" })))
}

View File

@@ -0,0 +1,379 @@
use axum::extract::{Extension, Query};
use axum::http::{header, StatusCode};
use axum::response::IntoResponse;
use axum::Json;
use mongodb::bson::doc;
use super::dto::*;
use compliance_core::models::SbomEntry;
const COPYLEFT_LICENSES: &[&str] = &[
"GPL-2.0",
"GPL-2.0-only",
"GPL-2.0-or-later",
"GPL-3.0",
"GPL-3.0-only",
"GPL-3.0-or-later",
"AGPL-3.0",
"AGPL-3.0-only",
"AGPL-3.0-or-later",
"LGPL-2.1",
"LGPL-2.1-only",
"LGPL-2.1-or-later",
"LGPL-3.0",
"LGPL-3.0-only",
"LGPL-3.0-or-later",
"MPL-2.0",
];
#[tracing::instrument(skip_all)]
pub async fn sbom_filters(
Extension(agent): AgentExt,
) -> Result<Json<serde_json::Value>, StatusCode> {
let db = &agent.db;
let managers: Vec<String> = db
.sbom_entries()
.distinct("package_manager", doc! {})
.await
.unwrap_or_default()
.into_iter()
.filter_map(|v| v.as_str().map(|s| s.to_string()))
.filter(|s| !s.is_empty() && s != "unknown" && s != "file")
.collect();
let licenses: Vec<String> = db
.sbom_entries()
.distinct("license", doc! {})
.await
.unwrap_or_default()
.into_iter()
.filter_map(|v| v.as_str().map(|s| s.to_string()))
.filter(|s| !s.is_empty())
.collect();
Ok(Json(serde_json::json!({
"package_managers": managers,
"licenses": licenses,
})))
}
#[tracing::instrument(skip_all, fields(repo_id = ?filter.repo_id, package_manager = ?filter.package_manager))]
pub async fn list_sbom(
Extension(agent): AgentExt,
Query(filter): Query<SbomFilter>,
) -> ApiResult<Vec<SbomEntry>> {
let db = &agent.db;
let mut query = doc! {};
if let Some(repo_id) = &filter.repo_id {
query.insert("repo_id", repo_id);
}
if let Some(pm) = &filter.package_manager {
query.insert("package_manager", pm);
}
if let Some(q) = &filter.q {
if !q.is_empty() {
query.insert("name", doc! { "$regex": q, "$options": "i" });
}
}
if let Some(has_vulns) = filter.has_vulns {
if has_vulns {
query.insert("known_vulnerabilities", doc! { "$exists": true, "$ne": [] });
} else {
query.insert("known_vulnerabilities", doc! { "$size": 0 });
}
}
if let Some(license) = &filter.license {
query.insert("license", license);
}
let skip = (filter.page.saturating_sub(1)) * filter.limit as u64;
let total = db
.sbom_entries()
.count_documents(query.clone())
.await
.unwrap_or(0);
let entries = match db
.sbom_entries()
.find(query)
.sort(doc! { "name": 1 })
.skip(skip)
.limit(filter.limit)
.await
{
Ok(cursor) => collect_cursor_async(cursor).await,
Err(e) => {
tracing::warn!("Failed to fetch SBOM entries: {e}");
Vec::new()
}
};
Ok(Json(ApiResponse {
data: entries,
total: Some(total),
page: Some(filter.page),
}))
}
#[tracing::instrument(skip_all)]
pub async fn export_sbom(
Extension(agent): AgentExt,
Query(params): Query<SbomExportParams>,
) -> Result<impl IntoResponse, StatusCode> {
let db = &agent.db;
let entries: Vec<SbomEntry> = match db
.sbom_entries()
.find(doc! { "repo_id": &params.repo_id })
.await
{
Ok(cursor) => collect_cursor_async(cursor).await,
Err(e) => {
tracing::warn!("Failed to fetch SBOM entries for export: {e}");
Vec::new()
}
};
let body = if params.format == "spdx" {
// SPDX 2.3 format
let packages: Vec<serde_json::Value> = entries
.iter()
.enumerate()
.map(|(i, e)| {
serde_json::json!({
"SPDXID": format!("SPDXRef-Package-{i}"),
"name": e.name,
"versionInfo": e.version,
"downloadLocation": "NOASSERTION",
"licenseConcluded": e.license.as_deref().unwrap_or("NOASSERTION"),
"externalRefs": e.purl.as_ref().map(|p| vec![serde_json::json!({
"referenceCategory": "PACKAGE-MANAGER",
"referenceType": "purl",
"referenceLocator": p,
})]).unwrap_or_default(),
})
})
.collect();
serde_json::json!({
"spdxVersion": "SPDX-2.3",
"dataLicense": "CC0-1.0",
"SPDXID": "SPDXRef-DOCUMENT",
"name": format!("sbom-{}", params.repo_id),
"documentNamespace": format!("https://compliance-scanner/sbom/{}", params.repo_id),
"packages": packages,
})
} else {
// CycloneDX 1.5 format
let components: Vec<serde_json::Value> = entries
.iter()
.map(|e| {
let mut comp = serde_json::json!({
"type": "library",
"name": e.name,
"version": e.version,
"group": e.package_manager,
});
if let Some(purl) = &e.purl {
comp["purl"] = serde_json::Value::String(purl.clone());
}
if let Some(license) = &e.license {
comp["licenses"] = serde_json::json!([{ "license": { "id": license } }]);
}
if !e.known_vulnerabilities.is_empty() {
comp["vulnerabilities"] = serde_json::json!(
e.known_vulnerabilities.iter().map(|v| serde_json::json!({
"id": v.id,
"source": { "name": v.source },
"ratings": v.severity.as_ref().map(|s| vec![serde_json::json!({"severity": s})]).unwrap_or_default(),
})).collect::<Vec<_>>()
);
}
comp
})
.collect();
serde_json::json!({
"bomFormat": "CycloneDX",
"specVersion": "1.5",
"version": 1,
"metadata": {
"component": {
"type": "application",
"name": format!("repo-{}", params.repo_id),
}
},
"components": components,
})
};
let json_str =
serde_json::to_string_pretty(&body).map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
let filename = if params.format == "spdx" {
format!("sbom-{}-spdx.json", params.repo_id)
} else {
format!("sbom-{}-cyclonedx.json", params.repo_id)
};
let disposition = format!("attachment; filename=\"{filename}\"");
Ok((
[
(
header::CONTENT_TYPE,
header::HeaderValue::from_static("application/json"),
),
(
header::CONTENT_DISPOSITION,
header::HeaderValue::from_str(&disposition)
.unwrap_or_else(|_| header::HeaderValue::from_static("attachment")),
),
],
json_str,
))
}
#[tracing::instrument(skip_all)]
pub async fn license_summary(
Extension(agent): AgentExt,
Query(params): Query<SbomFilter>,
) -> ApiResult<Vec<LicenseSummary>> {
let db = &agent.db;
let mut query = doc! {};
if let Some(repo_id) = &params.repo_id {
query.insert("repo_id", repo_id);
}
let entries: Vec<SbomEntry> = match db.sbom_entries().find(query).await {
Ok(cursor) => collect_cursor_async(cursor).await,
Err(e) => {
tracing::warn!("Failed to fetch SBOM entries for license summary: {e}");
Vec::new()
}
};
let mut license_map: std::collections::HashMap<String, Vec<String>> =
std::collections::HashMap::new();
for entry in &entries {
let lic = entry.license.as_deref().unwrap_or("Unknown").to_string();
license_map.entry(lic).or_default().push(entry.name.clone());
}
let mut summaries: Vec<LicenseSummary> = license_map
.into_iter()
.map(|(license, packages)| {
let is_copyleft = COPYLEFT_LICENSES
.iter()
.any(|c| license.to_uppercase().contains(&c.to_uppercase()));
LicenseSummary {
license,
count: packages.len() as u64,
is_copyleft,
packages,
}
})
.collect();
summaries.sort_by(|a, b| b.count.cmp(&a.count));
Ok(Json(ApiResponse {
data: summaries,
total: None,
page: None,
}))
}
#[tracing::instrument(skip_all)]
pub async fn sbom_diff(
Extension(agent): AgentExt,
Query(params): Query<SbomDiffParams>,
) -> ApiResult<SbomDiffResult> {
let db = &agent.db;
let entries_a: Vec<SbomEntry> = match db
.sbom_entries()
.find(doc! { "repo_id": &params.repo_a })
.await
{
Ok(cursor) => collect_cursor_async(cursor).await,
Err(e) => {
tracing::warn!("Failed to fetch SBOM entries for repo_a: {e}");
Vec::new()
}
};
let entries_b: Vec<SbomEntry> = match db
.sbom_entries()
.find(doc! { "repo_id": &params.repo_b })
.await
{
Ok(cursor) => collect_cursor_async(cursor).await,
Err(e) => {
tracing::warn!("Failed to fetch SBOM entries for repo_b: {e}");
Vec::new()
}
};
// Build maps by (name, package_manager) -> version
let map_a: std::collections::HashMap<(String, String), String> = entries_a
.iter()
.map(|e| {
(
(e.name.clone(), e.package_manager.clone()),
e.version.clone(),
)
})
.collect();
let map_b: std::collections::HashMap<(String, String), String> = entries_b
.iter()
.map(|e| {
(
(e.name.clone(), e.package_manager.clone()),
e.version.clone(),
)
})
.collect();
let mut only_in_a = Vec::new();
let mut version_changed = Vec::new();
let mut common_count: u64 = 0;
for (key, ver_a) in &map_a {
match map_b.get(key) {
None => only_in_a.push(SbomDiffEntry {
name: key.0.clone(),
version: ver_a.clone(),
package_manager: key.1.clone(),
}),
Some(ver_b) if ver_a != ver_b => {
version_changed.push(SbomVersionDiff {
name: key.0.clone(),
package_manager: key.1.clone(),
version_a: ver_a.clone(),
version_b: ver_b.clone(),
});
}
Some(_) => common_count += 1,
}
}
let only_in_b: Vec<SbomDiffEntry> = map_b
.iter()
.filter(|(key, _)| !map_a.contains_key(key))
.map(|(key, ver)| SbomDiffEntry {
name: key.0.clone(),
version: ver.clone(),
package_manager: key.1.clone(),
})
.collect();
Ok(Json(ApiResponse {
data: SbomDiffResult {
only_in_a,
only_in_b,
version_changed,
common_count,
},
total: None,
page: None,
}))
}

View File

@@ -0,0 +1,37 @@
use axum::extract::{Extension, Query};
use axum::Json;
use mongodb::bson::doc;
use super::dto::*;
use compliance_core::models::ScanRun;
#[tracing::instrument(skip_all)]
pub async fn list_scan_runs(
Extension(agent): AgentExt,
Query(params): Query<PaginationParams>,
) -> ApiResult<Vec<ScanRun>> {
let db = &agent.db;
let skip = (params.page.saturating_sub(1)) * params.limit as u64;
let total = db.scan_runs().count_documents(doc! {}).await.unwrap_or(0);
let scans = match db
.scan_runs()
.find(doc! {})
.sort(doc! { "started_at": -1 })
.skip(skip)
.limit(params.limit)
.await
{
Ok(cursor) => collect_cursor_async(cursor).await,
Err(e) => {
tracing::warn!("Failed to fetch scan runs: {e}");
Vec::new()
}
};
Ok(Json(ApiResponse {
data: scans,
total: Some(total),
page: Some(params.page),
}))
}

View File

@@ -136,7 +136,10 @@ pub fn build_router() -> Router {
"/api/v1/pentest/sessions/{id}/export",
post(handlers::pentest::export_session_report),
)
.route("/api/v1/pentest/stats", get(handlers::pentest::pentest_stats))
.route(
"/api/v1/pentest/stats",
get(handlers::pentest::pentest_stats),
)
// Webhook endpoints (proxied through dashboard)
.route(
"/webhook/github/{repo_id}",