From a912ec9ad981074307061ec76b91f2d209cff0b8 Mon Sep 17 00:00:00 2001 From: Sharang Parnerkar Date: Tue, 17 Mar 2026 00:07:50 +0100 Subject: [PATCH] =?UTF-8?q?feat:=20pentest=20feature=20improvements=20?= =?UTF-8?q?=E2=80=94=20streaming,=20pause/resume,=20encryption,=20browser?= =?UTF-8?q?=20tool,=20reports,=20docs?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - True SSE streaming via broadcast channels (DashMap per session) - Session pause/resume with watch channels + dashboard buttons - AES-256-GCM credential encryption at rest (PENTEST_ENCRYPTION_KEY) - Concurrency limiter (Semaphore, max 5 sessions, 429 on overflow) - Browser tool: headless Chrome CDP automation (navigate, click, fill, screenshot, evaluate) - Report code-level correlation: SAST findings, code graph, SBOM linked per DAST finding - Split html.rs (1919 LOC) into html/ module directory (8 files) - Wizard: target/repo dropdowns from existing data, SSH key display, close button on all steps - Auth: auto-register with optional registration URL (Playwright discovery), plus-addressing email, IMAP overrides - Attack chain: tool input/output in detail panel, running node pulse animation - Architecture docs with Mermaid diagrams + 8 screenshots Co-Authored-By: Claude Opus 4.6 (1M context) --- Cargo.lock | 130 +- Cargo.toml | 3 + compliance-agent/Cargo.toml | 3 + compliance-agent/src/agent.rs | 66 + .../api/handlers/pentest_handlers/export.rs | 99 + .../api/handlers/pentest_handlers/session.rs | 470 ++++- .../api/handlers/pentest_handlers/stream.rs | 72 +- compliance-agent/src/api/routes.rs | 12 + compliance-agent/src/pentest/crypto.rs | 117 ++ compliance-agent/src/pentest/mod.rs | 1 + compliance-agent/src/pentest/orchestrator.rs | 103 +- .../src/pentest/prompt_builder.rs | 101 +- compliance-agent/src/pentest/report/html.rs | 1851 ----------------- .../src/pentest/report/html/appendix.rs | 40 + .../src/pentest/report/html/attack_chain.rs | 175 ++ .../src/pentest/report/html/cover.rs | 56 + .../pentest/report/html/executive_summary.rs | 238 +++ .../src/pentest/report/html/findings.rs | 369 ++++ .../src/pentest/report/html/mod.rs | 473 +++++ .../src/pentest/report/html/scope.rs | 127 ++ .../src/pentest/report/html/styles.rs | 889 ++++++++ compliance-agent/src/pentest/report/mod.rs | 13 +- compliance-core/src/models/mod.rs | 6 +- compliance-core/src/models/pentest.rs | 107 + compliance-core/tests/models.rs | 81 + compliance-dashboard/assets/main.css | 309 ++- .../src/components/attack_chain/view.rs | 67 +- compliance-dashboard/src/components/mod.rs | 1 + .../src/components/pentest_wizard.rs | 925 ++++++++ .../src/infrastructure/pentest.rs | 101 + .../src/pages/pentest_dashboard.rs | 186 +- .../src/pages/pentest_session.rs | 53 +- compliance-dast/Cargo.toml | 5 + compliance-dast/src/tools/browser.rs | 488 +++++ compliance-dast/src/tools/mod.rs | 2 + docs/features/pentest-architecture.md | 273 +++ docs/features/pentest.md | 48 +- .../screenshots/pentest-attack-chain.png | Bin 0 -> 80769 bytes docs/public/screenshots/pentest-dashboard.png | Bin 0 -> 101036 bytes .../screenshots/pentest-session-findings.png | Bin 0 -> 112159 bytes .../pentest-wizard-step1-dropdown.png | Bin 0 -> 92004 bytes .../screenshots/pentest-wizard-step1.png | Bin 0 -> 86966 bytes .../screenshots/pentest-wizard-step2-auth.png | Bin 0 -> 103882 bytes .../pentest-wizard-step3-strategy.png | Bin 0 -> 95115 bytes .../pentest-wizard-step4-confirm.png | Bin 0 -> 107589 bytes 45 files changed, 5927 insertions(+), 2133 deletions(-) create mode 100644 compliance-agent/src/pentest/crypto.rs delete mode 100644 compliance-agent/src/pentest/report/html.rs create mode 100644 compliance-agent/src/pentest/report/html/appendix.rs create mode 100644 compliance-agent/src/pentest/report/html/attack_chain.rs create mode 100644 compliance-agent/src/pentest/report/html/cover.rs create mode 100644 compliance-agent/src/pentest/report/html/executive_summary.rs create mode 100644 compliance-agent/src/pentest/report/html/findings.rs create mode 100644 compliance-agent/src/pentest/report/html/mod.rs create mode 100644 compliance-agent/src/pentest/report/html/scope.rs create mode 100644 compliance-agent/src/pentest/report/html/styles.rs create mode 100644 compliance-dashboard/src/components/pentest_wizard.rs create mode 100644 compliance-dast/src/tools/browser.rs create mode 100644 docs/features/pentest-architecture.md create mode 100644 docs/public/screenshots/pentest-attack-chain.png create mode 100644 docs/public/screenshots/pentest-dashboard.png create mode 100644 docs/public/screenshots/pentest-session-findings.png create mode 100644 docs/public/screenshots/pentest-wizard-step1-dropdown.png create mode 100644 docs/public/screenshots/pentest-wizard-step1.png create mode 100644 docs/public/screenshots/pentest-wizard-step2-auth.png create mode 100644 docs/public/screenshots/pentest-wizard-step3-strategy.png create mode 100644 docs/public/screenshots/pentest-wizard-step4-confirm.png diff --git a/Cargo.lock b/Cargo.lock index 7a703ef..bc57a8a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -8,6 +8,16 @@ version = "2.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "320119579fcad9c21884f5c4861d16174d0e06250625266f50fe6898340abefa" +[[package]] +name = "aead" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d122413f284cf2d62fb1b7db97e02edb8cda96d769b16e443a4f6195e35662b0" +dependencies = [ + "crypto-common", + "generic-array", +] + [[package]] name = "aes" version = "0.8.4" @@ -19,6 +29,20 @@ dependencies = [ "cpufeatures 0.2.17", ] +[[package]] +name = "aes-gcm" +version = "0.10.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "831010a0f742e1209b3bcea8fab6a8e149051ba6099432c8cb2cc117dec3ead1" +dependencies = [ + "aead", + "aes", + "cipher", + "ctr", + "ghash", + "subtle", +] + [[package]] name = "ahash" version = "0.8.12" @@ -635,13 +659,16 @@ dependencies = [ name = "compliance-agent" version = "0.1.0" dependencies = [ + "aes-gcm", "axum", "base64", "chrono", "compliance-core", "compliance-dast", "compliance-graph", + "dashmap", "dotenvy", + "futures-core", "futures-util", "git2", "hex", @@ -658,6 +685,8 @@ dependencies = [ "thiserror 2.0.18", "tokio", "tokio-cron-scheduler", + "tokio-stream", + "tokio-tungstenite 0.26.2", "tower-http", "tracing", "tracing-subscriber", @@ -730,11 +759,13 @@ dependencies = [ name = "compliance-dast" version = "0.1.0" dependencies = [ + "base64", "bollard", "bson", "chromiumoxide", "chrono", "compliance-core", + "futures-util", "mongodb", "native-tls", "reqwest", @@ -744,6 +775,7 @@ dependencies = [ "thiserror 2.0.18", "tokio", "tokio-native-tls", + "tokio-tungstenite 0.26.2", "tracing", "url", "uuid", @@ -1089,6 +1121,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "78c8292055d1c1df0cce5d180393dc8cce0abec0a7102adb6c7b1eef6016d60a" dependencies = [ "generic-array", + "rand_core 0.6.4", "typenum", ] @@ -1115,6 +1148,15 @@ dependencies = [ "syn", ] +[[package]] +name = "ctr" +version = "0.9.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0369ee1ad671834580515889b80f2ea915f23b8be8d0daa4bbaf2ac5c7590835" +dependencies = [ + "cipher", +] + [[package]] name = "darling" version = "0.21.3" @@ -2314,6 +2356,16 @@ dependencies = [ "wasip3", ] +[[package]] +name = "ghash" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f0d8a4362ccb29cb0b265253fb0a2728f592895ee6854fd9bc13f2ffda266ff1" +dependencies = [ + "opaque-debug", + "polyval", +] + [[package]] name = "git2" version = "0.20.4" @@ -2672,7 +2724,7 @@ dependencies = [ "tokio", "tokio-rustls", "tower-service", - "webpki-roots", + "webpki-roots 1.0.6", ] [[package]] @@ -3513,7 +3565,7 @@ dependencies = [ "tokio-util", "typed-builder", "uuid", - "webpki-roots", + "webpki-roots 1.0.6", ] [[package]] @@ -3747,6 +3799,12 @@ version = "0.1.13" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "269bca4c2591a28585d6bf10d9ed0332b7d76900a1b02bec41bdc3a2cdcda107" +[[package]] +name = "opaque-debug" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c08d65885ee38876c4f86fa503fb49d7b507c2b62552df7c70b2fce627e06381" + [[package]] name = "openssl" version = "0.10.75" @@ -4052,6 +4110,18 @@ version = "0.3.32" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7edddbd0b52d732b21ad9a5fab5c704c14cd949e5e9a1ec5929a24fded1b904c" +[[package]] +name = "polyval" +version = "0.6.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9d1fe60d06143b2430aa532c94cfe9e29783047f06c0d7fd359a9a51b729fa25" +dependencies = [ + "cfg-if", + "cpufeatures 0.2.17", + "opaque-debug", + "universal-hash", +] + [[package]] name = "portable-atomic" version = "1.13.1" @@ -4456,7 +4526,7 @@ dependencies = [ "wasm-bindgen-futures", "wasm-streams", "web-sys", - "webpki-roots", + "webpki-roots 1.0.6", ] [[package]] @@ -5662,6 +5732,22 @@ dependencies = [ "tokio-util", ] +[[package]] +name = "tokio-tungstenite" +version = "0.26.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7a9daff607c6d2bf6c16fd681ccb7eecc83e4e2cdc1ca067ffaadfca5de7f084" +dependencies = [ + "futures-util", + "log", + "rustls", + "rustls-pki-types", + "tokio", + "tokio-rustls", + "tungstenite 0.26.2", + "webpki-roots 0.26.11", +] + [[package]] name = "tokio-tungstenite" version = "0.27.0" @@ -6060,6 +6146,25 @@ dependencies = [ "utf-8", ] +[[package]] +name = "tungstenite" +version = "0.26.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4793cb5e56680ecbb1d843515b23b6de9a75eb04b66643e256a396d43be33c13" +dependencies = [ + "bytes", + "data-encoding", + "http", + "httparse", + "log", + "rand 0.9.2", + "rustls", + "rustls-pki-types", + "sha1", + "thiserror 2.0.18", + "utf-8", +] + [[package]] name = "tungstenite" version = "0.27.0" @@ -6171,6 +6276,16 @@ version = "0.2.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ebc1c04c71510c7f702b52b7c350734c9ff1295c464a03335b00bb84fc54f853" +[[package]] +name = "universal-hash" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fc1de2c688dc15305988b563c3854064043356019f97a4b46276fe734c4f07ea" +dependencies = [ + "crypto-common", + "subtle", +] + [[package]] name = "untrusted" version = "0.9.0" @@ -6448,6 +6563,15 @@ dependencies = [ "string_cache_codegen", ] +[[package]] +name = "webpki-roots" +version = "0.26.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "521bc38abb08001b01866da9f51eb7c5d647a19260e00054a8c7fd5f9e57f7a9" +dependencies = [ + "webpki-roots 1.0.6", +] + [[package]] name = "webpki-roots" version = "1.0.6" diff --git a/Cargo.toml b/Cargo.toml index 34096b8..e1b0315 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -30,3 +30,6 @@ uuid = { version = "1", features = ["v4", "serde"] } secrecy = { version = "0.10", features = ["serde"] } regex = "1" zip = { version = "2", features = ["aes-crypto", "deflate"] } +dashmap = "6" +tokio-stream = { version = "0.1", features = ["sync"] } +aes-gcm = "0.10" diff --git a/compliance-agent/Cargo.toml b/compliance-agent/Cargo.toml index 69522c1..b47dd2e 100644 --- a/compliance-agent/Cargo.toml +++ b/compliance-agent/Cargo.toml @@ -37,5 +37,8 @@ urlencoding = "2" futures-util = "0.3" jsonwebtoken = "9" zip = { workspace = true } +aes-gcm = { workspace = true } tokio-tungstenite = { version = "0.26", features = ["rustls-tls-webpki-roots"] } futures-core = "0.3" +dashmap = { workspace = true } +tokio-stream = { workspace = true } diff --git a/compliance-agent/src/agent.rs b/compliance-agent/src/agent.rs index 6c8fed7..9ad55ea 100644 --- a/compliance-agent/src/agent.rs +++ b/compliance-agent/src/agent.rs @@ -1,17 +1,30 @@ use std::sync::Arc; +use dashmap::DashMap; +use tokio::sync::{broadcast, watch, Semaphore}; + +use compliance_core::models::pentest::PentestEvent; use compliance_core::AgentConfig; use crate::database::Database; use crate::llm::LlmClient; use crate::pipeline::orchestrator::PipelineOrchestrator; +/// Default maximum concurrent pentest sessions. +const DEFAULT_MAX_CONCURRENT_SESSIONS: usize = 5; + #[derive(Clone)] pub struct ComplianceAgent { pub config: AgentConfig, pub db: Database, pub llm: Arc, pub http: reqwest::Client, + /// Per-session broadcast senders for SSE streaming. + pub session_streams: Arc>>, + /// Per-session pause controls (true = paused). + pub session_pause: Arc>>, + /// Semaphore limiting concurrent pentest sessions. + pub session_semaphore: Arc, } impl ComplianceAgent { @@ -27,6 +40,9 @@ impl ComplianceAgent { db, llm, http: reqwest::Client::new(), + session_streams: Arc::new(DashMap::new()), + session_pause: Arc::new(DashMap::new()), + session_semaphore: Arc::new(Semaphore::new(DEFAULT_MAX_CONCURRENT_SESSIONS)), } } @@ -74,4 +90,54 @@ impl ComplianceAgent { .run_pr_review(&repo, repo_id, pr_number, base_sha, head_sha) .await } + + // ── Session stream management ────────────────────────────────── + + /// Register a broadcast sender for a session. Returns the sender. + pub fn register_session_stream(&self, session_id: &str) -> broadcast::Sender { + let (tx, _) = broadcast::channel(256); + self.session_streams + .insert(session_id.to_string(), tx.clone()); + tx + } + + /// Subscribe to a session's broadcast stream. + pub fn subscribe_session(&self, session_id: &str) -> Option> { + self.session_streams + .get(session_id) + .map(|tx| tx.subscribe()) + } + + // ── Session pause/resume management ──────────────────────────── + + /// Register a pause control for a session. Returns the watch receiver. + pub fn register_pause_control(&self, session_id: &str) -> watch::Receiver { + let (tx, rx) = watch::channel(false); + self.session_pause.insert(session_id.to_string(), tx); + rx + } + + /// Pause a session. + pub fn pause_session(&self, session_id: &str) -> bool { + if let Some(tx) = self.session_pause.get(session_id) { + tx.send(true).is_ok() + } else { + false + } + } + + /// Resume a session. + pub fn resume_session(&self, session_id: &str) -> bool { + if let Some(tx) = self.session_pause.get(session_id) { + tx.send(false).is_ok() + } else { + false + } + } + + /// Clean up all per-session resources. + pub fn cleanup_session(&self, session_id: &str) { + self.session_streams.remove(session_id); + self.session_pause.remove(session_id); + } } diff --git a/compliance-agent/src/api/handlers/pentest_handlers/export.rs b/compliance-agent/src/api/handlers/pentest_handlers/export.rs index e4396c4..7377232 100644 --- a/compliance-agent/src/api/handlers/pentest_handlers/export.rs +++ b/compliance-agent/src/api/handlers/pentest_handlers/export.rs @@ -7,8 +7,12 @@ use axum::Json; use mongodb::bson::doc; use serde::Deserialize; +use futures_util::StreamExt; + use compliance_core::models::dast::DastFinding; +use compliance_core::models::finding::Finding; use compliance_core::models::pentest::*; +use compliance_core::models::sbom::SbomEntry; use crate::agent::ComplianceAgent; @@ -103,6 +107,97 @@ pub async fn export_session_report( Err(_) => Vec::new(), }; + // Fetch SAST findings, SBOM, and code context for the linked repository + let repo_id = session + .repo_id + .clone() + .or_else(|| target.as_ref().and_then(|t| t.repo_id.clone())); + + let (sast_findings, sbom_entries, code_context) = if let Some(ref rid) = repo_id { + let sast: Vec = match agent + .db + .findings() + .find(doc! { + "repo_id": rid, + "status": { "$in": ["open", "triaged"] }, + }) + .sort(doc! { "severity": -1 }) + .limit(100) + .await + { + Ok(mut cursor) => { + let mut results = Vec::new(); + while let Some(Ok(f)) = cursor.next().await { + results.push(f); + } + results + } + Err(_) => Vec::new(), + }; + + let sbom: Vec = match agent + .db + .sbom_entries() + .find(doc! { + "repo_id": rid, + "known_vulnerabilities": { "$exists": true, "$ne": [] }, + }) + .limit(50) + .await + { + Ok(mut cursor) => { + let mut results = Vec::new(); + while let Some(Ok(e)) = cursor.next().await { + results.push(e); + } + results + } + Err(_) => Vec::new(), + }; + + // Build code context from graph nodes + let code_ctx: Vec = match agent + .db + .graph_nodes() + .find(doc! { "repo_id": rid, "is_entry_point": true }) + .limit(50) + .await + { + Ok(mut cursor) => { + let mut nodes_vec = Vec::new(); + while let Some(Ok(n)) = cursor.next().await { + let linked_vulns: Vec = sast + .iter() + .filter(|f| f.file_path.as_deref() == Some(&n.file_path)) + .map(|f| { + format!( + "[{}] {}: {} (line {})", + f.severity, + f.scanner, + f.title, + f.line_number.unwrap_or(0) + ) + }) + .collect(); + nodes_vec.push(CodeContextHint { + endpoint_pattern: n.qualified_name.clone(), + handler_function: n.name.clone(), + file_path: n.file_path.clone(), + code_snippet: String::new(), + known_vulnerabilities: linked_vulns, + }); + } + nodes_vec + } + Err(_) => Vec::new(), + }; + + (sast, sbom, code_ctx) + } else { + (Vec::new(), Vec::new(), Vec::new()) + }; + + let config = session.config.clone(); let ctx = crate::pentest::report::ReportContext { session, target_name, @@ -115,6 +210,10 @@ pub async fn export_session_report( body.requester_name }, requester_email: body.requester_email, + config, + sast_findings, + sbom_entries, + code_context, }; let report = crate::pentest::generate_encrypted_report(&ctx, &body.password) diff --git a/compliance-agent/src/api/handlers/pentest_handlers/session.rs b/compliance-agent/src/api/handlers/pentest_handlers/session.rs index c768625..77268f4 100644 --- a/compliance-agent/src/api/handlers/pentest_handlers/session.rs +++ b/compliance-agent/src/api/handlers/pentest_handlers/session.rs @@ -17,10 +17,12 @@ type AgentExt = Extension>; #[derive(Deserialize)] pub struct CreateSessionRequest { - pub target_id: String, + pub target_id: Option, #[serde(default = "default_strategy")] pub strategy: String, pub message: Option, + /// Wizard configuration — if present, takes precedence over legacy fields + pub config: Option, } fn default_strategy() -> String { @@ -32,83 +34,310 @@ pub struct SendMessageRequest { pub message: String, } +#[derive(Deserialize)] +pub struct LookupRepoQuery { + pub url: String, +} + /// POST /api/v1/pentest/sessions — Create a new pentest session and start the orchestrator #[tracing::instrument(skip_all)] pub async fn create_session( Extension(agent): AgentExt, Json(req): Json, ) -> Result>, (StatusCode, String)> { - let oid = mongodb::bson::oid::ObjectId::parse_str(&req.target_id).map_err(|_| { - ( - StatusCode::BAD_REQUEST, - "Invalid target_id format".to_string(), - ) - })?; - - // Look up the target - let target = agent - .db - .dast_targets() - .find_one(doc! { "_id": oid }) - .await - .map_err(|e| { + // Try to acquire a concurrency permit + let permit = agent + .session_semaphore + .clone() + .try_acquire_owned() + .map_err(|_| { ( - StatusCode::INTERNAL_SERVER_ERROR, - format!("Database error: {e}"), + StatusCode::TOO_MANY_REQUESTS, + "Maximum concurrent pentest sessions reached. Try again later.".to_string(), ) - })? - .ok_or_else(|| (StatusCode::NOT_FOUND, "Target not found".to_string()))?; + })?; - // Parse strategy - let strategy = match req.strategy.as_str() { + if let Some(ref config) = req.config { + // ── Wizard path ────────────────────────────────────────────── + if !config.disclaimer_accepted { + return Err(( + StatusCode::BAD_REQUEST, + "Disclaimer must be accepted".to_string(), + )); + } + + // Look up or auto-create DastTarget by app_url + let target = match agent + .db + .dast_targets() + .find_one(doc! { "base_url": &config.app_url }) + .await + .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, format!("DB error: {e}")))? + { + Some(t) => t, + None => { + use compliance_core::models::dast::{DastTarget, DastTargetType}; + let mut t = DastTarget::new( + config.app_url.clone(), + config.app_url.clone(), + DastTargetType::WebApp, + ); + if let Some(rl) = config.rate_limit { + t.rate_limit = rl; + } + t.allow_destructive = config.allow_destructive; + t.excluded_paths = config.scope_exclusions.clone(); + let res = agent.db.dast_targets().insert_one(&t).await.map_err(|e| { + ( + StatusCode::INTERNAL_SERVER_ERROR, + format!("Failed to create target: {e}"), + ) + })?; + t.id = res.inserted_id.as_object_id(); + t + } + }; + + let target_id = target.id.map(|oid| oid.to_hex()).unwrap_or_default(); + + // Parse strategy from config or request + let strat_str = config.strategy.as_deref().unwrap_or(req.strategy.as_str()); + let strategy = parse_strategy(strat_str); + + let mut session = PentestSession::new(target_id, strategy); + session.config = Some(config.clone()); + session.repo_id = target.repo_id.clone(); + + // Resolve repo_id from git_repo_url if provided + if let Some(ref git_url) = config.git_repo_url { + if let Ok(Some(repo)) = agent + .db + .repositories() + .find_one(doc! { "git_url": git_url }) + .await + { + session.repo_id = repo.id.map(|oid| oid.to_hex()); + } + } + + let insert_result = agent + .db + .pentest_sessions() + .insert_one(&session) + .await + .map_err(|e| { + ( + StatusCode::INTERNAL_SERVER_ERROR, + format!("Failed to create session: {e}"), + ) + })?; + session.id = insert_result.inserted_id.as_object_id(); + + let session_id_str = session.id.map(|oid| oid.to_hex()).unwrap_or_default(); + + // Register broadcast stream and pause control + let event_tx = agent.register_session_stream(&session_id_str); + let pause_rx = agent.register_pause_control(&session_id_str); + + // Encrypt credentials before they linger in memory + let mut session_for_task = session.clone(); + if let Some(ref mut cfg) = session_for_task.config { + cfg.auth.username = cfg + .auth + .username + .as_ref() + .map(|u| crate::pentest::crypto::encrypt(u)); + cfg.auth.password = cfg + .auth + .password + .as_ref() + .map(|p| crate::pentest::crypto::encrypt(p)); + } + + // Persist encrypted credentials to DB + if session_for_task.config.is_some() { + if let Some(sid) = session.id { + let _ = agent + .db + .pentest_sessions() + .update_one( + doc! { "_id": sid }, + doc! { "$set": { + "config.auth.username": session_for_task.config.as_ref() + .and_then(|c| c.auth.username.as_deref()) + .map(|s| mongodb::bson::Bson::String(s.to_string())) + .unwrap_or(mongodb::bson::Bson::Null), + "config.auth.password": session_for_task.config.as_ref() + .and_then(|c| c.auth.password.as_deref()) + .map(|s| mongodb::bson::Bson::String(s.to_string())) + .unwrap_or(mongodb::bson::Bson::Null), + }}, + ) + .await; + } + } + + let initial_message = config + .initial_instructions + .clone() + .or(req.message.clone()) + .unwrap_or_else(|| { + format!( + "Begin a {} penetration test against {} ({}). \ + Identify vulnerabilities and provide evidence for each finding.", + session.strategy, target.name, target.base_url, + ) + }); + + let llm = agent.llm.clone(); + let db = agent.db.clone(); + let session_clone = session.clone(); + let target_clone = target.clone(); + let agent_ref = agent.clone(); + tokio::spawn(async move { + let orchestrator = PentestOrchestrator::new(llm, db, event_tx, Some(pause_rx)); + orchestrator + .run_session_guarded(&session_clone, &target_clone, &initial_message) + .await; + // Clean up session resources + agent_ref.cleanup_session(&session_id_str); + // Release concurrency permit + drop(permit); + }); + + // Redact credentials in response + let mut response_session = session; + if let Some(ref mut cfg) = response_session.config { + if cfg.auth.username.is_some() { + cfg.auth.username = Some("********".to_string()); + } + if cfg.auth.password.is_some() { + cfg.auth.password = Some("********".to_string()); + } + } + + Ok(Json(ApiResponse { + data: response_session, + total: None, + page: None, + })) + } else { + // ── Legacy path ────────────────────────────────────────────── + let target_id = req.target_id.clone().ok_or_else(|| { + ( + StatusCode::BAD_REQUEST, + "target_id is required for legacy creation".to_string(), + ) + })?; + + let oid = mongodb::bson::oid::ObjectId::parse_str(&target_id).map_err(|_| { + ( + StatusCode::BAD_REQUEST, + "Invalid target_id format".to_string(), + ) + })?; + + let target = agent + .db + .dast_targets() + .find_one(doc! { "_id": oid }) + .await + .map_err(|e| { + ( + StatusCode::INTERNAL_SERVER_ERROR, + format!("Database error: {e}"), + ) + })? + .ok_or_else(|| (StatusCode::NOT_FOUND, "Target not found".to_string()))?; + + let strategy = parse_strategy(&req.strategy); + + let mut session = PentestSession::new(target_id, strategy); + session.repo_id = target.repo_id.clone(); + + let insert_result = agent + .db + .pentest_sessions() + .insert_one(&session) + .await + .map_err(|e| { + ( + StatusCode::INTERNAL_SERVER_ERROR, + format!("Failed to create session: {e}"), + ) + })?; + session.id = insert_result.inserted_id.as_object_id(); + + let session_id_str = session.id.map(|oid| oid.to_hex()).unwrap_or_default(); + + // Register broadcast stream and pause control + let event_tx = agent.register_session_stream(&session_id_str); + let pause_rx = agent.register_pause_control(&session_id_str); + + let initial_message = req.message.unwrap_or_else(|| { + format!( + "Begin a {} penetration test against {} ({}). \ + Identify vulnerabilities and provide evidence for each finding.", + session.strategy, target.name, target.base_url, + ) + }); + + let llm = agent.llm.clone(); + let db = agent.db.clone(); + let session_clone = session.clone(); + let target_clone = target.clone(); + let agent_ref = agent.clone(); + tokio::spawn(async move { + let orchestrator = PentestOrchestrator::new(llm, db, event_tx, Some(pause_rx)); + orchestrator + .run_session_guarded(&session_clone, &target_clone, &initial_message) + .await; + agent_ref.cleanup_session(&session_id_str); + drop(permit); + }); + + Ok(Json(ApiResponse { + data: session, + total: None, + page: None, + })) + } +} + +fn parse_strategy(s: &str) -> PentestStrategy { + match s { "quick" => PentestStrategy::Quick, "targeted" => PentestStrategy::Targeted, "aggressive" => PentestStrategy::Aggressive, "stealth" => PentestStrategy::Stealth, _ => PentestStrategy::Comprehensive, + } +} + +/// GET /api/v1/pentest/lookup-repo — Look up a tracked repository by git URL +#[tracing::instrument(skip_all)] +pub async fn lookup_repo( + Extension(agent): AgentExt, + Query(params): Query, +) -> Result>, StatusCode> { + let repo = agent + .db + .repositories() + .find_one(doc! { "git_url": ¶ms.url }) + .await + .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; + + let data = match repo { + Some(r) => serde_json::json!({ + "name": r.name, + "default_branch": r.default_branch, + "last_scanned_commit": r.last_scanned_commit, + }), + None => serde_json::Value::Null, }; - // Create session - let mut session = PentestSession::new(req.target_id.clone(), strategy); - session.repo_id = target.repo_id.clone(); - - let insert_result = agent - .db - .pentest_sessions() - .insert_one(&session) - .await - .map_err(|e| { - ( - StatusCode::INTERNAL_SERVER_ERROR, - format!("Failed to create session: {e}"), - ) - })?; - - // Set the generated ID back on the session so the orchestrator has it - session.id = insert_result.inserted_id.as_object_id(); - - let initial_message = req.message.unwrap_or_else(|| { - format!( - "Begin a {} penetration test against {} ({}). \ - Identify vulnerabilities and provide evidence for each finding.", - session.strategy, target.name, target.base_url, - ) - }); - - // Spawn the orchestrator on a background task - let llm = agent.llm.clone(); - let db = agent.db.clone(); - let session_clone = session.clone(); - let target_clone = target.clone(); - tokio::spawn(async move { - let orchestrator = PentestOrchestrator::new(llm, db); - orchestrator - .run_session_guarded(&session_clone, &target_clone, &initial_message) - .await; - }); - Ok(Json(ApiResponse { - data: session, + data, total: None, page: None, })) @@ -158,7 +387,7 @@ pub async fn get_session( ) -> Result>, StatusCode> { let oid = mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?; - let session = agent + let mut session = agent .db .pentest_sessions() .find_one(doc! { "_id": oid }) @@ -166,6 +395,16 @@ pub async fn get_session( .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)? .ok_or(StatusCode::NOT_FOUND)?; + // Redact credentials in response + if let Some(ref mut cfg) = session.config { + if cfg.auth.username.is_some() { + cfg.auth.username = Some("********".to_string()); + } + if cfg.auth.password.is_some() { + cfg.auth.password = Some("********".to_string()); + } + } + Ok(Json(ApiResponse { data: session, total: None, @@ -241,8 +480,20 @@ pub async fn send_message( let llm = agent.llm.clone(); let db = agent.db.clone(); let message = req.message.clone(); + + // Use existing broadcast sender if available, otherwise create a new one + let event_tx = agent + .subscribe_session(&session_id) + .and_then(|_| { + agent + .session_streams + .get(&session_id) + .map(|entry| entry.value().clone()) + }) + .unwrap_or_else(|| agent.register_session_stream(&session_id)); + tokio::spawn(async move { - let orchestrator = PentestOrchestrator::new(llm, db); + let orchestrator = PentestOrchestrator::new(llm, db, event_tx, None); orchestrator .run_session_guarded(&session, &target, &message) .await; @@ -277,10 +528,10 @@ pub async fn stop_session( })? .ok_or_else(|| (StatusCode::NOT_FOUND, "Session not found".to_string()))?; - if session.status != PentestStatus::Running { + if session.status != PentestStatus::Running && session.status != PentestStatus::Paused { return Err(( StatusCode::BAD_REQUEST, - format!("Session is {}, not running", session.status), + format!("Session is {}, not running or paused", session.status), )); } @@ -303,6 +554,9 @@ pub async fn stop_session( ) })?; + // Clean up session resources + agent.cleanup_session(&id); + let updated = agent .db .pentest_sessions() @@ -328,6 +582,92 @@ pub async fn stop_session( })) } +/// POST /api/v1/pentest/sessions/:id/pause — Pause a running pentest session +#[tracing::instrument(skip_all, fields(session_id = %id))] +pub async fn pause_session( + Extension(agent): AgentExt, + Path(id): Path, +) -> Result>, (StatusCode, String)> { + let oid = mongodb::bson::oid::ObjectId::parse_str(&id) + .map_err(|_| (StatusCode::BAD_REQUEST, "Invalid session ID".to_string()))?; + + let session = agent + .db + .pentest_sessions() + .find_one(doc! { "_id": oid }) + .await + .map_err(|e| { + ( + StatusCode::INTERNAL_SERVER_ERROR, + format!("Database error: {e}"), + ) + })? + .ok_or_else(|| (StatusCode::NOT_FOUND, "Session not found".to_string()))?; + + if session.status != PentestStatus::Running { + return Err(( + StatusCode::BAD_REQUEST, + format!("Session is {}, not running", session.status), + )); + } + + if !agent.pause_session(&id) { + return Err(( + StatusCode::INTERNAL_SERVER_ERROR, + "Failed to send pause signal".to_string(), + )); + } + + Ok(Json(ApiResponse { + data: serde_json::json!({ "status": "paused" }), + total: None, + page: None, + })) +} + +/// POST /api/v1/pentest/sessions/:id/resume — Resume a paused pentest session +#[tracing::instrument(skip_all, fields(session_id = %id))] +pub async fn resume_session( + Extension(agent): AgentExt, + Path(id): Path, +) -> Result>, (StatusCode, String)> { + let oid = mongodb::bson::oid::ObjectId::parse_str(&id) + .map_err(|_| (StatusCode::BAD_REQUEST, "Invalid session ID".to_string()))?; + + let session = agent + .db + .pentest_sessions() + .find_one(doc! { "_id": oid }) + .await + .map_err(|e| { + ( + StatusCode::INTERNAL_SERVER_ERROR, + format!("Database error: {e}"), + ) + })? + .ok_or_else(|| (StatusCode::NOT_FOUND, "Session not found".to_string()))?; + + if session.status != PentestStatus::Paused { + return Err(( + StatusCode::BAD_REQUEST, + format!("Session is {}, not paused", session.status), + )); + } + + if !agent.resume_session(&id) { + return Err(( + StatusCode::INTERNAL_SERVER_ERROR, + "Failed to send resume signal".to_string(), + )); + } + + Ok(Json(ApiResponse { + data: serde_json::json!({ "status": "running" }), + total: None, + page: None, + })) +} + /// GET /api/v1/pentest/sessions/:id/attack-chain — Get attack chain nodes for a session #[tracing::instrument(skip_all, fields(session_id = %id))] pub async fn get_attack_chain( diff --git a/compliance-agent/src/api/handlers/pentest_handlers/stream.rs b/compliance-agent/src/api/handlers/pentest_handlers/stream.rs index aa29cab..015c288 100644 --- a/compliance-agent/src/api/handlers/pentest_handlers/stream.rs +++ b/compliance-agent/src/api/handlers/pentest_handlers/stream.rs @@ -1,10 +1,14 @@ +use std::convert::Infallible; use std::sync::Arc; +use std::time::Duration; use axum::extract::{Extension, Path}; use axum::http::StatusCode; -use axum::response::sse::{Event, Sse}; +use axum::response::sse::{Event, KeepAlive, Sse}; use futures_util::stream; use mongodb::bson::doc; +use tokio_stream::wrappers::BroadcastStream; +use tokio_stream::StreamExt; use compliance_core::models::pentest::*; @@ -16,16 +20,13 @@ type AgentExt = Extension>; /// GET /api/v1/pentest/sessions/:id/stream — SSE endpoint for real-time events /// -/// Returns recent messages as SSE events (polling approach). -/// True real-time streaming with broadcast channels will be added in a future iteration. +/// Replays stored messages/nodes as initial burst, then subscribes to the +/// broadcast channel for live updates. Sends keepalive comments every 15s. #[tracing::instrument(skip_all, fields(session_id = %id))] pub async fn session_stream( Extension(agent): AgentExt, Path(id): Path, -) -> Result< - Sse>>, - StatusCode, -> { +) -> Result>>, StatusCode> { let oid = mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?; // Verify session exists @@ -37,6 +38,10 @@ pub async fn session_stream( .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)? .ok_or(StatusCode::NOT_FOUND)?; + // ── Initial burst: replay stored data ────────────────────────── + + let mut initial_events: Vec> = Vec::new(); + // Fetch recent messages for this session let messages: Vec = match agent .db @@ -63,9 +68,6 @@ pub async fn session_stream( Err(_) => Vec::new(), }; - // Build SSE events from stored data - let mut events: Vec> = Vec::new(); - for msg in &messages { let event_data = serde_json::json!({ "type": "message", @@ -74,7 +76,7 @@ pub async fn session_stream( "created_at": msg.created_at.to_rfc3339(), }); if let Ok(data) = serde_json::to_string(&event_data) { - events.push(Ok(Event::default().event("message").data(data))); + initial_events.push(Ok(Event::default().event("message").data(data))); } } @@ -87,11 +89,11 @@ pub async fn session_stream( "findings_produced": node.findings_produced, }); if let Ok(data) = serde_json::to_string(&event_data) { - events.push(Ok(Event::default().event("tool").data(data))); + initial_events.push(Ok(Event::default().event("tool").data(data))); } } - // Add session status event + // Add current session status event let session = agent .db .pentest_sessions() @@ -108,9 +110,49 @@ pub async fn session_stream( "tool_invocations": s.tool_invocations, }); if let Ok(data) = serde_json::to_string(&status_data) { - events.push(Ok(Event::default().event("status").data(data))); + initial_events.push(Ok(Event::default().event("status").data(data))); } } - Ok(Sse::new(stream::iter(events))) + // ── Live stream: subscribe to broadcast ──────────────────────── + + let live_stream = if let Some(rx) = agent.subscribe_session(&id) { + let broadcast = BroadcastStream::new(rx).filter_map(|result| match result { + Ok(event) => { + if let Ok(data) = serde_json::to_string(&event) { + let event_type = match &event { + PentestEvent::ToolStart { .. } => "tool_start", + PentestEvent::ToolComplete { .. } => "tool_complete", + PentestEvent::Finding { .. } => "finding", + PentestEvent::Message { .. } => "message", + PentestEvent::Complete { .. } => "complete", + PentestEvent::Error { .. } => "error", + PentestEvent::Thinking { .. } => "thinking", + PentestEvent::Paused => "paused", + PentestEvent::Resumed => "resumed", + }; + Some(Ok(Event::default().event(event_type).data(data))) + } else { + None + } + } + Err(_) => None, + }); + // Box to unify types + Box::pin(broadcast) + as std::pin::Pin> + Send>> + } else { + // No active broadcast — return empty stream + Box::pin(stream::empty()) + as std::pin::Pin> + Send>> + }; + + // Chain initial burst + live stream + let combined = stream::iter(initial_events).chain(live_stream); + + Ok(Sse::new(combined).keep_alive( + KeepAlive::new() + .interval(Duration::from_secs(15)) + .text("keepalive"), + )) } diff --git a/compliance-agent/src/api/routes.rs b/compliance-agent/src/api/routes.rs index 0b72262..feeb51f 100644 --- a/compliance-agent/src/api/routes.rs +++ b/compliance-agent/src/api/routes.rs @@ -100,6 +100,10 @@ pub fn build_router() -> Router { get(handlers::chat::embedding_status), ) // Pentest API endpoints + .route( + "/api/v1/pentest/lookup-repo", + get(handlers::pentest::lookup_repo), + ) .route( "/api/v1/pentest/sessions", get(handlers::pentest::list_sessions).post(handlers::pentest::create_session), @@ -116,6 +120,14 @@ pub fn build_router() -> Router { "/api/v1/pentest/sessions/{id}/stop", post(handlers::pentest::stop_session), ) + .route( + "/api/v1/pentest/sessions/{id}/pause", + post(handlers::pentest::pause_session), + ) + .route( + "/api/v1/pentest/sessions/{id}/resume", + post(handlers::pentest::resume_session), + ) .route( "/api/v1/pentest/sessions/{id}/stream", get(handlers::pentest::session_stream), diff --git a/compliance-agent/src/pentest/crypto.rs b/compliance-agent/src/pentest/crypto.rs new file mode 100644 index 0000000..7364bcd --- /dev/null +++ b/compliance-agent/src/pentest/crypto.rs @@ -0,0 +1,117 @@ +use aes_gcm::aead::AeadCore; +use aes_gcm::{ + aead::{Aead, KeyInit, OsRng}, + Aes256Gcm, Nonce, +}; +use base64::Engine; + +/// Load the 32-byte encryption key from PENTEST_ENCRYPTION_KEY env var. +/// Returns None if not set or invalid length. +pub fn load_encryption_key() -> Option<[u8; 32]> { + let hex_key = std::env::var("PENTEST_ENCRYPTION_KEY").ok()?; + let bytes = hex::decode(hex_key).ok()?; + if bytes.len() != 32 { + return None; + } + let mut key = [0u8; 32]; + key.copy_from_slice(&bytes); + Some(key) +} + +/// Encrypt a plaintext string. Returns base64-encoded nonce+ciphertext. +/// Returns the original string if no encryption key is available. +pub fn encrypt(plaintext: &str) -> String { + let Some(key_bytes) = load_encryption_key() else { + return plaintext.to_string(); + }; + let Ok(cipher) = Aes256Gcm::new_from_slice(&key_bytes) else { + return plaintext.to_string(); + }; + let nonce = Aes256Gcm::generate_nonce(&mut OsRng); + let Ok(ciphertext) = cipher.encrypt(&nonce, plaintext.as_bytes()) else { + return plaintext.to_string(); + }; + let mut combined = nonce.to_vec(); + combined.extend_from_slice(&ciphertext); + base64::engine::general_purpose::STANDARD.encode(&combined) +} + +/// Decrypt a base64-encoded nonce+ciphertext string. +/// Returns None if decryption fails. +pub fn decrypt(encrypted: &str) -> Option { + let key_bytes = load_encryption_key()?; + let cipher = Aes256Gcm::new_from_slice(&key_bytes).ok()?; + let combined = base64::engine::general_purpose::STANDARD + .decode(encrypted) + .ok()?; + if combined.len() < 12 { + return None; + } + let (nonce_bytes, ciphertext) = combined.split_at(12); + let nonce = Nonce::from_slice(nonce_bytes); + let plaintext = cipher.decrypt(nonce, ciphertext).ok()?; + String::from_utf8(plaintext).ok() +} + +#[cfg(test)] +mod tests { + use super::*; + use std::sync::Mutex; + + // Guard to serialize tests that touch env vars + static ENV_LOCK: Mutex<()> = Mutex::new(()); + + fn with_key(hex_key: &str, f: F) { + let _guard = ENV_LOCK.lock(); + unsafe { std::env::set_var("PENTEST_ENCRYPTION_KEY", hex_key) }; + f(); + unsafe { std::env::remove_var("PENTEST_ENCRYPTION_KEY") }; + } + + #[test] + fn round_trip() { + let key = "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"; + with_key(key, || { + let plaintext = "my_secret_password"; + let encrypted = encrypt(plaintext); + assert_ne!(encrypted, plaintext); + let decrypted = decrypt(&encrypted); + assert_eq!(decrypted, Some(plaintext.to_string())); + }); + } + + #[test] + fn wrong_key_fails() { + let _guard = ENV_LOCK.lock(); + let key1 = "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"; + let key2 = "abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789"; + let encrypted = { + unsafe { std::env::set_var("PENTEST_ENCRYPTION_KEY", key1) }; + let e = encrypt("secret"); + unsafe { std::env::remove_var("PENTEST_ENCRYPTION_KEY") }; + e + }; + unsafe { std::env::set_var("PENTEST_ENCRYPTION_KEY", key2) }; + assert!(decrypt(&encrypted).is_none()); + unsafe { std::env::remove_var("PENTEST_ENCRYPTION_KEY") }; + } + + #[test] + fn no_key_passthrough() { + let _guard = ENV_LOCK.lock(); + unsafe { std::env::remove_var("PENTEST_ENCRYPTION_KEY") }; + let result = encrypt("plain"); + assert_eq!(result, "plain"); + } + + #[test] + fn corrupted_ciphertext() { + let key = "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"; + with_key(key, || { + assert!(decrypt("not-valid-base64!!!").is_none()); + // Valid base64 but wrong content + let garbage = base64::engine::general_purpose::STANDARD.encode(b"tooshort"); + assert!(decrypt(&garbage).is_none()); + }); + } +} diff --git a/compliance-agent/src/pentest/mod.rs b/compliance-agent/src/pentest/mod.rs index 6aa5bfb..2911ffd 100644 --- a/compliance-agent/src/pentest/mod.rs +++ b/compliance-agent/src/pentest/mod.rs @@ -1,4 +1,5 @@ mod context; +pub mod crypto; pub mod orchestrator; mod prompt_builder; pub mod report; diff --git a/compliance-agent/src/pentest/orchestrator.rs b/compliance-agent/src/pentest/orchestrator.rs index 2c88ce5..9bfe895 100644 --- a/compliance-agent/src/pentest/orchestrator.rs +++ b/compliance-agent/src/pentest/orchestrator.rs @@ -2,7 +2,7 @@ use std::sync::Arc; use std::time::Duration; use mongodb::bson::doc; -use tokio::sync::broadcast; +use tokio::sync::{broadcast, watch}; use compliance_core::models::dast::DastTarget; use compliance_core::models::pentest::*; @@ -22,29 +22,27 @@ pub struct PentestOrchestrator { pub(crate) llm: Arc, pub(crate) db: Database, pub(crate) event_tx: broadcast::Sender, + pub(crate) pause_rx: Option>, } impl PentestOrchestrator { - pub fn new(llm: Arc, db: Database) -> Self { - let (event_tx, _) = broadcast::channel(256); + /// Create a new orchestrator with an externally-provided broadcast sender + /// and an optional pause receiver. + pub fn new( + llm: Arc, + db: Database, + event_tx: broadcast::Sender, + pause_rx: Option>, + ) -> Self { Self { tool_registry: ToolRegistry::new(), llm, db, event_tx, + pause_rx, } } - #[allow(dead_code)] - pub fn subscribe(&self) -> broadcast::Receiver { - self.event_tx.subscribe() - } - - #[allow(dead_code)] - pub fn event_sender(&self) -> broadcast::Sender { - self.event_tx.clone() - } - /// Run a pentest session with timeout and automatic failure marking on errors. pub async fn run_session_guarded( &self, @@ -54,8 +52,18 @@ impl PentestOrchestrator { ) { let session_id = session.id; + // Use config-specified timeout or default + let timeout_duration = session + .config + .as_ref() + .and_then(|c| c.max_duration_minutes) + .map(|m| Duration::from_secs(m as u64 * 60)) + .unwrap_or(SESSION_TIMEOUT); + + let timeout_minutes = timeout_duration.as_secs() / 60; + match tokio::time::timeout( - SESSION_TIMEOUT, + timeout_duration, self.run_session(session, target, initial_message), ) .await @@ -72,12 +80,10 @@ impl PentestOrchestrator { }); } Err(_) => { - tracing::warn!(?session_id, "Pentest session timed out after 30 minutes"); - self.mark_session_failed(session_id, "Session timed out after 30 minutes") - .await; - let _ = self.event_tx.send(PentestEvent::Error { - message: "Session timed out after 30 minutes".to_string(), - }); + let msg = format!("Session timed out after {timeout_minutes} minutes"); + tracing::warn!(?session_id, "{msg}"); + self.mark_session_failed(session_id, &msg).await; + let _ = self.event_tx.send(PentestEvent::Error { message: msg }); } } } @@ -103,6 +109,45 @@ impl PentestOrchestrator { } } + /// Check if the session is paused; if so, update DB status and wait until resumed. + async fn wait_if_paused(&self, session: &PentestSession) { + let Some(ref pause_rx) = self.pause_rx else { + return; + }; + let mut rx = pause_rx.clone(); + + if !*rx.borrow() { + return; + } + + // We are paused — update DB status + if let Some(sid) = session.id { + let _ = self + .db + .pentest_sessions() + .update_one(doc! { "_id": sid }, doc! { "$set": { "status": "paused" }}) + .await; + } + let _ = self.event_tx.send(PentestEvent::Paused); + + // Wait until unpaused + while *rx.borrow() { + if rx.changed().await.is_err() { + break; + } + } + + // Resumed — update DB status back to running + if let Some(sid) = session.id { + let _ = self + .db + .pentest_sessions() + .update_one(doc! { "_id": sid }, doc! { "$set": { "status": "running" }}) + .await; + } + let _ = self.event_tx.send(PentestEvent::Resumed); + } + async fn run_session( &self, session: &PentestSession, @@ -175,6 +220,9 @@ impl PentestOrchestrator { let mut prev_node_ids: Vec = Vec::new(); for _iteration in 0..max_iterations { + // Check pause state at top of each iteration + self.wait_if_paused(session).await; + let response = self .llm .chat_with_tools(messages.clone(), &tool_defs, Some(0.2), Some(8192)) @@ -417,6 +465,21 @@ impl PentestOrchestrator { .await; } + // If cleanup_test_user is requested, append a cleanup instruction + if session + .config + .as_ref() + .is_some_and(|c| c.auth.cleanup_test_user) + { + let cleanup_msg = PentestMessage::user( + session_id.clone(), + "Testing is complete. Now please clean up: navigate to the application and delete \ + the test user account that was created during this session. Confirm once done." + .to_string(), + ); + let _ = self.db.pentest_messages().insert_one(&cleanup_msg).await; + } + let _ = self.event_tx.send(PentestEvent::Complete { summary: format!( "Pentest complete. {} findings from {} tool invocations.", diff --git a/compliance-agent/src/pentest/prompt_builder.rs b/compliance-agent/src/pentest/prompt_builder.rs index ac8ee97..13da6da 100644 --- a/compliance-agent/src/pentest/prompt_builder.rs +++ b/compliance-agent/src/pentest/prompt_builder.rs @@ -5,6 +5,100 @@ use compliance_core::models::sbom::SbomEntry; use super::orchestrator::PentestOrchestrator; +/// Attempt to decrypt a field; if decryption fails, return the original value +/// (which may be plaintext from before encryption was enabled). +fn decrypt_field(value: &str) -> String { + super::crypto::decrypt(value).unwrap_or_else(|| value.to_string()) +} + +/// Build additional prompt sections from PentestConfig when present. +fn build_config_sections(config: &PentestConfig) -> String { + let mut sections = String::new(); + + // Authentication section + match config.auth.mode { + AuthMode::Manual => { + sections.push_str("\n## Authentication\n"); + sections.push_str("- **Mode**: Manual credentials\n"); + if let Some(ref u) = config.auth.username { + let decrypted = decrypt_field(u); + sections.push_str(&format!("- **Username**: {decrypted}\n")); + } + if let Some(ref p) = config.auth.password { + let decrypted = decrypt_field(p); + sections.push_str(&format!("- **Password**: {decrypted}\n")); + } + sections.push_str( + "Use these credentials to log in before testing authenticated endpoints.\n", + ); + } + AuthMode::AutoRegister => { + sections.push_str("\n## Authentication\n"); + sections.push_str("- **Mode**: Auto-register\n"); + if let Some(ref url) = config.auth.registration_url { + sections.push_str(&format!("- **Registration URL**: {url}\n")); + } else { + sections.push_str( + "- **Registration URL**: Not provided — use Playwright to discover the registration page.\n", + ); + } + if let Some(ref email) = config.auth.verification_email { + sections.push_str(&format!( + "- **Verification Email**: Use plus-addressing from `{email}` \ + (e.g. `{base}+{{session_id}}@{domain}`) for email verification. \ + The system will poll the IMAP mailbox for verification links.\n", + base = email.split('@').next().unwrap_or(email), + domain = email.split('@').nth(1).unwrap_or("example.com"), + )); + } + sections.push_str( + "Register a new test account using the registration page, then use it for testing.\n", + ); + } + AuthMode::None => {} + } + + // Custom headers + if !config.custom_headers.is_empty() { + sections.push_str("\n## Custom HTTP Headers\n"); + sections.push_str("Include these headers in all HTTP requests:\n"); + for (k, v) in &config.custom_headers { + sections.push_str(&format!("- `{k}: {v}`\n")); + } + } + + // Scope exclusions + if !config.scope_exclusions.is_empty() { + sections.push_str("\n## Scope Exclusions\n"); + sections.push_str("Do NOT test the following paths:\n"); + for path in &config.scope_exclusions { + sections.push_str(&format!("- `{path}`\n")); + } + } + + // Git context + if config.git_repo_url.is_some() || config.branch.is_some() || config.commit_hash.is_some() { + sections.push_str("\n## Git Context\n"); + if let Some(ref url) = config.git_repo_url { + sections.push_str(&format!("- **Repository**: {url}\n")); + } + if let Some(ref branch) = config.branch { + sections.push_str(&format!("- **Branch**: {branch}\n")); + } + if let Some(ref commit) = config.commit_hash { + sections.push_str(&format!("- **Commit**: {commit}\n")); + } + } + + // Environment + sections.push_str(&format!( + "\n## Environment\n- **Target environment**: {}\n", + config.environment + )); + + sections +} + /// Return strategy guidance text for the given strategy. fn strategy_guidance(strategy: &PentestStrategy) -> &'static str { match strategy { @@ -155,6 +249,11 @@ impl PentestOrchestrator { let sast_section = build_sast_section(sast_findings); let sbom_section = build_sbom_section(sbom_entries); let code_section = build_code_section(code_context); + let config_sections = session + .config + .as_ref() + .map(build_config_sections) + .unwrap_or_default(); format!( r#"You are an expert penetration tester conducting an authorized security assessment. @@ -178,7 +277,7 @@ impl PentestOrchestrator { ## Code Entry Points (Knowledge Graph) {code_section} - +{config_sections} ## Available Tools {tool_names} diff --git a/compliance-agent/src/pentest/report/html.rs b/compliance-agent/src/pentest/report/html.rs deleted file mode 100644 index 3882f76..0000000 --- a/compliance-agent/src/pentest/report/html.rs +++ /dev/null @@ -1,1851 +0,0 @@ -use compliance_core::models::dast::DastFinding; -use compliance_core::models::pentest::AttackChainNode; - -use super::ReportContext; - -#[allow(clippy::format_in_format_args)] -pub(super) fn build_html_report(ctx: &ReportContext) -> String { - let session = &ctx.session; - let session_id = session - .id - .map(|oid| oid.to_hex()) - .unwrap_or_else(|| "-".to_string()); - let date_str = session - .started_at - .format("%B %d, %Y at %H:%M UTC") - .to_string(); - let date_short = session.started_at.format("%B %d, %Y").to_string(); - let completed_str = session - .completed_at - .map(|d| d.format("%B %d, %Y at %H:%M UTC").to_string()) - .unwrap_or_else(|| "In Progress".to_string()); - - let critical = ctx - .findings - .iter() - .filter(|f| f.severity.to_string() == "critical") - .count(); - let high = ctx - .findings - .iter() - .filter(|f| f.severity.to_string() == "high") - .count(); - let medium = ctx - .findings - .iter() - .filter(|f| f.severity.to_string() == "medium") - .count(); - let low = ctx - .findings - .iter() - .filter(|f| f.severity.to_string() == "low") - .count(); - let info = ctx - .findings - .iter() - .filter(|f| f.severity.to_string() == "info") - .count(); - let exploitable = ctx.findings.iter().filter(|f| f.exploitable).count(); - let total = ctx.findings.len(); - - let overall_risk = if critical > 0 { - "CRITICAL" - } else if high > 0 { - "HIGH" - } else if medium > 0 { - "MEDIUM" - } else if low > 0 { - "LOW" - } else { - "INFORMATIONAL" - }; - - let risk_color = match overall_risk { - "CRITICAL" => "#991b1b", - "HIGH" => "#c2410c", - "MEDIUM" => "#a16207", - "LOW" => "#1d4ed8", - _ => "#4b5563", - }; - - // Risk score 0-100 - let risk_score: usize = - std::cmp::min(100, critical * 25 + high * 15 + medium * 8 + low * 3 + info); - - // Collect unique tool names used - let tool_names: Vec = { - let mut names: Vec = ctx - .attack_chain - .iter() - .map(|n| n.tool_name.clone()) - .collect(); - names.sort(); - names.dedup(); - names - }; - - // Severity distribution bar - let severity_bar = if total > 0 { - let crit_pct = (critical as f64 / total as f64 * 100.0) as usize; - let high_pct = (high as f64 / total as f64 * 100.0) as usize; - let med_pct = (medium as f64 / total as f64 * 100.0) as usize; - let low_pct = (low as f64 / total as f64 * 100.0) as usize; - let info_pct = 100_usize.saturating_sub(crit_pct + high_pct + med_pct + low_pct); - - let mut bar = String::from(r#"
"#); - if critical > 0 { - bar.push_str(&format!( - r#"
{}
"#, - std::cmp::max(crit_pct, 4), critical - )); - } - if high > 0 { - bar.push_str(&format!( - r#"
{}
"#, - std::cmp::max(high_pct, 4), - high - )); - } - if medium > 0 { - bar.push_str(&format!( - r#"
{}
"#, - std::cmp::max(med_pct, 4), medium - )); - } - if low > 0 { - bar.push_str(&format!( - r#"
{}
"#, - std::cmp::max(low_pct, 4), - low - )); - } - if info > 0 { - bar.push_str(&format!( - r#"
{}
"#, - std::cmp::max(info_pct, 4), - info - )); - } - bar.push_str("
"); - bar.push_str(r#"
"#); - if critical > 0 { - bar.push_str( - r#" Critical"#, - ); - } - if high > 0 { - bar.push_str(r#" High"#); - } - if medium > 0 { - bar.push_str( - r#" Medium"#, - ); - } - if low > 0 { - bar.push_str(r#" Low"#); - } - if info > 0 { - bar.push_str(r#" Info"#); - } - bar.push_str("
"); - bar - } else { - String::new() - }; - - // Build findings grouped by severity - let severity_order = ["critical", "high", "medium", "low", "info"]; - let severity_labels = ["Critical", "High", "Medium", "Low", "Informational"]; - let severity_colors = ["#991b1b", "#c2410c", "#a16207", "#1d4ed8", "#4b5563"]; - - let mut findings_html = String::new(); - let mut finding_num = 0usize; - - for (si, &sev_key) in severity_order.iter().enumerate() { - let sev_findings: Vec<&DastFinding> = ctx - .findings - .iter() - .filter(|f| f.severity.to_string() == sev_key) - .collect(); - if sev_findings.is_empty() { - continue; - } - - findings_html.push_str(&format!( - r#"

