refactor: modularize codebase and add 404 unit tests (#13)
All checks were successful
CI / Format (push) Successful in 4s
CI / Clippy (push) Successful in 4m19s
CI / Security Audit (push) Successful in 1m44s
CI / Detect Changes (push) Successful in 5s
CI / Tests (push) Successful in 5m15s
CI / Deploy Agent (push) Successful in 2s
CI / Deploy Dashboard (push) Successful in 2s
CI / Deploy Docs (push) Has been skipped
CI / Deploy MCP (push) Successful in 2s

This commit was merged in pull request #13.
This commit is contained in:
2026-03-13 08:03:45 +00:00
parent acc5b86aa4
commit 3bb690e5bb
89 changed files with 11884 additions and 6046 deletions

View File

@@ -1,31 +1,27 @@
use std::sync::Arc;
use std::time::Duration;
use futures_util::StreamExt;
use mongodb::bson::doc;
use tokio::sync::broadcast;
use compliance_core::models::dast::DastTarget;
use compliance_core::models::finding::{Finding, FindingStatus, Severity};
use compliance_core::models::pentest::*;
use compliance_core::models::sbom::SbomEntry;
use compliance_core::traits::pentest_tool::PentestToolContext;
use compliance_dast::ToolRegistry;
use crate::database::Database;
use crate::llm::client::{
ChatMessage, LlmResponse, ToolCallRequest, ToolCallRequestFunction, ToolDefinition,
use crate::llm::{
ChatMessage, LlmClient, LlmResponse, ToolCallRequest, ToolCallRequestFunction, ToolDefinition,
};
use crate::llm::LlmClient;
/// Maximum duration for a single pentest session before timeout
const SESSION_TIMEOUT: Duration = Duration::from_secs(30 * 60); // 30 minutes
pub struct PentestOrchestrator {
tool_registry: ToolRegistry,
llm: Arc<LlmClient>,
db: Database,
event_tx: broadcast::Sender<PentestEvent>,
pub(crate) tool_registry: ToolRegistry,
pub(crate) llm: Arc<LlmClient>,
pub(crate) db: Database,
pub(crate) event_tx: broadcast::Sender<PentestEvent>,
}
impl PentestOrchestrator {
@@ -39,10 +35,12 @@ impl PentestOrchestrator {
}
}
#[allow(dead_code)]
pub fn subscribe(&self) -> broadcast::Receiver<PentestEvent> {
self.event_tx.subscribe()
}
#[allow(dead_code)]
pub fn event_sender(&self) -> broadcast::Sender<PentestEvent> {
self.event_tx.clone()
}
@@ -111,18 +109,20 @@ impl PentestOrchestrator {
target: &DastTarget,
initial_message: &str,
) -> Result<(), crate::error::AgentError> {
let session_id = session
.id
.map(|oid| oid.to_hex())
.unwrap_or_default();
let session_id = session.id.map(|oid| oid.to_hex()).unwrap_or_default();
// Gather code-awareness context from linked repo
let (sast_findings, sbom_entries, code_context) =
self.gather_repo_context(target).await;
let (sast_findings, sbom_entries, code_context) = self.gather_repo_context(target).await;
// Build system prompt with code context
let system_prompt = self
.build_system_prompt(session, target, &sast_findings, &sbom_entries, &code_context)
.build_system_prompt(
session,
target,
&sast_findings,
&sbom_entries,
&code_context,
)
.await;
// Build tool definitions for LLM
@@ -182,8 +182,7 @@ impl PentestOrchestrator {
match response {
LlmResponse::Content(content) => {
let msg =
PentestMessage::assistant(session_id.clone(), content.clone());
let msg = PentestMessage::assistant(session_id.clone(), content.clone());
let _ = self.db.pentest_messages().insert_one(&msg).await;
let _ = self.event_tx.send(PentestEvent::Message {
content: content.clone(),
@@ -213,7 +212,10 @@ impl PentestOrchestrator {
}
break;
}
LlmResponse::ToolCalls { calls: tool_calls, reasoning } => {
LlmResponse::ToolCalls {
calls: tool_calls,
reasoning,
} => {
let tc_requests: Vec<ToolCallRequest> = tool_calls
.iter()
.map(|tc| ToolCallRequest {
@@ -221,15 +223,18 @@ impl PentestOrchestrator {
r#type: "function".to_string(),
function: ToolCallRequestFunction {
name: tc.name.clone(),
arguments: serde_json::to_string(&tc.arguments)
.unwrap_or_default(),
arguments: serde_json::to_string(&tc.arguments).unwrap_or_default(),
},
})
.collect();
messages.push(ChatMessage {
role: "assistant".to_string(),
content: if reasoning.is_empty() { None } else { Some(reasoning.clone()) },
content: if reasoning.is_empty() {
None
} else {
Some(reasoning.clone())
},
tool_calls: Some(tc_requests),
tool_call_id: None,
});
@@ -274,24 +279,30 @@ impl PentestOrchestrator {
let insert_result =
self.db.dast_findings().insert_one(&finding).await;
if let Ok(res) = &insert_result {
finding_ids.push(res.inserted_id.as_object_id().map(|oid| oid.to_hex()).unwrap_or_default());
}
let _ =
self.event_tx.send(PentestEvent::Finding {
finding_id: finding
.id
finding_ids.push(
res.inserted_id
.as_object_id()
.map(|oid| oid.to_hex())
.unwrap_or_default(),
title: finding.title.clone(),
severity: finding.severity.to_string(),
});
);
}
let _ = self.event_tx.send(PentestEvent::Finding {
finding_id: finding
.id
.map(|oid| oid.to_hex())
.unwrap_or_default(),
title: finding.title.clone(),
severity: finding.severity.to_string(),
});
}
// Compute risk score based on findings severity
let risk_score: Option<u8> = if findings_count > 0 {
Some(std::cmp::min(
100,
(findings_count as u8).saturating_mul(15).saturating_add(20),
(findings_count as u8)
.saturating_mul(15)
.saturating_add(20),
))
} else {
None
@@ -415,347 +426,4 @@ impl PentestOrchestrator {
Ok(())
}
// ── Code-Awareness: Gather context from linked repo ─────────
/// Fetch SAST findings, SBOM entries (with CVEs), and code graph entry points
/// for the repo linked to this DAST target.
async fn gather_repo_context(
&self,
target: &DastTarget,
) -> (Vec<Finding>, Vec<SbomEntry>, Vec<CodeContextHint>) {
let Some(repo_id) = &target.repo_id else {
return (Vec::new(), Vec::new(), Vec::new());
};
let sast_findings = self.fetch_sast_findings(repo_id).await;
let sbom_entries = self.fetch_vulnerable_sbom(repo_id).await;
let code_context = self.fetch_code_context(repo_id, &sast_findings).await;
tracing::info!(
repo_id,
sast_findings = sast_findings.len(),
vulnerable_deps = sbom_entries.len(),
code_hints = code_context.len(),
"Gathered code-awareness context for pentest"
);
(sast_findings, sbom_entries, code_context)
}
/// Fetch open/triaged SAST findings for the repo (not false positives or resolved)
async fn fetch_sast_findings(&self, repo_id: &str) -> Vec<Finding> {
let cursor = self
.db
.findings()
.find(doc! {
"repo_id": repo_id,
"status": { "$in": ["open", "triaged"] },
})
.sort(doc! { "severity": -1 })
.limit(100)
.await;
match cursor {
Ok(mut c) => {
let mut results = Vec::new();
while let Some(Ok(f)) = c.next().await {
results.push(f);
}
results
}
Err(e) => {
tracing::warn!("Failed to fetch SAST findings for pentest: {e}");
Vec::new()
}
}
}
/// Fetch SBOM entries that have known vulnerabilities
async fn fetch_vulnerable_sbom(&self, repo_id: &str) -> Vec<SbomEntry> {
let cursor = self
.db
.sbom_entries()
.find(doc! {
"repo_id": repo_id,
"known_vulnerabilities": { "$exists": true, "$ne": [] },
})
.limit(50)
.await;
match cursor {
Ok(mut c) => {
let mut results = Vec::new();
while let Some(Ok(e)) = c.next().await {
results.push(e);
}
results
}
Err(e) => {
tracing::warn!("Failed to fetch vulnerable SBOM entries: {e}");
Vec::new()
}
}
}
/// Build CodeContextHint objects from the code knowledge graph.
/// Maps entry points to their source files and links SAST findings.
async fn fetch_code_context(
&self,
repo_id: &str,
sast_findings: &[Finding],
) -> Vec<CodeContextHint> {
// Get entry point nodes from the code graph
let cursor = self
.db
.graph_nodes()
.find(doc! {
"repo_id": repo_id,
"is_entry_point": true,
})
.limit(50)
.await;
let nodes = match cursor {
Ok(mut c) => {
let mut results = Vec::new();
while let Some(Ok(n)) = c.next().await {
results.push(n);
}
results
}
Err(_) => return Vec::new(),
};
// Build hints by matching graph nodes to SAST findings by file path
nodes
.into_iter()
.map(|node| {
// Find SAST findings in the same file
let linked_vulns: Vec<String> = sast_findings
.iter()
.filter(|f| {
f.file_path.as_deref() == Some(&node.file_path)
})
.map(|f| {
format!(
"[{}] {}: {} (line {})",
f.severity,
f.scanner,
f.title,
f.line_number.unwrap_or(0)
)
})
.collect();
CodeContextHint {
endpoint_pattern: node.qualified_name.clone(),
handler_function: node.name.clone(),
file_path: node.file_path.clone(),
code_snippet: String::new(), // Could fetch from embeddings
known_vulnerabilities: linked_vulns,
}
})
.collect()
}
// ── System Prompt Builder ───────────────────────────────────
async fn build_system_prompt(
&self,
session: &PentestSession,
target: &DastTarget,
sast_findings: &[Finding],
sbom_entries: &[SbomEntry],
code_context: &[CodeContextHint],
) -> String {
let tool_names = self.tool_registry.list_names().join(", ");
let strategy_guidance = match session.strategy {
PentestStrategy::Quick => {
"Focus on the most common and impactful vulnerabilities. Run a quick recon, then target the highest-risk areas."
}
PentestStrategy::Comprehensive => {
"Perform a thorough assessment covering all vulnerability types. Start with recon, then systematically test each attack surface."
}
PentestStrategy::Targeted => {
"Focus specifically on areas highlighted by SAST findings and known CVEs. Prioritize exploiting known weaknesses."
}
PentestStrategy::Aggressive => {
"Use all available tools aggressively. Test with maximum payloads and attempt full exploitation."
}
PentestStrategy::Stealth => {
"Minimize noise. Use fewer requests, avoid aggressive payloads. Focus on passive analysis and targeted probes."
}
};
// Build SAST findings section
let sast_section = if sast_findings.is_empty() {
String::from("No SAST findings available for this target.")
} else {
let critical = sast_findings
.iter()
.filter(|f| f.severity == Severity::Critical)
.count();
let high = sast_findings
.iter()
.filter(|f| f.severity == Severity::High)
.count();
let mut section = format!(
"{} open findings ({} critical, {} high):\n",
sast_findings.len(),
critical,
high
);
// List the most important findings (critical/high first, up to 20)
for f in sast_findings.iter().take(20) {
let file_info = f
.file_path
.as_ref()
.map(|p| {
format!(
" in {}:{}",
p,
f.line_number.unwrap_or(0)
)
})
.unwrap_or_default();
let status_note = match f.status {
FindingStatus::Triaged => " [TRIAGED]",
_ => "",
};
section.push_str(&format!(
"- [{sev}] {title}{file}{status}\n",
sev = f.severity,
title = f.title,
file = file_info,
status = status_note,
));
if let Some(cwe) = &f.cwe {
section.push_str(&format!(" CWE: {cwe}\n"));
}
}
if sast_findings.len() > 20 {
section.push_str(&format!(
"... and {} more findings\n",
sast_findings.len() - 20
));
}
section
};
// Build SBOM/CVE section
let sbom_section = if sbom_entries.is_empty() {
String::from("No vulnerable dependencies identified.")
} else {
let mut section = format!(
"{} dependencies with known vulnerabilities:\n",
sbom_entries.len()
);
for entry in sbom_entries.iter().take(15) {
let cve_ids: Vec<&str> = entry
.known_vulnerabilities
.iter()
.map(|v| v.id.as_str())
.collect();
section.push_str(&format!(
"- {} {} ({}): {}\n",
entry.name,
entry.version,
entry.package_manager,
cve_ids.join(", ")
));
}
if sbom_entries.len() > 15 {
section.push_str(&format!(
"... and {} more vulnerable dependencies\n",
sbom_entries.len() - 15
));
}
section
};
// Build code context section
let code_section = if code_context.is_empty() {
String::from("No code knowledge graph available for this target.")
} else {
let with_vulns = code_context
.iter()
.filter(|c| !c.known_vulnerabilities.is_empty())
.count();
let mut section = format!(
"{} entry points identified ({} with linked SAST findings):\n",
code_context.len(),
with_vulns
);
for hint in code_context.iter().take(20) {
section.push_str(&format!(
"- {} ({})\n",
hint.endpoint_pattern, hint.file_path
));
for vuln in &hint.known_vulnerabilities {
section.push_str(&format!(" SAST: {vuln}\n"));
}
}
section
};
format!(
r#"You are an expert penetration tester conducting an authorized security assessment.
## Target
- **Name**: {target_name}
- **URL**: {base_url}
- **Type**: {target_type}
- **Rate Limit**: {rate_limit} req/s
- **Destructive Tests Allowed**: {allow_destructive}
- **Linked Repository**: {repo_linked}
## Strategy
{strategy_guidance}
## SAST Findings (Static Analysis)
{sast_section}
## Vulnerable Dependencies (SBOM)
{sbom_section}
## Code Entry Points (Knowledge Graph)
{code_section}
## Available Tools
{tool_names}
## Instructions
1. Start by running reconnaissance (recon tool) to fingerprint the target and discover technologies.
2. Run the OpenAPI parser to discover API endpoints from specs.
3. Check infrastructure: DNS, DMARC, TLS, security headers, cookies, CSP, CORS.
4. Based on SAST findings, prioritize testing endpoints where vulnerabilities were found in code.
5. For each vulnerability type found in SAST, use the corresponding DAST tool to verify exploitability.
6. If vulnerable dependencies are listed, try to trigger known CVE conditions against the running application.
7. Test rate limiting on critical endpoints (login, API).
8. Check for console.log leakage in frontend JavaScript.
9. Analyze tool results and chain findings — if one vulnerability enables others, explore the chain.
10. When testing is complete, provide a structured summary with severity and remediation.
11. Always explain your reasoning before invoking each tool.
12. When done, say "Testing complete" followed by a final summary.
## Important
- This is an authorized penetration test. All testing is permitted within the target scope.
- Respect the rate limit of {rate_limit} requests per second.
- Only use destructive tests if explicitly allowed ({allow_destructive}).
- Use SAST findings to guide your testing — they tell you WHERE in the code vulnerabilities exist.
- Use SBOM data to understand what technologies and versions the target runs.
"#,
target_name = target.name,
base_url = target.base_url,
target_type = target.target_type,
rate_limit = target.rate_limit,
allow_destructive = target.allow_destructive,
repo_linked = target.repo_id.as_deref().unwrap_or("None"),
)
}
}