use std::sync::Arc; use axum::extract::{Extension, Path, Query}; use axum::http::StatusCode; use axum::response::sse::{Event, Sse}; use axum::Json; use futures_util::stream; use mongodb::bson::doc; use serde::Deserialize; use compliance_core::models::dast::DastFinding; use compliance_core::models::pentest::*; use crate::agent::ComplianceAgent; use crate::pentest::PentestOrchestrator; use super::{collect_cursor_async, ApiResponse, PaginationParams}; type AgentExt = Extension>; #[derive(Deserialize)] pub struct CreateSessionRequest { pub target_id: String, #[serde(default = "default_strategy")] pub strategy: String, pub message: Option, } fn default_strategy() -> String { "comprehensive".to_string() } #[derive(Deserialize)] pub struct SendMessageRequest { pub message: String, } /// POST /api/v1/pentest/sessions — Create a new pentest session and start the orchestrator #[tracing::instrument(skip_all)] pub async fn create_session( Extension(agent): AgentExt, Json(req): Json, ) -> Result>, (StatusCode, String)> { let oid = mongodb::bson::oid::ObjectId::parse_str(&req.target_id).map_err(|_| { ( StatusCode::BAD_REQUEST, "Invalid target_id format".to_string(), ) })?; // Look up the target let target = agent .db .dast_targets() .find_one(doc! { "_id": oid }) .await .map_err(|e| { ( StatusCode::INTERNAL_SERVER_ERROR, format!("Database error: {e}"), ) })? .ok_or_else(|| (StatusCode::NOT_FOUND, "Target not found".to_string()))?; // Parse strategy let strategy = match req.strategy.as_str() { "quick" => PentestStrategy::Quick, "targeted" => PentestStrategy::Targeted, "aggressive" => PentestStrategy::Aggressive, "stealth" => PentestStrategy::Stealth, _ => PentestStrategy::Comprehensive, }; // Create session let mut session = PentestSession::new(req.target_id.clone(), strategy); session.repo_id = target.repo_id.clone(); agent .db .pentest_sessions() .insert_one(&session) .await .map_err(|e| { ( StatusCode::INTERNAL_SERVER_ERROR, format!("Failed to create session: {e}"), ) })?; let initial_message = req.message.unwrap_or_else(|| { format!( "Begin a {} penetration test against {} ({}). \ Identify vulnerabilities and provide evidence for each finding.", session.strategy, target.name, target.base_url, ) }); // Spawn the orchestrator on a background task let llm = agent.llm.clone(); let db = agent.db.clone(); let session_clone = session.clone(); let target_clone = target.clone(); tokio::spawn(async move { let orchestrator = PentestOrchestrator::new(llm, db); orchestrator .run_session_guarded(&session_clone, &target_clone, &initial_message) .await; }); Ok(Json(ApiResponse { data: session, total: None, page: None, })) } /// GET /api/v1/pentest/sessions — List pentest sessions #[tracing::instrument(skip_all)] pub async fn list_sessions( Extension(agent): AgentExt, Query(params): Query, ) -> Result>>, StatusCode> { let db = &agent.db; let skip = (params.page.saturating_sub(1)) * params.limit as u64; let total = db .pentest_sessions() .count_documents(doc! {}) .await .unwrap_or(0); let sessions = match db .pentest_sessions() .find(doc! {}) .sort(doc! { "started_at": -1 }) .skip(skip) .limit(params.limit) .await { Ok(cursor) => collect_cursor_async(cursor).await, Err(e) => { tracing::warn!("Failed to fetch pentest sessions: {e}"); Vec::new() } }; Ok(Json(ApiResponse { data: sessions, total: Some(total), page: Some(params.page), })) } /// GET /api/v1/pentest/sessions/:id — Get a single pentest session #[tracing::instrument(skip_all, fields(session_id = %id))] pub async fn get_session( Extension(agent): AgentExt, Path(id): Path, ) -> Result>, StatusCode> { let oid = mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?; let session = agent .db .pentest_sessions() .find_one(doc! { "_id": oid }) .await .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)? .ok_or(StatusCode::NOT_FOUND)?; Ok(Json(ApiResponse { data: session, total: None, page: None, })) } /// POST /api/v1/pentest/sessions/:id/chat — Send a user message and trigger next orchestrator iteration #[tracing::instrument(skip_all, fields(session_id = %id))] pub async fn send_message( Extension(agent): AgentExt, Path(id): Path, Json(req): Json, ) -> Result>, (StatusCode, String)> { let oid = mongodb::bson::oid::ObjectId::parse_str(&id) .map_err(|_| (StatusCode::BAD_REQUEST, "Invalid session ID".to_string()))?; // Verify session exists and is running let session = agent .db .pentest_sessions() .find_one(doc! { "_id": oid }) .await .map_err(|e| { ( StatusCode::INTERNAL_SERVER_ERROR, format!("Database error: {e}"), ) })? .ok_or_else(|| (StatusCode::NOT_FOUND, "Session not found".to_string()))?; if session.status != PentestStatus::Running && session.status != PentestStatus::Paused { return Err(( StatusCode::BAD_REQUEST, format!("Session is {}, cannot send messages", session.status), )); } // Look up the target let target_oid = mongodb::bson::oid::ObjectId::parse_str(&session.target_id).map_err(|_| { ( StatusCode::INTERNAL_SERVER_ERROR, "Invalid target_id in session".to_string(), ) })?; let target = agent .db .dast_targets() .find_one(doc! { "_id": target_oid }) .await .map_err(|e| { ( StatusCode::INTERNAL_SERVER_ERROR, format!("Database error: {e}"), ) })? .ok_or_else(|| { ( StatusCode::NOT_FOUND, "Target for session not found".to_string(), ) })?; // Store user message let session_id = id.clone(); let user_msg = PentestMessage::user(session_id.clone(), req.message.clone()); let _ = agent.db.pentest_messages().insert_one(&user_msg).await; let response_msg = user_msg.clone(); // Spawn orchestrator to continue the session let llm = agent.llm.clone(); let db = agent.db.clone(); let message = req.message.clone(); tokio::spawn(async move { let orchestrator = PentestOrchestrator::new(llm, db); orchestrator .run_session_guarded(&session, &target, &message) .await; }); Ok(Json(ApiResponse { data: response_msg, total: None, page: None, })) } /// GET /api/v1/pentest/sessions/:id/stream — SSE endpoint for real-time events /// /// Returns recent messages as SSE events (polling approach). /// True real-time streaming with broadcast channels will be added in a future iteration. #[tracing::instrument(skip_all, fields(session_id = %id))] pub async fn session_stream( Extension(agent): AgentExt, Path(id): Path, ) -> Result>>, StatusCode> { let oid = mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?; // Verify session exists let _session = agent .db .pentest_sessions() .find_one(doc! { "_id": oid }) .await .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)? .ok_or(StatusCode::NOT_FOUND)?; // Fetch recent messages for this session let messages: Vec = match agent .db .pentest_messages() .find(doc! { "session_id": &id }) .sort(doc! { "created_at": 1 }) .limit(100) .await { Ok(cursor) => collect_cursor_async(cursor).await, Err(_) => Vec::new(), }; // Fetch recent attack chain nodes let nodes: Vec = match agent .db .attack_chain_nodes() .find(doc! { "session_id": &id }) .sort(doc! { "started_at": 1 }) .limit(100) .await { Ok(cursor) => collect_cursor_async(cursor).await, Err(_) => Vec::new(), }; // Build SSE events from stored data let mut events: Vec> = Vec::new(); for msg in &messages { let event_data = serde_json::json!({ "type": "message", "role": msg.role, "content": msg.content, "created_at": msg.created_at.to_rfc3339(), }); if let Ok(data) = serde_json::to_string(&event_data) { events.push(Ok(Event::default().event("message").data(data))); } } for node in &nodes { let event_data = serde_json::json!({ "type": "tool_execution", "node_id": node.node_id, "tool_name": node.tool_name, "status": node.status, "findings_produced": node.findings_produced, }); if let Ok(data) = serde_json::to_string(&event_data) { events.push(Ok(Event::default().event("tool").data(data))); } } // Add session status event let session = agent .db .pentest_sessions() .find_one(doc! { "_id": oid }) .await .ok() .flatten(); if let Some(s) = session { let status_data = serde_json::json!({ "type": "status", "status": s.status, "findings_count": s.findings_count, "tool_invocations": s.tool_invocations, }); if let Ok(data) = serde_json::to_string(&status_data) { events.push(Ok(Event::default().event("status").data(data))); } } Ok(Sse::new(stream::iter(events))) } /// GET /api/v1/pentest/sessions/:id/attack-chain — Get attack chain nodes for a session #[tracing::instrument(skip_all, fields(session_id = %id))] pub async fn get_attack_chain( Extension(agent): AgentExt, Path(id): Path, ) -> Result>>, StatusCode> { // Verify the session ID is valid let _oid = mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?; let nodes = match agent .db .attack_chain_nodes() .find(doc! { "session_id": &id }) .sort(doc! { "started_at": 1 }) .await { Ok(cursor) => collect_cursor_async(cursor).await, Err(e) => { tracing::warn!("Failed to fetch attack chain nodes: {e}"); Vec::new() } }; let total = nodes.len() as u64; Ok(Json(ApiResponse { data: nodes, total: Some(total), page: None, })) } /// GET /api/v1/pentest/sessions/:id/messages — Get messages for a session #[tracing::instrument(skip_all, fields(session_id = %id))] pub async fn get_messages( Extension(agent): AgentExt, Path(id): Path, Query(params): Query, ) -> Result>>, StatusCode> { let _oid = mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?; let skip = (params.page.saturating_sub(1)) * params.limit as u64; let total = agent .db .pentest_messages() .count_documents(doc! { "session_id": &id }) .await .unwrap_or(0); let messages = match agent .db .pentest_messages() .find(doc! { "session_id": &id }) .sort(doc! { "created_at": 1 }) .skip(skip) .limit(params.limit) .await { Ok(cursor) => collect_cursor_async(cursor).await, Err(e) => { tracing::warn!("Failed to fetch pentest messages: {e}"); Vec::new() } }; Ok(Json(ApiResponse { data: messages, total: Some(total), page: Some(params.page), })) } /// GET /api/v1/pentest/stats — Aggregated pentest statistics #[tracing::instrument(skip_all)] pub async fn pentest_stats( Extension(agent): AgentExt, ) -> Result>, StatusCode> { let db = &agent.db; let running_sessions = db .pentest_sessions() .count_documents(doc! { "status": "running" }) .await .unwrap_or(0) as u32; // Count DAST findings from pentest sessions let total_vulnerabilities = db .dast_findings() .count_documents(doc! { "session_id": { "$exists": true, "$ne": null } }) .await .unwrap_or(0) as u32; // Aggregate tool invocations from all sessions let sessions: Vec = match db.pentest_sessions().find(doc! {}).await { Ok(cursor) => collect_cursor_async(cursor).await, Err(_) => Vec::new(), }; let total_tool_invocations: u32 = sessions.iter().map(|s| s.tool_invocations).sum(); let total_successes: u32 = sessions.iter().map(|s| s.tool_successes).sum(); let tool_success_rate = if total_tool_invocations == 0 { 100.0 } else { (total_successes as f64 / total_tool_invocations as f64) * 100.0 }; // Severity distribution from pentest-related DAST findings let critical = db .dast_findings() .count_documents(doc! { "session_id": { "$exists": true, "$ne": null }, "severity": "critical" }) .await .unwrap_or(0) as u32; let high = db .dast_findings() .count_documents(doc! { "session_id": { "$exists": true, "$ne": null }, "severity": "high" }) .await .unwrap_or(0) as u32; let medium = db .dast_findings() .count_documents(doc! { "session_id": { "$exists": true, "$ne": null }, "severity": "medium" }) .await .unwrap_or(0) as u32; let low = db .dast_findings() .count_documents(doc! { "session_id": { "$exists": true, "$ne": null }, "severity": "low" }) .await .unwrap_or(0) as u32; let info = db .dast_findings() .count_documents(doc! { "session_id": { "$exists": true, "$ne": null }, "severity": "info" }) .await .unwrap_or(0) as u32; Ok(Json(ApiResponse { data: PentestStats { running_sessions, total_vulnerabilities, total_tool_invocations, tool_success_rate, severity_distribution: SeverityDistribution { critical, high, medium, low, info, }, }, total: None, page: None, })) } /// GET /api/v1/pentest/sessions/:id/findings — Get DAST findings for a pentest session #[tracing::instrument(skip_all, fields(session_id = %id))] pub async fn get_session_findings( Extension(agent): AgentExt, Path(id): Path, Query(params): Query, ) -> Result>>, StatusCode> { let _oid = mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?; let skip = (params.page.saturating_sub(1)) * params.limit as u64; let total = agent .db .dast_findings() .count_documents(doc! { "session_id": &id }) .await .unwrap_or(0); let findings = match agent .db .dast_findings() .find(doc! { "session_id": &id }) .sort(doc! { "created_at": -1 }) .skip(skip) .limit(params.limit) .await { Ok(cursor) => collect_cursor_async(cursor).await, Err(e) => { tracing::warn!("Failed to fetch pentest session findings: {e}"); Vec::new() } }; Ok(Json(ApiResponse { data: findings, total: Some(total), page: Some(params.page), })) }