{label} ({count})

"#, - color = severity_colors[si], - label = severity_labels[si], - count = sev_findings.len(), - )); - - for f in sev_findings { - finding_num += 1; - let sev_color = severity_colors[si]; - let exploitable_badge = if f.exploitable { - r#"EXPLOITABLE"# - } else { - "" - }; - let cwe_cell = f - .cwe - .as_deref() - .map(|c| format!("CWE{}", html_escape(c))) - .unwrap_or_default(); - let param_row = f - .parameter - .as_deref() - .map(|p| { - format!( - "Parameter{}", - html_escape(p) - ) - }) - .unwrap_or_default(); - let remediation = f - .remediation - .as_deref() - .unwrap_or("Refer to industry best practices for this vulnerability class."); - - let evidence_html = if f.evidence.is_empty() { - String::new() - } else { - let mut eh = String::from( - r#"
Evidence
"#, - ); - for ev in &f.evidence { - let payload_info = ev - .payload - .as_deref() - .map(|p| format!("
Payload: {}", html_escape(p))) - .unwrap_or_default(); - eh.push_str(&format!( - "", - html_escape(&ev.request_method), - html_escape(&ev.request_url), - ev.response_status, - ev.response_snippet - .as_deref() - .map(html_escape) - .unwrap_or_default(), - payload_info, - )); - } - eh.push_str("
RequestStatusDetails
{} {}{}{}{}
"); - eh - }; - - let linked_sast = f - .linked_sast_finding_id - .as_deref() - .map(|id| { - format!( - r#"
Correlated SAST Finding: {id}
"# - ) - }) - .unwrap_or_default(); - - findings_html.push_str(&format!( - r#" -
-
- F-{num:03} - {title} - {exploitable_badge} -
- - - - {param_row} - {cwe_cell} -
Type{vuln_type}
Endpoint{method} {endpoint}
-
{description}
- {evidence_html} - {linked_sast} -
-
Recommendation
- {remediation} -
-
- "#, - num = finding_num, - title = html_escape(&f.title), - vuln_type = f.vuln_type, - method = f.method, - endpoint = html_escape(&f.endpoint), - description = html_escape(&f.description), - )); - } - } - - // Build attack chain — group by phase using BFS - let mut chain_html = String::new(); - if !ctx.attack_chain.is_empty() { - // Compute phases via BFS from root nodes - let mut phase_map: std::collections::HashMap = - std::collections::HashMap::new(); - let mut queue: std::collections::VecDeque = std::collections::VecDeque::new(); - - for node in &ctx.attack_chain { - if node.parent_node_ids.is_empty() { - let nid = node.node_id.clone(); - if !nid.is_empty() { - phase_map.insert(nid.clone(), 0); - queue.push_back(nid); - } - } - } - - while let Some(nid) = queue.pop_front() { - let parent_phase = phase_map.get(&nid).copied().unwrap_or(0); - for node in &ctx.attack_chain { - if node.parent_node_ids.contains(&nid) { - let child_id = node.node_id.clone(); - if !child_id.is_empty() && !phase_map.contains_key(&child_id) { - phase_map.insert(child_id.clone(), parent_phase + 1); - queue.push_back(child_id); - } - } - } - } - - // Assign phase 0 to any unassigned nodes - for node in &ctx.attack_chain { - let nid = node.node_id.clone(); - if !nid.is_empty() && !phase_map.contains_key(&nid) { - phase_map.insert(nid, 0); - } - } - - // Group nodes by phase - let max_phase = phase_map.values().copied().max().unwrap_or(0); - let phase_labels = [ - "Reconnaissance", - "Enumeration", - "Exploitation", - "Validation", - "Post-Exploitation", - ]; - - for phase_idx in 0..=max_phase { - let phase_nodes: Vec<&AttackChainNode> = ctx - .attack_chain - .iter() - .filter(|n| { - let nid = n.node_id.clone(); - phase_map.get(&nid).copied().unwrap_or(0) == phase_idx - }) - .collect(); - - if phase_nodes.is_empty() { - continue; - } - - let label = if phase_idx < phase_labels.len() { - phase_labels[phase_idx] - } else { - "Additional Testing" - }; - - chain_html.push_str(&format!( - r#"
-
- Phase {} - {} - {} step{} -
-
"#, - phase_idx + 1, - label, - phase_nodes.len(), - if phase_nodes.len() == 1 { "" } else { "s" }, - )); - - for (i, node) in phase_nodes.iter().enumerate() { - let status_label = format!("{:?}", node.status); - let status_class = match status_label.to_lowercase().as_str() { - "completed" => "step-completed", - "failed" => "step-failed", - _ => "step-running", - }; - let findings_badge = if !node.findings_produced.is_empty() { - format!( - r#"{} finding{}"#, - node.findings_produced.len(), - if node.findings_produced.len() == 1 { - "" - } else { - "s" - }, - ) - } else { - String::new() - }; - let risk_badge = node - .risk_score - .map(|r| { - let risk_class = if r >= 70 { - "risk-high" - } else if r >= 40 { - "risk-med" - } else { - "risk-low" - }; - format!(r#"Risk: {r}"#) - }) - .unwrap_or_default(); - - let reasoning_html = if node.llm_reasoning.is_empty() { - String::new() - } else { - format!( - r#"
{}
"#, - html_escape(&node.llm_reasoning) - ) - }; - - chain_html.push_str(&format!( - r#"
-
{num}
-
-
-
- {tool_name} - {status_label} - {findings_badge} - {risk_badge} -
- {reasoning_html} -
-
"#, - num = i + 1, - tool_name = html_escape(&node.tool_name), - )); - } - - chain_html.push_str("
"); - } - } - - // Tools methodology table - let tools_table: String = tool_names - .iter() - .enumerate() - .map(|(i, t)| { - let category = tool_category(t); - format!( - "{}{}{}", - i + 1, - html_escape(t), - category, - ) - }) - .collect::>() - .join("\n"); - - // Table of contents - let toc_findings_sub = if !ctx.findings.is_empty() { - let mut sub = String::new(); - let mut fnum = 0usize; - for &sev_key in severity_order.iter() { - let count = ctx - .findings - .iter() - .filter(|f| f.severity.to_string() == sev_key) - .count(); - if count == 0 { - continue; - } - for f in ctx - .findings - .iter() - .filter(|f| f.severity.to_string() == sev_key) - { - fnum += 1; - sub.push_str(&format!( - r#"
F-{:03} — {}
"#, - fnum, - html_escape(&f.title), - )); - } - } - sub - } else { - String::new() - }; - - format!( - r##" - - - - -Penetration Test Report — {target_name} - - - - - - - - -
- - - - - - - - - - - - - - -
CONFIDENTIAL
- -
Penetration Test Report
-
{target_name}
- -
- -
- Report ID: {session_id}
- Date: {date_short}
- Target: {target_url}
- Prepared for: {requester_name} ({requester_email}) -
- - -
- - -
- -
-

