feat: add pentest MCP tools, session timeout, and error recovery

Add 5 MCP tools for querying pentest sessions, attack chains, messages,
and stats. Add session timeout (30min) and automatic failure marking
with run_session_guarded wrapper.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Sharang Parnerkar
2026-03-11 19:49:39 +01:00
parent 03d8e16e13
commit ad9036e5ad
6 changed files with 393 additions and 21 deletions

View File

@@ -102,18 +102,9 @@ pub async fn create_session(
let target_clone = target.clone();
tokio::spawn(async move {
let orchestrator = PentestOrchestrator::new(llm, db);
if let Err(e) = orchestrator
.run_session(&session_clone, &target_clone, &initial_message)
.await
{
tracing::error!(
"Pentest orchestrator failed for session {}: {e}",
session_clone
.id
.map(|oid| oid.to_hex())
.unwrap_or_default()
);
}
orchestrator
.run_session_guarded(&session_clone, &target_clone, &initial_message)
.await;
});
Ok(Json(ApiResponse {
@@ -254,9 +245,9 @@ pub async fn send_message(
let message = req.message.clone();
tokio::spawn(async move {
let orchestrator = PentestOrchestrator::new(llm, db);
if let Err(e) = orchestrator.run_session(&session, &target, &message).await {
tracing::error!("Pentest orchestrator failed for session {session_id}: {e}");
}
orchestrator
.run_session_guarded(&session, &target, &message)
.await;
});
Ok(Json(ApiResponse {
@@ -474,7 +465,6 @@ pub async fn pentest_stats(
};
// Severity distribution from pentest-related DAST findings
let pentest_filter = doc! { "session_id": { "$exists": true, "$ne": null } };
let critical = db
.dast_findings()
.count_documents(doc! { "session_id": { "$exists": true, "$ne": null }, "severity": "critical" })
@@ -501,8 +491,6 @@ pub async fn pentest_stats(
.await
.unwrap_or(0) as u32;
let _ = pentest_filter; // used above inline
Ok(Json(ApiResponse {
data: PentestStats {
running_sessions,

View File

@@ -1,4 +1,5 @@
use std::sync::Arc;
use std::time::Duration;
use futures_util::StreamExt;
use mongodb::bson::doc;
@@ -17,6 +18,9 @@ use crate::llm::client::{
};
use crate::llm::LlmClient;
/// Maximum duration for a single pentest session before timeout
const SESSION_TIMEOUT: Duration = Duration::from_secs(30 * 60); // 30 minutes
pub struct PentestOrchestrator {
tool_registry: ToolRegistry,
llm: Arc<LlmClient>,
@@ -43,7 +47,65 @@ impl PentestOrchestrator {
self.event_tx.clone()
}
pub async fn run_session(
/// Run a pentest session with timeout and automatic failure marking on errors.
pub async fn run_session_guarded(
&self,
session: &PentestSession,
target: &DastTarget,
initial_message: &str,
) {
let session_id = session.id;
match tokio::time::timeout(
SESSION_TIMEOUT,
self.run_session(session, target, initial_message),
)
.await
{
Ok(Ok(())) => {
tracing::info!(?session_id, "Pentest session completed successfully");
}
Ok(Err(e)) => {
tracing::error!(?session_id, error = %e, "Pentest session failed");
self.mark_session_failed(session_id, &format!("Error: {e}"))
.await;
let _ = self.event_tx.send(PentestEvent::Error {
message: format!("Session failed: {e}"),
});
}
Err(_) => {
tracing::warn!(?session_id, "Pentest session timed out after 30 minutes");
self.mark_session_failed(session_id, "Session timed out after 30 minutes")
.await;
let _ = self.event_tx.send(PentestEvent::Error {
message: "Session timed out after 30 minutes".to_string(),
});
}
}
}
async fn mark_session_failed(
&self,
session_id: Option<mongodb::bson::oid::ObjectId>,
reason: &str,
) {
if let Some(sid) = session_id {
let _ = self
.db
.pentest_sessions()
.update_one(
doc! { "_id": sid },
doc! { "$set": {
"status": "failed",
"completed_at": mongodb::bson::DateTime::now(),
"error_message": reason,
}},
)
.await;
}
}
async fn run_session(
&self,
session: &PentestSession,
target: &DastTarget,

View File

@@ -31,4 +31,16 @@ impl Database {
pub fn dast_scan_runs(&self) -> Collection<DastScanRun> {
self.inner.collection("dast_scan_runs")
}
pub fn pentest_sessions(&self) -> Collection<PentestSession> {
self.inner.collection("pentest_sessions")
}
pub fn attack_chain_nodes(&self) -> Collection<AttackChainNode> {
self.inner.collection("attack_chain_nodes")
}
pub fn pentest_messages(&self) -> Collection<PentestMessage> {
self.inner.collection("pentest_messages")
}
}

View File

@@ -3,7 +3,7 @@ use rmcp::{
};
use crate::database::Database;
use crate::tools::{dast, findings, sbom};
use crate::tools::{dast, findings, pentest, sbom};
pub struct ComplianceMcpServer {
db: Database,
@@ -89,6 +89,54 @@ impl ComplianceMcpServer {
) -> Result<CallToolResult, rmcp::ErrorData> {
dast::dast_scan_summary(&self.db, params).await
}
// ── Pentest ─────────────────────────────────────────────
#[tool(
description = "List AI pentest sessions with optional filters for target, status, and strategy"
)]
async fn list_pentest_sessions(
&self,
Parameters(params): Parameters<pentest::ListPentestSessionsParams>,
) -> Result<CallToolResult, rmcp::ErrorData> {
pentest::list_pentest_sessions(&self.db, params).await
}
#[tool(description = "Get a single AI pentest session by its ID")]
async fn get_pentest_session(
&self,
Parameters(params): Parameters<pentest::GetPentestSessionParams>,
) -> Result<CallToolResult, rmcp::ErrorData> {
pentest::get_pentest_session(&self.db, params).await
}
#[tool(
description = "Get the attack chain DAG for a pentest session showing each tool invocation, its reasoning, and results"
)]
async fn get_attack_chain(
&self,
Parameters(params): Parameters<pentest::GetAttackChainParams>,
) -> Result<CallToolResult, rmcp::ErrorData> {
pentest::get_attack_chain(&self.db, params).await
}
#[tool(description = "Get chat messages from a pentest session")]
async fn get_pentest_messages(
&self,
Parameters(params): Parameters<pentest::GetPentestMessagesParams>,
) -> Result<CallToolResult, rmcp::ErrorData> {
pentest::get_pentest_messages(&self.db, params).await
}
#[tool(
description = "Get aggregated pentest statistics including running sessions, vulnerability counts, and severity distribution"
)]
async fn pentest_stats(
&self,
Parameters(params): Parameters<pentest::PentestStatsParams>,
) -> Result<CallToolResult, rmcp::ErrorData> {
pentest::pentest_stats(&self.db, params).await
}
}
#[tool_handler]
@@ -101,7 +149,7 @@ impl ServerHandler for ComplianceMcpServer {
.build(),
server_info: Implementation::from_build_env(),
instructions: Some(
"Compliance Scanner MCP server. Query security findings, SBOM data, and DAST results."
"Compliance Scanner MCP server. Query security findings, SBOM data, DAST results, and AI pentest sessions."
.to_string(),
),
}

View File

@@ -1,3 +1,4 @@
pub mod dast;
pub mod findings;
pub mod pentest;
pub mod sbom;

View File

@@ -0,0 +1,261 @@
use mongodb::bson::doc;
use rmcp::{model::*, ErrorData as McpError};
use schemars::JsonSchema;
use serde::Deserialize;
use crate::database::Database;
const MAX_LIMIT: i64 = 200;
const DEFAULT_LIMIT: i64 = 50;
fn cap_limit(limit: Option<i64>) -> i64 {
limit.unwrap_or(DEFAULT_LIMIT).clamp(1, MAX_LIMIT)
}
// ── List Pentest Sessions ──────────────────────────────────────
#[derive(Debug, Deserialize, JsonSchema)]
pub struct ListPentestSessionsParams {
/// Filter by target ID
pub target_id: Option<String>,
/// Filter by status: running, paused, completed, failed
pub status: Option<String>,
/// Filter by strategy: quick, comprehensive, targeted, aggressive, stealth
pub strategy: Option<String>,
/// Maximum number of results (default 50, max 200)
pub limit: Option<i64>,
}
pub async fn list_pentest_sessions(
db: &Database,
params: ListPentestSessionsParams,
) -> Result<CallToolResult, McpError> {
let mut filter = doc! {};
if let Some(ref target_id) = params.target_id {
filter.insert("target_id", target_id);
}
if let Some(ref status) = params.status {
filter.insert("status", status);
}
if let Some(ref strategy) = params.strategy {
filter.insert("strategy", strategy);
}
let limit = cap_limit(params.limit);
let mut cursor = db
.pentest_sessions()
.find(filter)
.sort(doc! { "started_at": -1 })
.limit(limit)
.await
.map_err(|e| McpError::internal_error(format!("DB error: {e}"), None))?;
let mut results = Vec::new();
while cursor
.advance()
.await
.map_err(|e| McpError::internal_error(format!("cursor error: {e}"), None))?
{
let session = cursor
.deserialize_current()
.map_err(|e| McpError::internal_error(format!("deserialize error: {e}"), None))?;
results.push(session);
}
let json = serde_json::to_string_pretty(&results)
.map_err(|e| McpError::internal_error(format!("json error: {e}"), None))?;
Ok(CallToolResult::success(vec![Content::text(json)]))
}
// ── Get Pentest Session ────────────────────────────────────────
#[derive(Debug, Deserialize, JsonSchema)]
pub struct GetPentestSessionParams {
/// Pentest session ID (MongoDB ObjectId hex string)
pub id: String,
}
pub async fn get_pentest_session(
db: &Database,
params: GetPentestSessionParams,
) -> Result<CallToolResult, McpError> {
let oid = bson::oid::ObjectId::parse_str(&params.id)
.map_err(|e| McpError::invalid_params(format!("invalid id: {e}"), None))?;
let session = db
.pentest_sessions()
.find_one(doc! { "_id": oid })
.await
.map_err(|e| McpError::internal_error(format!("DB error: {e}"), None))?
.ok_or_else(|| McpError::invalid_params("session not found", None))?;
let json = serde_json::to_string_pretty(&session)
.map_err(|e| McpError::internal_error(format!("json error: {e}"), None))?;
Ok(CallToolResult::success(vec![Content::text(json)]))
}
// ── Get Attack Chain ───────────────────────────────────────────
#[derive(Debug, Deserialize, JsonSchema)]
pub struct GetAttackChainParams {
/// Pentest session ID to get the attack chain for
pub session_id: String,
/// Maximum number of nodes (default 50, max 200)
pub limit: Option<i64>,
}
pub async fn get_attack_chain(
db: &Database,
params: GetAttackChainParams,
) -> Result<CallToolResult, McpError> {
let limit = cap_limit(params.limit);
let mut cursor = db
.attack_chain_nodes()
.find(doc! { "session_id": &params.session_id })
.sort(doc! { "started_at": 1 })
.limit(limit)
.await
.map_err(|e| McpError::internal_error(format!("DB error: {e}"), None))?;
let mut results = Vec::new();
while cursor
.advance()
.await
.map_err(|e| McpError::internal_error(format!("cursor error: {e}"), None))?
{
let node = cursor
.deserialize_current()
.map_err(|e| McpError::internal_error(format!("deserialize error: {e}"), None))?;
results.push(node);
}
let json = serde_json::to_string_pretty(&results)
.map_err(|e| McpError::internal_error(format!("json error: {e}"), None))?;
Ok(CallToolResult::success(vec![Content::text(json)]))
}
// ── Get Pentest Messages ───────────────────────────────────────
#[derive(Debug, Deserialize, JsonSchema)]
pub struct GetPentestMessagesParams {
/// Pentest session ID
pub session_id: String,
/// Maximum number of messages (default 50, max 200)
pub limit: Option<i64>,
}
pub async fn get_pentest_messages(
db: &Database,
params: GetPentestMessagesParams,
) -> Result<CallToolResult, McpError> {
let limit = cap_limit(params.limit);
let mut cursor = db
.pentest_messages()
.find(doc! { "session_id": &params.session_id })
.sort(doc! { "created_at": 1 })
.limit(limit)
.await
.map_err(|e| McpError::internal_error(format!("DB error: {e}"), None))?;
let mut results = Vec::new();
while cursor
.advance()
.await
.map_err(|e| McpError::internal_error(format!("cursor error: {e}"), None))?
{
let msg = cursor
.deserialize_current()
.map_err(|e| McpError::internal_error(format!("deserialize error: {e}"), None))?;
results.push(msg);
}
let json = serde_json::to_string_pretty(&results)
.map_err(|e| McpError::internal_error(format!("json error: {e}"), None))?;
Ok(CallToolResult::success(vec![Content::text(json)]))
}
// ── Pentest Stats ──────────────────────────────────────────────
#[derive(Debug, Deserialize, JsonSchema)]
pub struct PentestStatsParams {
/// Filter stats by target ID
pub target_id: Option<String>,
}
pub async fn pentest_stats(
db: &Database,
params: PentestStatsParams,
) -> Result<CallToolResult, McpError> {
let mut base_filter = doc! {};
if let Some(ref target_id) = params.target_id {
base_filter.insert("target_id", target_id);
}
// Count running sessions
let mut running_filter = base_filter.clone();
running_filter.insert("status", "running");
let running = db
.pentest_sessions()
.count_documents(running_filter)
.await
.map_err(|e| McpError::internal_error(format!("DB error: {e}"), None))?;
// Count total sessions
let total_sessions = db
.pentest_sessions()
.count_documents(base_filter.clone())
.await
.map_err(|e| McpError::internal_error(format!("DB error: {e}"), None))?;
// Get findings for these sessions — query DAST findings with session_id set
let mut findings_filter = doc! { "session_id": { "$ne": null } };
if let Some(ref target_id) = params.target_id {
findings_filter.insert("target_id", target_id);
}
let total_findings = db
.dast_findings()
.count_documents(findings_filter.clone())
.await
.map_err(|e| McpError::internal_error(format!("DB error: {e}"), None))?;
let mut exploitable_filter = findings_filter.clone();
exploitable_filter.insert("exploitable", true);
let exploitable = db
.dast_findings()
.count_documents(exploitable_filter)
.await
.map_err(|e| McpError::internal_error(format!("DB error: {e}"), None))?;
// Severity counts
let mut severity = serde_json::Map::new();
for sev in ["critical", "high", "medium", "low", "info"] {
let mut sf = findings_filter.clone();
sf.insert("severity", sev);
let count = db
.dast_findings()
.count_documents(sf)
.await
.map_err(|e| McpError::internal_error(format!("DB error: {e}"), None))?;
severity.insert(sev.to_string(), serde_json::json!(count));
}
let summary = serde_json::json!({
"running_sessions": running,
"total_sessions": total_sessions,
"total_findings": total_findings,
"exploitable_findings": exploitable,
"severity_distribution": severity,
});
let json = serde_json::to_string_pretty(&summary)
.map_err(|e| McpError::internal_error(format!("json error: {e}"), None))?;
Ok(CallToolResult::success(vec![Content::text(json)]))
}