feat: add pentest MCP tools, session timeout, and error recovery
Add 5 MCP tools for querying pentest sessions, attack chains, messages, and stats. Add session timeout (30min) and automatic failure marking with run_session_guarded wrapper. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -102,18 +102,9 @@ pub async fn create_session(
|
||||
let target_clone = target.clone();
|
||||
tokio::spawn(async move {
|
||||
let orchestrator = PentestOrchestrator::new(llm, db);
|
||||
if let Err(e) = orchestrator
|
||||
.run_session(&session_clone, &target_clone, &initial_message)
|
||||
.await
|
||||
{
|
||||
tracing::error!(
|
||||
"Pentest orchestrator failed for session {}: {e}",
|
||||
session_clone
|
||||
.id
|
||||
.map(|oid| oid.to_hex())
|
||||
.unwrap_or_default()
|
||||
);
|
||||
}
|
||||
orchestrator
|
||||
.run_session_guarded(&session_clone, &target_clone, &initial_message)
|
||||
.await;
|
||||
});
|
||||
|
||||
Ok(Json(ApiResponse {
|
||||
@@ -254,9 +245,9 @@ pub async fn send_message(
|
||||
let message = req.message.clone();
|
||||
tokio::spawn(async move {
|
||||
let orchestrator = PentestOrchestrator::new(llm, db);
|
||||
if let Err(e) = orchestrator.run_session(&session, &target, &message).await {
|
||||
tracing::error!("Pentest orchestrator failed for session {session_id}: {e}");
|
||||
}
|
||||
orchestrator
|
||||
.run_session_guarded(&session, &target, &message)
|
||||
.await;
|
||||
});
|
||||
|
||||
Ok(Json(ApiResponse {
|
||||
@@ -474,7 +465,6 @@ pub async fn pentest_stats(
|
||||
};
|
||||
|
||||
// Severity distribution from pentest-related DAST findings
|
||||
let pentest_filter = doc! { "session_id": { "$exists": true, "$ne": null } };
|
||||
let critical = db
|
||||
.dast_findings()
|
||||
.count_documents(doc! { "session_id": { "$exists": true, "$ne": null }, "severity": "critical" })
|
||||
@@ -501,8 +491,6 @@ pub async fn pentest_stats(
|
||||
.await
|
||||
.unwrap_or(0) as u32;
|
||||
|
||||
let _ = pentest_filter; // used above inline
|
||||
|
||||
Ok(Json(ApiResponse {
|
||||
data: PentestStats {
|
||||
running_sessions,
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use futures_util::StreamExt;
|
||||
use mongodb::bson::doc;
|
||||
@@ -17,6 +18,9 @@ use crate::llm::client::{
|
||||
};
|
||||
use crate::llm::LlmClient;
|
||||
|
||||
/// Maximum duration for a single pentest session before timeout
|
||||
const SESSION_TIMEOUT: Duration = Duration::from_secs(30 * 60); // 30 minutes
|
||||
|
||||
pub struct PentestOrchestrator {
|
||||
tool_registry: ToolRegistry,
|
||||
llm: Arc<LlmClient>,
|
||||
@@ -43,7 +47,65 @@ impl PentestOrchestrator {
|
||||
self.event_tx.clone()
|
||||
}
|
||||
|
||||
pub async fn run_session(
|
||||
/// Run a pentest session with timeout and automatic failure marking on errors.
|
||||
pub async fn run_session_guarded(
|
||||
&self,
|
||||
session: &PentestSession,
|
||||
target: &DastTarget,
|
||||
initial_message: &str,
|
||||
) {
|
||||
let session_id = session.id;
|
||||
|
||||
match tokio::time::timeout(
|
||||
SESSION_TIMEOUT,
|
||||
self.run_session(session, target, initial_message),
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(Ok(())) => {
|
||||
tracing::info!(?session_id, "Pentest session completed successfully");
|
||||
}
|
||||
Ok(Err(e)) => {
|
||||
tracing::error!(?session_id, error = %e, "Pentest session failed");
|
||||
self.mark_session_failed(session_id, &format!("Error: {e}"))
|
||||
.await;
|
||||
let _ = self.event_tx.send(PentestEvent::Error {
|
||||
message: format!("Session failed: {e}"),
|
||||
});
|
||||
}
|
||||
Err(_) => {
|
||||
tracing::warn!(?session_id, "Pentest session timed out after 30 minutes");
|
||||
self.mark_session_failed(session_id, "Session timed out after 30 minutes")
|
||||
.await;
|
||||
let _ = self.event_tx.send(PentestEvent::Error {
|
||||
message: "Session timed out after 30 minutes".to_string(),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn mark_session_failed(
|
||||
&self,
|
||||
session_id: Option<mongodb::bson::oid::ObjectId>,
|
||||
reason: &str,
|
||||
) {
|
||||
if let Some(sid) = session_id {
|
||||
let _ = self
|
||||
.db
|
||||
.pentest_sessions()
|
||||
.update_one(
|
||||
doc! { "_id": sid },
|
||||
doc! { "$set": {
|
||||
"status": "failed",
|
||||
"completed_at": mongodb::bson::DateTime::now(),
|
||||
"error_message": reason,
|
||||
}},
|
||||
)
|
||||
.await;
|
||||
}
|
||||
}
|
||||
|
||||
async fn run_session(
|
||||
&self,
|
||||
session: &PentestSession,
|
||||
target: &DastTarget,
|
||||
|
||||
@@ -31,4 +31,16 @@ impl Database {
|
||||
pub fn dast_scan_runs(&self) -> Collection<DastScanRun> {
|
||||
self.inner.collection("dast_scan_runs")
|
||||
}
|
||||
|
||||
pub fn pentest_sessions(&self) -> Collection<PentestSession> {
|
||||
self.inner.collection("pentest_sessions")
|
||||
}
|
||||
|
||||
pub fn attack_chain_nodes(&self) -> Collection<AttackChainNode> {
|
||||
self.inner.collection("attack_chain_nodes")
|
||||
}
|
||||
|
||||
pub fn pentest_messages(&self) -> Collection<PentestMessage> {
|
||||
self.inner.collection("pentest_messages")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,7 +3,7 @@ use rmcp::{
|
||||
};
|
||||
|
||||
use crate::database::Database;
|
||||
use crate::tools::{dast, findings, sbom};
|
||||
use crate::tools::{dast, findings, pentest, sbom};
|
||||
|
||||
pub struct ComplianceMcpServer {
|
||||
db: Database,
|
||||
@@ -89,6 +89,54 @@ impl ComplianceMcpServer {
|
||||
) -> Result<CallToolResult, rmcp::ErrorData> {
|
||||
dast::dast_scan_summary(&self.db, params).await
|
||||
}
|
||||
|
||||
// ── Pentest ─────────────────────────────────────────────
|
||||
|
||||
#[tool(
|
||||
description = "List AI pentest sessions with optional filters for target, status, and strategy"
|
||||
)]
|
||||
async fn list_pentest_sessions(
|
||||
&self,
|
||||
Parameters(params): Parameters<pentest::ListPentestSessionsParams>,
|
||||
) -> Result<CallToolResult, rmcp::ErrorData> {
|
||||
pentest::list_pentest_sessions(&self.db, params).await
|
||||
}
|
||||
|
||||
#[tool(description = "Get a single AI pentest session by its ID")]
|
||||
async fn get_pentest_session(
|
||||
&self,
|
||||
Parameters(params): Parameters<pentest::GetPentestSessionParams>,
|
||||
) -> Result<CallToolResult, rmcp::ErrorData> {
|
||||
pentest::get_pentest_session(&self.db, params).await
|
||||
}
|
||||
|
||||
#[tool(
|
||||
description = "Get the attack chain DAG for a pentest session showing each tool invocation, its reasoning, and results"
|
||||
)]
|
||||
async fn get_attack_chain(
|
||||
&self,
|
||||
Parameters(params): Parameters<pentest::GetAttackChainParams>,
|
||||
) -> Result<CallToolResult, rmcp::ErrorData> {
|
||||
pentest::get_attack_chain(&self.db, params).await
|
||||
}
|
||||
|
||||
#[tool(description = "Get chat messages from a pentest session")]
|
||||
async fn get_pentest_messages(
|
||||
&self,
|
||||
Parameters(params): Parameters<pentest::GetPentestMessagesParams>,
|
||||
) -> Result<CallToolResult, rmcp::ErrorData> {
|
||||
pentest::get_pentest_messages(&self.db, params).await
|
||||
}
|
||||
|
||||
#[tool(
|
||||
description = "Get aggregated pentest statistics including running sessions, vulnerability counts, and severity distribution"
|
||||
)]
|
||||
async fn pentest_stats(
|
||||
&self,
|
||||
Parameters(params): Parameters<pentest::PentestStatsParams>,
|
||||
) -> Result<CallToolResult, rmcp::ErrorData> {
|
||||
pentest::pentest_stats(&self.db, params).await
|
||||
}
|
||||
}
|
||||
|
||||
#[tool_handler]
|
||||
@@ -101,7 +149,7 @@ impl ServerHandler for ComplianceMcpServer {
|
||||
.build(),
|
||||
server_info: Implementation::from_build_env(),
|
||||
instructions: Some(
|
||||
"Compliance Scanner MCP server. Query security findings, SBOM data, and DAST results."
|
||||
"Compliance Scanner MCP server. Query security findings, SBOM data, DAST results, and AI pentest sessions."
|
||||
.to_string(),
|
||||
),
|
||||
}
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
pub mod dast;
|
||||
pub mod findings;
|
||||
pub mod pentest;
|
||||
pub mod sbom;
|
||||
|
||||
261
compliance-mcp/src/tools/pentest.rs
Normal file
261
compliance-mcp/src/tools/pentest.rs
Normal file
@@ -0,0 +1,261 @@
|
||||
use mongodb::bson::doc;
|
||||
use rmcp::{model::*, ErrorData as McpError};
|
||||
use schemars::JsonSchema;
|
||||
use serde::Deserialize;
|
||||
|
||||
use crate::database::Database;
|
||||
|
||||
const MAX_LIMIT: i64 = 200;
|
||||
const DEFAULT_LIMIT: i64 = 50;
|
||||
|
||||
fn cap_limit(limit: Option<i64>) -> i64 {
|
||||
limit.unwrap_or(DEFAULT_LIMIT).clamp(1, MAX_LIMIT)
|
||||
}
|
||||
|
||||
// ── List Pentest Sessions ──────────────────────────────────────
|
||||
|
||||
#[derive(Debug, Deserialize, JsonSchema)]
|
||||
pub struct ListPentestSessionsParams {
|
||||
/// Filter by target ID
|
||||
pub target_id: Option<String>,
|
||||
/// Filter by status: running, paused, completed, failed
|
||||
pub status: Option<String>,
|
||||
/// Filter by strategy: quick, comprehensive, targeted, aggressive, stealth
|
||||
pub strategy: Option<String>,
|
||||
/// Maximum number of results (default 50, max 200)
|
||||
pub limit: Option<i64>,
|
||||
}
|
||||
|
||||
pub async fn list_pentest_sessions(
|
||||
db: &Database,
|
||||
params: ListPentestSessionsParams,
|
||||
) -> Result<CallToolResult, McpError> {
|
||||
let mut filter = doc! {};
|
||||
if let Some(ref target_id) = params.target_id {
|
||||
filter.insert("target_id", target_id);
|
||||
}
|
||||
if let Some(ref status) = params.status {
|
||||
filter.insert("status", status);
|
||||
}
|
||||
if let Some(ref strategy) = params.strategy {
|
||||
filter.insert("strategy", strategy);
|
||||
}
|
||||
|
||||
let limit = cap_limit(params.limit);
|
||||
|
||||
let mut cursor = db
|
||||
.pentest_sessions()
|
||||
.find(filter)
|
||||
.sort(doc! { "started_at": -1 })
|
||||
.limit(limit)
|
||||
.await
|
||||
.map_err(|e| McpError::internal_error(format!("DB error: {e}"), None))?;
|
||||
|
||||
let mut results = Vec::new();
|
||||
while cursor
|
||||
.advance()
|
||||
.await
|
||||
.map_err(|e| McpError::internal_error(format!("cursor error: {e}"), None))?
|
||||
{
|
||||
let session = cursor
|
||||
.deserialize_current()
|
||||
.map_err(|e| McpError::internal_error(format!("deserialize error: {e}"), None))?;
|
||||
results.push(session);
|
||||
}
|
||||
|
||||
let json = serde_json::to_string_pretty(&results)
|
||||
.map_err(|e| McpError::internal_error(format!("json error: {e}"), None))?;
|
||||
|
||||
Ok(CallToolResult::success(vec![Content::text(json)]))
|
||||
}
|
||||
|
||||
// ── Get Pentest Session ────────────────────────────────────────
|
||||
|
||||
#[derive(Debug, Deserialize, JsonSchema)]
|
||||
pub struct GetPentestSessionParams {
|
||||
/// Pentest session ID (MongoDB ObjectId hex string)
|
||||
pub id: String,
|
||||
}
|
||||
|
||||
pub async fn get_pentest_session(
|
||||
db: &Database,
|
||||
params: GetPentestSessionParams,
|
||||
) -> Result<CallToolResult, McpError> {
|
||||
let oid = bson::oid::ObjectId::parse_str(¶ms.id)
|
||||
.map_err(|e| McpError::invalid_params(format!("invalid id: {e}"), None))?;
|
||||
|
||||
let session = db
|
||||
.pentest_sessions()
|
||||
.find_one(doc! { "_id": oid })
|
||||
.await
|
||||
.map_err(|e| McpError::internal_error(format!("DB error: {e}"), None))?
|
||||
.ok_or_else(|| McpError::invalid_params("session not found", None))?;
|
||||
|
||||
let json = serde_json::to_string_pretty(&session)
|
||||
.map_err(|e| McpError::internal_error(format!("json error: {e}"), None))?;
|
||||
|
||||
Ok(CallToolResult::success(vec![Content::text(json)]))
|
||||
}
|
||||
|
||||
// ── Get Attack Chain ───────────────────────────────────────────
|
||||
|
||||
#[derive(Debug, Deserialize, JsonSchema)]
|
||||
pub struct GetAttackChainParams {
|
||||
/// Pentest session ID to get the attack chain for
|
||||
pub session_id: String,
|
||||
/// Maximum number of nodes (default 50, max 200)
|
||||
pub limit: Option<i64>,
|
||||
}
|
||||
|
||||
pub async fn get_attack_chain(
|
||||
db: &Database,
|
||||
params: GetAttackChainParams,
|
||||
) -> Result<CallToolResult, McpError> {
|
||||
let limit = cap_limit(params.limit);
|
||||
|
||||
let mut cursor = db
|
||||
.attack_chain_nodes()
|
||||
.find(doc! { "session_id": ¶ms.session_id })
|
||||
.sort(doc! { "started_at": 1 })
|
||||
.limit(limit)
|
||||
.await
|
||||
.map_err(|e| McpError::internal_error(format!("DB error: {e}"), None))?;
|
||||
|
||||
let mut results = Vec::new();
|
||||
while cursor
|
||||
.advance()
|
||||
.await
|
||||
.map_err(|e| McpError::internal_error(format!("cursor error: {e}"), None))?
|
||||
{
|
||||
let node = cursor
|
||||
.deserialize_current()
|
||||
.map_err(|e| McpError::internal_error(format!("deserialize error: {e}"), None))?;
|
||||
results.push(node);
|
||||
}
|
||||
|
||||
let json = serde_json::to_string_pretty(&results)
|
||||
.map_err(|e| McpError::internal_error(format!("json error: {e}"), None))?;
|
||||
|
||||
Ok(CallToolResult::success(vec![Content::text(json)]))
|
||||
}
|
||||
|
||||
// ── Get Pentest Messages ───────────────────────────────────────
|
||||
|
||||
#[derive(Debug, Deserialize, JsonSchema)]
|
||||
pub struct GetPentestMessagesParams {
|
||||
/// Pentest session ID
|
||||
pub session_id: String,
|
||||
/// Maximum number of messages (default 50, max 200)
|
||||
pub limit: Option<i64>,
|
||||
}
|
||||
|
||||
pub async fn get_pentest_messages(
|
||||
db: &Database,
|
||||
params: GetPentestMessagesParams,
|
||||
) -> Result<CallToolResult, McpError> {
|
||||
let limit = cap_limit(params.limit);
|
||||
|
||||
let mut cursor = db
|
||||
.pentest_messages()
|
||||
.find(doc! { "session_id": ¶ms.session_id })
|
||||
.sort(doc! { "created_at": 1 })
|
||||
.limit(limit)
|
||||
.await
|
||||
.map_err(|e| McpError::internal_error(format!("DB error: {e}"), None))?;
|
||||
|
||||
let mut results = Vec::new();
|
||||
while cursor
|
||||
.advance()
|
||||
.await
|
||||
.map_err(|e| McpError::internal_error(format!("cursor error: {e}"), None))?
|
||||
{
|
||||
let msg = cursor
|
||||
.deserialize_current()
|
||||
.map_err(|e| McpError::internal_error(format!("deserialize error: {e}"), None))?;
|
||||
results.push(msg);
|
||||
}
|
||||
|
||||
let json = serde_json::to_string_pretty(&results)
|
||||
.map_err(|e| McpError::internal_error(format!("json error: {e}"), None))?;
|
||||
|
||||
Ok(CallToolResult::success(vec![Content::text(json)]))
|
||||
}
|
||||
|
||||
// ── Pentest Stats ──────────────────────────────────────────────
|
||||
|
||||
#[derive(Debug, Deserialize, JsonSchema)]
|
||||
pub struct PentestStatsParams {
|
||||
/// Filter stats by target ID
|
||||
pub target_id: Option<String>,
|
||||
}
|
||||
|
||||
pub async fn pentest_stats(
|
||||
db: &Database,
|
||||
params: PentestStatsParams,
|
||||
) -> Result<CallToolResult, McpError> {
|
||||
let mut base_filter = doc! {};
|
||||
if let Some(ref target_id) = params.target_id {
|
||||
base_filter.insert("target_id", target_id);
|
||||
}
|
||||
|
||||
// Count running sessions
|
||||
let mut running_filter = base_filter.clone();
|
||||
running_filter.insert("status", "running");
|
||||
let running = db
|
||||
.pentest_sessions()
|
||||
.count_documents(running_filter)
|
||||
.await
|
||||
.map_err(|e| McpError::internal_error(format!("DB error: {e}"), None))?;
|
||||
|
||||
// Count total sessions
|
||||
let total_sessions = db
|
||||
.pentest_sessions()
|
||||
.count_documents(base_filter.clone())
|
||||
.await
|
||||
.map_err(|e| McpError::internal_error(format!("DB error: {e}"), None))?;
|
||||
|
||||
// Get findings for these sessions — query DAST findings with session_id set
|
||||
let mut findings_filter = doc! { "session_id": { "$ne": null } };
|
||||
if let Some(ref target_id) = params.target_id {
|
||||
findings_filter.insert("target_id", target_id);
|
||||
}
|
||||
let total_findings = db
|
||||
.dast_findings()
|
||||
.count_documents(findings_filter.clone())
|
||||
.await
|
||||
.map_err(|e| McpError::internal_error(format!("DB error: {e}"), None))?;
|
||||
|
||||
let mut exploitable_filter = findings_filter.clone();
|
||||
exploitable_filter.insert("exploitable", true);
|
||||
let exploitable = db
|
||||
.dast_findings()
|
||||
.count_documents(exploitable_filter)
|
||||
.await
|
||||
.map_err(|e| McpError::internal_error(format!("DB error: {e}"), None))?;
|
||||
|
||||
// Severity counts
|
||||
let mut severity = serde_json::Map::new();
|
||||
for sev in ["critical", "high", "medium", "low", "info"] {
|
||||
let mut sf = findings_filter.clone();
|
||||
sf.insert("severity", sev);
|
||||
let count = db
|
||||
.dast_findings()
|
||||
.count_documents(sf)
|
||||
.await
|
||||
.map_err(|e| McpError::internal_error(format!("DB error: {e}"), None))?;
|
||||
severity.insert(sev.to_string(), serde_json::json!(count));
|
||||
}
|
||||
|
||||
let summary = serde_json::json!({
|
||||
"running_sessions": running,
|
||||
"total_sessions": total_sessions,
|
||||
"total_findings": total_findings,
|
||||
"exploitable_findings": exploitable,
|
||||
"severity_distribution": severity,
|
||||
});
|
||||
|
||||
let json = serde_json::to_string_pretty(&summary)
|
||||
.map_err(|e| McpError::internal_error(format!("json error: {e}"), None))?;
|
||||
|
||||
Ok(CallToolResult::success(vec![Content::text(json)]))
|
||||
}
|
||||
Reference in New Issue
Block a user