Table of Contents

-
1Executive Summary
-
2Scope & Methodology
-
3Findings ({total_findings})
- {toc_findings_sub} -
4Attack Chain Timeline
-
5Appendix
-
- - -

1. Executive Summary

- -
-
-
-
-
-
{risk_score} / 100
-
-
-
Overall Risk: {overall_risk}
-
- Based on {total_findings} finding{findings_plural} identified across the target application. -
-
-
- -
-
-
{total_findings}
-
Total Findings
-
-
-
{critical_high}
-
Critical / High
-
-
-
{exploitable_count}
-
Exploitable
-
-
-
{tool_count}
-
Tools Used
-
-
- -

Severity Distribution

-{severity_bar} - -

- This report presents the results of an automated penetration test conducted against - {target_name} ({target_url}) using the Compliance Scanner - AI-powered testing engine. A total of {total_findings} vulnerabilities were - identified, of which {exploitable_count} were confirmed exploitable with - working proof-of-concept payloads. The assessment employed {tool_count} security tools - across {tool_invocations} invocations ({success_rate:.0}% success rate). -

- - -
-

2. Scope & Methodology

- -

- The assessment was performed using an AI-driven orchestrator that autonomously selects and - executes security testing tools based on the target's attack surface, technology stack, and - any available static analysis (SAST) findings and SBOM data. -

