From 3bb690e5bbafd19d0b820d36f4512d2406eb9e5e Mon Sep 17 00:00:00 2001 From: Sharang Parnerkar Date: Fri, 13 Mar 2026 08:03:45 +0000 Subject: [PATCH] refactor: modularize codebase and add 404 unit tests (#13) --- compliance-agent/src/api/handlers/dto.rs | 481 +++++++ compliance-agent/src/api/handlers/findings.rs | 172 +++ compliance-agent/src/api/handlers/health.rs | 84 ++ compliance-agent/src/api/handlers/issues.rs | 41 + compliance-agent/src/api/handlers/mod.rs | 1139 +---------------- .../api/handlers/pentest_handlers/export.rs | 131 ++ .../src/api/handlers/pentest_handlers/mod.rs | 9 + .../session.rs} | 351 +---- .../api/handlers/pentest_handlers/stats.rs | 102 ++ .../api/handlers/pentest_handlers/stream.rs | 116 ++ compliance-agent/src/api/handlers/repos.rs | 241 ++++ compliance-agent/src/api/handlers/sbom.rs | 379 ++++++ compliance-agent/src/api/handlers/scans.rs | 37 + compliance-agent/src/api/routes.rs | 5 +- compliance-agent/src/llm/client.rs | 209 +-- compliance-agent/src/llm/embedding.rs | 74 ++ compliance-agent/src/llm/mod.rs | 5 + compliance-agent/src/llm/triage.rs | 217 ++++ compliance-agent/src/llm/types.rs | 369 ++++++ compliance-agent/src/pentest/context.rs | 150 +++ compliance-agent/src/pentest/mod.rs | 2 + compliance-agent/src/pentest/orchestrator.rs | 420 +----- .../src/pentest/prompt_builder.rs | 504 ++++++++ .../src/pentest/report/archive.rs | 43 + .../src/pentest/{report.rs => report/html.rs} | 680 ++++++---- compliance-agent/src/pentest/report/mod.rs | 58 + compliance-agent/src/pentest/report/pdf.rs | 79 ++ compliance-agent/src/pipeline/dedup.rs | 48 + compliance-agent/src/pipeline/gitleaks.rs | 107 ++ compliance-agent/src/pipeline/graph_build.rs | 106 ++ .../src/pipeline/issue_creation.rs | 259 ++++ compliance-agent/src/pipeline/lint.rs | 366 ------ compliance-agent/src/pipeline/lint/clippy.rs | 251 ++++ compliance-agent/src/pipeline/lint/eslint.rs | 183 +++ compliance-agent/src/pipeline/lint/mod.rs | 97 ++ compliance-agent/src/pipeline/lint/ruff.rs | 150 +++ compliance-agent/src/pipeline/mod.rs | 4 + compliance-agent/src/pipeline/orchestrator.rs | 581 +-------- compliance-agent/src/pipeline/patterns.rs | 156 +++ compliance-agent/src/pipeline/pr_review.rs | 146 +++ .../src/pipeline/sbom/cargo_audit.rs | 72 ++ .../src/pipeline/{sbom.rs => sbom/mod.rs} | 202 +-- compliance-agent/src/pipeline/sbom/syft.rs | 355 +++++ compliance-agent/src/pipeline/semgrep.rs | 121 ++ .../src/pipeline/tracker_dispatch.rs | 81 ++ compliance-agent/tests/common/mod.rs | 3 + compliance-agent/tests/integration/mod.rs | 4 + compliance-core/src/models/pentest.rs | 6 +- compliance-core/tests/models.rs | 475 +++++++ .../src/components/attack_chain/helpers.rs | 283 ++++ .../src/components/attack_chain/mod.rs | 4 + .../src/components/attack_chain/view.rs | 363 ++++++ compliance-dashboard/src/components/mod.rs | 1 + .../src/infrastructure/pentest.rs | 15 +- .../src/pages/pentest_dashboard.rs | 5 +- .../src/pages/pentest_session.rs | 607 +-------- compliance-dast/src/tools/api_fuzzer.rs | 126 +- compliance-dast/src/tools/auth_bypass.rs | 104 +- .../src/tools/console_log_detector.rs | 399 +++--- compliance-dast/src/tools/cookie_analyzer.rs | 480 ++++--- compliance-dast/src/tools/cors_checker.rs | 652 +++++----- compliance-dast/src/tools/csp_analyzer.rs | 387 ++++-- compliance-dast/src/tools/dmarc_checker.rs | 497 ++++--- compliance-dast/src/tools/dns_checker.rs | 430 ++++--- compliance-dast/src/tools/mod.rs | 19 +- compliance-dast/src/tools/openapi_parser.rs | 383 ++++-- .../src/tools/rate_limit_tester.rs | 385 +++--- compliance-dast/src/tools/recon.rs | 107 +- compliance-dast/src/tools/security_headers.rs | 244 ++-- compliance-dast/src/tools/sql_injection.rs | 160 ++- compliance-dast/src/tools/ssrf.rs | 104 +- compliance-dast/src/tools/tls_analyzer.rs | 455 +++---- compliance-dast/src/tools/xss.rs | 189 ++- compliance-dast/tests/agents.rs | 4 + compliance-graph/src/graph/chunking.rs | 61 + compliance-graph/src/graph/community.rs | 212 +++ compliance-graph/src/graph/engine.rs | 182 +++ compliance-graph/src/graph/impact.rs | 375 ++++++ compliance-graph/src/parsers/registry.rs | 112 ++ compliance-graph/src/parsers/rust_parser.rs | 208 +++ compliance-graph/src/search/index.rs | 183 +++ compliance-graph/tests/parsers.rs | 4 + compliance-mcp/src/tools/dast.rs | 60 + compliance-mcp/src/tools/findings.rs | 83 ++ compliance-mcp/src/tools/pentest.rs | 84 ++ compliance-mcp/src/tools/sbom.rs | 60 + compliance-mcp/tests/tools.rs | 4 + fuzz/Cargo.toml | 16 + fuzz/fuzz_targets/fuzz_finding_dedup.rs | 12 + 89 files changed, 11884 insertions(+), 6046 deletions(-) create mode 100644 compliance-agent/src/api/handlers/dto.rs create mode 100644 compliance-agent/src/api/handlers/findings.rs create mode 100644 compliance-agent/src/api/handlers/health.rs create mode 100644 compliance-agent/src/api/handlers/issues.rs create mode 100644 compliance-agent/src/api/handlers/pentest_handlers/export.rs create mode 100644 compliance-agent/src/api/handlers/pentest_handlers/mod.rs rename compliance-agent/src/api/handlers/{pentest.rs => pentest_handlers/session.rs} (53%) create mode 100644 compliance-agent/src/api/handlers/pentest_handlers/stats.rs create mode 100644 compliance-agent/src/api/handlers/pentest_handlers/stream.rs create mode 100644 compliance-agent/src/api/handlers/repos.rs create mode 100644 compliance-agent/src/api/handlers/sbom.rs create mode 100644 compliance-agent/src/api/handlers/scans.rs create mode 100644 compliance-agent/src/llm/embedding.rs create mode 100644 compliance-agent/src/llm/types.rs create mode 100644 compliance-agent/src/pentest/context.rs create mode 100644 compliance-agent/src/pentest/prompt_builder.rs create mode 100644 compliance-agent/src/pentest/report/archive.rs rename compliance-agent/src/pentest/{report.rs => report/html.rs} (73%) create mode 100644 compliance-agent/src/pentest/report/mod.rs create mode 100644 compliance-agent/src/pentest/report/pdf.rs create mode 100644 compliance-agent/src/pipeline/graph_build.rs create mode 100644 compliance-agent/src/pipeline/issue_creation.rs delete mode 100644 compliance-agent/src/pipeline/lint.rs create mode 100644 compliance-agent/src/pipeline/lint/clippy.rs create mode 100644 compliance-agent/src/pipeline/lint/eslint.rs create mode 100644 compliance-agent/src/pipeline/lint/mod.rs create mode 100644 compliance-agent/src/pipeline/lint/ruff.rs create mode 100644 compliance-agent/src/pipeline/pr_review.rs create mode 100644 compliance-agent/src/pipeline/sbom/cargo_audit.rs rename compliance-agent/src/pipeline/{sbom.rs => sbom/mod.rs} (55%) create mode 100644 compliance-agent/src/pipeline/sbom/syft.rs create mode 100644 compliance-agent/src/pipeline/tracker_dispatch.rs create mode 100644 compliance-agent/tests/common/mod.rs create mode 100644 compliance-agent/tests/integration/mod.rs create mode 100644 compliance-core/tests/models.rs create mode 100644 compliance-dashboard/src/components/attack_chain/helpers.rs create mode 100644 compliance-dashboard/src/components/attack_chain/mod.rs create mode 100644 compliance-dashboard/src/components/attack_chain/view.rs create mode 100644 compliance-dast/tests/agents.rs create mode 100644 compliance-graph/tests/parsers.rs create mode 100644 compliance-mcp/tests/tools.rs create mode 100644 fuzz/Cargo.toml create mode 100644 fuzz/fuzz_targets/fuzz_finding_dedup.rs diff --git a/compliance-agent/src/api/handlers/dto.rs b/compliance-agent/src/api/handlers/dto.rs new file mode 100644 index 0000000..b6c0992 --- /dev/null +++ b/compliance-agent/src/api/handlers/dto.rs @@ -0,0 +1,481 @@ +use compliance_core::models::TrackerType; +use serde::{Deserialize, Serialize}; + +use compliance_core::models::ScanRun; + +#[derive(Deserialize)] +pub struct PaginationParams { + #[serde(default = "default_page")] + pub page: u64, + #[serde(default = "default_limit")] + pub limit: i64, +} + +pub(crate) fn default_page() -> u64 { + 1 +} +pub(crate) fn default_limit() -> i64 { + 50 +} + +#[derive(Deserialize)] +pub struct FindingsFilter { + #[serde(default)] + pub repo_id: Option, + #[serde(default)] + pub severity: Option, + #[serde(default)] + pub scan_type: Option, + #[serde(default)] + pub status: Option, + #[serde(default)] + pub q: Option, + #[serde(default)] + pub sort_by: Option, + #[serde(default)] + pub sort_order: Option, + #[serde(default = "default_page")] + pub page: u64, + #[serde(default = "default_limit")] + pub limit: i64, +} + +#[derive(Serialize)] +pub struct ApiResponse { + pub data: T, + #[serde(skip_serializing_if = "Option::is_none")] + pub total: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub page: Option, +} + +#[derive(Serialize)] +pub struct OverviewStats { + pub total_repositories: u64, + pub total_findings: u64, + pub critical_findings: u64, + pub high_findings: u64, + pub medium_findings: u64, + pub low_findings: u64, + pub total_sbom_entries: u64, + pub total_cve_alerts: u64, + pub total_issues: u64, + pub recent_scans: Vec, +} + +#[derive(Deserialize)] +pub struct AddRepositoryRequest { + pub name: String, + pub git_url: String, + #[serde(default = "default_branch")] + pub default_branch: String, + pub auth_token: Option, + pub auth_username: Option, + pub tracker_type: Option, + pub tracker_owner: Option, + pub tracker_repo: Option, + pub tracker_token: Option, + pub scan_schedule: Option, +} + +#[derive(Deserialize)] +pub struct UpdateRepositoryRequest { + pub name: Option, + pub default_branch: Option, + pub auth_token: Option, + pub auth_username: Option, + pub tracker_type: Option, + pub tracker_owner: Option, + pub tracker_repo: Option, + pub tracker_token: Option, + pub scan_schedule: Option, +} + +fn default_branch() -> String { + "main".to_string() +} + +#[derive(Deserialize)] +pub struct UpdateStatusRequest { + pub status: String, +} + +#[derive(Deserialize)] +pub struct BulkUpdateStatusRequest { + pub ids: Vec, + pub status: String, +} + +#[derive(Deserialize)] +pub struct UpdateFeedbackRequest { + pub feedback: String, +} + +#[derive(Deserialize)] +pub struct SbomFilter { + #[serde(default)] + pub repo_id: Option, + #[serde(default)] + pub package_manager: Option, + #[serde(default)] + pub q: Option, + #[serde(default)] + pub has_vulns: Option, + #[serde(default)] + pub license: Option, + #[serde(default = "default_page")] + pub page: u64, + #[serde(default = "default_limit")] + pub limit: i64, +} + +#[derive(Deserialize)] +pub struct SbomExportParams { + pub repo_id: String, + #[serde(default = "default_export_format")] + pub format: String, +} + +fn default_export_format() -> String { + "cyclonedx".to_string() +} + +#[derive(Deserialize)] +pub struct SbomDiffParams { + pub repo_a: String, + pub repo_b: String, +} + +#[derive(Serialize)] +pub struct LicenseSummary { + pub license: String, + pub count: u64, + pub is_copyleft: bool, + pub packages: Vec, +} + +#[derive(Serialize)] +pub struct SbomDiffResult { + pub only_in_a: Vec, + pub only_in_b: Vec, + pub version_changed: Vec, + pub common_count: u64, +} + +#[derive(Serialize)] +pub struct SbomDiffEntry { + pub name: String, + pub version: String, + pub package_manager: String, +} + +#[derive(Serialize)] +pub struct SbomVersionDiff { + pub name: String, + pub package_manager: String, + pub version_a: String, + pub version_b: String, +} + +pub(crate) type AgentExt = axum::extract::Extension>; +pub(crate) type ApiResult = Result>, axum::http::StatusCode>; + +pub(crate) async fn collect_cursor_async( + mut cursor: mongodb::Cursor, +) -> Vec { + use futures_util::StreamExt; + let mut items = Vec::new(); + while let Some(result) = cursor.next().await { + match result { + Ok(item) => items.push(item), + Err(e) => tracing::warn!("Failed to deserialize document: {e}"), + } + } + items +} + +#[cfg(test)] +mod tests { + use super::*; + use serde_json::json; + + // ── PaginationParams ───────────────────────────────────────── + + #[test] + fn pagination_params_defaults() { + let p: PaginationParams = serde_json::from_str("{}").unwrap(); + assert_eq!(p.page, 1); + assert_eq!(p.limit, 50); + } + + #[test] + fn pagination_params_custom_values() { + let p: PaginationParams = serde_json::from_str(r#"{"page":3,"limit":10}"#).unwrap(); + assert_eq!(p.page, 3); + assert_eq!(p.limit, 10); + } + + #[test] + fn pagination_params_partial_override() { + let p: PaginationParams = serde_json::from_str(r#"{"page":5}"#).unwrap(); + assert_eq!(p.page, 5); + assert_eq!(p.limit, 50); + } + + #[test] + fn pagination_params_zero_page() { + let p: PaginationParams = serde_json::from_str(r#"{"page":0}"#).unwrap(); + assert_eq!(p.page, 0); + } + + // ── FindingsFilter ─────────────────────────────────────────── + + #[test] + fn findings_filter_all_defaults() { + let f: FindingsFilter = serde_json::from_str("{}").unwrap(); + assert!(f.repo_id.is_none()); + assert!(f.severity.is_none()); + assert!(f.scan_type.is_none()); + assert!(f.status.is_none()); + assert!(f.q.is_none()); + assert!(f.sort_by.is_none()); + assert!(f.sort_order.is_none()); + assert_eq!(f.page, 1); + assert_eq!(f.limit, 50); + } + + #[test] + fn findings_filter_with_all_fields() { + let f: FindingsFilter = serde_json::from_str( + r#"{ + "repo_id": "abc", + "severity": "high", + "scan_type": "sast", + "status": "open", + "q": "sql injection", + "sort_by": "severity", + "sort_order": "desc", + "page": 2, + "limit": 25 + }"#, + ) + .unwrap(); + assert_eq!(f.repo_id.as_deref(), Some("abc")); + assert_eq!(f.severity.as_deref(), Some("high")); + assert_eq!(f.scan_type.as_deref(), Some("sast")); + assert_eq!(f.status.as_deref(), Some("open")); + assert_eq!(f.q.as_deref(), Some("sql injection")); + assert_eq!(f.sort_by.as_deref(), Some("severity")); + assert_eq!(f.sort_order.as_deref(), Some("desc")); + assert_eq!(f.page, 2); + assert_eq!(f.limit, 25); + } + + #[test] + fn findings_filter_empty_string_fields() { + let f: FindingsFilter = serde_json::from_str(r#"{"repo_id":"","severity":""}"#).unwrap(); + assert_eq!(f.repo_id.as_deref(), Some("")); + assert_eq!(f.severity.as_deref(), Some("")); + } + + // ── ApiResponse ────────────────────────────────────────────── + + #[test] + fn api_response_serializes_with_all_fields() { + let resp = ApiResponse { + data: vec!["a", "b"], + total: Some(100), + page: Some(1), + }; + let v = serde_json::to_value(&resp).unwrap(); + assert_eq!(v["data"], json!(["a", "b"])); + assert_eq!(v["total"], 100); + assert_eq!(v["page"], 1); + } + + #[test] + fn api_response_skips_none_fields() { + let resp = ApiResponse { + data: "hello", + total: None, + page: None, + }; + let v = serde_json::to_value(&resp).unwrap(); + assert_eq!(v["data"], "hello"); + assert!(v.get("total").is_none()); + assert!(v.get("page").is_none()); + } + + #[test] + fn api_response_with_nested_struct() { + #[derive(Serialize)] + struct Item { + id: u32, + } + let resp = ApiResponse { + data: Item { id: 42 }, + total: Some(1), + page: None, + }; + let v = serde_json::to_value(&resp).unwrap(); + assert_eq!(v["data"]["id"], 42); + assert_eq!(v["total"], 1); + assert!(v.get("page").is_none()); + } + + #[test] + fn api_response_empty_vec() { + let resp: ApiResponse> = ApiResponse { + data: vec![], + total: Some(0), + page: Some(1), + }; + let v = serde_json::to_value(&resp).unwrap(); + assert!(v["data"].as_array().unwrap().is_empty()); + } + + // ── SbomFilter ─────────────────────────────────────────────── + + #[test] + fn sbom_filter_defaults() { + let f: SbomFilter = serde_json::from_str("{}").unwrap(); + assert!(f.repo_id.is_none()); + assert!(f.package_manager.is_none()); + assert!(f.q.is_none()); + assert!(f.has_vulns.is_none()); + assert!(f.license.is_none()); + assert_eq!(f.page, 1); + assert_eq!(f.limit, 50); + } + + #[test] + fn sbom_filter_has_vulns_bool() { + let f: SbomFilter = serde_json::from_str(r#"{"has_vulns": true}"#).unwrap(); + assert_eq!(f.has_vulns, Some(true)); + } + + // ── SbomExportParams ───────────────────────────────────────── + + #[test] + fn sbom_export_params_default_format() { + let p: SbomExportParams = serde_json::from_str(r#"{"repo_id":"r1"}"#).unwrap(); + assert_eq!(p.repo_id, "r1"); + assert_eq!(p.format, "cyclonedx"); + } + + #[test] + fn sbom_export_params_custom_format() { + let p: SbomExportParams = + serde_json::from_str(r#"{"repo_id":"r1","format":"spdx"}"#).unwrap(); + assert_eq!(p.format, "spdx"); + } + + // ── AddRepositoryRequest ───────────────────────────────────── + + #[test] + fn add_repository_request_defaults() { + let r: AddRepositoryRequest = serde_json::from_str( + r#"{ + "name": "my-repo", + "git_url": "https://github.com/x/y.git" + }"#, + ) + .unwrap(); + assert_eq!(r.name, "my-repo"); + assert_eq!(r.git_url, "https://github.com/x/y.git"); + assert_eq!(r.default_branch, "main"); + assert!(r.auth_token.is_none()); + assert!(r.tracker_type.is_none()); + assert!(r.scan_schedule.is_none()); + } + + #[test] + fn add_repository_request_custom_branch() { + let r: AddRepositoryRequest = serde_json::from_str( + r#"{ + "name": "repo", + "git_url": "url", + "default_branch": "develop" + }"#, + ) + .unwrap(); + assert_eq!(r.default_branch, "develop"); + } + + // ── UpdateStatusRequest / BulkUpdateStatusRequest ──────────── + + #[test] + fn update_status_request() { + let r: UpdateStatusRequest = serde_json::from_str(r#"{"status":"resolved"}"#).unwrap(); + assert_eq!(r.status, "resolved"); + } + + #[test] + fn bulk_update_status_request() { + let r: BulkUpdateStatusRequest = + serde_json::from_str(r#"{"ids":["a","b"],"status":"dismissed"}"#).unwrap(); + assert_eq!(r.ids, vec!["a", "b"]); + assert_eq!(r.status, "dismissed"); + } + + #[test] + fn bulk_update_status_empty_ids() { + let r: BulkUpdateStatusRequest = + serde_json::from_str(r#"{"ids":[],"status":"x"}"#).unwrap(); + assert!(r.ids.is_empty()); + } + + // ── SbomDiffResult serialization ───────────────────────────── + + #[test] + fn sbom_diff_result_serializes() { + let r = SbomDiffResult { + only_in_a: vec![SbomDiffEntry { + name: "pkg-a".to_string(), + version: "1.0".to_string(), + package_manager: "npm".to_string(), + }], + only_in_b: vec![], + version_changed: vec![SbomVersionDiff { + name: "shared".to_string(), + package_manager: "cargo".to_string(), + version_a: "0.1".to_string(), + version_b: "0.2".to_string(), + }], + common_count: 10, + }; + let v = serde_json::to_value(&r).unwrap(); + assert_eq!(v["only_in_a"].as_array().unwrap().len(), 1); + assert_eq!(v["only_in_b"].as_array().unwrap().len(), 0); + assert_eq!(v["version_changed"][0]["version_a"], "0.1"); + assert_eq!(v["common_count"], 10); + } + + // ── LicenseSummary ─────────────────────────────────────────── + + #[test] + fn license_summary_serializes() { + let ls = LicenseSummary { + license: "MIT".to_string(), + count: 42, + is_copyleft: false, + packages: vec!["serde".to_string()], + }; + let v = serde_json::to_value(&ls).unwrap(); + assert_eq!(v["license"], "MIT"); + assert_eq!(v["is_copyleft"], false); + assert_eq!(v["count"], 42); + } + + // ── Default helper functions ───────────────────────────────── + + #[test] + fn default_page_returns_1() { + assert_eq!(default_page(), 1); + } + + #[test] + fn default_limit_returns_50() { + assert_eq!(default_limit(), 50); + } +} diff --git a/compliance-agent/src/api/handlers/findings.rs b/compliance-agent/src/api/handlers/findings.rs new file mode 100644 index 0000000..d20a5e9 --- /dev/null +++ b/compliance-agent/src/api/handlers/findings.rs @@ -0,0 +1,172 @@ +use axum::extract::{Extension, Path, Query}; +use axum::http::StatusCode; +use axum::Json; +use mongodb::bson::doc; + +use super::dto::*; +use compliance_core::models::Finding; + +#[tracing::instrument(skip_all, fields(repo_id = ?filter.repo_id, severity = ?filter.severity, scan_type = ?filter.scan_type))] +pub async fn list_findings( + Extension(agent): AgentExt, + Query(filter): Query, +) -> ApiResult> { + let db = &agent.db; + let mut query = doc! {}; + if let Some(repo_id) = &filter.repo_id { + query.insert("repo_id", repo_id); + } + if let Some(severity) = &filter.severity { + query.insert("severity", severity); + } + if let Some(scan_type) = &filter.scan_type { + query.insert("scan_type", scan_type); + } + if let Some(status) = &filter.status { + query.insert("status", status); + } + // Text search across title, description, file_path, rule_id + if let Some(q) = &filter.q { + if !q.is_empty() { + let regex = doc! { "$regex": q, "$options": "i" }; + query.insert( + "$or", + mongodb::bson::bson!([ + { "title": regex.clone() }, + { "description": regex.clone() }, + { "file_path": regex.clone() }, + { "rule_id": regex }, + ]), + ); + } + } + + // Dynamic sort + let sort_field = filter.sort_by.as_deref().unwrap_or("created_at"); + let sort_dir: i32 = match filter.sort_order.as_deref() { + Some("asc") => 1, + _ => -1, + }; + let sort_doc = doc! { sort_field: sort_dir }; + + let skip = (filter.page.saturating_sub(1)) * filter.limit as u64; + let total = db + .findings() + .count_documents(query.clone()) + .await + .unwrap_or(0); + + let findings = match db + .findings() + .find(query) + .sort(sort_doc) + .skip(skip) + .limit(filter.limit) + .await + { + Ok(cursor) => collect_cursor_async(cursor).await, + Err(e) => { + tracing::warn!("Failed to fetch findings: {e}"); + Vec::new() + } + }; + + Ok(Json(ApiResponse { + data: findings, + total: Some(total), + page: Some(filter.page), + })) +} + +#[tracing::instrument(skip_all, fields(finding_id = %id))] +pub async fn get_finding( + Extension(agent): AgentExt, + Path(id): Path, +) -> Result>, StatusCode> { + let oid = mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?; + let finding = agent + .db + .findings() + .find_one(doc! { "_id": oid }) + .await + .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)? + .ok_or(StatusCode::NOT_FOUND)?; + + Ok(Json(ApiResponse { + data: finding, + total: None, + page: None, + })) +} + +#[tracing::instrument(skip_all, fields(finding_id = %id))] +pub async fn update_finding_status( + Extension(agent): AgentExt, + Path(id): Path, + Json(req): Json, +) -> Result, StatusCode> { + let oid = mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?; + + agent + .db + .findings() + .update_one( + doc! { "_id": oid }, + doc! { "$set": { "status": &req.status, "updated_at": mongodb::bson::DateTime::now() } }, + ) + .await + .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; + + Ok(Json(serde_json::json!({ "status": "updated" }))) +} + +#[tracing::instrument(skip_all)] +pub async fn bulk_update_finding_status( + Extension(agent): AgentExt, + Json(req): Json, +) -> Result, StatusCode> { + let oids: Vec = req + .ids + .iter() + .filter_map(|id| mongodb::bson::oid::ObjectId::parse_str(id).ok()) + .collect(); + + if oids.is_empty() { + return Err(StatusCode::BAD_REQUEST); + } + + let result = agent + .db + .findings() + .update_many( + doc! { "_id": { "$in": oids } }, + doc! { "$set": { "status": &req.status, "updated_at": mongodb::bson::DateTime::now() } }, + ) + .await + .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; + + Ok(Json( + serde_json::json!({ "status": "updated", "modified_count": result.modified_count }), + )) +} + +#[tracing::instrument(skip_all)] +pub async fn update_finding_feedback( + Extension(agent): AgentExt, + Path(id): Path, + Json(req): Json, +) -> Result, StatusCode> { + let oid = mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?; + + agent + .db + .findings() + .update_one( + doc! { "_id": oid }, + doc! { "$set": { "developer_feedback": &req.feedback, "updated_at": mongodb::bson::DateTime::now() } }, + ) + .await + .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; + + Ok(Json(serde_json::json!({ "status": "updated" }))) +} diff --git a/compliance-agent/src/api/handlers/health.rs b/compliance-agent/src/api/handlers/health.rs new file mode 100644 index 0000000..264ae10 --- /dev/null +++ b/compliance-agent/src/api/handlers/health.rs @@ -0,0 +1,84 @@ +use axum::Json; +use mongodb::bson::doc; + +use super::dto::*; +use compliance_core::models::ScanRun; + +#[tracing::instrument(skip_all)] +pub async fn health() -> Json { + Json(serde_json::json!({ "status": "ok" })) +} + +#[tracing::instrument(skip_all)] +pub async fn stats_overview(axum::extract::Extension(agent): AgentExt) -> ApiResult { + let db = &agent.db; + + let total_repositories = db + .repositories() + .count_documents(doc! {}) + .await + .unwrap_or(0); + let total_findings = db.findings().count_documents(doc! {}).await.unwrap_or(0); + let critical_findings = db + .findings() + .count_documents(doc! { "severity": "critical" }) + .await + .unwrap_or(0); + let high_findings = db + .findings() + .count_documents(doc! { "severity": "high" }) + .await + .unwrap_or(0); + let medium_findings = db + .findings() + .count_documents(doc! { "severity": "medium" }) + .await + .unwrap_or(0); + let low_findings = db + .findings() + .count_documents(doc! { "severity": "low" }) + .await + .unwrap_or(0); + let total_sbom_entries = db + .sbom_entries() + .count_documents(doc! {}) + .await + .unwrap_or(0); + let total_cve_alerts = db.cve_alerts().count_documents(doc! {}).await.unwrap_or(0); + let total_issues = db + .tracker_issues() + .count_documents(doc! {}) + .await + .unwrap_or(0); + + let recent_scans: Vec = match db + .scan_runs() + .find(doc! {}) + .sort(doc! { "started_at": -1 }) + .limit(10) + .await + { + Ok(cursor) => collect_cursor_async(cursor).await, + Err(e) => { + tracing::warn!("Failed to fetch recent scans: {e}"); + Vec::new() + } + }; + + Ok(Json(ApiResponse { + data: OverviewStats { + total_repositories, + total_findings, + critical_findings, + high_findings, + medium_findings, + low_findings, + total_sbom_entries, + total_cve_alerts, + total_issues, + recent_scans, + }, + total: None, + page: None, + })) +} diff --git a/compliance-agent/src/api/handlers/issues.rs b/compliance-agent/src/api/handlers/issues.rs new file mode 100644 index 0000000..d808445 --- /dev/null +++ b/compliance-agent/src/api/handlers/issues.rs @@ -0,0 +1,41 @@ +use axum::extract::{Extension, Query}; +use axum::Json; +use mongodb::bson::doc; + +use super::dto::*; +use compliance_core::models::TrackerIssue; + +#[tracing::instrument(skip_all)] +pub async fn list_issues( + Extension(agent): AgentExt, + Query(params): Query, +) -> ApiResult> { + let db = &agent.db; + let skip = (params.page.saturating_sub(1)) * params.limit as u64; + let total = db + .tracker_issues() + .count_documents(doc! {}) + .await + .unwrap_or(0); + + let issues = match db + .tracker_issues() + .find(doc! {}) + .sort(doc! { "created_at": -1 }) + .skip(skip) + .limit(params.limit) + .await + { + Ok(cursor) => collect_cursor_async(cursor).await, + Err(e) => { + tracing::warn!("Failed to fetch tracker issues: {e}"); + Vec::new() + } + }; + + Ok(Json(ApiResponse { + data: issues, + total: Some(total), + page: Some(params.page), + })) +} diff --git a/compliance-agent/src/api/handlers/mod.rs b/compliance-agent/src/api/handlers/mod.rs index 63d64ef..c860312 100644 --- a/compliance-agent/src/api/handlers/mod.rs +++ b/compliance-agent/src/api/handlers/mod.rs @@ -1,1124 +1,21 @@ pub mod chat; pub mod dast; +pub mod dto; +pub mod findings; pub mod graph; -pub mod pentest; - -use std::sync::Arc; - -#[allow(unused_imports)] -use axum::extract::{Extension, Path, Query}; -use axum::http::{header, StatusCode}; -use axum::response::IntoResponse; -use axum::Json; -use mongodb::bson::doc; -use serde::{Deserialize, Serialize}; - -use compliance_core::models::*; - -use crate::agent::ComplianceAgent; - -#[derive(Deserialize)] -pub struct PaginationParams { - #[serde(default = "default_page")] - pub page: u64, - #[serde(default = "default_limit")] - pub limit: i64, -} - -fn default_page() -> u64 { - 1 -} -fn default_limit() -> i64 { - 50 -} - -#[derive(Deserialize)] -pub struct FindingsFilter { - #[serde(default)] - pub repo_id: Option, - #[serde(default)] - pub severity: Option, - #[serde(default)] - pub scan_type: Option, - #[serde(default)] - pub status: Option, - #[serde(default)] - pub q: Option, - #[serde(default)] - pub sort_by: Option, - #[serde(default)] - pub sort_order: Option, - #[serde(default = "default_page")] - pub page: u64, - #[serde(default = "default_limit")] - pub limit: i64, -} - -#[derive(Serialize)] -pub struct ApiResponse { - pub data: T, - #[serde(skip_serializing_if = "Option::is_none")] - pub total: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub page: Option, -} - -#[derive(Serialize)] -pub struct OverviewStats { - pub total_repositories: u64, - pub total_findings: u64, - pub critical_findings: u64, - pub high_findings: u64, - pub medium_findings: u64, - pub low_findings: u64, - pub total_sbom_entries: u64, - pub total_cve_alerts: u64, - pub total_issues: u64, - pub recent_scans: Vec, -} - -#[derive(Deserialize)] -pub struct AddRepositoryRequest { - pub name: String, - pub git_url: String, - #[serde(default = "default_branch")] - pub default_branch: String, - pub auth_token: Option, - pub auth_username: Option, - pub tracker_type: Option, - pub tracker_owner: Option, - pub tracker_repo: Option, - pub tracker_token: Option, - pub scan_schedule: Option, -} - -#[derive(Deserialize)] -pub struct UpdateRepositoryRequest { - pub name: Option, - pub default_branch: Option, - pub auth_token: Option, - pub auth_username: Option, - pub tracker_type: Option, - pub tracker_owner: Option, - pub tracker_repo: Option, - pub tracker_token: Option, - pub scan_schedule: Option, -} - -fn default_branch() -> String { - "main".to_string() -} - -#[derive(Deserialize)] -pub struct UpdateStatusRequest { - pub status: String, -} - -#[derive(Deserialize)] -pub struct BulkUpdateStatusRequest { - pub ids: Vec, - pub status: String, -} - -#[derive(Deserialize)] -pub struct UpdateFeedbackRequest { - pub feedback: String, -} - -#[derive(Deserialize)] -pub struct SbomFilter { - #[serde(default)] - pub repo_id: Option, - #[serde(default)] - pub package_manager: Option, - #[serde(default)] - pub q: Option, - #[serde(default)] - pub has_vulns: Option, - #[serde(default)] - pub license: Option, - #[serde(default = "default_page")] - pub page: u64, - #[serde(default = "default_limit")] - pub limit: i64, -} - -#[derive(Deserialize)] -pub struct SbomExportParams { - pub repo_id: String, - #[serde(default = "default_export_format")] - pub format: String, -} - -fn default_export_format() -> String { - "cyclonedx".to_string() -} - -#[derive(Deserialize)] -pub struct SbomDiffParams { - pub repo_a: String, - pub repo_b: String, -} - -#[derive(Serialize)] -pub struct LicenseSummary { - pub license: String, - pub count: u64, - pub is_copyleft: bool, - pub packages: Vec, -} - -#[derive(Serialize)] -pub struct SbomDiffResult { - pub only_in_a: Vec, - pub only_in_b: Vec, - pub version_changed: Vec, - pub common_count: u64, -} - -#[derive(Serialize)] -pub struct SbomDiffEntry { - pub name: String, - pub version: String, - pub package_manager: String, -} - -#[derive(Serialize)] -pub struct SbomVersionDiff { - pub name: String, - pub package_manager: String, - pub version_a: String, - pub version_b: String, -} - -type AgentExt = Extension>; -type ApiResult = Result>, StatusCode>; - -#[tracing::instrument(skip_all)] -pub async fn health() -> Json { - Json(serde_json::json!({ "status": "ok" })) -} - -#[tracing::instrument(skip_all)] -pub async fn stats_overview(Extension(agent): AgentExt) -> ApiResult { - let db = &agent.db; - - let total_repositories = db - .repositories() - .count_documents(doc! {}) - .await - .unwrap_or(0); - let total_findings = db.findings().count_documents(doc! {}).await.unwrap_or(0); - let critical_findings = db - .findings() - .count_documents(doc! { "severity": "critical" }) - .await - .unwrap_or(0); - let high_findings = db - .findings() - .count_documents(doc! { "severity": "high" }) - .await - .unwrap_or(0); - let medium_findings = db - .findings() - .count_documents(doc! { "severity": "medium" }) - .await - .unwrap_or(0); - let low_findings = db - .findings() - .count_documents(doc! { "severity": "low" }) - .await - .unwrap_or(0); - let total_sbom_entries = db - .sbom_entries() - .count_documents(doc! {}) - .await - .unwrap_or(0); - let total_cve_alerts = db.cve_alerts().count_documents(doc! {}).await.unwrap_or(0); - let total_issues = db - .tracker_issues() - .count_documents(doc! {}) - .await - .unwrap_or(0); - - let recent_scans: Vec = match db - .scan_runs() - .find(doc! {}) - .sort(doc! { "started_at": -1 }) - .limit(10) - .await - { - Ok(cursor) => collect_cursor_async(cursor).await, - Err(e) => { - tracing::warn!("Failed to fetch recent scans: {e}"); - Vec::new() - } - }; - - Ok(Json(ApiResponse { - data: OverviewStats { - total_repositories, - total_findings, - critical_findings, - high_findings, - medium_findings, - low_findings, - total_sbom_entries, - total_cve_alerts, - total_issues, - recent_scans, - }, - total: None, - page: None, - })) -} - -#[tracing::instrument(skip_all)] -pub async fn list_repositories( - Extension(agent): AgentExt, - Query(params): Query, -) -> ApiResult> { - let db = &agent.db; - let skip = (params.page.saturating_sub(1)) * params.limit as u64; - let total = db - .repositories() - .count_documents(doc! {}) - .await - .unwrap_or(0); - - let repos = match db - .repositories() - .find(doc! {}) - .skip(skip) - .limit(params.limit) - .await - { - Ok(cursor) => collect_cursor_async(cursor).await, - Err(e) => { - tracing::warn!("Failed to fetch repositories: {e}"); - Vec::new() - } - }; - - Ok(Json(ApiResponse { - data: repos, - total: Some(total), - page: Some(params.page), - })) -} - -#[tracing::instrument(skip_all)] -pub async fn add_repository( - Extension(agent): AgentExt, - Json(req): Json, -) -> Result>, (StatusCode, String)> { - // Validate repository access before saving - let creds = crate::pipeline::git::RepoCredentials { - ssh_key_path: Some(agent.config.ssh_key_path.clone()), - auth_token: req.auth_token.clone(), - auth_username: req.auth_username.clone(), - }; - - if let Err(e) = crate::pipeline::git::GitOps::test_access(&req.git_url, &creds) { - return Err(( - StatusCode::BAD_REQUEST, - format!("Cannot access repository: {e}"), - )); - } - - let mut repo = TrackedRepository::new(req.name, req.git_url); - repo.default_branch = req.default_branch; - repo.auth_token = req.auth_token; - repo.auth_username = req.auth_username; - repo.tracker_type = req.tracker_type; - repo.tracker_owner = req.tracker_owner; - repo.tracker_repo = req.tracker_repo; - repo.tracker_token = req.tracker_token; - repo.scan_schedule = req.scan_schedule; - - agent - .db - .repositories() - .insert_one(&repo) - .await - .map_err(|_| { - ( - StatusCode::CONFLICT, - "Repository already exists".to_string(), - ) - })?; - - Ok(Json(ApiResponse { - data: repo, - total: None, - page: None, - })) -} - -#[tracing::instrument(skip_all, fields(repo_id = %id))] -pub async fn update_repository( - Extension(agent): AgentExt, - Path(id): Path, - Json(req): Json, -) -> Result, StatusCode> { - let oid = mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?; - - let mut set_doc = doc! { "updated_at": mongodb::bson::DateTime::now() }; - - if let Some(name) = &req.name { - set_doc.insert("name", name); - } - if let Some(branch) = &req.default_branch { - set_doc.insert("default_branch", branch); - } - if let Some(token) = &req.auth_token { - set_doc.insert("auth_token", token); - } - if let Some(username) = &req.auth_username { - set_doc.insert("auth_username", username); - } - if let Some(tracker_type) = &req.tracker_type { - set_doc.insert("tracker_type", tracker_type.to_string()); - } - if let Some(owner) = &req.tracker_owner { - set_doc.insert("tracker_owner", owner); - } - if let Some(repo) = &req.tracker_repo { - set_doc.insert("tracker_repo", repo); - } - if let Some(token) = &req.tracker_token { - set_doc.insert("tracker_token", token); - } - if let Some(schedule) = &req.scan_schedule { - set_doc.insert("scan_schedule", schedule); - } - - let result = agent - .db - .repositories() - .update_one(doc! { "_id": oid }, doc! { "$set": set_doc }) - .await - .map_err(|e| { - tracing::warn!("Failed to update repository: {e}"); - StatusCode::INTERNAL_SERVER_ERROR - })?; - - if result.matched_count == 0 { - return Err(StatusCode::NOT_FOUND); - } - - Ok(Json(serde_json::json!({ "status": "updated" }))) -} - -#[tracing::instrument(skip_all)] -pub async fn get_ssh_public_key( - Extension(agent): AgentExt, -) -> Result, StatusCode> { - let public_path = format!("{}.pub", agent.config.ssh_key_path); - let public_key = std::fs::read_to_string(&public_path).map_err(|_| StatusCode::NOT_FOUND)?; - Ok(Json(serde_json::json!({ "public_key": public_key.trim() }))) -} - -#[tracing::instrument(skip_all, fields(repo_id = %id))] -pub async fn trigger_scan( - Extension(agent): AgentExt, - Path(id): Path, -) -> Result, StatusCode> { - let agent_clone = (*agent).clone(); - tokio::spawn(async move { - if let Err(e) = agent_clone.run_scan(&id, ScanTrigger::Manual).await { - tracing::error!("Manual scan failed for {id}: {e}"); - } - }); - - Ok(Json(serde_json::json!({ "status": "scan_triggered" }))) -} - -/// Return the webhook secret for a repository (used by dashboard to display it) -pub async fn get_webhook_config( - Extension(agent): AgentExt, - Path(id): Path, -) -> Result, StatusCode> { - let oid = mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?; - let repo = agent - .db - .repositories() - .find_one(doc! { "_id": oid }) - .await - .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)? - .ok_or(StatusCode::NOT_FOUND)?; - - let tracker_type = repo - .tracker_type - .as_ref() - .map(|t| t.to_string()) - .unwrap_or_else(|| "gitea".to_string()); - - Ok(Json(serde_json::json!({ - "webhook_secret": repo.webhook_secret, - "tracker_type": tracker_type, - }))) -} - -#[tracing::instrument(skip_all, fields(repo_id = %id))] -pub async fn delete_repository( - Extension(agent): AgentExt, - Path(id): Path, -) -> Result, StatusCode> { - let oid = mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?; - let db = &agent.db; - - // Delete the repository - let result = db - .repositories() - .delete_one(doc! { "_id": oid }) - .await - .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; - - if result.deleted_count == 0 { - return Err(StatusCode::NOT_FOUND); - } - - // Cascade delete all related data - let _ = db.findings().delete_many(doc! { "repo_id": &id }).await; - let _ = db.sbom_entries().delete_many(doc! { "repo_id": &id }).await; - let _ = db.scan_runs().delete_many(doc! { "repo_id": &id }).await; - let _ = db.cve_alerts().delete_many(doc! { "repo_id": &id }).await; - let _ = db - .tracker_issues() - .delete_many(doc! { "repo_id": &id }) - .await; - let _ = db.graph_nodes().delete_many(doc! { "repo_id": &id }).await; - let _ = db.graph_edges().delete_many(doc! { "repo_id": &id }).await; - let _ = db.graph_builds().delete_many(doc! { "repo_id": &id }).await; - let _ = db - .impact_analyses() - .delete_many(doc! { "repo_id": &id }) - .await; - let _ = db - .code_embeddings() - .delete_many(doc! { "repo_id": &id }) - .await; - let _ = db - .embedding_builds() - .delete_many(doc! { "repo_id": &id }) - .await; - - Ok(Json(serde_json::json!({ "status": "deleted" }))) -} - -#[tracing::instrument(skip_all, fields(repo_id = ?filter.repo_id, severity = ?filter.severity, scan_type = ?filter.scan_type))] -pub async fn list_findings( - Extension(agent): AgentExt, - Query(filter): Query, -) -> ApiResult> { - let db = &agent.db; - let mut query = doc! {}; - if let Some(repo_id) = &filter.repo_id { - query.insert("repo_id", repo_id); - } - if let Some(severity) = &filter.severity { - query.insert("severity", severity); - } - if let Some(scan_type) = &filter.scan_type { - query.insert("scan_type", scan_type); - } - if let Some(status) = &filter.status { - query.insert("status", status); - } - // Text search across title, description, file_path, rule_id - if let Some(q) = &filter.q { - if !q.is_empty() { - let regex = doc! { "$regex": q, "$options": "i" }; - query.insert( - "$or", - mongodb::bson::bson!([ - { "title": regex.clone() }, - { "description": regex.clone() }, - { "file_path": regex.clone() }, - { "rule_id": regex }, - ]), - ); - } - } - - // Dynamic sort - let sort_field = filter.sort_by.as_deref().unwrap_or("created_at"); - let sort_dir: i32 = match filter.sort_order.as_deref() { - Some("asc") => 1, - _ => -1, - }; - let sort_doc = doc! { sort_field: sort_dir }; - - let skip = (filter.page.saturating_sub(1)) * filter.limit as u64; - let total = db - .findings() - .count_documents(query.clone()) - .await - .unwrap_or(0); - - let findings = match db - .findings() - .find(query) - .sort(sort_doc) - .skip(skip) - .limit(filter.limit) - .await - { - Ok(cursor) => collect_cursor_async(cursor).await, - Err(e) => { - tracing::warn!("Failed to fetch findings: {e}"); - Vec::new() - } - }; - - Ok(Json(ApiResponse { - data: findings, - total: Some(total), - page: Some(filter.page), - })) -} - -#[tracing::instrument(skip_all, fields(finding_id = %id))] -pub async fn get_finding( - Extension(agent): AgentExt, - Path(id): Path, -) -> Result>, StatusCode> { - let oid = mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?; - let finding = agent - .db - .findings() - .find_one(doc! { "_id": oid }) - .await - .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)? - .ok_or(StatusCode::NOT_FOUND)?; - - Ok(Json(ApiResponse { - data: finding, - total: None, - page: None, - })) -} - -#[tracing::instrument(skip_all, fields(finding_id = %id))] -pub async fn update_finding_status( - Extension(agent): AgentExt, - Path(id): Path, - Json(req): Json, -) -> Result, StatusCode> { - let oid = mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?; - - agent - .db - .findings() - .update_one( - doc! { "_id": oid }, - doc! { "$set": { "status": &req.status, "updated_at": mongodb::bson::DateTime::now() } }, - ) - .await - .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; - - Ok(Json(serde_json::json!({ "status": "updated" }))) -} - -#[tracing::instrument(skip_all)] -pub async fn bulk_update_finding_status( - Extension(agent): AgentExt, - Json(req): Json, -) -> Result, StatusCode> { - let oids: Vec = req - .ids - .iter() - .filter_map(|id| mongodb::bson::oid::ObjectId::parse_str(id).ok()) - .collect(); - - if oids.is_empty() { - return Err(StatusCode::BAD_REQUEST); - } - - let result = agent - .db - .findings() - .update_many( - doc! { "_id": { "$in": oids } }, - doc! { "$set": { "status": &req.status, "updated_at": mongodb::bson::DateTime::now() } }, - ) - .await - .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; - - Ok(Json( - serde_json::json!({ "status": "updated", "modified_count": result.modified_count }), - )) -} - -#[tracing::instrument(skip_all)] -pub async fn update_finding_feedback( - Extension(agent): AgentExt, - Path(id): Path, - Json(req): Json, -) -> Result, StatusCode> { - let oid = mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?; - - agent - .db - .findings() - .update_one( - doc! { "_id": oid }, - doc! { "$set": { "developer_feedback": &req.feedback, "updated_at": mongodb::bson::DateTime::now() } }, - ) - .await - .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; - - Ok(Json(serde_json::json!({ "status": "updated" }))) -} - -#[tracing::instrument(skip_all)] -pub async fn sbom_filters( - Extension(agent): AgentExt, -) -> Result, StatusCode> { - let db = &agent.db; - - let managers: Vec = db - .sbom_entries() - .distinct("package_manager", doc! {}) - .await - .unwrap_or_default() - .into_iter() - .filter_map(|v| v.as_str().map(|s| s.to_string())) - .filter(|s| !s.is_empty() && s != "unknown" && s != "file") - .collect(); - - let licenses: Vec = db - .sbom_entries() - .distinct("license", doc! {}) - .await - .unwrap_or_default() - .into_iter() - .filter_map(|v| v.as_str().map(|s| s.to_string())) - .filter(|s| !s.is_empty()) - .collect(); - - Ok(Json(serde_json::json!({ - "package_managers": managers, - "licenses": licenses, - }))) -} - -#[tracing::instrument(skip_all, fields(repo_id = ?filter.repo_id, package_manager = ?filter.package_manager))] -pub async fn list_sbom( - Extension(agent): AgentExt, - Query(filter): Query, -) -> ApiResult> { - let db = &agent.db; - let mut query = doc! {}; - - if let Some(repo_id) = &filter.repo_id { - query.insert("repo_id", repo_id); - } - if let Some(pm) = &filter.package_manager { - query.insert("package_manager", pm); - } - if let Some(q) = &filter.q { - if !q.is_empty() { - query.insert("name", doc! { "$regex": q, "$options": "i" }); - } - } - if let Some(has_vulns) = filter.has_vulns { - if has_vulns { - query.insert("known_vulnerabilities", doc! { "$exists": true, "$ne": [] }); - } else { - query.insert("known_vulnerabilities", doc! { "$size": 0 }); - } - } - if let Some(license) = &filter.license { - query.insert("license", license); - } - - let skip = (filter.page.saturating_sub(1)) * filter.limit as u64; - let total = db - .sbom_entries() - .count_documents(query.clone()) - .await - .unwrap_or(0); - - let entries = match db - .sbom_entries() - .find(query) - .sort(doc! { "name": 1 }) - .skip(skip) - .limit(filter.limit) - .await - { - Ok(cursor) => collect_cursor_async(cursor).await, - Err(e) => { - tracing::warn!("Failed to fetch SBOM entries: {e}"); - Vec::new() - } - }; - - Ok(Json(ApiResponse { - data: entries, - total: Some(total), - page: Some(filter.page), - })) -} - -#[tracing::instrument(skip_all)] -pub async fn export_sbom( - Extension(agent): AgentExt, - Query(params): Query, -) -> Result { - let db = &agent.db; - let entries: Vec = match db - .sbom_entries() - .find(doc! { "repo_id": ¶ms.repo_id }) - .await - { - Ok(cursor) => collect_cursor_async(cursor).await, - Err(e) => { - tracing::warn!("Failed to fetch SBOM entries for export: {e}"); - Vec::new() - } - }; - - let body = if params.format == "spdx" { - // SPDX 2.3 format - let packages: Vec = entries - .iter() - .enumerate() - .map(|(i, e)| { - serde_json::json!({ - "SPDXID": format!("SPDXRef-Package-{i}"), - "name": e.name, - "versionInfo": e.version, - "downloadLocation": "NOASSERTION", - "licenseConcluded": e.license.as_deref().unwrap_or("NOASSERTION"), - "externalRefs": e.purl.as_ref().map(|p| vec![serde_json::json!({ - "referenceCategory": "PACKAGE-MANAGER", - "referenceType": "purl", - "referenceLocator": p, - })]).unwrap_or_default(), - }) - }) - .collect(); - - serde_json::json!({ - "spdxVersion": "SPDX-2.3", - "dataLicense": "CC0-1.0", - "SPDXID": "SPDXRef-DOCUMENT", - "name": format!("sbom-{}", params.repo_id), - "documentNamespace": format!("https://compliance-scanner/sbom/{}", params.repo_id), - "packages": packages, - }) - } else { - // CycloneDX 1.5 format - let components: Vec = entries - .iter() - .map(|e| { - let mut comp = serde_json::json!({ - "type": "library", - "name": e.name, - "version": e.version, - "group": e.package_manager, - }); - if let Some(purl) = &e.purl { - comp["purl"] = serde_json::Value::String(purl.clone()); - } - if let Some(license) = &e.license { - comp["licenses"] = serde_json::json!([{ "license": { "id": license } }]); - } - if !e.known_vulnerabilities.is_empty() { - comp["vulnerabilities"] = serde_json::json!( - e.known_vulnerabilities.iter().map(|v| serde_json::json!({ - "id": v.id, - "source": { "name": v.source }, - "ratings": v.severity.as_ref().map(|s| vec![serde_json::json!({"severity": s})]).unwrap_or_default(), - })).collect::>() - ); - } - comp - }) - .collect(); - - serde_json::json!({ - "bomFormat": "CycloneDX", - "specVersion": "1.5", - "version": 1, - "metadata": { - "component": { - "type": "application", - "name": format!("repo-{}", params.repo_id), - } - }, - "components": components, - }) - }; - - let json_str = - serde_json::to_string_pretty(&body).map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; - let filename = if params.format == "spdx" { - format!("sbom-{}-spdx.json", params.repo_id) - } else { - format!("sbom-{}-cyclonedx.json", params.repo_id) - }; - - let disposition = format!("attachment; filename=\"{filename}\""); - Ok(( - [ - ( - header::CONTENT_TYPE, - header::HeaderValue::from_static("application/json"), - ), - ( - header::CONTENT_DISPOSITION, - header::HeaderValue::from_str(&disposition) - .unwrap_or_else(|_| header::HeaderValue::from_static("attachment")), - ), - ], - json_str, - )) -} - -const COPYLEFT_LICENSES: &[&str] = &[ - "GPL-2.0", - "GPL-2.0-only", - "GPL-2.0-or-later", - "GPL-3.0", - "GPL-3.0-only", - "GPL-3.0-or-later", - "AGPL-3.0", - "AGPL-3.0-only", - "AGPL-3.0-or-later", - "LGPL-2.1", - "LGPL-2.1-only", - "LGPL-2.1-or-later", - "LGPL-3.0", - "LGPL-3.0-only", - "LGPL-3.0-or-later", - "MPL-2.0", -]; - -#[tracing::instrument(skip_all)] -pub async fn license_summary( - Extension(agent): AgentExt, - Query(params): Query, -) -> ApiResult> { - let db = &agent.db; - let mut query = doc! {}; - if let Some(repo_id) = ¶ms.repo_id { - query.insert("repo_id", repo_id); - } - - let entries: Vec = match db.sbom_entries().find(query).await { - Ok(cursor) => collect_cursor_async(cursor).await, - Err(e) => { - tracing::warn!("Failed to fetch SBOM entries for license summary: {e}"); - Vec::new() - } - }; - - let mut license_map: std::collections::HashMap> = - std::collections::HashMap::new(); - for entry in &entries { - let lic = entry.license.as_deref().unwrap_or("Unknown").to_string(); - license_map.entry(lic).or_default().push(entry.name.clone()); - } - - let mut summaries: Vec = license_map - .into_iter() - .map(|(license, packages)| { - let is_copyleft = COPYLEFT_LICENSES - .iter() - .any(|c| license.to_uppercase().contains(&c.to_uppercase())); - LicenseSummary { - license, - count: packages.len() as u64, - is_copyleft, - packages, - } - }) - .collect(); - summaries.sort_by(|a, b| b.count.cmp(&a.count)); - - Ok(Json(ApiResponse { - data: summaries, - total: None, - page: None, - })) -} - -#[tracing::instrument(skip_all)] -pub async fn sbom_diff( - Extension(agent): AgentExt, - Query(params): Query, -) -> ApiResult { - let db = &agent.db; - - let entries_a: Vec = match db - .sbom_entries() - .find(doc! { "repo_id": ¶ms.repo_a }) - .await - { - Ok(cursor) => collect_cursor_async(cursor).await, - Err(e) => { - tracing::warn!("Failed to fetch SBOM entries for repo_a: {e}"); - Vec::new() - } - }; - - let entries_b: Vec = match db - .sbom_entries() - .find(doc! { "repo_id": ¶ms.repo_b }) - .await - { - Ok(cursor) => collect_cursor_async(cursor).await, - Err(e) => { - tracing::warn!("Failed to fetch SBOM entries for repo_b: {e}"); - Vec::new() - } - }; - - // Build maps by (name, package_manager) -> version - let map_a: std::collections::HashMap<(String, String), String> = entries_a - .iter() - .map(|e| { - ( - (e.name.clone(), e.package_manager.clone()), - e.version.clone(), - ) - }) - .collect(); - let map_b: std::collections::HashMap<(String, String), String> = entries_b - .iter() - .map(|e| { - ( - (e.name.clone(), e.package_manager.clone()), - e.version.clone(), - ) - }) - .collect(); - - let mut only_in_a = Vec::new(); - let mut version_changed = Vec::new(); - let mut common_count: u64 = 0; - - for (key, ver_a) in &map_a { - match map_b.get(key) { - None => only_in_a.push(SbomDiffEntry { - name: key.0.clone(), - version: ver_a.clone(), - package_manager: key.1.clone(), - }), - Some(ver_b) if ver_a != ver_b => { - version_changed.push(SbomVersionDiff { - name: key.0.clone(), - package_manager: key.1.clone(), - version_a: ver_a.clone(), - version_b: ver_b.clone(), - }); - } - Some(_) => common_count += 1, - } - } - - let only_in_b: Vec = map_b - .iter() - .filter(|(key, _)| !map_a.contains_key(key)) - .map(|(key, ver)| SbomDiffEntry { - name: key.0.clone(), - version: ver.clone(), - package_manager: key.1.clone(), - }) - .collect(); - - Ok(Json(ApiResponse { - data: SbomDiffResult { - only_in_a, - only_in_b, - version_changed, - common_count, - }, - total: None, - page: None, - })) -} - -#[tracing::instrument(skip_all)] -pub async fn list_issues( - Extension(agent): AgentExt, - Query(params): Query, -) -> ApiResult> { - let db = &agent.db; - let skip = (params.page.saturating_sub(1)) * params.limit as u64; - let total = db - .tracker_issues() - .count_documents(doc! {}) - .await - .unwrap_or(0); - - let issues = match db - .tracker_issues() - .find(doc! {}) - .sort(doc! { "created_at": -1 }) - .skip(skip) - .limit(params.limit) - .await - { - Ok(cursor) => collect_cursor_async(cursor).await, - Err(e) => { - tracing::warn!("Failed to fetch tracker issues: {e}"); - Vec::new() - } - }; - - Ok(Json(ApiResponse { - data: issues, - total: Some(total), - page: Some(params.page), - })) -} - -#[tracing::instrument(skip_all)] -pub async fn list_scan_runs( - Extension(agent): AgentExt, - Query(params): Query, -) -> ApiResult> { - let db = &agent.db; - let skip = (params.page.saturating_sub(1)) * params.limit as u64; - let total = db.scan_runs().count_documents(doc! {}).await.unwrap_or(0); - - let scans = match db - .scan_runs() - .find(doc! {}) - .sort(doc! { "started_at": -1 }) - .skip(skip) - .limit(params.limit) - .await - { - Ok(cursor) => collect_cursor_async(cursor).await, - Err(e) => { - tracing::warn!("Failed to fetch scan runs: {e}"); - Vec::new() - } - }; - - Ok(Json(ApiResponse { - data: scans, - total: Some(total), - page: Some(params.page), - })) -} - -pub(crate) async fn collect_cursor_async( - mut cursor: mongodb::Cursor, -) -> Vec { - use futures_util::StreamExt; - let mut items = Vec::new(); - while let Some(result) = cursor.next().await { - match result { - Ok(item) => items.push(item), - Err(e) => tracing::warn!("Failed to deserialize document: {e}"), - } - } - items -} +pub mod health; +pub mod issues; +pub mod pentest_handlers; +pub use pentest_handlers as pentest; +pub mod repos; +pub mod sbom; +pub mod scans; + +// Re-export all handler functions so routes.rs can use `handlers::function_name` +pub use dto::*; +pub use findings::*; +pub use health::*; +pub use issues::*; +pub use repos::*; +pub use sbom::*; +pub use scans::*; diff --git a/compliance-agent/src/api/handlers/pentest_handlers/export.rs b/compliance-agent/src/api/handlers/pentest_handlers/export.rs new file mode 100644 index 0000000..e4396c4 --- /dev/null +++ b/compliance-agent/src/api/handlers/pentest_handlers/export.rs @@ -0,0 +1,131 @@ +use std::sync::Arc; + +use axum::extract::{Extension, Path}; +use axum::http::StatusCode; +use axum::response::IntoResponse; +use axum::Json; +use mongodb::bson::doc; +use serde::Deserialize; + +use compliance_core::models::dast::DastFinding; +use compliance_core::models::pentest::*; + +use crate::agent::ComplianceAgent; + +use super::super::dto::collect_cursor_async; + +type AgentExt = Extension>; + +#[derive(Deserialize)] +pub struct ExportBody { + pub password: String, + /// Requester display name (from auth) + #[serde(default)] + pub requester_name: String, + /// Requester email (from auth) + #[serde(default)] + pub requester_email: String, +} + +/// POST /api/v1/pentest/sessions/:id/export — Export an encrypted pentest report archive +#[tracing::instrument(skip_all, fields(session_id = %id))] +pub async fn export_session_report( + Extension(agent): AgentExt, + Path(id): Path, + Json(body): Json, +) -> Result { + let oid = mongodb::bson::oid::ObjectId::parse_str(&id) + .map_err(|_| (StatusCode::BAD_REQUEST, "Invalid session ID".to_string()))?; + + if body.password.len() < 8 { + return Err(( + StatusCode::BAD_REQUEST, + "Password must be at least 8 characters".to_string(), + )); + } + + // Fetch session + let session = agent + .db + .pentest_sessions() + .find_one(doc! { "_id": oid }) + .await + .map_err(|e| { + ( + StatusCode::INTERNAL_SERVER_ERROR, + format!("Database error: {e}"), + ) + })? + .ok_or_else(|| (StatusCode::NOT_FOUND, "Session not found".to_string()))?; + + // Resolve target name + let target = if let Ok(tid) = mongodb::bson::oid::ObjectId::parse_str(&session.target_id) { + agent + .db + .dast_targets() + .find_one(doc! { "_id": tid }) + .await + .ok() + .flatten() + } else { + None + }; + let target_name = target + .as_ref() + .map(|t| t.name.clone()) + .unwrap_or_else(|| "Unknown Target".to_string()); + let target_url = target + .as_ref() + .map(|t| t.base_url.clone()) + .unwrap_or_default(); + + // Fetch attack chain nodes + let nodes: Vec = match agent + .db + .attack_chain_nodes() + .find(doc! { "session_id": &id }) + .sort(doc! { "started_at": 1 }) + .await + { + Ok(cursor) => collect_cursor_async(cursor).await, + Err(_) => Vec::new(), + }; + + // Fetch DAST findings for this session + let findings: Vec = match agent + .db + .dast_findings() + .find(doc! { "session_id": &id }) + .sort(doc! { "severity": -1, "created_at": -1 }) + .await + { + Ok(cursor) => collect_cursor_async(cursor).await, + Err(_) => Vec::new(), + }; + + let ctx = crate::pentest::report::ReportContext { + session, + target_name, + target_url, + findings, + attack_chain: nodes, + requester_name: if body.requester_name.is_empty() { + "Unknown".to_string() + } else { + body.requester_name + }, + requester_email: body.requester_email, + }; + + let report = crate::pentest::generate_encrypted_report(&ctx, &body.password) + .await + .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e))?; + + let response = serde_json::json!({ + "archive_base64": base64::Engine::encode(&base64::engine::general_purpose::STANDARD, &report.archive), + "sha256": report.sha256, + "filename": format!("pentest-report-{id}.zip"), + }); + + Ok(Json(response).into_response()) +} diff --git a/compliance-agent/src/api/handlers/pentest_handlers/mod.rs b/compliance-agent/src/api/handlers/pentest_handlers/mod.rs new file mode 100644 index 0000000..c444890 --- /dev/null +++ b/compliance-agent/src/api/handlers/pentest_handlers/mod.rs @@ -0,0 +1,9 @@ +mod export; +mod session; +mod stats; +mod stream; + +pub use export::*; +pub use session::*; +pub use stats::*; +pub use stream::*; diff --git a/compliance-agent/src/api/handlers/pentest.rs b/compliance-agent/src/api/handlers/pentest_handlers/session.rs similarity index 53% rename from compliance-agent/src/api/handlers/pentest.rs rename to compliance-agent/src/api/handlers/pentest_handlers/session.rs index 57ecf4e..c768625 100644 --- a/compliance-agent/src/api/handlers/pentest.rs +++ b/compliance-agent/src/api/handlers/pentest_handlers/session.rs @@ -2,20 +2,16 @@ use std::sync::Arc; use axum::extract::{Extension, Path, Query}; use axum::http::StatusCode; -use axum::response::sse::{Event, Sse}; -use axum::response::IntoResponse; use axum::Json; -use futures_util::stream; use mongodb::bson::doc; use serde::Deserialize; -use compliance_core::models::dast::DastFinding; use compliance_core::models::pentest::*; use crate::agent::ComplianceAgent; use crate::pentest::PentestOrchestrator; -use super::{collect_cursor_async, ApiResponse, PaginationParams}; +use super::super::dto::{collect_cursor_async, ApiResponse, PaginationParams}; type AgentExt = Extension>; @@ -160,8 +156,7 @@ pub async fn get_session( Extension(agent): AgentExt, Path(id): Path, ) -> Result>, StatusCode> { - let oid = - mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?; + let oid = mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?; let session = agent .db @@ -210,13 +205,12 @@ pub async fn send_message( } // Look up the target - let target_oid = - mongodb::bson::oid::ObjectId::parse_str(&session.target_id).map_err(|_| { - ( - StatusCode::INTERNAL_SERVER_ERROR, - "Invalid target_id in session".to_string(), - ) - })?; + let target_oid = mongodb::bson::oid::ObjectId::parse_str(&session.target_id).map_err(|_| { + ( + StatusCode::INTERNAL_SERVER_ERROR, + "Invalid target_id in session".to_string(), + ) + })?; let target = agent .db @@ -261,106 +255,6 @@ pub async fn send_message( })) } -/// GET /api/v1/pentest/sessions/:id/stream — SSE endpoint for real-time events -/// -/// Returns recent messages as SSE events (polling approach). -/// True real-time streaming with broadcast channels will be added in a future iteration. -#[tracing::instrument(skip_all, fields(session_id = %id))] -pub async fn session_stream( - Extension(agent): AgentExt, - Path(id): Path, -) -> Result>>, StatusCode> -{ - let oid = - mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?; - - // Verify session exists - let _session = agent - .db - .pentest_sessions() - .find_one(doc! { "_id": oid }) - .await - .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)? - .ok_or(StatusCode::NOT_FOUND)?; - - // Fetch recent messages for this session - let messages: Vec = match agent - .db - .pentest_messages() - .find(doc! { "session_id": &id }) - .sort(doc! { "created_at": 1 }) - .limit(100) - .await - { - Ok(cursor) => collect_cursor_async(cursor).await, - Err(_) => Vec::new(), - }; - - // Fetch recent attack chain nodes - let nodes: Vec = match agent - .db - .attack_chain_nodes() - .find(doc! { "session_id": &id }) - .sort(doc! { "started_at": 1 }) - .limit(100) - .await - { - Ok(cursor) => collect_cursor_async(cursor).await, - Err(_) => Vec::new(), - }; - - // Build SSE events from stored data - let mut events: Vec> = Vec::new(); - - for msg in &messages { - let event_data = serde_json::json!({ - "type": "message", - "role": msg.role, - "content": msg.content, - "created_at": msg.created_at.to_rfc3339(), - }); - if let Ok(data) = serde_json::to_string(&event_data) { - events.push(Ok(Event::default().event("message").data(data))); - } - } - - for node in &nodes { - let event_data = serde_json::json!({ - "type": "tool_execution", - "node_id": node.node_id, - "tool_name": node.tool_name, - "status": node.status, - "findings_produced": node.findings_produced, - }); - if let Ok(data) = serde_json::to_string(&event_data) { - events.push(Ok(Event::default().event("tool").data(data))); - } - } - - // Add session status event - let session = agent - .db - .pentest_sessions() - .find_one(doc! { "_id": oid }) - .await - .ok() - .flatten(); - - if let Some(s) = session { - let status_data = serde_json::json!({ - "type": "status", - "status": s.status, - "findings_count": s.findings_count, - "tool_invocations": s.tool_invocations, - }); - if let Ok(data) = serde_json::to_string(&status_data) { - events.push(Ok(Event::default().event("status").data(data))); - } - } - - Ok(Sse::new(stream::iter(events))) -} - /// POST /api/v1/pentest/sessions/:id/stop — Stop a running pentest session #[tracing::instrument(skip_all, fields(session_id = %id))] pub async fn stop_session( @@ -375,7 +269,12 @@ pub async fn stop_session( .pentest_sessions() .find_one(doc! { "_id": oid }) .await - .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, format!("Database error: {e}")))? + .map_err(|e| { + ( + StatusCode::INTERNAL_SERVER_ERROR, + format!("Database error: {e}"), + ) + })? .ok_or_else(|| (StatusCode::NOT_FOUND, "Session not found".to_string()))?; if session.status != PentestStatus::Running { @@ -397,15 +296,30 @@ pub async fn stop_session( }}, ) .await - .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, format!("Database error: {e}")))?; + .map_err(|e| { + ( + StatusCode::INTERNAL_SERVER_ERROR, + format!("Database error: {e}"), + ) + })?; let updated = agent .db .pentest_sessions() .find_one(doc! { "_id": oid }) .await - .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, format!("Database error: {e}")))? - .ok_or_else(|| (StatusCode::NOT_FOUND, "Session not found after update".to_string()))?; + .map_err(|e| { + ( + StatusCode::INTERNAL_SERVER_ERROR, + format!("Database error: {e}"), + ) + })? + .ok_or_else(|| { + ( + StatusCode::NOT_FOUND, + "Session not found after update".to_string(), + ) + })?; Ok(Json(ApiResponse { data: updated, @@ -420,9 +334,7 @@ pub async fn get_attack_chain( Extension(agent): AgentExt, Path(id): Path, ) -> Result>>, StatusCode> { - // Verify the session ID is valid - let _oid = - mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?; + let _oid = mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?; let nodes = match agent .db @@ -453,8 +365,7 @@ pub async fn get_messages( Path(id): Path, Query(params): Query, ) -> Result>>, StatusCode> { - let _oid = - mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?; + let _oid = mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?; let skip = (params.page.saturating_sub(1)) * params.limit as u64; let total = agent @@ -487,95 +398,14 @@ pub async fn get_messages( })) } -/// GET /api/v1/pentest/stats — Aggregated pentest statistics -#[tracing::instrument(skip_all)] -pub async fn pentest_stats( - Extension(agent): AgentExt, -) -> Result>, StatusCode> { - let db = &agent.db; - - let running_sessions = db - .pentest_sessions() - .count_documents(doc! { "status": "running" }) - .await - .unwrap_or(0) as u32; - - // Count DAST findings from pentest sessions - let total_vulnerabilities = db - .dast_findings() - .count_documents(doc! { "session_id": { "$exists": true, "$ne": null } }) - .await - .unwrap_or(0) as u32; - - // Aggregate tool invocations from all sessions - let sessions: Vec = match db.pentest_sessions().find(doc! {}).await { - Ok(cursor) => collect_cursor_async(cursor).await, - Err(_) => Vec::new(), - }; - - let total_tool_invocations: u32 = sessions.iter().map(|s| s.tool_invocations).sum(); - let total_successes: u32 = sessions.iter().map(|s| s.tool_successes).sum(); - let tool_success_rate = if total_tool_invocations == 0 { - 100.0 - } else { - (total_successes as f64 / total_tool_invocations as f64) * 100.0 - }; - - // Severity distribution from pentest-related DAST findings - let critical = db - .dast_findings() - .count_documents(doc! { "session_id": { "$exists": true, "$ne": null }, "severity": "critical" }) - .await - .unwrap_or(0) as u32; - let high = db - .dast_findings() - .count_documents(doc! { "session_id": { "$exists": true, "$ne": null }, "severity": "high" }) - .await - .unwrap_or(0) as u32; - let medium = db - .dast_findings() - .count_documents(doc! { "session_id": { "$exists": true, "$ne": null }, "severity": "medium" }) - .await - .unwrap_or(0) as u32; - let low = db - .dast_findings() - .count_documents(doc! { "session_id": { "$exists": true, "$ne": null }, "severity": "low" }) - .await - .unwrap_or(0) as u32; - let info = db - .dast_findings() - .count_documents(doc! { "session_id": { "$exists": true, "$ne": null }, "severity": "info" }) - .await - .unwrap_or(0) as u32; - - Ok(Json(ApiResponse { - data: PentestStats { - running_sessions, - total_vulnerabilities, - total_tool_invocations, - tool_success_rate, - severity_distribution: SeverityDistribution { - critical, - high, - medium, - low, - info, - }, - }, - total: None, - page: None, - })) -} - /// GET /api/v1/pentest/sessions/:id/findings — Get DAST findings for a pentest session #[tracing::instrument(skip_all, fields(session_id = %id))] pub async fn get_session_findings( Extension(agent): AgentExt, Path(id): Path, Query(params): Query, -) -> Result>>, StatusCode> { - let _oid = - mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?; +) -> Result>>, StatusCode> { + let _oid = mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?; let skip = (params.page.saturating_sub(1)) * params.limit as u64; let total = agent @@ -607,112 +437,3 @@ pub async fn get_session_findings( page: Some(params.page), })) } - -#[derive(Deserialize)] -pub struct ExportBody { - pub password: String, - /// Requester display name (from auth) - #[serde(default)] - pub requester_name: String, - /// Requester email (from auth) - #[serde(default)] - pub requester_email: String, -} - -/// POST /api/v1/pentest/sessions/:id/export — Export an encrypted pentest report archive -#[tracing::instrument(skip_all, fields(session_id = %id))] -pub async fn export_session_report( - Extension(agent): AgentExt, - Path(id): Path, - Json(body): Json, -) -> Result { - let oid = mongodb::bson::oid::ObjectId::parse_str(&id) - .map_err(|_| (StatusCode::BAD_REQUEST, "Invalid session ID".to_string()))?; - - if body.password.len() < 8 { - return Err(( - StatusCode::BAD_REQUEST, - "Password must be at least 8 characters".to_string(), - )); - } - - // Fetch session - let session = agent - .db - .pentest_sessions() - .find_one(doc! { "_id": oid }) - .await - .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, format!("Database error: {e}")))? - .ok_or_else(|| (StatusCode::NOT_FOUND, "Session not found".to_string()))?; - - // Resolve target name - let target = if let Ok(tid) = mongodb::bson::oid::ObjectId::parse_str(&session.target_id) { - agent - .db - .dast_targets() - .find_one(doc! { "_id": tid }) - .await - .ok() - .flatten() - } else { - None - }; - let target_name = target - .as_ref() - .map(|t| t.name.clone()) - .unwrap_or_else(|| "Unknown Target".to_string()); - let target_url = target - .as_ref() - .map(|t| t.base_url.clone()) - .unwrap_or_default(); - - // Fetch attack chain nodes - let nodes: Vec = match agent - .db - .attack_chain_nodes() - .find(doc! { "session_id": &id }) - .sort(doc! { "started_at": 1 }) - .await - { - Ok(cursor) => collect_cursor_async(cursor).await, - Err(_) => Vec::new(), - }; - - // Fetch DAST findings for this session - let findings: Vec = match agent - .db - .dast_findings() - .find(doc! { "session_id": &id }) - .sort(doc! { "severity": -1, "created_at": -1 }) - .await - { - Ok(cursor) => collect_cursor_async(cursor).await, - Err(_) => Vec::new(), - }; - - let ctx = crate::pentest::report::ReportContext { - session, - target_name, - target_url, - findings, - attack_chain: nodes, - requester_name: if body.requester_name.is_empty() { - "Unknown".to_string() - } else { - body.requester_name - }, - requester_email: body.requester_email, - }; - - let report = crate::pentest::generate_encrypted_report(&ctx, &body.password) - .await - .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e))?; - - let response = serde_json::json!({ - "archive_base64": base64::Engine::encode(&base64::engine::general_purpose::STANDARD, &report.archive), - "sha256": report.sha256, - "filename": format!("pentest-report-{id}.zip"), - }); - - Ok(Json(response).into_response()) -} diff --git a/compliance-agent/src/api/handlers/pentest_handlers/stats.rs b/compliance-agent/src/api/handlers/pentest_handlers/stats.rs new file mode 100644 index 0000000..6333408 --- /dev/null +++ b/compliance-agent/src/api/handlers/pentest_handlers/stats.rs @@ -0,0 +1,102 @@ +use std::sync::Arc; + +use axum::extract::Extension; +use axum::http::StatusCode; +use axum::Json; +use mongodb::bson::doc; + +use compliance_core::models::pentest::*; + +use crate::agent::ComplianceAgent; + +use super::super::dto::{collect_cursor_async, ApiResponse}; + +type AgentExt = Extension>; + +/// GET /api/v1/pentest/stats — Aggregated pentest statistics +#[tracing::instrument(skip_all)] +pub async fn pentest_stats( + Extension(agent): AgentExt, +) -> Result>, StatusCode> { + let db = &agent.db; + + let running_sessions = db + .pentest_sessions() + .count_documents(doc! { "status": "running" }) + .await + .unwrap_or(0) as u32; + + // Count DAST findings from pentest sessions + let total_vulnerabilities = db + .dast_findings() + .count_documents(doc! { "session_id": { "$exists": true, "$ne": null } }) + .await + .unwrap_or(0) as u32; + + // Aggregate tool invocations from all sessions + let sessions: Vec = match db.pentest_sessions().find(doc! {}).await { + Ok(cursor) => collect_cursor_async(cursor).await, + Err(_) => Vec::new(), + }; + + let total_tool_invocations: u32 = sessions.iter().map(|s| s.tool_invocations).sum(); + let total_successes: u32 = sessions.iter().map(|s| s.tool_successes).sum(); + let tool_success_rate = if total_tool_invocations == 0 { + 100.0 + } else { + (total_successes as f64 / total_tool_invocations as f64) * 100.0 + }; + + // Severity distribution from pentest-related DAST findings + let critical = db + .dast_findings() + .count_documents( + doc! { "session_id": { "$exists": true, "$ne": null }, "severity": "critical" }, + ) + .await + .unwrap_or(0) as u32; + let high = db + .dast_findings() + .count_documents( + doc! { "session_id": { "$exists": true, "$ne": null }, "severity": "high" }, + ) + .await + .unwrap_or(0) as u32; + let medium = db + .dast_findings() + .count_documents( + doc! { "session_id": { "$exists": true, "$ne": null }, "severity": "medium" }, + ) + .await + .unwrap_or(0) as u32; + let low = db + .dast_findings() + .count_documents(doc! { "session_id": { "$exists": true, "$ne": null }, "severity": "low" }) + .await + .unwrap_or(0) as u32; + let info = db + .dast_findings() + .count_documents( + doc! { "session_id": { "$exists": true, "$ne": null }, "severity": "info" }, + ) + .await + .unwrap_or(0) as u32; + + Ok(Json(ApiResponse { + data: PentestStats { + running_sessions, + total_vulnerabilities, + total_tool_invocations, + tool_success_rate, + severity_distribution: SeverityDistribution { + critical, + high, + medium, + low, + info, + }, + }, + total: None, + page: None, + })) +} diff --git a/compliance-agent/src/api/handlers/pentest_handlers/stream.rs b/compliance-agent/src/api/handlers/pentest_handlers/stream.rs new file mode 100644 index 0000000..aa29cab --- /dev/null +++ b/compliance-agent/src/api/handlers/pentest_handlers/stream.rs @@ -0,0 +1,116 @@ +use std::sync::Arc; + +use axum::extract::{Extension, Path}; +use axum::http::StatusCode; +use axum::response::sse::{Event, Sse}; +use futures_util::stream; +use mongodb::bson::doc; + +use compliance_core::models::pentest::*; + +use crate::agent::ComplianceAgent; + +use super::super::dto::collect_cursor_async; + +type AgentExt = Extension>; + +/// GET /api/v1/pentest/sessions/:id/stream — SSE endpoint for real-time events +/// +/// Returns recent messages as SSE events (polling approach). +/// True real-time streaming with broadcast channels will be added in a future iteration. +#[tracing::instrument(skip_all, fields(session_id = %id))] +pub async fn session_stream( + Extension(agent): AgentExt, + Path(id): Path, +) -> Result< + Sse>>, + StatusCode, +> { + let oid = mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?; + + // Verify session exists + let _session = agent + .db + .pentest_sessions() + .find_one(doc! { "_id": oid }) + .await + .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)? + .ok_or(StatusCode::NOT_FOUND)?; + + // Fetch recent messages for this session + let messages: Vec = match agent + .db + .pentest_messages() + .find(doc! { "session_id": &id }) + .sort(doc! { "created_at": 1 }) + .limit(100) + .await + { + Ok(cursor) => collect_cursor_async(cursor).await, + Err(_) => Vec::new(), + }; + + // Fetch recent attack chain nodes + let nodes: Vec = match agent + .db + .attack_chain_nodes() + .find(doc! { "session_id": &id }) + .sort(doc! { "started_at": 1 }) + .limit(100) + .await + { + Ok(cursor) => collect_cursor_async(cursor).await, + Err(_) => Vec::new(), + }; + + // Build SSE events from stored data + let mut events: Vec> = Vec::new(); + + for msg in &messages { + let event_data = serde_json::json!({ + "type": "message", + "role": msg.role, + "content": msg.content, + "created_at": msg.created_at.to_rfc3339(), + }); + if let Ok(data) = serde_json::to_string(&event_data) { + events.push(Ok(Event::default().event("message").data(data))); + } + } + + for node in &nodes { + let event_data = serde_json::json!({ + "type": "tool_execution", + "node_id": node.node_id, + "tool_name": node.tool_name, + "status": node.status, + "findings_produced": node.findings_produced, + }); + if let Ok(data) = serde_json::to_string(&event_data) { + events.push(Ok(Event::default().event("tool").data(data))); + } + } + + // Add session status event + let session = agent + .db + .pentest_sessions() + .find_one(doc! { "_id": oid }) + .await + .ok() + .flatten(); + + if let Some(s) = session { + let status_data = serde_json::json!({ + "type": "status", + "status": s.status, + "findings_count": s.findings_count, + "tool_invocations": s.tool_invocations, + }); + if let Ok(data) = serde_json::to_string(&status_data) { + events.push(Ok(Event::default().event("status").data(data))); + } + } + + Ok(Sse::new(stream::iter(events))) +} diff --git a/compliance-agent/src/api/handlers/repos.rs b/compliance-agent/src/api/handlers/repos.rs new file mode 100644 index 0000000..7dfd77b --- /dev/null +++ b/compliance-agent/src/api/handlers/repos.rs @@ -0,0 +1,241 @@ +use axum::extract::{Extension, Path, Query}; +use axum::http::StatusCode; +use axum::Json; +use mongodb::bson::doc; + +use super::dto::*; +use compliance_core::models::*; + +#[tracing::instrument(skip_all)] +pub async fn list_repositories( + Extension(agent): AgentExt, + Query(params): Query, +) -> ApiResult> { + let db = &agent.db; + let skip = (params.page.saturating_sub(1)) * params.limit as u64; + let total = db + .repositories() + .count_documents(doc! {}) + .await + .unwrap_or(0); + + let repos = match db + .repositories() + .find(doc! {}) + .skip(skip) + .limit(params.limit) + .await + { + Ok(cursor) => collect_cursor_async(cursor).await, + Err(e) => { + tracing::warn!("Failed to fetch repositories: {e}"); + Vec::new() + } + }; + + Ok(Json(ApiResponse { + data: repos, + total: Some(total), + page: Some(params.page), + })) +} + +#[tracing::instrument(skip_all)] +pub async fn add_repository( + Extension(agent): AgentExt, + Json(req): Json, +) -> Result>, (StatusCode, String)> { + // Validate repository access before saving + let creds = crate::pipeline::git::RepoCredentials { + ssh_key_path: Some(agent.config.ssh_key_path.clone()), + auth_token: req.auth_token.clone(), + auth_username: req.auth_username.clone(), + }; + + if let Err(e) = crate::pipeline::git::GitOps::test_access(&req.git_url, &creds) { + return Err(( + StatusCode::BAD_REQUEST, + format!("Cannot access repository: {e}"), + )); + } + + let mut repo = TrackedRepository::new(req.name, req.git_url); + repo.default_branch = req.default_branch; + repo.auth_token = req.auth_token; + repo.auth_username = req.auth_username; + repo.tracker_type = req.tracker_type; + repo.tracker_owner = req.tracker_owner; + repo.tracker_repo = req.tracker_repo; + repo.tracker_token = req.tracker_token; + repo.scan_schedule = req.scan_schedule; + + agent + .db + .repositories() + .insert_one(&repo) + .await + .map_err(|_| { + ( + StatusCode::CONFLICT, + "Repository already exists".to_string(), + ) + })?; + + Ok(Json(ApiResponse { + data: repo, + total: None, + page: None, + })) +} + +#[tracing::instrument(skip_all, fields(repo_id = %id))] +pub async fn update_repository( + Extension(agent): AgentExt, + Path(id): Path, + Json(req): Json, +) -> Result, StatusCode> { + let oid = mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?; + + let mut set_doc = doc! { "updated_at": mongodb::bson::DateTime::now() }; + + if let Some(name) = &req.name { + set_doc.insert("name", name); + } + if let Some(branch) = &req.default_branch { + set_doc.insert("default_branch", branch); + } + if let Some(token) = &req.auth_token { + set_doc.insert("auth_token", token); + } + if let Some(username) = &req.auth_username { + set_doc.insert("auth_username", username); + } + if let Some(tracker_type) = &req.tracker_type { + set_doc.insert("tracker_type", tracker_type.to_string()); + } + if let Some(owner) = &req.tracker_owner { + set_doc.insert("tracker_owner", owner); + } + if let Some(repo) = &req.tracker_repo { + set_doc.insert("tracker_repo", repo); + } + if let Some(token) = &req.tracker_token { + set_doc.insert("tracker_token", token); + } + if let Some(schedule) = &req.scan_schedule { + set_doc.insert("scan_schedule", schedule); + } + + let result = agent + .db + .repositories() + .update_one(doc! { "_id": oid }, doc! { "$set": set_doc }) + .await + .map_err(|e| { + tracing::warn!("Failed to update repository: {e}"); + StatusCode::INTERNAL_SERVER_ERROR + })?; + + if result.matched_count == 0 { + return Err(StatusCode::NOT_FOUND); + } + + Ok(Json(serde_json::json!({ "status": "updated" }))) +} + +#[tracing::instrument(skip_all)] +pub async fn get_ssh_public_key( + Extension(agent): AgentExt, +) -> Result, StatusCode> { + let public_path = format!("{}.pub", agent.config.ssh_key_path); + let public_key = std::fs::read_to_string(&public_path).map_err(|_| StatusCode::NOT_FOUND)?; + Ok(Json(serde_json::json!({ "public_key": public_key.trim() }))) +} + +#[tracing::instrument(skip_all, fields(repo_id = %id))] +pub async fn trigger_scan( + Extension(agent): AgentExt, + Path(id): Path, +) -> Result, StatusCode> { + let agent_clone = (*agent).clone(); + tokio::spawn(async move { + if let Err(e) = agent_clone.run_scan(&id, ScanTrigger::Manual).await { + tracing::error!("Manual scan failed for {id}: {e}"); + } + }); + + Ok(Json(serde_json::json!({ "status": "scan_triggered" }))) +} + +/// Return the webhook secret for a repository (used by dashboard to display it) +pub async fn get_webhook_config( + Extension(agent): AgentExt, + Path(id): Path, +) -> Result, StatusCode> { + let oid = mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?; + let repo = agent + .db + .repositories() + .find_one(doc! { "_id": oid }) + .await + .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)? + .ok_or(StatusCode::NOT_FOUND)?; + + let tracker_type = repo + .tracker_type + .as_ref() + .map(|t| t.to_string()) + .unwrap_or_else(|| "gitea".to_string()); + + Ok(Json(serde_json::json!({ + "webhook_secret": repo.webhook_secret, + "tracker_type": tracker_type, + }))) +} + +#[tracing::instrument(skip_all, fields(repo_id = %id))] +pub async fn delete_repository( + Extension(agent): AgentExt, + Path(id): Path, +) -> Result, StatusCode> { + let oid = mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?; + let db = &agent.db; + + // Delete the repository + let result = db + .repositories() + .delete_one(doc! { "_id": oid }) + .await + .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; + + if result.deleted_count == 0 { + return Err(StatusCode::NOT_FOUND); + } + + // Cascade delete all related data + let _ = db.findings().delete_many(doc! { "repo_id": &id }).await; + let _ = db.sbom_entries().delete_many(doc! { "repo_id": &id }).await; + let _ = db.scan_runs().delete_many(doc! { "repo_id": &id }).await; + let _ = db.cve_alerts().delete_many(doc! { "repo_id": &id }).await; + let _ = db + .tracker_issues() + .delete_many(doc! { "repo_id": &id }) + .await; + let _ = db.graph_nodes().delete_many(doc! { "repo_id": &id }).await; + let _ = db.graph_edges().delete_many(doc! { "repo_id": &id }).await; + let _ = db.graph_builds().delete_many(doc! { "repo_id": &id }).await; + let _ = db + .impact_analyses() + .delete_many(doc! { "repo_id": &id }) + .await; + let _ = db + .code_embeddings() + .delete_many(doc! { "repo_id": &id }) + .await; + let _ = db + .embedding_builds() + .delete_many(doc! { "repo_id": &id }) + .await; + + Ok(Json(serde_json::json!({ "status": "deleted" }))) +} diff --git a/compliance-agent/src/api/handlers/sbom.rs b/compliance-agent/src/api/handlers/sbom.rs new file mode 100644 index 0000000..e9ec8ff --- /dev/null +++ b/compliance-agent/src/api/handlers/sbom.rs @@ -0,0 +1,379 @@ +use axum::extract::{Extension, Query}; +use axum::http::{header, StatusCode}; +use axum::response::IntoResponse; +use axum::Json; +use mongodb::bson::doc; + +use super::dto::*; +use compliance_core::models::SbomEntry; + +const COPYLEFT_LICENSES: &[&str] = &[ + "GPL-2.0", + "GPL-2.0-only", + "GPL-2.0-or-later", + "GPL-3.0", + "GPL-3.0-only", + "GPL-3.0-or-later", + "AGPL-3.0", + "AGPL-3.0-only", + "AGPL-3.0-or-later", + "LGPL-2.1", + "LGPL-2.1-only", + "LGPL-2.1-or-later", + "LGPL-3.0", + "LGPL-3.0-only", + "LGPL-3.0-or-later", + "MPL-2.0", +]; + +#[tracing::instrument(skip_all)] +pub async fn sbom_filters( + Extension(agent): AgentExt, +) -> Result, StatusCode> { + let db = &agent.db; + + let managers: Vec = db + .sbom_entries() + .distinct("package_manager", doc! {}) + .await + .unwrap_or_default() + .into_iter() + .filter_map(|v| v.as_str().map(|s| s.to_string())) + .filter(|s| !s.is_empty() && s != "unknown" && s != "file") + .collect(); + + let licenses: Vec = db + .sbom_entries() + .distinct("license", doc! {}) + .await + .unwrap_or_default() + .into_iter() + .filter_map(|v| v.as_str().map(|s| s.to_string())) + .filter(|s| !s.is_empty()) + .collect(); + + Ok(Json(serde_json::json!({ + "package_managers": managers, + "licenses": licenses, + }))) +} + +#[tracing::instrument(skip_all, fields(repo_id = ?filter.repo_id, package_manager = ?filter.package_manager))] +pub async fn list_sbom( + Extension(agent): AgentExt, + Query(filter): Query, +) -> ApiResult> { + let db = &agent.db; + let mut query = doc! {}; + + if let Some(repo_id) = &filter.repo_id { + query.insert("repo_id", repo_id); + } + if let Some(pm) = &filter.package_manager { + query.insert("package_manager", pm); + } + if let Some(q) = &filter.q { + if !q.is_empty() { + query.insert("name", doc! { "$regex": q, "$options": "i" }); + } + } + if let Some(has_vulns) = filter.has_vulns { + if has_vulns { + query.insert("known_vulnerabilities", doc! { "$exists": true, "$ne": [] }); + } else { + query.insert("known_vulnerabilities", doc! { "$size": 0 }); + } + } + if let Some(license) = &filter.license { + query.insert("license", license); + } + + let skip = (filter.page.saturating_sub(1)) * filter.limit as u64; + let total = db + .sbom_entries() + .count_documents(query.clone()) + .await + .unwrap_or(0); + + let entries = match db + .sbom_entries() + .find(query) + .sort(doc! { "name": 1 }) + .skip(skip) + .limit(filter.limit) + .await + { + Ok(cursor) => collect_cursor_async(cursor).await, + Err(e) => { + tracing::warn!("Failed to fetch SBOM entries: {e}"); + Vec::new() + } + }; + + Ok(Json(ApiResponse { + data: entries, + total: Some(total), + page: Some(filter.page), + })) +} + +#[tracing::instrument(skip_all)] +pub async fn export_sbom( + Extension(agent): AgentExt, + Query(params): Query, +) -> Result { + let db = &agent.db; + let entries: Vec = match db + .sbom_entries() + .find(doc! { "repo_id": ¶ms.repo_id }) + .await + { + Ok(cursor) => collect_cursor_async(cursor).await, + Err(e) => { + tracing::warn!("Failed to fetch SBOM entries for export: {e}"); + Vec::new() + } + }; + + let body = if params.format == "spdx" { + // SPDX 2.3 format + let packages: Vec = entries + .iter() + .enumerate() + .map(|(i, e)| { + serde_json::json!({ + "SPDXID": format!("SPDXRef-Package-{i}"), + "name": e.name, + "versionInfo": e.version, + "downloadLocation": "NOASSERTION", + "licenseConcluded": e.license.as_deref().unwrap_or("NOASSERTION"), + "externalRefs": e.purl.as_ref().map(|p| vec![serde_json::json!({ + "referenceCategory": "PACKAGE-MANAGER", + "referenceType": "purl", + "referenceLocator": p, + })]).unwrap_or_default(), + }) + }) + .collect(); + + serde_json::json!({ + "spdxVersion": "SPDX-2.3", + "dataLicense": "CC0-1.0", + "SPDXID": "SPDXRef-DOCUMENT", + "name": format!("sbom-{}", params.repo_id), + "documentNamespace": format!("https://compliance-scanner/sbom/{}", params.repo_id), + "packages": packages, + }) + } else { + // CycloneDX 1.5 format + let components: Vec = entries + .iter() + .map(|e| { + let mut comp = serde_json::json!({ + "type": "library", + "name": e.name, + "version": e.version, + "group": e.package_manager, + }); + if let Some(purl) = &e.purl { + comp["purl"] = serde_json::Value::String(purl.clone()); + } + if let Some(license) = &e.license { + comp["licenses"] = serde_json::json!([{ "license": { "id": license } }]); + } + if !e.known_vulnerabilities.is_empty() { + comp["vulnerabilities"] = serde_json::json!( + e.known_vulnerabilities.iter().map(|v| serde_json::json!({ + "id": v.id, + "source": { "name": v.source }, + "ratings": v.severity.as_ref().map(|s| vec![serde_json::json!({"severity": s})]).unwrap_or_default(), + })).collect::>() + ); + } + comp + }) + .collect(); + + serde_json::json!({ + "bomFormat": "CycloneDX", + "specVersion": "1.5", + "version": 1, + "metadata": { + "component": { + "type": "application", + "name": format!("repo-{}", params.repo_id), + } + }, + "components": components, + }) + }; + + let json_str = + serde_json::to_string_pretty(&body).map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; + let filename = if params.format == "spdx" { + format!("sbom-{}-spdx.json", params.repo_id) + } else { + format!("sbom-{}-cyclonedx.json", params.repo_id) + }; + + let disposition = format!("attachment; filename=\"{filename}\""); + Ok(( + [ + ( + header::CONTENT_TYPE, + header::HeaderValue::from_static("application/json"), + ), + ( + header::CONTENT_DISPOSITION, + header::HeaderValue::from_str(&disposition) + .unwrap_or_else(|_| header::HeaderValue::from_static("attachment")), + ), + ], + json_str, + )) +} + +#[tracing::instrument(skip_all)] +pub async fn license_summary( + Extension(agent): AgentExt, + Query(params): Query, +) -> ApiResult> { + let db = &agent.db; + let mut query = doc! {}; + if let Some(repo_id) = ¶ms.repo_id { + query.insert("repo_id", repo_id); + } + + let entries: Vec = match db.sbom_entries().find(query).await { + Ok(cursor) => collect_cursor_async(cursor).await, + Err(e) => { + tracing::warn!("Failed to fetch SBOM entries for license summary: {e}"); + Vec::new() + } + }; + + let mut license_map: std::collections::HashMap> = + std::collections::HashMap::new(); + for entry in &entries { + let lic = entry.license.as_deref().unwrap_or("Unknown").to_string(); + license_map.entry(lic).or_default().push(entry.name.clone()); + } + + let mut summaries: Vec = license_map + .into_iter() + .map(|(license, packages)| { + let is_copyleft = COPYLEFT_LICENSES + .iter() + .any(|c| license.to_uppercase().contains(&c.to_uppercase())); + LicenseSummary { + license, + count: packages.len() as u64, + is_copyleft, + packages, + } + }) + .collect(); + summaries.sort_by(|a, b| b.count.cmp(&a.count)); + + Ok(Json(ApiResponse { + data: summaries, + total: None, + page: None, + })) +} + +#[tracing::instrument(skip_all)] +pub async fn sbom_diff( + Extension(agent): AgentExt, + Query(params): Query, +) -> ApiResult { + let db = &agent.db; + + let entries_a: Vec = match db + .sbom_entries() + .find(doc! { "repo_id": ¶ms.repo_a }) + .await + { + Ok(cursor) => collect_cursor_async(cursor).await, + Err(e) => { + tracing::warn!("Failed to fetch SBOM entries for repo_a: {e}"); + Vec::new() + } + }; + + let entries_b: Vec = match db + .sbom_entries() + .find(doc! { "repo_id": ¶ms.repo_b }) + .await + { + Ok(cursor) => collect_cursor_async(cursor).await, + Err(e) => { + tracing::warn!("Failed to fetch SBOM entries for repo_b: {e}"); + Vec::new() + } + }; + + // Build maps by (name, package_manager) -> version + let map_a: std::collections::HashMap<(String, String), String> = entries_a + .iter() + .map(|e| { + ( + (e.name.clone(), e.package_manager.clone()), + e.version.clone(), + ) + }) + .collect(); + let map_b: std::collections::HashMap<(String, String), String> = entries_b + .iter() + .map(|e| { + ( + (e.name.clone(), e.package_manager.clone()), + e.version.clone(), + ) + }) + .collect(); + + let mut only_in_a = Vec::new(); + let mut version_changed = Vec::new(); + let mut common_count: u64 = 0; + + for (key, ver_a) in &map_a { + match map_b.get(key) { + None => only_in_a.push(SbomDiffEntry { + name: key.0.clone(), + version: ver_a.clone(), + package_manager: key.1.clone(), + }), + Some(ver_b) if ver_a != ver_b => { + version_changed.push(SbomVersionDiff { + name: key.0.clone(), + package_manager: key.1.clone(), + version_a: ver_a.clone(), + version_b: ver_b.clone(), + }); + } + Some(_) => common_count += 1, + } + } + + let only_in_b: Vec = map_b + .iter() + .filter(|(key, _)| !map_a.contains_key(key)) + .map(|(key, ver)| SbomDiffEntry { + name: key.0.clone(), + version: ver.clone(), + package_manager: key.1.clone(), + }) + .collect(); + + Ok(Json(ApiResponse { + data: SbomDiffResult { + only_in_a, + only_in_b, + version_changed, + common_count, + }, + total: None, + page: None, + })) +} diff --git a/compliance-agent/src/api/handlers/scans.rs b/compliance-agent/src/api/handlers/scans.rs new file mode 100644 index 0000000..ec16468 --- /dev/null +++ b/compliance-agent/src/api/handlers/scans.rs @@ -0,0 +1,37 @@ +use axum::extract::{Extension, Query}; +use axum::Json; +use mongodb::bson::doc; + +use super::dto::*; +use compliance_core::models::ScanRun; + +#[tracing::instrument(skip_all)] +pub async fn list_scan_runs( + Extension(agent): AgentExt, + Query(params): Query, +) -> ApiResult> { + let db = &agent.db; + let skip = (params.page.saturating_sub(1)) * params.limit as u64; + let total = db.scan_runs().count_documents(doc! {}).await.unwrap_or(0); + + let scans = match db + .scan_runs() + .find(doc! {}) + .sort(doc! { "started_at": -1 }) + .skip(skip) + .limit(params.limit) + .await + { + Ok(cursor) => collect_cursor_async(cursor).await, + Err(e) => { + tracing::warn!("Failed to fetch scan runs: {e}"); + Vec::new() + } + }; + + Ok(Json(ApiResponse { + data: scans, + total: Some(total), + page: Some(params.page), + })) +} diff --git a/compliance-agent/src/api/routes.rs b/compliance-agent/src/api/routes.rs index f878d7e..0b72262 100644 --- a/compliance-agent/src/api/routes.rs +++ b/compliance-agent/src/api/routes.rs @@ -136,7 +136,10 @@ pub fn build_router() -> Router { "/api/v1/pentest/sessions/{id}/export", post(handlers::pentest::export_session_report), ) - .route("/api/v1/pentest/stats", get(handlers::pentest::pentest_stats)) + .route( + "/api/v1/pentest/stats", + get(handlers::pentest::pentest_stats), + ) // Webhook endpoints (proxied through dashboard) .route( "/webhook/github/{repo_id}", diff --git a/compliance-agent/src/llm/client.rs b/compliance-agent/src/llm/client.rs index b6bc657..d3bfae6 100644 --- a/compliance-agent/src/llm/client.rs +++ b/compliance-agent/src/llm/client.rs @@ -1,147 +1,17 @@ use secrecy::{ExposeSecret, SecretString}; -use serde::{Deserialize, Serialize}; +use super::types::*; use crate::error::AgentError; #[derive(Clone)] pub struct LlmClient { - base_url: String, - api_key: SecretString, - model: String, - embed_model: String, - http: reqwest::Client, + pub(crate) base_url: String, + pub(crate) api_key: SecretString, + pub(crate) model: String, + pub(crate) embed_model: String, + pub(crate) http: reqwest::Client, } -// ── Request types ────────────────────────────────────────────── - -#[derive(Serialize, Clone, Debug)] -pub struct ChatMessage { - pub role: String, - pub content: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub tool_calls: Option>, - #[serde(skip_serializing_if = "Option::is_none")] - pub tool_call_id: Option, -} - -#[derive(Serialize)] -struct ChatCompletionRequest { - model: String, - messages: Vec, - #[serde(skip_serializing_if = "Option::is_none")] - temperature: Option, - #[serde(skip_serializing_if = "Option::is_none")] - max_tokens: Option, - #[serde(skip_serializing_if = "Option::is_none")] - tools: Option>, -} - -#[derive(Serialize)] -struct ToolDefinitionPayload { - r#type: String, - function: ToolFunctionPayload, -} - -#[derive(Serialize)] -struct ToolFunctionPayload { - name: String, - description: String, - parameters: serde_json::Value, -} - -// ── Response types ───────────────────────────────────────────── - -#[derive(Deserialize)] -struct ChatCompletionResponse { - choices: Vec, -} - -#[derive(Deserialize)] -struct ChatChoice { - message: ChatResponseMessage, -} - -#[derive(Deserialize)] -struct ChatResponseMessage { - #[serde(default)] - content: Option, - #[serde(default)] - tool_calls: Option>, -} - -#[derive(Deserialize)] -struct ToolCallResponse { - id: String, - function: ToolCallFunction, -} - -#[derive(Deserialize)] -struct ToolCallFunction { - name: String, - arguments: String, -} - -// ── Public types for tool calling ────────────────────────────── - -/// Definition of a tool that the LLM can invoke -#[derive(Debug, Clone, Serialize)] -pub struct ToolDefinition { - pub name: String, - pub description: String, - pub parameters: serde_json::Value, -} - -/// A tool call request from the LLM -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct LlmToolCall { - pub id: String, - pub name: String, - pub arguments: serde_json::Value, -} - -/// A tool call in the request message format (for sending back tool_calls in assistant messages) -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct ToolCallRequest { - pub id: String, - pub r#type: String, - pub function: ToolCallRequestFunction, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct ToolCallRequestFunction { - pub name: String, - pub arguments: String, -} - -/// Response from the LLM — either content or tool calls -#[derive(Debug, Clone)] -pub enum LlmResponse { - Content(String), - /// Tool calls with optional reasoning text from the LLM - ToolCalls { calls: Vec, reasoning: String }, -} - -// ── Embedding types ──────────────────────────────────────────── - -#[derive(Serialize)] -struct EmbeddingRequest { - model: String, - input: Vec, -} - -#[derive(Deserialize)] -struct EmbeddingResponse { - data: Vec, -} - -#[derive(Deserialize)] -struct EmbeddingData { - embedding: Vec, - index: usize, -} - -// ── Implementation ───────────────────────────────────────────── - impl LlmClient { pub fn new( base_url: String, @@ -158,18 +28,14 @@ impl LlmClient { } } - pub fn embed_model(&self) -> &str { - &self.embed_model - } - - fn chat_url(&self) -> String { + pub(crate) fn chat_url(&self) -> String { format!( "{}/v1/chat/completions", self.base_url.trim_end_matches('/') ) } - fn auth_header(&self) -> Option { + pub(crate) fn auth_header(&self) -> Option { let key = self.api_key.expose_secret(); if key.is_empty() { None @@ -241,12 +107,12 @@ impl LlmClient { tools: None, }; - self.send_chat_request(&request_body).await.map(|resp| { - match resp { + self.send_chat_request(&request_body) + .await + .map(|resp| match resp { LlmResponse::Content(c) => c, LlmResponse::ToolCalls { .. } => String::new(), - } - }) + }) } /// Chat with tool definitions — returns either content or tool calls. @@ -292,7 +158,7 @@ impl LlmClient { ) -> Result { let mut req = self .http - .post(&self.chat_url()) + .post(self.chat_url()) .header("content-type", "application/json") .json(request_body); @@ -345,54 +211,7 @@ impl LlmClient { } // Otherwise return content - let content = choice - .message - .content - .clone() - .unwrap_or_default(); + let content = choice.message.content.clone().unwrap_or_default(); Ok(LlmResponse::Content(content)) } - - /// Generate embeddings for a batch of texts - pub async fn embed(&self, texts: Vec) -> Result>, AgentError> { - let url = format!("{}/v1/embeddings", self.base_url.trim_end_matches('/')); - - let request_body = EmbeddingRequest { - model: self.embed_model.clone(), - input: texts, - }; - - let mut req = self - .http - .post(&url) - .header("content-type", "application/json") - .json(&request_body); - - if let Some(auth) = self.auth_header() { - req = req.header("Authorization", auth); - } - - let resp = req - .send() - .await - .map_err(|e| AgentError::Other(format!("Embedding request failed: {e}")))?; - - if !resp.status().is_success() { - let status = resp.status(); - let body = resp.text().await.unwrap_or_default(); - return Err(AgentError::Other(format!( - "Embedding API returned {status}: {body}" - ))); - } - - let body: EmbeddingResponse = resp - .json() - .await - .map_err(|e| AgentError::Other(format!("Failed to parse embedding response: {e}")))?; - - let mut data = body.data; - data.sort_by_key(|d| d.index); - - Ok(data.into_iter().map(|d| d.embedding).collect()) - } } diff --git a/compliance-agent/src/llm/embedding.rs b/compliance-agent/src/llm/embedding.rs new file mode 100644 index 0000000..3e8e0d9 --- /dev/null +++ b/compliance-agent/src/llm/embedding.rs @@ -0,0 +1,74 @@ +use serde::{Deserialize, Serialize}; + +use super::client::LlmClient; +use crate::error::AgentError; + +// ── Embedding types ──────────────────────────────────────────── + +#[derive(Serialize)] +struct EmbeddingRequest { + model: String, + input: Vec, +} + +#[derive(Deserialize)] +struct EmbeddingResponse { + data: Vec, +} + +#[derive(Deserialize)] +struct EmbeddingData { + embedding: Vec, + index: usize, +} + +// ── Embedding implementation ─────────────────────────────────── + +impl LlmClient { + pub fn embed_model(&self) -> &str { + &self.embed_model + } + + /// Generate embeddings for a batch of texts + pub async fn embed(&self, texts: Vec) -> Result>, AgentError> { + let url = format!("{}/v1/embeddings", self.base_url.trim_end_matches('/')); + + let request_body = EmbeddingRequest { + model: self.embed_model.clone(), + input: texts, + }; + + let mut req = self + .http + .post(&url) + .header("content-type", "application/json") + .json(&request_body); + + if let Some(auth) = self.auth_header() { + req = req.header("Authorization", auth); + } + + let resp = req + .send() + .await + .map_err(|e| AgentError::Other(format!("Embedding request failed: {e}")))?; + + if !resp.status().is_success() { + let status = resp.status(); + let body = resp.text().await.unwrap_or_default(); + return Err(AgentError::Other(format!( + "Embedding API returned {status}: {body}" + ))); + } + + let body: EmbeddingResponse = resp + .json() + .await + .map_err(|e| AgentError::Other(format!("Failed to parse embedding response: {e}")))?; + + let mut data = body.data; + data.sort_by_key(|d| d.index); + + Ok(data.into_iter().map(|d| d.embedding).collect()) + } +} diff --git a/compliance-agent/src/llm/mod.rs b/compliance-agent/src/llm/mod.rs index 17e8550..9fd6b9a 100644 --- a/compliance-agent/src/llm/mod.rs +++ b/compliance-agent/src/llm/mod.rs @@ -1,11 +1,16 @@ pub mod client; #[allow(dead_code)] pub mod descriptions; +pub mod embedding; #[allow(dead_code)] pub mod fixes; #[allow(dead_code)] pub mod pr_review; pub mod review_prompts; pub mod triage; +pub mod types; pub use client::LlmClient; +pub use types::{ + ChatMessage, LlmResponse, ToolCallRequest, ToolCallRequestFunction, ToolDefinition, +}; diff --git a/compliance-agent/src/llm/triage.rs b/compliance-agent/src/llm/triage.rs index 62d056d..0050bde 100644 --- a/compliance-agent/src/llm/triage.rs +++ b/compliance-agent/src/llm/triage.rs @@ -278,3 +278,220 @@ struct TriageResult { fn default_action() -> String { "confirm".to_string() } + +#[cfg(test)] +mod tests { + use super::*; + use compliance_core::models::Severity; + + // ── classify_file_path ─────────────────────────────────────── + + #[test] + fn classify_none_path() { + assert_eq!(classify_file_path(None), "unknown"); + } + + #[test] + fn classify_production_path() { + assert_eq!(classify_file_path(Some("src/main.rs")), "production"); + assert_eq!(classify_file_path(Some("lib/core/engine.py")), "production"); + } + + #[test] + fn classify_test_paths() { + assert_eq!(classify_file_path(Some("src/test/helper.rs")), "test"); + assert_eq!(classify_file_path(Some("src/tests/unit.rs")), "test"); + assert_eq!(classify_file_path(Some("foo_test.go")), "test"); + assert_eq!(classify_file_path(Some("bar.test.js")), "test"); + assert_eq!(classify_file_path(Some("baz.spec.ts")), "test"); + assert_eq!( + classify_file_path(Some("data/fixtures/sample.json")), + "test" + ); + assert_eq!(classify_file_path(Some("src/testdata/input.txt")), "test"); + } + + #[test] + fn classify_example_paths() { + assert_eq!( + classify_file_path(Some("docs/examples/basic.rs")), + "example" + ); + // /example matches because contains("/example") + assert_eq!(classify_file_path(Some("src/example/main.py")), "example"); + assert_eq!(classify_file_path(Some("src/demo/run.sh")), "example"); + assert_eq!(classify_file_path(Some("src/sample/lib.rs")), "example"); + } + + #[test] + fn classify_generated_paths() { + assert_eq!( + classify_file_path(Some("src/generated/api.rs")), + "generated" + ); + assert_eq!( + classify_file_path(Some("proto/gen/service.go")), + "generated" + ); + assert_eq!(classify_file_path(Some("api.generated.ts")), "generated"); + assert_eq!(classify_file_path(Some("service.pb.go")), "generated"); + assert_eq!(classify_file_path(Some("model_generated.rs")), "generated"); + } + + #[test] + fn classify_vendored_paths() { + // Implementation checks for /vendor/, /node_modules/, /third_party/ (with slashes) + assert_eq!( + classify_file_path(Some("src/vendor/lib/foo.go")), + "vendored" + ); + assert_eq!( + classify_file_path(Some("src/node_modules/pkg/index.js")), + "vendored" + ); + assert_eq!( + classify_file_path(Some("src/third_party/lib.c")), + "vendored" + ); + } + + #[test] + fn classify_is_case_insensitive() { + assert_eq!(classify_file_path(Some("src/TEST/Helper.rs")), "test"); + assert_eq!(classify_file_path(Some("src/VENDOR/lib.go")), "vendored"); + assert_eq!( + classify_file_path(Some("src/GENERATED/foo.ts")), + "generated" + ); + } + + // ── adjust_confidence ──────────────────────────────────────── + + #[test] + fn adjust_confidence_production() { + assert_eq!(adjust_confidence(8.0, "production"), 8.0); + } + + #[test] + fn adjust_confidence_test() { + assert_eq!(adjust_confidence(10.0, "test"), 5.0); + } + + #[test] + fn adjust_confidence_example() { + assert_eq!(adjust_confidence(10.0, "example"), 6.0); + } + + #[test] + fn adjust_confidence_generated() { + assert_eq!(adjust_confidence(10.0, "generated"), 3.0); + } + + #[test] + fn adjust_confidence_vendored() { + assert_eq!(adjust_confidence(10.0, "vendored"), 4.0); + } + + #[test] + fn adjust_confidence_unknown_classification() { + assert_eq!(adjust_confidence(7.0, "unknown"), 7.0); + assert_eq!(adjust_confidence(7.0, "something_else"), 7.0); + } + + #[test] + fn adjust_confidence_zero() { + assert_eq!(adjust_confidence(0.0, "test"), 0.0); + assert_eq!(adjust_confidence(0.0, "production"), 0.0); + } + + // ── downgrade_severity ─────────────────────────────────────── + + #[test] + fn downgrade_severity_all_levels() { + assert_eq!(downgrade_severity(&Severity::Critical), Severity::High); + assert_eq!(downgrade_severity(&Severity::High), Severity::Medium); + assert_eq!(downgrade_severity(&Severity::Medium), Severity::Low); + assert_eq!(downgrade_severity(&Severity::Low), Severity::Info); + assert_eq!(downgrade_severity(&Severity::Info), Severity::Info); + } + + #[test] + fn downgrade_severity_info_is_floor() { + // Downgrading Info twice should still be Info + let s = downgrade_severity(&Severity::Info); + assert_eq!(downgrade_severity(&s), Severity::Info); + } + + // ── upgrade_severity ───────────────────────────────────────── + + #[test] + fn upgrade_severity_all_levels() { + assert_eq!(upgrade_severity(&Severity::Info), Severity::Low); + assert_eq!(upgrade_severity(&Severity::Low), Severity::Medium); + assert_eq!(upgrade_severity(&Severity::Medium), Severity::High); + assert_eq!(upgrade_severity(&Severity::High), Severity::Critical); + assert_eq!(upgrade_severity(&Severity::Critical), Severity::Critical); + } + + #[test] + fn upgrade_severity_critical_is_ceiling() { + let s = upgrade_severity(&Severity::Critical); + assert_eq!(upgrade_severity(&s), Severity::Critical); + } + + // ── upgrade/downgrade roundtrip ────────────────────────────── + + #[test] + fn upgrade_then_downgrade_is_identity_for_middle_values() { + for sev in [Severity::Low, Severity::Medium, Severity::High] { + assert_eq!(downgrade_severity(&upgrade_severity(&sev)), sev); + } + } + + // ── TriageResult deserialization ───────────────────────────── + + #[test] + fn triage_result_full() { + let json = r#"{"action":"dismiss","confidence":8.5,"rationale":"false positive","remediation":"remove code"}"#; + let r: TriageResult = serde_json::from_str(json).unwrap(); + assert_eq!(r.action, "dismiss"); + assert_eq!(r.confidence, 8.5); + assert_eq!(r.rationale, "false positive"); + assert_eq!(r.remediation.as_deref(), Some("remove code")); + } + + #[test] + fn triage_result_defaults() { + let json = r#"{}"#; + let r: TriageResult = serde_json::from_str(json).unwrap(); + assert_eq!(r.action, "confirm"); + assert_eq!(r.confidence, 0.0); + assert_eq!(r.rationale, ""); + assert!(r.remediation.is_none()); + } + + #[test] + fn triage_result_partial() { + let json = r#"{"action":"downgrade","confidence":6.0}"#; + let r: TriageResult = serde_json::from_str(json).unwrap(); + assert_eq!(r.action, "downgrade"); + assert_eq!(r.confidence, 6.0); + assert_eq!(r.rationale, ""); + assert!(r.remediation.is_none()); + } + + #[test] + fn triage_result_with_markdown_fences() { + // Simulate LLM wrapping response in markdown code fences + let raw = "```json\n{\"action\":\"upgrade\",\"confidence\":9,\"rationale\":\"critical\",\"remediation\":null}\n```"; + let cleaned = raw + .trim() + .trim_start_matches("```json") + .trim_start_matches("```") + .trim_end_matches("```") + .trim(); + let r: TriageResult = serde_json::from_str(cleaned).unwrap(); + assert_eq!(r.action, "upgrade"); + assert_eq!(r.confidence, 9.0); + } +} diff --git a/compliance-agent/src/llm/types.rs b/compliance-agent/src/llm/types.rs new file mode 100644 index 0000000..110d790 --- /dev/null +++ b/compliance-agent/src/llm/types.rs @@ -0,0 +1,369 @@ +use serde::{Deserialize, Serialize}; + +// ── Request types ────────────────────────────────────────────── + +#[derive(Serialize, Clone, Debug)] +pub struct ChatMessage { + pub role: String, + pub content: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_calls: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_call_id: Option, +} + +#[derive(Serialize)] +pub(crate) struct ChatCompletionRequest { + pub(crate) model: String, + pub(crate) messages: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + pub(crate) temperature: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub(crate) max_tokens: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub(crate) tools: Option>, +} + +#[derive(Serialize)] +pub(crate) struct ToolDefinitionPayload { + pub(crate) r#type: String, + pub(crate) function: ToolFunctionPayload, +} + +#[derive(Serialize)] +pub(crate) struct ToolFunctionPayload { + pub(crate) name: String, + pub(crate) description: String, + pub(crate) parameters: serde_json::Value, +} + +// ── Response types ───────────────────────────────────────────── + +#[derive(Deserialize)] +pub(crate) struct ChatCompletionResponse { + pub(crate) choices: Vec, +} + +#[derive(Deserialize)] +pub(crate) struct ChatChoice { + pub(crate) message: ChatResponseMessage, +} + +#[derive(Deserialize)] +pub(crate) struct ChatResponseMessage { + #[serde(default)] + pub(crate) content: Option, + #[serde(default)] + pub(crate) tool_calls: Option>, +} + +#[derive(Deserialize)] +pub(crate) struct ToolCallResponse { + pub(crate) id: String, + pub(crate) function: ToolCallFunction, +} + +#[derive(Deserialize)] +pub(crate) struct ToolCallFunction { + pub(crate) name: String, + pub(crate) arguments: String, +} + +// ── Public types for tool calling ────────────────────────────── + +/// Definition of a tool that the LLM can invoke +#[derive(Debug, Clone, Serialize)] +pub struct ToolDefinition { + pub name: String, + pub description: String, + pub parameters: serde_json::Value, +} + +/// A tool call request from the LLM +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct LlmToolCall { + pub id: String, + pub name: String, + pub arguments: serde_json::Value, +} + +/// A tool call in the request message format (for sending back tool_calls in assistant messages) +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ToolCallRequest { + pub id: String, + pub r#type: String, + pub function: ToolCallRequestFunction, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ToolCallRequestFunction { + pub name: String, + pub arguments: String, +} + +/// Response from the LLM — either content or tool calls +#[derive(Debug, Clone)] +pub enum LlmResponse { + Content(String), + /// Tool calls with optional reasoning text from the LLM + ToolCalls { + calls: Vec, + reasoning: String, + }, +} + +#[cfg(test)] +mod tests { + use super::*; + use serde_json::json; + + // ── ChatMessage ────────────────────────────────────────────── + + #[test] + fn chat_message_serializes_minimal() { + let msg = ChatMessage { + role: "user".to_string(), + content: Some("hello".to_string()), + tool_calls: None, + tool_call_id: None, + }; + let v = serde_json::to_value(&msg).unwrap(); + assert_eq!(v["role"], "user"); + assert_eq!(v["content"], "hello"); + // None fields with skip_serializing_if should be absent + assert!(v.get("tool_calls").is_none()); + assert!(v.get("tool_call_id").is_none()); + } + + #[test] + fn chat_message_serializes_with_tool_calls() { + let msg = ChatMessage { + role: "assistant".to_string(), + content: None, + tool_calls: Some(vec![ToolCallRequest { + id: "call_1".to_string(), + r#type: "function".to_string(), + function: ToolCallRequestFunction { + name: "get_weather".to_string(), + arguments: r#"{"city":"NYC"}"#.to_string(), + }, + }]), + tool_call_id: None, + }; + let v = serde_json::to_value(&msg).unwrap(); + assert!(v["tool_calls"].is_array()); + assert_eq!(v["tool_calls"][0]["function"]["name"], "get_weather"); + } + + #[test] + fn chat_message_content_null_when_none() { + let msg = ChatMessage { + role: "assistant".to_string(), + content: None, + tool_calls: None, + tool_call_id: None, + }; + let v = serde_json::to_value(&msg).unwrap(); + assert!(v["content"].is_null()); + } + + // ── ToolDefinition ─────────────────────────────────────────── + + #[test] + fn tool_definition_serializes() { + let td = ToolDefinition { + name: "search".to_string(), + description: "Search the web".to_string(), + parameters: json!({"type": "object", "properties": {"q": {"type": "string"}}}), + }; + let v = serde_json::to_value(&td).unwrap(); + assert_eq!(v["name"], "search"); + assert_eq!(v["parameters"]["type"], "object"); + } + + #[test] + fn tool_definition_empty_parameters() { + let td = ToolDefinition { + name: "noop".to_string(), + description: "".to_string(), + parameters: json!({}), + }; + let v = serde_json::to_value(&td).unwrap(); + assert_eq!(v["parameters"], json!({})); + } + + // ── LlmToolCall ────────────────────────────────────────────── + + #[test] + fn llm_tool_call_roundtrip() { + let call = LlmToolCall { + id: "tc_42".to_string(), + name: "run_scan".to_string(), + arguments: json!({"path": "/tmp", "verbose": true}), + }; + let serialized = serde_json::to_string(&call).unwrap(); + let deserialized: LlmToolCall = serde_json::from_str(&serialized).unwrap(); + assert_eq!(deserialized.id, "tc_42"); + assert_eq!(deserialized.name, "run_scan"); + assert_eq!(deserialized.arguments["path"], "/tmp"); + assert_eq!(deserialized.arguments["verbose"], true); + } + + #[test] + fn llm_tool_call_empty_arguments() { + let call = LlmToolCall { + id: "tc_0".to_string(), + name: "noop".to_string(), + arguments: json!({}), + }; + let rt: LlmToolCall = serde_json::from_str(&serde_json::to_string(&call).unwrap()).unwrap(); + assert!(rt.arguments.as_object().unwrap().is_empty()); + } + + // ── ToolCallRequest / ToolCallRequestFunction ──────────────── + + #[test] + fn tool_call_request_roundtrip() { + let req = ToolCallRequest { + id: "call_abc".to_string(), + r#type: "function".to_string(), + function: ToolCallRequestFunction { + name: "my_func".to_string(), + arguments: r#"{"x":1}"#.to_string(), + }, + }; + let json_str = serde_json::to_string(&req).unwrap(); + let back: ToolCallRequest = serde_json::from_str(&json_str).unwrap(); + assert_eq!(back.id, "call_abc"); + assert_eq!(back.r#type, "function"); + assert_eq!(back.function.name, "my_func"); + assert_eq!(back.function.arguments, r#"{"x":1}"#); + } + + #[test] + fn tool_call_request_type_field_serializes_as_type() { + let req = ToolCallRequest { + id: "id".to_string(), + r#type: "function".to_string(), + function: ToolCallRequestFunction { + name: "f".to_string(), + arguments: "{}".to_string(), + }, + }; + let v = serde_json::to_value(&req).unwrap(); + // The field should be "type" in JSON, not "r#type" + assert!(v.get("type").is_some()); + assert!(v.get("r#type").is_none()); + } + + // ── ChatCompletionRequest ──────────────────────────────────── + + #[test] + fn chat_completion_request_skips_none_fields() { + let req = ChatCompletionRequest { + model: "gpt-4".to_string(), + messages: vec![], + temperature: None, + max_tokens: None, + tools: None, + }; + let v = serde_json::to_value(&req).unwrap(); + assert_eq!(v["model"], "gpt-4"); + assert!(v.get("temperature").is_none()); + assert!(v.get("max_tokens").is_none()); + assert!(v.get("tools").is_none()); + } + + #[test] + fn chat_completion_request_includes_set_fields() { + let req = ChatCompletionRequest { + model: "gpt-4".to_string(), + messages: vec![], + temperature: Some(0.7), + max_tokens: Some(1024), + tools: Some(vec![]), + }; + let v = serde_json::to_value(&req).unwrap(); + assert_eq!(v["temperature"], 0.7); + assert_eq!(v["max_tokens"], 1024); + assert!(v["tools"].is_array()); + } + + // ── ChatCompletionResponse deserialization ─────────────────── + + #[test] + fn chat_completion_response_deserializes_content() { + let json_str = r#"{"choices":[{"message":{"content":"Hello!"}}]}"#; + let resp: ChatCompletionResponse = serde_json::from_str(json_str).unwrap(); + assert_eq!(resp.choices.len(), 1); + assert_eq!(resp.choices[0].message.content.as_deref(), Some("Hello!")); + assert!(resp.choices[0].message.tool_calls.is_none()); + } + + #[test] + fn chat_completion_response_deserializes_tool_calls() { + let json_str = r#"{ + "choices": [{ + "message": { + "tool_calls": [{ + "id": "call_1", + "function": {"name": "search", "arguments": "{\"q\":\"rust\"}"} + }] + } + }] + }"#; + let resp: ChatCompletionResponse = serde_json::from_str(json_str).unwrap(); + let tc = resp.choices[0].message.tool_calls.as_ref().unwrap(); + assert_eq!(tc.len(), 1); + assert_eq!(tc[0].id, "call_1"); + assert_eq!(tc[0].function.name, "search"); + } + + #[test] + fn chat_completion_response_defaults_missing_fields() { + // content and tool_calls are both missing — should default to None + let json_str = r#"{"choices":[{"message":{}}]}"#; + let resp: ChatCompletionResponse = serde_json::from_str(json_str).unwrap(); + assert!(resp.choices[0].message.content.is_none()); + assert!(resp.choices[0].message.tool_calls.is_none()); + } + + // ── LlmResponse ───────────────────────────────────────────── + + #[test] + fn llm_response_content_variant() { + let resp = LlmResponse::Content("answer".to_string()); + match resp { + LlmResponse::Content(s) => assert_eq!(s, "answer"), + _ => panic!("expected Content variant"), + } + } + + #[test] + fn llm_response_tool_calls_variant() { + let resp = LlmResponse::ToolCalls { + calls: vec![LlmToolCall { + id: "1".to_string(), + name: "f".to_string(), + arguments: json!({}), + }], + reasoning: "because".to_string(), + }; + match resp { + LlmResponse::ToolCalls { calls, reasoning } => { + assert_eq!(calls.len(), 1); + assert_eq!(reasoning, "because"); + } + _ => panic!("expected ToolCalls variant"), + } + } + + #[test] + fn llm_response_empty_content() { + let resp = LlmResponse::Content(String::new()); + match resp { + LlmResponse::Content(s) => assert!(s.is_empty()), + _ => panic!("expected Content variant"), + } + } +} diff --git a/compliance-agent/src/pentest/context.rs b/compliance-agent/src/pentest/context.rs new file mode 100644 index 0000000..6ec9980 --- /dev/null +++ b/compliance-agent/src/pentest/context.rs @@ -0,0 +1,150 @@ +use futures_util::StreamExt; +use mongodb::bson::doc; + +use compliance_core::models::dast::DastTarget; +use compliance_core::models::finding::Finding; +use compliance_core::models::pentest::CodeContextHint; +use compliance_core::models::sbom::SbomEntry; + +use super::orchestrator::PentestOrchestrator; + +impl PentestOrchestrator { + /// Fetch SAST findings, SBOM entries (with CVEs), and code graph entry points + /// for the repo linked to this DAST target. + pub(crate) async fn gather_repo_context( + &self, + target: &DastTarget, + ) -> (Vec, Vec, Vec) { + let Some(repo_id) = &target.repo_id else { + return (Vec::new(), Vec::new(), Vec::new()); + }; + + let sast_findings = self.fetch_sast_findings(repo_id).await; + let sbom_entries = self.fetch_vulnerable_sbom(repo_id).await; + let code_context = self.fetch_code_context(repo_id, &sast_findings).await; + + tracing::info!( + repo_id, + sast_findings = sast_findings.len(), + vulnerable_deps = sbom_entries.len(), + code_hints = code_context.len(), + "Gathered code-awareness context for pentest" + ); + + (sast_findings, sbom_entries, code_context) + } + + /// Fetch open/triaged SAST findings for the repo (not false positives or resolved) + async fn fetch_sast_findings(&self, repo_id: &str) -> Vec { + let cursor = self + .db + .findings() + .find(doc! { + "repo_id": repo_id, + "status": { "$in": ["open", "triaged"] }, + }) + .sort(doc! { "severity": -1 }) + .limit(100) + .await; + + match cursor { + Ok(mut c) => { + let mut results = Vec::new(); + while let Some(Ok(f)) = c.next().await { + results.push(f); + } + results + } + Err(e) => { + tracing::warn!("Failed to fetch SAST findings for pentest: {e}"); + Vec::new() + } + } + } + + /// Fetch SBOM entries that have known vulnerabilities + async fn fetch_vulnerable_sbom(&self, repo_id: &str) -> Vec { + let cursor = self + .db + .sbom_entries() + .find(doc! { + "repo_id": repo_id, + "known_vulnerabilities": { "$exists": true, "$ne": [] }, + }) + .limit(50) + .await; + + match cursor { + Ok(mut c) => { + let mut results = Vec::new(); + while let Some(Ok(e)) = c.next().await { + results.push(e); + } + results + } + Err(e) => { + tracing::warn!("Failed to fetch vulnerable SBOM entries: {e}"); + Vec::new() + } + } + } + + /// Build CodeContextHint objects from the code knowledge graph. + /// Maps entry points to their source files and links SAST findings. + async fn fetch_code_context( + &self, + repo_id: &str, + sast_findings: &[Finding], + ) -> Vec { + // Get entry point nodes from the code graph + let cursor = self + .db + .graph_nodes() + .find(doc! { + "repo_id": repo_id, + "is_entry_point": true, + }) + .limit(50) + .await; + + let nodes = match cursor { + Ok(mut c) => { + let mut results = Vec::new(); + while let Some(Ok(n)) = c.next().await { + results.push(n); + } + results + } + Err(_) => return Vec::new(), + }; + + // Build hints by matching graph nodes to SAST findings by file path + nodes + .into_iter() + .map(|node| { + // Find SAST findings in the same file + let linked_vulns: Vec = sast_findings + .iter() + .filter(|f| f.file_path.as_deref() == Some(&node.file_path)) + .map(|f| { + format!( + "[{}] {}: {} (line {})", + f.severity, + f.scanner, + f.title, + f.line_number.unwrap_or(0) + ) + }) + .collect(); + + CodeContextHint { + endpoint_pattern: node.qualified_name.clone(), + handler_function: node.name.clone(), + file_path: node.file_path.clone(), + code_snippet: String::new(), // Could fetch from embeddings + known_vulnerabilities: linked_vulns, + } + }) + .collect() + } +} diff --git a/compliance-agent/src/pentest/mod.rs b/compliance-agent/src/pentest/mod.rs index 934315a..6aa5bfb 100644 --- a/compliance-agent/src/pentest/mod.rs +++ b/compliance-agent/src/pentest/mod.rs @@ -1,4 +1,6 @@ +mod context; pub mod orchestrator; +mod prompt_builder; pub mod report; pub use orchestrator::PentestOrchestrator; diff --git a/compliance-agent/src/pentest/orchestrator.rs b/compliance-agent/src/pentest/orchestrator.rs index 184db47..2c88ce5 100644 --- a/compliance-agent/src/pentest/orchestrator.rs +++ b/compliance-agent/src/pentest/orchestrator.rs @@ -1,31 +1,27 @@ use std::sync::Arc; use std::time::Duration; -use futures_util::StreamExt; use mongodb::bson::doc; use tokio::sync::broadcast; use compliance_core::models::dast::DastTarget; -use compliance_core::models::finding::{Finding, FindingStatus, Severity}; use compliance_core::models::pentest::*; -use compliance_core::models::sbom::SbomEntry; use compliance_core::traits::pentest_tool::PentestToolContext; use compliance_dast::ToolRegistry; use crate::database::Database; -use crate::llm::client::{ - ChatMessage, LlmResponse, ToolCallRequest, ToolCallRequestFunction, ToolDefinition, +use crate::llm::{ + ChatMessage, LlmClient, LlmResponse, ToolCallRequest, ToolCallRequestFunction, ToolDefinition, }; -use crate::llm::LlmClient; /// Maximum duration for a single pentest session before timeout const SESSION_TIMEOUT: Duration = Duration::from_secs(30 * 60); // 30 minutes pub struct PentestOrchestrator { - tool_registry: ToolRegistry, - llm: Arc, - db: Database, - event_tx: broadcast::Sender, + pub(crate) tool_registry: ToolRegistry, + pub(crate) llm: Arc, + pub(crate) db: Database, + pub(crate) event_tx: broadcast::Sender, } impl PentestOrchestrator { @@ -39,10 +35,12 @@ impl PentestOrchestrator { } } + #[allow(dead_code)] pub fn subscribe(&self) -> broadcast::Receiver { self.event_tx.subscribe() } + #[allow(dead_code)] pub fn event_sender(&self) -> broadcast::Sender { self.event_tx.clone() } @@ -111,18 +109,20 @@ impl PentestOrchestrator { target: &DastTarget, initial_message: &str, ) -> Result<(), crate::error::AgentError> { - let session_id = session - .id - .map(|oid| oid.to_hex()) - .unwrap_or_default(); + let session_id = session.id.map(|oid| oid.to_hex()).unwrap_or_default(); // Gather code-awareness context from linked repo - let (sast_findings, sbom_entries, code_context) = - self.gather_repo_context(target).await; + let (sast_findings, sbom_entries, code_context) = self.gather_repo_context(target).await; // Build system prompt with code context let system_prompt = self - .build_system_prompt(session, target, &sast_findings, &sbom_entries, &code_context) + .build_system_prompt( + session, + target, + &sast_findings, + &sbom_entries, + &code_context, + ) .await; // Build tool definitions for LLM @@ -182,8 +182,7 @@ impl PentestOrchestrator { match response { LlmResponse::Content(content) => { - let msg = - PentestMessage::assistant(session_id.clone(), content.clone()); + let msg = PentestMessage::assistant(session_id.clone(), content.clone()); let _ = self.db.pentest_messages().insert_one(&msg).await; let _ = self.event_tx.send(PentestEvent::Message { content: content.clone(), @@ -213,7 +212,10 @@ impl PentestOrchestrator { } break; } - LlmResponse::ToolCalls { calls: tool_calls, reasoning } => { + LlmResponse::ToolCalls { + calls: tool_calls, + reasoning, + } => { let tc_requests: Vec = tool_calls .iter() .map(|tc| ToolCallRequest { @@ -221,15 +223,18 @@ impl PentestOrchestrator { r#type: "function".to_string(), function: ToolCallRequestFunction { name: tc.name.clone(), - arguments: serde_json::to_string(&tc.arguments) - .unwrap_or_default(), + arguments: serde_json::to_string(&tc.arguments).unwrap_or_default(), }, }) .collect(); messages.push(ChatMessage { role: "assistant".to_string(), - content: if reasoning.is_empty() { None } else { Some(reasoning.clone()) }, + content: if reasoning.is_empty() { + None + } else { + Some(reasoning.clone()) + }, tool_calls: Some(tc_requests), tool_call_id: None, }); @@ -274,24 +279,30 @@ impl PentestOrchestrator { let insert_result = self.db.dast_findings().insert_one(&finding).await; if let Ok(res) = &insert_result { - finding_ids.push(res.inserted_id.as_object_id().map(|oid| oid.to_hex()).unwrap_or_default()); - } - let _ = - self.event_tx.send(PentestEvent::Finding { - finding_id: finding - .id + finding_ids.push( + res.inserted_id + .as_object_id() .map(|oid| oid.to_hex()) .unwrap_or_default(), - title: finding.title.clone(), - severity: finding.severity.to_string(), - }); + ); + } + let _ = self.event_tx.send(PentestEvent::Finding { + finding_id: finding + .id + .map(|oid| oid.to_hex()) + .unwrap_or_default(), + title: finding.title.clone(), + severity: finding.severity.to_string(), + }); } // Compute risk score based on findings severity let risk_score: Option = if findings_count > 0 { Some(std::cmp::min( 100, - (findings_count as u8).saturating_mul(15).saturating_add(20), + (findings_count as u8) + .saturating_mul(15) + .saturating_add(20), )) } else { None @@ -415,347 +426,4 @@ impl PentestOrchestrator { Ok(()) } - - // ── Code-Awareness: Gather context from linked repo ───────── - - /// Fetch SAST findings, SBOM entries (with CVEs), and code graph entry points - /// for the repo linked to this DAST target. - async fn gather_repo_context( - &self, - target: &DastTarget, - ) -> (Vec, Vec, Vec) { - let Some(repo_id) = &target.repo_id else { - return (Vec::new(), Vec::new(), Vec::new()); - }; - - let sast_findings = self.fetch_sast_findings(repo_id).await; - let sbom_entries = self.fetch_vulnerable_sbom(repo_id).await; - let code_context = self.fetch_code_context(repo_id, &sast_findings).await; - - tracing::info!( - repo_id, - sast_findings = sast_findings.len(), - vulnerable_deps = sbom_entries.len(), - code_hints = code_context.len(), - "Gathered code-awareness context for pentest" - ); - - (sast_findings, sbom_entries, code_context) - } - - /// Fetch open/triaged SAST findings for the repo (not false positives or resolved) - async fn fetch_sast_findings(&self, repo_id: &str) -> Vec { - let cursor = self - .db - .findings() - .find(doc! { - "repo_id": repo_id, - "status": { "$in": ["open", "triaged"] }, - }) - .sort(doc! { "severity": -1 }) - .limit(100) - .await; - - match cursor { - Ok(mut c) => { - let mut results = Vec::new(); - while let Some(Ok(f)) = c.next().await { - results.push(f); - } - results - } - Err(e) => { - tracing::warn!("Failed to fetch SAST findings for pentest: {e}"); - Vec::new() - } - } - } - - /// Fetch SBOM entries that have known vulnerabilities - async fn fetch_vulnerable_sbom(&self, repo_id: &str) -> Vec { - let cursor = self - .db - .sbom_entries() - .find(doc! { - "repo_id": repo_id, - "known_vulnerabilities": { "$exists": true, "$ne": [] }, - }) - .limit(50) - .await; - - match cursor { - Ok(mut c) => { - let mut results = Vec::new(); - while let Some(Ok(e)) = c.next().await { - results.push(e); - } - results - } - Err(e) => { - tracing::warn!("Failed to fetch vulnerable SBOM entries: {e}"); - Vec::new() - } - } - } - - /// Build CodeContextHint objects from the code knowledge graph. - /// Maps entry points to their source files and links SAST findings. - async fn fetch_code_context( - &self, - repo_id: &str, - sast_findings: &[Finding], - ) -> Vec { - // Get entry point nodes from the code graph - let cursor = self - .db - .graph_nodes() - .find(doc! { - "repo_id": repo_id, - "is_entry_point": true, - }) - .limit(50) - .await; - - let nodes = match cursor { - Ok(mut c) => { - let mut results = Vec::new(); - while let Some(Ok(n)) = c.next().await { - results.push(n); - } - results - } - Err(_) => return Vec::new(), - }; - - // Build hints by matching graph nodes to SAST findings by file path - nodes - .into_iter() - .map(|node| { - // Find SAST findings in the same file - let linked_vulns: Vec = sast_findings - .iter() - .filter(|f| { - f.file_path.as_deref() == Some(&node.file_path) - }) - .map(|f| { - format!( - "[{}] {}: {} (line {})", - f.severity, - f.scanner, - f.title, - f.line_number.unwrap_or(0) - ) - }) - .collect(); - - CodeContextHint { - endpoint_pattern: node.qualified_name.clone(), - handler_function: node.name.clone(), - file_path: node.file_path.clone(), - code_snippet: String::new(), // Could fetch from embeddings - known_vulnerabilities: linked_vulns, - } - }) - .collect() - } - - // ── System Prompt Builder ─────────────────────────────────── - - async fn build_system_prompt( - &self, - session: &PentestSession, - target: &DastTarget, - sast_findings: &[Finding], - sbom_entries: &[SbomEntry], - code_context: &[CodeContextHint], - ) -> String { - let tool_names = self.tool_registry.list_names().join(", "); - let strategy_guidance = match session.strategy { - PentestStrategy::Quick => { - "Focus on the most common and impactful vulnerabilities. Run a quick recon, then target the highest-risk areas." - } - PentestStrategy::Comprehensive => { - "Perform a thorough assessment covering all vulnerability types. Start with recon, then systematically test each attack surface." - } - PentestStrategy::Targeted => { - "Focus specifically on areas highlighted by SAST findings and known CVEs. Prioritize exploiting known weaknesses." - } - PentestStrategy::Aggressive => { - "Use all available tools aggressively. Test with maximum payloads and attempt full exploitation." - } - PentestStrategy::Stealth => { - "Minimize noise. Use fewer requests, avoid aggressive payloads. Focus on passive analysis and targeted probes." - } - }; - - // Build SAST findings section - let sast_section = if sast_findings.is_empty() { - String::from("No SAST findings available for this target.") - } else { - let critical = sast_findings - .iter() - .filter(|f| f.severity == Severity::Critical) - .count(); - let high = sast_findings - .iter() - .filter(|f| f.severity == Severity::High) - .count(); - - let mut section = format!( - "{} open findings ({} critical, {} high):\n", - sast_findings.len(), - critical, - high - ); - - // List the most important findings (critical/high first, up to 20) - for f in sast_findings.iter().take(20) { - let file_info = f - .file_path - .as_ref() - .map(|p| { - format!( - " in {}:{}", - p, - f.line_number.unwrap_or(0) - ) - }) - .unwrap_or_default(); - let status_note = match f.status { - FindingStatus::Triaged => " [TRIAGED]", - _ => "", - }; - section.push_str(&format!( - "- [{sev}] {title}{file}{status}\n", - sev = f.severity, - title = f.title, - file = file_info, - status = status_note, - )); - if let Some(cwe) = &f.cwe { - section.push_str(&format!(" CWE: {cwe}\n")); - } - } - if sast_findings.len() > 20 { - section.push_str(&format!( - "... and {} more findings\n", - sast_findings.len() - 20 - )); - } - section - }; - - // Build SBOM/CVE section - let sbom_section = if sbom_entries.is_empty() { - String::from("No vulnerable dependencies identified.") - } else { - let mut section = format!( - "{} dependencies with known vulnerabilities:\n", - sbom_entries.len() - ); - for entry in sbom_entries.iter().take(15) { - let cve_ids: Vec<&str> = entry - .known_vulnerabilities - .iter() - .map(|v| v.id.as_str()) - .collect(); - section.push_str(&format!( - "- {} {} ({}): {}\n", - entry.name, - entry.version, - entry.package_manager, - cve_ids.join(", ") - )); - } - if sbom_entries.len() > 15 { - section.push_str(&format!( - "... and {} more vulnerable dependencies\n", - sbom_entries.len() - 15 - )); - } - section - }; - - // Build code context section - let code_section = if code_context.is_empty() { - String::from("No code knowledge graph available for this target.") - } else { - let with_vulns = code_context - .iter() - .filter(|c| !c.known_vulnerabilities.is_empty()) - .count(); - - let mut section = format!( - "{} entry points identified ({} with linked SAST findings):\n", - code_context.len(), - with_vulns - ); - - for hint in code_context.iter().take(20) { - section.push_str(&format!( - "- {} ({})\n", - hint.endpoint_pattern, hint.file_path - )); - for vuln in &hint.known_vulnerabilities { - section.push_str(&format!(" SAST: {vuln}\n")); - } - } - section - }; - - format!( - r#"You are an expert penetration tester conducting an authorized security assessment. - -## Target -- **Name**: {target_name} -- **URL**: {base_url} -- **Type**: {target_type} -- **Rate Limit**: {rate_limit} req/s -- **Destructive Tests Allowed**: {allow_destructive} -- **Linked Repository**: {repo_linked} - -## Strategy -{strategy_guidance} - -## SAST Findings (Static Analysis) -{sast_section} - -## Vulnerable Dependencies (SBOM) -{sbom_section} - -## Code Entry Points (Knowledge Graph) -{code_section} - -## Available Tools -{tool_names} - -## Instructions -1. Start by running reconnaissance (recon tool) to fingerprint the target and discover technologies. -2. Run the OpenAPI parser to discover API endpoints from specs. -3. Check infrastructure: DNS, DMARC, TLS, security headers, cookies, CSP, CORS. -4. Based on SAST findings, prioritize testing endpoints where vulnerabilities were found in code. -5. For each vulnerability type found in SAST, use the corresponding DAST tool to verify exploitability. -6. If vulnerable dependencies are listed, try to trigger known CVE conditions against the running application. -7. Test rate limiting on critical endpoints (login, API). -8. Check for console.log leakage in frontend JavaScript. -9. Analyze tool results and chain findings — if one vulnerability enables others, explore the chain. -10. When testing is complete, provide a structured summary with severity and remediation. -11. Always explain your reasoning before invoking each tool. -12. When done, say "Testing complete" followed by a final summary. - -## Important -- This is an authorized penetration test. All testing is permitted within the target scope. -- Respect the rate limit of {rate_limit} requests per second. -- Only use destructive tests if explicitly allowed ({allow_destructive}). -- Use SAST findings to guide your testing — they tell you WHERE in the code vulnerabilities exist. -- Use SBOM data to understand what technologies and versions the target runs. -"#, - target_name = target.name, - base_url = target.base_url, - target_type = target.target_type, - rate_limit = target.rate_limit, - allow_destructive = target.allow_destructive, - repo_linked = target.repo_id.as_deref().unwrap_or("None"), - ) - } } diff --git a/compliance-agent/src/pentest/prompt_builder.rs b/compliance-agent/src/pentest/prompt_builder.rs new file mode 100644 index 0000000..ac8ee97 --- /dev/null +++ b/compliance-agent/src/pentest/prompt_builder.rs @@ -0,0 +1,504 @@ +use compliance_core::models::dast::DastTarget; +use compliance_core::models::finding::{Finding, FindingStatus, Severity}; +use compliance_core::models::pentest::*; +use compliance_core::models::sbom::SbomEntry; + +use super::orchestrator::PentestOrchestrator; + +/// Return strategy guidance text for the given strategy. +fn strategy_guidance(strategy: &PentestStrategy) -> &'static str { + match strategy { + PentestStrategy::Quick => { + "Focus on the most common and impactful vulnerabilities. Run a quick recon, then target the highest-risk areas." + } + PentestStrategy::Comprehensive => { + "Perform a thorough assessment covering all vulnerability types. Start with recon, then systematically test each attack surface." + } + PentestStrategy::Targeted => { + "Focus specifically on areas highlighted by SAST findings and known CVEs. Prioritize exploiting known weaknesses." + } + PentestStrategy::Aggressive => { + "Use all available tools aggressively. Test with maximum payloads and attempt full exploitation." + } + PentestStrategy::Stealth => { + "Minimize noise. Use fewer requests, avoid aggressive payloads. Focus on passive analysis and targeted probes." + } + } +} + +/// Build the SAST findings section for the system prompt. +fn build_sast_section(sast_findings: &[Finding]) -> String { + if sast_findings.is_empty() { + return String::from("No SAST findings available for this target."); + } + + let critical = sast_findings + .iter() + .filter(|f| f.severity == Severity::Critical) + .count(); + let high = sast_findings + .iter() + .filter(|f| f.severity == Severity::High) + .count(); + + let mut section = format!( + "{} open findings ({} critical, {} high):\n", + sast_findings.len(), + critical, + high + ); + + // List the most important findings (critical/high first, up to 20) + for f in sast_findings.iter().take(20) { + let file_info = f + .file_path + .as_ref() + .map(|p| format!(" in {}:{}", p, f.line_number.unwrap_or(0))) + .unwrap_or_default(); + let status_note = match f.status { + FindingStatus::Triaged => " [TRIAGED]", + _ => "", + }; + section.push_str(&format!( + "- [{sev}] {title}{file}{status}\n", + sev = f.severity, + title = f.title, + file = file_info, + status = status_note, + )); + if let Some(cwe) = &f.cwe { + section.push_str(&format!(" CWE: {cwe}\n")); + } + } + if sast_findings.len() > 20 { + section.push_str(&format!( + "... and {} more findings\n", + sast_findings.len() - 20 + )); + } + section +} + +/// Build the SBOM/CVE section for the system prompt. +fn build_sbom_section(sbom_entries: &[SbomEntry]) -> String { + if sbom_entries.is_empty() { + return String::from("No vulnerable dependencies identified."); + } + + let mut section = format!( + "{} dependencies with known vulnerabilities:\n", + sbom_entries.len() + ); + for entry in sbom_entries.iter().take(15) { + let cve_ids: Vec<&str> = entry + .known_vulnerabilities + .iter() + .map(|v| v.id.as_str()) + .collect(); + section.push_str(&format!( + "- {} {} ({}): {}\n", + entry.name, + entry.version, + entry.package_manager, + cve_ids.join(", ") + )); + } + if sbom_entries.len() > 15 { + section.push_str(&format!( + "... and {} more vulnerable dependencies\n", + sbom_entries.len() - 15 + )); + } + section +} + +/// Build the code context section for the system prompt. +fn build_code_section(code_context: &[CodeContextHint]) -> String { + if code_context.is_empty() { + return String::from("No code knowledge graph available for this target."); + } + + let with_vulns = code_context + .iter() + .filter(|c| !c.known_vulnerabilities.is_empty()) + .count(); + + let mut section = format!( + "{} entry points identified ({} with linked SAST findings):\n", + code_context.len(), + with_vulns + ); + + for hint in code_context.iter().take(20) { + section.push_str(&format!( + "- {} ({})\n", + hint.endpoint_pattern, hint.file_path + )); + for vuln in &hint.known_vulnerabilities { + section.push_str(&format!(" SAST: {vuln}\n")); + } + } + section +} + +impl PentestOrchestrator { + pub(crate) async fn build_system_prompt( + &self, + session: &PentestSession, + target: &DastTarget, + sast_findings: &[Finding], + sbom_entries: &[SbomEntry], + code_context: &[CodeContextHint], + ) -> String { + let tool_names = self.tool_registry.list_names().join(", "); + let guidance = strategy_guidance(&session.strategy); + let sast_section = build_sast_section(sast_findings); + let sbom_section = build_sbom_section(sbom_entries); + let code_section = build_code_section(code_context); + + format!( + r#"You are an expert penetration tester conducting an authorized security assessment. + +## Target +- **Name**: {target_name} +- **URL**: {base_url} +- **Type**: {target_type} +- **Rate Limit**: {rate_limit} req/s +- **Destructive Tests Allowed**: {allow_destructive} +- **Linked Repository**: {repo_linked} + +## Strategy +{strategy_guidance} + +## SAST Findings (Static Analysis) +{sast_section} + +## Vulnerable Dependencies (SBOM) +{sbom_section} + +## Code Entry Points (Knowledge Graph) +{code_section} + +## Available Tools +{tool_names} + +## Instructions +1. Start by running reconnaissance (recon tool) to fingerprint the target and discover technologies. +2. Run the OpenAPI parser to discover API endpoints from specs. +3. Check infrastructure: DNS, DMARC, TLS, security headers, cookies, CSP, CORS. +4. Based on SAST findings, prioritize testing endpoints where vulnerabilities were found in code. +5. For each vulnerability type found in SAST, use the corresponding DAST tool to verify exploitability. +6. If vulnerable dependencies are listed, try to trigger known CVE conditions against the running application. +7. Test rate limiting on critical endpoints (login, API). +8. Check for console.log leakage in frontend JavaScript. +9. Analyze tool results and chain findings — if one vulnerability enables others, explore the chain. +10. When testing is complete, provide a structured summary with severity and remediation. +11. Always explain your reasoning before invoking each tool. +12. When done, say "Testing complete" followed by a final summary. + +## Important +- This is an authorized penetration test. All testing is permitted within the target scope. +- Respect the rate limit of {rate_limit} requests per second. +- Only use destructive tests if explicitly allowed ({allow_destructive}). +- Use SAST findings to guide your testing — they tell you WHERE in the code vulnerabilities exist. +- Use SBOM data to understand what technologies and versions the target runs. +"#, + target_name = target.name, + base_url = target.base_url, + target_type = target.target_type, + rate_limit = target.rate_limit, + allow_destructive = target.allow_destructive, + repo_linked = target.repo_id.as_deref().unwrap_or("None"), + strategy_guidance = guidance, + ) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use compliance_core::models::finding::Severity; + use compliance_core::models::sbom::VulnRef; + use compliance_core::models::scan::ScanType; + + fn make_finding( + severity: Severity, + title: &str, + file_path: Option<&str>, + line: Option, + status: FindingStatus, + cwe: Option<&str>, + ) -> Finding { + let mut f = Finding::new( + "repo-1".into(), + format!("fp-{title}"), + "semgrep".into(), + ScanType::Sast, + title.into(), + "desc".into(), + severity, + ); + f.file_path = file_path.map(|s| s.to_string()); + f.line_number = line; + f.status = status; + f.cwe = cwe.map(|s| s.to_string()); + f + } + + fn make_sbom_entry(name: &str, version: &str, cves: &[&str]) -> SbomEntry { + let mut entry = SbomEntry::new("repo-1".into(), name.into(), version.into(), "npm".into()); + entry.known_vulnerabilities = cves + .iter() + .map(|id| VulnRef { + id: id.to_string(), + source: "nvd".into(), + severity: None, + url: None, + }) + .collect(); + entry + } + + fn make_code_hint(endpoint: &str, file: &str, vulns: Vec) -> CodeContextHint { + CodeContextHint { + endpoint_pattern: endpoint.into(), + handler_function: "handler".into(), + file_path: file.into(), + code_snippet: String::new(), + known_vulnerabilities: vulns, + } + } + + // ── strategy_guidance ──────────────────────────────────────────── + + #[test] + fn strategy_guidance_quick() { + let g = strategy_guidance(&PentestStrategy::Quick); + assert!(g.contains("most common")); + assert!(g.contains("quick recon")); + } + + #[test] + fn strategy_guidance_comprehensive() { + let g = strategy_guidance(&PentestStrategy::Comprehensive); + assert!(g.contains("thorough assessment")); + } + + #[test] + fn strategy_guidance_targeted() { + let g = strategy_guidance(&PentestStrategy::Targeted); + assert!(g.contains("SAST findings")); + assert!(g.contains("known CVEs")); + } + + #[test] + fn strategy_guidance_aggressive() { + let g = strategy_guidance(&PentestStrategy::Aggressive); + assert!(g.contains("aggressively")); + assert!(g.contains("full exploitation")); + } + + #[test] + fn strategy_guidance_stealth() { + let g = strategy_guidance(&PentestStrategy::Stealth); + assert!(g.contains("Minimize noise")); + assert!(g.contains("passive analysis")); + } + + // ── build_sast_section ─────────────────────────────────────────── + + #[test] + fn sast_section_empty() { + let section = build_sast_section(&[]); + assert_eq!(section, "No SAST findings available for this target."); + } + + #[test] + fn sast_section_single_critical() { + let findings = vec![make_finding( + Severity::Critical, + "SQL Injection", + Some("src/db.rs"), + Some(42), + FindingStatus::Open, + Some("CWE-89"), + )]; + let section = build_sast_section(&findings); + assert!(section.contains("1 open findings (1 critical, 0 high)")); + assert!(section.contains("[critical] SQL Injection in src/db.rs:42")); + assert!(section.contains("CWE: CWE-89")); + } + + #[test] + fn sast_section_triaged_finding_shows_marker() { + let findings = vec![make_finding( + Severity::High, + "XSS", + None, + None, + FindingStatus::Triaged, + None, + )]; + let section = build_sast_section(&findings); + assert!(section.contains("[TRIAGED]")); + } + + #[test] + fn sast_section_no_file_path_omits_location() { + let findings = vec![make_finding( + Severity::Medium, + "Open Redirect", + None, + None, + FindingStatus::Open, + None, + )]; + let section = build_sast_section(&findings); + assert!(section.contains("- [medium] Open Redirect\n")); + assert!(!section.contains(" in ")); + } + + #[test] + fn sast_section_counts_critical_and_high() { + let findings = vec![ + make_finding( + Severity::Critical, + "F1", + None, + None, + FindingStatus::Open, + None, + ), + make_finding( + Severity::Critical, + "F2", + None, + None, + FindingStatus::Open, + None, + ), + make_finding(Severity::High, "F3", None, None, FindingStatus::Open, None), + make_finding( + Severity::Medium, + "F4", + None, + None, + FindingStatus::Open, + None, + ), + ]; + let section = build_sast_section(&findings); + assert!(section.contains("4 open findings (2 critical, 1 high)")); + } + + #[test] + fn sast_section_truncates_at_20() { + let findings: Vec = (0..25) + .map(|i| { + make_finding( + Severity::Low, + &format!("Finding {i}"), + None, + None, + FindingStatus::Open, + None, + ) + }) + .collect(); + let section = build_sast_section(&findings); + assert!(section.contains("... and 5 more findings")); + // Should contain Finding 19 (the 20th) but not Finding 20 (the 21st) + assert!(section.contains("Finding 19")); + assert!(!section.contains("Finding 20")); + } + + // ── build_sbom_section ─────────────────────────────────────────── + + #[test] + fn sbom_section_empty() { + let section = build_sbom_section(&[]); + assert_eq!(section, "No vulnerable dependencies identified."); + } + + #[test] + fn sbom_section_single_entry() { + let entries = vec![make_sbom_entry("lodash", "4.17.20", &["CVE-2021-23337"])]; + let section = build_sbom_section(&entries); + assert!(section.contains("1 dependencies with known vulnerabilities")); + assert!(section.contains("- lodash 4.17.20 (npm): CVE-2021-23337")); + } + + #[test] + fn sbom_section_multiple_cves() { + let entries = vec![make_sbom_entry( + "openssl", + "1.1.1", + &["CVE-2022-0001", "CVE-2022-0002"], + )]; + let section = build_sbom_section(&entries); + assert!(section.contains("CVE-2022-0001, CVE-2022-0002")); + } + + #[test] + fn sbom_section_truncates_at_15() { + let entries: Vec = (0..18) + .map(|i| make_sbom_entry(&format!("pkg-{i}"), "1.0.0", &["CVE-2024-0001"])) + .collect(); + let section = build_sbom_section(&entries); + assert!(section.contains("... and 3 more vulnerable dependencies")); + assert!(section.contains("pkg-14")); + assert!(!section.contains("pkg-15")); + } + + // ── build_code_section ─────────────────────────────────────────── + + #[test] + fn code_section_empty() { + let section = build_code_section(&[]); + assert_eq!( + section, + "No code knowledge graph available for this target." + ); + } + + #[test] + fn code_section_single_entry_no_vulns() { + let hints = vec![make_code_hint("GET /api/users", "src/routes.rs", vec![])]; + let section = build_code_section(&hints); + assert!(section.contains("1 entry points identified (0 with linked SAST findings)")); + assert!(section.contains("- GET /api/users (src/routes.rs)")); + } + + #[test] + fn code_section_with_linked_vulns() { + let hints = vec![make_code_hint( + "POST /login", + "src/auth.rs", + vec!["[critical] semgrep: SQL Injection (line 15)".into()], + )]; + let section = build_code_section(&hints); + assert!(section.contains("1 entry points identified (1 with linked SAST findings)")); + assert!(section.contains("SAST: [critical] semgrep: SQL Injection (line 15)")); + } + + #[test] + fn code_section_counts_entries_with_vulns() { + let hints = vec![ + make_code_hint("GET /a", "a.rs", vec!["vuln1".into()]), + make_code_hint("GET /b", "b.rs", vec![]), + make_code_hint("GET /c", "c.rs", vec!["vuln2".into(), "vuln3".into()]), + ]; + let section = build_code_section(&hints); + assert!(section.contains("3 entry points identified (2 with linked SAST findings)")); + } + + #[test] + fn code_section_truncates_at_20() { + let hints: Vec = (0..25) + .map(|i| make_code_hint(&format!("GET /ep{i}"), &format!("f{i}.rs"), vec![])) + .collect(); + let section = build_code_section(&hints); + assert!(section.contains("GET /ep19")); + assert!(!section.contains("GET /ep20")); + } +} diff --git a/compliance-agent/src/pentest/report/archive.rs b/compliance-agent/src/pentest/report/archive.rs new file mode 100644 index 0000000..4a3bb4c --- /dev/null +++ b/compliance-agent/src/pentest/report/archive.rs @@ -0,0 +1,43 @@ +use std::io::{Cursor, Write}; + +use zip::write::SimpleFileOptions; +use zip::AesMode; + +use super::ReportContext; + +pub(super) fn build_zip( + ctx: &ReportContext, + password: &str, + html: &str, + pdf: &[u8], +) -> Result, zip::result::ZipError> { + let buf = Cursor::new(Vec::new()); + let mut zip = zip::ZipWriter::new(buf); + + let options = SimpleFileOptions::default() + .compression_method(zip::CompressionMethod::Deflated) + .with_aes_encryption(AesMode::Aes256, password); + + // report.pdf (primary) + zip.start_file("report.pdf", options)?; + zip.write_all(pdf)?; + + // report.html (fallback) + zip.start_file("report.html", options)?; + zip.write_all(html.as_bytes())?; + + // findings.json + let findings_json = + serde_json::to_string_pretty(&ctx.findings).unwrap_or_else(|_| "[]".to_string()); + zip.start_file("findings.json", options)?; + zip.write_all(findings_json.as_bytes())?; + + // attack-chain.json + let chain_json = + serde_json::to_string_pretty(&ctx.attack_chain).unwrap_or_else(|_| "[]".to_string()); + zip.start_file("attack-chain.json", options)?; + zip.write_all(chain_json.as_bytes())?; + + let cursor = zip.finish()?; + Ok(cursor.into_inner()) +} diff --git a/compliance-agent/src/pentest/report.rs b/compliance-agent/src/pentest/report/html.rs similarity index 73% rename from compliance-agent/src/pentest/report.rs rename to compliance-agent/src/pentest/report/html.rs index 6997ded..3882f76 100644 --- a/compliance-agent/src/pentest/report.rs +++ b/compliance-agent/src/pentest/report/html.rs @@ -1,193 +1,50 @@ -use std::io::{Cursor, Write}; - use compliance_core::models::dast::DastFinding; -use compliance_core::models::pentest::{AttackChainNode, PentestSession}; -use sha2::{Digest, Sha256}; -use zip::write::SimpleFileOptions; -use zip::AesMode; +use compliance_core::models::pentest::AttackChainNode; -/// Report archive with metadata -pub struct ReportArchive { - /// The password-protected ZIP bytes - pub archive: Vec, - /// SHA-256 hex digest of the archive - pub sha256: String, -} +use super::ReportContext; -/// Report context gathered from the database -pub struct ReportContext { - pub session: PentestSession, - pub target_name: String, - pub target_url: String, - pub findings: Vec, - pub attack_chain: Vec, - pub requester_name: String, - pub requester_email: String, -} - -/// Generate a password-protected ZIP archive containing the pentest report. -/// -/// The archive contains: -/// - `report.pdf` — Professional pentest report (PDF) -/// - `report.html` — HTML source (fallback) -/// - `findings.json` — Raw findings data -/// - `attack-chain.json` — Attack chain timeline -/// -/// Files are encrypted with AES-256 inside the ZIP (standard WinZip AES format, -/// supported by 7-Zip, WinRAR, macOS Archive Utility, etc.). -pub async fn generate_encrypted_report( - ctx: &ReportContext, - password: &str, -) -> Result { - let html = build_html_report(ctx); - - // Convert HTML to PDF via headless Chrome - let pdf_bytes = html_to_pdf(&html).await?; - - let zip_bytes = build_zip(ctx, password, &html, &pdf_bytes) - .map_err(|e| format!("Failed to create archive: {e}"))?; - - let mut hasher = Sha256::new(); - hasher.update(&zip_bytes); - let sha256 = hex::encode(hasher.finalize()); - - Ok(ReportArchive { archive: zip_bytes, sha256 }) -} - -/// Convert HTML string to PDF bytes using headless Chrome/Chromium. -async fn html_to_pdf(html: &str) -> Result, String> { - let tmp_dir = std::env::temp_dir(); - let run_id = uuid::Uuid::new_v4().to_string(); - let html_path = tmp_dir.join(format!("pentest-report-{run_id}.html")); - let pdf_path = tmp_dir.join(format!("pentest-report-{run_id}.pdf")); - - // Write HTML to temp file - std::fs::write(&html_path, html) - .map_err(|e| format!("Failed to write temp HTML: {e}"))?; - - // Find Chrome/Chromium binary - let chrome_bin = find_chrome_binary() - .ok_or_else(|| "Chrome/Chromium not found. Install google-chrome or chromium to generate PDF reports.".to_string())?; - - tracing::info!(chrome = %chrome_bin, "Generating PDF report via headless Chrome"); - - let html_url = format!("file://{}", html_path.display()); - - let output = tokio::process::Command::new(&chrome_bin) - .args([ - "--headless", - "--disable-gpu", - "--no-sandbox", - "--disable-software-rasterizer", - "--run-all-compositor-stages-before-draw", - "--disable-dev-shm-usage", - &format!("--print-to-pdf={}", pdf_path.display()), - "--no-pdf-header-footer", - &html_url, - ]) - .output() - .await - .map_err(|e| format!("Failed to run Chrome: {e}"))?; - - if !output.status.success() { - let stderr = String::from_utf8_lossy(&output.stderr); - // Clean up temp files - let _ = std::fs::remove_file(&html_path); - let _ = std::fs::remove_file(&pdf_path); - return Err(format!("Chrome PDF generation failed: {stderr}")); - } - - let pdf_bytes = std::fs::read(&pdf_path) - .map_err(|e| format!("Failed to read generated PDF: {e}"))?; - - // Clean up temp files - let _ = std::fs::remove_file(&html_path); - let _ = std::fs::remove_file(&pdf_path); - - if pdf_bytes.is_empty() { - return Err("Chrome produced an empty PDF".to_string()); - } - - tracing::info!(size_kb = pdf_bytes.len() / 1024, "PDF report generated"); - Ok(pdf_bytes) -} - -/// Search for Chrome/Chromium binary on the system. -fn find_chrome_binary() -> Option { - let candidates = [ - "google-chrome-stable", - "google-chrome", - "chromium-browser", - "chromium", - ]; - for name in &candidates { - if let Ok(output) = std::process::Command::new("which").arg(name).output() { - if output.status.success() { - let path = String::from_utf8_lossy(&output.stdout).trim().to_string(); - if !path.is_empty() { - return Some(path); - } - } - } - } - None -} - -fn build_zip( - ctx: &ReportContext, - password: &str, - html: &str, - pdf: &[u8], -) -> Result, zip::result::ZipError> { - let buf = Cursor::new(Vec::new()); - let mut zip = zip::ZipWriter::new(buf); - - let options = SimpleFileOptions::default() - .compression_method(zip::CompressionMethod::Deflated) - .with_aes_encryption(AesMode::Aes256, password); - - // report.pdf (primary) - zip.start_file("report.pdf", options.clone())?; - zip.write_all(pdf)?; - - // report.html (fallback) - zip.start_file("report.html", options.clone())?; - zip.write_all(html.as_bytes())?; - - // findings.json - let findings_json = - serde_json::to_string_pretty(&ctx.findings).unwrap_or_else(|_| "[]".to_string()); - zip.start_file("findings.json", options.clone())?; - zip.write_all(findings_json.as_bytes())?; - - // attack-chain.json - let chain_json = - serde_json::to_string_pretty(&ctx.attack_chain).unwrap_or_else(|_| "[]".to_string()); - zip.start_file("attack-chain.json", options)?; - zip.write_all(chain_json.as_bytes())?; - - let cursor = zip.finish()?; - Ok(cursor.into_inner()) -} - -fn build_html_report(ctx: &ReportContext) -> String { +#[allow(clippy::format_in_format_args)] +pub(super) fn build_html_report(ctx: &ReportContext) -> String { let session = &ctx.session; let session_id = session .id .map(|oid| oid.to_hex()) .unwrap_or_else(|| "-".to_string()); - let date_str = session.started_at.format("%B %d, %Y at %H:%M UTC").to_string(); + let date_str = session + .started_at + .format("%B %d, %Y at %H:%M UTC") + .to_string(); let date_short = session.started_at.format("%B %d, %Y").to_string(); let completed_str = session .completed_at .map(|d| d.format("%B %d, %Y at %H:%M UTC").to_string()) .unwrap_or_else(|| "In Progress".to_string()); - let critical = ctx.findings.iter().filter(|f| f.severity.to_string() == "critical").count(); - let high = ctx.findings.iter().filter(|f| f.severity.to_string() == "high").count(); - let medium = ctx.findings.iter().filter(|f| f.severity.to_string() == "medium").count(); - let low = ctx.findings.iter().filter(|f| f.severity.to_string() == "low").count(); - let info = ctx.findings.iter().filter(|f| f.severity.to_string() == "info").count(); + let critical = ctx + .findings + .iter() + .filter(|f| f.severity.to_string() == "critical") + .count(); + let high = ctx + .findings + .iter() + .filter(|f| f.severity.to_string() == "high") + .count(); + let medium = ctx + .findings + .iter() + .filter(|f| f.severity.to_string() == "medium") + .count(); + let low = ctx + .findings + .iter() + .filter(|f| f.severity.to_string() == "low") + .count(); + let info = ctx + .findings + .iter() + .filter(|f| f.severity.to_string() == "info") + .count(); let exploitable = ctx.findings.iter().filter(|f| f.exploitable).count(); let total = ctx.findings.len(); @@ -212,10 +69,8 @@ fn build_html_report(ctx: &ReportContext) -> String { }; // Risk score 0-100 - let risk_score: usize = std::cmp::min( - 100, - critical * 25 + high * 15 + medium * 8 + low * 3 + info * 1, - ); + let risk_score: usize = + std::cmp::min(100, critical * 25 + high * 15 + medium * 8 + low * 3 + info); // Collect unique tool names used let tool_names: Vec = { @@ -247,7 +102,8 @@ fn build_html_report(ctx: &ReportContext) -> String { if high > 0 { bar.push_str(&format!( r#"
{}
"#, - std::cmp::max(high_pct, 4), high + std::cmp::max(high_pct, 4), + high )); } if medium > 0 { @@ -259,22 +115,38 @@ fn build_html_report(ctx: &ReportContext) -> String { if low > 0 { bar.push_str(&format!( r#"
{}
"#, - std::cmp::max(low_pct, 4), low + std::cmp::max(low_pct, 4), + low )); } if info > 0 { bar.push_str(&format!( r#"
{}
"#, - std::cmp::max(info_pct, 4), info + std::cmp::max(info_pct, 4), + info )); } bar.push_str(""); bar.push_str(r#"
"#); - if critical > 0 { bar.push_str(r#" Critical"#); } - if high > 0 { bar.push_str(r#" High"#); } - if medium > 0 { bar.push_str(r#" Medium"#); } - if low > 0 { bar.push_str(r#" Low"#); } - if info > 0 { bar.push_str(r#" Info"#); } + if critical > 0 { + bar.push_str( + r#" Critical"#, + ); + } + if high > 0 { + bar.push_str(r#" High"#); + } + if medium > 0 { + bar.push_str( + r#" Medium"#, + ); + } + if low > 0 { + bar.push_str(r#" Low"#); + } + if info > 0 { + bar.push_str(r#" Info"#); + } bar.push_str("
"); bar } else { @@ -322,7 +194,12 @@ fn build_html_report(ctx: &ReportContext) -> String { let param_row = f .parameter .as_deref() - .map(|p| format!("Parameter{}", html_escape(p))) + .map(|p| { + format!( + "Parameter{}", + html_escape(p) + ) + }) .unwrap_or_default(); let remediation = f .remediation @@ -332,7 +209,9 @@ fn build_html_report(ctx: &ReportContext) -> String { let evidence_html = if f.evidence.is_empty() { String::new() } else { - let mut eh = String::from(r#"
Evidence
"#); + let mut eh = String::from( + r#"
Evidence
RequestStatusDetails
"#, + ); for ev in &f.evidence { let payload_info = ev .payload @@ -346,7 +225,7 @@ fn build_html_report(ctx: &ReportContext) -> String { ev.response_status, ev.response_snippet .as_deref() - .map(|s| html_escape(s)) + .map(html_escape) .unwrap_or_default(), payload_info, )); @@ -402,7 +281,8 @@ fn build_html_report(ctx: &ReportContext) -> String { let mut chain_html = String::new(); if !ctx.attack_chain.is_empty() { // Compute phases via BFS from root nodes - let mut phase_map: std::collections::HashMap = std::collections::HashMap::new(); + let mut phase_map: std::collections::HashMap = + std::collections::HashMap::new(); let mut queue: std::collections::VecDeque = std::collections::VecDeque::new(); for node in &ctx.attack_chain { @@ -438,7 +318,13 @@ fn build_html_report(ctx: &ReportContext) -> String { // Group nodes by phase let max_phase = phase_map.values().copied().max().unwrap_or(0); - let phase_labels = ["Reconnaissance", "Enumeration", "Exploitation", "Validation", "Post-Exploitation"]; + let phase_labels = [ + "Reconnaissance", + "Enumeration", + "Exploitation", + "Validation", + "Post-Exploitation", + ]; for phase_idx in 0..=max_phase { let phase_nodes: Vec<&AttackChainNode> = ctx @@ -485,15 +371,28 @@ fn build_html_report(ctx: &ReportContext) -> String { format!( r#"{} finding{}"#, node.findings_produced.len(), - if node.findings_produced.len() == 1 { "" } else { "s" }, + if node.findings_produced.len() == 1 { + "" + } else { + "s" + }, ) } else { String::new() }; - let risk_badge = node.risk_score.map(|r| { - let risk_class = if r >= 70 { "risk-high" } else if r >= 40 { "risk-med" } else { "risk-low" }; - format!(r#"Risk: {r}"#) - }).unwrap_or_default(); + let risk_badge = node + .risk_score + .map(|r| { + let risk_class = if r >= 70 { + "risk-high" + } else if r >= 40 { + "risk-med" + } else { + "risk-low" + }; + format!(r#"Risk: {r}"#) + }) + .unwrap_or_default(); let reasoning_html = if node.llm_reasoning.is_empty() { String::new() @@ -547,10 +446,20 @@ fn build_html_report(ctx: &ReportContext) -> String { let toc_findings_sub = if !ctx.findings.is_empty() { let mut sub = String::new(); let mut fnum = 0usize; - for (si, &sev_key) in severity_order.iter().enumerate() { - let count = ctx.findings.iter().filter(|f| f.severity.to_string() == sev_key).count(); - if count == 0 { continue; } - for f in ctx.findings.iter().filter(|f| f.severity.to_string() == sev_key) { + for &sev_key in severity_order.iter() { + let count = ctx + .findings + .iter() + .filter(|f| f.severity.to_string() == sev_key) + .count(); + if count == 0 { + continue; + } + for f in ctx + .findings + .iter() + .filter(|f| f.severity.to_string() == sev_key) + { fnum += 1; sub.push_str(&format!( r#"
F-{:03} — {}
"#, @@ -1577,19 +1486,49 @@ table.tools-table td:first-child {{ fn tool_category(tool_name: &str) -> &'static str { let name = tool_name.to_lowercase(); - if name.contains("nmap") || name.contains("port") { return "Network Reconnaissance"; } - if name.contains("nikto") || name.contains("header") { return "Web Server Analysis"; } - if name.contains("zap") || name.contains("spider") || name.contains("crawl") { return "Web Application Scanning"; } - if name.contains("sqlmap") || name.contains("sqli") || name.contains("sql") { return "SQL Injection Testing"; } - if name.contains("xss") || name.contains("cross-site") { return "Cross-Site Scripting Testing"; } - if name.contains("dir") || name.contains("brute") || name.contains("fuzz") || name.contains("gobuster") { return "Directory Enumeration"; } - if name.contains("ssl") || name.contains("tls") || name.contains("cert") { return "SSL/TLS Analysis"; } - if name.contains("api") || name.contains("endpoint") { return "API Security Testing"; } - if name.contains("auth") || name.contains("login") || name.contains("credential") { return "Authentication Testing"; } - if name.contains("cors") { return "CORS Testing"; } - if name.contains("csrf") { return "CSRF Testing"; } - if name.contains("nuclei") || name.contains("template") { return "Vulnerability Scanning"; } - if name.contains("whatweb") || name.contains("tech") || name.contains("wappalyzer") { return "Technology Fingerprinting"; } + if name.contains("nmap") || name.contains("port") { + return "Network Reconnaissance"; + } + if name.contains("nikto") || name.contains("header") { + return "Web Server Analysis"; + } + if name.contains("zap") || name.contains("spider") || name.contains("crawl") { + return "Web Application Scanning"; + } + if name.contains("sqlmap") || name.contains("sqli") || name.contains("sql") { + return "SQL Injection Testing"; + } + if name.contains("xss") || name.contains("cross-site") { + return "Cross-Site Scripting Testing"; + } + if name.contains("dir") + || name.contains("brute") + || name.contains("fuzz") + || name.contains("gobuster") + { + return "Directory Enumeration"; + } + if name.contains("ssl") || name.contains("tls") || name.contains("cert") { + return "SSL/TLS Analysis"; + } + if name.contains("api") || name.contains("endpoint") { + return "API Security Testing"; + } + if name.contains("auth") || name.contains("login") || name.contains("credential") { + return "Authentication Testing"; + } + if name.contains("cors") { + return "CORS Testing"; + } + if name.contains("csrf") { + return "CSRF Testing"; + } + if name.contains("nuclei") || name.contains("template") { + return "Vulnerability Scanning"; + } + if name.contains("whatweb") || name.contains("tech") || name.contains("wappalyzer") { + return "Technology Fingerprinting"; + } "Security Testing" } @@ -1599,3 +1538,314 @@ fn html_escape(s: &str) -> String { .replace('>', ">") .replace('"', """) } + +#[cfg(test)] +mod tests { + use super::*; + use compliance_core::models::dast::{DastFinding, DastVulnType}; + use compliance_core::models::finding::Severity; + use compliance_core::models::pentest::{ + AttackChainNode, AttackNodeStatus, PentestSession, PentestStrategy, + }; + + // ── html_escape ────────────────────────────────────────────────── + + #[test] + fn html_escape_handles_ampersand() { + assert_eq!(html_escape("a & b"), "a & b"); + } + + #[test] + fn html_escape_handles_angle_brackets() { + assert_eq!(html_escape(" + + + + + "#; + let urls = ConsoleLogDetectorTool::extract_js_urls(html, "https://example.com"); + assert_eq!(urls.len(), 3); + assert!(urls.contains(&"https://example.com/static/app.js".to_string())); + assert!(urls.contains(&"https://cdn.example.com/lib.js".to_string())); + assert!(urls.contains(&"https://cdn2.example.com/vendor.js".to_string())); + } + + #[test] + fn extract_js_urls_no_scripts() { + let html = "

Hello

"; + let urls = ConsoleLogDetectorTool::extract_js_urls(html, "https://example.com"); + assert!(urls.is_empty()); + } + + #[test] + fn extract_js_urls_filters_non_js() { + let html = r#""#; + let urls = ConsoleLogDetectorTool::extract_js_urls(html, "https://example.com"); + // Only .js files should be extracted + assert_eq!(urls.len(), 1); + assert!(urls[0].ends_with("/app.js")); + } + + #[test] + fn scan_js_content_finds_console_log() { + let js = r#" + function init() { + console.log("debug info"); + doStuff(); + } + "#; + let matches = ConsoleLogDetectorTool::scan_js_content(js, "https://example.com/app.js"); + assert_eq!(matches.len(), 1); + assert_eq!(matches[0].pattern, "console.log"); + assert_eq!(matches[0].line_number, Some(3)); + } + + #[test] + fn scan_js_content_finds_multiple_patterns() { + let js = + "console.log('a');\nconsole.debug('b');\nconsole.error('c');\ndebugger;\nalert('x');"; + let matches = ConsoleLogDetectorTool::scan_js_content(js, "test.js"); + assert_eq!(matches.len(), 5); + } + + #[test] + fn scan_js_content_skips_comments() { + let js = "// console.log('commented out');\n* console.log('also comment');\n/* console.log('block comment') */"; + let matches = ConsoleLogDetectorTool::scan_js_content(js, "test.js"); + assert!(matches.is_empty()); + } + + #[test] + fn scan_js_content_one_match_per_line() { + let js = "console.log('a'); console.debug('b');"; + let matches = ConsoleLogDetectorTool::scan_js_content(js, "test.js"); + // Only one match per line + assert_eq!(matches.len(), 1); + } + + #[test] + fn scan_js_content_empty_input() { + let matches = ConsoleLogDetectorTool::scan_js_content("", "test.js"); + assert!(matches.is_empty()); + } + + #[test] + fn patterns_list_is_not_empty() { + let patterns = ConsoleLogDetectorTool::patterns(); + assert!(patterns.len() >= 8); + assert!(patterns.contains(&"console.log(")); + assert!(patterns.contains(&"debugger;")); + } +} + impl PentestTool for ConsoleLogDetectorTool { fn name(&self) -> &str { "console_log_detector" @@ -154,173 +244,180 @@ impl PentestTool for ConsoleLogDetectorTool { &'a self, input: serde_json::Value, context: &'a PentestToolContext, - ) -> std::pin::Pin> + Send + 'a>> { + ) -> std::pin::Pin< + Box> + Send + 'a>, + > { Box::pin(async move { - let url = input - .get("url") - .and_then(|v| v.as_str()) - .ok_or_else(|| CoreError::Dast("Missing required 'url' parameter".to_string()))?; + let url = input + .get("url") + .and_then(|v| v.as_str()) + .ok_or_else(|| CoreError::Dast("Missing required 'url' parameter".to_string()))?; - let additional_js: Vec = input - .get("additional_js_urls") - .and_then(|v| v.as_array()) - .map(|arr| { - arr.iter() - .filter_map(|v| v.as_str().map(String::from)) - .collect() - }) - .unwrap_or_default(); - - let target_id = context - .target - .id - .map(|oid| oid.to_hex()) - .unwrap_or_else(|| "unknown".to_string()); - - // Fetch the main page - let response = self - .http - .get(url) - .send() - .await - .map_err(|e| CoreError::Dast(format!("Failed to fetch {url}: {e}")))?; - - let html = response.text().await.unwrap_or_default(); - - // Scan inline scripts in the HTML - let mut all_matches = Vec::new(); - let inline_matches = Self::scan_js_content(&html, url); - all_matches.extend(inline_matches); - - // Extract JS file URLs from the HTML - let mut js_urls = Self::extract_js_urls(&html, url); - js_urls.extend(additional_js); - js_urls.dedup(); - - // Fetch and scan each JS file - for js_url in &js_urls { - match self.http.get(js_url).send().await { - Ok(resp) => { - if resp.status().is_success() { - let js_content = resp.text().await.unwrap_or_default(); - // Only scan non-minified-looking files or files where we can still - // find patterns (minifiers typically strip console calls, but not always) - let file_matches = Self::scan_js_content(&js_content, js_url); - all_matches.extend(file_matches); - } - } - Err(_) => continue, - } - } - - let mut findings = Vec::new(); - let match_data: Vec = all_matches - .iter() - .map(|m| { - json!({ - "pattern": m.pattern, - "file": m.file_url, - "line": m.line_number, - "snippet": m.line_snippet, + let additional_js: Vec = input + .get("additional_js_urls") + .and_then(|v| v.as_array()) + .map(|arr| { + arr.iter() + .filter_map(|v| v.as_str().map(String::from)) + .collect() }) - }) - .collect(); + .unwrap_or_default(); - if !all_matches.is_empty() { - // Group by file for the finding - let mut by_file: std::collections::HashMap<&str, Vec<&ConsoleMatch>> = - std::collections::HashMap::new(); - for m in &all_matches { - by_file.entry(&m.file_url).or_default().push(m); + let target_id = context + .target + .id + .map(|oid| oid.to_hex()) + .unwrap_or_else(|| "unknown".to_string()); + + // Fetch the main page + let response = self + .http + .get(url) + .send() + .await + .map_err(|e| CoreError::Dast(format!("Failed to fetch {url}: {e}")))?; + + let html = response.text().await.unwrap_or_default(); + + // Scan inline scripts in the HTML + let mut all_matches = Vec::new(); + let inline_matches = Self::scan_js_content(&html, url); + all_matches.extend(inline_matches); + + // Extract JS file URLs from the HTML + let mut js_urls = Self::extract_js_urls(&html, url); + js_urls.extend(additional_js); + js_urls.dedup(); + + // Fetch and scan each JS file + for js_url in &js_urls { + match self.http.get(js_url).send().await { + Ok(resp) => { + if resp.status().is_success() { + let js_content = resp.text().await.unwrap_or_default(); + // Only scan non-minified-looking files or files where we can still + // find patterns (minifiers typically strip console calls, but not always) + let file_matches = Self::scan_js_content(&js_content, js_url); + all_matches.extend(file_matches); + } + } + Err(_) => continue, + } } - for (file_url, matches) in &by_file { - let pattern_summary: Vec = matches - .iter() - .take(5) - .map(|m| { - format!( - " Line {}: {} - {}", - m.line_number.unwrap_or(0), - m.pattern, - if m.line_snippet.len() > 80 { - format!("{}...", &m.line_snippet[..80]) - } else { - m.line_snippet.clone() - } - ) + let mut findings = Vec::new(); + let match_data: Vec = all_matches + .iter() + .map(|m| { + json!({ + "pattern": m.pattern, + "file": m.file_url, + "line": m.line_number, + "snippet": m.line_snippet, }) - .collect(); + }) + .collect(); - let evidence = DastEvidence { - request_method: "GET".to_string(), - request_url: file_url.to_string(), - request_headers: None, - request_body: None, - response_status: 200, - response_headers: None, - response_snippet: Some(pattern_summary.join("\n")), - screenshot_path: None, - payload: None, - response_time_ms: None, - }; + if !all_matches.is_empty() { + // Group by file for the finding + let mut by_file: std::collections::HashMap<&str, Vec<&ConsoleMatch>> = + std::collections::HashMap::new(); + for m in &all_matches { + by_file.entry(&m.file_url).or_default().push(m); + } - let total = matches.len(); - let extra = if total > 5 { - format!(" (and {} more)", total - 5) - } else { - String::new() - }; + for (file_url, matches) in &by_file { + let pattern_summary: Vec = matches + .iter() + .take(5) + .map(|m| { + format!( + " Line {}: {} - {}", + m.line_number.unwrap_or(0), + m.pattern, + if m.line_snippet.len() > 80 { + format!("{}...", &m.line_snippet[..80]) + } else { + m.line_snippet.clone() + } + ) + }) + .collect(); - let mut finding = DastFinding::new( - String::new(), - target_id.clone(), - DastVulnType::ConsoleLogLeakage, - format!("Console/debug statements in {}", file_url), - format!( - "Found {total} console/debug statements in {file_url}{extra}. \ + let evidence = DastEvidence { + request_method: "GET".to_string(), + request_url: file_url.to_string(), + request_headers: None, + request_body: None, + response_status: 200, + response_headers: None, + response_snippet: Some(pattern_summary.join("\n")), + screenshot_path: None, + payload: None, + response_time_ms: None, + }; + + let total = matches.len(); + let extra = if total > 5 { + format!(" (and {} more)", total - 5) + } else { + String::new() + }; + + let mut finding = DastFinding::new( + String::new(), + target_id.clone(), + DastVulnType::ConsoleLogLeakage, + format!("Console/debug statements in {}", file_url), + format!( + "Found {total} console/debug statements in {file_url}{extra}. \ These can leak sensitive information such as API responses, user data, \ or internal state to anyone with browser developer tools open." - ), - Severity::Low, - file_url.to_string(), - "GET".to_string(), - ); - finding.cwe = Some("CWE-532".to_string()); - finding.evidence = vec![evidence]; - finding.remediation = Some( - "Remove console.log/debug/error statements from production code. \ + ), + Severity::Low, + file_url.to_string(), + "GET".to_string(), + ); + finding.cwe = Some("CWE-532".to_string()); + finding.evidence = vec![evidence]; + finding.remediation = Some( + "Remove console.log/debug/error statements from production code. \ Use a build step (e.g., babel plugin, terser) to strip console calls \ during the production build." - .to_string(), - ); - findings.push(finding); + .to_string(), + ); + findings.push(finding); + } } - } - let total_matches = all_matches.len(); - let count = findings.len(); - info!(url, js_files = js_urls.len(), total_matches, "Console log detection complete"); + let total_matches = all_matches.len(); + let count = findings.len(); + info!( + url, + js_files = js_urls.len(), + total_matches, + "Console log detection complete" + ); - Ok(PentestToolResult { - summary: if total_matches > 0 { - format!( - "Found {total_matches} console/debug statements across {} files.", - count - ) - } else { - format!( - "No console/debug statements found in HTML or {} JS files.", - js_urls.len() - ) - }, - findings, - data: json!({ - "total_matches": total_matches, - "js_files_scanned": js_urls.len(), - "matches": match_data, - }), - }) + Ok(PentestToolResult { + summary: if total_matches > 0 { + format!( + "Found {total_matches} console/debug statements across {} files.", + count + ) + } else { + format!( + "No console/debug statements found in HTML or {} JS files.", + js_urls.len() + ) + }, + findings, + data: json!({ + "total_matches": total_matches, + "js_files_scanned": js_urls.len(), + "matches": match_data, + }), + }) }) } } diff --git a/compliance-dast/src/tools/cookie_analyzer.rs b/compliance-dast/src/tools/cookie_analyzer.rs index 7563889..9985e8b 100644 --- a/compliance-dast/src/tools/cookie_analyzer.rs +++ b/compliance-dast/src/tools/cookie_analyzer.rs @@ -14,6 +14,7 @@ pub struct CookieAnalyzerTool { #[derive(Debug)] struct ParsedCookie { name: String, + #[allow(dead_code)] value: String, secure: bool, http_only: bool, @@ -92,6 +93,81 @@ impl CookieAnalyzerTool { } } +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn parse_simple_cookie() { + let cookie = CookieAnalyzerTool::parse_set_cookie("session_id=abc123"); + assert_eq!(cookie.name, "session_id"); + assert_eq!(cookie.value, "abc123"); + assert!(!cookie.secure); + assert!(!cookie.http_only); + assert!(cookie.same_site.is_none()); + assert!(cookie.domain.is_none()); + assert!(cookie.path.is_none()); + } + + #[test] + fn parse_cookie_with_all_attributes() { + let raw = "token=xyz; Secure; HttpOnly; SameSite=Strict; Domain=.example.com; Path=/api"; + let cookie = CookieAnalyzerTool::parse_set_cookie(raw); + assert_eq!(cookie.name, "token"); + assert_eq!(cookie.value, "xyz"); + assert!(cookie.secure); + assert!(cookie.http_only); + assert_eq!(cookie.same_site.as_deref(), Some("strict")); + assert_eq!(cookie.domain.as_deref(), Some(".example.com")); + assert_eq!(cookie.path.as_deref(), Some("/api")); + assert_eq!(cookie.raw, raw); + } + + #[test] + fn parse_cookie_samesite_none() { + let cookie = CookieAnalyzerTool::parse_set_cookie("id=1; SameSite=None; Secure"); + assert_eq!(cookie.same_site.as_deref(), Some("none")); + assert!(cookie.secure); + } + + #[test] + fn parse_cookie_with_equals_in_value() { + let cookie = CookieAnalyzerTool::parse_set_cookie("data=a=b=c; HttpOnly"); + assert_eq!(cookie.name, "data"); + assert_eq!(cookie.value, "a=b=c"); + assert!(cookie.http_only); + } + + #[test] + fn is_sensitive_cookie_known_names() { + assert!(CookieAnalyzerTool::is_sensitive_cookie("session_id")); + assert!(CookieAnalyzerTool::is_sensitive_cookie("PHPSESSID")); + assert!(CookieAnalyzerTool::is_sensitive_cookie("JSESSIONID")); + assert!(CookieAnalyzerTool::is_sensitive_cookie("connect.sid")); + assert!(CookieAnalyzerTool::is_sensitive_cookie("asp.net_sessionid")); + assert!(CookieAnalyzerTool::is_sensitive_cookie("auth_token")); + assert!(CookieAnalyzerTool::is_sensitive_cookie("jwt_access")); + assert!(CookieAnalyzerTool::is_sensitive_cookie("csrf_token")); + assert!(CookieAnalyzerTool::is_sensitive_cookie("my_sess_cookie")); + assert!(CookieAnalyzerTool::is_sensitive_cookie("SID")); + } + + #[test] + fn is_sensitive_cookie_non_sensitive() { + assert!(!CookieAnalyzerTool::is_sensitive_cookie("theme")); + assert!(!CookieAnalyzerTool::is_sensitive_cookie("language")); + assert!(!CookieAnalyzerTool::is_sensitive_cookie("_ga")); + assert!(!CookieAnalyzerTool::is_sensitive_cookie("tracking")); + } + + #[test] + fn parse_empty_cookie_header() { + let cookie = CookieAnalyzerTool::parse_set_cookie(""); + assert_eq!(cookie.name, ""); + assert_eq!(cookie.value, ""); + } +} + impl PentestTool for CookieAnalyzerTool { fn name(&self) -> &str { "cookie_analyzer" @@ -123,96 +199,96 @@ impl PentestTool for CookieAnalyzerTool { &'a self, input: serde_json::Value, context: &'a PentestToolContext, - ) -> std::pin::Pin> + Send + 'a>> { + ) -> std::pin::Pin< + Box> + Send + 'a>, + > { Box::pin(async move { - let url = input - .get("url") - .and_then(|v| v.as_str()) - .ok_or_else(|| CoreError::Dast("Missing required 'url' parameter".to_string()))?; + let url = input + .get("url") + .and_then(|v| v.as_str()) + .ok_or_else(|| CoreError::Dast("Missing required 'url' parameter".to_string()))?; - let login_url = input.get("login_url").and_then(|v| v.as_str()); + let login_url = input.get("login_url").and_then(|v| v.as_str()); - let target_id = context - .target - .id - .map(|oid| oid.to_hex()) - .unwrap_or_else(|| "unknown".to_string()); + let target_id = context + .target + .id + .map(|oid| oid.to_hex()) + .unwrap_or_else(|| "unknown".to_string()); - let mut findings = Vec::new(); - let mut cookie_data = Vec::new(); + let mut findings = Vec::new(); + let mut cookie_data = Vec::new(); - // Collect Set-Cookie headers from the main URL and optional login URL - let urls_to_check: Vec<&str> = std::iter::once(url) - .chain(login_url.into_iter()) - .collect(); + // Collect Set-Cookie headers from the main URL and optional login URL + let urls_to_check: Vec<&str> = std::iter::once(url).chain(login_url).collect(); - for check_url in &urls_to_check { - // Use a client that does NOT follow redirects so we catch cookies on redirect responses - let no_redirect_client = reqwest::Client::builder() - .danger_accept_invalid_certs(true) - .redirect(reqwest::redirect::Policy::none()) - .timeout(std::time::Duration::from_secs(15)) - .build() - .map_err(|e| CoreError::Dast(format!("Client build error: {e}")))?; + for check_url in &urls_to_check { + // Use a client that does NOT follow redirects so we catch cookies on redirect responses + let no_redirect_client = reqwest::Client::builder() + .danger_accept_invalid_certs(true) + .redirect(reqwest::redirect::Policy::none()) + .timeout(std::time::Duration::from_secs(15)) + .build() + .map_err(|e| CoreError::Dast(format!("Client build error: {e}")))?; - let response = match no_redirect_client.get(*check_url).send().await { - Ok(r) => r, - Err(e) => { - // Try with the main client that follows redirects - match self.http.get(*check_url).send().await { - Ok(r) => r, - Err(_) => continue, + let response = match no_redirect_client.get(*check_url).send().await { + Ok(r) => r, + Err(_e) => { + // Try with the main client that follows redirects + match self.http.get(*check_url).send().await { + Ok(r) => r, + Err(_) => continue, + } } - } - }; + }; - let status = response.status().as_u16(); - let set_cookie_headers: Vec = response - .headers() - .get_all("set-cookie") - .iter() - .filter_map(|v| v.to_str().ok().map(String::from)) - .collect(); + let status = response.status().as_u16(); + let set_cookie_headers: Vec = response + .headers() + .get_all("set-cookie") + .iter() + .filter_map(|v| v.to_str().ok().map(String::from)) + .collect(); - for raw_cookie in &set_cookie_headers { - let cookie = Self::parse_set_cookie(raw_cookie); - let is_sensitive = Self::is_sensitive_cookie(&cookie.name); - let is_https = check_url.starts_with("https://"); + for raw_cookie in &set_cookie_headers { + let cookie = Self::parse_set_cookie(raw_cookie); + let is_sensitive = Self::is_sensitive_cookie(&cookie.name); + let is_https = check_url.starts_with("https://"); - let cookie_info = json!({ - "name": cookie.name, - "secure": cookie.secure, - "http_only": cookie.http_only, - "same_site": cookie.same_site, - "domain": cookie.domain, - "path": cookie.path, - "is_sensitive": is_sensitive, - "url": check_url, - }); - cookie_data.push(cookie_info); + let cookie_info = json!({ + "name": cookie.name, + "secure": cookie.secure, + "http_only": cookie.http_only, + "same_site": cookie.same_site, + "domain": cookie.domain, + "path": cookie.path, + "is_sensitive": is_sensitive, + "url": check_url, + }); + cookie_data.push(cookie_info); - // Check: missing Secure flag - if !cookie.secure && (is_https || is_sensitive) { - let severity = if is_sensitive { - Severity::High - } else { - Severity::Medium - }; + // Check: missing Secure flag + if !cookie.secure && (is_https || is_sensitive) { + let severity = if is_sensitive { + Severity::High + } else { + Severity::Medium + }; - let evidence = DastEvidence { - request_method: "GET".to_string(), - request_url: check_url.to_string(), - request_headers: None, - request_body: None, - response_status: status, - response_headers: None, - response_snippet: Some(cookie.raw.clone()), - screenshot_path: None, - payload: None, - response_time_ms: None, - }; + let evidence = DastEvidence { + request_method: "GET".to_string(), + request_url: check_url.to_string(), + request_headers: None, + request_body: None, + response_status: status, + response_headers: None, + response_snippet: Some(cookie.raw.clone()), + screenshot_path: None, + payload: None, + response_time_ms: None, + }; - let mut finding = DastFinding::new( + let mut finding = DastFinding::new( String::new(), target_id.clone(), DastVulnType::CookieSecurity, @@ -226,32 +302,32 @@ impl PentestTool for CookieAnalyzerTool { check_url.to_string(), "GET".to_string(), ); - finding.cwe = Some("CWE-614".to_string()); - finding.evidence = vec![evidence]; - finding.remediation = Some( - "Add the 'Secure' attribute to the Set-Cookie header to ensure the \ + finding.cwe = Some("CWE-614".to_string()); + finding.evidence = vec![evidence]; + finding.remediation = Some( + "Add the 'Secure' attribute to the Set-Cookie header to ensure the \ cookie is only sent over HTTPS connections." - .to_string(), - ); - findings.push(finding); - } + .to_string(), + ); + findings.push(finding); + } - // Check: missing HttpOnly flag on sensitive cookies - if !cookie.http_only && is_sensitive { - let evidence = DastEvidence { - request_method: "GET".to_string(), - request_url: check_url.to_string(), - request_headers: None, - request_body: None, - response_status: status, - response_headers: None, - response_snippet: Some(cookie.raw.clone()), - screenshot_path: None, - payload: None, - response_time_ms: None, - }; + // Check: missing HttpOnly flag on sensitive cookies + if !cookie.http_only && is_sensitive { + let evidence = DastEvidence { + request_method: "GET".to_string(), + request_url: check_url.to_string(), + request_headers: None, + request_body: None, + response_status: status, + response_headers: None, + response_snippet: Some(cookie.raw.clone()), + screenshot_path: None, + payload: None, + response_time_ms: None, + }; - let mut finding = DastFinding::new( + let mut finding = DastFinding::new( String::new(), target_id.clone(), DastVulnType::CookieSecurity, @@ -265,137 +341,137 @@ impl PentestTool for CookieAnalyzerTool { check_url.to_string(), "GET".to_string(), ); - finding.cwe = Some("CWE-1004".to_string()); - finding.evidence = vec![evidence]; - finding.remediation = Some( - "Add the 'HttpOnly' attribute to the Set-Cookie header to prevent \ + finding.cwe = Some("CWE-1004".to_string()); + finding.evidence = vec![evidence]; + finding.remediation = Some( + "Add the 'HttpOnly' attribute to the Set-Cookie header to prevent \ JavaScript access to the cookie." - .to_string(), - ); - findings.push(finding); - } + .to_string(), + ); + findings.push(finding); + } - // Check: missing or weak SameSite - if is_sensitive { - let weak_same_site = match &cookie.same_site { - None => true, - Some(ss) => ss == "none", - }; - - if weak_same_site { - let evidence = DastEvidence { - request_method: "GET".to_string(), - request_url: check_url.to_string(), - request_headers: None, - request_body: None, - response_status: status, - response_headers: None, - response_snippet: Some(cookie.raw.clone()), - screenshot_path: None, - payload: None, - response_time_ms: None, + // Check: missing or weak SameSite + if is_sensitive { + let weak_same_site = match &cookie.same_site { + None => true, + Some(ss) => ss == "none", }; - let desc = if cookie.same_site.is_none() { - format!( + if weak_same_site { + let evidence = DastEvidence { + request_method: "GET".to_string(), + request_url: check_url.to_string(), + request_headers: None, + request_body: None, + response_status: status, + response_headers: None, + response_snippet: Some(cookie.raw.clone()), + screenshot_path: None, + payload: None, + response_time_ms: None, + }; + + let desc = if cookie.same_site.is_none() { + format!( "The session/auth cookie '{}' does not have a SameSite attribute. \ This may allow cross-site request forgery (CSRF) attacks.", cookie.name ) - } else { - format!( + } else { + format!( "The session/auth cookie '{}' has SameSite=None, which allows it \ to be sent in cross-site requests, enabling CSRF attacks.", cookie.name ) - }; + }; - let mut finding = DastFinding::new( - String::new(), - target_id.clone(), - DastVulnType::CookieSecurity, - format!("Cookie '{}' missing or weak SameSite", cookie.name), - desc, - Severity::Medium, - check_url.to_string(), - "GET".to_string(), - ); - finding.cwe = Some("CWE-1275".to_string()); - finding.evidence = vec![evidence]; - finding.remediation = Some( - "Set 'SameSite=Strict' or 'SameSite=Lax' on session/auth cookies \ + let mut finding = DastFinding::new( + String::new(), + target_id.clone(), + DastVulnType::CookieSecurity, + format!("Cookie '{}' missing or weak SameSite", cookie.name), + desc, + Severity::Medium, + check_url.to_string(), + "GET".to_string(), + ); + finding.cwe = Some("CWE-1275".to_string()); + finding.evidence = vec![evidence]; + finding.remediation = Some( + "Set 'SameSite=Strict' or 'SameSite=Lax' on session/auth cookies \ to prevent cross-site request inclusion." - .to_string(), - ); - findings.push(finding); + .to_string(), + ); + findings.push(finding); + } } - } - // Check: overly broad domain - if let Some(ref domain) = cookie.domain { - // A domain starting with a dot applies to all subdomains - let dot_domain = domain.starts_with('.'); - // Count domain parts - if only 2 parts (e.g., .example.com), it's broad - let parts: Vec<&str> = domain.trim_start_matches('.').split('.').collect(); - if dot_domain && parts.len() <= 2 && is_sensitive { - let evidence = DastEvidence { - request_method: "GET".to_string(), - request_url: check_url.to_string(), - request_headers: None, - request_body: None, - response_status: status, - response_headers: None, - response_snippet: Some(cookie.raw.clone()), - screenshot_path: None, - payload: None, - response_time_ms: None, - }; + // Check: overly broad domain + if let Some(ref domain) = cookie.domain { + // A domain starting with a dot applies to all subdomains + let dot_domain = domain.starts_with('.'); + // Count domain parts - if only 2 parts (e.g., .example.com), it's broad + let parts: Vec<&str> = domain.trim_start_matches('.').split('.').collect(); + if dot_domain && parts.len() <= 2 && is_sensitive { + let evidence = DastEvidence { + request_method: "GET".to_string(), + request_url: check_url.to_string(), + request_headers: None, + request_body: None, + response_status: status, + response_headers: None, + response_snippet: Some(cookie.raw.clone()), + screenshot_path: None, + payload: None, + response_time_ms: None, + }; - let mut finding = DastFinding::new( - String::new(), - target_id.clone(), - DastVulnType::CookieSecurity, - format!("Cookie '{}' has overly broad domain", cookie.name), - format!( - "The cookie '{}' is scoped to domain '{}' which includes all \ + let mut finding = DastFinding::new( + String::new(), + target_id.clone(), + DastVulnType::CookieSecurity, + format!("Cookie '{}' has overly broad domain", cookie.name), + format!( + "The cookie '{}' is scoped to domain '{}' which includes all \ subdomains. If any subdomain is compromised, the attacker can \ access this cookie.", - cookie.name, domain - ), - Severity::Low, - check_url.to_string(), - "GET".to_string(), - ); - finding.cwe = Some("CWE-1004".to_string()); - finding.evidence = vec![evidence]; - finding.remediation = Some( + cookie.name, domain + ), + Severity::Low, + check_url.to_string(), + "GET".to_string(), + ); + finding.cwe = Some("CWE-1004".to_string()); + finding.evidence = vec![evidence]; + finding.remediation = Some( "Restrict the cookie domain to the specific subdomain that needs it \ rather than the entire parent domain." .to_string(), ); - findings.push(finding); + findings.push(finding); + } } } } - } - let count = findings.len(); - info!(url, findings = count, "Cookie analysis complete"); + let count = findings.len(); + info!(url, findings = count, "Cookie analysis complete"); - Ok(PentestToolResult { - summary: if count > 0 { - format!("Found {count} cookie security issues.") - } else if cookie_data.is_empty() { - "No cookies were set by the target.".to_string() - } else { - "All cookies have proper security attributes.".to_string() - }, - findings, - data: json!({ - "cookies": cookie_data, - "total_cookies": cookie_data.len(), - }), - }) + Ok(PentestToolResult { + summary: if count > 0 { + format!("Found {count} cookie security issues.") + } else if cookie_data.is_empty() { + "No cookies were set by the target.".to_string() + } else { + "All cookies have proper security attributes.".to_string() + }, + findings, + data: json!({ + "cookies": cookie_data, + "total_cookies": cookie_data.len(), + }), + }) }) } } diff --git a/compliance-dast/src/tools/cors_checker.rs b/compliance-dast/src/tools/cors_checker.rs index 736a59f..8dcb5e1 100644 --- a/compliance-dast/src/tools/cors_checker.rs +++ b/compliance-dast/src/tools/cors_checker.rs @@ -22,19 +22,60 @@ impl CorsCheckerTool { vec![ ("null_origin", "null".to_string()), ("evil_domain", "https://evil.com".to_string()), - ( - "subdomain_spoof", - format!("https://{target_host}.evil.com"), - ), - ( - "prefix_spoof", - format!("https://evil-{target_host}"), - ), + ("subdomain_spoof", format!("https://{target_host}.evil.com")), + ("prefix_spoof", format!("https://evil-{target_host}")), ("http_downgrade", format!("http://{target_host}")), ] } } +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_origins_contains_expected_variants() { + let origins = CorsCheckerTool::test_origins("example.com"); + assert_eq!(origins.len(), 5); + + let names: Vec<&str> = origins.iter().map(|(name, _)| *name).collect(); + assert!(names.contains(&"null_origin")); + assert!(names.contains(&"evil_domain")); + assert!(names.contains(&"subdomain_spoof")); + assert!(names.contains(&"prefix_spoof")); + assert!(names.contains(&"http_downgrade")); + } + + #[test] + fn test_origins_uses_target_host() { + let origins = CorsCheckerTool::test_origins("myapp.io"); + + let subdomain = origins + .iter() + .find(|(n, _)| *n == "subdomain_spoof") + .unwrap(); + assert_eq!(subdomain.1, "https://myapp.io.evil.com"); + + let prefix = origins.iter().find(|(n, _)| *n == "prefix_spoof").unwrap(); + assert_eq!(prefix.1, "https://evil-myapp.io"); + + let http_downgrade = origins + .iter() + .find(|(n, _)| *n == "http_downgrade") + .unwrap(); + assert_eq!(http_downgrade.1, "http://myapp.io"); + } + + #[test] + fn test_origins_null_and_evil_are_static() { + let origins = CorsCheckerTool::test_origins("anything.com"); + let null_origin = origins.iter().find(|(n, _)| *n == "null_origin").unwrap(); + assert_eq!(null_origin.1, "null"); + let evil = origins.iter().find(|(n, _)| *n == "evil_domain").unwrap(); + assert_eq!(evil.1, "https://evil.com"); + } +} + impl PentestTool for CorsCheckerTool { fn name(&self) -> &str { "cors_checker" @@ -68,82 +109,82 @@ impl PentestTool for CorsCheckerTool { &'a self, input: serde_json::Value, context: &'a PentestToolContext, - ) -> std::pin::Pin> + Send + 'a>> { + ) -> std::pin::Pin< + Box> + Send + 'a>, + > { Box::pin(async move { - let url = input - .get("url") - .and_then(|v| v.as_str()) - .ok_or_else(|| CoreError::Dast("Missing required 'url' parameter".to_string()))?; + let url = input + .get("url") + .and_then(|v| v.as_str()) + .ok_or_else(|| CoreError::Dast("Missing required 'url' parameter".to_string()))?; - let additional_origins: Vec = input - .get("additional_origins") - .and_then(|v| v.as_array()) - .map(|arr| { - arr.iter() - .filter_map(|v| v.as_str().map(String::from)) - .collect() - }) - .unwrap_or_default(); + let additional_origins: Vec = input + .get("additional_origins") + .and_then(|v| v.as_array()) + .map(|arr| { + arr.iter() + .filter_map(|v| v.as_str().map(String::from)) + .collect() + }) + .unwrap_or_default(); - let target_id = context - .target - .id - .map(|oid| oid.to_hex()) - .unwrap_or_else(|| "unknown".to_string()); + let target_id = context + .target + .id + .map(|oid| oid.to_hex()) + .unwrap_or_else(|| "unknown".to_string()); - let target_host = url::Url::parse(url) - .ok() - .and_then(|u| u.host_str().map(String::from)) - .unwrap_or_else(|| url.to_string()); + let target_host = url::Url::parse(url) + .ok() + .and_then(|u| u.host_str().map(String::from)) + .unwrap_or_else(|| url.to_string()); - let mut findings = Vec::new(); - let mut cors_data: Vec = Vec::new(); + let mut findings = Vec::new(); + let mut cors_data: Vec = Vec::new(); - // First, send a request without Origin to get baseline - let baseline = self - .http - .get(url) - .send() - .await - .map_err(|e| CoreError::Dast(format!("Failed to fetch {url}: {e}")))?; + // First, send a request without Origin to get baseline + let baseline = self + .http + .get(url) + .send() + .await + .map_err(|e| CoreError::Dast(format!("Failed to fetch {url}: {e}")))?; - let baseline_acao = baseline - .headers() - .get("access-control-allow-origin") - .and_then(|v| v.to_str().ok()) - .map(String::from); + let baseline_acao = baseline + .headers() + .get("access-control-allow-origin") + .and_then(|v| v.to_str().ok()) + .map(String::from); - cors_data.push(json!({ - "origin": null, - "acao": baseline_acao, - })); + cors_data.push(json!({ + "origin": null, + "acao": baseline_acao, + })); - // Check for wildcard + credentials (dangerous combo) - if let Some(ref acao) = baseline_acao { - if acao == "*" { - let acac = baseline - .headers() - .get("access-control-allow-credentials") - .and_then(|v| v.to_str().ok()) - .unwrap_or(""); + // Check for wildcard + credentials (dangerous combo) + if let Some(ref acao) = baseline_acao { + if acao == "*" { + let acac = baseline + .headers() + .get("access-control-allow-credentials") + .and_then(|v| v.to_str().ok()) + .unwrap_or(""); - if acac.to_lowercase() == "true" { - let evidence = DastEvidence { - request_method: "GET".to_string(), - request_url: url.to_string(), - request_headers: None, - request_body: None, - response_status: baseline.status().as_u16(), - response_headers: None, - response_snippet: Some(format!( - "Access-Control-Allow-Origin: *\nAccess-Control-Allow-Credentials: true" - )), - screenshot_path: None, - payload: None, - response_time_ms: None, - }; + if acac.to_lowercase() == "true" { + let evidence = DastEvidence { + request_method: "GET".to_string(), + request_url: url.to_string(), + request_headers: None, + request_body: None, + response_status: baseline.status().as_u16(), + response_headers: None, + response_snippet: Some("Access-Control-Allow-Origin: *\nAccess-Control-Allow-Credentials: true".to_string()), + screenshot_path: None, + payload: None, + response_time_ms: None, + }; - let mut finding = DastFinding::new( + let mut finding = DastFinding::new( String::new(), target_id.clone(), DastVulnType::CorsMisconfiguration, @@ -157,254 +198,251 @@ impl PentestTool for CorsCheckerTool { url.to_string(), "GET".to_string(), ); - finding.cwe = Some("CWE-942".to_string()); - finding.evidence = vec![evidence]; - finding.remediation = Some( - "Never combine Access-Control-Allow-Origin: * with \ + finding.cwe = Some("CWE-942".to_string()); + finding.evidence = vec![evidence]; + finding.remediation = Some( + "Never combine Access-Control-Allow-Origin: * with \ Access-Control-Allow-Credentials: true. Specify explicit allowed origins." - .to_string(), - ); - findings.push(finding); + .to_string(), + ); + findings.push(finding); + } } } - } - // Test with various Origin headers - let mut test_origins = Self::test_origins(&target_host); - for origin in &additional_origins { - test_origins.push(("custom", origin.clone())); - } + // Test with various Origin headers + let mut test_origins = Self::test_origins(&target_host); + for origin in &additional_origins { + test_origins.push(("custom", origin.clone())); + } - for (test_name, origin) in &test_origins { - let resp = match self + for (test_name, origin) in &test_origins { + let resp = match self + .http + .get(url) + .header("Origin", origin.as_str()) + .send() + .await + { + Ok(r) => r, + Err(_) => continue, + }; + + let status = resp.status().as_u16(); + let acao = resp + .headers() + .get("access-control-allow-origin") + .and_then(|v| v.to_str().ok()) + .map(String::from); + + let acac = resp + .headers() + .get("access-control-allow-credentials") + .and_then(|v| v.to_str().ok()) + .map(String::from); + + let acam = resp + .headers() + .get("access-control-allow-methods") + .and_then(|v| v.to_str().ok()) + .map(String::from); + + cors_data.push(json!({ + "test": test_name, + "origin": origin, + "acao": acao, + "acac": acac, + "acam": acam, + "status": status, + })); + + // Check if the origin was reflected back + if let Some(ref acao_val) = acao { + let origin_reflected = acao_val == origin; + let credentials_allowed = acac + .as_ref() + .map(|v| v.to_lowercase() == "true") + .unwrap_or(false); + + if origin_reflected && *test_name != "http_downgrade" { + let severity = if credentials_allowed { + Severity::Critical + } else { + Severity::High + }; + + let resp_headers: HashMap = resp + .headers() + .iter() + .map(|(k, v)| (k.to_string(), v.to_str().unwrap_or("").to_string())) + .collect(); + + let evidence = DastEvidence { + request_method: "GET".to_string(), + request_url: url.to_string(), + request_headers: Some( + [("Origin".to_string(), origin.clone())] + .into_iter() + .collect(), + ), + request_body: None, + response_status: status, + response_headers: Some(resp_headers), + response_snippet: Some(format!( + "Origin: {origin}\nAccess-Control-Allow-Origin: {acao_val}\n\ + Access-Control-Allow-Credentials: {}", + acac.as_deref().unwrap_or("not set") + )), + screenshot_path: None, + payload: Some(origin.clone()), + response_time_ms: None, + }; + + let title = match *test_name { + "null_origin" => "CORS accepts null origin".to_string(), + "evil_domain" => "CORS reflects arbitrary origin".to_string(), + "subdomain_spoof" => { + "CORS vulnerable to subdomain spoofing".to_string() + } + "prefix_spoof" => "CORS vulnerable to prefix spoofing".to_string(), + _ => format!("CORS reflects untrusted origin ({test_name})"), + }; + + let cred_note = if credentials_allowed { + " Combined with Access-Control-Allow-Credentials: true, this allows \ + the attacker to steal authenticated data." + } else { + "" + }; + + let mut finding = DastFinding::new( + String::new(), + target_id.clone(), + DastVulnType::CorsMisconfiguration, + title, + format!( + "The endpoint {url} reflects the Origin header '{origin}' back in \ + Access-Control-Allow-Origin, allowing cross-origin requests from \ + untrusted domains.{cred_note}" + ), + severity, + url.to_string(), + "GET".to_string(), + ); + finding.cwe = Some("CWE-942".to_string()); + finding.exploitable = credentials_allowed; + finding.evidence = vec![evidence]; + finding.remediation = Some( + "Validate the Origin header against a whitelist of trusted origins. \ + Do not reflect the Origin header value directly. Use specific allowed \ + origins instead of wildcards." + .to_string(), + ); + findings.push(finding); + + warn!( + url, + test_name, + origin, + credentials = credentials_allowed, + "CORS misconfiguration detected" + ); + } + + // Special case: HTTP downgrade + if *test_name == "http_downgrade" && origin_reflected && credentials_allowed { + let evidence = DastEvidence { + request_method: "GET".to_string(), + request_url: url.to_string(), + request_headers: Some( + [("Origin".to_string(), origin.clone())] + .into_iter() + .collect(), + ), + request_body: None, + response_status: status, + response_headers: None, + response_snippet: Some(format!( + "HTTP origin accepted: {origin} -> ACAO: {acao_val}" + )), + screenshot_path: None, + payload: Some(origin.clone()), + response_time_ms: None, + }; + + let mut finding = DastFinding::new( + String::new(), + target_id.clone(), + DastVulnType::CorsMisconfiguration, + "CORS allows HTTP origin with credentials".to_string(), + format!( + "The HTTPS endpoint {url} accepts the HTTP origin {origin} with \ + credentials. An attacker performing a man-in-the-middle attack on \ + the HTTP version could steal authenticated data." + ), + Severity::High, + url.to_string(), + "GET".to_string(), + ); + finding.cwe = Some("CWE-942".to_string()); + finding.evidence = vec![evidence]; + finding.remediation = Some( + "Do not accept HTTP origins for HTTPS endpoints. Ensure CORS \ + origin validation enforces the https:// scheme." + .to_string(), + ); + findings.push(finding); + } + } + } + + // Also send a preflight OPTIONS request + if let Ok(resp) = self .http - .get(url) - .header("Origin", origin.as_str()) + .request(reqwest::Method::OPTIONS, url) + .header("Origin", "https://evil.com") + .header("Access-Control-Request-Method", "POST") + .header( + "Access-Control-Request-Headers", + "Authorization, Content-Type", + ) .send() .await { - Ok(r) => r, - Err(_) => continue, - }; + let acam = resp + .headers() + .get("access-control-allow-methods") + .and_then(|v| v.to_str().ok()) + .map(String::from); - let status = resp.status().as_u16(); - let acao = resp - .headers() - .get("access-control-allow-origin") - .and_then(|v| v.to_str().ok()) - .map(String::from); + let acah = resp + .headers() + .get("access-control-allow-headers") + .and_then(|v| v.to_str().ok()) + .map(String::from); - let acac = resp - .headers() - .get("access-control-allow-credentials") - .and_then(|v| v.to_str().ok()) - .map(String::from); - - let acam = resp - .headers() - .get("access-control-allow-methods") - .and_then(|v| v.to_str().ok()) - .map(String::from); - - cors_data.push(json!({ - "test": test_name, - "origin": origin, - "acao": acao, - "acac": acac, - "acam": acam, - "status": status, - })); - - // Check if the origin was reflected back - if let Some(ref acao_val) = acao { - let origin_reflected = acao_val == origin; - let credentials_allowed = acac - .as_ref() - .map(|v| v.to_lowercase() == "true") - .unwrap_or(false); - - if origin_reflected && *test_name != "http_downgrade" { - let severity = if credentials_allowed { - Severity::Critical - } else { - Severity::High - }; - - let resp_headers: HashMap = resp - .headers() - .iter() - .map(|(k, v)| (k.to_string(), v.to_str().unwrap_or("").to_string())) - .collect(); - - let evidence = DastEvidence { - request_method: "GET".to_string(), - request_url: url.to_string(), - request_headers: Some( - [("Origin".to_string(), origin.clone())] - .into_iter() - .collect(), - ), - request_body: None, - response_status: status, - response_headers: Some(resp_headers), - response_snippet: Some(format!( - "Origin: {origin}\nAccess-Control-Allow-Origin: {acao_val}\n\ - Access-Control-Allow-Credentials: {}", - acac.as_deref().unwrap_or("not set") - )), - screenshot_path: None, - payload: Some(origin.clone()), - response_time_ms: None, - }; - - let title = match *test_name { - "null_origin" => { - "CORS accepts null origin".to_string() - } - "evil_domain" => { - "CORS reflects arbitrary origin".to_string() - } - "subdomain_spoof" => { - "CORS vulnerable to subdomain spoofing".to_string() - } - "prefix_spoof" => { - "CORS vulnerable to prefix spoofing".to_string() - } - _ => format!("CORS reflects untrusted origin ({test_name})"), - }; - - let cred_note = if credentials_allowed { - " Combined with Access-Control-Allow-Credentials: true, this allows \ - the attacker to steal authenticated data." - } else { - "" - }; - - let mut finding = DastFinding::new( - String::new(), - target_id.clone(), - DastVulnType::CorsMisconfiguration, - title, - format!( - "The endpoint {url} reflects the Origin header '{origin}' back in \ - Access-Control-Allow-Origin, allowing cross-origin requests from \ - untrusted domains.{cred_note}" - ), - severity, - url.to_string(), - "GET".to_string(), - ); - finding.cwe = Some("CWE-942".to_string()); - finding.exploitable = credentials_allowed; - finding.evidence = vec![evidence]; - finding.remediation = Some( - "Validate the Origin header against a whitelist of trusted origins. \ - Do not reflect the Origin header value directly. Use specific allowed \ - origins instead of wildcards." - .to_string(), - ); - findings.push(finding); - - warn!( - url, - test_name, - origin, - credentials = credentials_allowed, - "CORS misconfiguration detected" - ); - } - - // Special case: HTTP downgrade - if *test_name == "http_downgrade" && origin_reflected && credentials_allowed { - let evidence = DastEvidence { - request_method: "GET".to_string(), - request_url: url.to_string(), - request_headers: Some( - [("Origin".to_string(), origin.clone())] - .into_iter() - .collect(), - ), - request_body: None, - response_status: status, - response_headers: None, - response_snippet: Some(format!( - "HTTP origin accepted: {origin} -> ACAO: {acao_val}" - )), - screenshot_path: None, - payload: Some(origin.clone()), - response_time_ms: None, - }; - - let mut finding = DastFinding::new( - String::new(), - target_id.clone(), - DastVulnType::CorsMisconfiguration, - "CORS allows HTTP origin with credentials".to_string(), - format!( - "The HTTPS endpoint {url} accepts the HTTP origin {origin} with \ - credentials. An attacker performing a man-in-the-middle attack on \ - the HTTP version could steal authenticated data." - ), - Severity::High, - url.to_string(), - "GET".to_string(), - ); - finding.cwe = Some("CWE-942".to_string()); - finding.evidence = vec![evidence]; - finding.remediation = Some( - "Do not accept HTTP origins for HTTPS endpoints. Ensure CORS \ - origin validation enforces the https:// scheme." - .to_string(), - ); - findings.push(finding); - } + cors_data.push(json!({ + "test": "preflight", + "status": resp.status().as_u16(), + "allow_methods": acam, + "allow_headers": acah, + })); } - } - // Also send a preflight OPTIONS request - if let Ok(resp) = self - .http - .request(reqwest::Method::OPTIONS, url) - .header("Origin", "https://evil.com") - .header("Access-Control-Request-Method", "POST") - .header("Access-Control-Request-Headers", "Authorization, Content-Type") - .send() - .await - { - let acam = resp - .headers() - .get("access-control-allow-methods") - .and_then(|v| v.to_str().ok()) - .map(String::from); + let count = findings.len(); + info!(url, findings = count, "CORS check complete"); - let acah = resp - .headers() - .get("access-control-allow-headers") - .and_then(|v| v.to_str().ok()) - .map(String::from); - - cors_data.push(json!({ - "test": "preflight", - "status": resp.status().as_u16(), - "allow_methods": acam, - "allow_headers": acah, - })); - } - - let count = findings.len(); - info!(url, findings = count, "CORS check complete"); - - Ok(PentestToolResult { - summary: if count > 0 { - format!("Found {count} CORS misconfiguration issues for {url}.") - } else { - format!("CORS configuration appears secure for {url}.") - }, - findings, - data: json!({ - "tests": cors_data, - }), - }) + Ok(PentestToolResult { + summary: if count > 0 { + format!("Found {count} CORS misconfiguration issues for {url}.") + } else { + format!("CORS configuration appears secure for {url}.") + }, + findings, + data: json!({ + "tests": cors_data, + }), + }) }) } } diff --git a/compliance-dast/src/tools/csp_analyzer.rs b/compliance-dast/src/tools/csp_analyzer.rs index 75c6a92..c950067 100644 --- a/compliance-dast/src/tools/csp_analyzer.rs +++ b/compliance-dast/src/tools/csp_analyzer.rs @@ -47,7 +47,7 @@ impl CspAnalyzerTool { url: &str, target_id: &str, status: u16, - csp_raw: &str, + _csp_raw: &str, ) -> Vec { let mut findings = Vec::new(); @@ -216,12 +216,18 @@ impl CspAnalyzerTool { ("object-src", "Controls plugins like Flash"), ("base-uri", "Controls the base URL for relative URLs"), ("form-action", "Controls where forms can submit to"), - ("frame-ancestors", "Controls who can embed this page in iframes"), + ( + "frame-ancestors", + "Controls who can embed this page in iframes", + ), ]; for (dir_name, desc) in &important_directives { if !directive_names.contains(dir_name) - && !(has_default_src && *dir_name != "frame-ancestors" && *dir_name != "base-uri" && *dir_name != "form-action") + && (!has_default_src + || *dir_name == "frame-ancestors" + || *dir_name == "base-uri" + || *dir_name == "form-action") { let evidence = make_evidence(format!("CSP missing directive: {dir_name}")); let mut finding = DastFinding::new( @@ -258,6 +264,125 @@ impl CspAnalyzerTool { } } +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn parse_csp_basic() { + let directives = CspAnalyzerTool::parse_csp( + "default-src 'self'; script-src 'self' https://cdn.example.com", + ); + assert_eq!(directives.len(), 2); + assert_eq!(directives[0].name, "default-src"); + assert_eq!(directives[0].values, vec!["'self'"]); + assert_eq!(directives[1].name, "script-src"); + assert_eq!( + directives[1].values, + vec!["'self'", "https://cdn.example.com"] + ); + } + + #[test] + fn parse_csp_empty() { + let directives = CspAnalyzerTool::parse_csp(""); + assert!(directives.is_empty()); + } + + #[test] + fn parse_csp_trailing_semicolons() { + let directives = CspAnalyzerTool::parse_csp("default-src 'none';;;"); + assert_eq!(directives.len(), 1); + assert_eq!(directives[0].name, "default-src"); + assert_eq!(directives[0].values, vec!["'none'"]); + } + + #[test] + fn parse_csp_directive_without_value() { + let directives = CspAnalyzerTool::parse_csp("upgrade-insecure-requests"); + assert_eq!(directives.len(), 1); + assert_eq!(directives[0].name, "upgrade-insecure-requests"); + assert!(directives[0].values.is_empty()); + } + + #[test] + fn parse_csp_names_are_lowercased() { + let directives = CspAnalyzerTool::parse_csp("Script-Src 'self'"); + assert_eq!(directives[0].name, "script-src"); + } + + #[test] + fn analyze_detects_unsafe_inline() { + let directives = CspAnalyzerTool::parse_csp("script-src 'self' 'unsafe-inline'"); + let findings = + CspAnalyzerTool::analyze_directives(&directives, "https://example.com", "t1", 200, ""); + assert!(findings.iter().any(|f| f.title.contains("unsafe-inline"))); + } + + #[test] + fn analyze_detects_unsafe_eval() { + let directives = CspAnalyzerTool::parse_csp("script-src 'self' 'unsafe-eval'"); + let findings = + CspAnalyzerTool::analyze_directives(&directives, "https://example.com", "t1", 200, ""); + assert!(findings.iter().any(|f| f.title.contains("unsafe-eval"))); + } + + #[test] + fn analyze_detects_wildcard() { + let directives = CspAnalyzerTool::parse_csp("img-src *"); + let findings = + CspAnalyzerTool::analyze_directives(&directives, "https://example.com", "t1", 200, ""); + assert!(findings.iter().any(|f| f.title.contains("wildcard"))); + } + + #[test] + fn analyze_detects_data_uri_in_script_src() { + let directives = CspAnalyzerTool::parse_csp("script-src 'self' data:"); + let findings = + CspAnalyzerTool::analyze_directives(&directives, "https://example.com", "t1", 200, ""); + assert!(findings.iter().any(|f| f.title.contains("data:"))); + } + + #[test] + fn analyze_detects_http_sources() { + let directives = CspAnalyzerTool::parse_csp("script-src http:"); + let findings = + CspAnalyzerTool::analyze_directives(&directives, "https://example.com", "t1", 200, ""); + assert!(findings.iter().any(|f| f.title.contains("HTTP sources"))); + } + + #[test] + fn analyze_detects_missing_directives_without_default_src() { + let directives = CspAnalyzerTool::parse_csp("img-src 'self'"); + let findings = + CspAnalyzerTool::analyze_directives(&directives, "https://example.com", "t1", 200, ""); + let missing_names: Vec<&str> = findings + .iter() + .filter(|f| f.title.contains("missing")) + .map(|f| f.title.as_str()) + .collect(); + // Should flag script-src, object-src, base-uri, form-action, frame-ancestors + assert!(missing_names.len() >= 4); + } + + #[test] + fn analyze_good_csp_no_unsafe_findings() { + let directives = CspAnalyzerTool::parse_csp( + "default-src 'none'; script-src 'self'; style-src 'self'; \ + img-src 'self'; object-src 'none'; base-uri 'self'; \ + form-action 'self'; frame-ancestors 'none'", + ); + let findings = + CspAnalyzerTool::analyze_directives(&directives, "https://example.com", "t1", 200, ""); + // A well-configured CSP should not produce unsafe-inline/eval/wildcard findings + assert!(findings.iter().all(|f| { + !f.title.contains("unsafe-inline") + && !f.title.contains("unsafe-eval") + && !f.title.contains("wildcard") + })); + } +} + impl PentestTool for CspAnalyzerTool { fn name(&self) -> &str { "csp_analyzer" @@ -285,163 +410,167 @@ impl PentestTool for CspAnalyzerTool { &'a self, input: serde_json::Value, context: &'a PentestToolContext, - ) -> std::pin::Pin> + Send + 'a>> { + ) -> std::pin::Pin< + Box> + Send + 'a>, + > { Box::pin(async move { - let url = input - .get("url") - .and_then(|v| v.as_str()) - .ok_or_else(|| CoreError::Dast("Missing required 'url' parameter".to_string()))?; + let url = input + .get("url") + .and_then(|v| v.as_str()) + .ok_or_else(|| CoreError::Dast("Missing required 'url' parameter".to_string()))?; - let target_id = context - .target - .id - .map(|oid| oid.to_hex()) - .unwrap_or_else(|| "unknown".to_string()); + let target_id = context + .target + .id + .map(|oid| oid.to_hex()) + .unwrap_or_else(|| "unknown".to_string()); - let response = self - .http - .get(url) - .send() - .await - .map_err(|e| CoreError::Dast(format!("Failed to fetch {url}: {e}")))?; + let response = self + .http + .get(url) + .send() + .await + .map_err(|e| CoreError::Dast(format!("Failed to fetch {url}: {e}")))?; - let status = response.status().as_u16(); + let status = response.status().as_u16(); - // Check for CSP header - let csp_header = response - .headers() - .get("content-security-policy") - .and_then(|v| v.to_str().ok()) - .map(String::from); + // Check for CSP header + let csp_header = response + .headers() + .get("content-security-policy") + .and_then(|v| v.to_str().ok()) + .map(String::from); - // Also check for report-only variant - let csp_report_only = response - .headers() - .get("content-security-policy-report-only") - .and_then(|v| v.to_str().ok()) - .map(String::from); + // Also check for report-only variant + let csp_report_only = response + .headers() + .get("content-security-policy-report-only") + .and_then(|v| v.to_str().ok()) + .map(String::from); - let mut findings = Vec::new(); - let mut csp_data = json!({}); + let mut findings = Vec::new(); + let mut csp_data = json!({}); - match &csp_header { - Some(csp) => { - let directives = Self::parse_csp(csp); - let directive_map: serde_json::Value = directives - .iter() - .map(|d| (d.name.clone(), json!(d.values))) - .collect::>() - .into(); + match &csp_header { + Some(csp) => { + let directives = Self::parse_csp(csp); + let directive_map: serde_json::Value = directives + .iter() + .map(|d| (d.name.clone(), json!(d.values))) + .collect::>() + .into(); - csp_data["csp_header"] = json!(csp); - csp_data["directives"] = directive_map; + csp_data["csp_header"] = json!(csp); + csp_data["directives"] = directive_map; - findings.extend(Self::analyze_directives( - &directives, - url, - &target_id, - status, - csp, - )); - } - None => { - csp_data["csp_header"] = json!(null); + findings.extend(Self::analyze_directives( + &directives, + url, + &target_id, + status, + csp, + )); + } + None => { + csp_data["csp_header"] = json!(null); - let evidence = DastEvidence { - request_method: "GET".to_string(), - request_url: url.to_string(), - request_headers: None, - request_body: None, - response_status: status, - response_headers: None, - response_snippet: Some("Content-Security-Policy header is missing".to_string()), - screenshot_path: None, - payload: None, - response_time_ms: None, - }; + let evidence = DastEvidence { + request_method: "GET".to_string(), + request_url: url.to_string(), + request_headers: None, + request_body: None, + response_status: status, + response_headers: None, + response_snippet: Some( + "Content-Security-Policy header is missing".to_string(), + ), + screenshot_path: None, + payload: None, + response_time_ms: None, + }; - let mut finding = DastFinding::new( - String::new(), - target_id.clone(), - DastVulnType::CspIssue, - "Missing Content-Security-Policy header".to_string(), - format!( - "No Content-Security-Policy header is present on {url}. \ + let mut finding = DastFinding::new( + String::new(), + target_id.clone(), + DastVulnType::CspIssue, + "Missing Content-Security-Policy header".to_string(), + format!( + "No Content-Security-Policy header is present on {url}. \ Without CSP, the browser has no instructions on which sources are \ trusted, making XSS exploitation much easier." - ), - Severity::Medium, - url.to_string(), - "GET".to_string(), - ); - finding.cwe = Some("CWE-16".to_string()); - finding.evidence = vec![evidence]; - finding.remediation = Some( + ), + Severity::Medium, + url.to_string(), + "GET".to_string(), + ); + finding.cwe = Some("CWE-16".to_string()); + finding.evidence = vec![evidence]; + finding.remediation = Some( "Add a Content-Security-Policy header. Start with a restrictive policy like \ \"default-src 'self'; script-src 'self'; style-src 'self'; img-src 'self' data:; \ object-src 'none'; frame-ancestors 'none'; base-uri 'self'\"." .to_string(), ); - findings.push(finding); + findings.push(finding); + } } - } - if let Some(ref report_only) = csp_report_only { - csp_data["csp_report_only"] = json!(report_only); - // If ONLY report-only exists (no enforcing CSP), warn - if csp_header.is_none() { - let evidence = DastEvidence { - request_method: "GET".to_string(), - request_url: url.to_string(), - request_headers: None, - request_body: None, - response_status: status, - response_headers: None, - response_snippet: Some(format!( - "Content-Security-Policy-Report-Only: {}", - report_only - )), - screenshot_path: None, - payload: None, - response_time_ms: None, - }; + if let Some(ref report_only) = csp_report_only { + csp_data["csp_report_only"] = json!(report_only); + // If ONLY report-only exists (no enforcing CSP), warn + if csp_header.is_none() { + let evidence = DastEvidence { + request_method: "GET".to_string(), + request_url: url.to_string(), + request_headers: None, + request_body: None, + response_status: status, + response_headers: None, + response_snippet: Some(format!( + "Content-Security-Policy-Report-Only: {}", + report_only + )), + screenshot_path: None, + payload: None, + response_time_ms: None, + }; - let mut finding = DastFinding::new( - String::new(), - target_id.clone(), - DastVulnType::CspIssue, - "CSP is report-only, not enforcing".to_string(), - "A Content-Security-Policy-Report-Only header is present but no enforcing \ + let mut finding = DastFinding::new( + String::new(), + target_id.clone(), + DastVulnType::CspIssue, + "CSP is report-only, not enforcing".to_string(), + "A Content-Security-Policy-Report-Only header is present but no enforcing \ Content-Security-Policy header exists. Report-only mode only logs violations \ but does not block them." - .to_string(), - Severity::Low, - url.to_string(), - "GET".to_string(), - ); - finding.cwe = Some("CWE-16".to_string()); - finding.evidence = vec![evidence]; - finding.remediation = Some( + .to_string(), + Severity::Low, + url.to_string(), + "GET".to_string(), + ); + finding.cwe = Some("CWE-16".to_string()); + finding.evidence = vec![evidence]; + finding.remediation = Some( "Once you have verified the CSP policy works correctly in report-only mode, \ deploy it as an enforcing Content-Security-Policy header." .to_string(), ); - findings.push(finding); + findings.push(finding); + } } - } - let count = findings.len(); - info!(url, findings = count, "CSP analysis complete"); + let count = findings.len(); + info!(url, findings = count, "CSP analysis complete"); - Ok(PentestToolResult { - summary: if count > 0 { - format!("Found {count} CSP issues for {url}.") - } else { - format!("Content-Security-Policy looks good for {url}.") - }, - findings, - data: csp_data, - }) + Ok(PentestToolResult { + summary: if count > 0 { + format!("Found {count} CSP issues for {url}.") + } else { + format!("Content-Security-Policy looks good for {url}.") + }, + findings, + data: csp_data, + }) }) } } diff --git a/compliance-dast/src/tools/dmarc_checker.rs b/compliance-dast/src/tools/dmarc_checker.rs index 39564a5..b312df8 100644 --- a/compliance-dast/src/tools/dmarc_checker.rs +++ b/compliance-dast/src/tools/dmarc_checker.rs @@ -8,6 +8,12 @@ use tracing::{info, warn}; /// Tool that checks email security configuration (DMARC and SPF records). pub struct DmarcCheckerTool; +impl Default for DmarcCheckerTool { + fn default() -> Self { + Self::new() + } +} + impl DmarcCheckerTool { pub fn new() -> Self { Self @@ -78,6 +84,105 @@ impl DmarcCheckerTool { } } +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn parse_dmarc_policy_reject() { + let record = "v=DMARC1; p=reject; rua=mailto:dmarc@example.com"; + assert_eq!( + DmarcCheckerTool::parse_dmarc_policy(record), + Some("reject".to_string()) + ); + } + + #[test] + fn parse_dmarc_policy_none() { + let record = "v=DMARC1; p=none"; + assert_eq!( + DmarcCheckerTool::parse_dmarc_policy(record), + Some("none".to_string()) + ); + } + + #[test] + fn parse_dmarc_policy_quarantine() { + let record = "v=DMARC1; p=quarantine; sp=none"; + assert_eq!( + DmarcCheckerTool::parse_dmarc_policy(record), + Some("quarantine".to_string()) + ); + } + + #[test] + fn parse_dmarc_policy_missing() { + let record = "v=DMARC1; rua=mailto:test@example.com"; + assert_eq!(DmarcCheckerTool::parse_dmarc_policy(record), None); + } + + #[test] + fn parse_dmarc_subdomain_policy() { + let record = "v=DMARC1; p=reject; sp=quarantine"; + assert_eq!( + DmarcCheckerTool::parse_dmarc_subdomain_policy(record), + Some("quarantine".to_string()) + ); + } + + #[test] + fn parse_dmarc_subdomain_policy_missing() { + let record = "v=DMARC1; p=reject"; + assert_eq!(DmarcCheckerTool::parse_dmarc_subdomain_policy(record), None); + } + + #[test] + fn parse_dmarc_rua_present() { + let record = "v=DMARC1; p=reject; rua=mailto:dmarc@example.com"; + assert_eq!( + DmarcCheckerTool::parse_dmarc_rua(record), + Some("mailto:dmarc@example.com".to_string()) + ); + } + + #[test] + fn parse_dmarc_rua_missing() { + let record = "v=DMARC1; p=none"; + assert_eq!(DmarcCheckerTool::parse_dmarc_rua(record), None); + } + + #[test] + fn is_spf_record_valid() { + assert!(DmarcCheckerTool::is_spf_record( + "v=spf1 include:_spf.google.com -all" + )); + assert!(DmarcCheckerTool::is_spf_record("v=spf1 -all")); + } + + #[test] + fn is_spf_record_invalid() { + assert!(!DmarcCheckerTool::is_spf_record("v=DMARC1; p=reject")); + assert!(!DmarcCheckerTool::is_spf_record("some random txt record")); + } + + #[test] + fn spf_soft_fail_detection() { + assert!(DmarcCheckerTool::spf_uses_soft_fail( + "v=spf1 include:_spf.google.com ~all" + )); + assert!(!DmarcCheckerTool::spf_uses_soft_fail( + "v=spf1 include:_spf.google.com -all" + )); + } + + #[test] + fn spf_allows_all_detection() { + assert!(DmarcCheckerTool::spf_allows_all("v=spf1 +all")); + assert!(!DmarcCheckerTool::spf_allows_all("v=spf1 -all")); + assert!(!DmarcCheckerTool::spf_allows_all("v=spf1 ~all")); + } +} + impl PentestTool for DmarcCheckerTool { fn name(&self) -> &str { "dmarc_checker" @@ -105,43 +210,89 @@ impl PentestTool for DmarcCheckerTool { &'a self, input: serde_json::Value, context: &'a PentestToolContext, - ) -> std::pin::Pin> + Send + 'a>> { + ) -> std::pin::Pin< + Box> + Send + 'a>, + > { Box::pin(async move { - let domain = input - .get("domain") - .and_then(|v| v.as_str()) - .ok_or_else(|| CoreError::Dast("Missing required 'domain' parameter".to_string()))?; + let domain = input + .get("domain") + .and_then(|v| v.as_str()) + .ok_or_else(|| { + CoreError::Dast("Missing required 'domain' parameter".to_string()) + })?; - let target_id = context - .target - .id - .map(|oid| oid.to_hex()) - .unwrap_or_else(|| "unknown".to_string()); + let target_id = context + .target + .id + .map(|oid| oid.to_hex()) + .unwrap_or_else(|| "unknown".to_string()); - let mut findings = Vec::new(); - let mut email_data = json!({}); + let mut findings = Vec::new(); + let mut email_data = json!({}); - // ---- DMARC check ---- - let dmarc_domain = format!("_dmarc.{domain}"); - let dmarc_records = Self::query_txt(&dmarc_domain).await.unwrap_or_default(); + // ---- DMARC check ---- + let dmarc_domain = format!("_dmarc.{domain}"); + let dmarc_records = Self::query_txt(&dmarc_domain).await.unwrap_or_default(); - let dmarc_record = dmarc_records.iter().find(|r| r.starts_with("v=DMARC1")); + let dmarc_record = dmarc_records.iter().find(|r| r.starts_with("v=DMARC1")); - match dmarc_record { - Some(record) => { - email_data["dmarc_record"] = json!(record); + match dmarc_record { + Some(record) => { + email_data["dmarc_record"] = json!(record); - let policy = Self::parse_dmarc_policy(record); - let sp = Self::parse_dmarc_subdomain_policy(record); - let rua = Self::parse_dmarc_rua(record); + let policy = Self::parse_dmarc_policy(record); + let sp = Self::parse_dmarc_subdomain_policy(record); + let rua = Self::parse_dmarc_rua(record); - email_data["dmarc_policy"] = json!(policy); - email_data["dmarc_subdomain_policy"] = json!(sp); - email_data["dmarc_rua"] = json!(rua); + email_data["dmarc_policy"] = json!(policy); + email_data["dmarc_subdomain_policy"] = json!(sp); + email_data["dmarc_rua"] = json!(rua); - // Warn on weak policy - if let Some(ref p) = policy { - if p == "none" { + // Warn on weak policy + if let Some(ref p) = policy { + if p == "none" { + let evidence = DastEvidence { + request_method: "DNS".to_string(), + request_url: dmarc_domain.clone(), + request_headers: None, + request_body: None, + response_status: 0, + response_headers: None, + response_snippet: Some(record.clone()), + screenshot_path: None, + payload: None, + response_time_ms: None, + }; + + let mut finding = DastFinding::new( + String::new(), + target_id.clone(), + DastVulnType::EmailSecurity, + format!("Weak DMARC policy for {domain}"), + format!( + "The DMARC policy for {domain} is set to 'none', which only monitors \ + but does not enforce email authentication. Attackers can spoof emails \ + from this domain." + ), + Severity::Medium, + domain.to_string(), + "DNS".to_string(), + ); + finding.cwe = Some("CWE-290".to_string()); + finding.evidence = vec![evidence]; + finding.remediation = Some( + "Upgrade the DMARC policy from 'p=none' to 'p=quarantine' or \ + 'p=reject' after verifying legitimate email flows are properly \ + authenticated." + .to_string(), + ); + findings.push(finding); + warn!(domain, "DMARC policy is 'none'"); + } + } + + // Warn if no reporting URI + if rua.is_none() { let evidence = DastEvidence { request_method: "DNS".to_string(), request_url: dmarc_domain.clone(), @@ -156,48 +307,6 @@ impl PentestTool for DmarcCheckerTool { }; let mut finding = DastFinding::new( - String::new(), - target_id.clone(), - DastVulnType::EmailSecurity, - format!("Weak DMARC policy for {domain}"), - format!( - "The DMARC policy for {domain} is set to 'none', which only monitors \ - but does not enforce email authentication. Attackers can spoof emails \ - from this domain." - ), - Severity::Medium, - domain.to_string(), - "DNS".to_string(), - ); - finding.cwe = Some("CWE-290".to_string()); - finding.evidence = vec![evidence]; - finding.remediation = Some( - "Upgrade the DMARC policy from 'p=none' to 'p=quarantine' or \ - 'p=reject' after verifying legitimate email flows are properly \ - authenticated." - .to_string(), - ); - findings.push(finding); - warn!(domain, "DMARC policy is 'none'"); - } - } - - // Warn if no reporting URI - if rua.is_none() { - let evidence = DastEvidence { - request_method: "DNS".to_string(), - request_url: dmarc_domain.clone(), - request_headers: None, - request_body: None, - response_status: 0, - response_headers: None, - response_snippet: Some(record.clone()), - screenshot_path: None, - payload: None, - response_time_ms: None, - }; - - let mut finding = DastFinding::new( String::new(), target_id.clone(), DastVulnType::EmailSecurity, @@ -211,80 +320,80 @@ impl PentestTool for DmarcCheckerTool { domain.to_string(), "DNS".to_string(), ); - finding.cwe = Some("CWE-778".to_string()); - finding.evidence = vec![evidence]; - finding.remediation = Some( - "Add a 'rua=' tag to your DMARC record to receive aggregate reports. \ + finding.cwe = Some("CWE-778".to_string()); + finding.evidence = vec![evidence]; + finding.remediation = Some( + "Add a 'rua=' tag to your DMARC record to receive aggregate reports. \ Example: 'rua=mailto:dmarc-reports@example.com'." - .to_string(), - ); - findings.push(finding); + .to_string(), + ); + findings.push(finding); + } } - } - None => { - email_data["dmarc_record"] = json!(null); + None => { + email_data["dmarc_record"] = json!(null); - let evidence = DastEvidence { - request_method: "DNS".to_string(), - request_url: dmarc_domain.clone(), - request_headers: None, - request_body: None, - response_status: 0, - response_headers: None, - response_snippet: Some("No DMARC record found".to_string()), - screenshot_path: None, - payload: None, - response_time_ms: None, - }; - - let mut finding = DastFinding::new( - String::new(), - target_id.clone(), - DastVulnType::EmailSecurity, - format!("Missing DMARC record for {domain}"), - format!( - "No DMARC record was found for {domain}. Without DMARC, there is no \ - policy to prevent email spoofing and phishing attacks using this domain." - ), - Severity::High, - domain.to_string(), - "DNS".to_string(), - ); - finding.cwe = Some("CWE-290".to_string()); - finding.evidence = vec![evidence]; - finding.remediation = Some( - "Create a DMARC TXT record at _dmarc.. Start with 'v=DMARC1; p=none; \ - rua=mailto:dmarc@example.com' and gradually move to 'p=reject'." - .to_string(), - ); - findings.push(finding); - warn!(domain, "No DMARC record found"); - } - } - - // ---- SPF check ---- - let txt_records = Self::query_txt(domain).await.unwrap_or_default(); - let spf_record = txt_records.iter().find(|r| Self::is_spf_record(r)); - - match spf_record { - Some(record) => { - email_data["spf_record"] = json!(record); - - if Self::spf_allows_all(record) { let evidence = DastEvidence { request_method: "DNS".to_string(), - request_url: domain.to_string(), + request_url: dmarc_domain.clone(), request_headers: None, request_body: None, response_status: 0, response_headers: None, - response_snippet: Some(record.clone()), + response_snippet: Some("No DMARC record found".to_string()), screenshot_path: None, payload: None, response_time_ms: None, }; let mut finding = DastFinding::new( + String::new(), + target_id.clone(), + DastVulnType::EmailSecurity, + format!("Missing DMARC record for {domain}"), + format!( + "No DMARC record was found for {domain}. Without DMARC, there is no \ + policy to prevent email spoofing and phishing attacks using this domain." + ), + Severity::High, + domain.to_string(), + "DNS".to_string(), + ); + finding.cwe = Some("CWE-290".to_string()); + finding.evidence = vec![evidence]; + finding.remediation = Some( + "Create a DMARC TXT record at _dmarc.. Start with 'v=DMARC1; p=none; \ + rua=mailto:dmarc@example.com' and gradually move to 'p=reject'." + .to_string(), + ); + findings.push(finding); + warn!(domain, "No DMARC record found"); + } + } + + // ---- SPF check ---- + let txt_records = Self::query_txt(domain).await.unwrap_or_default(); + let spf_record = txt_records.iter().find(|r| Self::is_spf_record(r)); + + match spf_record { + Some(record) => { + email_data["spf_record"] = json!(record); + + if Self::spf_allows_all(record) { + let evidence = DastEvidence { + request_method: "DNS".to_string(), + request_url: domain.to_string(), + request_headers: None, + request_body: None, + response_status: 0, + response_headers: None, + response_snippet: Some(record.clone()), + screenshot_path: None, + payload: None, + response_time_ms: None, + }; + + let mut finding = DastFinding::new( String::new(), target_id.clone(), DastVulnType::EmailSecurity, @@ -297,15 +406,55 @@ impl PentestTool for DmarcCheckerTool { domain.to_string(), "DNS".to_string(), ); - finding.cwe = Some("CWE-290".to_string()); - finding.evidence = vec![evidence]; - finding.remediation = Some( + finding.cwe = Some("CWE-290".to_string()); + finding.evidence = vec![evidence]; + finding.remediation = Some( "Change '+all' to '-all' (hard fail) or '~all' (soft fail) in your SPF record. \ Only list authorized mail servers." .to_string(), ); - findings.push(finding); - } else if Self::spf_uses_soft_fail(record) { + findings.push(finding); + } else if Self::spf_uses_soft_fail(record) { + let evidence = DastEvidence { + request_method: "DNS".to_string(), + request_url: domain.to_string(), + request_headers: None, + request_body: None, + response_status: 0, + response_headers: None, + response_snippet: Some(record.clone()), + screenshot_path: None, + payload: None, + response_time_ms: None, + }; + + let mut finding = DastFinding::new( + String::new(), + target_id.clone(), + DastVulnType::EmailSecurity, + format!("SPF soft fail for {domain}"), + format!( + "The SPF record for {domain} uses '~all' (soft fail) instead of \ + '-all' (hard fail). Soft fail marks unauthorized emails as suspicious \ + but does not reject them." + ), + Severity::Low, + domain.to_string(), + "DNS".to_string(), + ); + finding.cwe = Some("CWE-290".to_string()); + finding.evidence = vec![evidence]; + finding.remediation = Some( + "Consider changing '~all' to '-all' in your SPF record once you have \ + confirmed all legitimate mail sources are listed." + .to_string(), + ); + findings.push(finding); + } + } + None => { + email_data["spf_record"] = json!(null); + let evidence = DastEvidence { request_method: "DNS".to_string(), request_url: domain.to_string(), @@ -313,7 +462,7 @@ impl PentestTool for DmarcCheckerTool { request_body: None, response_status: 0, response_headers: None, - response_snippet: Some(record.clone()), + response_snippet: Some("No SPF record found".to_string()), screenshot_path: None, payload: None, response_time_ms: None, @@ -323,79 +472,39 @@ impl PentestTool for DmarcCheckerTool { String::new(), target_id.clone(), DastVulnType::EmailSecurity, - format!("SPF soft fail for {domain}"), + format!("Missing SPF record for {domain}"), format!( - "The SPF record for {domain} uses '~all' (soft fail) instead of \ - '-all' (hard fail). Soft fail marks unauthorized emails as suspicious \ - but does not reject them." - ), - Severity::Low, + "No SPF record was found for {domain}. Without SPF, any server can claim \ + to send email on behalf of this domain." + ), + Severity::High, domain.to_string(), "DNS".to_string(), ); finding.cwe = Some("CWE-290".to_string()); finding.evidence = vec![evidence]; finding.remediation = Some( - "Consider changing '~all' to '-all' in your SPF record once you have \ - confirmed all legitimate mail sources are listed." + "Create an SPF TXT record for your domain. Example: \ + 'v=spf1 include:_spf.google.com -all'." .to_string(), ); findings.push(finding); + warn!(domain, "No SPF record found"); } } - None => { - email_data["spf_record"] = json!(null); - let evidence = DastEvidence { - request_method: "DNS".to_string(), - request_url: domain.to_string(), - request_headers: None, - request_body: None, - response_status: 0, - response_headers: None, - response_snippet: Some("No SPF record found".to_string()), - screenshot_path: None, - payload: None, - response_time_ms: None, - }; + let count = findings.len(); + info!(domain, findings = count, "DMARC/SPF check complete"); - let mut finding = DastFinding::new( - String::new(), - target_id.clone(), - DastVulnType::EmailSecurity, - format!("Missing SPF record for {domain}"), - format!( - "No SPF record was found for {domain}. Without SPF, any server can claim \ - to send email on behalf of this domain." - ), - Severity::High, - domain.to_string(), - "DNS".to_string(), - ); - finding.cwe = Some("CWE-290".to_string()); - finding.evidence = vec![evidence]; - finding.remediation = Some( - "Create an SPF TXT record for your domain. Example: \ - 'v=spf1 include:_spf.google.com -all'." - .to_string(), - ); - findings.push(finding); - warn!(domain, "No SPF record found"); - } - } - - let count = findings.len(); - info!(domain, findings = count, "DMARC/SPF check complete"); - - Ok(PentestToolResult { - summary: if count > 0 { - format!("Found {count} email security issues for {domain}.") - } else { - format!("Email security configuration looks good for {domain}.") - }, - findings, - data: email_data, - }) + Ok(PentestToolResult { + summary: if count > 0 { + format!("Found {count} email security issues for {domain}.") + } else { + format!("Email security configuration looks good for {domain}.") + }, + findings, + data: email_data, + }) }) } } diff --git a/compliance-dast/src/tools/dns_checker.rs b/compliance-dast/src/tools/dns_checker.rs index 07d9c21..b9dd90d 100644 --- a/compliance-dast/src/tools/dns_checker.rs +++ b/compliance-dast/src/tools/dns_checker.rs @@ -16,6 +16,12 @@ use tracing::{info, warn}; /// `tokio::process::Command` wrapper around `dig` where available. pub struct DnsCheckerTool; +impl Default for DnsCheckerTool { + fn default() -> Self { + Self::new() + } +} + impl DnsCheckerTool { pub fn new() -> Self { Self @@ -54,7 +60,9 @@ impl DnsCheckerTool { } } Err(e) => { - return Err(CoreError::Dast(format!("DNS resolution failed for {domain}: {e}"))); + return Err(CoreError::Dast(format!( + "DNS resolution failed for {domain}: {e}" + ))); } } @@ -94,107 +102,111 @@ impl PentestTool for DnsCheckerTool { &'a self, input: serde_json::Value, context: &'a PentestToolContext, - ) -> std::pin::Pin> + Send + 'a>> { + ) -> std::pin::Pin< + Box> + Send + 'a>, + > { Box::pin(async move { - let domain = input - .get("domain") - .and_then(|v| v.as_str()) - .ok_or_else(|| CoreError::Dast("Missing required 'domain' parameter".to_string()))?; + let domain = input + .get("domain") + .and_then(|v| v.as_str()) + .ok_or_else(|| { + CoreError::Dast("Missing required 'domain' parameter".to_string()) + })?; - let subdomains: Vec = input - .get("subdomains") - .and_then(|v| v.as_array()) - .map(|arr| { - arr.iter() - .filter_map(|v| v.as_str().map(String::from)) - .collect() - }) - .unwrap_or_default(); + let subdomains: Vec = input + .get("subdomains") + .and_then(|v| v.as_array()) + .map(|arr| { + arr.iter() + .filter_map(|v| v.as_str().map(String::from)) + .collect() + }) + .unwrap_or_default(); - let mut findings = Vec::new(); - let mut dns_data: HashMap = HashMap::new(); + let mut findings = Vec::new(); + let mut dns_data: HashMap = HashMap::new(); - let target_id = context - .target - .id - .map(|oid| oid.to_hex()) - .unwrap_or_else(|| "unknown".to_string()); + let target_id = context + .target + .id + .map(|oid| oid.to_hex()) + .unwrap_or_else(|| "unknown".to_string()); - // --- A / AAAA records --- - match Self::resolve_addresses(domain).await { - Ok((ipv4, ipv6)) => { - dns_data.insert("a_records".to_string(), json!(ipv4)); - dns_data.insert("aaaa_records".to_string(), json!(ipv6)); + // --- A / AAAA records --- + match Self::resolve_addresses(domain).await { + Ok((ipv4, ipv6)) => { + dns_data.insert("a_records".to_string(), json!(ipv4)); + dns_data.insert("aaaa_records".to_string(), json!(ipv6)); + } + Err(e) => { + dns_data.insert("a_records_error".to_string(), json!(e.to_string())); + } } - Err(e) => { - dns_data.insert("a_records_error".to_string(), json!(e.to_string())); - } - } - // --- MX records --- - match Self::dig_query(domain, "MX").await { - Ok(mx) => { - dns_data.insert("mx_records".to_string(), json!(mx)); + // --- MX records --- + match Self::dig_query(domain, "MX").await { + Ok(mx) => { + dns_data.insert("mx_records".to_string(), json!(mx)); + } + Err(e) => { + dns_data.insert("mx_records_error".to_string(), json!(e.to_string())); + } } - Err(e) => { - dns_data.insert("mx_records_error".to_string(), json!(e.to_string())); - } - } - // --- NS records --- - let ns_records = match Self::dig_query(domain, "NS").await { - Ok(ns) => { - dns_data.insert("ns_records".to_string(), json!(ns)); - ns - } - Err(e) => { - dns_data.insert("ns_records_error".to_string(), json!(e.to_string())); - Vec::new() - } - }; + // --- NS records --- + let ns_records = match Self::dig_query(domain, "NS").await { + Ok(ns) => { + dns_data.insert("ns_records".to_string(), json!(ns)); + ns + } + Err(e) => { + dns_data.insert("ns_records_error".to_string(), json!(e.to_string())); + Vec::new() + } + }; - // --- TXT records --- - match Self::dig_query(domain, "TXT").await { - Ok(txt) => { - dns_data.insert("txt_records".to_string(), json!(txt)); + // --- TXT records --- + match Self::dig_query(domain, "TXT").await { + Ok(txt) => { + dns_data.insert("txt_records".to_string(), json!(txt)); + } + Err(e) => { + dns_data.insert("txt_records_error".to_string(), json!(e.to_string())); + } } - Err(e) => { - dns_data.insert("txt_records_error".to_string(), json!(e.to_string())); + + // --- CNAME records (for subdomains) --- + let mut cname_data: HashMap> = HashMap::new(); + let mut domains_to_check = vec![domain.to_string()]; + for sub in &subdomains { + domains_to_check.push(format!("{sub}.{domain}")); } - } - // --- CNAME records (for subdomains) --- - let mut cname_data: HashMap> = HashMap::new(); - let mut domains_to_check = vec![domain.to_string()]; - for sub in &subdomains { - domains_to_check.push(format!("{sub}.{domain}")); - } + for fqdn in &domains_to_check { + match Self::dig_query(fqdn, "CNAME").await { + Ok(cnames) if !cnames.is_empty() => { + // Check for dangling CNAME + for cname in &cnames { + let cname_clean = cname.trim_end_matches('.'); + let check_addr = format!("{cname_clean}:443"); + let is_dangling = lookup_host(&check_addr).await.is_err(); + if is_dangling { + let evidence = DastEvidence { + request_method: "DNS".to_string(), + request_url: fqdn.clone(), + request_headers: None, + request_body: None, + response_status: 0, + response_headers: None, + response_snippet: Some(format!( + "CNAME {fqdn} -> {cname} (target does not resolve)" + )), + screenshot_path: None, + payload: None, + response_time_ms: None, + }; - for fqdn in &domains_to_check { - match Self::dig_query(fqdn, "CNAME").await { - Ok(cnames) if !cnames.is_empty() => { - // Check for dangling CNAME - for cname in &cnames { - let cname_clean = cname.trim_end_matches('.'); - let check_addr = format!("{cname_clean}:443"); - let is_dangling = lookup_host(&check_addr).await.is_err(); - if is_dangling { - let evidence = DastEvidence { - request_method: "DNS".to_string(), - request_url: fqdn.clone(), - request_headers: None, - request_body: None, - response_status: 0, - response_headers: None, - response_snippet: Some(format!( - "CNAME {fqdn} -> {cname} (target does not resolve)" - )), - screenshot_path: None, - payload: None, - response_time_ms: None, - }; - - let mut finding = DastFinding::new( + let mut finding = DastFinding::new( String::new(), target_id.clone(), DastVulnType::DnsMisconfiguration, @@ -207,44 +219,47 @@ impl PentestTool for DnsCheckerTool { fqdn.clone(), "DNS".to_string(), ); - finding.cwe = Some("CWE-923".to_string()); - finding.evidence = vec![evidence]; - finding.remediation = Some( + finding.cwe = Some("CWE-923".to_string()); + finding.evidence = vec![evidence]; + finding.remediation = Some( "Remove dangling CNAME records or ensure the target hostname is \ properly configured and resolvable." .to_string(), ); - findings.push(finding); - warn!(fqdn, cname, "Dangling CNAME detected - potential subdomain takeover"); + findings.push(finding); + warn!( + fqdn, + cname, "Dangling CNAME detected - potential subdomain takeover" + ); + } } + cname_data.insert(fqdn.clone(), cnames); } - cname_data.insert(fqdn.clone(), cnames); + _ => {} } - _ => {} } - } - if !cname_data.is_empty() { - dns_data.insert("cname_records".to_string(), json!(cname_data)); - } + if !cname_data.is_empty() { + dns_data.insert("cname_records".to_string(), json!(cname_data)); + } - // --- CAA records --- - match Self::dig_query(domain, "CAA").await { - Ok(caa) => { - if caa.is_empty() { - let evidence = DastEvidence { - request_method: "DNS".to_string(), - request_url: domain.to_string(), - request_headers: None, - request_body: None, - response_status: 0, - response_headers: None, - response_snippet: Some("No CAA records found".to_string()), - screenshot_path: None, - payload: None, - response_time_ms: None, - }; + // --- CAA records --- + match Self::dig_query(domain, "CAA").await { + Ok(caa) => { + if caa.is_empty() { + let evidence = DastEvidence { + request_method: "DNS".to_string(), + request_url: domain.to_string(), + request_headers: None, + request_body: None, + response_status: 0, + response_headers: None, + response_snippet: Some("No CAA records found".to_string()), + screenshot_path: None, + payload: None, + response_time_ms: None, + }; - let mut finding = DastFinding::new( + let mut finding = DastFinding::new( String::new(), target_id.clone(), DastVulnType::DnsMisconfiguration, @@ -257,35 +272,83 @@ impl PentestTool for DnsCheckerTool { domain.to_string(), "DNS".to_string(), ); - finding.cwe = Some("CWE-295".to_string()); - finding.evidence = vec![evidence]; - finding.remediation = Some( + finding.cwe = Some("CWE-295".to_string()); + finding.evidence = vec![evidence]; + finding.remediation = Some( "Add CAA DNS records to restrict which certificate authorities can issue \ certificates for your domain. Example: '0 issue \"letsencrypt.org\"'." .to_string(), ); - findings.push(finding); + findings.push(finding); + } + dns_data.insert("caa_records".to_string(), json!(caa)); + } + Err(e) => { + dns_data.insert("caa_records_error".to_string(), json!(e.to_string())); } - dns_data.insert("caa_records".to_string(), json!(caa)); } - Err(e) => { - dns_data.insert("caa_records_error".to_string(), json!(e.to_string())); + + // --- DNSSEC check --- + let dnssec_output = tokio::process::Command::new("dig") + .args(["+dnssec", "+short", "DNSKEY", domain]) + .output() + .await; + + match dnssec_output { + Ok(output) => { + let stdout = String::from_utf8_lossy(&output.stdout); + let has_dnssec = !stdout.trim().is_empty(); + dns_data.insert("dnssec_enabled".to_string(), json!(has_dnssec)); + + if !has_dnssec { + let evidence = DastEvidence { + request_method: "DNS".to_string(), + request_url: domain.to_string(), + request_headers: None, + request_body: None, + response_status: 0, + response_headers: None, + response_snippet: Some( + "No DNSKEY records found - DNSSEC not enabled".to_string(), + ), + screenshot_path: None, + payload: None, + response_time_ms: None, + }; + + let mut finding = DastFinding::new( + String::new(), + target_id.clone(), + DastVulnType::DnsMisconfiguration, + format!("DNSSEC not enabled for {domain}"), + format!( + "DNSSEC is not enabled for {domain}. Without DNSSEC, DNS responses \ + can be spoofed, allowing man-in-the-middle attacks." + ), + Severity::Medium, + domain.to_string(), + "DNS".to_string(), + ); + finding.cwe = Some("CWE-350".to_string()); + finding.evidence = vec![evidence]; + finding.remediation = Some( + "Enable DNSSEC for your domain by configuring DNSKEY and DS records \ + with your DNS provider and domain registrar." + .to_string(), + ); + findings.push(finding); + } + } + Err(_) => { + dns_data.insert("dnssec_check_error".to_string(), json!("dig not available")); + } } - } - // --- DNSSEC check --- - let dnssec_output = tokio::process::Command::new("dig") - .args(["+dnssec", "+short", "DNSKEY", domain]) - .output() - .await; - - match dnssec_output { - Ok(output) => { - let stdout = String::from_utf8_lossy(&output.stdout); - let has_dnssec = !stdout.trim().is_empty(); - dns_data.insert("dnssec_enabled".to_string(), json!(has_dnssec)); - - if !has_dnssec { + // --- Check NS records for dangling --- + for ns in &ns_records { + let ns_clean = ns.trim_end_matches('.'); + let check_addr = format!("{ns_clean}:53"); + if lookup_host(&check_addr).await.is_err() { let evidence = DastEvidence { request_method: "DNS".to_string(), request_url: domain.to_string(), @@ -293,61 +356,13 @@ impl PentestTool for DnsCheckerTool { request_body: None, response_status: 0, response_headers: None, - response_snippet: Some("No DNSKEY records found - DNSSEC not enabled".to_string()), + response_snippet: Some(format!("NS record {ns} does not resolve")), screenshot_path: None, payload: None, response_time_ms: None, }; let mut finding = DastFinding::new( - String::new(), - target_id.clone(), - DastVulnType::DnsMisconfiguration, - format!("DNSSEC not enabled for {domain}"), - format!( - "DNSSEC is not enabled for {domain}. Without DNSSEC, DNS responses \ - can be spoofed, allowing man-in-the-middle attacks." - ), - Severity::Medium, - domain.to_string(), - "DNS".to_string(), - ); - finding.cwe = Some("CWE-350".to_string()); - finding.evidence = vec![evidence]; - finding.remediation = Some( - "Enable DNSSEC for your domain by configuring DNSKEY and DS records \ - with your DNS provider and domain registrar." - .to_string(), - ); - findings.push(finding); - } - } - Err(_) => { - dns_data.insert("dnssec_check_error".to_string(), json!("dig not available")); - } - } - - // --- Check NS records for dangling --- - for ns in &ns_records { - let ns_clean = ns.trim_end_matches('.'); - let check_addr = format!("{ns_clean}:53"); - if lookup_host(&check_addr).await.is_err() { - let evidence = DastEvidence { - request_method: "DNS".to_string(), - request_url: domain.to_string(), - request_headers: None, - request_body: None, - response_status: 0, - response_headers: None, - response_snippet: Some(format!( - "NS record {ns} does not resolve" - )), - screenshot_path: None, - payload: None, - response_time_ms: None, - }; - - let mut finding = DastFinding::new( String::new(), target_id.clone(), DastVulnType::DnsMisconfiguration, @@ -360,30 +375,33 @@ impl PentestTool for DnsCheckerTool { domain.to_string(), "DNS".to_string(), ); - finding.cwe = Some("CWE-923".to_string()); - finding.evidence = vec![evidence]; - finding.remediation = Some( - "Remove dangling NS records or ensure the nameserver hostname is properly \ + finding.cwe = Some("CWE-923".to_string()); + finding.evidence = vec![evidence]; + finding.remediation = Some( + "Remove dangling NS records or ensure the nameserver hostname is properly \ configured. Dangling NS records can lead to full domain takeover." - .to_string(), - ); - findings.push(finding); - warn!(domain, ns, "Dangling NS record detected - potential domain takeover"); + .to_string(), + ); + findings.push(finding); + warn!( + domain, + ns, "Dangling NS record detected - potential domain takeover" + ); + } } - } - let count = findings.len(); - info!(domain, findings = count, "DNS check complete"); + let count = findings.len(); + info!(domain, findings = count, "DNS check complete"); - Ok(PentestToolResult { - summary: if count > 0 { - format!("Found {count} DNS configuration issues for {domain}.") - } else { - format!("No DNS configuration issues found for {domain}.") - }, - findings, - data: json!(dns_data), - }) + Ok(PentestToolResult { + summary: if count > 0 { + format!("Found {count} DNS configuration issues for {domain}.") + } else { + format!("No DNS configuration issues found for {domain}.") + }, + findings, + data: json!(dns_data), + }) }) } } diff --git a/compliance-dast/src/tools/mod.rs b/compliance-dast/src/tools/mod.rs index 5bff516..318052f 100644 --- a/compliance-dast/src/tools/mod.rs +++ b/compliance-dast/src/tools/mod.rs @@ -33,8 +33,15 @@ pub struct ToolRegistry { tools: HashMap>, } +impl Default for ToolRegistry { + fn default() -> Self { + Self::new() + } +} + impl ToolRegistry { /// Create a new registry with all built-in tools pre-registered. + #[allow(clippy::expect_used)] pub fn new() -> Self { let http = reqwest::Client::builder() .danger_accept_invalid_certs(true) @@ -67,13 +74,10 @@ impl ToolRegistry { ); // New infrastructure / analysis tools + register(&mut tools, Box::::default()); register( &mut tools, - Box::new(dns_checker::DnsCheckerTool::new()), - ); - register( - &mut tools, - Box::new(dmarc_checker::DmarcCheckerTool::new()), + Box::::default(), ); register( &mut tools, @@ -109,10 +113,7 @@ impl ToolRegistry { &mut tools, Box::new(openapi_parser::OpenApiParserTool::new(http.clone())), ); - register( - &mut tools, - Box::new(recon::ReconTool::new(http)), - ); + register(&mut tools, Box::new(recon::ReconTool::new(http))); Self { tools } } diff --git a/compliance-dast/src/tools/openapi_parser.rs b/compliance-dast/src/tools/openapi_parser.rs index 044fc16..bb6f52d 100644 --- a/compliance-dast/src/tools/openapi_parser.rs +++ b/compliance-dast/src/tools/openapi_parser.rs @@ -92,7 +92,10 @@ impl OpenApiParserTool { // If content type suggests YAML, we can't easily parse without a YAML dep, // so just report the URL as found - if content_type.contains("yaml") || body.starts_with("openapi:") || body.starts_with("swagger:") { + if content_type.contains("yaml") + || body.starts_with("openapi:") + || body.starts_with("swagger:") + { // Return a minimal JSON indicating YAML was found return Some(( url.to_string(), @@ -107,7 +110,7 @@ impl OpenApiParserTool { } /// Parse an OpenAPI 3.x or Swagger 2.x spec into structured endpoints. - fn parse_spec(spec: &serde_json::Value, base_url: &str) -> Vec { + fn parse_spec(spec: &serde_json::Value, _base_url: &str) -> Vec { let mut endpoints = Vec::new(); // Determine base path @@ -258,6 +261,166 @@ impl OpenApiParserTool { } } +#[cfg(test)] +mod tests { + use super::*; + use serde_json::json; + + #[test] + fn common_spec_paths_not_empty() { + let paths = OpenApiParserTool::common_spec_paths(); + assert!(paths.len() >= 5); + assert!(paths.contains(&"/openapi.json")); + assert!(paths.contains(&"/swagger.json")); + } + + #[test] + fn parse_spec_openapi3_basic() { + let spec = json!({ + "openapi": "3.0.0", + "info": { "title": "Test API", "version": "1.0" }, + "paths": { + "/users": { + "get": { + "operationId": "listUsers", + "summary": "List all users", + "parameters": [ + { + "name": "limit", + "in": "query", + "required": false, + "schema": { "type": "integer" } + } + ], + "responses": { + "200": { "description": "OK" }, + "401": { "description": "Unauthorized" } + }, + "tags": ["users"] + }, + "post": { + "operationId": "createUser", + "requestBody": { + "content": { + "application/json": {} + } + }, + "responses": { "201": {} }, + "security": [{ "bearerAuth": [] }] + } + } + } + }); + + let endpoints = OpenApiParserTool::parse_spec(&spec, "https://api.example.com"); + assert_eq!(endpoints.len(), 2); + + let get_ep = endpoints.iter().find(|e| e.method == "GET").unwrap(); + assert_eq!(get_ep.path, "/users"); + assert_eq!(get_ep.operation_id.as_deref(), Some("listUsers")); + assert_eq!(get_ep.summary.as_deref(), Some("List all users")); + assert_eq!(get_ep.parameters.len(), 1); + assert_eq!(get_ep.parameters[0].name, "limit"); + assert_eq!(get_ep.parameters[0].location, "query"); + assert!(!get_ep.parameters[0].required); + assert_eq!(get_ep.parameters[0].param_type.as_deref(), Some("integer")); + assert_eq!(get_ep.response_codes.len(), 2); + assert_eq!(get_ep.tags, vec!["users"]); + + let post_ep = endpoints.iter().find(|e| e.method == "POST").unwrap(); + assert_eq!( + post_ep.request_body_content_type.as_deref(), + Some("application/json") + ); + assert_eq!(post_ep.security, vec!["bearerAuth"]); + } + + #[test] + fn parse_spec_swagger2_with_base_path() { + let spec = json!({ + "swagger": "2.0", + "basePath": "/api/v1", + "paths": { + "/items": { + "get": { + "parameters": [ + { "name": "id", "in": "path", "required": true, "type": "string" } + ], + "responses": { "200": {} } + } + } + } + }); + + let endpoints = OpenApiParserTool::parse_spec(&spec, "https://api.example.com"); + assert_eq!(endpoints.len(), 1); + assert_eq!(endpoints[0].path, "/api/v1/items"); + assert_eq!( + endpoints[0].parameters[0].param_type.as_deref(), + Some("string") + ); + } + + #[test] + fn parse_spec_empty_paths() { + let spec = json!({ "openapi": "3.0.0", "paths": {} }); + let endpoints = OpenApiParserTool::parse_spec(&spec, "https://example.com"); + assert!(endpoints.is_empty()); + } + + #[test] + fn parse_spec_no_paths_key() { + let spec = json!({ "openapi": "3.0.0" }); + let endpoints = OpenApiParserTool::parse_spec(&spec, "https://example.com"); + assert!(endpoints.is_empty()); + } + + #[test] + fn parse_spec_servers_base_url() { + let spec = json!({ + "openapi": "3.0.0", + "servers": [{ "url": "/api/v2" }], + "paths": { + "/health": { + "get": { "responses": { "200": {} } } + } + } + }); + let endpoints = OpenApiParserTool::parse_spec(&spec, "https://example.com"); + assert_eq!(endpoints[0].path, "/api/v2/health"); + } + + #[test] + fn parse_spec_path_level_parameters_merged() { + let spec = json!({ + "openapi": "3.0.0", + "paths": { + "/items/{id}": { + "parameters": [ + { "name": "id", "in": "path", "required": true, "schema": { "type": "string" } } + ], + "get": { + "parameters": [ + { "name": "fields", "in": "query", "schema": { "type": "string" } } + ], + "responses": { "200": {} } + } + } + } + }); + let endpoints = OpenApiParserTool::parse_spec(&spec, "https://example.com"); + assert_eq!(endpoints[0].parameters.len(), 2); + assert!(endpoints[0] + .parameters + .iter() + .any(|p| p.name == "id" && p.location == "path")); + assert!(endpoints[0] + .parameters + .iter() + .any(|p| p.name == "fields" && p.location == "query")); + } +} + impl PentestTool for OpenApiParserTool { fn name(&self) -> &str { "openapi_parser" @@ -289,134 +452,138 @@ impl PentestTool for OpenApiParserTool { fn execute<'a>( &'a self, input: serde_json::Value, - context: &'a PentestToolContext, - ) -> std::pin::Pin> + Send + 'a>> { + _context: &'a PentestToolContext, + ) -> std::pin::Pin< + Box> + Send + 'a>, + > { Box::pin(async move { - let base_url = input - .get("base_url") - .and_then(|v| v.as_str()) - .ok_or_else(|| CoreError::Dast("Missing required 'base_url' parameter".to_string()))?; + let base_url = input + .get("base_url") + .and_then(|v| v.as_str()) + .ok_or_else(|| { + CoreError::Dast("Missing required 'base_url' parameter".to_string()) + })?; - let explicit_spec_url = input.get("spec_url").and_then(|v| v.as_str()); + let explicit_spec_url = input.get("spec_url").and_then(|v| v.as_str()); - let base_url_trimmed = base_url.trim_end_matches('/'); + let base_url_trimmed = base_url.trim_end_matches('/'); - // If an explicit spec URL is provided, try it first - let mut spec_result: Option<(String, serde_json::Value)> = None; - if let Some(spec_url) = explicit_spec_url { - spec_result = Self::try_fetch_spec(&self.http, spec_url).await; - } + // If an explicit spec URL is provided, try it first + let mut spec_result: Option<(String, serde_json::Value)> = None; + if let Some(spec_url) = explicit_spec_url { + spec_result = Self::try_fetch_spec(&self.http, spec_url).await; + } - // If no explicit URL or it failed, try common paths - if spec_result.is_none() { - for path in Self::common_spec_paths() { - let url = format!("{base_url_trimmed}{path}"); - if let Some(result) = Self::try_fetch_spec(&self.http, &url).await { - spec_result = Some(result); - break; + // If no explicit URL or it failed, try common paths + if spec_result.is_none() { + for path in Self::common_spec_paths() { + let url = format!("{base_url_trimmed}{path}"); + if let Some(result) = Self::try_fetch_spec(&self.http, &url).await { + spec_result = Some(result); + break; + } } } - } - match spec_result { - Some((spec_url, spec)) => { - let spec_version = spec - .get("openapi") - .or_else(|| spec.get("swagger")) - .and_then(|v| v.as_str()) - .unwrap_or("unknown"); + match spec_result { + Some((spec_url, spec)) => { + let spec_version = spec + .get("openapi") + .or_else(|| spec.get("swagger")) + .and_then(|v| v.as_str()) + .unwrap_or("unknown"); - let api_title = spec - .get("info") - .and_then(|i| i.get("title")) - .and_then(|t| t.as_str()) - .unwrap_or("Unknown API"); + let api_title = spec + .get("info") + .and_then(|i| i.get("title")) + .and_then(|t| t.as_str()) + .unwrap_or("Unknown API"); - let api_version = spec - .get("info") - .and_then(|i| i.get("version")) - .and_then(|v| v.as_str()) - .unwrap_or("unknown"); + let api_version = spec + .get("info") + .and_then(|i| i.get("version")) + .and_then(|v| v.as_str()) + .unwrap_or("unknown"); - let endpoints = Self::parse_spec(&spec, base_url_trimmed); + let endpoints = Self::parse_spec(&spec, base_url_trimmed); - let endpoint_data: Vec = endpoints - .iter() - .map(|ep| { - let params: Vec = ep - .parameters - .iter() - .map(|p| { - json!({ - "name": p.name, - "in": p.location, - "required": p.required, - "type": p.param_type, - "description": p.description, + let endpoint_data: Vec = endpoints + .iter() + .map(|ep| { + let params: Vec = ep + .parameters + .iter() + .map(|p| { + json!({ + "name": p.name, + "in": p.location, + "required": p.required, + "type": p.param_type, + "description": p.description, + }) }) + .collect(); + + json!({ + "path": ep.path, + "method": ep.method, + "operation_id": ep.operation_id, + "summary": ep.summary, + "parameters": params, + "request_body_content_type": ep.request_body_content_type, + "response_codes": ep.response_codes, + "security": ep.security, + "tags": ep.tags, }) - .collect(); - - json!({ - "path": ep.path, - "method": ep.method, - "operation_id": ep.operation_id, - "summary": ep.summary, - "parameters": params, - "request_body_content_type": ep.request_body_content_type, - "response_codes": ep.response_codes, - "security": ep.security, - "tags": ep.tags, }) - }) - .collect(); + .collect(); - let endpoint_count = endpoints.len(); - info!( - spec_url = %spec_url, - spec_version, - api_title, - endpoints = endpoint_count, - "OpenAPI spec parsed" - ); + let endpoint_count = endpoints.len(); + info!( + spec_url = %spec_url, + spec_version, + api_title, + endpoints = endpoint_count, + "OpenAPI spec parsed" + ); - Ok(PentestToolResult { - summary: format!( - "Found OpenAPI spec ({spec_version}) at {spec_url}. \ + Ok(PentestToolResult { + summary: format!( + "Found OpenAPI spec ({spec_version}) at {spec_url}. \ API: {api_title} v{api_version}. \ Parsed {endpoint_count} endpoints." - ), - findings: Vec::new(), // This tool produces data, not findings - data: json!({ - "spec_url": spec_url, - "spec_version": spec_version, - "api_title": api_title, - "api_version": api_version, - "endpoint_count": endpoint_count, - "endpoints": endpoint_data, - "security_schemes": spec.get("components") - .and_then(|c| c.get("securitySchemes")) - .or_else(|| spec.get("securityDefinitions")), - }), - }) - } - None => { - info!(base_url, "No OpenAPI spec found"); + ), + findings: Vec::new(), // This tool produces data, not findings + data: json!({ + "spec_url": spec_url, + "spec_version": spec_version, + "api_title": api_title, + "api_version": api_version, + "endpoint_count": endpoint_count, + "endpoints": endpoint_data, + "security_schemes": spec.get("components") + .and_then(|c| c.get("securitySchemes")) + .or_else(|| spec.get("securityDefinitions")), + }), + }) + } + None => { + info!(base_url, "No OpenAPI spec found"); - Ok(PentestToolResult { - summary: format!( - "No OpenAPI/Swagger specification found for {base_url}. \ + Ok(PentestToolResult { + summary: format!( + "No OpenAPI/Swagger specification found for {base_url}. \ Tried {} common paths.", - Self::common_spec_paths().len() - ), - findings: Vec::new(), - data: json!({ - "spec_found": false, - "paths_tried": Self::common_spec_paths(), - }), - }) + Self::common_spec_paths().len() + ), + findings: Vec::new(), + data: json!({ + "spec_found": false, + "paths_tried": Self::common_spec_paths(), + }), + }) + } } - } }) } } diff --git a/compliance-dast/src/tools/rate_limit_tester.rs b/compliance-dast/src/tools/rate_limit_tester.rs index 961d914..2bb7c17 100644 --- a/compliance-dast/src/tools/rate_limit_tester.rs +++ b/compliance-dast/src/tools/rate_limit_tester.rs @@ -62,224 +62,229 @@ impl PentestTool for RateLimitTesterTool { &'a self, input: serde_json::Value, context: &'a PentestToolContext, - ) -> std::pin::Pin> + Send + 'a>> { + ) -> std::pin::Pin< + Box> + Send + 'a>, + > { Box::pin(async move { - let url = input - .get("url") - .and_then(|v| v.as_str()) - .ok_or_else(|| CoreError::Dast("Missing required 'url' parameter".to_string()))?; + let url = input + .get("url") + .and_then(|v| v.as_str()) + .ok_or_else(|| CoreError::Dast("Missing required 'url' parameter".to_string()))?; - let method = input - .get("method") - .and_then(|v| v.as_str()) - .unwrap_or("GET"); + let method = input + .get("method") + .and_then(|v| v.as_str()) + .unwrap_or("GET"); - let request_count = input - .get("request_count") - .and_then(|v| v.as_u64()) - .unwrap_or(50) - .min(200) as usize; + let request_count = input + .get("request_count") + .and_then(|v| v.as_u64()) + .unwrap_or(50) + .min(200) as usize; - let body = input.get("body").and_then(|v| v.as_str()); + let body = input.get("body").and_then(|v| v.as_str()); - let target_id = context - .target - .id - .map(|oid| oid.to_hex()) - .unwrap_or_else(|| "unknown".to_string()); + let target_id = context + .target + .id + .map(|oid| oid.to_hex()) + .unwrap_or_else(|| "unknown".to_string()); - // Respect the context rate limit if set - let max_requests = if context.rate_limit > 0 { - request_count.min(context.rate_limit as usize * 5) - } else { - request_count - }; - - let mut status_codes: Vec = Vec::with_capacity(max_requests); - let mut response_times: Vec = Vec::with_capacity(max_requests); - let mut got_429 = false; - let mut rate_limit_at_request: Option = None; - - for i in 0..max_requests { - let start = Instant::now(); - - let request = match method { - "POST" => { - let mut req = self.http.post(url); - if let Some(b) = body { - req = req.body(b.to_string()); - } - req - } - "PUT" => { - let mut req = self.http.put(url); - if let Some(b) = body { - req = req.body(b.to_string()); - } - req - } - "PATCH" => { - let mut req = self.http.patch(url); - if let Some(b) = body { - req = req.body(b.to_string()); - } - req - } - "DELETE" => self.http.delete(url), - _ => self.http.get(url), + // Respect the context rate limit if set + let max_requests = if context.rate_limit > 0 { + request_count.min(context.rate_limit as usize * 5) + } else { + request_count }; - match request.send().await { - Ok(resp) => { - let elapsed = start.elapsed().as_millis(); - let status = resp.status().as_u16(); - status_codes.push(status); - response_times.push(elapsed); + let mut status_codes: Vec = Vec::with_capacity(max_requests); + let mut response_times: Vec = Vec::with_capacity(max_requests); + let mut got_429 = false; + let mut rate_limit_at_request: Option = None; - if status == 429 && !got_429 { - got_429 = true; - rate_limit_at_request = Some(i + 1); - info!(url, request_num = i + 1, "Rate limit triggered (429)"); + for i in 0..max_requests { + let start = Instant::now(); + + let request = match method { + "POST" => { + let mut req = self.http.post(url); + if let Some(b) = body { + req = req.body(b.to_string()); + } + req } + "PUT" => { + let mut req = self.http.put(url); + if let Some(b) = body { + req = req.body(b.to_string()); + } + req + } + "PATCH" => { + let mut req = self.http.patch(url); + if let Some(b) = body { + req = req.body(b.to_string()); + } + req + } + "DELETE" => self.http.delete(url), + _ => self.http.get(url), + }; - // Check for rate limit headers even on 200 - if !got_429 { - let headers = resp.headers(); - let has_rate_headers = headers.contains_key("x-ratelimit-limit") - || headers.contains_key("x-ratelimit-remaining") - || headers.contains_key("ratelimit-limit") - || headers.contains_key("ratelimit-remaining") - || headers.contains_key("retry-after"); + match request.send().await { + Ok(resp) => { + let elapsed = start.elapsed().as_millis(); + let status = resp.status().as_u16(); + status_codes.push(status); + response_times.push(elapsed); - if has_rate_headers && rate_limit_at_request.is_none() { - // Server has rate limit headers but hasn't blocked yet + if status == 429 && !got_429 { + got_429 = true; + rate_limit_at_request = Some(i + 1); + info!(url, request_num = i + 1, "Rate limit triggered (429)"); + } + + // Check for rate limit headers even on 200 + if !got_429 { + let headers = resp.headers(); + let has_rate_headers = headers.contains_key("x-ratelimit-limit") + || headers.contains_key("x-ratelimit-remaining") + || headers.contains_key("ratelimit-limit") + || headers.contains_key("ratelimit-remaining") + || headers.contains_key("retry-after"); + + if has_rate_headers && rate_limit_at_request.is_none() { + // Server has rate limit headers but hasn't blocked yet + } } } - } - Err(e) => { - let elapsed = start.elapsed().as_millis(); - status_codes.push(0); - response_times.push(elapsed); + Err(_e) => { + let elapsed = start.elapsed().as_millis(); + status_codes.push(0); + response_times.push(elapsed); + } } } - } - let mut findings = Vec::new(); - let total_sent = status_codes.len(); - let count_429 = status_codes.iter().filter(|&&s| s == 429).count(); - let count_success = status_codes.iter().filter(|&&s| (200..300).contains(&s)).count(); + let mut findings = Vec::new(); + let total_sent = status_codes.len(); + let count_429 = status_codes.iter().filter(|&&s| s == 429).count(); + let count_success = status_codes + .iter() + .filter(|&&s| (200..300).contains(&s)) + .count(); - // Calculate response time statistics - let avg_time = if !response_times.is_empty() { - response_times.iter().sum::() / response_times.len() as u128 - } else { - 0 - }; - - let first_half_avg = if response_times.len() >= 4 { - let half = response_times.len() / 2; - response_times[..half].iter().sum::() / half as u128 - } else { - avg_time - }; - - let second_half_avg = if response_times.len() >= 4 { - let half = response_times.len() / 2; - response_times[half..].iter().sum::() / (response_times.len() - half) as u128 - } else { - avg_time - }; - - // Significant time degradation suggests possible (weak) rate limiting - let time_degradation = if first_half_avg > 0 { - (second_half_avg as f64 / first_half_avg as f64) - 1.0 - } else { - 0.0 - }; - - let rate_data = json!({ - "total_requests_sent": total_sent, - "status_429_count": count_429, - "success_count": count_success, - "rate_limit_at_request": rate_limit_at_request, - "avg_response_time_ms": avg_time, - "first_half_avg_ms": first_half_avg, - "second_half_avg_ms": second_half_avg, - "time_degradation_pct": (time_degradation * 100.0).round(), - }); - - if !got_429 && count_success == total_sent { - // No rate limiting detected at all - let evidence = DastEvidence { - request_method: method.to_string(), - request_url: url.to_string(), - request_headers: None, - request_body: body.map(String::from), - response_status: 200, - response_headers: None, - response_snippet: Some(format!( - "Sent {total_sent} rapid requests. All returned success (2xx). \ - No 429 responses received. Avg response time: {avg_time}ms." - )), - screenshot_path: None, - payload: None, - response_time_ms: Some(avg_time as u64), + // Calculate response time statistics + let avg_time = if !response_times.is_empty() { + response_times.iter().sum::() / response_times.len() as u128 + } else { + 0 }; - let mut finding = DastFinding::new( - String::new(), - target_id.clone(), - DastVulnType::RateLimitAbsent, - format!("No rate limiting on {} {}", method, url), - format!( - "The endpoint {} {} does not enforce rate limiting. \ + let first_half_avg = if response_times.len() >= 4 { + let half = response_times.len() / 2; + response_times[..half].iter().sum::() / half as u128 + } else { + avg_time + }; + + let second_half_avg = if response_times.len() >= 4 { + let half = response_times.len() / 2; + response_times[half..].iter().sum::() / (response_times.len() - half) as u128 + } else { + avg_time + }; + + // Significant time degradation suggests possible (weak) rate limiting + let time_degradation = if first_half_avg > 0 { + (second_half_avg as f64 / first_half_avg as f64) - 1.0 + } else { + 0.0 + }; + + let rate_data = json!({ + "total_requests_sent": total_sent, + "status_429_count": count_429, + "success_count": count_success, + "rate_limit_at_request": rate_limit_at_request, + "avg_response_time_ms": avg_time, + "first_half_avg_ms": first_half_avg, + "second_half_avg_ms": second_half_avg, + "time_degradation_pct": (time_degradation * 100.0).round(), + }); + + if !got_429 && count_success == total_sent { + // No rate limiting detected at all + let evidence = DastEvidence { + request_method: method.to_string(), + request_url: url.to_string(), + request_headers: None, + request_body: body.map(String::from), + response_status: 200, + response_headers: None, + response_snippet: Some(format!( + "Sent {total_sent} rapid requests. All returned success (2xx). \ + No 429 responses received. Avg response time: {avg_time}ms." + )), + screenshot_path: None, + payload: None, + response_time_ms: Some(avg_time as u64), + }; + + let mut finding = DastFinding::new( + String::new(), + target_id.clone(), + DastVulnType::RateLimitAbsent, + format!("No rate limiting on {} {}", method, url), + format!( + "The endpoint {} {} does not enforce rate limiting. \ {total_sent} rapid requests were all accepted with no 429 responses \ or noticeable degradation. This makes the endpoint vulnerable to \ brute force attacks and abuse.", - method, url - ), - Severity::Medium, - url.to_string(), - method.to_string(), - ); - finding.cwe = Some("CWE-770".to_string()); - finding.evidence = vec![evidence]; - finding.remediation = Some( - "Implement rate limiting on this endpoint. Use token bucket or sliding window \ + method, url + ), + Severity::Medium, + url.to_string(), + method.to_string(), + ); + finding.cwe = Some("CWE-770".to_string()); + finding.evidence = vec![evidence]; + finding.remediation = Some( + "Implement rate limiting on this endpoint. Use token bucket or sliding window \ algorithms. Return 429 Too Many Requests with a Retry-After header when the \ limit is exceeded." - .to_string(), - ); - findings.push(finding); - warn!(url, method, total_sent, "No rate limiting detected"); - } else if got_429 { - info!( - url, - method, - rate_limit_at = ?rate_limit_at_request, - "Rate limiting is enforced" - ); - } + .to_string(), + ); + findings.push(finding); + warn!(url, method, total_sent, "No rate limiting detected"); + } else if got_429 { + info!( + url, + method, + rate_limit_at = ?rate_limit_at_request, + "Rate limiting is enforced" + ); + } - let count = findings.len(); + let count = findings.len(); - Ok(PentestToolResult { - summary: if got_429 { - format!( - "Rate limiting is enforced on {method} {url}. \ + Ok(PentestToolResult { + summary: if got_429 { + format!( + "Rate limiting is enforced on {method} {url}. \ 429 response received after {} requests.", - rate_limit_at_request.unwrap_or(0) - ) - } else if count > 0 { - format!( - "No rate limiting detected on {method} {url} after {total_sent} requests." - ) - } else { - format!("Rate limit testing complete for {method} {url}.") - }, - findings, - data: rate_data, - }) + rate_limit_at_request.unwrap_or(0) + ) + } else if count > 0 { + format!( + "No rate limiting detected on {method} {url} after {total_sent} requests." + ) + } else { + format!("Rate limit testing complete for {method} {url}.") + }, + findings, + data: rate_data, + }) }) } } diff --git a/compliance-dast/src/tools/recon.rs b/compliance-dast/src/tools/recon.rs index 1cefd73..386c352 100644 --- a/compliance-dast/src/tools/recon.rs +++ b/compliance-dast/src/tools/recon.rs @@ -54,72 +54,75 @@ impl PentestTool for ReconTool { fn execute<'a>( &'a self, input: serde_json::Value, - context: &'a PentestToolContext, - ) -> std::pin::Pin> + Send + 'a>> { + _context: &'a PentestToolContext, + ) -> std::pin::Pin< + Box> + Send + 'a>, + > { Box::pin(async move { - let url = input - .get("url") - .and_then(|v| v.as_str()) - .ok_or_else(|| CoreError::Dast("Missing required 'url' parameter".to_string()))?; + let url = input + .get("url") + .and_then(|v| v.as_str()) + .ok_or_else(|| CoreError::Dast("Missing required 'url' parameter".to_string()))?; - let additional_paths: Vec = input - .get("additional_paths") - .and_then(|v| v.as_array()) - .map(|arr| { - arr.iter() - .filter_map(|v| v.as_str().map(String::from)) - .collect() - }) - .unwrap_or_default(); + let additional_paths: Vec = input + .get("additional_paths") + .and_then(|v| v.as_array()) + .map(|arr| { + arr.iter() + .filter_map(|v| v.as_str().map(String::from)) + .collect() + }) + .unwrap_or_default(); - let result = self.agent.scan(url).await?; + let result = self.agent.scan(url).await?; - // Scan additional paths for more technology signals - let mut extra_technologies: Vec = Vec::new(); - let mut extra_headers: HashMap = HashMap::new(); + // Scan additional paths for more technology signals + let mut extra_technologies: Vec = Vec::new(); + let mut extra_headers: HashMap = HashMap::new(); - let base_url = url.trim_end_matches('/'); - for path in &additional_paths { - let probe_url = format!("{base_url}/{}", path.trim_start_matches('/')); - if let Ok(resp) = self.http.get(&probe_url).send().await { - for (key, value) in resp.headers() { - let k = key.to_string().to_lowercase(); - let v = value.to_str().unwrap_or("").to_string(); + let base_url = url.trim_end_matches('/'); + for path in &additional_paths { + let probe_url = format!("{base_url}/{}", path.trim_start_matches('/')); + if let Ok(resp) = self.http.get(&probe_url).send().await { + for (key, value) in resp.headers() { + let k = key.to_string().to_lowercase(); + let v = value.to_str().unwrap_or("").to_string(); - // Look for technology indicators - if k == "x-powered-by" || k == "server" || k == "x-generator" { - if !result.technologies.contains(&v) && !extra_technologies.contains(&v) { + // Look for technology indicators + if (k == "x-powered-by" || k == "server" || k == "x-generator") + && !result.technologies.contains(&v) + && !extra_technologies.contains(&v) + { extra_technologies.push(v.clone()); } + extra_headers.insert(format!("{probe_url} -> {k}"), v); } - extra_headers.insert(format!("{probe_url} -> {k}"), v); } } - } - let mut all_technologies = result.technologies.clone(); - all_technologies.extend(extra_technologies); - all_technologies.dedup(); + let mut all_technologies = result.technologies.clone(); + all_technologies.extend(extra_technologies); + all_technologies.dedup(); - let tech_count = all_technologies.len(); - info!(url, technologies = tech_count, "Recon complete"); + let tech_count = all_technologies.len(); + info!(url, technologies = tech_count, "Recon complete"); - Ok(PentestToolResult { - summary: format!( - "Recon complete for {url}. Detected {} technologies. Server: {}.", - tech_count, - result.server.as_deref().unwrap_or("unknown") - ), - findings: Vec::new(), // Recon produces data, not findings - data: json!({ - "base_url": url, - "server": result.server, - "technologies": all_technologies, - "interesting_headers": result.interesting_headers, - "extra_headers": extra_headers, - "open_ports": result.open_ports, - }), - }) + Ok(PentestToolResult { + summary: format!( + "Recon complete for {url}. Detected {} technologies. Server: {}.", + tech_count, + result.server.as_deref().unwrap_or("unknown") + ), + findings: Vec::new(), // Recon produces data, not findings + data: json!({ + "base_url": url, + "server": result.server, + "technologies": all_technologies, + "interesting_headers": result.interesting_headers, + "extra_headers": extra_headers, + "open_ports": result.open_ports, + }), + }) }) } } diff --git a/compliance-dast/src/tools/security_headers.rs b/compliance-dast/src/tools/security_headers.rs index d9b95a0..b288322 100644 --- a/compliance-dast/src/tools/security_headers.rs +++ b/compliance-dast/src/tools/security_headers.rs @@ -111,57 +111,107 @@ impl PentestTool for SecurityHeadersTool { &'a self, input: serde_json::Value, context: &'a PentestToolContext, - ) -> std::pin::Pin> + Send + 'a>> { + ) -> std::pin::Pin< + Box> + Send + 'a>, + > { Box::pin(async move { - let url = input - .get("url") - .and_then(|v| v.as_str()) - .ok_or_else(|| CoreError::Dast("Missing required 'url' parameter".to_string()))?; + let url = input + .get("url") + .and_then(|v| v.as_str()) + .ok_or_else(|| CoreError::Dast("Missing required 'url' parameter".to_string()))?; - let target_id = context - .target - .id - .map(|oid| oid.to_hex()) - .unwrap_or_else(|| "unknown".to_string()); + let target_id = context + .target + .id + .map(|oid| oid.to_hex()) + .unwrap_or_else(|| "unknown".to_string()); - let response = self - .http - .get(url) - .send() - .await - .map_err(|e| CoreError::Dast(format!("Failed to fetch {url}: {e}")))?; + let response = self + .http + .get(url) + .send() + .await + .map_err(|e| CoreError::Dast(format!("Failed to fetch {url}: {e}")))?; - let status = response.status().as_u16(); - let response_headers: HashMap = response - .headers() - .iter() - .map(|(k, v)| (k.to_string().to_lowercase(), v.to_str().unwrap_or("").to_string())) - .collect(); + let status = response.status().as_u16(); + let response_headers: HashMap = response + .headers() + .iter() + .map(|(k, v)| { + ( + k.to_string().to_lowercase(), + v.to_str().unwrap_or("").to_string(), + ) + }) + .collect(); - let mut findings = Vec::new(); - let mut header_results: HashMap = HashMap::new(); + let mut findings = Vec::new(); + let mut header_results: HashMap = HashMap::new(); - for expected in Self::expected_headers() { - let header_value = response_headers.get(expected.name); + for expected in Self::expected_headers() { + let header_value = response_headers.get(expected.name); - match header_value { - Some(value) => { - let mut is_valid = true; - if let Some(ref valid) = expected.valid_values { - let lower = value.to_lowercase(); - is_valid = valid.iter().any(|v| lower.contains(v)); + match header_value { + Some(value) => { + let mut is_valid = true; + if let Some(ref valid) = expected.valid_values { + let lower = value.to_lowercase(); + is_valid = valid.iter().any(|v| lower.contains(v)); + } + + header_results.insert( + expected.name.to_string(), + json!({ + "present": true, + "value": value, + "valid": is_valid, + }), + ); + + if !is_valid { + let evidence = DastEvidence { + request_method: "GET".to_string(), + request_url: url.to_string(), + request_headers: None, + request_body: None, + response_status: status, + response_headers: Some(response_headers.clone()), + response_snippet: Some(format!("{}: {}", expected.name, value)), + screenshot_path: None, + payload: None, + response_time_ms: None, + }; + + let mut finding = DastFinding::new( + String::new(), + target_id.clone(), + DastVulnType::SecurityHeaderMissing, + format!("Invalid {} header value", expected.name), + format!( + "The {} header is present but has an invalid or weak value: '{}'. \ + {}", + expected.name, value, expected.description + ), + expected.severity.clone(), + url.to_string(), + "GET".to_string(), + ); + finding.cwe = Some(expected.cwe.to_string()); + finding.evidence = vec![evidence]; + finding.remediation = Some(expected.remediation.to_string()); + findings.push(finding); + } } + None => { + header_results.insert( + expected.name.to_string(), + json!({ + "present": false, + "value": null, + "valid": false, + }), + ); - header_results.insert( - expected.name.to_string(), - json!({ - "present": true, - "value": value, - "valid": is_valid, - }), - ); - - if !is_valid { let evidence = DastEvidence { request_method: "GET".to_string(), request_url: url.to_string(), @@ -169,7 +219,7 @@ impl PentestTool for SecurityHeadersTool { request_body: None, response_status: status, response_headers: Some(response_headers.clone()), - response_snippet: Some(format!("{}: {}", expected.name, value)), + response_snippet: Some(format!("{} header is missing", expected.name)), screenshot_path: None, payload: None, response_time_ms: None, @@ -179,11 +229,10 @@ impl PentestTool for SecurityHeadersTool { String::new(), target_id.clone(), DastVulnType::SecurityHeaderMissing, - format!("Invalid {} header value", expected.name), + format!("Missing {} header", expected.name), format!( - "The {} header is present but has an invalid or weak value: '{}'. \ - {}", - expected.name, value, expected.description + "The {} header is not present in the response. {}", + expected.name, expected.description ), expected.severity.clone(), url.to_string(), @@ -195,14 +244,20 @@ impl PentestTool for SecurityHeadersTool { findings.push(finding); } } - None => { + } + + // Also check for information disclosure headers + let disclosure_headers = [ + "server", + "x-powered-by", + "x-aspnet-version", + "x-aspnetmvc-version", + ]; + for h in &disclosure_headers { + if let Some(value) = response_headers.get(*h) { header_results.insert( - expected.name.to_string(), - json!({ - "present": false, - "value": null, - "valid": false, - }), + format!("{h}_disclosure"), + json!({ "present": true, "value": value }), ); let evidence = DastEvidence { @@ -212,56 +267,13 @@ impl PentestTool for SecurityHeadersTool { request_body: None, response_status: status, response_headers: Some(response_headers.clone()), - response_snippet: Some(format!("{} header is missing", expected.name)), + response_snippet: Some(format!("{h}: {value}")), screenshot_path: None, payload: None, response_time_ms: None, }; let mut finding = DastFinding::new( - String::new(), - target_id.clone(), - DastVulnType::SecurityHeaderMissing, - format!("Missing {} header", expected.name), - format!( - "The {} header is not present in the response. {}", - expected.name, expected.description - ), - expected.severity.clone(), - url.to_string(), - "GET".to_string(), - ); - finding.cwe = Some(expected.cwe.to_string()); - finding.evidence = vec![evidence]; - finding.remediation = Some(expected.remediation.to_string()); - findings.push(finding); - } - } - } - - // Also check for information disclosure headers - let disclosure_headers = ["server", "x-powered-by", "x-aspnet-version", "x-aspnetmvc-version"]; - for h in &disclosure_headers { - if let Some(value) = response_headers.get(*h) { - header_results.insert( - format!("{h}_disclosure"), - json!({ "present": true, "value": value }), - ); - - let evidence = DastEvidence { - request_method: "GET".to_string(), - request_url: url.to_string(), - request_headers: None, - request_body: None, - response_status: status, - response_headers: Some(response_headers.clone()), - response_snippet: Some(format!("{h}: {value}")), - screenshot_path: None, - payload: None, - response_time_ms: None, - }; - - let mut finding = DastFinding::new( String::new(), target_id.clone(), DastVulnType::SecurityHeaderMissing, @@ -274,27 +286,27 @@ impl PentestTool for SecurityHeadersTool { url.to_string(), "GET".to_string(), ); - finding.cwe = Some("CWE-200".to_string()); - finding.evidence = vec![evidence]; - finding.remediation = Some(format!( - "Remove or suppress the {h} header in your server configuration." - )); - findings.push(finding); + finding.cwe = Some("CWE-200".to_string()); + finding.evidence = vec![evidence]; + finding.remediation = Some(format!( + "Remove or suppress the {h} header in your server configuration." + )); + findings.push(finding); + } } - } - let count = findings.len(); - info!(url, findings = count, "Security headers check complete"); + let count = findings.len(); + info!(url, findings = count, "Security headers check complete"); - Ok(PentestToolResult { - summary: if count > 0 { - format!("Found {count} security header issues for {url}.") - } else { - format!("All checked security headers are present and valid for {url}.") - }, - findings, - data: json!(header_results), - }) + Ok(PentestToolResult { + summary: if count > 0 { + format!("Found {count} security header issues for {url}.") + } else { + format!("All checked security headers are present and valid for {url}.") + }, + findings, + data: json!(header_results), + }) }) } } diff --git a/compliance-dast/src/tools/sql_injection.rs b/compliance-dast/src/tools/sql_injection.rs index 0ef260c..7bfed01 100644 --- a/compliance-dast/src/tools/sql_injection.rs +++ b/compliance-dast/src/tools/sql_injection.rs @@ -1,5 +1,7 @@ use compliance_core::error::CoreError; -use compliance_core::traits::dast_agent::{DastAgent, DastContext, DiscoveredEndpoint, EndpointParameter}; +use compliance_core::traits::dast_agent::{ + DastAgent, DastContext, DiscoveredEndpoint, EndpointParameter, +}; use compliance_core::traits::pentest_tool::{PentestTool, PentestToolContext, PentestToolResult}; use serde_json::json; @@ -7,29 +9,51 @@ use crate::agents::injection::SqlInjectionAgent; /// PentestTool wrapper around the existing SqlInjectionAgent. pub struct SqlInjectionTool { - http: reqwest::Client, + _http: reqwest::Client, agent: SqlInjectionAgent, } impl SqlInjectionTool { pub fn new(http: reqwest::Client) -> Self { let agent = SqlInjectionAgent::new(http.clone()); - Self { http, agent } + Self { _http: http, agent } } fn parse_endpoints(input: &serde_json::Value) -> Vec { let mut endpoints = Vec::new(); if let Some(arr) = input.get("endpoints").and_then(|v| v.as_array()) { for ep in arr { - let url = ep.get("url").and_then(|v| v.as_str()).unwrap_or_default().to_string(); - let method = ep.get("method").and_then(|v| v.as_str()).unwrap_or("GET").to_string(); + let url = ep + .get("url") + .and_then(|v| v.as_str()) + .unwrap_or_default() + .to_string(); + let method = ep + .get("method") + .and_then(|v| v.as_str()) + .unwrap_or("GET") + .to_string(); let mut parameters = Vec::new(); if let Some(params) = ep.get("parameters").and_then(|v| v.as_array()) { for p in params { - let name = p.get("name").and_then(|v| v.as_str()).unwrap_or_default().to_string(); - let location = p.get("location").and_then(|v| v.as_str()).unwrap_or("query").to_string(); - let param_type = p.get("param_type").and_then(|v| v.as_str()).map(String::from); - let example_value = p.get("example_value").and_then(|v| v.as_str()).map(String::from); + let name = p + .get("name") + .and_then(|v| v.as_str()) + .unwrap_or_default() + .to_string(); + let location = p + .get("location") + .and_then(|v| v.as_str()) + .unwrap_or("query") + .to_string(); + let param_type = p + .get("param_type") + .and_then(|v| v.as_str()) + .map(String::from); + let example_value = p + .get("example_value") + .and_then(|v| v.as_str()) + .map(String::from); parameters.push(EndpointParameter { name, location, @@ -42,8 +66,14 @@ impl SqlInjectionTool { url, method, parameters, - content_type: ep.get("content_type").and_then(|v| v.as_str()).map(String::from), - requires_auth: ep.get("requires_auth").and_then(|v| v.as_bool()).unwrap_or(false), + content_type: ep + .get("content_type") + .and_then(|v| v.as_str()) + .map(String::from), + requires_auth: ep + .get("requires_auth") + .and_then(|v| v.as_bool()) + .unwrap_or(false), }); } } @@ -51,6 +81,62 @@ impl SqlInjectionTool { } } +#[cfg(test)] +mod tests { + use super::*; + use serde_json::json; + + #[test] + fn parse_endpoints_basic() { + let input = json!({ + "endpoints": [{ + "url": "https://example.com/api/users", + "method": "POST", + "parameters": [ + { "name": "id", "location": "body", "param_type": "integer" } + ] + }] + }); + let endpoints = SqlInjectionTool::parse_endpoints(&input); + assert_eq!(endpoints.len(), 1); + assert_eq!(endpoints[0].url, "https://example.com/api/users"); + assert_eq!(endpoints[0].method, "POST"); + assert_eq!(endpoints[0].parameters[0].name, "id"); + assert_eq!(endpoints[0].parameters[0].location, "body"); + assert_eq!( + endpoints[0].parameters[0].param_type.as_deref(), + Some("integer") + ); + } + + #[test] + fn parse_endpoints_empty_input() { + assert!(SqlInjectionTool::parse_endpoints(&json!({})).is_empty()); + assert!(SqlInjectionTool::parse_endpoints(&json!({ "endpoints": [] })).is_empty()); + } + + #[test] + fn parse_endpoints_multiple() { + let input = json!({ + "endpoints": [ + { "url": "https://a.com/1", "method": "GET", "parameters": [] }, + { "url": "https://b.com/2", "method": "DELETE", "parameters": [] } + ] + }); + let endpoints = SqlInjectionTool::parse_endpoints(&input); + assert_eq!(endpoints.len(), 2); + assert_eq!(endpoints[0].url, "https://a.com/1"); + assert_eq!(endpoints[1].method, "DELETE"); + } + + #[test] + fn parse_endpoints_default_method() { + let input = json!({ "endpoints": [{ "url": "https://x.com", "parameters": [] }] }); + let endpoints = SqlInjectionTool::parse_endpoints(&input); + assert_eq!(endpoints[0].method, "GET"); + } +} + impl PentestTool for SqlInjectionTool { fn name(&self) -> &str { "sql_injection_scanner" @@ -104,35 +190,37 @@ impl PentestTool for SqlInjectionTool { &'a self, input: serde_json::Value, context: &'a PentestToolContext, - ) -> std::pin::Pin> + Send + 'a>> { + ) -> std::pin::Pin< + Box> + Send + 'a>, + > { Box::pin(async move { - let endpoints = Self::parse_endpoints(&input); - if endpoints.is_empty() { - return Ok(PentestToolResult { - summary: "No endpoints provided to test.".to_string(), - findings: Vec::new(), - data: json!({}), - }); - } + let endpoints = Self::parse_endpoints(&input); + if endpoints.is_empty() { + return Ok(PentestToolResult { + summary: "No endpoints provided to test.".to_string(), + findings: Vec::new(), + data: json!({}), + }); + } - let dast_context = DastContext { - endpoints, - technologies: Vec::new(), - sast_hints: Vec::new(), - }; + let dast_context = DastContext { + endpoints, + technologies: Vec::new(), + sast_hints: Vec::new(), + }; - let findings = self.agent.run(&context.target, &dast_context).await?; - let count = findings.len(); + let findings = self.agent.run(&context.target, &dast_context).await?; + let count = findings.len(); - Ok(PentestToolResult { - summary: if count > 0 { - format!("Found {count} SQL injection vulnerabilities.") - } else { - "No SQL injection vulnerabilities detected.".to_string() - }, - findings, - data: json!({ "endpoints_tested": dast_context.endpoints.len() }), - }) + Ok(PentestToolResult { + summary: if count > 0 { + format!("Found {count} SQL injection vulnerabilities.") + } else { + "No SQL injection vulnerabilities detected.".to_string() + }, + findings, + data: json!({ "endpoints_tested": dast_context.endpoints.len() }), + }) }) } } diff --git a/compliance-dast/src/tools/ssrf.rs b/compliance-dast/src/tools/ssrf.rs index 76af017..f7742f8 100644 --- a/compliance-dast/src/tools/ssrf.rs +++ b/compliance-dast/src/tools/ssrf.rs @@ -1,5 +1,7 @@ use compliance_core::error::CoreError; -use compliance_core::traits::dast_agent::{DastAgent, DastContext, DiscoveredEndpoint, EndpointParameter}; +use compliance_core::traits::dast_agent::{ + DastAgent, DastContext, DiscoveredEndpoint, EndpointParameter, +}; use compliance_core::traits::pentest_tool::{PentestTool, PentestToolContext, PentestToolResult}; use serde_json::json; @@ -7,30 +9,52 @@ use crate::agents::ssrf::SsrfAgent; /// PentestTool wrapper around the existing SsrfAgent. pub struct SsrfTool { - http: reqwest::Client, + _http: reqwest::Client, agent: SsrfAgent, } impl SsrfTool { pub fn new(http: reqwest::Client) -> Self { let agent = SsrfAgent::new(http.clone()); - Self { http, agent } + Self { _http: http, agent } } fn parse_endpoints(input: &serde_json::Value) -> Vec { let mut endpoints = Vec::new(); if let Some(arr) = input.get("endpoints").and_then(|v| v.as_array()) { for ep in arr { - let url = ep.get("url").and_then(|v| v.as_str()).unwrap_or_default().to_string(); - let method = ep.get("method").and_then(|v| v.as_str()).unwrap_or("GET").to_string(); + let url = ep + .get("url") + .and_then(|v| v.as_str()) + .unwrap_or_default() + .to_string(); + let method = ep + .get("method") + .and_then(|v| v.as_str()) + .unwrap_or("GET") + .to_string(); let mut parameters = Vec::new(); if let Some(params) = ep.get("parameters").and_then(|v| v.as_array()) { for p in params { parameters.push(EndpointParameter { - name: p.get("name").and_then(|v| v.as_str()).unwrap_or_default().to_string(), - location: p.get("location").and_then(|v| v.as_str()).unwrap_or("query").to_string(), - param_type: p.get("param_type").and_then(|v| v.as_str()).map(String::from), - example_value: p.get("example_value").and_then(|v| v.as_str()).map(String::from), + name: p + .get("name") + .and_then(|v| v.as_str()) + .unwrap_or_default() + .to_string(), + location: p + .get("location") + .and_then(|v| v.as_str()) + .unwrap_or("query") + .to_string(), + param_type: p + .get("param_type") + .and_then(|v| v.as_str()) + .map(String::from), + example_value: p + .get("example_value") + .and_then(|v| v.as_str()) + .map(String::from), }); } } @@ -38,8 +62,14 @@ impl SsrfTool { url, method, parameters, - content_type: ep.get("content_type").and_then(|v| v.as_str()).map(String::from), - requires_auth: ep.get("requires_auth").and_then(|v| v.as_bool()).unwrap_or(false), + content_type: ep + .get("content_type") + .and_then(|v| v.as_str()) + .map(String::from), + requires_auth: ep + .get("requires_auth") + .and_then(|v| v.as_bool()) + .unwrap_or(false), }); } } @@ -100,35 +130,37 @@ impl PentestTool for SsrfTool { &'a self, input: serde_json::Value, context: &'a PentestToolContext, - ) -> std::pin::Pin> + Send + 'a>> { + ) -> std::pin::Pin< + Box> + Send + 'a>, + > { Box::pin(async move { - let endpoints = Self::parse_endpoints(&input); - if endpoints.is_empty() { - return Ok(PentestToolResult { - summary: "No endpoints provided to test.".to_string(), - findings: Vec::new(), - data: json!({}), - }); - } + let endpoints = Self::parse_endpoints(&input); + if endpoints.is_empty() { + return Ok(PentestToolResult { + summary: "No endpoints provided to test.".to_string(), + findings: Vec::new(), + data: json!({}), + }); + } - let dast_context = DastContext { - endpoints, - technologies: Vec::new(), - sast_hints: Vec::new(), - }; + let dast_context = DastContext { + endpoints, + technologies: Vec::new(), + sast_hints: Vec::new(), + }; - let findings = self.agent.run(&context.target, &dast_context).await?; - let count = findings.len(); + let findings = self.agent.run(&context.target, &dast_context).await?; + let count = findings.len(); - Ok(PentestToolResult { - summary: if count > 0 { - format!("Found {count} SSRF vulnerabilities.") - } else { - "No SSRF vulnerabilities detected.".to_string() - }, - findings, - data: json!({ "endpoints_tested": dast_context.endpoints.len() }), - }) + Ok(PentestToolResult { + summary: if count > 0 { + format!("Found {count} SSRF vulnerabilities.") + } else { + "No SSRF vulnerabilities detected.".to_string() + }, + findings, + data: json!({ "endpoints_tested": dast_context.endpoints.len() }), + }) }) } } diff --git a/compliance-dast/src/tools/tls_analyzer.rs b/compliance-dast/src/tools/tls_analyzer.rs index 30ef5e0..af9fbf9 100644 --- a/compliance-dast/src/tools/tls_analyzer.rs +++ b/compliance-dast/src/tools/tls_analyzer.rs @@ -39,10 +39,7 @@ impl TlsAnalyzerTool { /// TLS client hello. We test SSLv3 / old protocol support by attempting /// connection with the system's native-tls which typically negotiates the /// best available, then inspect what was negotiated. - async fn check_tls( - host: &str, - port: u16, - ) -> Result { + async fn check_tls(host: &str, port: u16) -> Result { let addr = format!("{host}:{port}"); let tcp = TcpStream::connect(&addr) @@ -62,11 +59,13 @@ impl TlsAnalyzerTool { .await .map_err(|e| CoreError::Dast(format!("TLS handshake with {addr} failed: {e}")))?; - let peer_cert = tls_stream.get_ref().peer_certificate() + let peer_cert = tls_stream + .get_ref() + .peer_certificate() .map_err(|e| CoreError::Dast(format!("Failed to get peer certificate: {e}")))?; let mut tls_info = TlsInfo { - protocol_version: String::new(), + _protocol_version: String::new(), cert_subject: String::new(), cert_issuer: String::new(), cert_not_before: String::new(), @@ -78,7 +77,8 @@ impl TlsAnalyzerTool { }; if let Some(cert) = peer_cert { - let der = cert.to_der() + let der = cert + .to_der() .map_err(|e| CoreError::Dast(format!("Certificate DER encoding failed: {e}")))?; // native_tls doesn't give rich access, so we parse what we can @@ -93,7 +93,7 @@ impl TlsAnalyzerTool { /// Best-effort parse of DER-encoded X.509 certificate for dates and subject. /// This is a simplified parser; in production you would use a proper x509 crate. - fn parse_cert_der(der: &[u8], mut info: TlsInfo) -> TlsInfo { + fn parse_cert_der(_der: &[u8], mut info: TlsInfo) -> TlsInfo { // We rely on the native_tls debug output stored in cert_subject // and just mark fields as "see certificate details" if info.cert_subject.contains("self signed") || info.cert_subject.contains("Self-Signed") { @@ -104,7 +104,7 @@ impl TlsAnalyzerTool { } struct TlsInfo { - protocol_version: String, + _protocol_version: String, cert_subject: String, cert_issuer: String, cert_not_before: String, @@ -152,111 +152,194 @@ impl PentestTool for TlsAnalyzerTool { &'a self, input: serde_json::Value, context: &'a PentestToolContext, - ) -> std::pin::Pin> + Send + 'a>> { + ) -> std::pin::Pin< + Box> + Send + 'a>, + > { Box::pin(async move { - let url = input - .get("url") - .and_then(|v| v.as_str()) - .ok_or_else(|| CoreError::Dast("Missing required 'url' parameter".to_string()))?; + let url = input + .get("url") + .and_then(|v| v.as_str()) + .ok_or_else(|| CoreError::Dast("Missing required 'url' parameter".to_string()))?; - let host = Self::extract_host(url) - .unwrap_or_else(|| url.to_string()); + let host = Self::extract_host(url).unwrap_or_else(|| url.to_string()); - let port = input - .get("port") - .and_then(|v| v.as_u64()) - .map(|p| p as u16) - .unwrap_or_else(|| Self::extract_port(url)); + let port = input + .get("port") + .and_then(|v| v.as_u64()) + .map(|p| p as u16) + .unwrap_or_else(|| Self::extract_port(url)); - let target_id = context - .target - .id - .map(|oid| oid.to_hex()) - .unwrap_or_else(|| "unknown".to_string()); + let target_id = context + .target + .id + .map(|oid| oid.to_hex()) + .unwrap_or_else(|| "unknown".to_string()); - let mut findings = Vec::new(); - let mut tls_data = json!({}); + let mut findings = Vec::new(); + let mut tls_data = json!({}); - // First check: does the server even support HTTPS? - let https_url = if url.starts_with("https://") { - url.to_string() - } else if url.starts_with("http://") { - url.replace("http://", "https://") - } else { - format!("https://{url}") - }; + // First check: does the server even support HTTPS? + let https_url = if url.starts_with("https://") { + url.to_string() + } else if url.starts_with("http://") { + url.replace("http://", "https://") + } else { + format!("https://{url}") + }; - // Check if HTTP redirects to HTTPS - let http_url = if url.starts_with("http://") { - url.to_string() - } else if url.starts_with("https://") { - url.replace("https://", "http://") - } else { - format!("http://{url}") - }; + // Check if HTTP redirects to HTTPS + let http_url = if url.starts_with("http://") { + url.to_string() + } else if url.starts_with("https://") { + url.replace("https://", "http://") + } else { + format!("http://{url}") + }; - match self.http.get(&http_url).send().await { - Ok(resp) => { - let final_url = resp.url().to_string(); - let redirects_to_https = final_url.starts_with("https://"); - tls_data["http_redirects_to_https"] = json!(redirects_to_https); + match self.http.get(&http_url).send().await { + Ok(resp) => { + let final_url = resp.url().to_string(); + let redirects_to_https = final_url.starts_with("https://"); + tls_data["http_redirects_to_https"] = json!(redirects_to_https); - if !redirects_to_https { - let evidence = DastEvidence { - request_method: "GET".to_string(), - request_url: http_url.clone(), - request_headers: None, - request_body: None, - response_status: resp.status().as_u16(), - response_headers: None, - response_snippet: Some(format!("Final URL: {final_url}")), - screenshot_path: None, - payload: None, - response_time_ms: None, - }; + if !redirects_to_https { + let evidence = DastEvidence { + request_method: "GET".to_string(), + request_url: http_url.clone(), + request_headers: None, + request_body: None, + response_status: resp.status().as_u16(), + response_headers: None, + response_snippet: Some(format!("Final URL: {final_url}")), + screenshot_path: None, + payload: None, + response_time_ms: None, + }; - let mut finding = DastFinding::new( - String::new(), - target_id.clone(), - DastVulnType::TlsMisconfiguration, - format!("HTTP does not redirect to HTTPS for {host}"), - format!( - "HTTP requests to {host} are not redirected to HTTPS. \ + let mut finding = DastFinding::new( + String::new(), + target_id.clone(), + DastVulnType::TlsMisconfiguration, + format!("HTTP does not redirect to HTTPS for {host}"), + format!( + "HTTP requests to {host} are not redirected to HTTPS. \ Users accessing the site via HTTP will have their traffic \ transmitted in cleartext." - ), - Severity::Medium, - http_url.clone(), - "GET".to_string(), - ); - finding.cwe = Some("CWE-319".to_string()); - finding.evidence = vec![evidence]; - finding.remediation = Some( - "Configure the web server to redirect all HTTP requests to HTTPS \ + ), + Severity::Medium, + http_url.clone(), + "GET".to_string(), + ); + finding.cwe = Some("CWE-319".to_string()); + finding.evidence = vec![evidence]; + finding.remediation = Some( + "Configure the web server to redirect all HTTP requests to HTTPS \ using a 301 redirect." - .to_string(), - ); - findings.push(finding); + .to_string(), + ); + findings.push(finding); + } + } + Err(_) => { + tls_data["http_check_error"] = json!("Could not connect via HTTP"); } } - Err(_) => { - tls_data["http_check_error"] = json!("Could not connect via HTTP"); - } - } - // Perform TLS analysis - match Self::check_tls(&host, port).await { - Ok(tls_info) => { - tls_data["host"] = json!(host); - tls_data["port"] = json!(port); - tls_data["cert_subject"] = json!(tls_info.cert_subject); - tls_data["cert_issuer"] = json!(tls_info.cert_issuer); - tls_data["cert_not_before"] = json!(tls_info.cert_not_before); - tls_data["cert_not_after"] = json!(tls_info.cert_not_after); - tls_data["alpn_protocol"] = json!(tls_info.alpn_protocol); - tls_data["san_names"] = json!(tls_info.san_names); + // Perform TLS analysis + match Self::check_tls(&host, port).await { + Ok(tls_info) => { + tls_data["host"] = json!(host); + tls_data["port"] = json!(port); + tls_data["cert_subject"] = json!(tls_info.cert_subject); + tls_data["cert_issuer"] = json!(tls_info.cert_issuer); + tls_data["cert_not_before"] = json!(tls_info.cert_not_before); + tls_data["cert_not_after"] = json!(tls_info.cert_not_after); + tls_data["alpn_protocol"] = json!(tls_info.alpn_protocol); + tls_data["san_names"] = json!(tls_info.san_names); - if tls_info.cert_expired { + if tls_info.cert_expired { + let evidence = DastEvidence { + request_method: "TLS".to_string(), + request_url: format!("{host}:{port}"), + request_headers: None, + request_body: None, + response_status: 0, + response_headers: None, + response_snippet: Some(format!( + "Certificate expired. Not After: {}", + tls_info.cert_not_after + )), + screenshot_path: None, + payload: None, + response_time_ms: None, + }; + + let mut finding = DastFinding::new( + String::new(), + target_id.clone(), + DastVulnType::TlsMisconfiguration, + format!("Expired TLS certificate for {host}"), + format!( + "The TLS certificate for {host} has expired. \ + Browsers will show security warnings to users." + ), + Severity::High, + format!("https://{host}:{port}"), + "TLS".to_string(), + ); + finding.cwe = Some("CWE-295".to_string()); + finding.evidence = vec![evidence]; + finding.remediation = Some( + "Renew the TLS certificate. Consider using automated certificate \ + management with Let's Encrypt or a similar CA." + .to_string(), + ); + findings.push(finding); + warn!(host, "Expired TLS certificate"); + } + + if tls_info.cert_self_signed { + let evidence = DastEvidence { + request_method: "TLS".to_string(), + request_url: format!("{host}:{port}"), + request_headers: None, + request_body: None, + response_status: 0, + response_headers: None, + response_snippet: Some("Self-signed certificate detected".to_string()), + screenshot_path: None, + payload: None, + response_time_ms: None, + }; + + let mut finding = DastFinding::new( + String::new(), + target_id.clone(), + DastVulnType::TlsMisconfiguration, + format!("Self-signed TLS certificate for {host}"), + format!( + "The TLS certificate for {host} is self-signed and not issued by a \ + trusted certificate authority. Browsers will show security warnings." + ), + Severity::Medium, + format!("https://{host}:{port}"), + "TLS".to_string(), + ); + finding.cwe = Some("CWE-295".to_string()); + finding.evidence = vec![evidence]; + finding.remediation = Some( + "Replace the self-signed certificate with one issued by a trusted \ + certificate authority." + .to_string(), + ); + findings.push(finding); + warn!(host, "Self-signed certificate"); + } + } + Err(e) => { + tls_data["tls_error"] = json!(e.to_string()); + + // TLS handshake failure itself is a finding let evidence = DastEvidence { request_method: "TLS".to_string(), request_url: format!("{host}:{port}"), @@ -264,10 +347,7 @@ impl PentestTool for TlsAnalyzerTool { request_body: None, response_status: 0, response_headers: None, - response_snippet: Some(format!( - "Certificate expired. Not After: {}", - tls_info.cert_not_after - )), + response_snippet: Some(format!("TLS error: {e}")), screenshot_path: None, payload: None, response_time_ms: None, @@ -277,10 +357,9 @@ impl PentestTool for TlsAnalyzerTool { String::new(), target_id.clone(), DastVulnType::TlsMisconfiguration, - format!("Expired TLS certificate for {host}"), + format!("TLS handshake failure for {host}"), format!( - "The TLS certificate for {host} has expired. \ - Browsers will show security warnings to users." + "Could not establish a TLS connection to {host}:{port}. Error: {e}" ), Severity::High, format!("https://{host}:{port}"), @@ -289,115 +368,37 @@ impl PentestTool for TlsAnalyzerTool { finding.cwe = Some("CWE-295".to_string()); finding.evidence = vec![evidence]; finding.remediation = Some( - "Renew the TLS certificate. Consider using automated certificate \ - management with Let's Encrypt or a similar CA." - .to_string(), - ); - findings.push(finding); - warn!(host, "Expired TLS certificate"); - } - - if tls_info.cert_self_signed { - let evidence = DastEvidence { - request_method: "TLS".to_string(), - request_url: format!("{host}:{port}"), - request_headers: None, - request_body: None, - response_status: 0, - response_headers: None, - response_snippet: Some("Self-signed certificate detected".to_string()), - screenshot_path: None, - payload: None, - response_time_ms: None, - }; - - let mut finding = DastFinding::new( - String::new(), - target_id.clone(), - DastVulnType::TlsMisconfiguration, - format!("Self-signed TLS certificate for {host}"), - format!( - "The TLS certificate for {host} is self-signed and not issued by a \ - trusted certificate authority. Browsers will show security warnings." - ), - Severity::Medium, - format!("https://{host}:{port}"), - "TLS".to_string(), - ); - finding.cwe = Some("CWE-295".to_string()); - finding.evidence = vec![evidence]; - finding.remediation = Some( - "Replace the self-signed certificate with one issued by a trusted \ - certificate authority." - .to_string(), - ); - findings.push(finding); - warn!(host, "Self-signed certificate"); - } - } - Err(e) => { - tls_data["tls_error"] = json!(e.to_string()); - - // TLS handshake failure itself is a finding - let evidence = DastEvidence { - request_method: "TLS".to_string(), - request_url: format!("{host}:{port}"), - request_headers: None, - request_body: None, - response_status: 0, - response_headers: None, - response_snippet: Some(format!("TLS error: {e}")), - screenshot_path: None, - payload: None, - response_time_ms: None, - }; - - let mut finding = DastFinding::new( - String::new(), - target_id.clone(), - DastVulnType::TlsMisconfiguration, - format!("TLS handshake failure for {host}"), - format!( - "Could not establish a TLS connection to {host}:{port}. Error: {e}" - ), - Severity::High, - format!("https://{host}:{port}"), - "TLS".to_string(), - ); - finding.cwe = Some("CWE-295".to_string()); - finding.evidence = vec![evidence]; - finding.remediation = Some( - "Ensure TLS is properly configured on the server. Check that the \ + "Ensure TLS is properly configured on the server. Check that the \ certificate is valid and the server supports modern TLS versions." - .to_string(), - ); - findings.push(finding); + .to_string(), + ); + findings.push(finding); + } } - } - // Check strict transport security via an HTTPS request - match self.http.get(&https_url).send().await { - Ok(resp) => { - let hsts = resp.headers().get("strict-transport-security"); - tls_data["hsts_header"] = json!(hsts.map(|v| v.to_str().unwrap_or(""))); + // Check strict transport security via an HTTPS request + match self.http.get(&https_url).send().await { + Ok(resp) => { + let hsts = resp.headers().get("strict-transport-security"); + tls_data["hsts_header"] = json!(hsts.map(|v| v.to_str().unwrap_or(""))); - if hsts.is_none() { - let evidence = DastEvidence { - request_method: "GET".to_string(), - request_url: https_url.clone(), - request_headers: None, - request_body: None, - response_status: resp.status().as_u16(), - response_headers: None, - response_snippet: Some( - "Strict-Transport-Security header not present".to_string(), - ), - screenshot_path: None, - payload: None, - response_time_ms: None, - }; + if hsts.is_none() { + let evidence = DastEvidence { + request_method: "GET".to_string(), + request_url: https_url.clone(), + request_headers: None, + request_body: None, + response_status: resp.status().as_u16(), + response_headers: None, + response_snippet: Some( + "Strict-Transport-Security header not present".to_string(), + ), + screenshot_path: None, + payload: None, + response_time_ms: None, + }; - let mut finding = DastFinding::new( + let mut finding = DastFinding::new( String::new(), target_id.clone(), DastVulnType::TlsMisconfiguration, @@ -410,33 +411,33 @@ impl PentestTool for TlsAnalyzerTool { https_url.clone(), "GET".to_string(), ); - finding.cwe = Some("CWE-319".to_string()); - finding.evidence = vec![evidence]; - finding.remediation = Some( + finding.cwe = Some("CWE-319".to_string()); + finding.evidence = vec![evidence]; + finding.remediation = Some( "Add the Strict-Transport-Security header with an appropriate max-age. \ Example: 'Strict-Transport-Security: max-age=31536000; includeSubDomains'." .to_string(), ); - findings.push(finding); + findings.push(finding); + } + } + Err(_) => { + tls_data["https_check_error"] = json!("Could not connect via HTTPS"); } } - Err(_) => { - tls_data["https_check_error"] = json!("Could not connect via HTTPS"); - } - } - let count = findings.len(); - info!(host = %host, findings = count, "TLS analysis complete"); + let count = findings.len(); + info!(host = %host, findings = count, "TLS analysis complete"); - Ok(PentestToolResult { - summary: if count > 0 { - format!("Found {count} TLS configuration issues for {host}.") - } else { - format!("TLS configuration looks good for {host}.") - }, - findings, - data: tls_data, - }) + Ok(PentestToolResult { + summary: if count > 0 { + format!("Found {count} TLS configuration issues for {host}.") + } else { + format!("TLS configuration looks good for {host}.") + }, + findings, + data: tls_data, + }) }) } } diff --git a/compliance-dast/src/tools/xss.rs b/compliance-dast/src/tools/xss.rs index 2470a7d..fb2b62b 100644 --- a/compliance-dast/src/tools/xss.rs +++ b/compliance-dast/src/tools/xss.rs @@ -1,5 +1,7 @@ use compliance_core::error::CoreError; -use compliance_core::traits::dast_agent::{DastAgent, DastContext, DiscoveredEndpoint, EndpointParameter}; +use compliance_core::traits::dast_agent::{ + DastAgent, DastContext, DiscoveredEndpoint, EndpointParameter, +}; use compliance_core::traits::pentest_tool::{PentestTool, PentestToolContext, PentestToolResult}; use serde_json::json; @@ -7,30 +9,52 @@ use crate::agents::xss::XssAgent; /// PentestTool wrapper around the existing XssAgent. pub struct XssTool { - http: reqwest::Client, + _http: reqwest::Client, agent: XssAgent, } impl XssTool { pub fn new(http: reqwest::Client) -> Self { let agent = XssAgent::new(http.clone()); - Self { http, agent } + Self { _http: http, agent } } fn parse_endpoints(input: &serde_json::Value) -> Vec { let mut endpoints = Vec::new(); if let Some(arr) = input.get("endpoints").and_then(|v| v.as_array()) { for ep in arr { - let url = ep.get("url").and_then(|v| v.as_str()).unwrap_or_default().to_string(); - let method = ep.get("method").and_then(|v| v.as_str()).unwrap_or("GET").to_string(); + let url = ep + .get("url") + .and_then(|v| v.as_str()) + .unwrap_or_default() + .to_string(); + let method = ep + .get("method") + .and_then(|v| v.as_str()) + .unwrap_or("GET") + .to_string(); let mut parameters = Vec::new(); if let Some(params) = ep.get("parameters").and_then(|v| v.as_array()) { for p in params { parameters.push(EndpointParameter { - name: p.get("name").and_then(|v| v.as_str()).unwrap_or_default().to_string(), - location: p.get("location").and_then(|v| v.as_str()).unwrap_or("query").to_string(), - param_type: p.get("param_type").and_then(|v| v.as_str()).map(String::from), - example_value: p.get("example_value").and_then(|v| v.as_str()).map(String::from), + name: p + .get("name") + .and_then(|v| v.as_str()) + .unwrap_or_default() + .to_string(), + location: p + .get("location") + .and_then(|v| v.as_str()) + .unwrap_or("query") + .to_string(), + param_type: p + .get("param_type") + .and_then(|v| v.as_str()) + .map(String::from), + example_value: p + .get("example_value") + .and_then(|v| v.as_str()) + .map(String::from), }); } } @@ -38,8 +62,14 @@ impl XssTool { url, method, parameters, - content_type: ep.get("content_type").and_then(|v| v.as_str()).map(String::from), - requires_auth: ep.get("requires_auth").and_then(|v| v.as_bool()).unwrap_or(false), + content_type: ep + .get("content_type") + .and_then(|v| v.as_str()) + .map(String::from), + requires_auth: ep + .get("requires_auth") + .and_then(|v| v.as_bool()) + .unwrap_or(false), }); } } @@ -47,6 +77,91 @@ impl XssTool { } } +#[cfg(test)] +mod tests { + use super::*; + use serde_json::json; + + #[test] + fn parse_endpoints_basic() { + let input = json!({ + "endpoints": [ + { + "url": "https://example.com/search", + "method": "GET", + "parameters": [ + { "name": "q", "location": "query" } + ] + } + ] + }); + let endpoints = XssTool::parse_endpoints(&input); + assert_eq!(endpoints.len(), 1); + assert_eq!(endpoints[0].url, "https://example.com/search"); + assert_eq!(endpoints[0].method, "GET"); + assert_eq!(endpoints[0].parameters.len(), 1); + assert_eq!(endpoints[0].parameters[0].name, "q"); + assert_eq!(endpoints[0].parameters[0].location, "query"); + } + + #[test] + fn parse_endpoints_empty() { + let input = json!({ "endpoints": [] }); + assert!(XssTool::parse_endpoints(&input).is_empty()); + } + + #[test] + fn parse_endpoints_missing_key() { + let input = json!({}); + assert!(XssTool::parse_endpoints(&input).is_empty()); + } + + #[test] + fn parse_endpoints_defaults() { + let input = json!({ + "endpoints": [ + { "url": "https://example.com/api", "parameters": [] } + ] + }); + let endpoints = XssTool::parse_endpoints(&input); + assert_eq!(endpoints[0].method, "GET"); // default + assert!(!endpoints[0].requires_auth); // default false + } + + #[test] + fn parse_endpoints_full_params() { + let input = json!({ + "endpoints": [{ + "url": "https://example.com", + "method": "POST", + "content_type": "application/json", + "requires_auth": true, + "parameters": [{ + "name": "body", + "location": "body", + "param_type": "string", + "example_value": "test" + }] + }] + }); + let endpoints = XssTool::parse_endpoints(&input); + assert_eq!(endpoints[0].method, "POST"); + assert_eq!( + endpoints[0].content_type.as_deref(), + Some("application/json") + ); + assert!(endpoints[0].requires_auth); + assert_eq!( + endpoints[0].parameters[0].param_type.as_deref(), + Some("string") + ); + assert_eq!( + endpoints[0].parameters[0].example_value.as_deref(), + Some("test") + ); + } +} + impl PentestTool for XssTool { fn name(&self) -> &str { "xss_scanner" @@ -100,35 +215,37 @@ impl PentestTool for XssTool { &'a self, input: serde_json::Value, context: &'a PentestToolContext, - ) -> std::pin::Pin> + Send + 'a>> { + ) -> std::pin::Pin< + Box> + Send + 'a>, + > { Box::pin(async move { - let endpoints = Self::parse_endpoints(&input); - if endpoints.is_empty() { - return Ok(PentestToolResult { - summary: "No endpoints provided to test.".to_string(), - findings: Vec::new(), - data: json!({}), - }); - } + let endpoints = Self::parse_endpoints(&input); + if endpoints.is_empty() { + return Ok(PentestToolResult { + summary: "No endpoints provided to test.".to_string(), + findings: Vec::new(), + data: json!({}), + }); + } - let dast_context = DastContext { - endpoints, - technologies: Vec::new(), - sast_hints: Vec::new(), - }; + let dast_context = DastContext { + endpoints, + technologies: Vec::new(), + sast_hints: Vec::new(), + }; - let findings = self.agent.run(&context.target, &dast_context).await?; - let count = findings.len(); + let findings = self.agent.run(&context.target, &dast_context).await?; + let count = findings.len(); - Ok(PentestToolResult { - summary: if count > 0 { - format!("Found {count} XSS vulnerabilities.") - } else { - "No XSS vulnerabilities detected.".to_string() - }, - findings, - data: json!({ "endpoints_tested": dast_context.endpoints.len() }), - }) + Ok(PentestToolResult { + summary: if count > 0 { + format!("Found {count} XSS vulnerabilities.") + } else { + "No XSS vulnerabilities detected.".to_string() + }, + findings, + data: json!({ "endpoints_tested": dast_context.endpoints.len() }), + }) }) } } diff --git a/compliance-dast/tests/agents.rs b/compliance-dast/tests/agents.rs new file mode 100644 index 0000000..5e2483b --- /dev/null +++ b/compliance-dast/tests/agents.rs @@ -0,0 +1,4 @@ +// Integration tests for DAST agents. +// +// Test individual security testing agents (XSS, SQLi, SSRF, etc.) +// against controlled test targets. diff --git a/compliance-graph/src/graph/chunking.rs b/compliance-graph/src/graph/chunking.rs index ebbc5a0..9ba2b39 100644 --- a/compliance-graph/src/graph/chunking.rs +++ b/compliance-graph/src/graph/chunking.rs @@ -94,3 +94,64 @@ fn build_context_header(file_path: &str, qualified_name: &str, kind: &str) -> St format!("// {file_path} | {kind}") } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_build_context_header_with_parent() { + let result = + build_context_header("src/main.rs", "src/main.rs::MyStruct::my_method", "method"); + assert_eq!(result, "// src/main.rs | method in src/main.rs::MyStruct"); + } + + #[test] + fn test_build_context_header_top_level() { + let result = build_context_header("src/lib.rs", "main", "function"); + assert_eq!(result, "// src/lib.rs | function"); + } + + #[test] + fn test_build_context_header_single_parent() { + let result = build_context_header("src/lib.rs", "src/lib.rs::do_stuff", "function"); + assert_eq!(result, "// src/lib.rs | function in src/lib.rs"); + } + + #[test] + fn test_build_context_header_deep_nesting() { + let result = build_context_header( + "src/mod.rs", + "src/mod.rs::Outer::Inner::deep_fn", + "function", + ); + assert_eq!( + result, + "// src/mod.rs | function in src/mod.rs::Outer::Inner" + ); + } + + #[test] + fn test_build_context_header_empty_strings() { + let result = build_context_header("", "", "function"); + assert_eq!(result, "// | function"); + } + + #[test] + fn test_code_chunk_struct_fields() { + let chunk = CodeChunk { + qualified_name: "main".to_string(), + kind: "function".to_string(), + file_path: "src/main.rs".to_string(), + start_line: 1, + end_line: 10, + language: "rust".to_string(), + content: "fn main() {}".to_string(), + context_header: "// src/main.rs | function".to_string(), + token_estimate: 3, + }; + assert_eq!(chunk.start_line, 1); + assert_eq!(chunk.end_line, 10); + assert_eq!(chunk.language, "rust"); + } +} diff --git a/compliance-graph/src/graph/community.rs b/compliance-graph/src/graph/community.rs index 799d140..f7225bd 100644 --- a/compliance-graph/src/graph/community.rs +++ b/compliance-graph/src/graph/community.rs @@ -253,3 +253,215 @@ fn detect_communities_with_assignment(code_graph: &mut CodeGraph) -> u32 { next_id } + +#[cfg(test)] +mod tests { + use super::*; + use compliance_core::models::graph::{CodeEdgeKind, CodeNode, CodeNodeKind}; + use petgraph::graph::DiGraph; + + fn make_node(qualified_name: &str, graph_index: u32) -> CodeNode { + CodeNode { + id: None, + repo_id: "test".to_string(), + graph_build_id: "build1".to_string(), + qualified_name: qualified_name.to_string(), + name: qualified_name.to_string(), + kind: CodeNodeKind::Function, + file_path: "test.rs".to_string(), + start_line: 1, + end_line: 10, + language: "rust".to_string(), + community_id: None, + is_entry_point: false, + graph_index: Some(graph_index), + } + } + + fn make_empty_code_graph() -> CodeGraph { + CodeGraph { + graph: DiGraph::new(), + node_map: HashMap::new(), + nodes: Vec::new(), + edges: Vec::new(), + } + } + + #[test] + fn test_detect_communities_empty_graph() { + let cg = make_empty_code_graph(); + assert_eq!(detect_communities(&cg), 0); + } + + #[test] + fn test_detect_communities_single_node_no_edges() { + let mut graph = DiGraph::new(); + let idx = graph.add_node("a".to_string()); + let mut node_map = HashMap::new(); + node_map.insert("a".to_string(), idx); + + let cg = CodeGraph { + graph, + node_map, + nodes: vec![make_node("a", 0)], + edges: Vec::new(), + }; + // Single node with no edges => 1 community (itself) + assert_eq!(detect_communities(&cg), 1); + } + + #[test] + fn test_detect_communities_isolated_nodes() { + let mut graph = DiGraph::new(); + let a = graph.add_node("a".to_string()); + let b = graph.add_node("b".to_string()); + let c = graph.add_node("c".to_string()); + let mut node_map = HashMap::new(); + node_map.insert("a".to_string(), a); + node_map.insert("b".to_string(), b); + node_map.insert("c".to_string(), c); + + let cg = CodeGraph { + graph, + node_map, + nodes: vec![make_node("a", 0), make_node("b", 1), make_node("c", 2)], + edges: Vec::new(), + }; + // 3 isolated nodes => 3 communities + assert_eq!(detect_communities(&cg), 3); + } + + #[test] + fn test_detect_communities_fully_connected() { + let mut graph = DiGraph::new(); + let a = graph.add_node("a".to_string()); + let b = graph.add_node("b".to_string()); + let c = graph.add_node("c".to_string()); + graph.add_edge(a, b, CodeEdgeKind::Calls); + graph.add_edge(b, c, CodeEdgeKind::Calls); + graph.add_edge(c, a, CodeEdgeKind::Calls); + + let mut node_map = HashMap::new(); + node_map.insert("a".to_string(), a); + node_map.insert("b".to_string(), b); + node_map.insert("c".to_string(), c); + + let cg = CodeGraph { + graph, + node_map, + nodes: vec![make_node("a", 0), make_node("b", 1), make_node("c", 2)], + edges: Vec::new(), + }; + let num = detect_communities(&cg); + // Fully connected triangle should converge to 1 community + assert!(num >= 1); + assert!(num <= 3); + } + + #[test] + fn test_detect_communities_two_clusters() { + // Two separate triangles connected by a single weak edge + let mut graph = DiGraph::new(); + let a = graph.add_node("a".to_string()); + let b = graph.add_node("b".to_string()); + let c = graph.add_node("c".to_string()); + let d = graph.add_node("d".to_string()); + let e = graph.add_node("e".to_string()); + let f = graph.add_node("f".to_string()); + + // Cluster 1: a-b-c fully connected + graph.add_edge(a, b, CodeEdgeKind::Calls); + graph.add_edge(b, a, CodeEdgeKind::Calls); + graph.add_edge(b, c, CodeEdgeKind::Calls); + graph.add_edge(c, b, CodeEdgeKind::Calls); + graph.add_edge(a, c, CodeEdgeKind::Calls); + graph.add_edge(c, a, CodeEdgeKind::Calls); + + // Cluster 2: d-e-f fully connected + graph.add_edge(d, e, CodeEdgeKind::Calls); + graph.add_edge(e, d, CodeEdgeKind::Calls); + graph.add_edge(e, f, CodeEdgeKind::Calls); + graph.add_edge(f, e, CodeEdgeKind::Calls); + graph.add_edge(d, f, CodeEdgeKind::Calls); + graph.add_edge(f, d, CodeEdgeKind::Calls); + + let mut node_map = HashMap::new(); + node_map.insert("a".to_string(), a); + node_map.insert("b".to_string(), b); + node_map.insert("c".to_string(), c); + node_map.insert("d".to_string(), d); + node_map.insert("e".to_string(), e); + node_map.insert("f".to_string(), f); + + let cg = CodeGraph { + graph, + node_map, + nodes: vec![ + make_node("a", 0), + make_node("b", 1), + make_node("c", 2), + make_node("d", 3), + make_node("e", 4), + make_node("f", 5), + ], + edges: Vec::new(), + }; + let num = detect_communities(&cg); + // Two disconnected clusters should yield 2 communities + assert_eq!(num, 2); + } + + #[test] + fn test_apply_communities_assigns_ids() { + let mut graph = DiGraph::new(); + let a = graph.add_node("a".to_string()); + let b = graph.add_node("b".to_string()); + graph.add_edge(a, b, CodeEdgeKind::Calls); + + let mut node_map = HashMap::new(); + node_map.insert("a".to_string(), a); + node_map.insert("b".to_string(), b); + + let mut cg = CodeGraph { + graph, + node_map, + nodes: vec![make_node("a", 0), make_node("b", 1)], + edges: Vec::new(), + }; + let count = apply_communities(&mut cg); + assert!(count >= 1); + // All nodes should have a community_id assigned + for node in &cg.nodes { + assert!(node.community_id.is_some()); + } + } + + #[test] + fn test_apply_communities_empty() { + let mut cg = make_empty_code_graph(); + assert_eq!(apply_communities(&mut cg), 0); + } + + #[test] + fn test_apply_communities_isolated_nodes_get_own_community() { + let mut graph = DiGraph::new(); + let a = graph.add_node("a".to_string()); + let b = graph.add_node("b".to_string()); + let mut node_map = HashMap::new(); + node_map.insert("a".to_string(), a); + node_map.insert("b".to_string(), b); + + let mut cg = CodeGraph { + graph, + node_map, + nodes: vec![make_node("a", 0), make_node("b", 1)], + edges: Vec::new(), + }; + let count = apply_communities(&mut cg); + assert_eq!(count, 2); + // Each isolated node should be in a different community + let c0 = cg.nodes[0].community_id.unwrap(); + let c1 = cg.nodes[1].community_id.unwrap(); + assert_ne!(c0, c1); + } +} diff --git a/compliance-graph/src/graph/engine.rs b/compliance-graph/src/graph/engine.rs index 5ec71c3..431a54f 100644 --- a/compliance-graph/src/graph/engine.rs +++ b/compliance-graph/src/graph/engine.rs @@ -172,3 +172,185 @@ impl GraphEngine { ImpactAnalyzer::new(code_graph) } } + +#[cfg(test)] +mod tests { + use super::*; + use compliance_core::models::graph::{CodeEdgeKind, CodeNode, CodeNodeKind}; + + fn make_node(qualified_name: &str) -> CodeNode { + CodeNode { + id: None, + repo_id: "test".to_string(), + graph_build_id: "build1".to_string(), + qualified_name: qualified_name.to_string(), + name: qualified_name + .split("::") + .last() + .unwrap_or(qualified_name) + .to_string(), + kind: CodeNodeKind::Function, + file_path: "src/main.rs".to_string(), + start_line: 1, + end_line: 10, + language: "rust".to_string(), + community_id: None, + is_entry_point: false, + graph_index: None, + } + } + + fn build_test_node_map(names: &[&str]) -> HashMap { + let mut graph: DiGraph = DiGraph::new(); + let mut map = HashMap::new(); + for name in names { + let idx = graph.add_node(name.to_string()); + map.insert(name.to_string(), idx); + } + map + } + + #[test] + fn test_resolve_edge_target_direct_match() { + let engine = GraphEngine::new(1000); + let node_map = build_test_node_map(&["src/main.rs::foo", "src/main.rs::bar"]); + let result = engine.resolve_edge_target("src/main.rs::foo", &node_map); + assert!(result.is_some()); + assert_eq!(result.unwrap(), node_map["src/main.rs::foo"]); + } + + #[test] + fn test_resolve_edge_target_short_name_match() { + let engine = GraphEngine::new(1000); + let node_map = build_test_node_map(&["src/main.rs::foo", "src/main.rs::bar"]); + let result = engine.resolve_edge_target("foo", &node_map); + assert!(result.is_some()); + assert_eq!(result.unwrap(), node_map["src/main.rs::foo"]); + } + + #[test] + fn test_resolve_edge_target_method_match() { + let engine = GraphEngine::new(1000); + let node_map = build_test_node_map(&["src/main.rs::MyStruct::do_thing"]); + let result = engine.resolve_edge_target("do_thing", &node_map); + assert!(result.is_some()); + } + + #[test] + fn test_resolve_edge_target_self_method() { + let engine = GraphEngine::new(1000); + let node_map = build_test_node_map(&["src/main.rs::MyStruct::process"]); + let result = engine.resolve_edge_target("self.process", &node_map); + assert!(result.is_some()); + } + + #[test] + fn test_resolve_edge_target_no_match() { + let engine = GraphEngine::new(1000); + let node_map = build_test_node_map(&["src/main.rs::foo"]); + let result = engine.resolve_edge_target("nonexistent", &node_map); + assert!(result.is_none()); + } + + #[test] + fn test_resolve_edge_target_empty_map() { + let engine = GraphEngine::new(1000); + let node_map = HashMap::new(); + let result = engine.resolve_edge_target("anything", &node_map); + assert!(result.is_none()); + } + + #[test] + fn test_resolve_edge_target_dot_notation() { + let engine = GraphEngine::new(1000); + let node_map = build_test_node_map(&["src/app.js.handler"]); + let result = engine.resolve_edge_target("handler", &node_map); + assert!(result.is_some()); + } + + #[test] + fn test_build_petgraph_empty() { + let engine = GraphEngine::new(1000); + let output = ParseOutput::default(); + let code_graph = engine.build_petgraph(output).unwrap(); + assert_eq!(code_graph.nodes.len(), 0); + assert_eq!(code_graph.edges.len(), 0); + assert_eq!(code_graph.graph.node_count(), 0); + } + + #[test] + fn test_build_petgraph_nodes_get_graph_index() { + let engine = GraphEngine::new(1000); + let mut output = ParseOutput::default(); + output.nodes.push(make_node("src/main.rs::foo")); + output.nodes.push(make_node("src/main.rs::bar")); + + let code_graph = engine.build_petgraph(output).unwrap(); + assert_eq!(code_graph.nodes.len(), 2); + assert_eq!(code_graph.graph.node_count(), 2); + // All nodes should have a graph_index assigned + for node in &code_graph.nodes { + assert!(node.graph_index.is_some()); + } + } + + #[test] + fn test_build_petgraph_resolves_edges() { + let engine = GraphEngine::new(1000); + let mut output = ParseOutput::default(); + output.nodes.push(make_node("src/main.rs::foo")); + output.nodes.push(make_node("src/main.rs::bar")); + output.edges.push(CodeEdge { + id: None, + repo_id: "test".to_string(), + graph_build_id: "build1".to_string(), + source: "src/main.rs::foo".to_string(), + target: "bar".to_string(), // short name, should resolve + kind: CodeEdgeKind::Calls, + file_path: "src/main.rs".to_string(), + line_number: Some(5), + }); + + let code_graph = engine.build_petgraph(output).unwrap(); + assert_eq!(code_graph.edges.len(), 1); + assert_eq!(code_graph.graph.edge_count(), 1); + // The resolved edge target should be the full qualified name + assert_eq!(code_graph.edges[0].target, "src/main.rs::bar"); + } + + #[test] + fn test_build_petgraph_skips_unresolved_edges() { + let engine = GraphEngine::new(1000); + let mut output = ParseOutput::default(); + output.nodes.push(make_node("src/main.rs::foo")); + output.edges.push(CodeEdge { + id: None, + repo_id: "test".to_string(), + graph_build_id: "build1".to_string(), + source: "src/main.rs::foo".to_string(), + target: "external_crate::something".to_string(), + kind: CodeEdgeKind::Calls, + file_path: "src/main.rs".to_string(), + line_number: Some(5), + }); + + let code_graph = engine.build_petgraph(output).unwrap(); + assert_eq!(code_graph.edges.len(), 0); + assert_eq!(code_graph.graph.edge_count(), 0); + } + + #[test] + fn test_code_graph_node_map_consistency() { + let engine = GraphEngine::new(1000); + let mut output = ParseOutput::default(); + output.nodes.push(make_node("a::b")); + output.nodes.push(make_node("a::c")); + output.nodes.push(make_node("a::d")); + + let code_graph = engine.build_petgraph(output).unwrap(); + assert_eq!(code_graph.node_map.len(), 3); + assert!(code_graph.node_map.contains_key("a::b")); + assert!(code_graph.node_map.contains_key("a::c")); + assert!(code_graph.node_map.contains_key("a::d")); + } +} diff --git a/compliance-graph/src/graph/impact.rs b/compliance-graph/src/graph/impact.rs index bd14543..7ee2eba 100644 --- a/compliance-graph/src/graph/impact.rs +++ b/compliance-graph/src/graph/impact.rs @@ -222,3 +222,378 @@ impl<'a> ImpactAnalyzer<'a> { .find(|n| n.graph_index == Some(target_gi)) } } + +#[cfg(test)] +mod tests { + use super::*; + use compliance_core::models::graph::{CodeEdgeKind, CodeNode, CodeNodeKind}; + use petgraph::graph::DiGraph; + use std::collections::HashMap; + + fn make_node( + qualified_name: &str, + file_path: &str, + start: u32, + end: u32, + graph_index: u32, + is_entry: bool, + kind: CodeNodeKind, + ) -> CodeNode { + CodeNode { + id: None, + repo_id: "test".to_string(), + graph_build_id: "build1".to_string(), + qualified_name: qualified_name.to_string(), + name: qualified_name + .split("::") + .last() + .unwrap_or(qualified_name) + .to_string(), + kind, + file_path: file_path.to_string(), + start_line: start, + end_line: end, + language: "rust".to_string(), + community_id: None, + is_entry_point: is_entry, + graph_index: Some(graph_index), + } + } + + fn make_fn_node( + qualified_name: &str, + file_path: &str, + start: u32, + end: u32, + gi: u32, + ) -> CodeNode { + make_node( + qualified_name, + file_path, + start, + end, + gi, + false, + CodeNodeKind::Function, + ) + } + + /// Build a simple linear graph: A -> B -> C + fn build_linear_graph() -> CodeGraph { + let mut graph = DiGraph::new(); + let a = graph.add_node("a".to_string()); + let b = graph.add_node("b".to_string()); + let c = graph.add_node("c".to_string()); + graph.add_edge(a, b, CodeEdgeKind::Calls); + graph.add_edge(b, c, CodeEdgeKind::Calls); + + let mut node_map = HashMap::new(); + node_map.insert("a".to_string(), a); + node_map.insert("b".to_string(), b); + node_map.insert("c".to_string(), c); + + CodeGraph { + graph, + node_map, + nodes: vec![ + make_fn_node("a", "src/main.rs", 1, 5, 0), + make_fn_node("b", "src/main.rs", 7, 12, 1), + make_fn_node("c", "src/main.rs", 14, 20, 2), + ], + edges: Vec::new(), + } + } + + #[test] + fn test_bfs_reachable_outgoing_linear() { + let cg = build_linear_graph(); + let analyzer = ImpactAnalyzer::new(&cg); + let start = cg.node_map["a"]; + let reachable = analyzer.bfs_reachable(start, Direction::Outgoing); + // From a, we can reach b and c + assert_eq!(reachable.len(), 2); + assert!(reachable.contains(&cg.node_map["b"])); + assert!(reachable.contains(&cg.node_map["c"])); + } + + #[test] + fn test_bfs_reachable_incoming_linear() { + let cg = build_linear_graph(); + let analyzer = ImpactAnalyzer::new(&cg); + let start = cg.node_map["c"]; + let reachable = analyzer.bfs_reachable(start, Direction::Incoming); + // c is reached by a and b + assert_eq!(reachable.len(), 2); + assert!(reachable.contains(&cg.node_map["a"])); + assert!(reachable.contains(&cg.node_map["b"])); + } + + #[test] + fn test_bfs_reachable_no_neighbors() { + let mut graph = DiGraph::new(); + let a = graph.add_node("a".to_string()); + let cg = CodeGraph { + graph, + node_map: [("a".to_string(), a)].into_iter().collect(), + nodes: vec![make_fn_node("a", "src/main.rs", 1, 5, 0)], + edges: Vec::new(), + }; + let analyzer = ImpactAnalyzer::new(&cg); + let reachable = analyzer.bfs_reachable(a, Direction::Outgoing); + assert!(reachable.is_empty()); + } + + #[test] + fn test_bfs_reachable_cycle() { + let mut graph = DiGraph::new(); + let a = graph.add_node("a".to_string()); + let b = graph.add_node("b".to_string()); + graph.add_edge(a, b, CodeEdgeKind::Calls); + graph.add_edge(b, a, CodeEdgeKind::Calls); + + let cg = CodeGraph { + graph, + node_map: [("a".to_string(), a), ("b".to_string(), b)] + .into_iter() + .collect(), + nodes: vec![ + make_fn_node("a", "f.rs", 1, 5, 0), + make_fn_node("b", "f.rs", 6, 10, 1), + ], + edges: Vec::new(), + }; + let analyzer = ImpactAnalyzer::new(&cg); + let reachable = analyzer.bfs_reachable(a, Direction::Outgoing); + // Should handle cycle without infinite loop + assert_eq!(reachable.len(), 1); + assert!(reachable.contains(&b)); + } + + #[test] + fn test_find_path_exists() { + let cg = build_linear_graph(); + let analyzer = ImpactAnalyzer::new(&cg); + let path = analyzer.find_path(cg.node_map["a"], cg.node_map["c"], 10); + assert!(path.is_some()); + let names = path.unwrap(); + assert_eq!(names, vec!["a", "b", "c"]); + } + + #[test] + fn test_find_path_direct() { + let cg = build_linear_graph(); + let analyzer = ImpactAnalyzer::new(&cg); + let path = analyzer.find_path(cg.node_map["a"], cg.node_map["b"], 10); + assert!(path.is_some()); + let names = path.unwrap(); + assert_eq!(names, vec!["a", "b"]); + } + + #[test] + fn test_find_path_same_node() { + let cg = build_linear_graph(); + let analyzer = ImpactAnalyzer::new(&cg); + let path = analyzer.find_path(cg.node_map["a"], cg.node_map["a"], 10); + assert!(path.is_some()); + let names = path.unwrap(); + assert_eq!(names, vec!["a"]); + } + + #[test] + fn test_find_path_no_connection() { + let mut graph = DiGraph::new(); + let a = graph.add_node("a".to_string()); + let b = graph.add_node("b".to_string()); + // No edge between a and b + let cg = CodeGraph { + graph, + node_map: [("a".to_string(), a), ("b".to_string(), b)] + .into_iter() + .collect(), + nodes: vec![ + make_fn_node("a", "f.rs", 1, 5, 0), + make_fn_node("b", "f.rs", 6, 10, 1), + ], + edges: Vec::new(), + }; + let analyzer = ImpactAnalyzer::new(&cg); + let path = analyzer.find_path(a, b, 10); + assert!(path.is_none()); + } + + #[test] + fn test_find_path_depth_limited() { + // Build a long chain: a -> b -> c -> d -> e + let mut graph = DiGraph::new(); + let a = graph.add_node("a".to_string()); + let b = graph.add_node("b".to_string()); + let c = graph.add_node("c".to_string()); + let d = graph.add_node("d".to_string()); + let e = graph.add_node("e".to_string()); + graph.add_edge(a, b, CodeEdgeKind::Calls); + graph.add_edge(b, c, CodeEdgeKind::Calls); + graph.add_edge(c, d, CodeEdgeKind::Calls); + graph.add_edge(d, e, CodeEdgeKind::Calls); + + let mut node_map = HashMap::new(); + node_map.insert("a".to_string(), a); + node_map.insert("b".to_string(), b); + node_map.insert("c".to_string(), c); + node_map.insert("d".to_string(), d); + node_map.insert("e".to_string(), e); + + let cg = CodeGraph { + graph, + node_map, + nodes: vec![ + make_fn_node("a", "f.rs", 1, 2, 0), + make_fn_node("b", "f.rs", 3, 4, 1), + make_fn_node("c", "f.rs", 5, 6, 2), + make_fn_node("d", "f.rs", 7, 8, 3), + make_fn_node("e", "f.rs", 9, 10, 4), + ], + edges: Vec::new(), + }; + let analyzer = ImpactAnalyzer::new(&cg); + // Depth 3 won't reach e from a (path length 5) + let path = analyzer.find_path(a, e, 3); + assert!(path.is_none()); + // Depth 5 should reach + let path = analyzer.find_path(a, e, 5); + assert!(path.is_some()); + } + + #[test] + fn test_find_node_at_location_exact_line() { + let cg = build_linear_graph(); + let analyzer = ImpactAnalyzer::new(&cg); + // Node "b" is at lines 7-12 + let result = analyzer.find_node_at_location("src/main.rs", Some(9)); + assert!(result.is_some()); + assert_eq!(result.unwrap(), cg.node_map["b"]); + } + + #[test] + fn test_find_node_at_location_narrowest_match() { + // Outer function 1-20, inner nested 5-10 + let mut graph = DiGraph::new(); + let outer = graph.add_node("outer".to_string()); + let inner = graph.add_node("inner".to_string()); + + let cg = CodeGraph { + graph, + node_map: [("outer".to_string(), outer), ("inner".to_string(), inner)] + .into_iter() + .collect(), + nodes: vec![ + make_fn_node("outer", "src/main.rs", 1, 20, 0), + make_fn_node("inner", "src/main.rs", 5, 10, 1), + ], + edges: Vec::new(), + }; + let analyzer = ImpactAnalyzer::new(&cg); + // Line 7 is inside both, but inner is narrower + let result = analyzer.find_node_at_location("src/main.rs", Some(7)); + assert!(result.is_some()); + assert_eq!(result.unwrap(), inner); + } + + #[test] + fn test_find_node_at_location_no_line_returns_file_node() { + let mut graph = DiGraph::new(); + let file_node = graph.add_node("src/main.rs".to_string()); + let fn_node = graph.add_node("src/main.rs::foo".to_string()); + + let cg = CodeGraph { + graph, + node_map: [ + ("src/main.rs".to_string(), file_node), + ("src/main.rs::foo".to_string(), fn_node), + ] + .into_iter() + .collect(), + nodes: vec![ + make_node( + "src/main.rs", + "src/main.rs", + 1, + 100, + 0, + false, + CodeNodeKind::File, + ), + make_fn_node("src/main.rs::foo", "src/main.rs", 5, 10, 1), + ], + edges: Vec::new(), + }; + let analyzer = ImpactAnalyzer::new(&cg); + let result = analyzer.find_node_at_location("src/main.rs", None); + assert!(result.is_some()); + assert_eq!(result.unwrap(), file_node); + } + + #[test] + fn test_find_node_at_location_wrong_file() { + let cg = build_linear_graph(); + let analyzer = ImpactAnalyzer::new(&cg); + let result = analyzer.find_node_at_location("nonexistent.rs", Some(5)); + assert!(result.is_none()); + } + + #[test] + fn test_find_node_at_location_line_out_of_range() { + let cg = build_linear_graph(); + let analyzer = ImpactAnalyzer::new(&cg); + let result = analyzer.find_node_at_location("src/main.rs", Some(999)); + assert!(result.is_none()); + } + + #[test] + fn test_analyze_basic() { + // A (entry) -> B -> C + let mut graph = DiGraph::new(); + let a = graph.add_node("a".to_string()); + let b = graph.add_node("b".to_string()); + let c = graph.add_node("c".to_string()); + graph.add_edge(a, b, CodeEdgeKind::Calls); + graph.add_edge(b, c, CodeEdgeKind::Calls); + + let mut node_map = HashMap::new(); + node_map.insert("a".to_string(), a); + node_map.insert("b".to_string(), b); + node_map.insert("c".to_string(), c); + + let cg = CodeGraph { + graph, + node_map, + nodes: vec![ + make_node("a", "src/main.rs", 1, 5, 0, true, CodeNodeKind::Function), + make_fn_node("b", "src/main.rs", 7, 12, 1), + make_fn_node("c", "src/main.rs", 14, 20, 2), + ], + edges: Vec::new(), + }; + + let analyzer = ImpactAnalyzer::new(&cg); + let result = analyzer.analyze("repo1", "finding1", "build1", "src/main.rs", Some(9)); + // B's blast radius: C is reachable forward + assert_eq!(result.blast_radius, 1); + // B has A as direct caller + assert_eq!(result.direct_callers, vec!["a"]); + // B calls C + assert_eq!(result.direct_callees, vec!["c"]); + // A is an entry point that reaches B + assert_eq!(result.affected_entry_points, vec!["a"]); + } + + #[test] + fn test_analyze_no_matching_node() { + let cg = build_linear_graph(); + let analyzer = ImpactAnalyzer::new(&cg); + let result = analyzer.analyze("repo1", "f1", "b1", "nonexistent.rs", Some(1)); + assert_eq!(result.blast_radius, 0); + assert!(result.affected_entry_points.is_empty()); + assert!(result.direct_callers.is_empty()); + assert!(result.direct_callees.is_empty()); + } +} diff --git a/compliance-graph/src/parsers/registry.rs b/compliance-graph/src/parsers/registry.rs index 0d42809..ed834ea 100644 --- a/compliance-graph/src/parsers/registry.rs +++ b/compliance-graph/src/parsers/registry.rs @@ -184,3 +184,115 @@ impl Default for ParserRegistry { Self::new() } } + +#[cfg(test)] +mod tests { + use super::*; + use std::path::PathBuf; + + #[test] + fn test_supports_rust_extension() { + let registry = ParserRegistry::new(); + assert!(registry.supports_extension("rs")); + } + + #[test] + fn test_supports_python_extension() { + let registry = ParserRegistry::new(); + assert!(registry.supports_extension("py")); + } + + #[test] + fn test_supports_javascript_extension() { + let registry = ParserRegistry::new(); + assert!(registry.supports_extension("js")); + } + + #[test] + fn test_supports_typescript_extension() { + let registry = ParserRegistry::new(); + assert!(registry.supports_extension("ts")); + } + + #[test] + fn test_does_not_support_unknown_extension() { + let registry = ParserRegistry::new(); + assert!(!registry.supports_extension("go")); + assert!(!registry.supports_extension("java")); + assert!(!registry.supports_extension("cpp")); + assert!(!registry.supports_extension("")); + } + + #[test] + fn test_supported_extensions_includes_all() { + let registry = ParserRegistry::new(); + let exts = registry.supported_extensions(); + assert!(exts.contains(&"rs")); + assert!(exts.contains(&"py")); + assert!(exts.contains(&"js")); + assert!(exts.contains(&"ts")); + } + + #[test] + fn test_supported_extensions_count() { + let registry = ParserRegistry::new(); + let exts = registry.supported_extensions(); + // At least 4 extensions (rs, py, js, ts); could be more if tsx, jsx etc. + assert!(exts.len() >= 4); + } + + #[test] + fn test_parse_file_returns_none_for_unsupported() { + let registry = ParserRegistry::new(); + let path = PathBuf::from("test.go"); + let result = registry.parse_file(&path, "package main", "repo1", "build1"); + assert!(result.is_ok()); + assert!(result.unwrap().is_none()); + } + + #[test] + fn test_parse_file_rust_source() { + let registry = ParserRegistry::new(); + let path = PathBuf::from("src/main.rs"); + let source = "fn main() {\n println!(\"hello\");\n}\n"; + let result = registry.parse_file(&path, source, "repo1", "build1"); + assert!(result.is_ok()); + let output = result.unwrap(); + assert!(output.is_some()); + let output = output.unwrap(); + // Should have at least the file node and the main function node + assert!(output.nodes.len() >= 2); + } + + #[test] + fn test_parse_file_python_source() { + let registry = ParserRegistry::new(); + let path = PathBuf::from("app.py"); + let source = "def hello():\n print('hi')\n"; + let result = registry.parse_file(&path, source, "repo1", "build1"); + assert!(result.is_ok()); + let output = result.unwrap(); + assert!(output.is_some()); + let output = output.unwrap(); + assert!(!output.nodes.is_empty()); + } + + #[test] + fn test_parse_file_empty_source() { + let registry = ParserRegistry::new(); + let path = PathBuf::from("empty.rs"); + let result = registry.parse_file(&path, "", "repo1", "build1"); + assert!(result.is_ok()); + let output = result.unwrap(); + assert!(output.is_some()); + // At minimum the file node + let output = output.unwrap(); + assert!(!output.nodes.is_empty()); + } + + #[test] + fn test_default_trait() { + let registry = ParserRegistry::default(); + assert!(registry.supports_extension("rs")); + } +} diff --git a/compliance-graph/src/parsers/rust_parser.rs b/compliance-graph/src/parsers/rust_parser.rs index 391a7d4..9936c17 100644 --- a/compliance-graph/src/parsers/rust_parser.rs +++ b/compliance-graph/src/parsers/rust_parser.rs @@ -363,6 +363,214 @@ impl RustParser { } } +#[cfg(test)] +mod tests { + use super::*; + use compliance_core::traits::graph_builder::LanguageParser; + use std::path::PathBuf; + + fn parse_rust(source: &str) -> ParseOutput { + let parser = RustParser::new(); + parser + .parse_file(&PathBuf::from("test.rs"), source, "repo1", "build1") + .unwrap() + } + + #[test] + fn test_extract_use_path_simple() { + let parser = RustParser::new(); + assert_eq!( + parser.extract_use_path("use std::collections::HashMap;"), + Some("std::collections::HashMap".to_string()) + ); + } + + #[test] + fn test_extract_use_path_nested() { + let parser = RustParser::new(); + assert_eq!( + parser.extract_use_path("use crate::models::graph::CodeNode;"), + Some("crate::models::graph::CodeNode".to_string()) + ); + } + + #[test] + fn test_extract_use_path_no_prefix() { + let parser = RustParser::new(); + assert_eq!(parser.extract_use_path("let x = 5;"), None); + } + + #[test] + fn test_extract_use_path_empty() { + let parser = RustParser::new(); + assert_eq!(parser.extract_use_path(""), None); + } + + #[test] + fn test_parse_function() { + let output = parse_rust("fn hello() {\n let x = 1;\n}\n"); + let fn_nodes: Vec<_> = output + .nodes + .iter() + .filter(|n| n.kind == CodeNodeKind::Function) + .collect(); + assert_eq!(fn_nodes.len(), 1); + assert_eq!(fn_nodes[0].name, "hello"); + assert!(fn_nodes[0].qualified_name.contains("hello")); + } + + #[test] + fn test_parse_struct() { + let output = parse_rust("struct Foo {\n x: i32,\n}\n"); + let struct_nodes: Vec<_> = output + .nodes + .iter() + .filter(|n| n.kind == CodeNodeKind::Struct) + .collect(); + assert_eq!(struct_nodes.len(), 1); + assert_eq!(struct_nodes[0].name, "Foo"); + } + + #[test] + fn test_parse_enum() { + let output = parse_rust("enum Color {\n Red,\n Blue,\n}\n"); + let enum_nodes: Vec<_> = output + .nodes + .iter() + .filter(|n| n.kind == CodeNodeKind::Enum) + .collect(); + assert_eq!(enum_nodes.len(), 1); + assert_eq!(enum_nodes[0].name, "Color"); + } + + #[test] + fn test_parse_trait() { + let output = parse_rust("trait Drawable {\n fn draw(&self);\n}\n"); + let trait_nodes: Vec<_> = output + .nodes + .iter() + .filter(|n| n.kind == CodeNodeKind::Trait) + .collect(); + assert_eq!(trait_nodes.len(), 1); + assert_eq!(trait_nodes[0].name, "Drawable"); + } + + #[test] + fn test_parse_file_node_always_created() { + let output = parse_rust(""); + let file_nodes: Vec<_> = output + .nodes + .iter() + .filter(|n| n.kind == CodeNodeKind::File) + .collect(); + assert_eq!(file_nodes.len(), 1); + assert_eq!(file_nodes[0].language, "rust"); + } + + #[test] + fn test_parse_multiple_functions() { + let source = "fn foo() {}\nfn bar() {}\nfn baz() {}\n"; + let output = parse_rust(source); + let fn_nodes: Vec<_> = output + .nodes + .iter() + .filter(|n| n.kind == CodeNodeKind::Function) + .collect(); + assert_eq!(fn_nodes.len(), 3); + } + + #[test] + fn test_parse_main_is_entry_point() { + let output = parse_rust("fn main() {\n println!(\"hi\");\n}\n"); + let main_node = output.nodes.iter().find(|n| n.name == "main").unwrap(); + assert!(main_node.is_entry_point); + } + + #[test] + fn test_parse_pub_fn_is_entry_point() { + let output = parse_rust("pub fn handler() {}\n"); + let node = output.nodes.iter().find(|n| n.name == "handler").unwrap(); + assert!(node.is_entry_point); + } + + #[test] + fn test_parse_private_fn_is_not_entry_point() { + let output = parse_rust("fn helper() {}\n"); + let node = output.nodes.iter().find(|n| n.name == "helper").unwrap(); + assert!(!node.is_entry_point); + } + + #[test] + fn test_parse_function_calls_create_edges() { + let source = "fn caller() {\n callee();\n}\nfn callee() {}\n"; + let output = parse_rust(source); + let call_edges: Vec<_> = output + .edges + .iter() + .filter(|e| e.kind == CodeEdgeKind::Calls) + .collect(); + assert!(!call_edges.is_empty()); + assert!(call_edges.iter().any(|e| e.target.contains("callee"))); + } + + #[test] + fn test_parse_use_declaration_creates_import_edge() { + let source = "use std::collections::HashMap;\nfn foo() {}\n"; + let output = parse_rust(source); + let import_edges: Vec<_> = output + .edges + .iter() + .filter(|e| e.kind == CodeEdgeKind::Imports) + .collect(); + assert!(!import_edges.is_empty()); + } + + #[test] + fn test_parse_impl_methods() { + let source = "struct Foo {}\nimpl Foo {\n fn do_thing(&self) {}\n}\n"; + let output = parse_rust(source); + let fn_nodes: Vec<_> = output + .nodes + .iter() + .filter(|n| n.kind == CodeNodeKind::Function) + .collect(); + assert_eq!(fn_nodes.len(), 1); + assert_eq!(fn_nodes[0].name, "do_thing"); + // Method should be qualified under the impl type + assert!(fn_nodes[0].qualified_name.contains("Foo")); + } + + #[test] + fn test_parse_mod_item() { + let source = "mod inner {\n fn nested() {}\n}\n"; + let output = parse_rust(source); + let mod_nodes: Vec<_> = output + .nodes + .iter() + .filter(|n| n.kind == CodeNodeKind::Module) + .collect(); + assert_eq!(mod_nodes.len(), 1); + assert_eq!(mod_nodes[0].name, "inner"); + } + + #[test] + fn test_parse_line_numbers() { + let source = "fn first() {}\n\n\nfn second() {}\n"; + let output = parse_rust(source); + let first = output.nodes.iter().find(|n| n.name == "first").unwrap(); + let second = output.nodes.iter().find(|n| n.name == "second").unwrap(); + assert_eq!(first.start_line, 1); + assert!(second.start_line > first.start_line); + } + + #[test] + fn test_language_and_extensions() { + let parser = RustParser::new(); + assert_eq!(parser.language(), "rust"); + assert_eq!(parser.extensions(), &["rs"]); + } +} + impl LanguageParser for RustParser { fn language(&self) -> &str { "rust" diff --git a/compliance-graph/src/search/index.rs b/compliance-graph/src/search/index.rs index c5e6534..31157ff 100644 --- a/compliance-graph/src/search/index.rs +++ b/compliance-graph/src/search/index.rs @@ -128,3 +128,186 @@ impl SymbolIndex { Ok(results) } } + +#[cfg(test)] +mod tests { + use super::*; + use compliance_core::models::graph::CodeNodeKind; + + fn make_node( + qualified_name: &str, + name: &str, + kind: CodeNodeKind, + file_path: &str, + language: &str, + ) -> CodeNode { + CodeNode { + id: None, + repo_id: "test".to_string(), + graph_build_id: "build1".to_string(), + qualified_name: qualified_name.to_string(), + name: name.to_string(), + kind, + file_path: file_path.to_string(), + start_line: 1, + end_line: 10, + language: language.to_string(), + community_id: None, + is_entry_point: false, + graph_index: None, + } + } + + #[test] + fn test_new_creates_index() { + let index = SymbolIndex::new(); + assert!(index.is_ok()); + } + + #[test] + fn test_index_empty_nodes() { + let index = SymbolIndex::new().unwrap(); + let result = index.index_nodes(&[]); + assert!(result.is_ok()); + } + + #[test] + fn test_index_and_search_single_node() { + let index = SymbolIndex::new().unwrap(); + let nodes = vec![make_node( + "src/main.rs::main", + "main", + CodeNodeKind::Function, + "src/main.rs", + "rust", + )]; + index.index_nodes(&nodes).unwrap(); + + let results = index.search("main", 10).unwrap(); + assert!(!results.is_empty()); + assert_eq!(results[0].name, "main"); + assert_eq!(results[0].qualified_name, "src/main.rs::main"); + } + + #[test] + fn test_search_no_results() { + let index = SymbolIndex::new().unwrap(); + let nodes = vec![make_node( + "src/main.rs::foo", + "foo", + CodeNodeKind::Function, + "src/main.rs", + "rust", + )]; + index.index_nodes(&nodes).unwrap(); + + let results = index.search("zzzznonexistent", 10).unwrap(); + assert!(results.is_empty()); + } + + #[test] + fn test_search_multiple_nodes() { + let index = SymbolIndex::new().unwrap(); + let nodes = vec![ + make_node( + "a.rs::handle_request", + "handle_request", + CodeNodeKind::Function, + "a.rs", + "rust", + ), + make_node( + "b.rs::handle_response", + "handle_response", + CodeNodeKind::Function, + "b.rs", + "rust", + ), + make_node( + "c.rs::process_data", + "process_data", + CodeNodeKind::Function, + "c.rs", + "rust", + ), + ]; + index.index_nodes(&nodes).unwrap(); + + let results = index.search("handle", 10).unwrap(); + assert!(results.len() >= 2); + } + + #[test] + fn test_search_limit() { + let index = SymbolIndex::new().unwrap(); + let mut nodes = Vec::new(); + for i in 0..20 { + nodes.push(make_node( + &format!("mod::func_{i}"), + &format!("func_{i}"), + CodeNodeKind::Function, + "mod.rs", + "rust", + )); + } + index.index_nodes(&nodes).unwrap(); + + let results = index.search("func", 5).unwrap(); + assert!(results.len() <= 5); + } + + #[test] + fn test_search_result_has_score() { + let index = SymbolIndex::new().unwrap(); + let nodes = vec![make_node( + "src/lib.rs::compute", + "compute", + CodeNodeKind::Function, + "src/lib.rs", + "rust", + )]; + index.index_nodes(&nodes).unwrap(); + + let results = index.search("compute", 10).unwrap(); + assert!(!results.is_empty()); + assert!(results[0].score > 0.0); + } + + #[test] + fn test_search_result_fields() { + let index = SymbolIndex::new().unwrap(); + let nodes = vec![make_node( + "src/app.py::MyClass", + "MyClass", + CodeNodeKind::Class, + "src/app.py", + "python", + )]; + index.index_nodes(&nodes).unwrap(); + + let results = index.search("MyClass", 10).unwrap(); + assert_eq!(results.len(), 1); + assert_eq!(results[0].name, "MyClass"); + assert_eq!(results[0].kind, "class"); + assert_eq!(results[0].file_path, "src/app.py"); + assert_eq!(results[0].language, "python"); + } + + #[test] + fn test_search_empty_query() { + let index = SymbolIndex::new().unwrap(); + let nodes = vec![make_node( + "src/lib.rs::foo", + "foo", + CodeNodeKind::Function, + "src/lib.rs", + "rust", + )]; + index.index_nodes(&nodes).unwrap(); + + // Empty query may parse error or return empty - both acceptable + let result = index.search("", 10); + // Just verify it doesn't panic + let _ = result; + } +} diff --git a/compliance-graph/tests/parsers.rs b/compliance-graph/tests/parsers.rs new file mode 100644 index 0000000..1b7aa78 --- /dev/null +++ b/compliance-graph/tests/parsers.rs @@ -0,0 +1,4 @@ +// Tests for language parsers (Rust, TypeScript, JavaScript, Python). +// +// Test AST parsing, symbol extraction, and dependency graph construction +// using fixture source files. diff --git a/compliance-mcp/src/tools/dast.rs b/compliance-mcp/src/tools/dast.rs index bc5b5b9..e87d7ad 100644 --- a/compliance-mcp/src/tools/dast.rs +++ b/compliance-mcp/src/tools/dast.rs @@ -12,6 +12,66 @@ fn cap_limit(limit: Option) -> i64 { limit.unwrap_or(DEFAULT_LIMIT).clamp(1, MAX_LIMIT) } +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn cap_limit_default() { + assert_eq!(cap_limit(None), DEFAULT_LIMIT); + } + + #[test] + fn cap_limit_clamps_high() { + assert_eq!(cap_limit(Some(300)), MAX_LIMIT); + } + + #[test] + fn cap_limit_clamps_low() { + assert_eq!(cap_limit(Some(0)), 1); + } + + #[test] + fn list_dast_findings_params_deserialize() { + let json = serde_json::json!({ + "target_id": "t1", + "scan_run_id": "sr1", + "severity": "critical", + "exploitable": true, + "vuln_type": "sql_injection", + "limit": 10 + }); + let params: ListDastFindingsParams = serde_json::from_value(json).unwrap(); + assert_eq!(params.target_id.as_deref(), Some("t1")); + assert_eq!(params.scan_run_id.as_deref(), Some("sr1")); + assert_eq!(params.severity.as_deref(), Some("critical")); + assert_eq!(params.exploitable, Some(true)); + assert_eq!(params.vuln_type.as_deref(), Some("sql_injection")); + assert_eq!(params.limit, Some(10)); + } + + #[test] + fn list_dast_findings_params_all_optional() { + let params: ListDastFindingsParams = serde_json::from_value(serde_json::json!({})).unwrap(); + assert!(params.target_id.is_none()); + assert!(params.scan_run_id.is_none()); + assert!(params.severity.is_none()); + assert!(params.exploitable.is_none()); + assert!(params.vuln_type.is_none()); + assert!(params.limit.is_none()); + } + + #[test] + fn dast_scan_summary_params_deserialize() { + let params: DastScanSummaryParams = + serde_json::from_value(serde_json::json!({ "target_id": "abc" })).unwrap(); + assert_eq!(params.target_id.as_deref(), Some("abc")); + + let params2: DastScanSummaryParams = serde_json::from_value(serde_json::json!({})).unwrap(); + assert!(params2.target_id.is_none()); + } +} + #[derive(Debug, Deserialize, JsonSchema)] pub struct ListDastFindingsParams { /// Filter by DAST target ID diff --git a/compliance-mcp/src/tools/findings.rs b/compliance-mcp/src/tools/findings.rs index 70489a5..366f71e 100644 --- a/compliance-mcp/src/tools/findings.rs +++ b/compliance-mcp/src/tools/findings.rs @@ -12,6 +12,89 @@ fn cap_limit(limit: Option) -> i64 { limit.unwrap_or(DEFAULT_LIMIT).clamp(1, MAX_LIMIT) } +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn cap_limit_default() { + assert_eq!(cap_limit(None), DEFAULT_LIMIT); + } + + #[test] + fn cap_limit_normal_value() { + assert_eq!(cap_limit(Some(100)), 100); + } + + #[test] + fn cap_limit_exceeds_max() { + assert_eq!(cap_limit(Some(500)), MAX_LIMIT); + assert_eq!(cap_limit(Some(201)), MAX_LIMIT); + } + + #[test] + fn cap_limit_zero_clamped_to_one() { + assert_eq!(cap_limit(Some(0)), 1); + } + + #[test] + fn cap_limit_negative_clamped_to_one() { + assert_eq!(cap_limit(Some(-10)), 1); + } + + #[test] + fn cap_limit_boundary_values() { + assert_eq!(cap_limit(Some(1)), 1); + assert_eq!(cap_limit(Some(MAX_LIMIT)), MAX_LIMIT); + } + + #[test] + fn list_findings_params_deserialize() { + let json = serde_json::json!({ + "repo_id": "abc", + "severity": "high", + "status": "open", + "scan_type": "sast", + "limit": 25 + }); + let params: ListFindingsParams = serde_json::from_value(json).unwrap(); + assert_eq!(params.repo_id.as_deref(), Some("abc")); + assert_eq!(params.severity.as_deref(), Some("high")); + assert_eq!(params.status.as_deref(), Some("open")); + assert_eq!(params.scan_type.as_deref(), Some("sast")); + assert_eq!(params.limit, Some(25)); + } + + #[test] + fn list_findings_params_all_optional() { + let json = serde_json::json!({}); + let params: ListFindingsParams = serde_json::from_value(json).unwrap(); + assert!(params.repo_id.is_none()); + assert!(params.severity.is_none()); + assert!(params.status.is_none()); + assert!(params.scan_type.is_none()); + assert!(params.limit.is_none()); + } + + #[test] + fn get_finding_params_deserialize() { + let json = serde_json::json!({ "id": "507f1f77bcf86cd799439011" }); + let params: GetFindingParams = serde_json::from_value(json).unwrap(); + assert_eq!(params.id, "507f1f77bcf86cd799439011"); + } + + #[test] + fn findings_summary_params_deserialize() { + let json = serde_json::json!({ "repo_id": "r1" }); + let params: FindingsSummaryParams = serde_json::from_value(json).unwrap(); + assert_eq!(params.repo_id.as_deref(), Some("r1")); + + let json2 = serde_json::json!({}); + let params2: FindingsSummaryParams = serde_json::from_value(json2).unwrap(); + assert!(params2.repo_id.is_none()); + } +} + #[derive(Debug, Deserialize, JsonSchema)] pub struct ListFindingsParams { /// Filter by repository ID diff --git a/compliance-mcp/src/tools/pentest.rs b/compliance-mcp/src/tools/pentest.rs index f6c7db9..740d0d4 100644 --- a/compliance-mcp/src/tools/pentest.rs +++ b/compliance-mcp/src/tools/pentest.rs @@ -12,6 +12,90 @@ fn cap_limit(limit: Option) -> i64 { limit.unwrap_or(DEFAULT_LIMIT).clamp(1, MAX_LIMIT) } +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn cap_limit_default() { + assert_eq!(cap_limit(None), DEFAULT_LIMIT); + } + + #[test] + fn cap_limit_clamps_high() { + assert_eq!(cap_limit(Some(1000)), MAX_LIMIT); + } + + #[test] + fn cap_limit_clamps_low() { + assert_eq!(cap_limit(Some(-100)), 1); + assert_eq!(cap_limit(Some(0)), 1); + } + + #[test] + fn cap_limit_normal() { + assert_eq!(cap_limit(Some(42)), 42); + } + + #[test] + fn list_pentest_sessions_params_deserialize() { + let json = serde_json::json!({ + "target_id": "tgt", + "status": "running", + "strategy": "aggressive", + "limit": 5 + }); + let params: ListPentestSessionsParams = serde_json::from_value(json).unwrap(); + assert_eq!(params.target_id.as_deref(), Some("tgt")); + assert_eq!(params.status.as_deref(), Some("running")); + assert_eq!(params.strategy.as_deref(), Some("aggressive")); + assert_eq!(params.limit, Some(5)); + } + + #[test] + fn list_pentest_sessions_params_all_optional() { + let params: ListPentestSessionsParams = + serde_json::from_value(serde_json::json!({})).unwrap(); + assert!(params.target_id.is_none()); + assert!(params.status.is_none()); + assert!(params.strategy.is_none()); + assert!(params.limit.is_none()); + } + + #[test] + fn get_pentest_session_params_deserialize() { + let params: GetPentestSessionParams = + serde_json::from_value(serde_json::json!({ "id": "abc123" })).unwrap(); + assert_eq!(params.id, "abc123"); + } + + #[test] + fn get_attack_chain_params_deserialize() { + let params: GetAttackChainParams = + serde_json::from_value(serde_json::json!({ "session_id": "s1", "limit": 20 })).unwrap(); + assert_eq!(params.session_id, "s1"); + assert_eq!(params.limit, Some(20)); + } + + #[test] + fn get_pentest_messages_params_deserialize() { + let params: GetPentestMessagesParams = + serde_json::from_value(serde_json::json!({ "session_id": "s2" })).unwrap(); + assert_eq!(params.session_id, "s2"); + assert!(params.limit.is_none()); + } + + #[test] + fn pentest_stats_params_deserialize() { + let params: PentestStatsParams = + serde_json::from_value(serde_json::json!({ "target_id": "t1" })).unwrap(); + assert_eq!(params.target_id.as_deref(), Some("t1")); + + let params2: PentestStatsParams = serde_json::from_value(serde_json::json!({})).unwrap(); + assert!(params2.target_id.is_none()); + } +} + // ── List Pentest Sessions ────────────────────────────────────── #[derive(Debug, Deserialize, JsonSchema)] diff --git a/compliance-mcp/src/tools/sbom.rs b/compliance-mcp/src/tools/sbom.rs index 78c3648..86806f0 100644 --- a/compliance-mcp/src/tools/sbom.rs +++ b/compliance-mcp/src/tools/sbom.rs @@ -12,6 +12,66 @@ fn cap_limit(limit: Option) -> i64 { limit.unwrap_or(DEFAULT_LIMIT).clamp(1, MAX_LIMIT) } +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn cap_limit_default() { + assert_eq!(cap_limit(None), DEFAULT_LIMIT); + } + + #[test] + fn cap_limit_clamps_high() { + assert_eq!(cap_limit(Some(999)), MAX_LIMIT); + } + + #[test] + fn cap_limit_clamps_low() { + assert_eq!(cap_limit(Some(0)), 1); + assert_eq!(cap_limit(Some(-5)), 1); + } + + #[test] + fn cap_limit_normal() { + assert_eq!(cap_limit(Some(75)), 75); + } + + #[test] + fn list_sbom_params_deserialize() { + let json = serde_json::json!({ + "repo_id": "repo1", + "has_vulns": true, + "package_manager": "npm", + "license": "MIT", + "limit": 30 + }); + let params: ListSbomPackagesParams = serde_json::from_value(json).unwrap(); + assert_eq!(params.repo_id.as_deref(), Some("repo1")); + assert_eq!(params.has_vulns, Some(true)); + assert_eq!(params.package_manager.as_deref(), Some("npm")); + assert_eq!(params.license.as_deref(), Some("MIT")); + assert_eq!(params.limit, Some(30)); + } + + #[test] + fn list_sbom_params_all_optional() { + let params: ListSbomPackagesParams = serde_json::from_value(serde_json::json!({})).unwrap(); + assert!(params.repo_id.is_none()); + assert!(params.has_vulns.is_none()); + assert!(params.package_manager.is_none()); + assert!(params.license.is_none()); + assert!(params.limit.is_none()); + } + + #[test] + fn sbom_vuln_report_params_deserialize() { + let json = serde_json::json!({ "repo_id": "my-repo" }); + let params: SbomVulnReportParams = serde_json::from_value(json).unwrap(); + assert_eq!(params.repo_id, "my-repo"); + } +} + #[derive(Debug, Deserialize, JsonSchema)] pub struct ListSbomPackagesParams { /// Filter by repository ID diff --git a/compliance-mcp/tests/tools.rs b/compliance-mcp/tests/tools.rs new file mode 100644 index 0000000..725e359 --- /dev/null +++ b/compliance-mcp/tests/tools.rs @@ -0,0 +1,4 @@ +// Tests for MCP tool implementations. +// +// Test tool request/response formats, parameter validation, +// and database query construction. diff --git a/fuzz/Cargo.toml b/fuzz/Cargo.toml new file mode 100644 index 0000000..ac80dbe --- /dev/null +++ b/fuzz/Cargo.toml @@ -0,0 +1,16 @@ +[package] +name = "compliance-fuzz" +version = "0.0.0" +publish = false +edition = "2021" + +[dependencies] +libfuzzer-sys = "0.4" +compliance-core = { path = "../compliance-core" } + +# Fuzz targets are defined below. Add new targets as [[bin]] entries. + +[[bin]] +name = "fuzz_finding_dedup" +path = "fuzz_targets/fuzz_finding_dedup.rs" +doc = false diff --git a/fuzz/fuzz_targets/fuzz_finding_dedup.rs b/fuzz/fuzz_targets/fuzz_finding_dedup.rs new file mode 100644 index 0000000..261e4a5 --- /dev/null +++ b/fuzz/fuzz_targets/fuzz_finding_dedup.rs @@ -0,0 +1,12 @@ +#![no_main] +use libfuzzer_sys::fuzz_target; + +// Example fuzz target stub for finding deduplication logic. +// Replace with actual dedup function calls once ready. + +fuzz_target!(|data: &[u8]| { + if let Ok(s) = std::str::from_utf8(data) { + // TODO: Call dedup/fingerprint functions with fuzzed input + let _ = s.len(); + } +});
RequestStatusDetails