feat: AI-driven automated penetration testing system

Add a complete AI pentest system where Claude autonomously drives security
testing via tool-calling. The LLM selects from 16 tools, chains results,
and builds an attack chain DAG.

Core:
- PentestTool trait (dyn-compatible) with PentestToolContext/Result
- PentestSession, AttackChainNode, PentestMessage, PentestEvent models
- 10 new DastVulnType variants (DNS, DMARC, TLS, cookies, CSP, CORS, etc.)
- LLM client chat_with_tools() for OpenAI-compatible tool calling

Tools (16 total):
- 5 agent wrappers: SQL injection, XSS, auth bypass, SSRF, API fuzzer
- 11 new infra tools: DNS checker, DMARC checker, TLS analyzer,
  security headers, cookie analyzer, CSP analyzer, rate limit tester,
  console log detector, CORS checker, OpenAPI parser, recon
- ToolRegistry for tool lookup and LLM definition generation

Orchestrator:
- PentestOrchestrator with iterative tool-calling loop (max 50 rounds)
- Attack chain node recording per tool invocation
- SSE event broadcasting for real-time progress
- Strategy-aware system prompts (quick/comprehensive/targeted/aggressive/stealth)

API (9 endpoints):
- POST/GET /pentest/sessions, GET /pentest/sessions/:id
- POST /pentest/sessions/:id/chat, GET /pentest/sessions/:id/stream
- GET /pentest/sessions/:id/attack-chain, messages, findings
- GET /pentest/stats

Dashboard:
- Pentest dashboard with stat cards, severity distribution, session list
- Chat-based session page with split layout (chat + findings/attack chain)
- Inline tool execution indicators, auto-polling, new session modal
- Sidebar navigation item

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Sharang Parnerkar
2026-03-11 19:23:21 +01:00
parent 76260acc76
commit 71d8741e10
40 changed files with 7546 additions and 90 deletions

View File

