feat: AI-driven automated penetration testing (#12)
Some checks failed
CI / Clippy (push) Failing after 1m51s
CI / Security Audit (push) Successful in 2m1s
CI / Tests (push) Has been skipped
CI / Detect Changes (push) Has been skipped
CI / Deploy Agent (push) Has been skipped
CI / Deploy Dashboard (push) Has been skipped
CI / Deploy Docs (push) Has been skipped
CI / Format (push) Failing after 42s
CI / Deploy MCP (push) Has been skipped

This commit was merged in pull request #12.
This commit is contained in:
2026-03-12 14:42:54 +00:00
parent 3ec1456b0d
commit acc5b86aa4
52 changed files with 11729 additions and 98 deletions

View File

@@ -0,0 +1,761 @@
use std::sync::Arc;
use std::time::Duration;
use futures_util::StreamExt;
use mongodb::bson::doc;
use tokio::sync::broadcast;
use compliance_core::models::dast::DastTarget;
use compliance_core::models::finding::{Finding, FindingStatus, Severity};
use compliance_core::models::pentest::*;
use compliance_core::models::sbom::SbomEntry;
use compliance_core::traits::pentest_tool::PentestToolContext;
use compliance_dast::ToolRegistry;
use crate::database::Database;
use crate::llm::client::{
ChatMessage, LlmResponse, ToolCallRequest, ToolCallRequestFunction, ToolDefinition,
};
use crate::llm::LlmClient;
/// Maximum duration for a single pentest session before timeout
const SESSION_TIMEOUT: Duration = Duration::from_secs(30 * 60); // 30 minutes
pub struct PentestOrchestrator {
tool_registry: ToolRegistry,
llm: Arc<LlmClient>,
db: Database,
event_tx: broadcast::Sender<PentestEvent>,
}
impl PentestOrchestrator {
pub fn new(llm: Arc<LlmClient>, db: Database) -> Self {
let (event_tx, _) = broadcast::channel(256);
Self {
tool_registry: ToolRegistry::new(),
llm,
db,
event_tx,
}
}
pub fn subscribe(&self) -> broadcast::Receiver<PentestEvent> {
self.event_tx.subscribe()
}
pub fn event_sender(&self) -> broadcast::Sender<PentestEvent> {
self.event_tx.clone()
}
/// Run a pentest session with timeout and automatic failure marking on errors.
pub async fn run_session_guarded(
&self,
session: &PentestSession,
target: &DastTarget,
initial_message: &str,
) {
let session_id = session.id;
match tokio::time::timeout(
SESSION_TIMEOUT,
self.run_session(session, target, initial_message),
)
.await
{
Ok(Ok(())) => {
tracing::info!(?session_id, "Pentest session completed successfully");
}
Ok(Err(e)) => {
tracing::error!(?session_id, error = %e, "Pentest session failed");
self.mark_session_failed(session_id, &format!("Error: {e}"))
.await;
let _ = self.event_tx.send(PentestEvent::Error {
message: format!("Session failed: {e}"),
});
}
Err(_) => {
tracing::warn!(?session_id, "Pentest session timed out after 30 minutes");
self.mark_session_failed(session_id, "Session timed out after 30 minutes")
.await;
let _ = self.event_tx.send(PentestEvent::Error {
message: "Session timed out after 30 minutes".to_string(),
});
}
}
}
async fn mark_session_failed(
&self,
session_id: Option<mongodb::bson::oid::ObjectId>,
reason: &str,
) {
if let Some(sid) = session_id {
let _ = self
.db
.pentest_sessions()
.update_one(
doc! { "_id": sid },
doc! { "$set": {
"status": "failed",
"completed_at": mongodb::bson::DateTime::now(),
"error_message": reason,
}},
)
.await;
}
}
async fn run_session(
&self,
session: &PentestSession,
target: &DastTarget,
initial_message: &str,
) -> Result<(), crate::error::AgentError> {
let session_id = session
.id
.map(|oid| oid.to_hex())
.unwrap_or_default();
// Gather code-awareness context from linked repo
let (sast_findings, sbom_entries, code_context) =
self.gather_repo_context(target).await;
// Build system prompt with code context
let system_prompt = self
.build_system_prompt(session, target, &sast_findings, &sbom_entries, &code_context)
.await;
// Build tool definitions for LLM
let tool_defs: Vec<ToolDefinition> = self
.tool_registry
.all_definitions()
.into_iter()
.map(|td| ToolDefinition {
name: td.name,
description: td.description,
parameters: td.input_schema,
})
.collect();
// Initialize messages
let mut messages = vec![
ChatMessage {
role: "system".to_string(),
content: Some(system_prompt),
tool_calls: None,
tool_call_id: None,
},
ChatMessage {
role: "user".to_string(),
content: Some(initial_message.to_string()),
tool_calls: None,
tool_call_id: None,
},
];
// Store user message
let user_msg = PentestMessage::user(session_id.clone(), initial_message.to_string());
let _ = self.db.pentest_messages().insert_one(&user_msg).await;
// Build tool context with real data
let tool_context = PentestToolContext {
target: target.clone(),
session_id: session_id.clone(),
sast_findings,
sbom_entries,
code_context,
rate_limit: target.rate_limit,
allow_destructive: target.allow_destructive,
};
let max_iterations = 50;
let mut total_findings = 0u32;
let mut total_tool_calls = 0u32;
let mut total_successes = 0u32;
let mut prev_node_ids: Vec<String> = Vec::new();
for _iteration in 0..max_iterations {
let response = self
.llm
.chat_with_tools(messages.clone(), &tool_defs, Some(0.2), Some(8192))
.await?;
match response {
LlmResponse::Content(content) => {
let msg =
PentestMessage::assistant(session_id.clone(), content.clone());
let _ = self.db.pentest_messages().insert_one(&msg).await;
let _ = self.event_tx.send(PentestEvent::Message {
content: content.clone(),
});
messages.push(ChatMessage {
role: "assistant".to_string(),
content: Some(content.clone()),
tool_calls: None,
tool_call_id: None,
});
let done_indicators = [
"pentest complete",
"testing complete",
"scan complete",
"analysis complete",
"finished",
"that concludes",
];
let content_lower = content.to_lowercase();
if done_indicators
.iter()
.any(|ind| content_lower.contains(ind))
{
break;
}
break;
}
LlmResponse::ToolCalls { calls: tool_calls, reasoning } => {
let tc_requests: Vec<ToolCallRequest> = tool_calls
.iter()
.map(|tc| ToolCallRequest {
id: tc.id.clone(),
r#type: "function".to_string(),
function: ToolCallRequestFunction {
name: tc.name.clone(),
arguments: serde_json::to_string(&tc.arguments)
.unwrap_or_default(),
},
})
.collect();
messages.push(ChatMessage {
role: "assistant".to_string(),
content: if reasoning.is_empty() { None } else { Some(reasoning.clone()) },
tool_calls: Some(tc_requests),
tool_call_id: None,
});
let mut current_batch_node_ids: Vec<String> = Vec::new();
for tc in &tool_calls {
total_tool_calls += 1;
let node_id = uuid::Uuid::new_v4().to_string();
let mut node = AttackChainNode::new(
session_id.clone(),
node_id.clone(),
tc.name.clone(),
tc.arguments.clone(),
reasoning.clone(),
);
// Link to previous iteration's nodes
node.parent_node_ids = prev_node_ids.clone();
node.status = AttackNodeStatus::Running;
node.started_at = Some(chrono::Utc::now());
let _ = self.db.attack_chain_nodes().insert_one(&node).await;
current_batch_node_ids.push(node_id.clone());
let _ = self.event_tx.send(PentestEvent::ToolStart {
node_id: node_id.clone(),
tool_name: tc.name.clone(),
input: tc.arguments.clone(),
});
let result = if let Some(tool) = self.tool_registry.get(&tc.name) {
match tool.execute(tc.arguments.clone(), &tool_context).await {
Ok(result) => {
total_successes += 1;
let findings_count = result.findings.len() as u32;
total_findings += findings_count;
let mut finding_ids: Vec<String> = Vec::new();
for mut finding in result.findings {
finding.scan_run_id = session_id.clone();
finding.session_id = Some(session_id.clone());
let insert_result =
self.db.dast_findings().insert_one(&finding).await;
if let Ok(res) = &insert_result {
finding_ids.push(res.inserted_id.as_object_id().map(|oid| oid.to_hex()).unwrap_or_default());
}
let _ =
self.event_tx.send(PentestEvent::Finding {
finding_id: finding
.id
.map(|oid| oid.to_hex())
.unwrap_or_default(),
title: finding.title.clone(),
severity: finding.severity.to_string(),
});
}
// Compute risk score based on findings severity
let risk_score: Option<u8> = if findings_count > 0 {
Some(std::cmp::min(
100,
(findings_count as u8).saturating_mul(15).saturating_add(20),
))
} else {
None
};
let _ = self.event_tx.send(PentestEvent::ToolComplete {
node_id: node_id.clone(),
summary: result.summary.clone(),
findings_count,
});
let finding_ids_bson: Vec<mongodb::bson::Bson> = finding_ids
.iter()
.map(|id| mongodb::bson::Bson::String(id.clone()))
.collect();
let mut update_doc = doc! {
"status": "completed",
"tool_output": mongodb::bson::to_bson(&result.data)
.unwrap_or(mongodb::bson::Bson::Null),
"completed_at": mongodb::bson::DateTime::now(),
"findings_produced": finding_ids_bson,
};
if let Some(rs) = risk_score {
update_doc.insert("risk_score", rs as i32);
}
let _ = self
.db
.attack_chain_nodes()
.update_one(
doc! {
"session_id": &session_id,
"node_id": &node_id,
},
doc! { "$set": update_doc },
)
.await;
serde_json::json!({
"summary": result.summary,
"findings_count": findings_count,
"data": result.data,
})
.to_string()
}
Err(e) => {
let _ = self
.db
.attack_chain_nodes()
.update_one(
doc! {
"session_id": &session_id,
"node_id": &node_id,
},
doc! { "$set": {
"status": "failed",
"completed_at": mongodb::bson::DateTime::now(),
}},
)
.await;
format!("Tool execution failed: {e}")
}
}
} else {
format!("Unknown tool: {}", tc.name)
};
messages.push(ChatMessage {
role: "tool".to_string(),
content: Some(result),
tool_calls: None,
tool_call_id: Some(tc.id.clone()),
});
}
// Advance parent links so next iteration's nodes connect to this batch
prev_node_ids = current_batch_node_ids;
if let Some(sid) = session.id {
let _ = self
.db
.pentest_sessions()
.update_one(
doc! { "_id": sid },
doc! { "$set": {
"tool_invocations": total_tool_calls as i64,
"tool_successes": total_successes as i64,
"findings_count": total_findings as i64,
}},
)
.await;
}
}
}
}
if let Some(sid) = session.id {
let _ = self
.db
.pentest_sessions()
.update_one(
doc! { "_id": sid },
doc! { "$set": {
"status": "completed",
"completed_at": mongodb::bson::DateTime::now(),
"tool_invocations": total_tool_calls as i64,
"tool_successes": total_successes as i64,
"findings_count": total_findings as i64,
}},
)
.await;
}
let _ = self.event_tx.send(PentestEvent::Complete {
summary: format!(
"Pentest complete. {} findings from {} tool invocations.",
total_findings, total_tool_calls
),
});
Ok(())
}
// ── Code-Awareness: Gather context from linked repo ─────────
/// Fetch SAST findings, SBOM entries (with CVEs), and code graph entry points
/// for the repo linked to this DAST target.
async fn gather_repo_context(
&self,
target: &DastTarget,
) -> (Vec<Finding>, Vec<SbomEntry>, Vec<CodeContextHint>) {
let Some(repo_id) = &target.repo_id else {
return (Vec::new(), Vec::new(), Vec::new());
};
let sast_findings = self.fetch_sast_findings(repo_id).await;
let sbom_entries = self.fetch_vulnerable_sbom(repo_id).await;
let code_context = self.fetch_code_context(repo_id, &sast_findings).await;
tracing::info!(
repo_id,
sast_findings = sast_findings.len(),
vulnerable_deps = sbom_entries.len(),
code_hints = code_context.len(),
"Gathered code-awareness context for pentest"
);
(sast_findings, sbom_entries, code_context)
}
/// Fetch open/triaged SAST findings for the repo (not false positives or resolved)
async fn fetch_sast_findings(&self, repo_id: &str) -> Vec<Finding> {
let cursor = self
.db
.findings()
.find(doc! {
"repo_id": repo_id,
"status": { "$in": ["open", "triaged"] },
})
.sort(doc! { "severity": -1 })
.limit(100)
.await;
match cursor {
Ok(mut c) => {
let mut results = Vec::new();
while let Some(Ok(f)) = c.next().await {
results.push(f);
}
results
}
Err(e) => {
tracing::warn!("Failed to fetch SAST findings for pentest: {e}");
Vec::new()
}
}
}
/// Fetch SBOM entries that have known vulnerabilities
async fn fetch_vulnerable_sbom(&self, repo_id: &str) -> Vec<SbomEntry> {
let cursor = self
.db
.sbom_entries()
.find(doc! {
"repo_id": repo_id,
"known_vulnerabilities": { "$exists": true, "$ne": [] },
})
.limit(50)
.await;
match cursor {
Ok(mut c) => {
let mut results = Vec::new();
while let Some(Ok(e)) = c.next().await {
results.push(e);
}
results
}
Err(e) => {
tracing::warn!("Failed to fetch vulnerable SBOM entries: {e}");
Vec::new()
}
}
}
/// Build CodeContextHint objects from the code knowledge graph.
/// Maps entry points to their source files and links SAST findings.
async fn fetch_code_context(
&self,
repo_id: &str,
sast_findings: &[Finding],
) -> Vec<CodeContextHint> {
// Get entry point nodes from the code graph
let cursor = self
.db
.graph_nodes()
.find(doc! {
"repo_id": repo_id,
"is_entry_point": true,
})
.limit(50)
.await;
let nodes = match cursor {
Ok(mut c) => {
let mut results = Vec::new();
while let Some(Ok(n)) = c.next().await {
results.push(n);
}
results
}
Err(_) => return Vec::new(),
};
// Build hints by matching graph nodes to SAST findings by file path
nodes
.into_iter()
.map(|node| {
// Find SAST findings in the same file
let linked_vulns: Vec<String> = sast_findings
.iter()
.filter(|f| {
f.file_path.as_deref() == Some(&node.file_path)
})
.map(|f| {
format!(
"[{}] {}: {} (line {})",
f.severity,
f.scanner,
f.title,
f.line_number.unwrap_or(0)
)
})
.collect();
CodeContextHint {
endpoint_pattern: node.qualified_name.clone(),
handler_function: node.name.clone(),
file_path: node.file_path.clone(),
code_snippet: String::new(), // Could fetch from embeddings
known_vulnerabilities: linked_vulns,
}
})
.collect()
}
// ── System Prompt Builder ───────────────────────────────────
async fn build_system_prompt(
&self,
session: &PentestSession,
target: &DastTarget,
sast_findings: &[Finding],
sbom_entries: &[SbomEntry],
code_context: &[CodeContextHint],
) -> String {
let tool_names = self.tool_registry.list_names().join(", ");
let strategy_guidance = match session.strategy {
PentestStrategy::Quick => {
"Focus on the most common and impactful vulnerabilities. Run a quick recon, then target the highest-risk areas."
}
PentestStrategy::Comprehensive => {
"Perform a thorough assessment covering all vulnerability types. Start with recon, then systematically test each attack surface."
}
PentestStrategy::Targeted => {
"Focus specifically on areas highlighted by SAST findings and known CVEs. Prioritize exploiting known weaknesses."
}
PentestStrategy::Aggressive => {
"Use all available tools aggressively. Test with maximum payloads and attempt full exploitation."
}
PentestStrategy::Stealth => {
"Minimize noise. Use fewer requests, avoid aggressive payloads. Focus on passive analysis and targeted probes."
}
};
// Build SAST findings section
let sast_section = if sast_findings.is_empty() {
String::from("No SAST findings available for this target.")
} else {
let critical = sast_findings
.iter()
.filter(|f| f.severity == Severity::Critical)
.count();
let high = sast_findings
.iter()
.filter(|f| f.severity == Severity::High)
.count();
let mut section = format!(
"{} open findings ({} critical, {} high):\n",
sast_findings.len(),
critical,
high
);
// List the most important findings (critical/high first, up to 20)
for f in sast_findings.iter().take(20) {
let file_info = f
.file_path
.as_ref()
.map(|p| {
format!(
" in {}:{}",
p,
f.line_number.unwrap_or(0)
)
})
.unwrap_or_default();
let status_note = match f.status {
FindingStatus::Triaged => " [TRIAGED]",
_ => "",
};
section.push_str(&format!(
"- [{sev}] {title}{file}{status}\n",
sev = f.severity,
title = f.title,
file = file_info,
status = status_note,
));
if let Some(cwe) = &f.cwe {
section.push_str(&format!(" CWE: {cwe}\n"));
}
}
if sast_findings.len() > 20 {
section.push_str(&format!(
"... and {} more findings\n",
sast_findings.len() - 20
));
}
section
};
// Build SBOM/CVE section
let sbom_section = if sbom_entries.is_empty() {
String::from("No vulnerable dependencies identified.")
} else {
let mut section = format!(
"{} dependencies with known vulnerabilities:\n",
sbom_entries.len()
);
for entry in sbom_entries.iter().take(15) {
let cve_ids: Vec<&str> = entry
.known_vulnerabilities
.iter()
.map(|v| v.id.as_str())
.collect();
section.push_str(&format!(
"- {} {} ({}): {}\n",
entry.name,
entry.version,
entry.package_manager,
cve_ids.join(", ")
));
}
if sbom_entries.len() > 15 {
section.push_str(&format!(
"... and {} more vulnerable dependencies\n",
sbom_entries.len() - 15
));
}
section
};
// Build code context section
let code_section = if code_context.is_empty() {
String::from("No code knowledge graph available for this target.")
} else {
let with_vulns = code_context
.iter()
.filter(|c| !c.known_vulnerabilities.is_empty())
.count();
let mut section = format!(
"{} entry points identified ({} with linked SAST findings):\n",
code_context.len(),
with_vulns
);
for hint in code_context.iter().take(20) {
section.push_str(&format!(
"- {} ({})\n",
hint.endpoint_pattern, hint.file_path
));
for vuln in &hint.known_vulnerabilities {
section.push_str(&format!(" SAST: {vuln}\n"));
}
}
section
};
format!(
r#"You are an expert penetration tester conducting an authorized security assessment.
## Target
- **Name**: {target_name}
- **URL**: {base_url}
- **Type**: {target_type}
- **Rate Limit**: {rate_limit} req/s
- **Destructive Tests Allowed**: {allow_destructive}
- **Linked Repository**: {repo_linked}
## Strategy
{strategy_guidance}
## SAST Findings (Static Analysis)
{sast_section}
## Vulnerable Dependencies (SBOM)
{sbom_section}
## Code Entry Points (Knowledge Graph)
{code_section}
## Available Tools
{tool_names}
## Instructions
1. Start by running reconnaissance (recon tool) to fingerprint the target and discover technologies.
2. Run the OpenAPI parser to discover API endpoints from specs.
3. Check infrastructure: DNS, DMARC, TLS, security headers, cookies, CSP, CORS.
4. Based on SAST findings, prioritize testing endpoints where vulnerabilities were found in code.
5. For each vulnerability type found in SAST, use the corresponding DAST tool to verify exploitability.
6. If vulnerable dependencies are listed, try to trigger known CVE conditions against the running application.
7. Test rate limiting on critical endpoints (login, API).
8. Check for console.log leakage in frontend JavaScript.
9. Analyze tool results and chain findings — if one vulnerability enables others, explore the chain.
10. When testing is complete, provide a structured summary with severity and remediation.
11. Always explain your reasoning before invoking each tool.
12. When done, say "Testing complete" followed by a final summary.
## Important
- This is an authorized penetration test. All testing is permitted within the target scope.
- Respect the rate limit of {rate_limit} requests per second.
- Only use destructive tests if explicitly allowed ({allow_destructive}).
- Use SAST findings to guide your testing — they tell you WHERE in the code vulnerabilities exist.
- Use SBOM data to understand what technologies and versions the target runs.
"#,
target_name = target.name,
base_url = target.base_url,
target_type = target.target_type,
rate_limit = target.rate_limit,
allow_destructive = target.allow_destructive,
repo_linked = target.repo_id.as_deref().unwrap_or("None"),
)
}
}