- -

Engagement Details

- - - - - - - - -
Target{target_name}
URL{target_url}
Strategy{strategy}
Status{status}
Started{date_str}
Completed{completed_str}
Tool Invocations{tool_invocations} ({tool_successes} successful, {success_rate:.1}% success rate)
- -

Tools Employed

- - - {tools_table} -
#ToolCategory
- - -
-

3. Findings

- -{findings_section} - - -
-

4. Attack Chain Timeline

- -

- The following sequence shows each tool invocation made by the AI orchestrator during the assessment, - grouped by phase. Each step includes the tool's name, execution status, and the AI's reasoning - for choosing that action. -

- -
- {chain_section} -
- - -
-

5. Appendix

- -

Severity Definitions

- - - - - - -
CriticalVulnerabilities that can be exploited remotely without authentication to execute arbitrary code, exfiltrate sensitive data, or fully compromise the system.
HighVulnerabilities that allow significant unauthorized access or data exposure, typically requiring minimal user interaction or privileges.
MediumVulnerabilities that may lead to limited data exposure or require specific conditions to exploit, but still represent meaningful risk.
LowMinor issues with limited direct impact. May contribute to broader attack chains or indicate defense-in-depth weaknesses.
InfoObservations and best-practice recommendations that do not represent direct security vulnerabilities.
- -