@@ -1,6 +1,7 @@
pub mod chat;
pub mod dast;
pub mod graph;
pub mod pentest;
use std::sync::Arc;
@@ -1108,7 +1109,7 @@ pub async fn list_scan_runs(
}))
}
async fn collect_cursor_async<T: serde::de::DeserializeOwned + Unpin + Send>(
pub(crate) async fn collect_cursor_async<T: serde::de::DeserializeOwned + Unpin + Send>(
mut cursor: mongodb::Cursor<T>,
) -> Vec<T> {
use futures_util::StreamExt;

View File

@@ -0,0 +1,564 @@
use std::sync::Arc;
use axum::extract::{Extension, Path, Query};
use axum::http::StatusCode;
use axum::response::sse::{Event, Sse};
use axum::Json;
use futures_util::stream;
use mongodb::bson::doc;
use serde::Deserialize;
use compliance_core::models::dast::DastFinding;
use compliance_core::models::pentest::*;
use crate::agent::ComplianceAgent;
use crate::pentest::PentestOrchestrator;
use super::{collect_cursor_async, ApiResponse, PaginationParams};
type AgentExt = Extension<Arc<ComplianceAgent>>;
#[derive(Deserialize)]
pub struct CreateSessionRequest {
pub target_id: String,
#[serde(default = "default_strategy")]
pub strategy: String,
pub message: Option<String>,
}
fn default_strategy() -> String {
"comprehensive".to_string()
}
#[derive(Deserialize)]
pub struct SendMessageRequest {
pub message: String,
}
/// POST /api/v1/pentest/sessions — Create a new pentest session and start the orchestrator
#[tracing::instrument(skip_all)]
pub async fn create_session(
Extension(agent): AgentExt,
Json(req): Json<CreateSessionRequest>,
) -> Result<Json<ApiResponse<PentestSession>>, (StatusCode, String)> {
let oid = mongodb::bson::oid::ObjectId::parse_str(&req.target_id).map_err(|_| {
(
StatusCode::BAD_REQUEST,
"Invalid target_id format".to_string(),
)
})?;
// Look up the target
let target = agent
.db
.dast_targets()
.find_one(doc! { "_id": oid })
.await
.map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("Database error: {e}"),
)
})?
.ok_or_else(|| (StatusCode::NOT_FOUND, "Target not found".to_string()))?;
// Parse strategy
let strategy = match req.strategy.as_str() {
"quick" => PentestStrategy::Quick,
"targeted" => PentestStrategy::Targeted,
"aggressive" => PentestStrategy::Aggressive,
"stealth" => PentestStrategy::Stealth,
_ => PentestStrategy::Comprehensive,
};
// Create session
let mut session = PentestSession::new(req.target_id.clone(), strategy);
session.repo_id = target.repo_id.clone();
agent
.db
.pentest_sessions()
.insert_one(&session)
.await
.map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("Failed to create session: {e}"),
)
})?;
let initial_message = req.message.unwrap_or_else(|| {
format!(
"Begin a {} penetration test against {} ({}). \
Identify vulnerabilities and provide evidence for each finding.",
session.strategy, target.name, target.base_url,
)
});
// Spawn the orchestrator on a background task
let llm = agent.llm.clone();
let db = agent.db.clone();
let session_clone = session.clone();
let target_clone = target.clone();
tokio::spawn(async move {
let orchestrator = PentestOrchestrator::new(llm, db);
if let Err(e) = orchestrator
.run_session(&session_clone, &target_clone, &initial_message)
.await
{
tracing::error!(
"Pentest orchestrator failed for session {}: {e}",
session_clone
.id
.map(|oid| oid.to_hex())
.unwrap_or_default()
);
}
});
Ok(Json(ApiResponse {
data: session,
total: None,
page: None,
}))
}
/// GET /api/v1/pentest/sessions — List pentest sessions
#[tracing::instrument(skip_all)]
pub async fn list_sessions(
Extension(agent): AgentExt,
Query(params): Query<PaginationParams>,
) -> Result<Json<ApiResponse<Vec<PentestSession>>>, StatusCode> {
let db = &agent.db;
let skip = (params.page.saturating_sub(1)) * params.limit as u64;
let total = db
.pentest_sessions()
.count_documents(doc! {})
.await
.unwrap_or(0);
let sessions = match db
.pentest_sessions()
.find(doc! {})
.sort(doc! { "started_at": -1 })
.skip(skip)
.limit(params.limit)
.await
{
Ok(cursor) => collect_cursor_async(cursor).await,
Err(e) => {
tracing::warn!("Failed to fetch pentest sessions: {e}");
Vec::new()
}
};
Ok(Json(ApiResponse {
data: sessions,
total: Some(total),
page: Some(params.page),
}))
}
/// GET /api/v1/pentest/sessions/:id — Get a single pentest session
#[tracing::instrument(skip_all, fields(session_id = %id))]
pub async fn get_session(
Extension(agent): AgentExt,
Path(id): Path<String>,
) -> Result<Json<ApiResponse<PentestSession>>, StatusCode> {
let oid =
mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?;
let session = agent
.db
.pentest_sessions()
.find_one(doc! { "_id": oid })
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
.ok_or(StatusCode::NOT_FOUND)?;
Ok(Json(ApiResponse {
data: session,
total: None,
page: None,
}))
}
/// POST /api/v1/pentest/sessions/:id/chat — Send a user message and trigger next orchestrator iteration
#[tracing::instrument(skip_all, fields(session_id = %id))]
pub async fn send_message(
Extension(agent): AgentExt,
Path(id): Path<String>,
Json(req): Json<SendMessageRequest>,
) -> Result<Json<ApiResponse<PentestMessage>>, (StatusCode, String)> {
let oid = mongodb::bson::oid::ObjectId::parse_str(&id)
.map_err(|_| (StatusCode::BAD_REQUEST, "Invalid session ID".to_string()))?;
// Verify session exists and is running
let session = agent
.db
.pentest_sessions()
.find_one(doc! { "_id": oid })
.await
.map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("Database error: {e}"),
)
})?
.ok_or_else(|| (StatusCode::NOT_FOUND, "Session not found".to_string()))?;
if session.status != PentestStatus::Running && session.status != PentestStatus::Paused {
return Err((
StatusCode::BAD_REQUEST,
format!("Session is {}, cannot send messages", session.status),
));
}
// Look up the target
let target_oid =
mongodb::bson::oid::ObjectId::parse_str(&session.target_id).map_err(|_| {
(
StatusCode::INTERNAL_SERVER_ERROR,
"Invalid target_id in session".to_string(),
)
})?;
let target = agent
.db
.dast_targets()
.find_one(doc! { "_id": target_oid })
.await
.map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("Database error: {e}"),
)
})?
.ok_or_else(|| {
(
StatusCode::NOT_FOUND,
"Target for session not found".to_string(),
)
})?;
// Store user message
let session_id = id.clone();
let user_msg = PentestMessage::user(session_id.clone(), req.message.clone());
let _ = agent.db.pentest_messages().insert_one(&user_msg).await;
let response_msg = user_msg.clone();
// Spawn orchestrator to continue the session
let llm = agent.llm.clone();
let db = agent.db.clone();
let message = req.message.clone();
tokio::spawn(async move {
let orchestrator = PentestOrchestrator::new(llm, db);
if let Err(e) = orchestrator.run_session(&session, &target, &message).await {
tracing::error!("Pentest orchestrator failed for session {session_id}: {e}");
}
});
Ok(Json(ApiResponse {
data: response_msg,
total: None,
page: None,
}))
}
/// GET /api/v1/pentest/sessions/:id/stream — SSE endpoint for real-time events
///
/// Returns recent messages as SSE events (polling approach).
/// True real-time streaming with broadcast channels will be added in a future iteration.
#[tracing::instrument(skip_all, fields(session_id = %id))]
pub async fn session_stream(
Extension(agent): AgentExt,
Path(id): Path<String>,
) -> Result<Sse<impl futures_util::Stream<Item = Result<Event, std::convert::Infallible>>>, StatusCode>
{
let oid =
mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?;
// Verify session exists
let _session = agent
.db
.pentest_sessions()
.find_one(doc! { "_id": oid })
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
.ok_or(StatusCode::NOT_FOUND)?;
// Fetch recent messages for this session
let messages: Vec<PentestMessage> = match agent
.db
.pentest_messages()
.find(doc! { "session_id": &id })
.sort(doc! { "created_at": 1 })
.limit(100)
.await
{
Ok(cursor) => collect_cursor_async(cursor).await,
Err(_) => Vec::new(),
};
// Fetch recent attack chain nodes
let nodes: Vec<AttackChainNode> = match agent
.db
.attack_chain_nodes()
.find(doc! { "session_id": &id })
.sort(doc! { "started_at": 1 })
.limit(100)
.await
{
Ok(cursor) => collect_cursor_async(cursor).await,
Err(_) => Vec::new(),
};
// Build SSE events from stored data
let mut events: Vec<Result<Event, std::convert::Infallible>> = Vec::new();
for msg in &messages {
let event_data = serde_json::json!({
"type": "message",
"role": msg.role,
"content": msg.content,
"created_at": msg.created_at.to_rfc3339(),
});
if let Ok(data) = serde_json::to_string(&event_data) {
events.push(Ok(Event::default().event("message").data(data)));
}
}
for node in &nodes {
let event_data = serde_json::json!({
"type": "tool_execution",
"node_id": node.node_id,
"tool_name": node.tool_name,
"status": node.status,
"findings_produced": node.findings_produced,
});
if let Ok(data) = serde_json::to_string(&event_data) {
events.push(Ok(Event::default().event("tool").data(data)));
}
}
// Add session status event
let session = agent
.db
.pentest_sessions()
.find_one(doc! { "_id": oid })
.await
.ok()
.flatten();
if let Some(s) = session {
let status_data = serde_json::json!({
"type": "status",
"status": s.status,
"findings_count": s.findings_count,
"tool_invocations": s.tool_invocations,
});
if let Ok(data) = serde_json::to_string(&status_data) {
events.push(Ok(Event::default().event("status").data(data)));
}
}
Ok(Sse::new(stream::iter(events)))
}
/// GET /api/v1/pentest/sessions/:id/attack-chain — Get attack chain nodes for a session
#[tracing::instrument(skip_all, fields(session_id = %id))]
pub async fn get_attack_chain(
Extension(agent): AgentExt,
Path(id): Path<String>,
) -> Result<Json<ApiResponse<Vec<AttackChainNode>>>, StatusCode> {
// Verify the session ID is valid
let _oid =
mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?;
let nodes = match agent
.db
.attack_chain_nodes()
.find(doc! { "session_id": &id })
.sort(doc! { "started_at": 1 })
.await
{
Ok(cursor) => collect_cursor_async(cursor).await,
Err(e) => {
tracing::warn!("Failed to fetch attack chain nodes: {e}");
Vec::new()
}
};
let total = nodes.len() as u64;
Ok(Json(ApiResponse {
data: nodes,
total: Some(total),
page: None,
}))
}
/// GET /api/v1/pentest/sessions/:id/messages — Get messages for a session
#[tracing::instrument(skip_all, fields(session_id = %id))]
pub async fn get_messages(
Extension(agent): AgentExt,
Path(id): Path<String>,
Query(params): Query<PaginationParams>,
) -> Result<Json<ApiResponse<Vec<PentestMessage>>>, StatusCode> {
let _oid =
mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?;
let skip = (params.page.saturating_sub(1)) * params.limit as u64;
let total = agent
.db
.pentest_messages()
.count_documents(doc! { "session_id": &id })
.await
.unwrap_or(0);
let messages = match agent
.db
.pentest_messages()
.find(doc! { "session_id": &id })
.sort(doc! { "created_at": 1 })
.skip(skip)
.limit(params.limit)
.await
{
Ok(cursor) => collect_cursor_async(cursor).await,
Err(e) => {
tracing::warn!("Failed to fetch pentest messages: {e}");
Vec::new()
}
};
Ok(Json(ApiResponse {
data: messages,
total: Some(total),
page: Some(params.page),
}))
}
/// GET /api/v1/pentest/stats — Aggregated pentest statistics
#[tracing::instrument(skip_all)]
pub async fn pentest_stats(
Extension(agent): AgentExt,
) -> Result<Json<ApiResponse<PentestStats>>, StatusCode> {
let db = &agent.db;
let running_sessions = db
.pentest_sessions()
.count_documents(doc! { "status": "running" })
.await
.unwrap_or(0) as u32;
// Count DAST findings from pentest sessions
let total_vulnerabilities = db
.dast_findings()
.count_documents(doc! { "session_id": { "$exists": true, "$ne": null } })
.await
.unwrap_or(0) as u32;
// Aggregate tool invocations from all sessions
let sessions: Vec<PentestSession> = match db.pentest_sessions().find(doc! {}).await {
Ok(cursor) => collect_cursor_async(cursor).await,
Err(_) => Vec::new(),
};
let total_tool_invocations: u32 = sessions.iter().map(|s| s.tool_invocations).sum();
let total_successes: u32 = sessions.iter().map(|s| s.tool_successes).sum();
let tool_success_rate = if total_tool_invocations == 0 {
100.0
} else {
(total_successes as f64 / total_tool_invocations as f64) * 100.0
};
// Severity distribution from pentest-related DAST findings
let pentest_filter = doc! { "session_id": { "$exists": true, "$ne": null } };
let critical = db
.dast_findings()
.count_documents(doc! { "session_id": { "$exists": true, "$ne": null }, "severity": "critical" })
.await
.unwrap_or(0) as u32;
let high = db
.dast_findings()
.count_documents(doc! { "session_id": { "$exists": true, "$ne": null }, "severity": "high" })
.await
.unwrap_or(0) as u32;
let medium = db
.dast_findings()
.count_documents(doc! { "session_id": { "$exists": true, "$ne": null }, "severity": "medium" })
.await
.unwrap_or(0) as u32;
let low = db
.dast_findings()
.count_documents(doc! { "session_id": { "$exists": true, "$ne": null }, "severity": "low" })
.await
.unwrap_or(0) as u32;
let info = db
.dast_findings()
.count_documents(doc! { "session_id": { "$exists": true, "$ne": null }, "severity": "info" })
.await
.unwrap_or(0) as u32;
let _ = pentest_filter; // used above inline
Ok(Json(ApiResponse {
data: PentestStats {
running_sessions,
total_vulnerabilities,
total_tool_invocations,
tool_success_rate,
severity_distribution: SeverityDistribution {
critical,
high,
medium,
low,
info,
},
},
total: None,
page: None,
}))
}
/// GET /api/v1/pentest/sessions/:id/findings — Get DAST findings for a pentest session
#[tracing::instrument(skip_all, fields(session_id = %id))]
pub async fn get_session_findings(
Extension(agent): AgentExt,
Path(id): Path<String>,
Query(params): Query<PaginationParams>,
) -> Result<Json<ApiResponse<Vec<DastFinding>>>, StatusCode> {
let _oid =
mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?;
let skip = (params.page.saturating_sub(1)) * params.limit as u64;
let total = agent
.db
.dast_findings()
.count_documents(doc! { "session_id": &id })
.await
.unwrap_or(0);
let findings = match agent
.db
.dast_findings()
.find(doc! { "session_id": &id })
.sort(doc! { "created_at": -1 })
.skip(skip)
.limit(params.limit)
.await
{
Ok(cursor) => collect_cursor_async(cursor).await,
Err(e) => {
tracing::warn!("Failed to fetch pentest session findings: {e}");
Vec::new()
}
};
Ok(Json(ApiResponse {
data: findings,
total: Some(total),
page: Some(params.page),
}))
}

