All checks were successful
Complete pentest feature overhaul: SSE streaming, session-persistent browser tool (CDP), AES-256 credential encryption, auto-screenshots in reports, code-level remediation correlation, SAST triage chunking, context window optimization, test user cleanup (Keycloak/Auth0/Okta), wizard dropdowns, attack chain improvements, architecture docs with Mermaid diagrams. Co-authored-by: Sharang Parnerkar <parnerkarsharang@gmail.com> Reviewed-on: #16
159 lines
5.4 KiB
Rust
159 lines
5.4 KiB
Rust
use std::convert::Infallible;
|
|
use std::sync::Arc;
|
|
use std::time::Duration;
|
|
|
|
use axum::extract::{Extension, Path};
|
|
use axum::http::StatusCode;
|
|
use axum::response::sse::{Event, KeepAlive, Sse};
|
|
use futures_util::stream;
|
|
use mongodb::bson::doc;
|
|
use tokio_stream::wrappers::BroadcastStream;
|
|
use tokio_stream::StreamExt;
|
|
|
|
use compliance_core::models::pentest::*;
|
|
|
|
use crate::agent::ComplianceAgent;
|
|
|
|
use super::super::dto::collect_cursor_async;
|
|
|
|
type AgentExt = Extension<Arc<ComplianceAgent>>;
|
|
|
|
/// GET /api/v1/pentest/sessions/:id/stream — SSE endpoint for real-time events
|
|
///
|
|
/// Replays stored messages/nodes as initial burst, then subscribes to the
|
|
/// broadcast channel for live updates. Sends keepalive comments every 15s.
|
|
#[tracing::instrument(skip_all, fields(session_id = %id))]
|
|
pub async fn session_stream(
|
|
Extension(agent): AgentExt,
|
|
Path(id): Path<String>,
|
|
) -> Result<Sse<impl futures_util::Stream<Item = Result<Event, Infallible>>>, StatusCode> {
|
|
let oid = mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?;
|
|
|
|
// Verify session exists
|
|
let _session = agent
|
|
.db
|
|
.pentest_sessions()
|
|
.find_one(doc! { "_id": oid })
|
|
.await
|
|
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
|
|
.ok_or(StatusCode::NOT_FOUND)?;
|
|
|
|
// ── Initial burst: replay stored data ──────────────────────────
|
|
|
|
let mut initial_events: Vec<Result<Event, Infallible>> = Vec::new();
|
|
|
|
// Fetch recent messages for this session
|
|
let messages: Vec<PentestMessage> = match agent
|
|
.db
|
|
.pentest_messages()
|
|
.find(doc! { "session_id": &id })
|
|
.sort(doc! { "created_at": 1 })
|
|
.limit(100)
|
|
.await
|
|
{
|
|
Ok(cursor) => collect_cursor_async(cursor).await,
|
|
Err(_) => Vec::new(),
|
|
};
|
|
|
|
// Fetch recent attack chain nodes
|
|
let nodes: Vec<AttackChainNode> = match agent
|
|
.db
|
|
.attack_chain_nodes()
|
|
.find(doc! { "session_id": &id })
|
|
.sort(doc! { "started_at": 1 })
|
|
.limit(100)
|
|
.await
|
|
{
|
|
Ok(cursor) => collect_cursor_async(cursor).await,
|
|
Err(_) => Vec::new(),
|
|
};
|
|
|
|
for msg in &messages {
|
|
let event_data = serde_json::json!({
|
|
"type": "message",
|
|
"role": msg.role,
|
|
"content": msg.content,
|
|
"created_at": msg.created_at.to_rfc3339(),
|
|
});
|
|
if let Ok(data) = serde_json::to_string(&event_data) {
|
|
initial_events.push(Ok(Event::default().event("message").data(data)));
|
|
}
|
|
}
|
|
|
|
for node in &nodes {
|
|
let event_data = serde_json::json!({
|
|
"type": "tool_execution",
|
|
"node_id": node.node_id,
|
|
"tool_name": node.tool_name,
|
|
"status": node.status,
|
|
"findings_produced": node.findings_produced,
|
|
});
|
|
if let Ok(data) = serde_json::to_string(&event_data) {
|
|
initial_events.push(Ok(Event::default().event("tool").data(data)));
|
|
}
|
|
}
|
|
|
|
// Add current session status event
|
|
let session = agent
|
|
.db
|
|
.pentest_sessions()
|
|
.find_one(doc! { "_id": oid })
|
|
.await
|
|
.ok()
|
|
.flatten();
|
|
|
|
if let Some(s) = session {
|
|
let status_data = serde_json::json!({
|
|
"type": "status",
|
|
"status": s.status,
|
|
"findings_count": s.findings_count,
|
|
"tool_invocations": s.tool_invocations,
|
|
});
|
|
if let Ok(data) = serde_json::to_string(&status_data) {
|
|
initial_events.push(Ok(Event::default().event("status").data(data)));
|
|
}
|
|
}
|
|
|
|
// ── Live stream: subscribe to broadcast ────────────────────────
|
|
|
|
let live_stream = if let Some(rx) = agent.subscribe_session(&id) {
|
|
let broadcast = BroadcastStream::new(rx).filter_map(|result| match result {
|
|
Ok(event) => {
|
|
if let Ok(data) = serde_json::to_string(&event) {
|
|
let event_type = match &event {
|
|
PentestEvent::ToolStart { .. } => "tool_start",
|
|
PentestEvent::ToolComplete { .. } => "tool_complete",
|
|
PentestEvent::Finding { .. } => "finding",
|
|
PentestEvent::Message { .. } => "message",
|
|
PentestEvent::Complete { .. } => "complete",
|
|
PentestEvent::Error { .. } => "error",
|
|
PentestEvent::Thinking { .. } => "thinking",
|
|
PentestEvent::Paused => "paused",
|
|
PentestEvent::Resumed => "resumed",
|
|
};
|
|
Some(Ok(Event::default().event(event_type).data(data)))
|
|
} else {
|
|
None
|
|
}
|
|
}
|
|
Err(_) => None,
|
|
});
|
|
// Box to unify types
|
|
Box::pin(broadcast)
|
|
as std::pin::Pin<Box<dyn futures_util::Stream<Item = Result<Event, Infallible>> + Send>>
|
|
} else {
|
|
// No active broadcast — return empty stream
|
|
Box::pin(stream::empty())
|
|
as std::pin::Pin<Box<dyn futures_util::Stream<Item = Result<Event, Infallible>> + Send>>
|
|
};
|
|
|
|
// Chain initial burst + live stream
|
|
let combined = stream::iter(initial_events).chain(live_stream);
|
|
|
|
Ok(Sse::new(combined).keep_alive(
|
|
KeepAlive::new()
|
|
.interval(Duration::from_secs(15))
|
|
.text("keepalive"),
|
|
))
|
|
}
|