feat: AI-driven automated penetration testing system

Add a complete AI pentest system where Claude autonomously drives security
testing via tool-calling. The LLM selects from 16 tools, chains results,
and builds an attack chain DAG.

Core:
- PentestTool trait (dyn-compatible) with PentestToolContext/Result
- PentestSession, AttackChainNode, PentestMessage, PentestEvent models
- 10 new DastVulnType variants (DNS, DMARC, TLS, cookies, CSP, CORS, etc.)
- LLM client chat_with_tools() for OpenAI-compatible tool calling

Tools (16 total):
- 5 agent wrappers: SQL injection, XSS, auth bypass, SSRF, API fuzzer
- 11 new infra tools: DNS checker, DMARC checker, TLS analyzer,
  security headers, cookie analyzer, CSP analyzer, rate limit tester,
  console log detector, CORS checker, OpenAPI parser, recon
- ToolRegistry for tool lookup and LLM definition generation

Orchestrator:
- PentestOrchestrator with iterative tool-calling loop (max 50 rounds)
- Attack chain node recording per tool invocation
- SSE event broadcasting for real-time progress
- Strategy-aware system prompts (quick/comprehensive/targeted/aggressive/stealth)

API (9 endpoints):
- POST/GET /pentest/sessions, GET /pentest/sessions/:id
- POST /pentest/sessions/:id/chat, GET /pentest/sessions/:id/stream
- GET /pentest/sessions/:id/attack-chain, messages, findings
- GET /pentest/stats

Dashboard:
- Pentest dashboard with stat cards, severity distribution, session list
- Chat-based session page with split layout (chat + findings/attack chain)
- Inline tool execution indicators, auto-polling, new session modal
- Sidebar navigation item

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Sharang Parnerkar
2026-03-11 19:23:21 +01:00
parent 76260acc76
commit 71d8741e10
40 changed files with 7546 additions and 90 deletions

View File

