use std::sync::Arc; use axum::extract::{Extension, Path, Query}; use axum::http::StatusCode; use axum::response::sse::{Event, Sse}; use axum::response::IntoResponse; use axum::Json; use futures_util::stream; use mongodb::bson::doc; use serde::Deserialize; use compliance_core::models::dast::DastFinding; use compliance_core::models::pentest::*; use crate::agent::ComplianceAgent; use crate::pentest::PentestOrchestrator; use super::{collect_cursor_async, ApiResponse, PaginationParams}; type AgentExt = Extension>; #[derive(Deserialize)] pub struct CreateSessionRequest { pub target_id: String, #[serde(default = "default_strategy")] pub strategy: String, pub message: Option, } fn default_strategy() -> String { "comprehensive".to_string() } #[derive(Deserialize)] pub struct SendMessageRequest { pub message: String, } /// POST /api/v1/pentest/sessions — Create a new pentest session and start the orchestrator #[tracing::instrument(skip_all)] pub async fn create_session( Extension(agent): AgentExt, Json(req): Json, ) -> Result>, (StatusCode, String)> { let oid = mongodb::bson::oid::ObjectId::parse_str(&req.target_id).map_err(|_| { ( StatusCode::BAD_REQUEST, "Invalid target_id format".to_string(), ) })?; // Look up the target let target = agent .db .dast_targets() .find_one(doc! { "_id": oid }) .await .map_err(|e| { ( StatusCode::INTERNAL_SERVER_ERROR, format!("Database error: {e}"), ) })? .ok_or_else(|| (StatusCode::NOT_FOUND, "Target not found".to_string()))?; // Parse strategy let strategy = match req.strategy.as_str() { "quick" => PentestStrategy::Quick, "targeted" => PentestStrategy::Targeted, "aggressive" => PentestStrategy::Aggressive, "stealth" => PentestStrategy::Stealth, _ => PentestStrategy::Comprehensive, }; // Create session let mut session = PentestSession::new(req.target_id.clone(), strategy); session.repo_id = target.repo_id.clone(); let insert_result = agent .db .pentest_sessions() .insert_one(&session) .await .map_err(|e| { ( StatusCode::INTERNAL_SERVER_ERROR, format!("Failed to create session: {e}"), ) })?; // Set the generated ID back on the session so the orchestrator has it session.id = insert_result.inserted_id.as_object_id(); let initial_message = req.message.unwrap_or_else(|| { format!( "Begin a {} penetration test against {} ({}). \ Identify vulnerabilities and provide evidence for each finding.", session.strategy, target.name, target.base_url, ) }); // Spawn the orchestrator on a background task let llm = agent.llm.clone(); let db = agent.db.clone(); let session_clone = session.clone(); let target_clone = target.clone(); tokio::spawn(async move { let orchestrator = PentestOrchestrator::new(llm, db); orchestrator .run_session_guarded(&session_clone, &target_clone, &initial_message) .await; }); Ok(Json(ApiResponse { data: session, total: None, page: None, })) } /// GET /api/v1/pentest/sessions — List pentest sessions #[tracing::instrument(skip_all)] pub async fn list_sessions( Extension(agent): AgentExt, Query(params): Query, ) -> Result>>, StatusCode> { let db = &agent.db; let skip = (params.page.saturating_sub(1)) * params.limit as u64; let total = db .pentest_sessions() .count_documents(doc! {}) .await .unwrap_or(0); let sessions = match db .pentest_sessions() .find(doc! {}) .sort(doc! { "started_at": -1 }) .skip(skip) .limit(params.limit) .await { Ok(cursor) => collect_cursor_async(cursor).await, Err(e) => { tracing::warn!("Failed to fetch pentest sessions: {e}"); Vec::new() } }; Ok(Json(ApiResponse { data: sessions, total: Some(total), page: Some(params.page), })) } /// GET /api/v1/pentest/sessions/:id — Get a single pentest session #[tracing::instrument(skip_all, fields(session_id = %id))] pub async fn get_session( Extension(agent): AgentExt, Path(id): Path, ) -> Result>, StatusCode> { let oid = mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?; let session = agent .db .pentest_sessions() .find_one(doc! { "_id": oid }) .await .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)? .ok_or(StatusCode::NOT_FOUND)?; Ok(Json(ApiResponse { data: session, total: None, page: None, })) } /// POST /api/v1/pentest/sessions/:id/chat — Send a user message and trigger next orchestrator iteration #[tracing::instrument(skip_all, fields(session_id = %id))] pub async fn send_message( Extension(agent): AgentExt, Path(id): Path, Json(req): Json, ) -> Result>, (StatusCode, String)> { let oid = mongodb::bson::oid::ObjectId::parse_str(&id) .map_err(|_| (StatusCode::BAD_REQUEST, "Invalid session ID".to_string()))?; // Verify session exists and is running let session = agent .db .pentest_sessions() .find_one(doc! { "_id": oid }) .await .map_err(|e| { ( StatusCode::INTERNAL_SERVER_ERROR, format!("Database error: {e}"), ) })? .ok_or_else(|| (StatusCode::NOT_FOUND, "Session not found".to_string()))?; if session.status != PentestStatus::Running && session.status != PentestStatus::Paused { return Err(( StatusCode::BAD_REQUEST, format!("Session is {}, cannot send messages", session.status), )); } // Look up the target let target_oid = mongodb::bson::oid::ObjectId::parse_str(&session.target_id).map_err(|_| { ( StatusCode::INTERNAL_SERVER_ERROR, "Invalid target_id in session".to_string(), ) })?; let target = agent .db .dast_targets() .find_one(doc! { "_id": target_oid }) .await .map_err(|e| { ( StatusCode::INTERNAL_SERVER_ERROR, format!("Database error: {e}"), ) })? .ok_or_else(|| { ( StatusCode::NOT_FOUND, "Target for session not found".to_string(), ) })?; // Store user message let session_id = id.clone(); let user_msg = PentestMessage::user(session_id.clone(), req.message.clone()); let _ = agent.db.pentest_messages().insert_one(&user_msg).await; let response_msg = user_msg.clone(); // Spawn orchestrator to continue the session let llm = agent.llm.clone(); let db = agent.db.clone(); let message = req.message.clone(); tokio::spawn(async move { let orchestrator = PentestOrchestrator::new(llm, db); orchestrator .run_session_guarded(&session, &target, &message) .await; }); Ok(Json(ApiResponse { data: response_msg, total: None, page: None, })) } /// GET /api/v1/pentest/sessions/:id/stream — SSE endpoint for real-time events /// /// Returns recent messages as SSE events (polling approach). /// True real-time streaming with broadcast channels will be added in a future iteration. #[tracing::instrument(skip_all, fields(session_id = %id))] pub async fn session_stream( Extension(agent): AgentExt, Path(id): Path, ) -> Result>>, StatusCode> { let oid = mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?; // Verify session exists let _session = agent .db .pentest_sessions() .find_one(doc! { "_id": oid }) .await .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)? .ok_or(StatusCode::NOT_FOUND)?; // Fetch recent messages for this session let messages: Vec = match agent .db .pentest_messages() .find(doc! { "session_id": &id }) .sort(doc! { "created_at": 1 }) .limit(100) .await { Ok(cursor) => collect_cursor_async(cursor).await, Err(_) => Vec::new(), }; // Fetch recent attack chain nodes let nodes: Vec = match agent .db .attack_chain_nodes() .find(doc! { "session_id": &id }) .sort(doc! { "started_at": 1 }) .limit(100) .await { Ok(cursor) => collect_cursor_async(cursor).await, Err(_) => Vec::new(), }; // Build SSE events from stored data let mut events: Vec> = Vec::new(); for msg in &messages { let event_data = serde_json::json!({ "type": "message", "role": msg.role, "content": msg.content, "created_at": msg.created_at.to_rfc3339(), }); if let Ok(data) = serde_json::to_string(&event_data) { events.push(Ok(Event::default().event("message").data(data))); } } for node in &nodes { let event_data = serde_json::json!({ "type": "tool_execution", "node_id": node.node_id, "tool_name": node.tool_name, "status": node.status, "findings_produced": node.findings_produced, }); if let Ok(data) = serde_json::to_string(&event_data) { events.push(Ok(Event::default().event("tool").data(data))); } } // Add session status event let session = agent .db .pentest_sessions() .find_one(doc! { "_id": oid }) .await .ok() .flatten(); if let Some(s) = session { let status_data = serde_json::json!({ "type": "status", "status": s.status, "findings_count": s.findings_count, "tool_invocations": s.tool_invocations, }); if let Ok(data) = serde_json::to_string(&status_data) { events.push(Ok(Event::default().event("status").data(data))); } } Ok(Sse::new(stream::iter(events))) } /// GET /api/v1/pentest/sessions/:id/attack-chain — Get attack chain nodes for a session #[tracing::instrument(skip_all, fields(session_id = %id))] pub async fn get_attack_chain( Extension(agent): AgentExt, Path(id): Path, ) -> Result>>, StatusCode> { // Verify the session ID is valid let _oid = mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?; let nodes = match agent .db .attack_chain_nodes() .find(doc! { "session_id": &id }) .sort(doc! { "started_at": 1 }) .await { Ok(cursor) => collect_cursor_async(cursor).await, Err(e) => { tracing::warn!("Failed to fetch attack chain nodes: {e}"); Vec::new() } }; let total = nodes.len() as u64; Ok(Json(ApiResponse { data: nodes, total: Some(total), page: None, })) } /// GET /api/v1/pentest/sessions/:id/messages — Get messages for a session #[tracing::instrument(skip_all, fields(session_id = %id))] pub async fn get_messages( Extension(agent): AgentExt, Path(id): Path, Query(params): Query, ) -> Result>>, StatusCode> { let _oid = mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?; let skip = (params.page.saturating_sub(1)) * params.limit as u64; let total = agent .db .pentest_messages() .count_documents(doc! { "session_id": &id }) .await .unwrap_or(0); let messages = match agent .db .pentest_messages() .find(doc! { "session_id": &id }) .sort(doc! { "created_at": 1 }) .skip(skip) .limit(params.limit) .await { Ok(cursor) => collect_cursor_async(cursor).await, Err(e) => { tracing::warn!("Failed to fetch pentest messages: {e}"); Vec::new() } }; Ok(Json(ApiResponse { data: messages, total: Some(total), page: Some(params.page), })) } /// GET /api/v1/pentest/stats — Aggregated pentest statistics #[tracing::instrument(skip_all)] pub async fn pentest_stats( Extension(agent): AgentExt, ) -> Result>, StatusCode> { let db = &agent.db; let running_sessions = db .pentest_sessions() .count_documents(doc! { "status": "running" }) .await .unwrap_or(0) as u32; // Count DAST findings from pentest sessions let total_vulnerabilities = db .dast_findings() .count_documents(doc! { "session_id": { "$exists": true, "$ne": null } }) .await .unwrap_or(0) as u32; // Aggregate tool invocations from all sessions let sessions: Vec = match db.pentest_sessions().find(doc! {}).await { Ok(cursor) => collect_cursor_async(cursor).await, Err(_) => Vec::new(), }; let total_tool_invocations: u32 = sessions.iter().map(|s| s.tool_invocations).sum(); let total_successes: u32 = sessions.iter().map(|s| s.tool_successes).sum(); let tool_success_rate = if total_tool_invocations == 0 { 100.0 } else { (total_successes as f64 / total_tool_invocations as f64) * 100.0 }; // Severity distribution from pentest-related DAST findings let critical = db .dast_findings() .count_documents(doc! { "session_id": { "$exists": true, "$ne": null }, "severity": "critical" }) .await .unwrap_or(0) as u32; let high = db .dast_findings() .count_documents(doc! { "session_id": { "$exists": true, "$ne": null }, "severity": "high" }) .await .unwrap_or(0) as u32; let medium = db .dast_findings() .count_documents(doc! { "session_id": { "$exists": true, "$ne": null }, "severity": "medium" }) .await .unwrap_or(0) as u32; let low = db .dast_findings() .count_documents(doc! { "session_id": { "$exists": true, "$ne": null }, "severity": "low" }) .await .unwrap_or(0) as u32; let info = db .dast_findings() .count_documents(doc! { "session_id": { "$exists": true, "$ne": null }, "severity": "info" }) .await .unwrap_or(0) as u32; Ok(Json(ApiResponse { data: PentestStats { running_sessions, total_vulnerabilities, total_tool_invocations, tool_success_rate, severity_distribution: SeverityDistribution { critical, high, medium, low, info, }, }, total: None, page: None, })) } /// GET /api/v1/pentest/sessions/:id/findings — Get DAST findings for a pentest session #[tracing::instrument(skip_all, fields(session_id = %id))] pub async fn get_session_findings( Extension(agent): AgentExt, Path(id): Path, Query(params): Query, ) -> Result>>, StatusCode> { let _oid = mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?; let skip = (params.page.saturating_sub(1)) * params.limit as u64; let total = agent .db .dast_findings() .count_documents(doc! { "session_id": &id }) .await .unwrap_or(0); let findings = match agent .db .dast_findings() .find(doc! { "session_id": &id }) .sort(doc! { "created_at": -1 }) .skip(skip) .limit(params.limit) .await { Ok(cursor) => collect_cursor_async(cursor).await, Err(e) => { tracing::warn!("Failed to fetch pentest session findings: {e}"); Vec::new() } }; Ok(Json(ApiResponse { data: findings, total: Some(total), page: Some(params.page), })) } #[derive(Deserialize)] pub struct ExportParams { #[serde(default = "default_export_format")] pub format: String, } fn default_export_format() -> String { "json".to_string() } /// GET /api/v1/pentest/sessions/:id/export?format=json|markdown — Export a session report #[tracing::instrument(skip_all, fields(session_id = %id))] pub async fn export_session_report( Extension(agent): AgentExt, Path(id): Path, Query(params): Query, ) -> Result { let oid = mongodb::bson::oid::ObjectId::parse_str(&id) .map_err(|_| (StatusCode::BAD_REQUEST, "Invalid session ID".to_string()))?; // Fetch session let session = agent .db .pentest_sessions() .find_one(doc! { "_id": oid }) .await .map_err(|e| { ( StatusCode::INTERNAL_SERVER_ERROR, format!("Database error: {e}"), ) })? .ok_or_else(|| (StatusCode::NOT_FOUND, "Session not found".to_string()))?; // Fetch messages let messages: Vec = match agent .db .pentest_messages() .find(doc! { "session_id": &id }) .sort(doc! { "created_at": 1 }) .await { Ok(cursor) => collect_cursor_async(cursor).await, Err(_) => Vec::new(), }; // Fetch attack chain nodes let nodes: Vec = match agent .db .attack_chain_nodes() .find(doc! { "session_id": &id }) .sort(doc! { "started_at": 1 }) .await { Ok(cursor) => collect_cursor_async(cursor).await, Err(_) => Vec::new(), }; // Fetch DAST findings for this session let findings: Vec = match agent .db .dast_findings() .find(doc! { "session_id": &id }) .sort(doc! { "created_at": -1 }) .await { Ok(cursor) => collect_cursor_async(cursor).await, Err(_) => Vec::new(), }; // Compute severity counts let critical = findings.iter().filter(|f| f.severity.to_string() == "critical").count(); let high = findings.iter().filter(|f| f.severity.to_string() == "high").count(); let medium = findings.iter().filter(|f| f.severity.to_string() == "medium").count(); let low = findings.iter().filter(|f| f.severity.to_string() == "low").count(); let info = findings.iter().filter(|f| f.severity.to_string() == "info").count(); match params.format.as_str() { "markdown" => { let mut md = String::new(); md.push_str("# Penetration Test Report\n\n"); // Executive summary md.push_str("## Executive Summary\n\n"); md.push_str(&format!("| Field | Value |\n")); md.push_str("| --- | --- |\n"); md.push_str(&format!("| **Session ID** | {} |\n", id)); md.push_str(&format!("| **Status** | {} |\n", session.status)); md.push_str(&format!("| **Strategy** | {} |\n", session.strategy)); md.push_str(&format!("| **Target ID** | {} |\n", session.target_id)); md.push_str(&format!( "| **Started** | {} |\n", session.started_at.to_rfc3339() )); if let Some(ref completed) = session.completed_at { md.push_str(&format!( "| **Completed** | {} |\n", completed.to_rfc3339() )); } md.push_str(&format!( "| **Tool Invocations** | {} |\n", session.tool_invocations )); md.push_str(&format!( "| **Success Rate** | {:.1}% |\n", session.success_rate() )); md.push('\n'); // Findings by severity md.push_str("## Findings Summary\n\n"); md.push_str(&format!( "| Severity | Count |\n| --- | --- |\n| Critical | {} |\n| High | {} |\n| Medium | {} |\n| Low | {} |\n| Info | {} |\n| **Total** | **{}** |\n\n", critical, high, medium, low, info, findings.len() )); // Findings table if !findings.is_empty() { md.push_str("## Findings Detail\n\n"); md.push_str("| # | Severity | Title | Endpoint | Exploitable |\n"); md.push_str("| --- | --- | --- | --- | --- |\n"); for (i, f) in findings.iter().enumerate() { md.push_str(&format!( "| {} | {} | {} | {} {} | {} |\n", i + 1, f.severity, f.title, f.method, f.endpoint, if f.exploitable { "Yes" } else { "No" }, )); } md.push('\n'); } // Attack chain timeline if !nodes.is_empty() { md.push_str("## Attack Chain Timeline\n\n"); md.push_str("| # | Tool | Status | Findings | Reasoning |\n"); md.push_str("| --- | --- | --- | --- | --- |\n"); for (i, node) in nodes.iter().enumerate() { let reasoning_short = if node.llm_reasoning.len() > 80 { format!("{}...", &node.llm_reasoning[..80]) } else { node.llm_reasoning.clone() }; md.push_str(&format!( "| {} | {} | {} | {} | {} |\n", i + 1, node.tool_name, format!("{:?}", node.status).to_lowercase(), node.findings_produced.len(), reasoning_short, )); } md.push('\n'); } // Statistics md.push_str("## Statistics\n\n"); md.push_str(&format!("- **Total Findings:** {}\n", findings.len())); md.push_str(&format!("- **Exploitable Findings:** {}\n", session.exploitable_count)); md.push_str(&format!("- **Attack Chain Steps:** {}\n", nodes.len())); md.push_str(&format!("- **Messages Exchanged:** {}\n", messages.len())); md.push_str(&format!("- **Tool Invocations:** {}\n", session.tool_invocations)); md.push_str(&format!("- **Tool Success Rate:** {:.1}%\n", session.success_rate())); Ok(( StatusCode::OK, [ (axum::http::header::CONTENT_TYPE, "text/markdown; charset=utf-8"), ], md, ) .into_response()) } _ => { // JSON format let report = serde_json::json!({ "session": { "id": id, "target_id": session.target_id, "repo_id": session.repo_id, "status": session.status, "strategy": session.strategy, "started_at": session.started_at.to_rfc3339(), "completed_at": session.completed_at.map(|d| d.to_rfc3339()), "tool_invocations": session.tool_invocations, "tool_successes": session.tool_successes, "success_rate": session.success_rate(), "findings_count": session.findings_count, "exploitable_count": session.exploitable_count, }, "findings": findings, "attack_chain": nodes, "messages": messages, "summary": { "total_findings": findings.len(), "severity_distribution": { "critical": critical, "high": high, "medium": medium, "low": low, "info": info, }, "attack_chain_steps": nodes.len(), "messages_exchanged": messages.len(), }, }); Ok(Json(report).into_response()) } } }