use std::sync::Arc; use std::time::Duration; use mongodb::bson::doc; use tokio::sync::{broadcast, watch}; use compliance_core::models::dast::DastTarget; use compliance_core::models::pentest::*; use compliance_core::traits::pentest_tool::PentestToolContext; use compliance_dast::ToolRegistry; use crate::database::Database; use crate::llm::{ ChatMessage, LlmClient, LlmResponse, ToolCallRequest, ToolCallRequestFunction, ToolDefinition, }; /// Maximum duration for a single pentest session before timeout const SESSION_TIMEOUT: Duration = Duration::from_secs(30 * 60); // 30 minutes pub struct PentestOrchestrator { pub(crate) tool_registry: ToolRegistry, pub(crate) llm: Arc, pub(crate) db: Database, pub(crate) event_tx: broadcast::Sender, pub(crate) pause_rx: Option>, } impl PentestOrchestrator { /// Create a new orchestrator with an externally-provided broadcast sender /// and an optional pause receiver. pub fn new( llm: Arc, db: Database, event_tx: broadcast::Sender, pause_rx: Option>, ) -> Self { Self { tool_registry: ToolRegistry::new(), llm, db, event_tx, pause_rx, } } /// Run a pentest session with timeout and automatic failure marking on errors. pub async fn run_session_guarded( &self, session: &PentestSession, target: &DastTarget, initial_message: &str, ) { let session_id = session.id; // Use config-specified timeout or default let timeout_duration = session .config .as_ref() .and_then(|c| c.max_duration_minutes) .map(|m| Duration::from_secs(m as u64 * 60)) .unwrap_or(SESSION_TIMEOUT); let timeout_minutes = timeout_duration.as_secs() / 60; match tokio::time::timeout( timeout_duration, self.run_session(session, target, initial_message), ) .await { Ok(Ok(())) => { tracing::info!(?session_id, "Pentest session completed successfully"); } Ok(Err(e)) => { tracing::error!(?session_id, error = %e, "Pentest session failed"); self.mark_session_failed(session_id, &format!("Error: {e}")) .await; let _ = self.event_tx.send(PentestEvent::Error { message: format!("Session failed: {e}"), }); } Err(_) => { let msg = format!("Session timed out after {timeout_minutes} minutes"); tracing::warn!(?session_id, "{msg}"); self.mark_session_failed(session_id, &msg).await; let _ = self.event_tx.send(PentestEvent::Error { message: msg }); } } } async fn mark_session_failed( &self, session_id: Option, reason: &str, ) { if let Some(sid) = session_id { let _ = self .db .pentest_sessions() .update_one( doc! { "_id": sid }, doc! { "$set": { "status": "failed", "completed_at": mongodb::bson::DateTime::now(), "error_message": reason, }}, ) .await; } } /// Check if the session is paused; if so, update DB status and wait until resumed. async fn wait_if_paused(&self, session: &PentestSession) { let Some(ref pause_rx) = self.pause_rx else { return; }; let mut rx = pause_rx.clone(); if !*rx.borrow() { return; } // We are paused — update DB status if let Some(sid) = session.id { let _ = self .db .pentest_sessions() .update_one(doc! { "_id": sid }, doc! { "$set": { "status": "paused" }}) .await; } let _ = self.event_tx.send(PentestEvent::Paused); // Wait until unpaused while *rx.borrow() { if rx.changed().await.is_err() { break; } } // Resumed — update DB status back to running if let Some(sid) = session.id { let _ = self .db .pentest_sessions() .update_one(doc! { "_id": sid }, doc! { "$set": { "status": "running" }}) .await; } let _ = self.event_tx.send(PentestEvent::Resumed); } async fn run_session( &self, session: &PentestSession, target: &DastTarget, initial_message: &str, ) -> Result<(), crate::error::AgentError> { let session_id = session.id.map(|oid| oid.to_hex()).unwrap_or_default(); // Gather code-awareness context from linked repo let (sast_findings, sbom_entries, code_context) = self.gather_repo_context(target).await; // Build system prompt with code context let system_prompt = self .build_system_prompt( session, target, &sast_findings, &sbom_entries, &code_context, ) .await; // Build tool definitions for LLM let tool_defs: Vec = self .tool_registry .all_definitions() .into_iter() .map(|td| ToolDefinition { name: td.name, description: td.description, parameters: td.input_schema, }) .collect(); // Initialize messages let mut messages = vec![ ChatMessage { role: "system".to_string(), content: Some(system_prompt), tool_calls: None, tool_call_id: None, }, ChatMessage { role: "user".to_string(), content: Some(initial_message.to_string()), tool_calls: None, tool_call_id: None, }, ]; // Store user message let user_msg = PentestMessage::user(session_id.clone(), initial_message.to_string()); let _ = self.db.pentest_messages().insert_one(&user_msg).await; // Build tool context with real data let tool_context = PentestToolContext { target: target.clone(), session_id: session_id.clone(), sast_findings, sbom_entries, code_context, rate_limit: target.rate_limit, allow_destructive: target.allow_destructive, }; let max_iterations = 50; let mut total_findings = 0u32; let mut total_tool_calls = 0u32; let mut total_successes = 0u32; let mut prev_node_ids: Vec = Vec::new(); for _iteration in 0..max_iterations { // Check pause state at top of each iteration self.wait_if_paused(session).await; let response = self .llm .chat_with_tools(messages.clone(), &tool_defs, Some(0.2), Some(8192)) .await?; match response { LlmResponse::Content(content) => { let msg = PentestMessage::assistant(session_id.clone(), content.clone()); let _ = self.db.pentest_messages().insert_one(&msg).await; let _ = self.event_tx.send(PentestEvent::Message { content: content.clone(), }); messages.push(ChatMessage { role: "assistant".to_string(), content: Some(content.clone()), tool_calls: None, tool_call_id: None, }); let done_indicators = [ "pentest complete", "testing complete", "scan complete", "analysis complete", "finished", "that concludes", ]; let content_lower = content.to_lowercase(); if done_indicators .iter() .any(|ind| content_lower.contains(ind)) { break; } break; } LlmResponse::ToolCalls { calls: tool_calls, reasoning, } => { let tc_requests: Vec = tool_calls .iter() .map(|tc| ToolCallRequest { id: tc.id.clone(), r#type: "function".to_string(), function: ToolCallRequestFunction { name: tc.name.clone(), arguments: serde_json::to_string(&tc.arguments).unwrap_or_default(), }, }) .collect(); messages.push(ChatMessage { role: "assistant".to_string(), content: if reasoning.is_empty() { None } else { Some(reasoning.clone()) }, tool_calls: Some(tc_requests), tool_call_id: None, }); let mut current_batch_node_ids: Vec = Vec::new(); for tc in &tool_calls { total_tool_calls += 1; let node_id = uuid::Uuid::new_v4().to_string(); let mut node = AttackChainNode::new( session_id.clone(), node_id.clone(), tc.name.clone(), tc.arguments.clone(), reasoning.clone(), ); // Link to previous iteration's nodes node.parent_node_ids = prev_node_ids.clone(); node.status = AttackNodeStatus::Running; node.started_at = Some(chrono::Utc::now()); let _ = self.db.attack_chain_nodes().insert_one(&node).await; current_batch_node_ids.push(node_id.clone()); let _ = self.event_tx.send(PentestEvent::ToolStart { node_id: node_id.clone(), tool_name: tc.name.clone(), input: tc.arguments.clone(), }); let result = if let Some(tool) = self.tool_registry.get(&tc.name) { match tool.execute(tc.arguments.clone(), &tool_context).await { Ok(result) => { total_successes += 1; let findings_count = result.findings.len() as u32; total_findings += findings_count; let mut finding_ids: Vec = Vec::new(); for mut finding in result.findings { finding.scan_run_id = session_id.clone(); finding.session_id = Some(session_id.clone()); let insert_result = self.db.dast_findings().insert_one(&finding).await; if let Ok(res) = &insert_result { finding_ids.push( res.inserted_id .as_object_id() .map(|oid| oid.to_hex()) .unwrap_or_default(), ); } let _ = self.event_tx.send(PentestEvent::Finding { finding_id: finding .id .map(|oid| oid.to_hex()) .unwrap_or_default(), title: finding.title.clone(), severity: finding.severity.to_string(), }); } // Compute risk score based on findings severity let risk_score: Option = if findings_count > 0 { Some(std::cmp::min( 100, (findings_count as u8) .saturating_mul(15) .saturating_add(20), )) } else { None }; let _ = self.event_tx.send(PentestEvent::ToolComplete { node_id: node_id.clone(), summary: result.summary.clone(), findings_count, }); let finding_ids_bson: Vec = finding_ids .iter() .map(|id| mongodb::bson::Bson::String(id.clone())) .collect(); let mut update_doc = doc! { "status": "completed", "tool_output": mongodb::bson::to_bson(&result.data) .unwrap_or(mongodb::bson::Bson::Null), "completed_at": mongodb::bson::DateTime::now(), "findings_produced": finding_ids_bson, }; if let Some(rs) = risk_score { update_doc.insert("risk_score", rs as i32); } let _ = self .db .attack_chain_nodes() .update_one( doc! { "session_id": &session_id, "node_id": &node_id, }, doc! { "$set": update_doc }, ) .await; // Build LLM-facing summary: strip large fields // (screenshots, raw HTML) to save context window let llm_data = summarize_tool_output(&result.data); serde_json::json!({ "summary": result.summary, "findings_count": findings_count, "data": llm_data, }) .to_string() } Err(e) => { let _ = self .db .attack_chain_nodes() .update_one( doc! { "session_id": &session_id, "node_id": &node_id, }, doc! { "$set": { "status": "failed", "completed_at": mongodb::bson::DateTime::now(), }}, ) .await; format!("Tool execution failed: {e}") } } } else { format!("Unknown tool: {}", tc.name) }; messages.push(ChatMessage { role: "tool".to_string(), content: Some(result), tool_calls: None, tool_call_id: Some(tc.id.clone()), }); } // Advance parent links so next iteration's nodes connect to this batch prev_node_ids = current_batch_node_ids; if let Some(sid) = session.id { let _ = self .db .pentest_sessions() .update_one( doc! { "_id": sid }, doc! { "$set": { "tool_invocations": total_tool_calls as i64, "tool_successes": total_successes as i64, "findings_count": total_findings as i64, }}, ) .await; } } } } if let Some(sid) = session.id { let _ = self .db .pentest_sessions() .update_one( doc! { "_id": sid }, doc! { "$set": { "status": "completed", "completed_at": mongodb::bson::DateTime::now(), "tool_invocations": total_tool_calls as i64, "tool_successes": total_successes as i64, "findings_count": total_findings as i64, }}, ) .await; } // Clean up test user via identity provider API if requested if session .config .as_ref() .is_some_and(|c| c.auth.cleanup_test_user) { if let Some(ref test_user) = session.test_user { let http = reqwest::Client::new(); // We need the AgentConfig — read from env since orchestrator doesn't hold it let config = crate::config::load_config(); match config { Ok(cfg) => { match crate::pentest::cleanup::cleanup_test_user(test_user, &cfg, &http) .await { Ok(true) => { tracing::info!( username = test_user.username.as_deref(), "Test user cleaned up via provider API" ); // Mark as cleaned up in DB if let Some(sid) = session.id { let _ = self .db .pentest_sessions() .update_one( doc! { "_id": sid }, doc! { "$set": { "test_user.cleaned_up": true } }, ) .await; } } Ok(false) => { tracing::info!( "Test user cleanup skipped (no provider configured)" ); } Err(e) => { tracing::warn!(error = %e, "Test user cleanup failed"); let _ = self.event_tx.send(PentestEvent::Error { message: format!("Test user cleanup failed: {e}"), }); } } } Err(e) => { tracing::warn!(error = %e, "Could not load config for cleanup"); } } } } // Clean up the persistent browser session for this pentest compliance_dast::tools::browser::cleanup_browser_session(&session_id).await; let _ = self.event_tx.send(PentestEvent::Complete { summary: format!( "Pentest complete. {} findings from {} tool invocations.", total_findings, total_tool_calls ), }); Ok(()) } } /// Strip large fields from tool output before sending to the LLM. /// Screenshots, raw HTML, and other bulky data are replaced with short summaries. /// The full data is still stored in the DB for the report. fn summarize_tool_output(data: &serde_json::Value) -> serde_json::Value { let Some(obj) = data.as_object() else { return data.clone(); }; let mut summarized = serde_json::Map::new(); for (key, value) in obj { match key.as_str() { // Replace screenshot base64 with a placeholder "screenshot_base64" => { if let Some(s) = value.as_str() { if !s.is_empty() { summarized.insert( key.clone(), serde_json::Value::String( "[screenshot captured and saved to report]".to_string(), ), ); continue; } } summarized.insert(key.clone(), value.clone()); } // Truncate raw HTML content "html" => { if let Some(s) = value.as_str() { if s.len() > 2000 { summarized.insert( key.clone(), serde_json::Value::String(format!( "{}... [truncated, {} chars total]", &s[..2000], s.len() )), ); continue; } } summarized.insert(key.clone(), value.clone()); } // Truncate page text "text" if value.as_str().is_some_and(|s| s.len() > 1500) => { let s = value.as_str().unwrap_or_default(); summarized.insert( key.clone(), serde_json::Value::String(format!("{}... [truncated]", &s[..1500])), ); } // Trim large arrays (e.g., "elements", "links", "inputs") "elements" | "links" | "inputs" => { if let Some(arr) = value.as_array() { if arr.len() > 15 { let mut trimmed: Vec = arr[..15].to_vec(); trimmed.push(serde_json::json!(format!( "... and {} more", arr.len() - 15 ))); summarized.insert(key.clone(), serde_json::Value::Array(trimmed)); continue; } } summarized.insert(key.clone(), value.clone()); } // Recursively summarize nested objects (e.g., "page" in get_content) _ if value.is_object() => { summarized.insert(key.clone(), summarize_tool_output(value)); } // Keep everything else as-is _ => { summarized.insert(key.clone(), value.clone()); } } } serde_json::Value::Object(summarized) } #[cfg(test)] mod tests { use super::*; use serde_json::json; #[test] fn test_summarize_strips_screenshot() { let input = json!({ "screenshot_base64": "iVBOR...", "url": "https://example.com" }); let result = summarize_tool_output(&input); assert_eq!( result["screenshot_base64"], "[screenshot captured and saved to report]" ); assert_eq!(result["url"], "https://example.com"); } #[test] fn test_summarize_truncates_html() { let long_html = "x".repeat(3000); let input = json!({ "html": long_html }); let result = summarize_tool_output(&input); let s = result["html"].as_str().unwrap_or_default(); assert!(s.contains("[truncated, 3000 chars total]")); assert!(s.starts_with(&"x".repeat(2000))); assert!(s.len() < 3000); } #[test] fn test_summarize_truncates_text() { let long_text = "a".repeat(2000); let input = json!({ "text": long_text }); let result = summarize_tool_output(&input); let s = result["text"].as_str().unwrap_or_default(); assert!(s.contains("[truncated]")); assert!(s.starts_with(&"a".repeat(1500))); assert!(s.len() < 2000); } #[test] fn test_summarize_trims_large_arrays() { let elements: Vec = (0..20).map(|i| json!(format!("el-{i}"))).collect(); let input = json!({ "elements": elements }); let result = summarize_tool_output(&input); let arr = result["elements"].as_array(); assert!(arr.is_some()); if let Some(arr) = arr { // 15 kept + 1 summary entry assert_eq!(arr.len(), 16); assert_eq!(arr[15], json!("... and 5 more")); } } #[test] fn test_summarize_preserves_small_data() { let input = json!({ "url": "https://example.com", "status": 200, "title": "Example" }); let result = summarize_tool_output(&input); assert_eq!(result, input); } #[test] fn test_summarize_recursive() { let input = json!({ "page": { "screenshot_base64": "iVBORw0KGgoAAAA...", "url": "https://example.com" } }); let result = summarize_tool_output(&input); assert_eq!( result["page"]["screenshot_base64"], "[screenshot captured and saved to report]" ); assert_eq!(result["page"]["url"], "https://example.com"); } #[test] fn test_summarize_non_object() { let string_val = json!("just a string"); assert_eq!(summarize_tool_output(&string_val), string_val); let num_val = json!(42); assert_eq!(summarize_tool_output(&num_val), num_val); } }