@@ -1,6 +1,7 @@
pub mod chat;
pub mod dast;
pub mod graph;
pub mod pentest;
use std::sync::Arc;
@@ -1108,7 +1109,7 @@ pub async fn list_scan_runs(
}))
}
async fn collect_cursor_async<T: serde::de::DeserializeOwned + Unpin + Send>(
pub(crate) async fn collect_cursor_async<T: serde::de::DeserializeOwned + Unpin + Send>(
mut cursor: mongodb::Cursor<T>,
) -> Vec<T> {
use futures_util::StreamExt;

View File

@@ -0,0 +1,564 @@
use std::sync::Arc;
use axum::extract::{Extension, Path, Query};
use axum::http::StatusCode;
use axum::response::sse::{Event, Sse};
use axum::Json;
use futures_util::stream;
use mongodb::bson::doc;
use serde::Deserialize;
use compliance_core::models::dast::DastFinding;
use compliance_core::models::pentest::*;
use crate::agent::ComplianceAgent;
use crate::pentest::PentestOrchestrator;
use super::{collect_cursor_async, ApiResponse, PaginationParams};
type AgentExt = Extension<Arc<ComplianceAgent>>;
#[derive(Deserialize)]
pub struct CreateSessionRequest {
pub target_id: String,
#[serde(default = "default_strategy")]
pub strategy: String,
pub message: Option<String>,
}
fn default_strategy() -> String {
"comprehensive".to_string()
}
#[derive(Deserialize)]
pub struct SendMessageRequest {
pub message: String,
}
/// POST /api/v1/pentest/sessions — Create a new pentest session and start the orchestrator
#[tracing::instrument(skip_all)]
pub async fn create_session(
Extension(agent): AgentExt,
Json(req): Json<CreateSessionRequest>,
) -> Result<Json<ApiResponse<PentestSession>>, (StatusCode, String)> {
let oid = mongodb::bson::oid::ObjectId::parse_str(&req.target_id).map_err(|_| {
(
StatusCode::BAD_REQUEST,
"Invalid target_id format".to_string(),
)
})?;
// Look up the target
let target = agent
.db
.dast_targets()
.find_one(doc! { "_id": oid })
.await
.map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("Database error: {e}"),
)
})?
.ok_or_else(|| (StatusCode::NOT_FOUND, "Target not found".to_string()))?;
// Parse strategy
let strategy = match req.strategy.as_str() {
"quick" => PentestStrategy::Quick,
"targeted" => PentestStrategy::Targeted,
"aggressive" => PentestStrategy::Aggressive,
"stealth" => PentestStrategy::Stealth,
_ => PentestStrategy::Comprehensive,
};
// Create session
let mut session = PentestSession::new(req.target_id.clone(), strategy);
session.repo_id = target.repo_id.clone();
agent
.db
.pentest_sessions()
.insert_one(&session)
.await
.map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("Failed to create session: {e}"),
)
})?;
let initial_message = req.message.unwrap_or_else(|| {
format!(
"Begin a {} penetration test against {} ({}). \
Identify vulnerabilities and provide evidence for each finding.",
session.strategy, target.name, target.base_url,
)
});
// Spawn the orchestrator on a background task
let llm = agent.llm.clone();
let db = agent.db.clone();
let session_clone = session.clone();
let target_clone = target.clone();
tokio::spawn(async move {
let orchestrator = PentestOrchestrator::new(llm, db);
if let Err(e) = orchestrator
.run_session(&session_clone, &target_clone, &initial_message)
.await
{
tracing::error!(
"Pentest orchestrator failed for session {}: {e}",
session_clone
.id
.map(|oid| oid.to_hex())
.unwrap_or_default()
);
}
});
Ok(Json(ApiResponse {
data: session,
total: None,
page: None,
}))
}
/// GET /api/v1/pentest/sessions — List pentest sessions
#[tracing::instrument(skip_all)]
pub async fn list_sessions(
Extension(agent): AgentExt,
Query(params): Query<PaginationParams>,
) -> Result<Json<ApiResponse<Vec<PentestSession>>>, StatusCode> {
let db = &agent.db;
let skip = (params.page.saturating_sub(1)) * params.limit as u64;
let total = db
.pentest_sessions()
.count_documents(doc! {})
.await
.unwrap_or(0);
let sessions = match db
.pentest_sessions()
.find(doc! {})
.sort(doc! { "started_at": -1 })
.skip(skip)
.limit(params.limit)
.await
{
Ok(cursor) => collect_cursor_async(cursor).await,
Err(e) => {
tracing::warn!("Failed to fetch pentest sessions: {e}");
Vec::new()
}
};
Ok(Json(ApiResponse {
data: sessions,
total: Some(total),
page: Some(params.page),
}))
}
/// GET /api/v1/pentest/sessions/:id — Get a single pentest session
#[tracing::instrument(skip_all, fields(session_id = %id))]
pub async fn get_session(
Extension(agent): AgentExt,
Path(id): Path<String>,
) -> Result<Json<ApiResponse<PentestSession>>, StatusCode> {
let oid =
mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?;
let session = agent
.db
.pentest_sessions()
.find_one(doc! { "_id": oid })
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
.ok_or(StatusCode::NOT_FOUND)?;
Ok(Json(ApiResponse {
data: session,
total: None,
page: None,
}))
}
/// POST /api/v1/pentest/sessions/:id/chat — Send a user message and trigger next orchestrator iteration
#[tracing::instrument(skip_all, fields(session_id = %id))]
pub async fn send_message(
Extension(agent): AgentExt,
Path(id): Path<String>,
Json(req): Json<SendMessageRequest>,
) -> Result<Json<ApiResponse<PentestMessage>>, (StatusCode, String)> {
let oid = mongodb::bson::oid::ObjectId::parse_str(&id)
.map_err(|_| (StatusCode::BAD_REQUEST, "Invalid session ID".to_string()))?;
// Verify session exists and is running
let session = agent
.db
.pentest_sessions()
.find_one(doc! { "_id": oid })
.await
.map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("Database error: {e}"),
)
})?
.ok_or_else(|| (StatusCode::NOT_FOUND, "Session not found".to_string()))?;
if session.status != PentestStatus::Running && session.status != PentestStatus::Paused {
return Err((
StatusCode::BAD_REQUEST,
format!("Session is {}, cannot send messages", session.status),
));
}
// Look up the target
let target_oid =
mongodb::bson::oid::ObjectId::parse_str(&session.target_id).map_err(|_| {
(
StatusCode::INTERNAL_SERVER_ERROR,
"Invalid target_id in session".to_string(),
)
})?;
let target = agent
.db
.dast_targets()
.find_one(doc! { "_id": target_oid })
.await
.map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("Database error: {e}"),
)
})?
.ok_or_else(|| {
(
StatusCode::NOT_FOUND,
"Target for session not found".to_string(),
)
})?;
// Store user message
let session_id = id.clone();
let user_msg = PentestMessage::user(session_id.clone(), req.message.clone());
let _ = agent.db.pentest_messages().insert_one(&user_msg).await;
let response_msg = user_msg.clone();
// Spawn orchestrator to continue the session
let llm = agent.llm.clone();
let db = agent.db.clone();
let message = req.message.clone();
tokio::spawn(async move {
let orchestrator = PentestOrchestrator::new(llm, db);
if let Err(e) = orchestrator.run_session(&session, &target, &message).await {
tracing::error!("Pentest orchestrator failed for session {session_id}: {e}");
}
});
Ok(Json(ApiResponse {
data: response_msg,
total: None,
page: None,
}))
}
/// GET /api/v1/pentest/sessions/:id/stream — SSE endpoint for real-time events
///
/// Returns recent messages as SSE events (polling approach).
/// True real-time streaming with broadcast channels will be added in a future iteration.
#[tracing::instrument(skip_all, fields(session_id = %id))]
pub async fn session_stream(
Extension(agent): AgentExt,
Path(id): Path<String>,
) -> Result<Sse<impl futures_util::Stream<Item = Result<Event, std::convert::Infallible>>>, StatusCode>
{
let oid =
mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?;
// Verify session exists
let _session = agent
.db
.pentest_sessions()
.find_one(doc! { "_id": oid })
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
.ok_or(StatusCode::NOT_FOUND)?;
// Fetch recent messages for this session
let messages: Vec<PentestMessage> = match agent
.db
.pentest_messages()
.find(doc! { "session_id": &id })
.sort(doc! { "created_at": 1 })
.limit(100)
.await
{
Ok(cursor) => collect_cursor_async(cursor).await,
Err(_) => Vec::new(),
};
// Fetch recent attack chain nodes
let nodes: Vec<AttackChainNode> = match agent
.db
.attack_chain_nodes()
.find(doc! { "session_id": &id })
.sort(doc! { "started_at": 1 })
.limit(100)
.await
{
Ok(cursor) => collect_cursor_async(cursor).await,
Err(_) => Vec::new(),
};
// Build SSE events from stored data
let mut events: Vec<Result<Event, std::convert::Infallible>> = Vec::new();
for msg in &messages {
let event_data = serde_json::json!({
"type": "message",
"role": msg.role,
"content": msg.content,
"created_at": msg.created_at.to_rfc3339(),
});
if let Ok(data) = serde_json::to_string(&event_data) {
events.push(Ok(Event::default().event("message").data(data)));
}
}
for node in &nodes {
let event_data = serde_json::json!({
"type": "tool_execution",
"node_id": node.node_id,
"tool_name": node.tool_name,
"status": node.status,
"findings_produced": node.findings_produced,
});
if let Ok(data) = serde_json::to_string(&event_data) {
events.push(Ok(Event::default().event("tool").data(data)));
}
}
// Add session status event
let session = agent
.db
.pentest_sessions()
.find_one(doc! { "_id": oid })
.await
.ok()
.flatten();
if let Some(s) = session {
let status_data = serde_json::json!({
"type": "status",
"status": s.status,
"findings_count": s.findings_count,
"tool_invocations": s.tool_invocations,
});
if let Ok(data) = serde_json::to_string(&status_data) {
events.push(Ok(Event::default().event("status").data(data)));
}
}
Ok(Sse::new(stream::iter(events)))
}
/// GET /api/v1/pentest/sessions/:id/attack-chain — Get attack chain nodes for a session
#[tracing::instrument(skip_all, fields(session_id = %id))]
pub async fn get_attack_chain(
Extension(agent): AgentExt,
Path(id): Path<String>,
) -> Result<Json<ApiResponse<Vec<AttackChainNode>>>, StatusCode> {
// Verify the session ID is valid
let _oid =
mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?;
let nodes = match agent
.db
.attack_chain_nodes()
.find(doc! { "session_id": &id })
.sort(doc! { "started_at": 1 })
.await
{
Ok(cursor) => collect_cursor_async(cursor).await,
Err(e) => {
tracing::warn!("Failed to fetch attack chain nodes: {e}");
Vec::new()
}
};
let total = nodes.len() as u64;
Ok(Json(ApiResponse {
data: nodes,
total: Some(total),
page: None,
}))
}
/// GET /api/v1/pentest/sessions/:id/messages — Get messages for a session
#[tracing::instrument(skip_all, fields(session_id = %id))]
pub async fn get_messages(
Extension(agent): AgentExt,
Path(id): Path<String>,
Query(params): Query<PaginationParams>,
) -> Result<Json<ApiResponse<Vec<PentestMessage>>>, StatusCode> {
let _oid =
mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?;
let skip = (params.page.saturating_sub(1)) * params.limit as u64;
let total = agent
.db
.pentest_messages()
.count_documents(doc! { "session_id": &id })
.await
.unwrap_or(0);
let messages = match agent
.db
.pentest_messages()
.find(doc! { "session_id": &id })
.sort(doc! { "created_at": 1 })
.skip(skip)
.limit(params.limit)
.await
{
Ok(cursor) => collect_cursor_async(cursor).await,
Err(e) => {
tracing::warn!("Failed to fetch pentest messages: {e}");
Vec::new()
}
};
Ok(Json(ApiResponse {
data: messages,
total: Some(total),
page: Some(params.page),
}))
}
/// GET /api/v1/pentest/stats — Aggregated pentest statistics
#[tracing::instrument(skip_all)]
pub async fn pentest_stats(
Extension(agent): AgentExt,
) -> Result<Json<ApiResponse<PentestStats>>, StatusCode> {
let db = &agent.db;
let running_sessions = db
.pentest_sessions()
.count_documents(doc! { "status": "running" })
.await
.unwrap_or(0) as u32;
// Count DAST findings from pentest sessions
let total_vulnerabilities = db
.dast_findings()
.count_documents(doc! { "session_id": { "$exists": true, "$ne": null } })
.await
.unwrap_or(0) as u32;
// Aggregate tool invocations from all sessions
let sessions: Vec<PentestSession> = match db.pentest_sessions().find(doc! {}).await {
Ok(cursor) => collect_cursor_async(cursor).await,
Err(_) => Vec::new(),
};
let total_tool_invocations: u32 = sessions.iter().map(|s| s.tool_invocations).sum();
let total_successes: u32 = sessions.iter().map(|s| s.tool_successes).sum();
let tool_success_rate = if total_tool_invocations == 0 {
100.0
} else {
(total_successes as f64 / total_tool_invocations as f64) * 100.0
};
// Severity distribution from pentest-related DAST findings
let pentest_filter = doc! { "session_id": { "$exists": true, "$ne": null } };
let critical = db
.dast_findings()
.count_documents(doc! { "session_id": { "$exists": true, "$ne": null }, "severity": "critical" })
.await
.unwrap_or(0) as u32;
let high = db
.dast_findings()
.count_documents(doc! { "session_id": { "$exists": true, "$ne": null }, "severity": "high" })
.await
.unwrap_or(0) as u32;
let medium = db
.dast_findings()
.count_documents(doc! { "session_id": { "$exists": true, "$ne": null }, "severity": "medium" })
.await
.unwrap_or(0) as u32;
let low = db
.dast_findings()
.count_documents(doc! { "session_id": { "$exists": true, "$ne": null }, "severity": "low" })
.await
.unwrap_or(0) as u32;
let info = db
.dast_findings()
.count_documents(doc! { "session_id": { "$exists": true, "$ne": null }, "severity": "info" })
.await
.unwrap_or(0) as u32;
let _ = pentest_filter; // used above inline
Ok(Json(ApiResponse {
data: PentestStats {
running_sessions,
total_vulnerabilities,
total_tool_invocations,
tool_success_rate,
severity_distribution: SeverityDistribution {
critical,
high,
medium,
low,
info,
},
},
total: None,
page: None,
}))
}
/// GET /api/v1/pentest/sessions/:id/findings — Get DAST findings for a pentest session
#[tracing::instrument(skip_all, fields(session_id = %id))]
pub async fn get_session_findings(
Extension(agent): AgentExt,
Path(id): Path<String>,
Query(params): Query<PaginationParams>,
) -> Result<Json<ApiResponse<Vec<DastFinding>>>, StatusCode> {
let _oid =
mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?;
let skip = (params.page.saturating_sub(1)) * params.limit as u64;
let total = agent
.db
.dast_findings()
.count_documents(doc! { "session_id": &id })
.await
.unwrap_or(0);
let findings = match agent
.db
.dast_findings()
.find(doc! { "session_id": &id })
.sort(doc! { "created_at": -1 })
.skip(skip)
.limit(params.limit)
.await
{
Ok(cursor) => collect_cursor_async(cursor).await,
Err(e) => {
tracing::warn!("Failed to fetch pentest session findings: {e}");
Vec::new()
}
};
Ok(Json(ApiResponse {
data: findings,
total: Some(total),
page: Some(params.page),
}))
}

View File

@@ -99,6 +99,36 @@ pub fn build_router() -> Router {
"/api/v1/chat/{repo_id}/status",
get(handlers::chat::embedding_status),
)
// Pentest API endpoints
.route(
"/api/v1/pentest/sessions",
get(handlers::pentest::list_sessions).post(handlers::pentest::create_session),
)
.route(
"/api/v1/pentest/sessions/{id}",
get(handlers::pentest::get_session),
)
.route(
"/api/v1/pentest/sessions/{id}/chat",
post(handlers::pentest::send_message),
)
.route(
"/api/v1/pentest/sessions/{id}/stream",
get(handlers::pentest::session_stream),
)
.route(
"/api/v1/pentest/sessions/{id}/attack-chain",
get(handlers::pentest::get_attack_chain),
)
.route(
"/api/v1/pentest/sessions/{id}/messages",
get(handlers::pentest::get_messages),
)
.route(
"/api/v1/pentest/sessions/{id}/findings",
get(handlers::pentest::get_session_findings),
)
.route("/api/v1/pentest/stats", get(handlers::pentest::pentest_stats))
// Webhook endpoints (proxied through dashboard)
.route(
"/webhook/github/{repo_id}",