Disclaimer

-

- This report was generated by an automated AI-powered penetration testing engine. While the system - employs advanced techniques to identify vulnerabilities, no automated assessment can guarantee - complete coverage. The results should be reviewed by qualified security professionals and validated - in the context of the target application's threat model. Findings are point-in-time observations - and may change as the application evolves. -

- - - - -
- - -"##, - target_name = html_escape(&ctx.target_name), - target_url = html_escape(&ctx.target_url), - session_id = html_escape(&session_id), - date_str = date_str, - date_short = date_short, - completed_str = completed_str, - requester_name = html_escape(&ctx.requester_name), - requester_email = html_escape(&ctx.requester_email), - risk_color = risk_color, - risk_score = risk_score, - overall_risk = overall_risk, - total_findings = total, - findings_plural = if total == 1 { "" } else { "s" }, - critical_high = format!("{} / {}", critical, high), - exploitable_count = exploitable, - tool_count = tool_names.len(), - strategy = session.strategy, - status = session.status, - tool_invocations = session.tool_invocations, - tool_successes = session.tool_successes, - success_rate = session.success_rate(), - severity_bar = severity_bar, - tools_table = tools_table, - toc_findings_sub = toc_findings_sub, - findings_section = if ctx.findings.is_empty() { - "

No vulnerabilities were identified during this assessment.

