From ad9036e5ad4460a2203e82d4cd322b722c9d4f10 Mon Sep 17 00:00:00 2001 From: Sharang Parnerkar Date: Wed, 11 Mar 2026 19:49:39 +0100 Subject: [PATCH] feat: add pentest MCP tools, session timeout, and error recovery Add 5 MCP tools for querying pentest sessions, attack chains, messages, and stats. Add session timeout (30min) and automatic failure marking with run_session_guarded wrapper. Co-Authored-By: Claude Opus 4.6 --- compliance-agent/src/api/handlers/pentest.rs | 24 +- compliance-agent/src/pentest/orchestrator.rs | 64 ++++- compliance-mcp/src/database.rs | 12 + compliance-mcp/src/server.rs | 52 +++- compliance-mcp/src/tools/mod.rs | 1 + compliance-mcp/src/tools/pentest.rs | 261 +++++++++++++++++++ 6 files changed, 393 insertions(+), 21 deletions(-) create mode 100644 compliance-mcp/src/tools/pentest.rs diff --git a/compliance-agent/src/api/handlers/pentest.rs b/compliance-agent/src/api/handlers/pentest.rs index ae8014b..020c2e6 100644 --- a/compliance-agent/src/api/handlers/pentest.rs +++ b/compliance-agent/src/api/handlers/pentest.rs @@ -102,18 +102,9 @@ pub async fn create_session( let target_clone = target.clone(); tokio::spawn(async move { let orchestrator = PentestOrchestrator::new(llm, db); - if let Err(e) = orchestrator - .run_session(&session_clone, &target_clone, &initial_message) - .await - { - tracing::error!( - "Pentest orchestrator failed for session {}: {e}", - session_clone - .id - .map(|oid| oid.to_hex()) - .unwrap_or_default() - ); - } + orchestrator + .run_session_guarded(&session_clone, &target_clone, &initial_message) + .await; }); Ok(Json(ApiResponse { @@ -254,9 +245,9 @@ pub async fn send_message( let message = req.message.clone(); tokio::spawn(async move { let orchestrator = PentestOrchestrator::new(llm, db); - if let Err(e) = orchestrator.run_session(&session, &target, &message).await { - tracing::error!("Pentest orchestrator failed for session {session_id}: {e}"); - } + orchestrator + .run_session_guarded(&session, &target, &message) + .await; }); Ok(Json(ApiResponse { @@ -474,7 +465,6 @@ pub async fn pentest_stats( }; // Severity distribution from pentest-related DAST findings - let pentest_filter = doc! { "session_id": { "$exists": true, "$ne": null } }; let critical = db .dast_findings() .count_documents(doc! { "session_id": { "$exists": true, "$ne": null }, "severity": "critical" }) @@ -501,8 +491,6 @@ pub async fn pentest_stats( .await .unwrap_or(0) as u32; - let _ = pentest_filter; // used above inline - Ok(Json(ApiResponse { data: PentestStats { running_sessions, diff --git a/compliance-agent/src/pentest/orchestrator.rs b/compliance-agent/src/pentest/orchestrator.rs index 8da2eef..4fd9e2f 100644 --- a/compliance-agent/src/pentest/orchestrator.rs +++ b/compliance-agent/src/pentest/orchestrator.rs @@ -1,4 +1,5 @@ use std::sync::Arc; +use std::time::Duration; use futures_util::StreamExt; use mongodb::bson::doc; @@ -17,6 +18,9 @@ use crate::llm::client::{ }; use crate::llm::LlmClient; +/// Maximum duration for a single pentest session before timeout +const SESSION_TIMEOUT: Duration = Duration::from_secs(30 * 60); // 30 minutes + pub struct PentestOrchestrator { tool_registry: ToolRegistry, llm: Arc, @@ -43,7 +47,65 @@ impl PentestOrchestrator { self.event_tx.clone() } - pub async fn run_session( + /// Run a pentest session with timeout and automatic failure marking on errors. + pub async fn run_session_guarded( + &self, + session: &PentestSession, + target: &DastTarget, + initial_message: &str, + ) { + let session_id = session.id; + + match tokio::time::timeout( + SESSION_TIMEOUT, + self.run_session(session, target, initial_message), + ) + .await + { + Ok(Ok(())) => { + tracing::info!(?session_id, "Pentest session completed successfully"); + } + Ok(Err(e)) => { + tracing::error!(?session_id, error = %e, "Pentest session failed"); + self.mark_session_failed(session_id, &format!("Error: {e}")) + .await; + let _ = self.event_tx.send(PentestEvent::Error { + message: format!("Session failed: {e}"), + }); + } + Err(_) => { + tracing::warn!(?session_id, "Pentest session timed out after 30 minutes"); + self.mark_session_failed(session_id, "Session timed out after 30 minutes") + .await; + let _ = self.event_tx.send(PentestEvent::Error { + message: "Session timed out after 30 minutes".to_string(), + }); + } + } + } + + async fn mark_session_failed( + &self, + session_id: Option, + reason: &str, + ) { + if let Some(sid) = session_id { + let _ = self + .db + .pentest_sessions() + .update_one( + doc! { "_id": sid }, + doc! { "$set": { + "status": "failed", + "completed_at": mongodb::bson::DateTime::now(), + "error_message": reason, + }}, + ) + .await; + } + } + + async fn run_session( &self, session: &PentestSession, target: &DastTarget, diff --git a/compliance-mcp/src/database.rs b/compliance-mcp/src/database.rs index 2d4e6c9..041d151 100644 --- a/compliance-mcp/src/database.rs +++ b/compliance-mcp/src/database.rs @@ -31,4 +31,16 @@ impl Database { pub fn dast_scan_runs(&self) -> Collection { self.inner.collection("dast_scan_runs") } + + pub fn pentest_sessions(&self) -> Collection { + self.inner.collection("pentest_sessions") + } + + pub fn attack_chain_nodes(&self) -> Collection { + self.inner.collection("attack_chain_nodes") + } + + pub fn pentest_messages(&self) -> Collection { + self.inner.collection("pentest_messages") + } } diff --git a/compliance-mcp/src/server.rs b/compliance-mcp/src/server.rs index 93ee55c..d40ba82 100644 --- a/compliance-mcp/src/server.rs +++ b/compliance-mcp/src/server.rs @@ -3,7 +3,7 @@ use rmcp::{ }; use crate::database::Database; -use crate::tools::{dast, findings, sbom}; +use crate::tools::{dast, findings, pentest, sbom}; pub struct ComplianceMcpServer { db: Database, @@ -89,6 +89,54 @@ impl ComplianceMcpServer { ) -> Result { dast::dast_scan_summary(&self.db, params).await } + + // ── Pentest ───────────────────────────────────────────── + + #[tool( + description = "List AI pentest sessions with optional filters for target, status, and strategy" + )] + async fn list_pentest_sessions( + &self, + Parameters(params): Parameters, + ) -> Result { + pentest::list_pentest_sessions(&self.db, params).await + } + + #[tool(description = "Get a single AI pentest session by its ID")] + async fn get_pentest_session( + &self, + Parameters(params): Parameters, + ) -> Result { + pentest::get_pentest_session(&self.db, params).await + } + + #[tool( + description = "Get the attack chain DAG for a pentest session showing each tool invocation, its reasoning, and results" + )] + async fn get_attack_chain( + &self, + Parameters(params): Parameters, + ) -> Result { + pentest::get_attack_chain(&self.db, params).await + } + + #[tool(description = "Get chat messages from a pentest session")] + async fn get_pentest_messages( + &self, + Parameters(params): Parameters, + ) -> Result { + pentest::get_pentest_messages(&self.db, params).await + } + + #[tool( + description = "Get aggregated pentest statistics including running sessions, vulnerability counts, and severity distribution" + )] + async fn pentest_stats( + &self, + Parameters(params): Parameters, + ) -> Result { + pentest::pentest_stats(&self.db, params).await + } } #[tool_handler] @@ -101,7 +149,7 @@ impl ServerHandler for ComplianceMcpServer { .build(), server_info: Implementation::from_build_env(), instructions: Some( - "Compliance Scanner MCP server. Query security findings, SBOM data, and DAST results." + "Compliance Scanner MCP server. Query security findings, SBOM data, DAST results, and AI pentest sessions." .to_string(), ), } diff --git a/compliance-mcp/src/tools/mod.rs b/compliance-mcp/src/tools/mod.rs index cf383fc..373d1a3 100644 --- a/compliance-mcp/src/tools/mod.rs +++ b/compliance-mcp/src/tools/mod.rs @@ -1,3 +1,4 @@ pub mod dast; pub mod findings; +pub mod pentest; pub mod sbom; diff --git a/compliance-mcp/src/tools/pentest.rs b/compliance-mcp/src/tools/pentest.rs new file mode 100644 index 0000000..f6c7db9 --- /dev/null +++ b/compliance-mcp/src/tools/pentest.rs @@ -0,0 +1,261 @@ +use mongodb::bson::doc; +use rmcp::{model::*, ErrorData as McpError}; +use schemars::JsonSchema; +use serde::Deserialize; + +use crate::database::Database; + +const MAX_LIMIT: i64 = 200; +const DEFAULT_LIMIT: i64 = 50; + +fn cap_limit(limit: Option) -> i64 { + limit.unwrap_or(DEFAULT_LIMIT).clamp(1, MAX_LIMIT) +} + +// ── List Pentest Sessions ────────────────────────────────────── + +#[derive(Debug, Deserialize, JsonSchema)] +pub struct ListPentestSessionsParams { + /// Filter by target ID + pub target_id: Option, + /// Filter by status: running, paused, completed, failed + pub status: Option, + /// Filter by strategy: quick, comprehensive, targeted, aggressive, stealth + pub strategy: Option, + /// Maximum number of results (default 50, max 200) + pub limit: Option, +} + +pub async fn list_pentest_sessions( + db: &Database, + params: ListPentestSessionsParams, +) -> Result { + let mut filter = doc! {}; + if let Some(ref target_id) = params.target_id { + filter.insert("target_id", target_id); + } + if let Some(ref status) = params.status { + filter.insert("status", status); + } + if let Some(ref strategy) = params.strategy { + filter.insert("strategy", strategy); + } + + let limit = cap_limit(params.limit); + + let mut cursor = db + .pentest_sessions() + .find(filter) + .sort(doc! { "started_at": -1 }) + .limit(limit) + .await + .map_err(|e| McpError::internal_error(format!("DB error: {e}"), None))?; + + let mut results = Vec::new(); + while cursor + .advance() + .await + .map_err(|e| McpError::internal_error(format!("cursor error: {e}"), None))? + { + let session = cursor + .deserialize_current() + .map_err(|e| McpError::internal_error(format!("deserialize error: {e}"), None))?; + results.push(session); + } + + let json = serde_json::to_string_pretty(&results) + .map_err(|e| McpError::internal_error(format!("json error: {e}"), None))?; + + Ok(CallToolResult::success(vec![Content::text(json)])) +} + +// ── Get Pentest Session ──────────────────────────────────────── + +#[derive(Debug, Deserialize, JsonSchema)] +pub struct GetPentestSessionParams { + /// Pentest session ID (MongoDB ObjectId hex string) + pub id: String, +} + +pub async fn get_pentest_session( + db: &Database, + params: GetPentestSessionParams, +) -> Result { + let oid = bson::oid::ObjectId::parse_str(¶ms.id) + .map_err(|e| McpError::invalid_params(format!("invalid id: {e}"), None))?; + + let session = db + .pentest_sessions() + .find_one(doc! { "_id": oid }) + .await + .map_err(|e| McpError::internal_error(format!("DB error: {e}"), None))? + .ok_or_else(|| McpError::invalid_params("session not found", None))?; + + let json = serde_json::to_string_pretty(&session) + .map_err(|e| McpError::internal_error(format!("json error: {e}"), None))?; + + Ok(CallToolResult::success(vec![Content::text(json)])) +} + +// ── Get Attack Chain ─────────────────────────────────────────── + +#[derive(Debug, Deserialize, JsonSchema)] +pub struct GetAttackChainParams { + /// Pentest session ID to get the attack chain for + pub session_id: String, + /// Maximum number of nodes (default 50, max 200) + pub limit: Option, +} + +pub async fn get_attack_chain( + db: &Database, + params: GetAttackChainParams, +) -> Result { + let limit = cap_limit(params.limit); + + let mut cursor = db + .attack_chain_nodes() + .find(doc! { "session_id": ¶ms.session_id }) + .sort(doc! { "started_at": 1 }) + .limit(limit) + .await + .map_err(|e| McpError::internal_error(format!("DB error: {e}"), None))?; + + let mut results = Vec::new(); + while cursor + .advance() + .await + .map_err(|e| McpError::internal_error(format!("cursor error: {e}"), None))? + { + let node = cursor + .deserialize_current() + .map_err(|e| McpError::internal_error(format!("deserialize error: {e}"), None))?; + results.push(node); + } + + let json = serde_json::to_string_pretty(&results) + .map_err(|e| McpError::internal_error(format!("json error: {e}"), None))?; + + Ok(CallToolResult::success(vec![Content::text(json)])) +} + +// ── Get Pentest Messages ─────────────────────────────────────── + +#[derive(Debug, Deserialize, JsonSchema)] +pub struct GetPentestMessagesParams { + /// Pentest session ID + pub session_id: String, + /// Maximum number of messages (default 50, max 200) + pub limit: Option, +} + +pub async fn get_pentest_messages( + db: &Database, + params: GetPentestMessagesParams, +) -> Result { + let limit = cap_limit(params.limit); + + let mut cursor = db + .pentest_messages() + .find(doc! { "session_id": ¶ms.session_id }) + .sort(doc! { "created_at": 1 }) + .limit(limit) + .await + .map_err(|e| McpError::internal_error(format!("DB error: {e}"), None))?; + + let mut results = Vec::new(); + while cursor + .advance() + .await + .map_err(|e| McpError::internal_error(format!("cursor error: {e}"), None))? + { + let msg = cursor + .deserialize_current() + .map_err(|e| McpError::internal_error(format!("deserialize error: {e}"), None))?; + results.push(msg); + } + + let json = serde_json::to_string_pretty(&results) + .map_err(|e| McpError::internal_error(format!("json error: {e}"), None))?; + + Ok(CallToolResult::success(vec![Content::text(json)])) +} + +// ── Pentest Stats ────────────────────────────────────────────── + +#[derive(Debug, Deserialize, JsonSchema)] +pub struct PentestStatsParams { + /// Filter stats by target ID + pub target_id: Option, +} + +pub async fn pentest_stats( + db: &Database, + params: PentestStatsParams, +) -> Result { + let mut base_filter = doc! {}; + if let Some(ref target_id) = params.target_id { + base_filter.insert("target_id", target_id); + } + + // Count running sessions + let mut running_filter = base_filter.clone(); + running_filter.insert("status", "running"); + let running = db + .pentest_sessions() + .count_documents(running_filter) + .await + .map_err(|e| McpError::internal_error(format!("DB error: {e}"), None))?; + + // Count total sessions + let total_sessions = db + .pentest_sessions() + .count_documents(base_filter.clone()) + .await + .map_err(|e| McpError::internal_error(format!("DB error: {e}"), None))?; + + // Get findings for these sessions — query DAST findings with session_id set + let mut findings_filter = doc! { "session_id": { "$ne": null } }; + if let Some(ref target_id) = params.target_id { + findings_filter.insert("target_id", target_id); + } + let total_findings = db + .dast_findings() + .count_documents(findings_filter.clone()) + .await + .map_err(|e| McpError::internal_error(format!("DB error: {e}"), None))?; + + let mut exploitable_filter = findings_filter.clone(); + exploitable_filter.insert("exploitable", true); + let exploitable = db + .dast_findings() + .count_documents(exploitable_filter) + .await + .map_err(|e| McpError::internal_error(format!("DB error: {e}"), None))?; + + // Severity counts + let mut severity = serde_json::Map::new(); + for sev in ["critical", "high", "medium", "low", "info"] { + let mut sf = findings_filter.clone(); + sf.insert("severity", sev); + let count = db + .dast_findings() + .count_documents(sf) + .await + .map_err(|e| McpError::internal_error(format!("DB error: {e}"), None))?; + severity.insert(sev.to_string(), serde_json::json!(count)); + } + + let summary = serde_json::json!({ + "running_sessions": running, + "total_sessions": total_sessions, + "total_findings": total_findings, + "exploitable_findings": exploitable, + "severity_distribution": severity, + }); + + let json = serde_json::to_string_pretty(&summary) + .map_err(|e| McpError::internal_error(format!("json error: {e}"), None))?; + + Ok(CallToolResult::success(vec![Content::text(json)])) +}