use std::sync::Arc; use axum::extract::{Extension, Path, Query}; use axum::http::StatusCode; use axum::Json; use mongodb::bson::doc; use serde::Deserialize; use compliance_core::models::pentest::*; use crate::agent::ComplianceAgent; use crate::pentest::PentestOrchestrator; use super::super::dto::{collect_cursor_async, ApiResponse, PaginationParams}; type AgentExt = Extension>; #[derive(Deserialize)] pub struct CreateSessionRequest { pub target_id: Option, #[serde(default = "default_strategy")] pub strategy: String, pub message: Option, /// Wizard configuration — if present, takes precedence over legacy fields pub config: Option, } fn default_strategy() -> String { "comprehensive".to_string() } #[derive(Deserialize)] pub struct SendMessageRequest { pub message: String, } #[derive(Deserialize)] pub struct LookupRepoQuery { pub url: String, } /// POST /api/v1/pentest/sessions — Create a new pentest session and start the orchestrator #[tracing::instrument(skip_all)] pub async fn create_session( Extension(agent): AgentExt, Json(req): Json, ) -> Result>, (StatusCode, String)> { // Try to acquire a concurrency permit let permit = agent .session_semaphore .clone() .try_acquire_owned() .map_err(|_| { ( StatusCode::TOO_MANY_REQUESTS, "Maximum concurrent pentest sessions reached. Try again later.".to_string(), ) })?; if let Some(ref config) = req.config { // ── Wizard path ────────────────────────────────────────────── if !config.disclaimer_accepted { return Err(( StatusCode::BAD_REQUEST, "Disclaimer must be accepted".to_string(), )); } // Look up or auto-create DastTarget by app_url let target = match agent .db .dast_targets() .find_one(doc! { "base_url": &config.app_url }) .await .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, format!("DB error: {e}")))? { Some(t) => t, None => { use compliance_core::models::dast::{DastTarget, DastTargetType}; let mut t = DastTarget::new( config.app_url.clone(), config.app_url.clone(), DastTargetType::WebApp, ); if let Some(rl) = config.rate_limit { t.rate_limit = rl; } t.allow_destructive = config.allow_destructive; t.excluded_paths = config.scope_exclusions.clone(); let res = agent.db.dast_targets().insert_one(&t).await.map_err(|e| { ( StatusCode::INTERNAL_SERVER_ERROR, format!("Failed to create target: {e}"), ) })?; t.id = res.inserted_id.as_object_id(); t } }; let target_id = target.id.map(|oid| oid.to_hex()).unwrap_or_default(); // Parse strategy from config or request let strat_str = config.strategy.as_deref().unwrap_or(req.strategy.as_str()); let strategy = parse_strategy(strat_str); let mut session = PentestSession::new(target_id, strategy); session.config = Some(config.clone()); session.repo_id = target.repo_id.clone(); // Resolve repo_id from git_repo_url if provided if let Some(ref git_url) = config.git_repo_url { if let Ok(Some(repo)) = agent .db .repositories() .find_one(doc! { "git_url": git_url }) .await { session.repo_id = repo.id.map(|oid| oid.to_hex()); } } let insert_result = agent .db .pentest_sessions() .insert_one(&session) .await .map_err(|e| { ( StatusCode::INTERNAL_SERVER_ERROR, format!("Failed to create session: {e}"), ) })?; session.id = insert_result.inserted_id.as_object_id(); let session_id_str = session.id.map(|oid| oid.to_hex()).unwrap_or_default(); // Register broadcast stream and pause control let event_tx = agent.register_session_stream(&session_id_str); let pause_rx = agent.register_pause_control(&session_id_str); // Merge server-default IMAP/email settings where wizard left blanks if let Some(ref mut cfg) = session.config { if cfg.auth.mode == AuthMode::AutoRegister { if cfg.auth.verification_email.is_none() { cfg.auth.verification_email = agent.config.pentest_verification_email.clone(); } if cfg.auth.imap_host.is_none() { cfg.auth.imap_host = agent.config.pentest_imap_host.clone(); } if cfg.auth.imap_port.is_none() { cfg.auth.imap_port = agent.config.pentest_imap_port; } if cfg.auth.imap_username.is_none() { cfg.auth.imap_username = agent.config.pentest_imap_username.clone(); } if cfg.auth.imap_password.is_none() { cfg.auth.imap_password = agent.config.pentest_imap_password.as_ref().map(|s| { use secrecy::ExposeSecret; s.expose_secret().to_string() }); } } } // Pre-populate test user record for auto-register sessions if let Some(ref cfg) = session.config { if cfg.auth.mode == AuthMode::AutoRegister { let verification_email = cfg.auth.verification_email.clone(); // Build plus-addressed email for this session let test_email = verification_email.as_deref().map(|email| { let parts: Vec<&str> = email.splitn(2, '@').collect(); if parts.len() == 2 { format!("{}+{}@{}", parts[0], session_id_str, parts[1]) } else { email.to_string() } }); // Detect identity provider from keycloak config let provider = if agent.config.keycloak_url.is_some() { Some(compliance_core::models::pentest::IdentityProvider::Keycloak) } else { None }; session.test_user = Some(compliance_core::models::pentest::TestUserRecord { username: None, // LLM will choose; updated after registration email: test_email, provider_user_id: None, provider, cleaned_up: false, }); } } // Encrypt credentials before they linger in memory let mut session_for_task = session.clone(); if let Some(ref mut cfg) = session_for_task.config { cfg.auth.username = cfg .auth .username .as_ref() .map(|u| crate::pentest::crypto::encrypt(u)); cfg.auth.password = cfg .auth .password .as_ref() .map(|p| crate::pentest::crypto::encrypt(p)); } // Persist encrypted credentials to DB if session_for_task.config.is_some() { if let Some(sid) = session.id { let _ = agent .db .pentest_sessions() .update_one( doc! { "_id": sid }, doc! { "$set": { "config.auth.username": session_for_task.config.as_ref() .and_then(|c| c.auth.username.as_deref()) .map(|s| mongodb::bson::Bson::String(s.to_string())) .unwrap_or(mongodb::bson::Bson::Null), "config.auth.password": session_for_task.config.as_ref() .and_then(|c| c.auth.password.as_deref()) .map(|s| mongodb::bson::Bson::String(s.to_string())) .unwrap_or(mongodb::bson::Bson::Null), }}, ) .await; } } let initial_message = config .initial_instructions .clone() .or(req.message.clone()) .unwrap_or_else(|| { format!( "Begin a {} penetration test against {} ({}). \ Identify vulnerabilities and provide evidence for each finding.", session.strategy, target.name, target.base_url, ) }); let llm = agent.llm.clone(); let db = agent.db.clone(); let session_clone = session.clone(); let target_clone = target.clone(); let agent_ref = agent.clone(); tokio::spawn(async move { let orchestrator = PentestOrchestrator::new(llm, db, event_tx, Some(pause_rx)); orchestrator .run_session_guarded(&session_clone, &target_clone, &initial_message) .await; // Clean up session resources agent_ref.cleanup_session(&session_id_str); // Release concurrency permit drop(permit); }); // Redact credentials in response let mut response_session = session; if let Some(ref mut cfg) = response_session.config { if cfg.auth.username.is_some() { cfg.auth.username = Some("********".to_string()); } if cfg.auth.password.is_some() { cfg.auth.password = Some("********".to_string()); } } Ok(Json(ApiResponse { data: response_session, total: None, page: None, })) } else { // ── Legacy path ────────────────────────────────────────────── let target_id = req.target_id.clone().ok_or_else(|| { ( StatusCode::BAD_REQUEST, "target_id is required for legacy creation".to_string(), ) })?; let oid = mongodb::bson::oid::ObjectId::parse_str(&target_id).map_err(|_| { ( StatusCode::BAD_REQUEST, "Invalid target_id format".to_string(), ) })?; let target = agent .db .dast_targets() .find_one(doc! { "_id": oid }) .await .map_err(|e| { ( StatusCode::INTERNAL_SERVER_ERROR, format!("Database error: {e}"), ) })? .ok_or_else(|| (StatusCode::NOT_FOUND, "Target not found".to_string()))?; let strategy = parse_strategy(&req.strategy); let mut session = PentestSession::new(target_id, strategy); session.repo_id = target.repo_id.clone(); let insert_result = agent .db .pentest_sessions() .insert_one(&session) .await .map_err(|e| { ( StatusCode::INTERNAL_SERVER_ERROR, format!("Failed to create session: {e}"), ) })?; session.id = insert_result.inserted_id.as_object_id(); let session_id_str = session.id.map(|oid| oid.to_hex()).unwrap_or_default(); // Register broadcast stream and pause control let event_tx = agent.register_session_stream(&session_id_str); let pause_rx = agent.register_pause_control(&session_id_str); let initial_message = req.message.unwrap_or_else(|| { format!( "Begin a {} penetration test against {} ({}). \ Identify vulnerabilities and provide evidence for each finding.", session.strategy, target.name, target.base_url, ) }); let llm = agent.llm.clone(); let db = agent.db.clone(); let session_clone = session.clone(); let target_clone = target.clone(); let agent_ref = agent.clone(); tokio::spawn(async move { let orchestrator = PentestOrchestrator::new(llm, db, event_tx, Some(pause_rx)); orchestrator .run_session_guarded(&session_clone, &target_clone, &initial_message) .await; agent_ref.cleanup_session(&session_id_str); drop(permit); }); Ok(Json(ApiResponse { data: session, total: None, page: None, })) } } fn parse_strategy(s: &str) -> PentestStrategy { match s { "quick" => PentestStrategy::Quick, "targeted" => PentestStrategy::Targeted, "aggressive" => PentestStrategy::Aggressive, "stealth" => PentestStrategy::Stealth, _ => PentestStrategy::Comprehensive, } } /// GET /api/v1/pentest/lookup-repo — Look up a tracked repository by git URL #[tracing::instrument(skip_all)] pub async fn lookup_repo( Extension(agent): AgentExt, Query(params): Query, ) -> Result>, StatusCode> { let repo = agent .db .repositories() .find_one(doc! { "git_url": ¶ms.url }) .await .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; let data = match repo { Some(r) => serde_json::json!({ "name": r.name, "default_branch": r.default_branch, "last_scanned_commit": r.last_scanned_commit, }), None => serde_json::Value::Null, }; Ok(Json(ApiResponse { data, total: None, page: None, })) } /// GET /api/v1/pentest/sessions — List pentest sessions #[tracing::instrument(skip_all)] pub async fn list_sessions( Extension(agent): AgentExt, Query(params): Query, ) -> Result>>, StatusCode> { let db = &agent.db; let skip = (params.page.saturating_sub(1)) * params.limit as u64; let total = db .pentest_sessions() .count_documents(doc! {}) .await .unwrap_or(0); let sessions = match db .pentest_sessions() .find(doc! {}) .sort(doc! { "started_at": -1 }) .skip(skip) .limit(params.limit) .await { Ok(cursor) => collect_cursor_async(cursor).await, Err(e) => { tracing::warn!("Failed to fetch pentest sessions: {e}"); Vec::new() } }; Ok(Json(ApiResponse { data: sessions, total: Some(total), page: Some(params.page), })) } /// GET /api/v1/pentest/sessions/:id — Get a single pentest session #[tracing::instrument(skip_all, fields(session_id = %id))] pub async fn get_session( Extension(agent): AgentExt, Path(id): Path, ) -> Result>, StatusCode> { let oid = mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?; let mut session = agent .db .pentest_sessions() .find_one(doc! { "_id": oid }) .await .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)? .ok_or(StatusCode::NOT_FOUND)?; // Redact credentials in response if let Some(ref mut cfg) = session.config { if cfg.auth.username.is_some() { cfg.auth.username = Some("********".to_string()); } if cfg.auth.password.is_some() { cfg.auth.password = Some("********".to_string()); } } Ok(Json(ApiResponse { data: session, total: None, page: None, })) } /// POST /api/v1/pentest/sessions/:id/chat — Send a user message and trigger next orchestrator iteration #[tracing::instrument(skip_all, fields(session_id = %id))] pub async fn send_message( Extension(agent): AgentExt, Path(id): Path, Json(req): Json, ) -> Result>, (StatusCode, String)> { let oid = mongodb::bson::oid::ObjectId::parse_str(&id) .map_err(|_| (StatusCode::BAD_REQUEST, "Invalid session ID".to_string()))?; // Verify session exists and is running let session = agent .db .pentest_sessions() .find_one(doc! { "_id": oid }) .await .map_err(|e| { ( StatusCode::INTERNAL_SERVER_ERROR, format!("Database error: {e}"), ) })? .ok_or_else(|| (StatusCode::NOT_FOUND, "Session not found".to_string()))?; if session.status != PentestStatus::Running && session.status != PentestStatus::Paused { return Err(( StatusCode::BAD_REQUEST, format!("Session is {}, cannot send messages", session.status), )); } // Look up the target let target_oid = mongodb::bson::oid::ObjectId::parse_str(&session.target_id).map_err(|_| { ( StatusCode::INTERNAL_SERVER_ERROR, "Invalid target_id in session".to_string(), ) })?; let target = agent .db .dast_targets() .find_one(doc! { "_id": target_oid }) .await .map_err(|e| { ( StatusCode::INTERNAL_SERVER_ERROR, format!("Database error: {e}"), ) })? .ok_or_else(|| { ( StatusCode::NOT_FOUND, "Target for session not found".to_string(), ) })?; // Store user message let session_id = id.clone(); let user_msg = PentestMessage::user(session_id.clone(), req.message.clone()); let _ = agent.db.pentest_messages().insert_one(&user_msg).await; let response_msg = user_msg.clone(); // Spawn orchestrator to continue the session let llm = agent.llm.clone(); let db = agent.db.clone(); let message = req.message.clone(); // Use existing broadcast sender if available, otherwise create a new one let event_tx = agent .subscribe_session(&session_id) .and_then(|_| { agent .session_streams .get(&session_id) .map(|entry| entry.value().clone()) }) .unwrap_or_else(|| agent.register_session_stream(&session_id)); tokio::spawn(async move { let orchestrator = PentestOrchestrator::new(llm, db, event_tx, None); orchestrator .run_session_guarded(&session, &target, &message) .await; }); Ok(Json(ApiResponse { data: response_msg, total: None, page: None, })) } /// POST /api/v1/pentest/sessions/:id/stop — Stop a running pentest session #[tracing::instrument(skip_all, fields(session_id = %id))] pub async fn stop_session( Extension(agent): AgentExt, Path(id): Path, ) -> Result>, (StatusCode, String)> { let oid = mongodb::bson::oid::ObjectId::parse_str(&id) .map_err(|_| (StatusCode::BAD_REQUEST, "Invalid session ID".to_string()))?; let session = agent .db .pentest_sessions() .find_one(doc! { "_id": oid }) .await .map_err(|e| { ( StatusCode::INTERNAL_SERVER_ERROR, format!("Database error: {e}"), ) })? .ok_or_else(|| (StatusCode::NOT_FOUND, "Session not found".to_string()))?; if session.status != PentestStatus::Running && session.status != PentestStatus::Paused { return Err(( StatusCode::BAD_REQUEST, format!("Session is {}, not running or paused", session.status), )); } agent .db .pentest_sessions() .update_one( doc! { "_id": oid }, doc! { "$set": { "status": "failed", "completed_at": mongodb::bson::DateTime::now(), "error_message": "Stopped by user", }}, ) .await .map_err(|e| { ( StatusCode::INTERNAL_SERVER_ERROR, format!("Database error: {e}"), ) })?; // Clean up session resources agent.cleanup_session(&id); let updated = agent .db .pentest_sessions() .find_one(doc! { "_id": oid }) .await .map_err(|e| { ( StatusCode::INTERNAL_SERVER_ERROR, format!("Database error: {e}"), ) })? .ok_or_else(|| { ( StatusCode::NOT_FOUND, "Session not found after update".to_string(), ) })?; Ok(Json(ApiResponse { data: updated, total: None, page: None, })) } /// POST /api/v1/pentest/sessions/:id/pause — Pause a running pentest session #[tracing::instrument(skip_all, fields(session_id = %id))] pub async fn pause_session( Extension(agent): AgentExt, Path(id): Path, ) -> Result>, (StatusCode, String)> { let oid = mongodb::bson::oid::ObjectId::parse_str(&id) .map_err(|_| (StatusCode::BAD_REQUEST, "Invalid session ID".to_string()))?; let session = agent .db .pentest_sessions() .find_one(doc! { "_id": oid }) .await .map_err(|e| { ( StatusCode::INTERNAL_SERVER_ERROR, format!("Database error: {e}"), ) })? .ok_or_else(|| (StatusCode::NOT_FOUND, "Session not found".to_string()))?; if session.status != PentestStatus::Running { return Err(( StatusCode::BAD_REQUEST, format!("Session is {}, not running", session.status), )); } if !agent.pause_session(&id) { return Err(( StatusCode::INTERNAL_SERVER_ERROR, "Failed to send pause signal".to_string(), )); } Ok(Json(ApiResponse { data: serde_json::json!({ "status": "paused" }), total: None, page: None, })) } /// POST /api/v1/pentest/sessions/:id/resume — Resume a paused pentest session #[tracing::instrument(skip_all, fields(session_id = %id))] pub async fn resume_session( Extension(agent): AgentExt, Path(id): Path, ) -> Result>, (StatusCode, String)> { let oid = mongodb::bson::oid::ObjectId::parse_str(&id) .map_err(|_| (StatusCode::BAD_REQUEST, "Invalid session ID".to_string()))?; let session = agent .db .pentest_sessions() .find_one(doc! { "_id": oid }) .await .map_err(|e| { ( StatusCode::INTERNAL_SERVER_ERROR, format!("Database error: {e}"), ) })? .ok_or_else(|| (StatusCode::NOT_FOUND, "Session not found".to_string()))?; if session.status != PentestStatus::Paused { return Err(( StatusCode::BAD_REQUEST, format!("Session is {}, not paused", session.status), )); } if !agent.resume_session(&id) { return Err(( StatusCode::INTERNAL_SERVER_ERROR, "Failed to send resume signal".to_string(), )); } Ok(Json(ApiResponse { data: serde_json::json!({ "status": "running" }), total: None, page: None, })) } /// GET /api/v1/pentest/sessions/:id/attack-chain — Get attack chain nodes for a session #[tracing::instrument(skip_all, fields(session_id = %id))] pub async fn get_attack_chain( Extension(agent): AgentExt, Path(id): Path, ) -> Result>>, StatusCode> { let _oid = mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?; let nodes = match agent .db .attack_chain_nodes() .find(doc! { "session_id": &id }) .sort(doc! { "started_at": 1 }) .await { Ok(cursor) => collect_cursor_async(cursor).await, Err(e) => { tracing::warn!("Failed to fetch attack chain nodes: {e}"); Vec::new() } }; let total = nodes.len() as u64; Ok(Json(ApiResponse { data: nodes, total: Some(total), page: None, })) } /// GET /api/v1/pentest/sessions/:id/messages — Get messages for a session #[tracing::instrument(skip_all, fields(session_id = %id))] pub async fn get_messages( Extension(agent): AgentExt, Path(id): Path, Query(params): Query, ) -> Result>>, StatusCode> { let _oid = mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?; let skip = (params.page.saturating_sub(1)) * params.limit as u64; let total = agent .db .pentest_messages() .count_documents(doc! { "session_id": &id }) .await .unwrap_or(0); let messages = match agent .db .pentest_messages() .find(doc! { "session_id": &id }) .sort(doc! { "created_at": 1 }) .skip(skip) .limit(params.limit) .await { Ok(cursor) => collect_cursor_async(cursor).await, Err(e) => { tracing::warn!("Failed to fetch pentest messages: {e}"); Vec::new() } }; Ok(Json(ApiResponse { data: messages, total: Some(total), page: Some(params.page), })) } /// GET /api/v1/pentest/sessions/:id/findings — Get DAST findings for a pentest session #[tracing::instrument(skip_all, fields(session_id = %id))] pub async fn get_session_findings( Extension(agent): AgentExt, Path(id): Path, Query(params): Query, ) -> Result>>, StatusCode> { let _oid = mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?; let skip = (params.page.saturating_sub(1)) * params.limit as u64; let total = agent .db .dast_findings() .count_documents(doc! { "session_id": &id }) .await .unwrap_or(0); let findings = match agent .db .dast_findings() .find(doc! { "session_id": &id }) .sort(doc! { "created_at": -1 }) .skip(skip) .limit(params.limit) .await { Ok(cursor) => collect_cursor_async(cursor).await, Err(e) => { tracing::warn!("Failed to fetch pentest session findings: {e}"); Vec::new() } }; Ok(Json(ApiResponse { data: findings, total: Some(total), page: Some(params.page), })) }