".to_string() - } else { - findings_html - }, - chain_section = if ctx.attack_chain.is_empty() { - "

No attack chain steps recorded.

".to_string() - } else { - chain_html - }, - ) -} - -fn tool_category(tool_name: &str) -> &'static str { - let name = tool_name.to_lowercase(); - if name.contains("nmap") || name.contains("port") { - return "Network Reconnaissance"; - } - if name.contains("nikto") || name.contains("header") { - return "Web Server Analysis"; - } - if name.contains("zap") || name.contains("spider") || name.contains("crawl") { - return "Web Application Scanning"; - } - if name.contains("sqlmap") || name.contains("sqli") || name.contains("sql") { - return "SQL Injection Testing"; - } - if name.contains("xss") || name.contains("cross-site") { - return "Cross-Site Scripting Testing"; - } - if name.contains("dir") - || name.contains("brute") - || name.contains("fuzz") - || name.contains("gobuster") - { - return "Directory Enumeration"; - } - if name.contains("ssl") || name.contains("tls") || name.contains("cert") { - return "SSL/TLS Analysis"; - } - if name.contains("api") || name.contains("endpoint") { - return "API Security Testing"; - } - if name.contains("auth") || name.contains("login") || name.contains("credential") { - return "Authentication Testing"; - } - if name.contains("cors") { - return "CORS Testing"; - } - if name.contains("csrf") { - return "CSRF Testing"; - } - if name.contains("nuclei") || name.contains("template") { - return "Vulnerability Scanning"; - } - if name.contains("whatweb") || name.contains("tech") || name.contains("wappalyzer") { - return "Technology Fingerprinting"; - } - "Security Testing" -} - -fn html_escape(s: &str) -> String { - s.replace('&', "&") - .replace('<', "<") - .replace('>', ">") - .replace('"', """) -} - -#[cfg(test)] -mod tests { - use super::*; - use compliance_core::models::dast::{DastFinding, DastVulnType}; - use compliance_core::models::finding::Severity; - use compliance_core::models::pentest::{ - AttackChainNode, AttackNodeStatus, PentestSession, PentestStrategy, - }; - - // ── html_escape ────────────────────────────────────────────────── - - #[test] - fn html_escape_handles_ampersand() { - assert_eq!(html_escape("a & b"), "a & b"); - } - - #[test] - fn html_escape_handles_angle_brackets() { - assert_eq!(html_escape("