From 71d8741e10710e716de2b687b1a47b2af6d4436b Mon Sep 17 00:00:00 2001 From: Sharang Parnerkar Date: Wed, 11 Mar 2026 19:23:21 +0100 Subject: [PATCH] feat: AI-driven automated penetration testing system Add a complete AI pentest system where Claude autonomously drives security testing via tool-calling. The LLM selects from 16 tools, chains results, and builds an attack chain DAG. Core: - PentestTool trait (dyn-compatible) with PentestToolContext/Result - PentestSession, AttackChainNode, PentestMessage, PentestEvent models - 10 new DastVulnType variants (DNS, DMARC, TLS, cookies, CSP, CORS, etc.) - LLM client chat_with_tools() for OpenAI-compatible tool calling Tools (16 total): - 5 agent wrappers: SQL injection, XSS, auth bypass, SSRF, API fuzzer - 11 new infra tools: DNS checker, DMARC checker, TLS analyzer, security headers, cookie analyzer, CSP analyzer, rate limit tester, console log detector, CORS checker, OpenAPI parser, recon - ToolRegistry for tool lookup and LLM definition generation Orchestrator: - PentestOrchestrator with iterative tool-calling loop (max 50 rounds) - Attack chain node recording per tool invocation - SSE event broadcasting for real-time progress - Strategy-aware system prompts (quick/comprehensive/targeted/aggressive/stealth) API (9 endpoints): - POST/GET /pentest/sessions, GET /pentest/sessions/:id - POST /pentest/sessions/:id/chat, GET /pentest/sessions/:id/stream - GET /pentest/sessions/:id/attack-chain, messages, findings - GET /pentest/stats Dashboard: - Pentest dashboard with stat cards, severity distribution, session list - Chat-based session page with split layout (chat + findings/attack chain) - Inline tool execution indicators, auto-polling, new session modal - Sidebar navigation item Co-Authored-By: Claude Opus 4.6 --- Cargo.lock | 89 ++- compliance-agent/src/api/handlers/mod.rs | 3 +- compliance-agent/src/api/handlers/pentest.rs | 564 ++++++++++++++++++ compliance-agent/src/api/routes.rs | 30 + compliance-agent/src/database.rs | 45 ++ compliance-agent/src/llm/client.rs | 297 ++++++--- compliance-agent/src/main.rs | 1 + compliance-agent/src/pentest/mod.rs | 3 + compliance-agent/src/pentest/orchestrator.rs | 393 ++++++++++++ compliance-core/src/models/dast.rs | 23 + compliance-core/src/models/mod.rs | 6 + compliance-core/src/models/pentest.rs | 294 +++++++++ compliance-core/src/traits/mod.rs | 2 + compliance-core/src/traits/pentest_tool.rs | 63 ++ compliance-dashboard/src/app.rs | 4 + .../src/components/sidebar.rs | 6 + .../src/infrastructure/mod.rs | 1 + .../src/infrastructure/pentest.rs | 190 ++++++ compliance-dashboard/src/pages/mod.rs | 4 + .../src/pages/pentest_dashboard.rs | 396 ++++++++++++ .../src/pages/pentest_session.rs | 445 ++++++++++++++ compliance-dast/Cargo.toml | 4 + compliance-dast/src/lib.rs | 2 + compliance-dast/src/tools/api_fuzzer.rs | 146 +++++ compliance-dast/src/tools/auth_bypass.rs | 130 ++++ .../src/tools/console_log_detector.rs | 326 ++++++++++ compliance-dast/src/tools/cookie_analyzer.rs | 401 +++++++++++++ compliance-dast/src/tools/cors_checker.rs | 410 +++++++++++++ compliance-dast/src/tools/csp_analyzer.rs | 447 ++++++++++++++ compliance-dast/src/tools/dmarc_checker.rs | 401 +++++++++++++ compliance-dast/src/tools/dns_checker.rs | 389 ++++++++++++ compliance-dast/src/tools/mod.rs | 141 +++++ compliance-dast/src/tools/openapi_parser.rs | 422 +++++++++++++ .../src/tools/rate_limit_tester.rs | 285 +++++++++ compliance-dast/src/tools/recon.rs | 125 ++++ compliance-dast/src/tools/security_headers.rs | 300 ++++++++++ compliance-dast/src/tools/sql_injection.rs | 138 +++++ compliance-dast/src/tools/ssrf.rs | 134 +++++ compliance-dast/src/tools/tls_analyzer.rs | 442 ++++++++++++++ compliance-dast/src/tools/xss.rs | 134 +++++ 40 files changed, 7546 insertions(+), 90 deletions(-) create mode 100644 compliance-agent/src/api/handlers/pentest.rs create mode 100644 compliance-agent/src/pentest/mod.rs create mode 100644 compliance-agent/src/pentest/orchestrator.rs create mode 100644 compliance-core/src/models/pentest.rs create mode 100644 compliance-core/src/traits/pentest_tool.rs create mode 100644 compliance-dashboard/src/infrastructure/pentest.rs create mode 100644 compliance-dashboard/src/pages/pentest_dashboard.rs create mode 100644 compliance-dashboard/src/pages/pentest_session.rs create mode 100644 compliance-dast/src/tools/api_fuzzer.rs create mode 100644 compliance-dast/src/tools/auth_bypass.rs create mode 100644 compliance-dast/src/tools/console_log_detector.rs create mode 100644 compliance-dast/src/tools/cookie_analyzer.rs create mode 100644 compliance-dast/src/tools/cors_checker.rs create mode 100644 compliance-dast/src/tools/csp_analyzer.rs create mode 100644 compliance-dast/src/tools/dmarc_checker.rs create mode 100644 compliance-dast/src/tools/dns_checker.rs create mode 100644 compliance-dast/src/tools/mod.rs create mode 100644 compliance-dast/src/tools/openapi_parser.rs create mode 100644 compliance-dast/src/tools/rate_limit_tester.rs create mode 100644 compliance-dast/src/tools/recon.rs create mode 100644 compliance-dast/src/tools/security_headers.rs create mode 100644 compliance-dast/src/tools/sql_injection.rs create mode 100644 compliance-dast/src/tools/ssrf.rs create mode 100644 compliance-dast/src/tools/tls_analyzer.rs create mode 100644 compliance-dast/src/tools/xss.rs diff --git a/Cargo.lock b/Cargo.lock index d3e3e28..4dc02f9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -680,12 +680,14 @@ dependencies = [ "chrono", "compliance-core", "mongodb", + "native-tls", "reqwest", "scraper", "serde", "serde_json", "thiserror 2.0.18", "tokio", + "tokio-native-tls", "tracing", "url", "uuid", @@ -1994,6 +1996,21 @@ version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "77ce24cb58228fbb8aa041425bb1050850ac19177686ea6e0f41a70416f56fdb" +[[package]] +name = "foreign-types" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f6f339eb8adc052cd2ca78910fda869aefa38d22d5cb648e6485e4d3fc06f3b1" +dependencies = [ + "foreign-types-shared", +] + +[[package]] +name = "foreign-types-shared" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "00b0228411908ca8685dba7fc2cdd70ec9990a6e753e89b6ac91a84c40fbaf4b" + [[package]] name = "form_urlencoded" version = "1.2.2" @@ -2824,15 +2841,6 @@ dependencies = [ "serde", ] -[[package]] -name = "itertools" -version = "0.12.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ba291022dbbd398a455acf126c1e341954079855bc60dfdda641363bd6922569" -dependencies = [ - "either", -] - [[package]] name = "itertools" version = "0.14.0" @@ -3399,6 +3407,23 @@ version = "0.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2195bf6aa996a481483b29d62a7663eed3fe39600c460e323f8ff41e90bdd89b" +[[package]] +name = "native-tls" +version = "0.2.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "465500e14ea162429d264d44189adc38b199b62b1c21eea9f69e4b73cb03bbf2" +dependencies = [ + "libc", + "log", + "openssl", + "openssl-probe 0.2.1", + "openssl-sys", + "schannel", + "security-framework", + "security-framework-sys", + "tempfile", +] + [[package]] name = "ndk" version = "0.9.0" @@ -3578,6 +3603,32 @@ version = "0.1.13" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "269bca4c2591a28585d6bf10d9ed0332b7d76900a1b02bec41bdc3a2cdcda107" +[[package]] +name = "openssl" +version = "0.10.75" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "08838db121398ad17ab8531ce9de97b244589089e290a384c900cb9ff7434328" +dependencies = [ + "bitflags", + "cfg-if", + "foreign-types", + "libc", + "once_cell", + "openssl-macros", + "openssl-sys", +] + +[[package]] +name = "openssl-macros" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "openssl-probe" version = "0.1.6" @@ -3949,7 +4000,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8a56d757972c98b346a9b766e3f02746cde6dd1cd1d1d563472929fdd74bec4d" dependencies = [ "anyhow", - "itertools 0.12.1", + "itertools", "proc-macro2", "quote", "syn", @@ -4899,7 +4950,7 @@ version = "0.8.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c1c97747dbf44bb1ca44a561ece23508e99cb592e862f22222dcf42f51d1e451" dependencies = [ - "heck 0.4.1", + "heck 0.5.0", "proc-macro2", "quote", "syn", @@ -5116,7 +5167,7 @@ dependencies = [ "fs4", "htmlescape", "hyperloglogplus", - "itertools 0.14.0", + "itertools", "levenshtein_automata", "log", "lru 0.12.5", @@ -5164,7 +5215,7 @@ checksum = "8b628488ae936c83e92b5c4056833054ca56f76c0e616aee8339e24ac89119cd" dependencies = [ "downcast-rs", "fastdivide", - "itertools 0.14.0", + "itertools", "serde", "tantivy-bitpacker", "tantivy-common", @@ -5214,7 +5265,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f8292095d1a8a2c2b36380ec455f910ab52dde516af36321af332c93f20ab7d5" dependencies = [ "futures-util", - "itertools 0.14.0", + "itertools", "tantivy-bitpacker", "tantivy-common", "tantivy-fst", @@ -5428,6 +5479,16 @@ dependencies = [ "syn", ] +[[package]] +name = "tokio-native-tls" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bbae76ab933c85776efabc971569dd6119c580d8f5d448769dec1764bf796ef2" +dependencies = [ + "native-tls", + "tokio", +] + [[package]] name = "tokio-rustls" version = "0.26.4" diff --git a/compliance-agent/src/api/handlers/mod.rs b/compliance-agent/src/api/handlers/mod.rs index 8f353dd..63d64ef 100644 --- a/compliance-agent/src/api/handlers/mod.rs +++ b/compliance-agent/src/api/handlers/mod.rs @@ -1,6 +1,7 @@ pub mod chat; pub mod dast; pub mod graph; +pub mod pentest; use std::sync::Arc; @@ -1108,7 +1109,7 @@ pub async fn list_scan_runs( })) } -async fn collect_cursor_async( +pub(crate) async fn collect_cursor_async( mut cursor: mongodb::Cursor, ) -> Vec { use futures_util::StreamExt; diff --git a/compliance-agent/src/api/handlers/pentest.rs b/compliance-agent/src/api/handlers/pentest.rs new file mode 100644 index 0000000..ae8014b --- /dev/null +++ b/compliance-agent/src/api/handlers/pentest.rs @@ -0,0 +1,564 @@ +use std::sync::Arc; + +use axum::extract::{Extension, Path, Query}; +use axum::http::StatusCode; +use axum::response::sse::{Event, Sse}; +use axum::Json; +use futures_util::stream; +use mongodb::bson::doc; +use serde::Deserialize; + +use compliance_core::models::dast::DastFinding; +use compliance_core::models::pentest::*; + +use crate::agent::ComplianceAgent; +use crate::pentest::PentestOrchestrator; + +use super::{collect_cursor_async, ApiResponse, PaginationParams}; + +type AgentExt = Extension>; + +#[derive(Deserialize)] +pub struct CreateSessionRequest { + pub target_id: String, + #[serde(default = "default_strategy")] + pub strategy: String, + pub message: Option, +} + +fn default_strategy() -> String { + "comprehensive".to_string() +} + +#[derive(Deserialize)] +pub struct SendMessageRequest { + pub message: String, +} + +/// POST /api/v1/pentest/sessions — Create a new pentest session and start the orchestrator +#[tracing::instrument(skip_all)] +pub async fn create_session( + Extension(agent): AgentExt, + Json(req): Json, +) -> Result>, (StatusCode, String)> { + let oid = mongodb::bson::oid::ObjectId::parse_str(&req.target_id).map_err(|_| { + ( + StatusCode::BAD_REQUEST, + "Invalid target_id format".to_string(), + ) + })?; + + // Look up the target + let target = agent + .db + .dast_targets() + .find_one(doc! { "_id": oid }) + .await + .map_err(|e| { + ( + StatusCode::INTERNAL_SERVER_ERROR, + format!("Database error: {e}"), + ) + })? + .ok_or_else(|| (StatusCode::NOT_FOUND, "Target not found".to_string()))?; + + // Parse strategy + let strategy = match req.strategy.as_str() { + "quick" => PentestStrategy::Quick, + "targeted" => PentestStrategy::Targeted, + "aggressive" => PentestStrategy::Aggressive, + "stealth" => PentestStrategy::Stealth, + _ => PentestStrategy::Comprehensive, + }; + + // Create session + let mut session = PentestSession::new(req.target_id.clone(), strategy); + session.repo_id = target.repo_id.clone(); + + agent + .db + .pentest_sessions() + .insert_one(&session) + .await + .map_err(|e| { + ( + StatusCode::INTERNAL_SERVER_ERROR, + format!("Failed to create session: {e}"), + ) + })?; + + let initial_message = req.message.unwrap_or_else(|| { + format!( + "Begin a {} penetration test against {} ({}). \ + Identify vulnerabilities and provide evidence for each finding.", + session.strategy, target.name, target.base_url, + ) + }); + + // Spawn the orchestrator on a background task + let llm = agent.llm.clone(); + let db = agent.db.clone(); + let session_clone = session.clone(); + let target_clone = target.clone(); + tokio::spawn(async move { + let orchestrator = PentestOrchestrator::new(llm, db); + if let Err(e) = orchestrator + .run_session(&session_clone, &target_clone, &initial_message) + .await + { + tracing::error!( + "Pentest orchestrator failed for session {}: {e}", + session_clone + .id + .map(|oid| oid.to_hex()) + .unwrap_or_default() + ); + } + }); + + Ok(Json(ApiResponse { + data: session, + total: None, + page: None, + })) +} + +/// GET /api/v1/pentest/sessions — List pentest sessions +#[tracing::instrument(skip_all)] +pub async fn list_sessions( + Extension(agent): AgentExt, + Query(params): Query, +) -> Result>>, StatusCode> { + let db = &agent.db; + let skip = (params.page.saturating_sub(1)) * params.limit as u64; + let total = db + .pentest_sessions() + .count_documents(doc! {}) + .await + .unwrap_or(0); + + let sessions = match db + .pentest_sessions() + .find(doc! {}) + .sort(doc! { "started_at": -1 }) + .skip(skip) + .limit(params.limit) + .await + { + Ok(cursor) => collect_cursor_async(cursor).await, + Err(e) => { + tracing::warn!("Failed to fetch pentest sessions: {e}"); + Vec::new() + } + }; + + Ok(Json(ApiResponse { + data: sessions, + total: Some(total), + page: Some(params.page), + })) +} + +/// GET /api/v1/pentest/sessions/:id — Get a single pentest session +#[tracing::instrument(skip_all, fields(session_id = %id))] +pub async fn get_session( + Extension(agent): AgentExt, + Path(id): Path, +) -> Result>, StatusCode> { + let oid = + mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?; + + let session = agent + .db + .pentest_sessions() + .find_one(doc! { "_id": oid }) + .await + .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)? + .ok_or(StatusCode::NOT_FOUND)?; + + Ok(Json(ApiResponse { + data: session, + total: None, + page: None, + })) +} + +/// POST /api/v1/pentest/sessions/:id/chat — Send a user message and trigger next orchestrator iteration +#[tracing::instrument(skip_all, fields(session_id = %id))] +pub async fn send_message( + Extension(agent): AgentExt, + Path(id): Path, + Json(req): Json, +) -> Result>, (StatusCode, String)> { + let oid = mongodb::bson::oid::ObjectId::parse_str(&id) + .map_err(|_| (StatusCode::BAD_REQUEST, "Invalid session ID".to_string()))?; + + // Verify session exists and is running + let session = agent + .db + .pentest_sessions() + .find_one(doc! { "_id": oid }) + .await + .map_err(|e| { + ( + StatusCode::INTERNAL_SERVER_ERROR, + format!("Database error: {e}"), + ) + })? + .ok_or_else(|| (StatusCode::NOT_FOUND, "Session not found".to_string()))?; + + if session.status != PentestStatus::Running && session.status != PentestStatus::Paused { + return Err(( + StatusCode::BAD_REQUEST, + format!("Session is {}, cannot send messages", session.status), + )); + } + + // Look up the target + let target_oid = + mongodb::bson::oid::ObjectId::parse_str(&session.target_id).map_err(|_| { + ( + StatusCode::INTERNAL_SERVER_ERROR, + "Invalid target_id in session".to_string(), + ) + })?; + + let target = agent + .db + .dast_targets() + .find_one(doc! { "_id": target_oid }) + .await + .map_err(|e| { + ( + StatusCode::INTERNAL_SERVER_ERROR, + format!("Database error: {e}"), + ) + })? + .ok_or_else(|| { + ( + StatusCode::NOT_FOUND, + "Target for session not found".to_string(), + ) + })?; + + // Store user message + let session_id = id.clone(); + let user_msg = PentestMessage::user(session_id.clone(), req.message.clone()); + let _ = agent.db.pentest_messages().insert_one(&user_msg).await; + + let response_msg = user_msg.clone(); + + // Spawn orchestrator to continue the session + let llm = agent.llm.clone(); + let db = agent.db.clone(); + let message = req.message.clone(); + tokio::spawn(async move { + let orchestrator = PentestOrchestrator::new(llm, db); + if let Err(e) = orchestrator.run_session(&session, &target, &message).await { + tracing::error!("Pentest orchestrator failed for session {session_id}: {e}"); + } + }); + + Ok(Json(ApiResponse { + data: response_msg, + total: None, + page: None, + })) +} + +/// GET /api/v1/pentest/sessions/:id/stream — SSE endpoint for real-time events +/// +/// Returns recent messages as SSE events (polling approach). +/// True real-time streaming with broadcast channels will be added in a future iteration. +#[tracing::instrument(skip_all, fields(session_id = %id))] +pub async fn session_stream( + Extension(agent): AgentExt, + Path(id): Path, +) -> Result>>, StatusCode> +{ + let oid = + mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?; + + // Verify session exists + let _session = agent + .db + .pentest_sessions() + .find_one(doc! { "_id": oid }) + .await + .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)? + .ok_or(StatusCode::NOT_FOUND)?; + + // Fetch recent messages for this session + let messages: Vec = match agent + .db + .pentest_messages() + .find(doc! { "session_id": &id }) + .sort(doc! { "created_at": 1 }) + .limit(100) + .await + { + Ok(cursor) => collect_cursor_async(cursor).await, + Err(_) => Vec::new(), + }; + + // Fetch recent attack chain nodes + let nodes: Vec = match agent + .db + .attack_chain_nodes() + .find(doc! { "session_id": &id }) + .sort(doc! { "started_at": 1 }) + .limit(100) + .await + { + Ok(cursor) => collect_cursor_async(cursor).await, + Err(_) => Vec::new(), + }; + + // Build SSE events from stored data + let mut events: Vec> = Vec::new(); + + for msg in &messages { + let event_data = serde_json::json!({ + "type": "message", + "role": msg.role, + "content": msg.content, + "created_at": msg.created_at.to_rfc3339(), + }); + if let Ok(data) = serde_json::to_string(&event_data) { + events.push(Ok(Event::default().event("message").data(data))); + } + } + + for node in &nodes { + let event_data = serde_json::json!({ + "type": "tool_execution", + "node_id": node.node_id, + "tool_name": node.tool_name, + "status": node.status, + "findings_produced": node.findings_produced, + }); + if let Ok(data) = serde_json::to_string(&event_data) { + events.push(Ok(Event::default().event("tool").data(data))); + } + } + + // Add session status event + let session = agent + .db + .pentest_sessions() + .find_one(doc! { "_id": oid }) + .await + .ok() + .flatten(); + + if let Some(s) = session { + let status_data = serde_json::json!({ + "type": "status", + "status": s.status, + "findings_count": s.findings_count, + "tool_invocations": s.tool_invocations, + }); + if let Ok(data) = serde_json::to_string(&status_data) { + events.push(Ok(Event::default().event("status").data(data))); + } + } + + Ok(Sse::new(stream::iter(events))) +} + +/// GET /api/v1/pentest/sessions/:id/attack-chain — Get attack chain nodes for a session +#[tracing::instrument(skip_all, fields(session_id = %id))] +pub async fn get_attack_chain( + Extension(agent): AgentExt, + Path(id): Path, +) -> Result>>, StatusCode> { + // Verify the session ID is valid + let _oid = + mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?; + + let nodes = match agent + .db + .attack_chain_nodes() + .find(doc! { "session_id": &id }) + .sort(doc! { "started_at": 1 }) + .await + { + Ok(cursor) => collect_cursor_async(cursor).await, + Err(e) => { + tracing::warn!("Failed to fetch attack chain nodes: {e}"); + Vec::new() + } + }; + + let total = nodes.len() as u64; + Ok(Json(ApiResponse { + data: nodes, + total: Some(total), + page: None, + })) +} + +/// GET /api/v1/pentest/sessions/:id/messages — Get messages for a session +#[tracing::instrument(skip_all, fields(session_id = %id))] +pub async fn get_messages( + Extension(agent): AgentExt, + Path(id): Path, + Query(params): Query, +) -> Result>>, StatusCode> { + let _oid = + mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?; + + let skip = (params.page.saturating_sub(1)) * params.limit as u64; + let total = agent + .db + .pentest_messages() + .count_documents(doc! { "session_id": &id }) + .await + .unwrap_or(0); + + let messages = match agent + .db + .pentest_messages() + .find(doc! { "session_id": &id }) + .sort(doc! { "created_at": 1 }) + .skip(skip) + .limit(params.limit) + .await + { + Ok(cursor) => collect_cursor_async(cursor).await, + Err(e) => { + tracing::warn!("Failed to fetch pentest messages: {e}"); + Vec::new() + } + }; + + Ok(Json(ApiResponse { + data: messages, + total: Some(total), + page: Some(params.page), + })) +} + +/// GET /api/v1/pentest/stats — Aggregated pentest statistics +#[tracing::instrument(skip_all)] +pub async fn pentest_stats( + Extension(agent): AgentExt, +) -> Result>, StatusCode> { + let db = &agent.db; + + let running_sessions = db + .pentest_sessions() + .count_documents(doc! { "status": "running" }) + .await + .unwrap_or(0) as u32; + + // Count DAST findings from pentest sessions + let total_vulnerabilities = db + .dast_findings() + .count_documents(doc! { "session_id": { "$exists": true, "$ne": null } }) + .await + .unwrap_or(0) as u32; + + // Aggregate tool invocations from all sessions + let sessions: Vec = match db.pentest_sessions().find(doc! {}).await { + Ok(cursor) => collect_cursor_async(cursor).await, + Err(_) => Vec::new(), + }; + + let total_tool_invocations: u32 = sessions.iter().map(|s| s.tool_invocations).sum(); + let total_successes: u32 = sessions.iter().map(|s| s.tool_successes).sum(); + let tool_success_rate = if total_tool_invocations == 0 { + 100.0 + } else { + (total_successes as f64 / total_tool_invocations as f64) * 100.0 + }; + + // Severity distribution from pentest-related DAST findings + let pentest_filter = doc! { "session_id": { "$exists": true, "$ne": null } }; + let critical = db + .dast_findings() + .count_documents(doc! { "session_id": { "$exists": true, "$ne": null }, "severity": "critical" }) + .await + .unwrap_or(0) as u32; + let high = db + .dast_findings() + .count_documents(doc! { "session_id": { "$exists": true, "$ne": null }, "severity": "high" }) + .await + .unwrap_or(0) as u32; + let medium = db + .dast_findings() + .count_documents(doc! { "session_id": { "$exists": true, "$ne": null }, "severity": "medium" }) + .await + .unwrap_or(0) as u32; + let low = db + .dast_findings() + .count_documents(doc! { "session_id": { "$exists": true, "$ne": null }, "severity": "low" }) + .await + .unwrap_or(0) as u32; + let info = db + .dast_findings() + .count_documents(doc! { "session_id": { "$exists": true, "$ne": null }, "severity": "info" }) + .await + .unwrap_or(0) as u32; + + let _ = pentest_filter; // used above inline + + Ok(Json(ApiResponse { + data: PentestStats { + running_sessions, + total_vulnerabilities, + total_tool_invocations, + tool_success_rate, + severity_distribution: SeverityDistribution { + critical, + high, + medium, + low, + info, + }, + }, + total: None, + page: None, + })) +} + +/// GET /api/v1/pentest/sessions/:id/findings — Get DAST findings for a pentest session +#[tracing::instrument(skip_all, fields(session_id = %id))] +pub async fn get_session_findings( + Extension(agent): AgentExt, + Path(id): Path, + Query(params): Query, +) -> Result>>, StatusCode> { + let _oid = + mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?; + + let skip = (params.page.saturating_sub(1)) * params.limit as u64; + let total = agent + .db + .dast_findings() + .count_documents(doc! { "session_id": &id }) + .await + .unwrap_or(0); + + let findings = match agent + .db + .dast_findings() + .find(doc! { "session_id": &id }) + .sort(doc! { "created_at": -1 }) + .skip(skip) + .limit(params.limit) + .await + { + Ok(cursor) => collect_cursor_async(cursor).await, + Err(e) => { + tracing::warn!("Failed to fetch pentest session findings: {e}"); + Vec::new() + } + }; + + Ok(Json(ApiResponse { + data: findings, + total: Some(total), + page: Some(params.page), + })) +} diff --git a/compliance-agent/src/api/routes.rs b/compliance-agent/src/api/routes.rs index 502984b..4c755e1 100644 --- a/compliance-agent/src/api/routes.rs +++ b/compliance-agent/src/api/routes.rs @@ -99,6 +99,36 @@ pub fn build_router() -> Router { "/api/v1/chat/{repo_id}/status", get(handlers::chat::embedding_status), ) + // Pentest API endpoints + .route( + "/api/v1/pentest/sessions", + get(handlers::pentest::list_sessions).post(handlers::pentest::create_session), + ) + .route( + "/api/v1/pentest/sessions/{id}", + get(handlers::pentest::get_session), + ) + .route( + "/api/v1/pentest/sessions/{id}/chat", + post(handlers::pentest::send_message), + ) + .route( + "/api/v1/pentest/sessions/{id}/stream", + get(handlers::pentest::session_stream), + ) + .route( + "/api/v1/pentest/sessions/{id}/attack-chain", + get(handlers::pentest::get_attack_chain), + ) + .route( + "/api/v1/pentest/sessions/{id}/messages", + get(handlers::pentest::get_messages), + ) + .route( + "/api/v1/pentest/sessions/{id}/findings", + get(handlers::pentest::get_session_findings), + ) + .route("/api/v1/pentest/stats", get(handlers::pentest::pentest_stats)) // Webhook endpoints (proxied through dashboard) .route( "/webhook/github/{repo_id}", diff --git a/compliance-agent/src/database.rs b/compliance-agent/src/database.rs index c2b0740..6b0c0d9 100644 --- a/compliance-agent/src/database.rs +++ b/compliance-agent/src/database.rs @@ -166,6 +166,38 @@ impl Database { ) .await?; + // pentest_sessions: compound (target_id, started_at DESC) + self.pentest_sessions() + .create_index( + IndexModel::builder() + .keys(doc! { "target_id": 1, "started_at": -1 }) + .build(), + ) + .await?; + + // pentest_sessions: status index + self.pentest_sessions() + .create_index(IndexModel::builder().keys(doc! { "status": 1 }).build()) + .await?; + + // attack_chain_nodes: compound (session_id, node_id) + self.attack_chain_nodes() + .create_index( + IndexModel::builder() + .keys(doc! { "session_id": 1, "node_id": 1 }) + .build(), + ) + .await?; + + // pentest_messages: compound (session_id, created_at) + self.pentest_messages() + .create_index( + IndexModel::builder() + .keys(doc! { "session_id": 1, "created_at": 1 }) + .build(), + ) + .await?; + tracing::info!("Database indexes ensured"); Ok(()) } @@ -235,6 +267,19 @@ impl Database { self.inner.collection("embedding_builds") } + // Pentest collections + pub fn pentest_sessions(&self) -> Collection { + self.inner.collection("pentest_sessions") + } + + pub fn attack_chain_nodes(&self) -> Collection { + self.inner.collection("attack_chain_nodes") + } + + pub fn pentest_messages(&self) -> Collection { + self.inner.collection("pentest_messages") + } + #[allow(dead_code)] pub fn raw_collection(&self, name: &str) -> Collection { self.inner.collection(name) diff --git a/compliance-agent/src/llm/client.rs b/compliance-agent/src/llm/client.rs index c7a571d..826bda5 100644 --- a/compliance-agent/src/llm/client.rs +++ b/compliance-agent/src/llm/client.rs @@ -12,10 +12,16 @@ pub struct LlmClient { http: reqwest::Client, } -#[derive(Serialize)] -struct ChatMessage { - role: String, - content: String, +// ── Request types ────────────────────────────────────────────── + +#[derive(Serialize, Clone, Debug)] +pub struct ChatMessage { + pub role: String, + pub content: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_calls: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_call_id: Option, } #[derive(Serialize)] @@ -26,8 +32,25 @@ struct ChatCompletionRequest { temperature: Option, #[serde(skip_serializing_if = "Option::is_none")] max_tokens: Option, + #[serde(skip_serializing_if = "Option::is_none")] + tools: Option>, } +#[derive(Serialize)] +struct ToolDefinitionPayload { + r#type: String, + function: ToolFunctionPayload, +} + +#[derive(Serialize)] +struct ToolFunctionPayload { + name: String, + description: String, + parameters: serde_json::Value, +} + +// ── Response types ───────────────────────────────────────────── + #[derive(Deserialize)] struct ChatCompletionResponse { choices: Vec, @@ -40,29 +63,84 @@ struct ChatChoice { #[derive(Deserialize)] struct ChatResponseMessage { - content: String, + #[serde(default)] + content: Option, + #[serde(default)] + tool_calls: Option>, } -/// Request body for the embeddings API +#[derive(Deserialize)] +struct ToolCallResponse { + id: String, + function: ToolCallFunction, +} + +#[derive(Deserialize)] +struct ToolCallFunction { + name: String, + arguments: String, +} + +// ── Public types for tool calling ────────────────────────────── + +/// Definition of a tool that the LLM can invoke +#[derive(Debug, Clone, Serialize)] +pub struct ToolDefinition { + pub name: String, + pub description: String, + pub parameters: serde_json::Value, +} + +/// A tool call request from the LLM +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct LlmToolCall { + pub id: String, + pub name: String, + pub arguments: serde_json::Value, +} + +/// A tool call in the request message format (for sending back tool_calls in assistant messages) +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ToolCallRequest { + pub id: String, + pub r#type: String, + pub function: ToolCallRequestFunction, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ToolCallRequestFunction { + pub name: String, + pub arguments: String, +} + +/// Response from the LLM — either content or tool calls +#[derive(Debug, Clone)] +pub enum LlmResponse { + Content(String), + ToolCalls(Vec), +} + +// ── Embedding types ──────────────────────────────────────────── + #[derive(Serialize)] struct EmbeddingRequest { model: String, input: Vec, } -/// Response from the embeddings API #[derive(Deserialize)] struct EmbeddingResponse { data: Vec, } -/// A single embedding result #[derive(Deserialize)] struct EmbeddingData { embedding: Vec, index: usize, } +// ── Implementation ───────────────────────────────────────────── + impl LlmClient { pub fn new( base_url: String, @@ -83,98 +161,142 @@ impl LlmClient { &self.embed_model } + fn chat_url(&self) -> String { + format!( + "{}/v1/chat/completions", + self.base_url.trim_end_matches('/') + ) + } + + fn auth_header(&self) -> Option { + let key = self.api_key.expose_secret(); + if key.is_empty() { + None + } else { + Some(format!("Bearer {key}")) + } + } + + /// Simple chat: system + user prompt → text response pub async fn chat( &self, system_prompt: &str, user_prompt: &str, temperature: Option, ) -> Result { - let url = format!( - "{}/v1/chat/completions", - self.base_url.trim_end_matches('/') - ); + let messages = vec![ + ChatMessage { + role: "system".to_string(), + content: Some(system_prompt.to_string()), + tool_calls: None, + tool_call_id: None, + }, + ChatMessage { + role: "user".to_string(), + content: Some(user_prompt.to_string()), + tool_calls: None, + tool_call_id: None, + }, + ]; let request_body = ChatCompletionRequest { model: self.model.clone(), - messages: vec![ - ChatMessage { - role: "system".to_string(), - content: system_prompt.to_string(), - }, - ChatMessage { - role: "user".to_string(), - content: user_prompt.to_string(), - }, - ], + messages, temperature, max_tokens: Some(4096), + tools: None, }; - let mut req = self - .http - .post(&url) - .header("content-type", "application/json") - .json(&request_body); - - let key = self.api_key.expose_secret(); - if !key.is_empty() { - req = req.header("Authorization", format!("Bearer {key}")); - } - - let resp = req - .send() - .await - .map_err(|e| AgentError::Other(format!("LiteLLM request failed: {e}")))?; - - if !resp.status().is_success() { - let status = resp.status(); - let body = resp.text().await.unwrap_or_default(); - return Err(AgentError::Other(format!( - "LiteLLM returned {status}: {body}" - ))); - } - - let body: ChatCompletionResponse = resp - .json() - .await - .map_err(|e| AgentError::Other(format!("Failed to parse LiteLLM response: {e}")))?; - - body.choices - .first() - .map(|c| c.message.content.clone()) - .ok_or_else(|| AgentError::Other("Empty response from LiteLLM".to_string())) + self.send_chat_request(&request_body).await.map(|resp| { + match resp { + LlmResponse::Content(c) => c, + LlmResponse::ToolCalls(_) => String::new(), // shouldn't happen without tools + } + }) } + /// Chat with a list of (role, content) messages → text response #[allow(dead_code)] pub async fn chat_with_messages( &self, messages: Vec<(String, String)>, temperature: Option, ) -> Result { - let url = format!( - "{}/v1/chat/completions", - self.base_url.trim_end_matches('/') - ); + let messages = messages + .into_iter() + .map(|(role, content)| ChatMessage { + role, + content: Some(content), + tool_calls: None, + tool_call_id: None, + }) + .collect(); let request_body = ChatCompletionRequest { model: self.model.clone(), - messages: messages - .into_iter() - .map(|(role, content)| ChatMessage { role, content }) - .collect(), + messages, temperature, max_tokens: Some(4096), + tools: None, }; + self.send_chat_request(&request_body).await.map(|resp| { + match resp { + LlmResponse::Content(c) => c, + LlmResponse::ToolCalls(_) => String::new(), + } + }) + } + + /// Chat with tool definitions — returns either content or tool calls. + /// Use this for the AI pentest orchestrator loop. + pub async fn chat_with_tools( + &self, + messages: Vec, + tools: &[ToolDefinition], + temperature: Option, + max_tokens: Option, + ) -> Result { + let tool_payloads: Vec = tools + .iter() + .map(|t| ToolDefinitionPayload { + r#type: "function".to_string(), + function: ToolFunctionPayload { + name: t.name.clone(), + description: t.description.clone(), + parameters: t.parameters.clone(), + }, + }) + .collect(); + + let request_body = ChatCompletionRequest { + model: self.model.clone(), + messages, + temperature, + max_tokens: Some(max_tokens.unwrap_or(8192)), + tools: if tool_payloads.is_empty() { + None + } else { + Some(tool_payloads) + }, + }; + + self.send_chat_request(&request_body).await + } + + /// Internal method to send a chat completion request and parse the response + async fn send_chat_request( + &self, + request_body: &ChatCompletionRequest, + ) -> Result { let mut req = self .http - .post(&url) + .post(&self.chat_url()) .header("content-type", "application/json") - .json(&request_body); + .json(request_body); - let key = self.api_key.expose_secret(); - if !key.is_empty() { - req = req.header("Authorization", format!("Bearer {key}")); + if let Some(auth) = self.auth_header() { + req = req.header("Authorization", auth); } let resp = req @@ -195,10 +317,37 @@ impl LlmClient { .await .map_err(|e| AgentError::Other(format!("Failed to parse LiteLLM response: {e}")))?; - body.choices + let choice = body + .choices .first() - .map(|c| c.message.content.clone()) - .ok_or_else(|| AgentError::Other("Empty response from LiteLLM".to_string())) + .ok_or_else(|| AgentError::Other("Empty response from LiteLLM".to_string()))?; + + // Check for tool calls first + if let Some(tool_calls) = &choice.message.tool_calls { + if !tool_calls.is_empty() { + let calls: Vec = tool_calls + .iter() + .map(|tc| { + let arguments = serde_json::from_str(&tc.function.arguments) + .unwrap_or(serde_json::Value::Object(serde_json::Map::new())); + LlmToolCall { + id: tc.id.clone(), + name: tc.function.name.clone(), + arguments, + } + }) + .collect(); + return Ok(LlmResponse::ToolCalls(calls)); + } + } + + // Otherwise return content + let content = choice + .message + .content + .clone() + .unwrap_or_default(); + Ok(LlmResponse::Content(content)) } /// Generate embeddings for a batch of texts @@ -216,9 +365,8 @@ impl LlmClient { .header("content-type", "application/json") .json(&request_body); - let key = self.api_key.expose_secret(); - if !key.is_empty() { - req = req.header("Authorization", format!("Bearer {key}")); + if let Some(auth) = self.auth_header() { + req = req.header("Authorization", auth); } let resp = req @@ -239,7 +387,6 @@ impl LlmClient { .await .map_err(|e| AgentError::Other(format!("Failed to parse embedding response: {e}")))?; - // Sort by index to maintain input order let mut data = body.data; data.sort_by_key(|d| d.index); diff --git a/compliance-agent/src/main.rs b/compliance-agent/src/main.rs index 97ec23e..9e346e9 100644 --- a/compliance-agent/src/main.rs +++ b/compliance-agent/src/main.rs @@ -4,6 +4,7 @@ mod config; mod database; mod error; mod llm; +mod pentest; mod pipeline; mod rag; mod scheduler; diff --git a/compliance-agent/src/pentest/mod.rs b/compliance-agent/src/pentest/mod.rs new file mode 100644 index 0000000..ba0e0c8 --- /dev/null +++ b/compliance-agent/src/pentest/mod.rs @@ -0,0 +1,3 @@ +pub mod orchestrator; + +pub use orchestrator::PentestOrchestrator; diff --git a/compliance-agent/src/pentest/orchestrator.rs b/compliance-agent/src/pentest/orchestrator.rs new file mode 100644 index 0000000..b8ff183 --- /dev/null +++ b/compliance-agent/src/pentest/orchestrator.rs @@ -0,0 +1,393 @@ +use std::sync::Arc; + +use tokio::sync::broadcast; + +use compliance_core::models::dast::DastTarget; +use compliance_core::models::pentest::*; +use compliance_core::traits::pentest_tool::PentestToolContext; +use compliance_dast::ToolRegistry; + +use crate::database::Database; +use crate::llm::client::{ + ChatMessage, LlmResponse, ToolCallRequest, ToolCallRequestFunction, ToolDefinition, +}; +use crate::llm::LlmClient; + +pub struct PentestOrchestrator { + tool_registry: ToolRegistry, + llm: Arc, + db: Database, + event_tx: broadcast::Sender, +} + +impl PentestOrchestrator { + pub fn new(llm: Arc, db: Database) -> Self { + let (event_tx, _) = broadcast::channel(256); + Self { + tool_registry: ToolRegistry::new(), + llm, + db, + event_tx, + } + } + + pub fn subscribe(&self) -> broadcast::Receiver { + self.event_tx.subscribe() + } + + pub fn event_sender(&self) -> broadcast::Sender { + self.event_tx.clone() + } + + pub async fn run_session( + &self, + session: &PentestSession, + target: &DastTarget, + initial_message: &str, + ) -> Result<(), crate::error::AgentError> { + let session_id = session + .id + .map(|oid| oid.to_hex()) + .unwrap_or_default(); + + // Build system prompt + let system_prompt = self.build_system_prompt(session, target); + + // Build tool definitions for LLM + let tool_defs: Vec = self + .tool_registry + .all_definitions() + .into_iter() + .map(|td| ToolDefinition { + name: td.name, + description: td.description, + parameters: td.input_schema, + }) + .collect(); + + // Initialize messages + let mut messages = vec![ + ChatMessage { + role: "system".to_string(), + content: Some(system_prompt), + tool_calls: None, + tool_call_id: None, + }, + ChatMessage { + role: "user".to_string(), + content: Some(initial_message.to_string()), + tool_calls: None, + tool_call_id: None, + }, + ]; + + // Store user message + let user_msg = PentestMessage::user(session_id.clone(), initial_message.to_string()); + let _ = self.db.pentest_messages().insert_one(&user_msg).await; + + // Build tool context + let tool_context = PentestToolContext { + target: target.clone(), + session_id: session_id.clone(), + sast_findings: Vec::new(), + sbom_entries: Vec::new(), + code_context: Vec::new(), + rate_limit: target.rate_limit, + allow_destructive: target.allow_destructive, + }; + + let max_iterations = 50; + let mut total_findings = 0u32; + let mut total_tool_calls = 0u32; + let mut total_successes = 0u32; + + for _iteration in 0..max_iterations { + // Call LLM with tools + let response = self + .llm + .chat_with_tools(messages.clone(), &tool_defs, Some(0.2), Some(8192)) + .await?; + + match response { + LlmResponse::Content(content) => { + // Store assistant message + let msg = + PentestMessage::assistant(session_id.clone(), content.clone()); + let _ = self.db.pentest_messages().insert_one(&msg).await; + + // Emit message event + let _ = self.event_tx.send(PentestEvent::Message { + content: content.clone(), + }); + + // Add to messages + messages.push(ChatMessage { + role: "assistant".to_string(), + content: Some(content.clone()), + tool_calls: None, + tool_call_id: None, + }); + + // Check if the LLM considers itself done + let done_indicators = [ + "pentest complete", + "testing complete", + "scan complete", + "analysis complete", + "finished", + "that concludes", + ]; + let content_lower = content.to_lowercase(); + if done_indicators + .iter() + .any(|ind| content_lower.contains(ind)) + { + break; + } + + // If not done, break and wait for user input + break; + } + LlmResponse::ToolCalls(tool_calls) => { + // Build the assistant message with tool_calls + let tc_requests: Vec = tool_calls + .iter() + .map(|tc| ToolCallRequest { + id: tc.id.clone(), + r#type: "function".to_string(), + function: ToolCallRequestFunction { + name: tc.name.clone(), + arguments: serde_json::to_string(&tc.arguments) + .unwrap_or_default(), + }, + }) + .collect(); + + messages.push(ChatMessage { + role: "assistant".to_string(), + content: None, + tool_calls: Some(tc_requests), + tool_call_id: None, + }); + + // Execute each tool call + for tc in &tool_calls { + total_tool_calls += 1; + let node_id = uuid::Uuid::new_v4().to_string(); + + // Create attack chain node + let mut node = AttackChainNode::new( + session_id.clone(), + node_id.clone(), + tc.name.clone(), + tc.arguments.clone(), + String::new(), + ); + node.status = AttackNodeStatus::Running; + node.started_at = Some(chrono::Utc::now()); + let _ = self.db.attack_chain_nodes().insert_one(&node).await; + + // Emit tool start event + let _ = self.event_tx.send(PentestEvent::ToolStart { + node_id: node_id.clone(), + tool_name: tc.name.clone(), + input: tc.arguments.clone(), + }); + + // Execute the tool + let result = if let Some(tool) = self.tool_registry.get(&tc.name) { + match tool.execute(tc.arguments.clone(), &tool_context).await { + Ok(result) => { + total_successes += 1; + let findings_count = result.findings.len() as u32; + total_findings += findings_count; + + // Store findings + for mut finding in result.findings { + finding.scan_run_id = session_id.clone(); + finding.session_id = Some(session_id.clone()); + let _ = + self.db.dast_findings().insert_one(&finding).await; + + let _ = + self.event_tx.send(PentestEvent::Finding { + finding_id: finding + .id + .map(|oid| oid.to_hex()) + .unwrap_or_default(), + title: finding.title.clone(), + severity: finding.severity.to_string(), + }); + } + + // Emit tool complete event + let _ = self.event_tx.send(PentestEvent::ToolComplete { + node_id: node_id.clone(), + summary: result.summary.clone(), + findings_count, + }); + + // Update attack chain node + let _ = self + .db + .attack_chain_nodes() + .update_one( + mongodb::bson::doc! { + "session_id": &session_id, + "node_id": &node_id, + }, + mongodb::bson::doc! { "$set": { + "status": "completed", + "tool_output": mongodb::bson::to_bson(&result.data) + .unwrap_or(mongodb::bson::Bson::Null), + "completed_at": mongodb::bson::DateTime::now(), + }}, + ) + .await; + + serde_json::json!({ + "summary": result.summary, + "findings_count": findings_count, + "data": result.data, + }) + .to_string() + } + Err(e) => { + // Update node as failed + let _ = self + .db + .attack_chain_nodes() + .update_one( + mongodb::bson::doc! { + "session_id": &session_id, + "node_id": &node_id, + }, + mongodb::bson::doc! { "$set": { + "status": "failed", + "completed_at": mongodb::bson::DateTime::now(), + }}, + ) + .await; + + format!("Tool execution failed: {e}") + } + } + } else { + format!("Unknown tool: {}", tc.name) + }; + + // Add tool result to messages + messages.push(ChatMessage { + role: "tool".to_string(), + content: Some(result), + tool_calls: None, + tool_call_id: Some(tc.id.clone()), + }); + } + + // Update session stats + if let Some(sid) = session.id { + let _ = self + .db + .pentest_sessions() + .update_one( + mongodb::bson::doc! { "_id": sid }, + mongodb::bson::doc! { "$set": { + "tool_invocations": total_tool_calls as i64, + "tool_successes": total_successes as i64, + "findings_count": total_findings as i64, + }}, + ) + .await; + } + } + } + } + + // Mark session as completed + if let Some(sid) = session.id { + let _ = self + .db + .pentest_sessions() + .update_one( + mongodb::bson::doc! { "_id": sid }, + mongodb::bson::doc! { "$set": { + "status": "completed", + "completed_at": mongodb::bson::DateTime::now(), + "tool_invocations": total_tool_calls as i64, + "tool_successes": total_successes as i64, + "findings_count": total_findings as i64, + }}, + ) + .await; + } + + let _ = self.event_tx.send(PentestEvent::Complete { + summary: format!( + "Pentest complete. {} findings from {} tool invocations.", + total_findings, total_tool_calls + ), + }); + + Ok(()) + } + + fn build_system_prompt(&self, session: &PentestSession, target: &DastTarget) -> String { + let tool_names = self.tool_registry.list_names().join(", "); + let strategy_guidance = match session.strategy { + PentestStrategy::Quick => { + "Focus on the most common and impactful vulnerabilities. Run a quick recon, then target the highest-risk areas." + } + PentestStrategy::Comprehensive => { + "Perform a thorough assessment covering all vulnerability types. Start with recon, then systematically test each attack surface." + } + PentestStrategy::Targeted => { + "Focus specifically on areas highlighted by SAST findings and known CVEs. Prioritize exploiting known weaknesses." + } + PentestStrategy::Aggressive => { + "Use all available tools aggressively. Test with maximum payloads and attempt full exploitation." + } + PentestStrategy::Stealth => { + "Minimize noise. Use fewer requests, avoid aggressive payloads. Focus on passive analysis and targeted probes." + } + }; + + format!( + r#"You are an expert penetration tester conducting an authorized security assessment. + +## Target +- **Name**: {target_name} +- **URL**: {base_url} +- **Type**: {target_type} +- **Rate Limit**: {rate_limit} req/s +- **Destructive Tests Allowed**: {allow_destructive} + +## Strategy +{strategy_guidance} + +## Available Tools +{tool_names} + +## Instructions +1. Start by running reconnaissance and crawling to understand the target. +2. Based on what you discover, select appropriate vulnerability scanning tools. +3. For each tool invocation, provide the discovered endpoints and parameters. +4. Analyze tool results and chain findings — if you find one vulnerability, explore whether it enables others. +5. When testing is complete, provide a summary of all findings with severity and remediation recommendations. +6. Always explain your reasoning before invoking each tool. +7. Focus on actionable findings with evidence. Avoid false positives. +8. When you have completed all relevant testing, say "Testing complete" followed by a final summary. + +## Important +- This is an authorized penetration test. All testing is permitted within the target scope. +- Respect the rate limit of {rate_limit} requests per second. +- Only use destructive tests if explicitly allowed ({allow_destructive}). +"#, + target_name = target.name, + base_url = target.base_url, + target_type = target.target_type, + rate_limit = target.rate_limit, + allow_destructive = target.allow_destructive, + ) + } +} diff --git a/compliance-core/src/models/dast.rs b/compliance-core/src/models/dast.rs index 7ff9599..4972aca 100644 --- a/compliance-core/src/models/dast.rs +++ b/compliance-core/src/models/dast.rs @@ -176,6 +176,16 @@ pub enum DastVulnType { InformationDisclosure, SecurityMisconfiguration, BrokenAuth, + DnsMisconfiguration, + EmailSecurity, + TlsMisconfiguration, + CookieSecurity, + CspIssue, + CorsMisconfiguration, + RateLimitAbsent, + ConsoleLogLeakage, + SecurityHeaderMissing, + KnownCveExploit, Other, } @@ -192,6 +202,16 @@ impl std::fmt::Display for DastVulnType { Self::InformationDisclosure => write!(f, "information_disclosure"), Self::SecurityMisconfiguration => write!(f, "security_misconfiguration"), Self::BrokenAuth => write!(f, "broken_auth"), + Self::DnsMisconfiguration => write!(f, "dns_misconfiguration"), + Self::EmailSecurity => write!(f, "email_security"), + Self::TlsMisconfiguration => write!(f, "tls_misconfiguration"), + Self::CookieSecurity => write!(f, "cookie_security"), + Self::CspIssue => write!(f, "csp_issue"), + Self::CorsMisconfiguration => write!(f, "cors_misconfiguration"), + Self::RateLimitAbsent => write!(f, "rate_limit_absent"), + Self::ConsoleLogLeakage => write!(f, "console_log_leakage"), + Self::SecurityHeaderMissing => write!(f, "security_header_missing"), + Self::KnownCveExploit => write!(f, "known_cve_exploit"), Self::Other => write!(f, "other"), } } @@ -244,6 +264,8 @@ pub struct DastFinding { pub remediation: Option, /// Linked SAST finding ID (if correlated) pub linked_sast_finding_id: Option, + /// Pentest session that produced this finding (if AI-driven) + pub session_id: Option, #[serde(with = "super::serde_helpers::bson_datetime")] pub created_at: DateTime, } @@ -276,6 +298,7 @@ impl DastFinding { evidence: Vec::new(), remediation: None, linked_sast_finding_id: None, + session_id: None, created_at: Utc::now(), } } diff --git a/compliance-core/src/models/mod.rs b/compliance-core/src/models/mod.rs index 8d9f064..daf0503 100644 --- a/compliance-core/src/models/mod.rs +++ b/compliance-core/src/models/mod.rs @@ -7,6 +7,7 @@ pub mod finding; pub mod graph; pub mod issue; pub mod mcp; +pub mod pentest; pub mod repository; pub mod sbom; pub mod scan; @@ -26,6 +27,11 @@ pub use graph::{ }; pub use issue::{IssueStatus, TrackerIssue, TrackerType}; pub use mcp::{McpServerConfig, McpServerStatus, McpTransport}; +pub use pentest::{ + AttackChainNode, AttackNodeStatus, CodeContextHint, PentestEvent, PentestMessage, + PentestSession, PentestStats, PentestStatus, PentestStrategy, SeverityDistribution, + ToolCallRecord, +}; pub use repository::{ScanTrigger, TrackedRepository}; pub use sbom::{SbomEntry, VulnRef}; pub use scan::{ScanPhase, ScanRun, ScanRunStatus, ScanType}; diff --git a/compliance-core/src/models/pentest.rs b/compliance-core/src/models/pentest.rs new file mode 100644 index 0000000..dc0fa4d --- /dev/null +++ b/compliance-core/src/models/pentest.rs @@ -0,0 +1,294 @@ +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; + +/// Status of a pentest session +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +#[serde(rename_all = "snake_case")] +pub enum PentestStatus { + Running, + Paused, + Completed, + Failed, +} + +impl std::fmt::Display for PentestStatus { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Running => write!(f, "running"), + Self::Paused => write!(f, "paused"), + Self::Completed => write!(f, "completed"), + Self::Failed => write!(f, "failed"), + } + } +} + +/// Strategy for the AI pentest orchestrator +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +#[serde(rename_all = "snake_case")] +pub enum PentestStrategy { + /// Quick scan focusing on common vulnerabilities + Quick, + /// Standard comprehensive scan + Comprehensive, + /// Focus on specific vulnerability types guided by SAST/SBOM + Targeted, + /// Aggressive testing with more payloads and deeper exploitation + Aggressive, + /// Stealth mode with slower rate and fewer noisy payloads + Stealth, +} + +impl std::fmt::Display for PentestStrategy { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Quick => write!(f, "quick"), + Self::Comprehensive => write!(f, "comprehensive"), + Self::Targeted => write!(f, "targeted"), + Self::Aggressive => write!(f, "aggressive"), + Self::Stealth => write!(f, "stealth"), + } + } +} + +/// A pentest session initiated via the chat interface +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct PentestSession { + #[serde(rename = "_id", skip_serializing_if = "Option::is_none")] + pub id: Option, + pub target_id: String, + /// Linked repository for code-aware testing + pub repo_id: Option, + pub status: PentestStatus, + pub strategy: PentestStrategy, + pub created_by: Option, + /// Total number of tool invocations in this session + pub tool_invocations: u32, + /// Total successful tool invocations + pub tool_successes: u32, + /// Number of findings discovered + pub findings_count: u32, + /// Number of confirmed exploitable findings + pub exploitable_count: u32, + #[serde(with = "super::serde_helpers::bson_datetime")] + pub started_at: DateTime, + #[serde(default, with = "super::serde_helpers::opt_bson_datetime")] + pub completed_at: Option>, +} + +impl PentestSession { + pub fn new(target_id: String, strategy: PentestStrategy) -> Self { + Self { + id: None, + target_id, + repo_id: None, + status: PentestStatus::Running, + strategy, + created_by: None, + tool_invocations: 0, + tool_successes: 0, + findings_count: 0, + exploitable_count: 0, + started_at: Utc::now(), + completed_at: None, + } + } + + pub fn success_rate(&self) -> f64 { + if self.tool_invocations == 0 { + return 100.0; + } + (self.tool_successes as f64 / self.tool_invocations as f64) * 100.0 + } +} + +/// Status of a node in the attack chain +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +#[serde(rename_all = "snake_case")] +pub enum AttackNodeStatus { + Pending, + Running, + Completed, + Failed, + Skipped, +} + +/// A single step in the LLM-driven attack chain DAG +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct AttackChainNode { + #[serde(rename = "_id", skip_serializing_if = "Option::is_none")] + pub id: Option, + pub session_id: String, + /// Unique ID for DAG references + pub node_id: String, + /// Parent node IDs (multiple for merge nodes) + pub parent_node_ids: Vec, + /// Tool that was invoked + pub tool_name: String, + /// Input parameters passed to the tool + pub tool_input: serde_json::Value, + /// Output from the tool + pub tool_output: Option, + pub status: AttackNodeStatus, + /// LLM's reasoning for choosing this action + pub llm_reasoning: String, + /// IDs of DastFindings produced by this step + pub findings_produced: Vec, + /// Risk score (0-100) assigned by the LLM + pub risk_score: Option, + #[serde(default, with = "super::serde_helpers::opt_bson_datetime")] + pub started_at: Option>, + #[serde(default, with = "super::serde_helpers::opt_bson_datetime")] + pub completed_at: Option>, +} + +impl AttackChainNode { + pub fn new( + session_id: String, + node_id: String, + tool_name: String, + tool_input: serde_json::Value, + llm_reasoning: String, + ) -> Self { + Self { + id: None, + session_id, + node_id, + parent_node_ids: Vec::new(), + tool_name, + tool_input, + tool_output: None, + status: AttackNodeStatus::Pending, + llm_reasoning, + findings_produced: Vec::new(), + risk_score: None, + started_at: None, + completed_at: None, + } + } +} + +/// Chat message within a pentest session +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct PentestMessage { + #[serde(rename = "_id", skip_serializing_if = "Option::is_none")] + pub id: Option, + pub session_id: String, + /// "user", "assistant", "tool_result", "system" + pub role: String, + pub content: String, + /// Tool calls made by the assistant in this message + pub tool_calls: Option>, + /// Link to the attack chain node (for tool_result messages) + pub attack_node_id: Option, + #[serde(with = "super::serde_helpers::bson_datetime")] + pub created_at: DateTime, +} + +impl PentestMessage { + pub fn user(session_id: String, content: String) -> Self { + Self { + id: None, + session_id, + role: "user".to_string(), + content, + tool_calls: None, + attack_node_id: None, + created_at: Utc::now(), + } + } + + pub fn assistant(session_id: String, content: String) -> Self { + Self { + id: None, + session_id, + role: "assistant".to_string(), + content, + tool_calls: None, + attack_node_id: None, + created_at: Utc::now(), + } + } + + pub fn tool_result(session_id: String, content: String, node_id: String) -> Self { + Self { + id: None, + session_id, + role: "tool_result".to_string(), + content, + tool_calls: None, + attack_node_id: Some(node_id), + created_at: Utc::now(), + } + } +} + +/// Record of a tool call made by the LLM +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ToolCallRecord { + pub call_id: String, + pub tool_name: String, + pub arguments: serde_json::Value, + pub result: Option, +} + +/// SSE event types for real-time pentest streaming +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum PentestEvent { + /// LLM is thinking/reasoning + Thinking { reasoning: String }, + /// A tool execution has started + ToolStart { + node_id: String, + tool_name: String, + input: serde_json::Value, + }, + /// A tool execution completed + ToolComplete { + node_id: String, + summary: String, + findings_count: u32, + }, + /// A new finding was discovered + Finding { finding_id: String, title: String, severity: String }, + /// Assistant message (streaming text) + Message { content: String }, + /// Session completed + Complete { summary: String }, + /// Error occurred + Error { message: String }, +} + +/// Aggregated stats for the pentest dashboard +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct PentestStats { + pub running_sessions: u32, + pub total_vulnerabilities: u32, + pub total_tool_invocations: u32, + pub tool_success_rate: f64, + pub severity_distribution: SeverityDistribution, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct SeverityDistribution { + pub critical: u32, + pub high: u32, + pub medium: u32, + pub low: u32, + pub info: u32, +} + +/// Code context hint linking a discovered endpoint to source code +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct CodeContextHint { + /// HTTP route pattern (e.g., "GET /api/users/:id") + pub endpoint_pattern: String, + /// Handler function name + pub handler_function: String, + /// Source file path + pub file_path: String, + /// Relevant code snippet + pub code_snippet: String, + /// SAST findings associated with this code + pub known_vulnerabilities: Vec, +} diff --git a/compliance-core/src/traits/mod.rs b/compliance-core/src/traits/mod.rs index 2d10e8b..4677153 100644 --- a/compliance-core/src/traits/mod.rs +++ b/compliance-core/src/traits/mod.rs @@ -1,9 +1,11 @@ pub mod dast_agent; pub mod graph_builder; pub mod issue_tracker; +pub mod pentest_tool; pub mod scanner; pub use dast_agent::{DastAgent, DastContext, DiscoveredEndpoint, EndpointParameter}; pub use graph_builder::{LanguageParser, ParseOutput}; pub use issue_tracker::IssueTracker; +pub use pentest_tool::{PentestTool, PentestToolContext, PentestToolResult}; pub use scanner::{ScanOutput, Scanner}; diff --git a/compliance-core/src/traits/pentest_tool.rs b/compliance-core/src/traits/pentest_tool.rs new file mode 100644 index 0000000..ad2d73c --- /dev/null +++ b/compliance-core/src/traits/pentest_tool.rs @@ -0,0 +1,63 @@ +use std::future::Future; +use std::pin::Pin; + +use crate::error::CoreError; +use crate::models::dast::{DastFinding, DastTarget}; +use crate::models::finding::Finding; +use crate::models::pentest::CodeContextHint; +use crate::models::sbom::SbomEntry; + +/// Context passed to pentest tools during execution. +/// +/// The HTTP client is not included here because `compliance-core` does not +/// depend on `reqwest`. Tools that need HTTP should hold their own client +/// or receive one via the `compliance-dast` orchestrator. +pub struct PentestToolContext { + /// The DAST target being tested + pub target: DastTarget, + /// Session ID for this pentest run + pub session_id: String, + /// SAST findings for the linked repo (if any) + pub sast_findings: Vec, + /// SBOM entries with known CVEs (if any) + pub sbom_entries: Vec, + /// Code knowledge graph hints mapping endpoints to source code + pub code_context: Vec, + /// Rate limit (requests per second) + pub rate_limit: u32, + /// Whether destructive operations are allowed + pub allow_destructive: bool, +} + +/// Result from a pentest tool execution +pub struct PentestToolResult { + /// Human-readable summary of what the tool found + pub summary: String, + /// DAST findings produced by this tool + pub findings: Vec, + /// Tool-specific structured output data + pub data: serde_json::Value, +} + +/// A tool that the LLM pentest orchestrator can invoke. +/// +/// Each tool represents a specific security testing capability +/// (e.g., SQL injection scanner, DNS checker, TLS analyzer). +/// Uses boxed futures for dyn-compatibility. +pub trait PentestTool: Send + Sync { + /// Tool name for LLM tool_use (e.g., "sql_injection_scanner") + fn name(&self) -> &str; + + /// Human-readable description for the LLM system prompt + fn description(&self) -> &str; + + /// JSON Schema for the tool's input parameters + fn input_schema(&self) -> serde_json::Value; + + /// Execute the tool with the given input + fn execute<'a>( + &'a self, + input: serde_json::Value, + context: &'a PentestToolContext, + ) -> Pin> + Send + 'a>>; +} diff --git a/compliance-dashboard/src/app.rs b/compliance-dashboard/src/app.rs index 91e382f..e64bd68 100644 --- a/compliance-dashboard/src/app.rs +++ b/compliance-dashboard/src/app.rs @@ -38,6 +38,10 @@ pub enum Route { DastFindingsPage {}, #[route("/dast/findings/:id")] DastFindingDetailPage { id: String }, + #[route("/pentest")] + PentestDashboardPage {}, + #[route("/pentest/:session_id")] + PentestSessionPage { session_id: String }, #[route("/mcp-servers")] McpServersPage {}, #[route("/settings")] diff --git a/compliance-dashboard/src/components/sidebar.rs b/compliance-dashboard/src/components/sidebar.rs index 4356c1a..4522fae 100644 --- a/compliance-dashboard/src/components/sidebar.rs +++ b/compliance-dashboard/src/components/sidebar.rs @@ -47,6 +47,11 @@ pub fn Sidebar() -> Element { route: Route::DastOverviewPage {}, icon: rsx! { Icon { icon: BsBug, width: 18, height: 18 } }, }, + NavItem { + label: "Pentest", + route: Route::PentestDashboardPage {}, + icon: rsx! { Icon { icon: BsLightningCharge, width: 18, height: 18 } }, + }, NavItem { label: "Settings", route: Route::SettingsPage {}, @@ -78,6 +83,7 @@ pub fn Sidebar() -> Element { (Route::DastTargetsPage {}, Route::DastOverviewPage {}) => true, (Route::DastFindingsPage {}, Route::DastOverviewPage {}) => true, (Route::DastFindingDetailPage { .. }, Route::DastOverviewPage {}) => true, + (Route::PentestSessionPage { .. }, Route::PentestDashboardPage {}) => true, (a, b) => a == b, }; let class = if is_active { "nav-item active" } else { "nav-item" }; diff --git a/compliance-dashboard/src/infrastructure/mod.rs b/compliance-dashboard/src/infrastructure/mod.rs index 1033ae9..490c63b 100644 --- a/compliance-dashboard/src/infrastructure/mod.rs +++ b/compliance-dashboard/src/infrastructure/mod.rs @@ -7,6 +7,7 @@ pub mod findings; pub mod graph; pub mod issues; pub mod mcp; +pub mod pentest; #[allow(clippy::too_many_arguments)] pub mod repositories; pub mod sbom; diff --git a/compliance-dashboard/src/infrastructure/pentest.rs b/compliance-dashboard/src/infrastructure/pentest.rs new file mode 100644 index 0000000..6546eb3 --- /dev/null +++ b/compliance-dashboard/src/infrastructure/pentest.rs @@ -0,0 +1,190 @@ +use dioxus::prelude::*; +use serde::{Deserialize, Serialize}; + +use super::dast::DastFindingsResponse; + +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct PentestSessionsResponse { + pub data: Vec, + pub total: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct PentestSessionResponse { + pub data: serde_json::Value, +} + +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct PentestMessagesResponse { + pub data: Vec, + pub total: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct PentestStatsResponse { + pub data: serde_json::Value, +} + +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct AttackChainResponse { + pub data: Vec, +} + +#[server] +pub async fn fetch_pentest_sessions() -> Result { + let state: super::server_state::ServerState = + dioxus_fullstack::FullstackContext::extract().await?; + let url = format!("{}/api/v1/pentest/sessions", state.agent_api_url); + let resp = reqwest::get(&url) + .await + .map_err(|e| ServerFnError::new(e.to_string()))?; + let body: PentestSessionsResponse = resp + .json() + .await + .map_err(|e| ServerFnError::new(e.to_string()))?; + Ok(body) +} + +#[server] +pub async fn fetch_pentest_session(id: String) -> Result { + let state: super::server_state::ServerState = + dioxus_fullstack::FullstackContext::extract().await?; + let url = format!("{}/api/v1/pentest/sessions/{id}", state.agent_api_url); + let resp = reqwest::get(&url) + .await + .map_err(|e| ServerFnError::new(e.to_string()))?; + let body: PentestSessionResponse = resp + .json() + .await + .map_err(|e| ServerFnError::new(e.to_string()))?; + Ok(body) +} + +#[server] +pub async fn fetch_pentest_messages( + session_id: String, +) -> Result { + let state: super::server_state::ServerState = + dioxus_fullstack::FullstackContext::extract().await?; + let url = format!( + "{}/api/v1/pentest/sessions/{session_id}/messages", + state.agent_api_url + ); + let resp = reqwest::get(&url) + .await + .map_err(|e| ServerFnError::new(e.to_string()))?; + let body: PentestMessagesResponse = resp + .json() + .await + .map_err(|e| ServerFnError::new(e.to_string()))?; + Ok(body) +} + +#[server] +pub async fn fetch_pentest_stats() -> Result { + let state: super::server_state::ServerState = + dioxus_fullstack::FullstackContext::extract().await?; + let url = format!("{}/api/v1/pentest/stats", state.agent_api_url); + let resp = reqwest::get(&url) + .await + .map_err(|e| ServerFnError::new(e.to_string()))?; + let body: PentestStatsResponse = resp + .json() + .await + .map_err(|e| ServerFnError::new(e.to_string()))?; + Ok(body) +} + +#[server] +pub async fn fetch_attack_chain( + session_id: String, +) -> Result { + let state: super::server_state::ServerState = + dioxus_fullstack::FullstackContext::extract().await?; + let url = format!( + "{}/api/v1/pentest/sessions/{session_id}/attack-chain", + state.agent_api_url + ); + let resp = reqwest::get(&url) + .await + .map_err(|e| ServerFnError::new(e.to_string()))?; + let body: AttackChainResponse = resp + .json() + .await + .map_err(|e| ServerFnError::new(e.to_string()))?; + Ok(body) +} + +#[server] +pub async fn create_pentest_session( + target_id: String, + strategy: String, + message: String, +) -> Result { + let state: super::server_state::ServerState = + dioxus_fullstack::FullstackContext::extract().await?; + let url = format!("{}/api/v1/pentest/sessions", state.agent_api_url); + let client = reqwest::Client::new(); + let resp = client + .post(&url) + .json(&serde_json::json!({ + "target_id": target_id, + "strategy": strategy, + "message": message, + })) + .send() + .await + .map_err(|e| ServerFnError::new(e.to_string()))?; + let body: PentestSessionResponse = resp + .json() + .await + .map_err(|e| ServerFnError::new(e.to_string()))?; + Ok(body) +} + +#[server] +pub async fn send_pentest_message( + session_id: String, + message: String, +) -> Result { + let state: super::server_state::ServerState = + dioxus_fullstack::FullstackContext::extract().await?; + let url = format!( + "{}/api/v1/pentest/sessions/{session_id}/messages", + state.agent_api_url + ); + let client = reqwest::Client::new(); + let resp = client + .post(&url) + .json(&serde_json::json!({ + "message": message, + })) + .send() + .await + .map_err(|e| ServerFnError::new(e.to_string()))?; + let body: PentestMessagesResponse = resp + .json() + .await + .map_err(|e| ServerFnError::new(e.to_string()))?; + Ok(body) +} + +#[server] +pub async fn fetch_pentest_findings( + session_id: String, +) -> Result { + let state: super::server_state::ServerState = + dioxus_fullstack::FullstackContext::extract().await?; + let url = format!( + "{}/api/v1/pentest/sessions/{session_id}/findings", + state.agent_api_url + ); + let resp = reqwest::get(&url) + .await + .map_err(|e| ServerFnError::new(e.to_string()))?; + let body: DastFindingsResponse = resp + .json() + .await + .map_err(|e| ServerFnError::new(e.to_string()))?; + Ok(body) +} diff --git a/compliance-dashboard/src/pages/mod.rs b/compliance-dashboard/src/pages/mod.rs index 623ec4a..bdc9281 100644 --- a/compliance-dashboard/src/pages/mod.rs +++ b/compliance-dashboard/src/pages/mod.rs @@ -12,6 +12,8 @@ pub mod impact_analysis; pub mod issues; pub mod mcp_servers; pub mod overview; +pub mod pentest_dashboard; +pub mod pentest_session; pub mod repositories; pub mod sbom; pub mod settings; @@ -30,6 +32,8 @@ pub use impact_analysis::ImpactAnalysisPage; pub use issues::IssuesPage; pub use mcp_servers::McpServersPage; pub use overview::OverviewPage; +pub use pentest_dashboard::PentestDashboardPage; +pub use pentest_session::PentestSessionPage; pub use repositories::RepositoriesPage; pub use sbom::SbomPage; pub use settings::SettingsPage; diff --git a/compliance-dashboard/src/pages/pentest_dashboard.rs b/compliance-dashboard/src/pages/pentest_dashboard.rs new file mode 100644 index 0000000..647f581 --- /dev/null +++ b/compliance-dashboard/src/pages/pentest_dashboard.rs @@ -0,0 +1,396 @@ +use dioxus::prelude::*; +use dioxus_free_icons::icons::bs_icons::*; +use dioxus_free_icons::Icon; + +use crate::app::Route; +use crate::components::page_header::PageHeader; +use crate::infrastructure::dast::fetch_dast_targets; +use crate::infrastructure::pentest::{ + create_pentest_session, fetch_pentest_sessions, fetch_pentest_stats, +}; + +#[component] +pub fn PentestDashboardPage() -> Element { + let mut sessions = use_resource(|| async { fetch_pentest_sessions().await.ok() }); + let stats = use_resource(|| async { fetch_pentest_stats().await.ok() }); + let targets = use_resource(|| async { fetch_dast_targets().await.ok() }); + + let mut show_modal = use_signal(|| false); + let mut new_target_id = use_signal(String::new); + let mut new_strategy = use_signal(|| "comprehensive".to_string()); + let mut new_message = use_signal(String::new); + let mut creating = use_signal(|| false); + + let on_create = move |_| { + let tid = new_target_id.read().clone(); + let strat = new_strategy.read().clone(); + let msg = new_message.read().clone(); + if tid.is_empty() || msg.is_empty() { + return; + } + creating.set(true); + spawn(async move { + match create_pentest_session(tid, strat, msg).await { + Ok(resp) => { + let session_id = resp + .data + .get("id") + .and_then(|v| v.as_str()) + .unwrap_or("") + .to_string(); + creating.set(false); + show_modal.set(false); + new_target_id.set(String::new()); + new_message.set(String::new()); + if !session_id.is_empty() { + navigator().push(Route::PentestSessionPage { + session_id: session_id.clone(), + }); + } else { + sessions.restart(); + } + } + Err(_) => { + creating.set(false); + } + } + }); + }; + + // Extract stats values + let running_sessions = { + let s = stats.read(); + match &*s { + Some(Some(data)) => data + .data + .get("running_sessions") + .and_then(|v| v.as_u64()) + .unwrap_or(0), + _ => 0, + } + }; + let total_vulns = { + let s = stats.read(); + match &*s { + Some(Some(data)) => data + .data + .get("total_vulnerabilities") + .and_then(|v| v.as_u64()) + .unwrap_or(0), + _ => 0, + } + }; + let tool_invocations = { + let s = stats.read(); + match &*s { + Some(Some(data)) => data + .data + .get("tool_invocations") + .and_then(|v| v.as_u64()) + .unwrap_or(0), + _ => 0, + } + }; + let success_rate = { + let s = stats.read(); + match &*s { + Some(Some(data)) => data + .data + .get("success_rate") + .and_then(|v| v.as_f64()) + .unwrap_or(0.0), + _ => 0.0, + } + }; + + // Severity counts from stats + let severity_critical = { + let s = stats.read(); + match &*s { + Some(Some(data)) => data + .data + .get("severity_critical") + .and_then(|v| v.as_u64()) + .unwrap_or(0), + _ => 0, + } + }; + let severity_high = { + let s = stats.read(); + match &*s { + Some(Some(data)) => data + .data + .get("severity_high") + .and_then(|v| v.as_u64()) + .unwrap_or(0), + _ => 0, + } + }; + let severity_medium = { + let s = stats.read(); + match &*s { + Some(Some(data)) => data + .data + .get("severity_medium") + .and_then(|v| v.as_u64()) + .unwrap_or(0), + _ => 0, + } + }; + let severity_low = { + let s = stats.read(); + match &*s { + Some(Some(data)) => data + .data + .get("severity_low") + .and_then(|v| v.as_u64()) + .unwrap_or(0), + _ => 0, + } + }; + + rsx! { + PageHeader { + title: "Pentest Dashboard", + description: "AI-powered penetration testing sessions — autonomous security assessment", + } + + // Stat cards + div { class: "stat-cards", style: "margin-bottom: 24px;", + div { class: "stat-card-item", + div { class: "stat-card-value", "{running_sessions}" } + div { class: "stat-card-label", + Icon { icon: BsPlayCircle, width: 14, height: 14 } + " Running Sessions" + } + } + div { class: "stat-card-item", + div { class: "stat-card-value", "{total_vulns}" } + div { class: "stat-card-label", + Icon { icon: BsShieldExclamation, width: 14, height: 14 } + " Total Vulnerabilities" + } + } + div { class: "stat-card-item", + div { class: "stat-card-value", "{tool_invocations}" } + div { class: "stat-card-label", + Icon { icon: BsWrench, width: 14, height: 14 } + " Tool Invocations" + } + } + div { class: "stat-card-item", + div { class: "stat-card-value", "{success_rate:.0}%" } + div { class: "stat-card-label", + Icon { icon: BsCheckCircle, width: 14, height: 14 } + " Success Rate" + } + } + } + + // Severity distribution + div { class: "card", style: "margin-bottom: 24px; padding: 16px;", + div { style: "display: flex; align-items: center; gap: 16px; flex-wrap: wrap;", + span { style: "font-weight: 600; color: var(--text-secondary); font-size: 0.85rem;", "Severity Distribution" } + span { + class: "badge", + style: "background: #dc2626; color: #fff;", + "Critical: {severity_critical}" + } + span { + class: "badge", + style: "background: #ea580c; color: #fff;", + "High: {severity_high}" + } + span { + class: "badge", + style: "background: #d97706; color: #fff;", + "Medium: {severity_medium}" + } + span { + class: "badge", + style: "background: #2563eb; color: #fff;", + "Low: {severity_low}" + } + } + } + + // Actions row + div { style: "display: flex; gap: 12px; margin-bottom: 24px;", + button { + class: "btn btn-primary", + onclick: move |_| show_modal.set(true), + Icon { icon: BsPlusCircle, width: 14, height: 14 } + " New Pentest" + } + } + + // Sessions list + div { class: "card", + div { class: "card-header", "Recent Pentest Sessions" } + match &*sessions.read() { + Some(Some(data)) => { + let sess_list = &data.data; + if sess_list.is_empty() { + rsx! { + div { style: "padding: 32px; text-align: center; color: var(--text-secondary);", + p { "No pentest sessions yet. Start one to begin autonomous security testing." } + } + } + } else { + rsx! { + div { style: "display: grid; gap: 12px; padding: 16px;", + for session in sess_list { + { + let id = session.get("id").and_then(|v| v.as_str()).unwrap_or("-").to_string(); + let target_name = session.get("target_name").and_then(|v| v.as_str()).unwrap_or("Unknown Target").to_string(); + let status = session.get("status").and_then(|v| v.as_str()).unwrap_or("unknown").to_string(); + let strategy = session.get("strategy").and_then(|v| v.as_str()).unwrap_or("-").to_string(); + let findings_count = session.get("findings_count").and_then(|v| v.as_u64()).unwrap_or(0); + let tool_count = session.get("tool_invocations").and_then(|v| v.as_u64()).unwrap_or(0); + let created_at = session.get("created_at").and_then(|v| v.as_str()).unwrap_or("-").to_string(); + let status_style = match status.as_str() { + "running" => "background: #16a34a; color: #fff;", + "completed" => "background: #2563eb; color: #fff;", + "failed" => "background: #dc2626; color: #fff;", + "paused" => "background: #d97706; color: #fff;", + _ => "background: var(--bg-tertiary); color: var(--text-secondary);", + }; + rsx! { + Link { + to: Route::PentestSessionPage { session_id: id.clone() }, + class: "card", + style: "padding: 16px; text-decoration: none; cursor: pointer; transition: border-color 0.15s;", + div { style: "display: flex; justify-content: space-between; align-items: flex-start;", + div { + div { style: "font-weight: 600; font-size: 1rem; margin-bottom: 4px; color: var(--text-primary);", + "{target_name}" + } + div { style: "display: flex; gap: 8px; align-items: center; flex-wrap: wrap;", + span { + class: "badge", + style: "{status_style}", + "{status}" + } + span { + class: "badge", + style: "background: var(--bg-tertiary); color: var(--text-secondary);", + "{strategy}" + } + } + } + div { style: "text-align: right; font-size: 0.85rem; color: var(--text-secondary);", + div { style: "margin-bottom: 4px;", + Icon { icon: BsShieldExclamation, width: 12, height: 12 } + " {findings_count} findings" + } + div { style: "margin-bottom: 4px;", + Icon { icon: BsWrench, width: 12, height: 12 } + " {tool_count} tools" + } + div { "{created_at}" } + } + } + } + } + } + } + } + } + } + }, + Some(None) => rsx! { p { style: "padding: 16px;", "Failed to load sessions." } }, + None => rsx! { p { style: "padding: 16px;", "Loading..." } }, + } + } + + // New Pentest Modal + if *show_modal.read() { + div { + style: "position: fixed; inset: 0; background: rgba(0,0,0,0.6); display: flex; align-items: center; justify-content: center; z-index: 1000;", + onclick: move |_| show_modal.set(false), + div { + style: "background: var(--bg-secondary); border: 1px solid var(--border-color); border-radius: 12px; padding: 24px; width: 480px; max-width: 90vw;", + onclick: move |e| e.stop_propagation(), + h3 { style: "margin: 0 0 16px 0;", "New Pentest Session" } + + // Target selection + div { style: "margin-bottom: 12px;", + label { style: "display: block; font-size: 0.85rem; color: var(--text-secondary); margin-bottom: 4px;", + "Target" + } + select { + class: "chat-input", + style: "width: 100%; padding: 8px; resize: none; height: auto;", + value: "{new_target_id}", + onchange: move |e| new_target_id.set(e.value()), + option { value: "", "Select a target..." } + match &*targets.read() { + Some(Some(data)) => { + rsx! { + for target in &data.data { + { + let tid = target.get("id").and_then(|v| v.as_str()).unwrap_or("").to_string(); + let tname = target.get("name").and_then(|v| v.as_str()).unwrap_or("Unknown").to_string(); + let turl = target.get("base_url").and_then(|v| v.as_str()).unwrap_or("").to_string(); + rsx! { + option { value: "{tid}", "{tname} ({turl})" } + } + } + } + } + }, + _ => rsx! {}, + } + } + } + + // Strategy selection + div { style: "margin-bottom: 12px;", + label { style: "display: block; font-size: 0.85rem; color: var(--text-secondary); margin-bottom: 4px;", + "Strategy" + } + select { + class: "chat-input", + style: "width: 100%; padding: 8px; resize: none; height: auto;", + value: "{new_strategy}", + onchange: move |e| new_strategy.set(e.value()), + option { value: "comprehensive", "Comprehensive" } + option { value: "quick", "Quick Scan" } + option { value: "owasp_top_10", "OWASP Top 10" } + option { value: "api_focused", "API Focused" } + option { value: "authentication", "Authentication" } + } + } + + // Initial message + div { style: "margin-bottom: 16px;", + label { style: "display: block; font-size: 0.85rem; color: var(--text-secondary); margin-bottom: 4px;", + "Initial Instructions" + } + textarea { + class: "chat-input", + style: "width: 100%; min-height: 80px;", + placeholder: "Describe the scope and goals of this pentest...", + value: "{new_message}", + oninput: move |e| new_message.set(e.value()), + } + } + + div { style: "display: flex; justify-content: flex-end; gap: 8px;", + button { + class: "btn btn-ghost", + onclick: move |_| show_modal.set(false), + "Cancel" + } + button { + class: "btn btn-primary", + disabled: *creating.read() || new_target_id.read().is_empty() || new_message.read().is_empty(), + onclick: on_create, + if *creating.read() { "Creating..." } else { "Start Pentest" } + } + } + } + } + } + } +} diff --git a/compliance-dashboard/src/pages/pentest_session.rs b/compliance-dashboard/src/pages/pentest_session.rs new file mode 100644 index 0000000..3cf99d1 --- /dev/null +++ b/compliance-dashboard/src/pages/pentest_session.rs @@ -0,0 +1,445 @@ +use dioxus::prelude::*; +use dioxus_free_icons::icons::bs_icons::*; +use dioxus_free_icons::Icon; + +use crate::app::Route; +use crate::infrastructure::pentest::{ + fetch_attack_chain, fetch_pentest_findings, fetch_pentest_messages, fetch_pentest_session, + send_pentest_message, +}; + +#[component] +pub fn PentestSessionPage(session_id: String) -> Element { + let sid = session_id.clone(); + let sid_for_session = session_id.clone(); + let sid_for_findings = session_id.clone(); + let sid_for_chain = session_id.clone(); + + let mut session = use_resource(move || { + let id = sid_for_session.clone(); + async move { fetch_pentest_session(id).await.ok() } + }); + let mut messages_res = use_resource(move || { + let id = sid.clone(); + async move { fetch_pentest_messages(id).await.ok() } + }); + let mut findings = use_resource(move || { + let id = sid_for_findings.clone(); + async move { fetch_pentest_findings(id).await.ok() } + }); + let mut attack_chain = use_resource(move || { + let id = sid_for_chain.clone(); + async move { fetch_attack_chain(id).await.ok() } + }); + + let mut input_text = use_signal(String::new); + let mut sending = use_signal(|| false); + let mut right_tab = use_signal(|| "findings".to_string()); + + // Auto-poll messages every 3s when session is running + let session_status = { + let s = session.read(); + match &*s { + Some(Some(resp)) => resp + .data + .get("status") + .and_then(|v| v.as_str()) + .unwrap_or("unknown") + .to_string(), + _ => "unknown".to_string(), + } + }; + + let is_running = session_status == "running"; + + let sid_for_poll = session_id.clone(); + use_effect(move || { + if is_running { + let _sid = sid_for_poll.clone(); + spawn(async move { + #[cfg(feature = "web")] + gloo_timers::future::TimeoutFuture::new(3_000).await; + #[cfg(not(feature = "web"))] + tokio::time::sleep(std::time::Duration::from_secs(3)).await; + messages_res.restart(); + findings.restart(); + attack_chain.restart(); + session.restart(); + }); + } + }); + + // Send message handler + let sid_for_send = session_id.clone(); + let mut do_send = move || { + let text = input_text.read().trim().to_string(); + if text.is_empty() || *sending.read() { + return; + } + let sid = sid_for_send.clone(); + input_text.set(String::new()); + sending.set(true); + spawn(async move { + let _ = send_pentest_message(sid, text).await; + sending.set(false); + messages_res.restart(); + }); + }; + + let mut do_send_click = do_send.clone(); + + // Session header info + let target_name = { + let s = session.read(); + match &*s { + Some(Some(resp)) => resp + .data + .get("target_name") + .and_then(|v| v.as_str()) + .unwrap_or("Pentest Session") + .to_string(), + _ => "Pentest Session".to_string(), + } + }; + + let strategy = { + let s = session.read(); + match &*s { + Some(Some(resp)) => resp + .data + .get("strategy") + .and_then(|v| v.as_str()) + .unwrap_or("-") + .to_string(), + _ => "-".to_string(), + } + }; + + let header_tool_count = { + let s = session.read(); + match &*s { + Some(Some(resp)) => resp + .data + .get("tool_invocations") + .and_then(|v| v.as_u64()) + .unwrap_or(0), + _ => 0, + } + }; + + let header_findings_count = { + let f = findings.read(); + match &*f { + Some(Some(data)) => data.total.unwrap_or(0), + _ => 0, + } + }; + + let status_style = match session_status.as_str() { + "running" => "background: #16a34a; color: #fff;", + "completed" => "background: #2563eb; color: #fff;", + "failed" => "background: #dc2626; color: #fff;", + "paused" => "background: #d97706; color: #fff;", + _ => "background: var(--bg-tertiary); color: var(--text-secondary);", + }; + + rsx! { + div { class: "back-nav", + Link { + to: Route::PentestDashboardPage {}, + class: "btn btn-ghost btn-back", + Icon { icon: BsArrowLeft, width: 16, height: 16 } + "Back to Pentest Dashboard" + } + } + + // Session header + div { style: "display: flex; align-items: center; justify-content: space-between; margin-bottom: 16px; flex-wrap: wrap; gap: 8px;", + div { + h2 { style: "margin: 0 0 4px 0;", "{target_name}" } + div { style: "display: flex; gap: 8px; align-items: center; flex-wrap: wrap;", + span { class: "badge", style: "{status_style}", "{session_status}" } + span { class: "badge", style: "background: var(--bg-tertiary); color: var(--text-secondary);", + "{strategy}" + } + } + } + div { style: "display: flex; gap: 16px; font-size: 0.85rem; color: var(--text-secondary);", + span { + Icon { icon: BsWrench, width: 14, height: 14 } + " {header_tool_count} tools" + } + span { + Icon { icon: BsShieldExclamation, width: 14, height: 14 } + " {header_findings_count} findings" + } + } + } + + // Split layout: chat left, findings/chain right + div { style: "display: grid; grid-template-columns: 1fr 380px; gap: 16px; height: calc(100vh - 220px); min-height: 400px;", + + // Left: Chat area + div { class: "card", style: "display: flex; flex-direction: column; overflow: hidden;", + div { class: "card-header", style: "flex-shrink: 0;", "Chat" } + + // Messages + div { + style: "flex: 1; overflow-y: auto; padding: 16px; display: flex; flex-direction: column; gap: 12px;", + match &*messages_res.read() { + Some(Some(data)) => { + let msgs = &data.data; + if msgs.is_empty() { + rsx! { + div { style: "text-align: center; color: var(--text-secondary); padding: 32px;", + h3 { style: "margin-bottom: 8px;", "Start the conversation" } + p { "Send a message to guide the pentest agent." } + } + } + } else { + rsx! { + for (i, msg) in msgs.iter().enumerate() { + { + let role = msg.get("role").and_then(|v| v.as_str()).unwrap_or("assistant").to_string(); + let content = msg.get("content").and_then(|v| v.as_str()).unwrap_or("").to_string(); + let msg_type = msg.get("type").and_then(|v| v.as_str()).unwrap_or("text").to_string(); + let tool_name = msg.get("tool_name").and_then(|v| v.as_str()).unwrap_or("").to_string(); + let tool_status = msg.get("tool_status").and_then(|v| v.as_str()).unwrap_or("").to_string(); + + if msg_type == "tool_call" || msg_type == "tool_result" { + // Tool invocation indicator + let tool_icon_style = match tool_status.as_str() { + "success" => "color: #16a34a;", + "error" => "color: #dc2626;", + "running" => "color: #d97706;", + _ => "color: var(--text-secondary);", + }; + rsx! { + div { + key: "{i}", + style: "display: flex; align-items: center; gap: 8px; padding: 6px 12px; background: var(--bg-tertiary); border-radius: 6px; font-size: 0.8rem; color: var(--text-secondary);", + span { style: "{tool_icon_style}", + Icon { icon: BsWrench, width: 12, height: 12 } + } + span { style: "font-family: monospace;", "{tool_name}" } + if !tool_status.is_empty() { + span { class: "badge", style: "font-size: 0.7rem;", "{tool_status}" } + } + if !content.is_empty() { + details { style: "margin-left: auto; cursor: pointer;", + summary { style: "font-size: 0.75rem;", "details" } + pre { style: "margin-top: 4px; padding: 8px; background: var(--bg-primary); border-radius: 4px; font-size: 0.75rem; overflow-x: auto; max-height: 200px; white-space: pre-wrap;", + "{content}" + } + } + } + } + } + } else if role == "user" { + // User message - right aligned + rsx! { + div { + key: "{i}", + style: "display: flex; justify-content: flex-end;", + div { + style: "max-width: 80%; padding: 10px 14px; background: #2563eb; color: #fff; border-radius: 12px 12px 2px 12px; font-size: 0.9rem; line-height: 1.5; white-space: pre-wrap;", + "{content}" + } + } + } + } else { + // Assistant message - left aligned + rsx! { + div { + key: "{i}", + style: "display: flex; gap: 8px; align-items: flex-start;", + div { + style: "flex-shrink: 0; width: 28px; height: 28px; border-radius: 50%; background: var(--bg-tertiary); display: flex; align-items: center; justify-content: center;", + Icon { icon: BsCpu, width: 14, height: 14 } + } + div { + style: "max-width: 80%; padding: 10px 14px; background: var(--bg-tertiary); border-radius: 12px 12px 12px 2px; font-size: 0.9rem; line-height: 1.5; white-space: pre-wrap;", + "{content}" + } + } + } + } + } + } + } + } + }, + Some(None) => rsx! { p { style: "padding: 16px; color: var(--text-secondary);", "Failed to load messages." } }, + None => rsx! { p { style: "padding: 16px; color: var(--text-secondary);", "Loading messages..." } }, + } + + if *sending.read() { + div { style: "display: flex; gap: 8px; align-items: flex-start;", + div { + style: "flex-shrink: 0; width: 28px; height: 28px; border-radius: 50%; background: var(--bg-tertiary); display: flex; align-items: center; justify-content: center;", + Icon { icon: BsCpu, width: 14, height: 14 } + } + div { + style: "padding: 10px 14px; background: var(--bg-tertiary); border-radius: 12px 12px 12px 2px; font-size: 0.9rem; color: var(--text-secondary);", + "Thinking..." + } + } + } + } + + // Input area + div { style: "flex-shrink: 0; padding: 12px; border-top: 1px solid var(--border-color); display: flex; gap: 8px;", + textarea { + class: "chat-input", + style: "flex: 1;", + placeholder: "Guide the pentest agent...", + value: "{input_text}", + oninput: move |e| input_text.set(e.value()), + onkeydown: move |e: Event| { + if e.key() == Key::Enter && !e.modifiers().shift() { + e.prevent_default(); + do_send(); + } + }, + } + button { + class: "btn btn-primary", + style: "align-self: flex-end;", + disabled: *sending.read(), + onclick: move |_| do_send_click(), + "Send" + } + } + } + + // Right: Findings / Attack Chain tabs + div { class: "card", style: "display: flex; flex-direction: column; overflow: hidden;", + // Tab bar + div { style: "display: flex; border-bottom: 1px solid var(--border-color); flex-shrink: 0;", + button { + style: if *right_tab.read() == "findings" { + "flex: 1; padding: 10px; background: none; border: none; border-bottom: 2px solid #2563eb; color: var(--text-primary); cursor: pointer; font-weight: 600;" + } else { + "flex: 1; padding: 10px; background: none; border: none; border-bottom: 2px solid transparent; color: var(--text-secondary); cursor: pointer;" + }, + onclick: move |_| right_tab.set("findings".to_string()), + Icon { icon: BsShieldExclamation, width: 14, height: 14 } + " Findings ({header_findings_count})" + } + button { + style: if *right_tab.read() == "chain" { + "flex: 1; padding: 10px; background: none; border: none; border-bottom: 2px solid #2563eb; color: var(--text-primary); cursor: pointer; font-weight: 600;" + } else { + "flex: 1; padding: 10px; background: none; border: none; border-bottom: 2px solid transparent; color: var(--text-secondary); cursor: pointer;" + }, + onclick: move |_| right_tab.set("chain".to_string()), + Icon { icon: BsDiagram3, width: 14, height: 14 } + " Attack Chain" + } + } + + // Tab content + div { style: "flex: 1; overflow-y: auto; padding: 12px;", + if *right_tab.read() == "findings" { + // Findings tab + match &*findings.read() { + Some(Some(data)) => { + let finding_list = &data.data; + if finding_list.is_empty() { + rsx! { + div { style: "text-align: center; color: var(--text-secondary); padding: 24px;", + p { "No findings yet." } + } + } + } else { + rsx! { + div { style: "display: flex; flex-direction: column; gap: 8px;", + for finding in finding_list { + { + let title = finding.get("title").and_then(|v| v.as_str()).unwrap_or("Untitled").to_string(); + let severity = finding.get("severity").and_then(|v| v.as_str()).unwrap_or("info").to_string(); + let vuln_type = finding.get("vulnerability_type").and_then(|v| v.as_str()).unwrap_or("-").to_string(); + let sev_style = match severity.as_str() { + "critical" => "background: #dc2626; color: #fff;", + "high" => "background: #ea580c; color: #fff;", + "medium" => "background: #d97706; color: #fff;", + "low" => "background: #2563eb; color: #fff;", + _ => "background: var(--bg-tertiary); color: var(--text-secondary);", + }; + rsx! { + div { style: "padding: 10px; background: var(--bg-tertiary); border-radius: 8px;", + div { style: "display: flex; justify-content: space-between; align-items: center; margin-bottom: 4px;", + span { style: "font-weight: 600; font-size: 0.85rem;", "{title}" } + span { class: "badge", style: "{sev_style}", "{severity}" } + } + div { style: "font-size: 0.8rem; color: var(--text-secondary);", "{vuln_type}" } + } + } + } + } + } + } + } + }, + Some(None) => rsx! { p { style: "color: var(--text-secondary);", "Failed to load findings." } }, + None => rsx! { p { style: "color: var(--text-secondary);", "Loading..." } }, + } + } else { + // Attack chain tab + match &*attack_chain.read() { + Some(Some(data)) => { + let steps = &data.data; + if steps.is_empty() { + rsx! { + div { style: "text-align: center; color: var(--text-secondary); padding: 24px;", + p { "No attack chain steps yet." } + } + } + } else { + rsx! { + div { style: "display: flex; flex-direction: column; gap: 4px;", + for (i, step) in steps.iter().enumerate() { + { + let step_name = step.get("name").and_then(|v| v.as_str()).unwrap_or("Step").to_string(); + let step_status = step.get("status").and_then(|v| v.as_str()).unwrap_or("pending").to_string(); + let description = step.get("description").and_then(|v| v.as_str()).unwrap_or("").to_string(); + let step_num = i + 1; + let dot_color = match step_status.as_str() { + "completed" => "#16a34a", + "running" => "#d97706", + "failed" => "#dc2626", + _ => "var(--text-secondary)", + }; + rsx! { + div { style: "display: flex; gap: 10px; padding: 8px 0;", + div { style: "display: flex; flex-direction: column; align-items: center;", + div { style: "width: 10px; height: 10px; border-radius: 50%; background: {dot_color}; flex-shrink: 0;" } + if i < steps.len() - 1 { + div { style: "width: 2px; flex: 1; background: var(--border-color); margin-top: 4px;" } + } + } + div { + div { style: "font-size: 0.85rem; font-weight: 600;", "{step_num}. {step_name}" } + if !description.is_empty() { + div { style: "font-size: 0.8rem; color: var(--text-secondary); margin-top: 2px;", + "{description}" + } + } + } + } + } + } + } + } + } + } + }, + Some(None) => rsx! { p { style: "color: var(--text-secondary);", "Failed to load attack chain." } }, + None => rsx! { p { style: "color: var(--text-secondary);", "Loading..." } }, + } + } + } + } + } + } +} diff --git a/compliance-dast/Cargo.toml b/compliance-dast/Cargo.toml index 4613e24..e91ef99 100644 --- a/compliance-dast/Cargo.toml +++ b/compliance-dast/Cargo.toml @@ -27,6 +27,10 @@ chromiumoxide = { version = "0.7", features = ["tokio-runtime"], default-feature # Docker sandboxing bollard = "0.18" +# TLS analysis +native-tls = "0.2" +tokio-native-tls = "0.3" + # Serialization bson = { version = "2", features = ["chrono-0_4"] } url = "2" diff --git a/compliance-dast/src/lib.rs b/compliance-dast/src/lib.rs index 38c2445..aefc932 100644 --- a/compliance-dast/src/lib.rs +++ b/compliance-dast/src/lib.rs @@ -2,5 +2,7 @@ pub mod agents; pub mod crawler; pub mod orchestrator; pub mod recon; +pub mod tools; pub use orchestrator::DastOrchestrator; +pub use tools::ToolRegistry; diff --git a/compliance-dast/src/tools/api_fuzzer.rs b/compliance-dast/src/tools/api_fuzzer.rs new file mode 100644 index 0000000..39d3f7e --- /dev/null +++ b/compliance-dast/src/tools/api_fuzzer.rs @@ -0,0 +1,146 @@ +use compliance_core::error::CoreError; +use compliance_core::traits::dast_agent::{DastAgent, DastContext, DiscoveredEndpoint, EndpointParameter}; +use compliance_core::traits::pentest_tool::{PentestTool, PentestToolContext, PentestToolResult}; +use serde_json::json; + +use crate::agents::api_fuzzer::ApiFuzzerAgent; + +/// PentestTool wrapper around the existing ApiFuzzerAgent. +pub struct ApiFuzzerTool { + http: reqwest::Client, + agent: ApiFuzzerAgent, +} + +impl ApiFuzzerTool { + pub fn new(http: reqwest::Client) -> Self { + let agent = ApiFuzzerAgent::new(http.clone()); + Self { http, agent } + } + + fn parse_endpoints(input: &serde_json::Value) -> Vec { + let mut endpoints = Vec::new(); + if let Some(arr) = input.get("endpoints").and_then(|v| v.as_array()) { + for ep in arr { + let url = ep.get("url").and_then(|v| v.as_str()).unwrap_or_default().to_string(); + let method = ep.get("method").and_then(|v| v.as_str()).unwrap_or("GET").to_string(); + let mut parameters = Vec::new(); + if let Some(params) = ep.get("parameters").and_then(|v| v.as_array()) { + for p in params { + parameters.push(EndpointParameter { + name: p.get("name").and_then(|v| v.as_str()).unwrap_or_default().to_string(), + location: p.get("location").and_then(|v| v.as_str()).unwrap_or("query").to_string(), + param_type: p.get("param_type").and_then(|v| v.as_str()).map(String::from), + example_value: p.get("example_value").and_then(|v| v.as_str()).map(String::from), + }); + } + } + endpoints.push(DiscoveredEndpoint { + url, + method, + parameters, + content_type: ep.get("content_type").and_then(|v| v.as_str()).map(String::from), + requires_auth: ep.get("requires_auth").and_then(|v| v.as_bool()).unwrap_or(false), + }); + } + } + endpoints + } +} + +impl PentestTool for ApiFuzzerTool { + fn name(&self) -> &str { + "api_fuzzer" + } + + fn description(&self) -> &str { + "Fuzzes API endpoints to discover misconfigurations, information disclosure, and hidden \ + endpoints. Probes common sensitive paths and tests for verbose error messages." + } + + fn input_schema(&self) -> serde_json::Value { + json!({ + "type": "object", + "properties": { + "endpoints": { + "type": "array", + "description": "Known endpoints to fuzz", + "items": { + "type": "object", + "properties": { + "url": { "type": "string" }, + "method": { "type": "string", "enum": ["GET", "POST", "PUT", "PATCH", "DELETE"] }, + "parameters": { + "type": "array", + "items": { + "type": "object", + "properties": { + "name": { "type": "string" }, + "location": { "type": "string" }, + "param_type": { "type": "string" }, + "example_value": { "type": "string" } + }, + "required": ["name"] + } + } + }, + "required": ["url"] + } + }, + "base_url": { + "type": "string", + "description": "Base URL to probe for common sensitive paths (used if no endpoints provided)" + } + } + }) + } + + fn execute<'a>( + &'a self, + input: serde_json::Value, + context: &'a PentestToolContext, + ) -> std::pin::Pin> + Send + 'a>> { + Box::pin(async move { + let mut endpoints = Self::parse_endpoints(&input); + + // If a base_url is provided but no endpoints, create a default endpoint + if endpoints.is_empty() { + if let Some(base) = input.get("base_url").and_then(|v| v.as_str()) { + endpoints.push(DiscoveredEndpoint { + url: base.to_string(), + method: "GET".to_string(), + parameters: Vec::new(), + content_type: None, + requires_auth: false, + }); + } + } + + if endpoints.is_empty() { + return Ok(PentestToolResult { + summary: "No endpoints or base_url provided to fuzz.".to_string(), + findings: Vec::new(), + data: json!({}), + }); + } + + let dast_context = DastContext { + endpoints, + technologies: Vec::new(), + sast_hints: Vec::new(), + }; + + let findings = self.agent.run(&context.target, &dast_context).await?; + let count = findings.len(); + + Ok(PentestToolResult { + summary: if count > 0 { + format!("Found {count} API misconfigurations or information disclosures.") + } else { + "No API misconfigurations detected.".to_string() + }, + findings, + data: json!({ "endpoints_tested": dast_context.endpoints.len() }), + }) + }) + } +} diff --git a/compliance-dast/src/tools/auth_bypass.rs b/compliance-dast/src/tools/auth_bypass.rs new file mode 100644 index 0000000..6f0bf87 --- /dev/null +++ b/compliance-dast/src/tools/auth_bypass.rs @@ -0,0 +1,130 @@ +use compliance_core::error::CoreError; +use compliance_core::traits::dast_agent::{DastAgent, DastContext, DiscoveredEndpoint, EndpointParameter}; +use compliance_core::traits::pentest_tool::{PentestTool, PentestToolContext, PentestToolResult}; +use serde_json::json; + +use crate::agents::auth_bypass::AuthBypassAgent; + +/// PentestTool wrapper around the existing AuthBypassAgent. +pub struct AuthBypassTool { + http: reqwest::Client, + agent: AuthBypassAgent, +} + +impl AuthBypassTool { + pub fn new(http: reqwest::Client) -> Self { + let agent = AuthBypassAgent::new(http.clone()); + Self { http, agent } + } + + fn parse_endpoints(input: &serde_json::Value) -> Vec { + let mut endpoints = Vec::new(); + if let Some(arr) = input.get("endpoints").and_then(|v| v.as_array()) { + for ep in arr { + let url = ep.get("url").and_then(|v| v.as_str()).unwrap_or_default().to_string(); + let method = ep.get("method").and_then(|v| v.as_str()).unwrap_or("GET").to_string(); + let mut parameters = Vec::new(); + if let Some(params) = ep.get("parameters").and_then(|v| v.as_array()) { + for p in params { + parameters.push(EndpointParameter { + name: p.get("name").and_then(|v| v.as_str()).unwrap_or_default().to_string(), + location: p.get("location").and_then(|v| v.as_str()).unwrap_or("query").to_string(), + param_type: p.get("param_type").and_then(|v| v.as_str()).map(String::from), + example_value: p.get("example_value").and_then(|v| v.as_str()).map(String::from), + }); + } + } + endpoints.push(DiscoveredEndpoint { + url, + method, + parameters, + content_type: ep.get("content_type").and_then(|v| v.as_str()).map(String::from), + requires_auth: ep.get("requires_auth").and_then(|v| v.as_bool()).unwrap_or(false), + }); + } + } + endpoints + } +} + +impl PentestTool for AuthBypassTool { + fn name(&self) -> &str { + "auth_bypass_scanner" + } + + fn description(&self) -> &str { + "Tests endpoints for authentication bypass vulnerabilities. Tries accessing protected \ + endpoints without credentials, with manipulated tokens, and with common default credentials." + } + + fn input_schema(&self) -> serde_json::Value { + json!({ + "type": "object", + "properties": { + "endpoints": { + "type": "array", + "description": "Endpoints to test for authentication bypass", + "items": { + "type": "object", + "properties": { + "url": { "type": "string" }, + "method": { "type": "string", "enum": ["GET", "POST", "PUT", "PATCH", "DELETE"] }, + "parameters": { + "type": "array", + "items": { + "type": "object", + "properties": { + "name": { "type": "string" }, + "location": { "type": "string" }, + "param_type": { "type": "string" }, + "example_value": { "type": "string" } + }, + "required": ["name"] + } + }, + "requires_auth": { "type": "boolean", "description": "Whether this endpoint requires authentication" } + }, + "required": ["url", "method"] + } + } + }, + "required": ["endpoints"] + }) + } + + fn execute<'a>( + &'a self, + input: serde_json::Value, + context: &'a PentestToolContext, + ) -> std::pin::Pin> + Send + 'a>> { + Box::pin(async move { + let endpoints = Self::parse_endpoints(&input); + if endpoints.is_empty() { + return Ok(PentestToolResult { + summary: "No endpoints provided to test.".to_string(), + findings: Vec::new(), + data: json!({}), + }); + } + + let dast_context = DastContext { + endpoints, + technologies: Vec::new(), + sast_hints: Vec::new(), + }; + + let findings = self.agent.run(&context.target, &dast_context).await?; + let count = findings.len(); + + Ok(PentestToolResult { + summary: if count > 0 { + format!("Found {count} authentication bypass vulnerabilities.") + } else { + "No authentication bypass vulnerabilities detected.".to_string() + }, + findings, + data: json!({ "endpoints_tested": dast_context.endpoints.len() }), + }) + }) + } +} diff --git a/compliance-dast/src/tools/console_log_detector.rs b/compliance-dast/src/tools/console_log_detector.rs new file mode 100644 index 0000000..c1bc9cd --- /dev/null +++ b/compliance-dast/src/tools/console_log_detector.rs @@ -0,0 +1,326 @@ +use compliance_core::error::CoreError; +use compliance_core::models::dast::{DastEvidence, DastFinding, DastVulnType}; +use compliance_core::models::Severity; +use compliance_core::traits::pentest_tool::{PentestTool, PentestToolContext, PentestToolResult}; +use serde_json::json; +use tracing::info; + +/// Tool that detects console.log and similar debug statements in frontend JavaScript. +pub struct ConsoleLogDetectorTool { + http: reqwest::Client, +} + +/// A detected console statement with its context. +#[derive(Debug)] +struct ConsoleMatch { + pattern: String, + file_url: String, + line_snippet: String, + line_number: Option, +} + +impl ConsoleLogDetectorTool { + pub fn new(http: reqwest::Client) -> Self { + Self { http } + } + + /// Patterns that indicate debug/logging statements left in production code. + fn patterns() -> Vec<&'static str> { + vec![ + "console.log(", + "console.debug(", + "console.error(", + "console.warn(", + "console.info(", + "console.trace(", + "console.dir(", + "console.table(", + "debugger;", + "alert(", + ] + } + + /// Extract JavaScript file URLs from an HTML page body. + fn extract_js_urls(html: &str, base_url: &str) -> Vec { + let mut urls = Vec::new(); + let base = url::Url::parse(base_url).ok(); + + // Simple regex-free extraction of