View File

@@ -99,6 +99,36 @@ pub fn build_router() -> Router {
"/api/v1/chat/{repo_id}/status",
get(handlers::chat::embedding_status),
)
// Pentest API endpoints
.route(
"/api/v1/pentest/sessions",
get(handlers::pentest::list_sessions).post(handlers::pentest::create_session),
)
.route(
"/api/v1/pentest/sessions/{id}",
get(handlers::pentest::get_session),
)
.route(
"/api/v1/pentest/sessions/{id}/chat",
post(handlers::pentest::send_message),
)
.route(
"/api/v1/pentest/sessions/{id}/stream",
get(handlers::pentest::session_stream),
)
.route(
"/api/v1/pentest/sessions/{id}/attack-chain",
get(handlers::pentest::get_attack_chain),
)
.route(
"/api/v1/pentest/sessions/{id}/messages",
get(handlers::pentest::get_messages),
)
.route(
"/api/v1/pentest/sessions/{id}/findings",
get(handlers::pentest::get_session_findings),
)
.route("/api/v1/pentest/stats", get(handlers::pentest::pentest_stats))
// Webhook endpoints (proxied through dashboard)
.route(
"/webhook/github/{repo_id}",

View File

@@ -166,6 +166,38 @@ impl Database {
)
.await?;
// pentest_sessions: compound (target_id, started_at DESC)
self.pentest_sessions()
.create_index(
IndexModel::builder()
.keys(doc! { "target_id": 1, "started_at": -1 })
.build(),
)
.await?;
// pentest_sessions: status index
self.pentest_sessions()
.create_index(IndexModel::builder().keys(doc! { "status": 1 }).build())
.await?;
// attack_chain_nodes: compound (session_id, node_id)
self.attack_chain_nodes()
.create_index(
IndexModel::builder()
.keys(doc! { "session_id": 1, "node_id": 1 })
.build(),
)
.await?;
// pentest_messages: compound (session_id, created_at)
self.pentest_messages()
.create_index(
IndexModel::builder()
.keys(doc! { "session_id": 1, "created_at": 1 })
.build(),
)
.await?;
tracing::info!("Database indexes ensured");
Ok(())
}
@@ -235,6 +267,19 @@ impl Database {
self.inner.collection("embedding_builds")
}
// Pentest collections
pub fn pentest_sessions(&self) -> Collection<PentestSession> {
self.inner.collection("pentest_sessions")
}
pub fn attack_chain_nodes(&self) -> Collection<AttackChainNode> {
self.inner.collection("attack_chain_nodes")
}
pub fn pentest_messages(&self) -> Collection<PentestMessage> {
self.inner.collection("pentest_messages")
}
#[allow(dead_code)]
pub fn raw_collection(&self, name: &str) -> Collection<mongodb::bson::Document> {
self.inner.collection(name)

View File

@@ -12,10 +12,16 @@ pub struct LlmClient {
http: reqwest::Client,
}
#[derive(Serialize)]
struct ChatMessage {
role: String,
content: String,
// ── Request types ──────────────────────────────────────────────
#[derive(Serialize, Clone, Debug)]
pub struct ChatMessage {
pub role: String,
pub content: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_calls: Option<Vec<ToolCallRequest>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_call_id: Option<String>,
}
#[derive(Serialize)]
@@ -26,8 +32,25 @@ struct ChatCompletionRequest {
temperature: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
max_tokens: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
tools: Option<Vec<ToolDefinitionPayload>>,
}
#[derive(Serialize)]
struct ToolDefinitionPayload {
r#type: String,
function: ToolFunctionPayload,
}
#[derive(Serialize)]
struct ToolFunctionPayload {
name: String,
description: String,
parameters: serde_json::Value,
}
// ── Response types ─────────────────────────────────────────────
#[derive(Deserialize)]
struct ChatCompletionResponse {
choices: Vec<ChatChoice>,
@@ -40,29 +63,84 @@ struct ChatChoice {
#[derive(Deserialize)]
struct ChatResponseMessage {
content: String,
#[serde(default)]
content: Option<String>,
#[serde(default)]
tool_calls: Option<Vec<ToolCallResponse>>,
}
/// Request body for the embeddings API
#[derive(Deserialize)]
struct ToolCallResponse {
id: String,
function: ToolCallFunction,
}
#[derive(Deserialize)]
struct ToolCallFunction {
name: String,
arguments: String,
}
// ── Public types for tool calling ──────────────────────────────
/// Definition of a tool that the LLM can invoke
#[derive(Debug, Clone, Serialize)]
pub struct ToolDefinition {
pub name: String,
pub description: String,
pub parameters: serde_json::Value,
}
/// A tool call request from the LLM
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LlmToolCall {
pub id: String,
pub name: String,
pub arguments: serde_json::Value,
}
/// A tool call in the request message format (for sending back tool_calls in assistant messages)
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolCallRequest {
pub id: String,
pub r#type: String,
pub function: ToolCallRequestFunction,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolCallRequestFunction {
pub name: String,
pub arguments: String,
}
/// Response from the LLM — either content or tool calls
#[derive(Debug, Clone)]
pub enum LlmResponse {
Content(String),
ToolCalls(Vec<LlmToolCall>),
}
// ── Embedding types ────────────────────────────────────────────
#[derive(Serialize)]
struct EmbeddingRequest {
model: String,
input: Vec<String>,
}
/// Response from the embeddings API
#[derive(Deserialize)]
struct EmbeddingResponse {
data: Vec<EmbeddingData>,
}
/// A single embedding result
#[derive(Deserialize)]
struct EmbeddingData {
embedding: Vec<f64>,
index: usize,
}
// ── Implementation ─────────────────────────────────────────────
impl LlmClient {
pub fn new(
base_url: String,
@@ -83,98 +161,142 @@ impl LlmClient {
&self.embed_model
}
fn chat_url(&self) -> String {
format!(
"{}/v1/chat/completions",
self.base_url.trim_end_matches('/')
)
}
fn auth_header(&self) -> Option<String> {
let key = self.api_key.expose_secret();
if key.is_empty() {
None
} else {
Some(format!("Bearer {key}"))
}
}
/// Simple chat: system + user prompt → text response
pub async fn chat(
&self,
system_prompt: &str,
user_prompt: &str,
temperature: Option<f64>,
) -> Result<String, AgentError> {
let url = format!(
"{}/v1/chat/completions",
self.base_url.trim_end_matches('/')
);
let messages = vec![
ChatMessage {
role: "system".to_string(),
content: Some(system_prompt.to_string()),
tool_calls: None,
tool_call_id: None,
},
ChatMessage {
role: "user".to_string(),
content: Some(user_prompt.to_string()),
tool_calls: None,
tool_call_id: None,
},
];
let request_body = ChatCompletionRequest {
model: self.model.clone(),
messages: vec![
ChatMessage {
role: "system".to_string(),
content: system_prompt.to_string(),
},
ChatMessage {
role: "user".to_string(),
content: user_prompt.to_string(),
},
],
messages,
temperature,
max_tokens: Some(4096),
tools: None,
};
let mut req = self
.http
.post(&url)
.header("content-type", "application/json")
.json(&request_body);
let key = self.api_key.expose_secret();
if !key.is_empty() {
req = req.header("Authorization", format!("Bearer {key}"));
}
let resp = req
.send()
.await
.map_err(|e| AgentError::Other(format!("LiteLLM request failed: {e}")))?;
if !resp.status().is_success() {
let status = resp.status();
let body = resp.text().await.unwrap_or_default();
return Err(AgentError::Other(format!(
"LiteLLM returned {status}: {body}"
)));
}
let body: ChatCompletionResponse = resp
.json()
.await
.map_err(|e| AgentError::Other(format!("Failed to parse LiteLLM response: {e}")))?;
body.choices
.first()
.map(|c| c.message.content.clone())
.ok_or_else(|| AgentError::Other("Empty response from LiteLLM".to_string()))
self.send_chat_request(&request_body).await.map(|resp| {
match resp {
LlmResponse::Content(c) => c,
LlmResponse::ToolCalls(_) => String::new(), // shouldn't happen without tools
}
})
}
/// Chat with a list of (role, content) messages → text response
#[allow(dead_code)]
pub async fn chat_with_messages(
&self,
messages: Vec<(String, String)>,
temperature: Option<f64>,
) -> Result<String, AgentError> {
let url = format!(
"{}/v1/chat/completions",
self.base_url.trim_end_matches('/')
);
let messages = messages
.into_iter()
.map(|(role, content)| ChatMessage {
role,
content: Some(content),
tool_calls: None,
tool_call_id: None,
})
.collect();
let request_body = ChatCompletionRequest {
model: self.model.clone(),
messages: messages
.into_iter()
.map(|(role, content)| ChatMessage { role, content })
.collect(),
messages,
temperature,
max_tokens: Some(4096),
tools: None,
};
self.send_chat_request(&request_body).await.map(|resp| {
match resp {
LlmResponse::Content(c) => c,
LlmResponse::ToolCalls(_) => String::new(),
}
})
}
/// Chat with tool definitions — returns either content or tool calls.
/// Use this for the AI pentest orchestrator loop.
pub async fn chat_with_tools(
&self,
messages: Vec<ChatMessage>,
tools: &[ToolDefinition],
temperature: Option<f64>,
max_tokens: Option<u32>,
) -> Result<LlmResponse, AgentError> {
let tool_payloads: Vec<ToolDefinitionPayload> = tools
.iter()
.map(|t| ToolDefinitionPayload {
r#type: "function".to_string(),
function: ToolFunctionPayload {
name: t.name.clone(),
description: t.description.clone(),
parameters: t.parameters.clone(),
},
})
.collect();
let request_body = ChatCompletionRequest {
model: self.model.clone(),
messages,
temperature,
max_tokens: Some(max_tokens.unwrap_or(8192)),
tools: if tool_payloads.is_empty() {
None
} else {
Some(tool_payloads)
},
};
self.send_chat_request(&request_body).await
}
/// Internal method to send a chat completion request and parse the response
async fn send_chat_request(
&self,
request_body: &ChatCompletionRequest,
) -> Result<LlmResponse, AgentError> {
let mut req = self
.http
.post(&url)
.post(&self.chat_url())
.header("content-type", "application/json")
.json(&request_body);
.json(request_body);
let key = self.api_key.expose_secret();
if !key.is_empty() {
req = req.header("Authorization", format!("Bearer {key}"));
if let Some(auth) = self.auth_header() {
req = req.header("Authorization", auth);
}
let resp = req
@@ -195,10 +317,37 @@ impl LlmClient {
.await
.map_err(|e| AgentError::Other(format!("Failed to parse LiteLLM response: {e}")))?;
body.choices
let choice = body
.choices
.first()
.map(|c| c.message.content.clone())
.ok_or_else(|| AgentError::Other("Empty response from LiteLLM".to_string()))
.ok_or_else(|| AgentError::Other("Empty response from LiteLLM".to_string()))?;
// Check for tool calls first
if let Some(tool_calls) = &choice.message.tool_calls {
if !tool_calls.is_empty() {
let calls: Vec<LlmToolCall> = tool_calls
.iter()
.map(|tc| {
let arguments = serde_json::from_str(&tc.function.arguments)
.unwrap_or(serde_json::Value::Object(serde_json::Map::new()));
LlmToolCall {
id: tc.id.clone(),
name: tc.function.name.clone(),
arguments,
}
})
.collect();
return Ok(LlmResponse::ToolCalls(calls));
}
}
// Otherwise return content
let content = choice
.message
.content
.clone()
.unwrap_or_default();
Ok(LlmResponse::Content(content))
}
/// Generate embeddings for a batch of texts
@@ -216,9 +365,8 @@ impl LlmClient {
.header("content-type", "application/json")
.json(&request_body);
let key = self.api_key.expose_secret();
if !key.is_empty() {
req = req.header("Authorization", format!("Bearer {key}"));
if let Some(auth) = self.auth_header() {
req = req.header("Authorization", auth);
}
let resp = req
@@ -239,7 +387,6 @@ impl LlmClient {
.await
.map_err(|e| AgentError::Other(format!("Failed to parse embedding response: {e}")))?;
// Sort by index to maintain input order
let mut data = body.data;
data.sort_by_key(|d| d.index);

View File

@@ -4,6 +4,7 @@ mod config;
mod database;
mod error;
mod llm;
mod pentest;
mod pipeline;
mod rag;
mod scheduler;

View File

@@ -0,0 +1,3 @@
pub mod orchestrator;
pub use orchestrator::PentestOrchestrator;

View File

@@ -0,0 +1,393 @@
use std::sync::Arc;
use tokio::sync::broadcast;
use compliance_core::models::dast::DastTarget;
use compliance_core::models::pentest::*;
use compliance_core::traits::pentest_tool::PentestToolContext;
use compliance_dast::ToolRegistry;
use crate::database::Database;
use crate::llm::client::{
ChatMessage, LlmResponse, ToolCallRequest, ToolCallRequestFunction, ToolDefinition,
};
use crate::llm::LlmClient;
pub struct PentestOrchestrator {
tool_registry: ToolRegistry,
llm: Arc<LlmClient>,
db: Database,
event_tx: broadcast::Sender<PentestEvent>,
}
impl PentestOrchestrator {
pub fn new(llm: Arc<LlmClient>, db: Database) -> Self {
let (event_tx, _) = broadcast::channel(256);
Self {
tool_registry: ToolRegistry::new(),
llm,
db,
event_tx,
}
}
pub fn subscribe(&self) -> broadcast::Receiver<PentestEvent> {
self.event_tx.subscribe()
}
pub fn event_sender(&self) -> broadcast::Sender<PentestEvent> {
self.event_tx.clone()
}
pub async fn run_session(
&self,
session: &PentestSession,
target: &DastTarget,
initial_message: &str,
) -> Result<(), crate::error::AgentError> {
let session_id = session
.id
.map(|oid| oid.to_hex())
.unwrap_or_default();
// Build system prompt
let system_prompt = self.build_system_prompt(session, target);
// Build tool definitions for LLM
let tool_defs: Vec<ToolDefinition> = self
.tool_registry
.all_definitions()
.into_iter()
.map(|td| ToolDefinition {
name: td.name,
description: td.description,
parameters: td.input_schema,
})
.collect();
// Initialize messages
let mut messages = vec![
ChatMessage {
role: "system".to_string(),
content: Some(system_prompt),
tool_calls: None,
tool_call_id: None,
},
ChatMessage {
role: "user".to_string(),
content: Some(initial_message.to_string()),
tool_calls: None,
tool_call_id: None,
},
];
// Store user message
let user_msg = PentestMessage::user(session_id.clone(), initial_message.to_string());
let _ = self.db.pentest_messages().insert_one(&user_msg).await;
// Build tool context
let tool_context = PentestToolContext {
target: target.clone(),
session_id: session_id.clone(),
sast_findings: Vec::new(),
sbom_entries: Vec::new(),
code_context: Vec::new(),
rate_limit: target.rate_limit,
allow_destructive: target.allow_destructive,
};
let max_iterations = 50;
let mut total_findings = 0u32;
let mut total_tool_calls = 0u32;
let mut total_successes = 0u32;
for _iteration in 0..max_iterations {
// Call LLM with tools
let response = self
.llm
.chat_with_tools(messages.clone(), &tool_defs, Some(0.2), Some(8192))
.await?;
match response {
LlmResponse::Content(content) => {
// Store assistant message
let msg =
PentestMessage::assistant(session_id.clone(), content.clone());
let _ = self.db.pentest_messages().insert_one(&msg).await;
// Emit message event
let _ = self.event_tx.send(PentestEvent::Message {
content: content.clone(),
});
// Add to messages
messages.push(ChatMessage {
role: "assistant".to_string(),
content: Some(content.clone()),
tool_calls: None,
tool_call_id: None,
});
// Check if the LLM considers itself done
let done_indicators = [
"pentest complete",
"testing complete",
"scan complete",
"analysis complete",
"finished",
"that concludes",
];
let content_lower = content.to_lowercase();
if done_indicators
.iter()
.any(|ind| content_lower.contains(ind))
{
break;
}
// If not done, break and wait for user input
break;
}
LlmResponse::ToolCalls(tool_calls) => {
// Build the assistant message with tool_calls
let tc_requests: Vec<ToolCallRequest> = tool_calls
.iter()
.map(|tc| ToolCallRequest {
id: tc.id.clone(),
r#type: "function".to_string(),
function: ToolCallRequestFunction {
name: tc.name.clone(),
arguments: serde_json::to_string(&tc.arguments)
.unwrap_or_default(),
},
})
.collect();
messages.push(ChatMessage {
role: "assistant".to_string(),
content: None,
tool_calls: Some(tc_requests),
tool_call_id: None,
});
// Execute each tool call
for tc in &tool_calls {
total_tool_calls += 1;
let node_id = uuid::Uuid::new_v4().to_string();
// Create attack chain node
let mut node = AttackChainNode::new(
session_id.clone(),
node_id.clone(),
tc.name.clone(),
tc.arguments.clone(),
String::new(),
);
node.status = AttackNodeStatus::Running;
node.started_at = Some(chrono::Utc::now());
let _ = self.db.attack_chain_nodes().insert_one(&node).await;
// Emit tool start event
let _ = self.event_tx.send(PentestEvent::ToolStart {
node_id: node_id.clone(),
tool_name: tc.name.clone(),
input: tc.arguments.clone(),
});
// Execute the tool
let result = if let Some(tool) = self.tool_registry.get(&tc.name) {
match tool.execute(tc.arguments.clone(), &tool_context).await {
Ok(result) => {
total_successes += 1;
let findings_count = result.findings.len() as u32;
total_findings += findings_count;
// Store findings
for mut finding in result.findings {
finding.scan_run_id = session_id.clone();
finding.session_id = Some(session_id.clone());
let _ =
self.db.dast_findings().insert_one(&finding).await;
let _ =
self.event_tx.send(PentestEvent::Finding {
finding_id: finding
.id
.map(|oid| oid.to_hex())
.unwrap_or_default(),
title: finding.title.clone(),
severity: finding.severity.to_string(),
});
}
// Emit tool complete event
let _ = self.event_tx.send(PentestEvent::ToolComplete {
node_id: node_id.clone(),
summary: result.summary.clone(),
findings_count,
});
// Update attack chain node
let _ = self
.db
.attack_chain_nodes()
.update_one(
mongodb::bson::doc! {
"session_id": &session_id,
"node_id": &node_id,
},
mongodb::bson::doc! { "$set": {
"status": "completed",
"tool_output": mongodb::bson::to_bson(&result.data)
.unwrap_or(mongodb::bson::Bson::Null),
"completed_at": mongodb::bson::DateTime::now(),
}},
)
.await;
serde_json::json!({
"summary": result.summary,
"findings_count": findings_count,
"data": result.data,
})
.to_string()
}
Err(e) => {
// Update node as failed
let _ = self
.db
.attack_chain_nodes()
.update_one(
mongodb::bson::doc! {
"session_id": &session_id,
"node_id": &node_id,
},
mongodb::bson::doc! { "$set": {
"status": "failed",
"completed_at": mongodb::bson::DateTime::now(),
}},
)
.await;
format!("Tool execution failed: {e}")
}
}
} else {
format!("Unknown tool: {}", tc.name)
};
// Add tool result to messages
messages.push(ChatMessage {
role: "tool".to_string(),
content: Some(result),
tool_calls: None,
tool_call_id: Some(tc.id.clone()),
});
}
// Update session stats
if let Some(sid) = session.id {
let _ = self
.db
.pentest_sessions()
.update_one(
mongodb::bson::doc! { "_id": sid },
mongodb::bson::doc! { "$set": {
"tool_invocations": total_tool_calls as i64,
"tool_successes": total_successes as i64,
"findings_count": total_findings as i64,
}},
)
.await;
}
}
}
}
// Mark session as completed
if let Some(sid) = session.id {
let _ = self
.db
.pentest_sessions()
.update_one(
mongodb::bson::doc! { "_id": sid },
mongodb::bson::doc! { "$set": {
"status": "completed",
"completed_at": mongodb::bson::DateTime::now(),
"tool_invocations": total_tool_calls as i64,
"tool_successes": total_successes as i64,
"findings_count": total_findings as i64,
}},
)
.await;
}
let _ = self.event_tx.send(PentestEvent::Complete {
summary: format!(
"Pentest complete. {} findings from {} tool invocations.",
total_findings, total_tool_calls
),
});
Ok(())
}
fn build_system_prompt(&self, session: &PentestSession, target: &DastTarget) -> String {
let tool_names = self.tool_registry.list_names().join(", ");
let strategy_guidance = match session.strategy {
PentestStrategy::Quick => {
"Focus on the most common and impactful vulnerabilities. Run a quick recon, then target the highest-risk areas."
}
PentestStrategy::Comprehensive => {
"Perform a thorough assessment covering all vulnerability types. Start with recon, then systematically test each attack surface."
}
PentestStrategy::Targeted => {
"Focus specifically on areas highlighted by SAST findings and known CVEs. Prioritize exploiting known weaknesses."
}
PentestStrategy::Aggressive => {
"Use all available tools aggressively. Test with maximum payloads and attempt full exploitation."
}
PentestStrategy::Stealth => {
"Minimize noise. Use fewer requests, avoid aggressive payloads. Focus on passive analysis and targeted probes."
}
};
format!(
r#"You are an expert penetration tester conducting an authorized security assessment.
## Target
- **Name**: {target_name}
- **URL**: {base_url}
- **Type**: {target_type}
- **Rate Limit**: {rate_limit} req/s
- **Destructive Tests Allowed**: {allow_destructive}
## Strategy
{strategy_guidance}
## Available Tools
{tool_names}
## Instructions
1. Start by running reconnaissance and crawling to understand the target.
2. Based on what you discover, select appropriate vulnerability scanning tools.
3. For each tool invocation, provide the discovered endpoints and parameters.
4. Analyze tool results and chain findings — if you find one vulnerability, explore whether it enables others.
5. When testing is complete, provide a summary of all findings with severity and remediation recommendations.
6. Always explain your reasoning before invoking each tool.
7. Focus on actionable findings with evidence. Avoid false positives.
8. When you have completed all relevant testing, say "Testing complete" followed by a final summary.
## Important
- This is an authorized penetration test. All testing is permitted within the target scope.
- Respect the rate limit of {rate_limit} requests per second.
- Only use destructive tests if explicitly allowed ({allow_destructive}).
"#,
target_name = target.name,
base_url = target.base_url,
target_type = target.target_type,
rate_limit = target.rate_limit,
allow_destructive = target.allow_destructive,
)
}
}