use std::convert::Infallible; use std::sync::Arc; use std::time::Duration; use axum::extract::{Extension, Path}; use axum::http::StatusCode; use axum::response::sse::{Event, KeepAlive, Sse}; use futures_util::stream; use mongodb::bson::doc; use tokio_stream::wrappers::BroadcastStream; use tokio_stream::StreamExt; use compliance_core::models::pentest::*; use crate::agent::ComplianceAgent; use super::super::dto::collect_cursor_async; type AgentExt = Extension>; /// GET /api/v1/pentest/sessions/:id/stream — SSE endpoint for real-time events /// /// Replays stored messages/nodes as initial burst, then subscribes to the /// broadcast channel for live updates. Sends keepalive comments every 15s. #[tracing::instrument(skip_all, fields(session_id = %id))] pub async fn session_stream( Extension(agent): AgentExt, Path(id): Path, ) -> Result>>, StatusCode> { let oid = mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?; // Verify session exists let _session = agent .db .pentest_sessions() .find_one(doc! { "_id": oid }) .await .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)? .ok_or(StatusCode::NOT_FOUND)?; // ── Initial burst: replay stored data ────────────────────────── let mut initial_events: Vec> = Vec::new(); // Fetch recent messages for this session let messages: Vec = match agent .db .pentest_messages() .find(doc! { "session_id": &id }) .sort(doc! { "created_at": 1 }) .limit(100) .await { Ok(cursor) => collect_cursor_async(cursor).await, Err(_) => Vec::new(), }; // Fetch recent attack chain nodes let nodes: Vec = match agent .db .attack_chain_nodes() .find(doc! { "session_id": &id }) .sort(doc! { "started_at": 1 }) .limit(100) .await { Ok(cursor) => collect_cursor_async(cursor).await, Err(_) => Vec::new(), }; for msg in &messages { let event_data = serde_json::json!({ "type": "message", "role": msg.role, "content": msg.content, "created_at": msg.created_at.to_rfc3339(), }); if let Ok(data) = serde_json::to_string(&event_data) { initial_events.push(Ok(Event::default().event("message").data(data))); } } for node in &nodes { let event_data = serde_json::json!({ "type": "tool_execution", "node_id": node.node_id, "tool_name": node.tool_name, "status": node.status, "findings_produced": node.findings_produced, }); if let Ok(data) = serde_json::to_string(&event_data) { initial_events.push(Ok(Event::default().event("tool").data(data))); } } // Add current session status event let session = agent .db .pentest_sessions() .find_one(doc! { "_id": oid }) .await .ok() .flatten(); if let Some(s) = session { let status_data = serde_json::json!({ "type": "status", "status": s.status, "findings_count": s.findings_count, "tool_invocations": s.tool_invocations, }); if let Ok(data) = serde_json::to_string(&status_data) { initial_events.push(Ok(Event::default().event("status").data(data))); } } // ── Live stream: subscribe to broadcast ──────────────────────── let live_stream = if let Some(rx) = agent.subscribe_session(&id) { let broadcast = BroadcastStream::new(rx).filter_map(|result| match result { Ok(event) => { if let Ok(data) = serde_json::to_string(&event) { let event_type = match &event { PentestEvent::ToolStart { .. } => "tool_start", PentestEvent::ToolComplete { .. } => "tool_complete", PentestEvent::Finding { .. } => "finding", PentestEvent::Message { .. } => "message", PentestEvent::Complete { .. } => "complete", PentestEvent::Error { .. } => "error", PentestEvent::Thinking { .. } => "thinking", PentestEvent::Paused => "paused", PentestEvent::Resumed => "resumed", }; Some(Ok(Event::default().event(event_type).data(data))) } else { None } } Err(_) => None, }); // Box to unify types Box::pin(broadcast) as std::pin::Pin> + Send>> } else { // No active broadcast — return empty stream Box::pin(stream::empty()) as std::pin::Pin> + Send>> }; // Chain initial burst + live stream let combined = stream::iter(initial_events).chain(live_stream); Ok(Sse::new(combined).keep_alive( KeepAlive::new() .interval(Duration::from_secs(15)) .text("keepalive"), )) }