use std::sync::Arc; use dashmap::DashMap; use tokio::sync::{broadcast, watch, Semaphore}; use compliance_core::models::pentest::PentestEvent; use compliance_core::AgentConfig; use crate::database::Database; use crate::llm::LlmClient; use crate::pipeline::orchestrator::PipelineOrchestrator; /// Default maximum concurrent pentest sessions. const DEFAULT_MAX_CONCURRENT_SESSIONS: usize = 5; #[derive(Clone)] pub struct ComplianceAgent { pub config: AgentConfig, pub db: Database, pub llm: Arc, pub http: reqwest::Client, /// Per-session broadcast senders for SSE streaming. pub session_streams: Arc>>, /// Per-session pause controls (true = paused). pub session_pause: Arc>>, /// Semaphore limiting concurrent pentest sessions. pub session_semaphore: Arc, } impl ComplianceAgent { pub fn new(config: AgentConfig, db: Database) -> Self { let llm = Arc::new(LlmClient::new( config.litellm_url.clone(), config.litellm_api_key.clone(), config.litellm_model.clone(), config.litellm_embed_model.clone(), )); Self { config, db, llm, http: reqwest::Client::new(), session_streams: Arc::new(DashMap::new()), session_pause: Arc::new(DashMap::new()), session_semaphore: Arc::new(Semaphore::new(DEFAULT_MAX_CONCURRENT_SESSIONS)), } } pub async fn run_scan( &self, repo_id: &str, trigger: compliance_core::models::ScanTrigger, ) -> Result<(), crate::error::AgentError> { let orchestrator = PipelineOrchestrator::new( self.config.clone(), self.db.clone(), self.llm.clone(), self.http.clone(), ); orchestrator.run(repo_id, trigger).await } /// Run a PR review: scan the diff and post review comments. pub async fn run_pr_review( &self, repo_id: &str, pr_number: u64, base_sha: &str, head_sha: &str, ) -> Result<(), crate::error::AgentError> { let repo = self .db .repositories() .find_one(mongodb::bson::doc! { "_id": mongodb::bson::oid::ObjectId::parse_str(repo_id) .map_err(|e| crate::error::AgentError::Other(e.to_string()))? }) .await? .ok_or_else(|| { crate::error::AgentError::Other(format!("Repository {repo_id} not found")) })?; let orchestrator = PipelineOrchestrator::new( self.config.clone(), self.db.clone(), self.llm.clone(), self.http.clone(), ); orchestrator .run_pr_review(&repo, repo_id, pr_number, base_sha, head_sha) .await } // ── Session stream management ────────────────────────────────── /// Register a broadcast sender for a session. Returns the sender. pub fn register_session_stream(&self, session_id: &str) -> broadcast::Sender { let (tx, _) = broadcast::channel(256); self.session_streams .insert(session_id.to_string(), tx.clone()); tx } /// Subscribe to a session's broadcast stream. pub fn subscribe_session(&self, session_id: &str) -> Option> { self.session_streams .get(session_id) .map(|tx| tx.subscribe()) } // ── Session pause/resume management ──────────────────────────── /// Register a pause control for a session. Returns the watch receiver. pub fn register_pause_control(&self, session_id: &str) -> watch::Receiver { let (tx, rx) = watch::channel(false); self.session_pause.insert(session_id.to_string(), tx); rx } /// Pause a session. pub fn pause_session(&self, session_id: &str) -> bool { if let Some(tx) = self.session_pause.get(session_id) { tx.send(true).is_ok() } else { false } } /// Resume a session. pub fn resume_session(&self, session_id: &str) -> bool { if let Some(tx) = self.session_pause.get(session_id) { tx.send(false).is_ok() } else { false } } /// Clean up all per-session resources. pub fn cleanup_session(&self, session_id: &str) { self.session_streams.remove(session_id); self.session_pause.remove(session_id); } }