feat: AI-driven automated penetration testing system

Add a complete AI pentest system where Claude autonomously drives security
testing via tool-calling. The LLM selects from 16 tools, chains results,
and builds an attack chain DAG.

Core:
- PentestTool trait (dyn-compatible) with PentestToolContext/Result
- PentestSession, AttackChainNode, PentestMessage, PentestEvent models
- 10 new DastVulnType variants (DNS, DMARC, TLS, cookies, CSP, CORS, etc.)
- LLM client chat_with_tools() for OpenAI-compatible tool calling

Tools (16 total):
- 5 agent wrappers: SQL injection, XSS, auth bypass, SSRF, API fuzzer
- 11 new infra tools: DNS checker, DMARC checker, TLS analyzer,
  security headers, cookie analyzer, CSP analyzer, rate limit tester,
  console log detector, CORS checker, OpenAPI parser, recon
- ToolRegistry for tool lookup and LLM definition generation

Orchestrator:
- PentestOrchestrator with iterative tool-calling loop (max 50 rounds)
- Attack chain node recording per tool invocation
- SSE event broadcasting for real-time progress
- Strategy-aware system prompts (quick/comprehensive/targeted/aggressive/stealth)

API (9 endpoints):
- POST/GET /pentest/sessions, GET /pentest/sessions/:id
- POST /pentest/sessions/:id/chat, GET /pentest/sessions/:id/stream
- GET /pentest/sessions/:id/attack-chain, messages, findings
- GET /pentest/stats

Dashboard:
- Pentest dashboard with stat cards, severity distribution, session list
- Chat-based session page with split layout (chat + findings/attack chain)
- Inline tool execution indicators, auto-polling, new session modal
- Sidebar navigation item

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Sharang Parnerkar
2026-03-11 19:23:21 +01:00
parent 76260acc76
commit 71d8741e10
40 changed files with 7546 additions and 90 deletions

89
Cargo.lock generated
View File

@@ -680,12 +680,14 @@ dependencies = [
"chrono",
"compliance-core",
"mongodb",
"native-tls",
"reqwest",
"scraper",
"serde",
"serde_json",
"thiserror 2.0.18",
"tokio",
"tokio-native-tls",
"tracing",
"url",
"uuid",
@@ -1994,6 +1996,21 @@ version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "77ce24cb58228fbb8aa041425bb1050850ac19177686ea6e0f41a70416f56fdb"
[[package]]
name = "foreign-types"
version = "0.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f6f339eb8adc052cd2ca78910fda869aefa38d22d5cb648e6485e4d3fc06f3b1"
dependencies = [
"foreign-types-shared",
]
[[package]]
name = "foreign-types-shared"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "00b0228411908ca8685dba7fc2cdd70ec9990a6e753e89b6ac91a84c40fbaf4b"
[[package]]
name = "form_urlencoded"
version = "1.2.2"
@@ -2824,15 +2841,6 @@ dependencies = [
"serde",
]
[[package]]
name = "itertools"
version = "0.12.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ba291022dbbd398a455acf126c1e341954079855bc60dfdda641363bd6922569"
dependencies = [
"either",
]
[[package]]
name = "itertools"
version = "0.14.0"
@@ -3399,6 +3407,23 @@ version = "0.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2195bf6aa996a481483b29d62a7663eed3fe39600c460e323f8ff41e90bdd89b"
[[package]]
name = "native-tls"
version = "0.2.18"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "465500e14ea162429d264d44189adc38b199b62b1c21eea9f69e4b73cb03bbf2"
dependencies = [
"libc",
"log",
"openssl",
"openssl-probe 0.2.1",
"openssl-sys",
"schannel",
"security-framework",
"security-framework-sys",
"tempfile",
]
[[package]]
name = "ndk"
version = "0.9.0"
@@ -3578,6 +3603,32 @@ version = "0.1.13"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "269bca4c2591a28585d6bf10d9ed0332b7d76900a1b02bec41bdc3a2cdcda107"
[[package]]
name = "openssl"
version = "0.10.75"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "08838db121398ad17ab8531ce9de97b244589089e290a384c900cb9ff7434328"
dependencies = [
"bitflags",
"cfg-if",
"foreign-types",
"libc",
"once_cell",
"openssl-macros",
"openssl-sys",
]
[[package]]
name = "openssl-macros"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "openssl-probe"
version = "0.1.6"
@@ -3949,7 +4000,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8a56d757972c98b346a9b766e3f02746cde6dd1cd1d1d563472929fdd74bec4d"
dependencies = [
"anyhow",
"itertools 0.12.1",
"itertools",
"proc-macro2",
"quote",
"syn",
@@ -4899,7 +4950,7 @@ version = "0.8.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c1c97747dbf44bb1ca44a561ece23508e99cb592e862f22222dcf42f51d1e451"
dependencies = [
"heck 0.4.1",
"heck 0.5.0",
"proc-macro2",
"quote",
"syn",
@@ -5116,7 +5167,7 @@ dependencies = [
"fs4",
"htmlescape",
"hyperloglogplus",
"itertools 0.14.0",
"itertools",
"levenshtein_automata",
"log",
"lru 0.12.5",
@@ -5164,7 +5215,7 @@ checksum = "8b628488ae936c83e92b5c4056833054ca56f76c0e616aee8339e24ac89119cd"
dependencies = [
"downcast-rs",
"fastdivide",
"itertools 0.14.0",
"itertools",
"serde",
"tantivy-bitpacker",
"tantivy-common",
@@ -5214,7 +5265,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f8292095d1a8a2c2b36380ec455f910ab52dde516af36321af332c93f20ab7d5"
dependencies = [
"futures-util",
"itertools 0.14.0",
"itertools",
"tantivy-bitpacker",
"tantivy-common",
"tantivy-fst",
@@ -5428,6 +5479,16 @@ dependencies = [
"syn",
]
[[package]]
name = "tokio-native-tls"
version = "0.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bbae76ab933c85776efabc971569dd6119c580d8f5d448769dec1764bf796ef2"
dependencies = [
"native-tls",
"tokio",
]
[[package]]
name = "tokio-rustls"
version = "0.26.4"

View File

@@ -1,6 +1,7 @@
pub mod chat;
pub mod dast;
pub mod graph;
pub mod pentest;
use std::sync::Arc;
@@ -1108,7 +1109,7 @@ pub async fn list_scan_runs(
}))
}
async fn collect_cursor_async<T: serde::de::DeserializeOwned + Unpin + Send>(
pub(crate) async fn collect_cursor_async<T: serde::de::DeserializeOwned + Unpin + Send>(
mut cursor: mongodb::Cursor<T>,
) -> Vec<T> {
use futures_util::StreamExt;

View File

@@ -0,0 +1,564 @@
use std::sync::Arc;
use axum::extract::{Extension, Path, Query};
use axum::http::StatusCode;
use axum::response::sse::{Event, Sse};
use axum::Json;
use futures_util::stream;
use mongodb::bson::doc;
use serde::Deserialize;
use compliance_core::models::dast::DastFinding;
use compliance_core::models::pentest::*;
use crate::agent::ComplianceAgent;
use crate::pentest::PentestOrchestrator;
use super::{collect_cursor_async, ApiResponse, PaginationParams};
type AgentExt = Extension<Arc<ComplianceAgent>>;
#[derive(Deserialize)]
pub struct CreateSessionRequest {
pub target_id: String,
#[serde(default = "default_strategy")]
pub strategy: String,
pub message: Option<String>,
}
fn default_strategy() -> String {
"comprehensive".to_string()
}
#[derive(Deserialize)]
pub struct SendMessageRequest {
pub message: String,
}
/// POST /api/v1/pentest/sessions — Create a new pentest session and start the orchestrator
#[tracing::instrument(skip_all)]
pub async fn create_session(
Extension(agent): AgentExt,
Json(req): Json<CreateSessionRequest>,
) -> Result<Json<ApiResponse<PentestSession>>, (StatusCode, String)> {
let oid = mongodb::bson::oid::ObjectId::parse_str(&req.target_id).map_err(|_| {
(
StatusCode::BAD_REQUEST,
"Invalid target_id format".to_string(),
)
})?;
// Look up the target
let target = agent
.db
.dast_targets()
.find_one(doc! { "_id": oid })
.await
.map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("Database error: {e}"),
)
})?
.ok_or_else(|| (StatusCode::NOT_FOUND, "Target not found".to_string()))?;
// Parse strategy
let strategy = match req.strategy.as_str() {
"quick" => PentestStrategy::Quick,
"targeted" => PentestStrategy::Targeted,
"aggressive" => PentestStrategy::Aggressive,
"stealth" => PentestStrategy::Stealth,
_ => PentestStrategy::Comprehensive,
};
// Create session
let mut session = PentestSession::new(req.target_id.clone(), strategy);
session.repo_id = target.repo_id.clone();
agent
.db
.pentest_sessions()
.insert_one(&session)
.await
.map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("Failed to create session: {e}"),
)
})?;
let initial_message = req.message.unwrap_or_else(|| {
format!(
"Begin a {} penetration test against {} ({}). \
Identify vulnerabilities and provide evidence for each finding.",
session.strategy, target.name, target.base_url,
)
});
// Spawn the orchestrator on a background task
let llm = agent.llm.clone();
let db = agent.db.clone();
let session_clone = session.clone();
let target_clone = target.clone();
tokio::spawn(async move {
let orchestrator = PentestOrchestrator::new(llm, db);
if let Err(e) = orchestrator
.run_session(&session_clone, &target_clone, &initial_message)
.await
{
tracing::error!(
"Pentest orchestrator failed for session {}: {e}",
session_clone
.id
.map(|oid| oid.to_hex())
.unwrap_or_default()
);
}
});
Ok(Json(ApiResponse {
data: session,
total: None,
page: None,
}))
}
/// GET /api/v1/pentest/sessions — List pentest sessions
#[tracing::instrument(skip_all)]
pub async fn list_sessions(
Extension(agent): AgentExt,
Query(params): Query<PaginationParams>,
) -> Result<Json<ApiResponse<Vec<PentestSession>>>, StatusCode> {
let db = &agent.db;
let skip = (params.page.saturating_sub(1)) * params.limit as u64;
let total = db
.pentest_sessions()
.count_documents(doc! {})
.await
.unwrap_or(0);
let sessions = match db
.pentest_sessions()
.find(doc! {})
.sort(doc! { "started_at": -1 })
.skip(skip)
.limit(params.limit)
.await
{
Ok(cursor) => collect_cursor_async(cursor).await,
Err(e) => {
tracing::warn!("Failed to fetch pentest sessions: {e}");
Vec::new()
}
};
Ok(Json(ApiResponse {
data: sessions,
total: Some(total),
page: Some(params.page),
}))
}
/// GET /api/v1/pentest/sessions/:id — Get a single pentest session
#[tracing::instrument(skip_all, fields(session_id = %id))]
pub async fn get_session(
Extension(agent): AgentExt,
Path(id): Path<String>,
) -> Result<Json<ApiResponse<PentestSession>>, StatusCode> {
let oid =
mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?;
let session = agent
.db
.pentest_sessions()
.find_one(doc! { "_id": oid })
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
.ok_or(StatusCode::NOT_FOUND)?;
Ok(Json(ApiResponse {
data: session,
total: None,
page: None,
}))
}
/// POST /api/v1/pentest/sessions/:id/chat — Send a user message and trigger next orchestrator iteration
#[tracing::instrument(skip_all, fields(session_id = %id))]
pub async fn send_message(
Extension(agent): AgentExt,
Path(id): Path<String>,
Json(req): Json<SendMessageRequest>,
) -> Result<Json<ApiResponse<PentestMessage>>, (StatusCode, String)> {
let oid = mongodb::bson::oid::ObjectId::parse_str(&id)
.map_err(|_| (StatusCode::BAD_REQUEST, "Invalid session ID".to_string()))?;
// Verify session exists and is running
let session = agent
.db
.pentest_sessions()
.find_one(doc! { "_id": oid })
.await
.map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("Database error: {e}"),
)
})?
.ok_or_else(|| (StatusCode::NOT_FOUND, "Session not found".to_string()))?;
if session.status != PentestStatus::Running && session.status != PentestStatus::Paused {
return Err((
StatusCode::BAD_REQUEST,
format!("Session is {}, cannot send messages", session.status),
));
}
// Look up the target
let target_oid =
mongodb::bson::oid::ObjectId::parse_str(&session.target_id).map_err(|_| {
(
StatusCode::INTERNAL_SERVER_ERROR,
"Invalid target_id in session".to_string(),
)
})?;
let target = agent
.db
.dast_targets()
.find_one(doc! { "_id": target_oid })
.await
.map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("Database error: {e}"),
)
})?
.ok_or_else(|| {
(
StatusCode::NOT_FOUND,
"Target for session not found".to_string(),
)
})?;
// Store user message
let session_id = id.clone();
let user_msg = PentestMessage::user(session_id.clone(), req.message.clone());
let _ = agent.db.pentest_messages().insert_one(&user_msg).await;
let response_msg = user_msg.clone();
// Spawn orchestrator to continue the session
let llm = agent.llm.clone();
let db = agent.db.clone();
let message = req.message.clone();
tokio::spawn(async move {
let orchestrator = PentestOrchestrator::new(llm, db);
if let Err(e) = orchestrator.run_session(&session, &target, &message).await {
tracing::error!("Pentest orchestrator failed for session {session_id}: {e}");
}
});
Ok(Json(ApiResponse {
data: response_msg,
total: None,
page: None,
}))
}
/// GET /api/v1/pentest/sessions/:id/stream — SSE endpoint for real-time events
///
/// Returns recent messages as SSE events (polling approach).
/// True real-time streaming with broadcast channels will be added in a future iteration.
#[tracing::instrument(skip_all, fields(session_id = %id))]
pub async fn session_stream(
Extension(agent): AgentExt,
Path(id): Path<String>,
) -> Result<Sse<impl futures_util::Stream<Item = Result<Event, std::convert::Infallible>>>, StatusCode>
{
let oid =
mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?;
// Verify session exists
let _session = agent
.db
.pentest_sessions()
.find_one(doc! { "_id": oid })
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
.ok_or(StatusCode::NOT_FOUND)?;
// Fetch recent messages for this session
let messages: Vec<PentestMessage> = match agent
.db
.pentest_messages()
.find(doc! { "session_id": &id })
.sort(doc! { "created_at": 1 })
.limit(100)
.await
{
Ok(cursor) => collect_cursor_async(cursor).await,
Err(_) => Vec::new(),
};
// Fetch recent attack chain nodes
let nodes: Vec<AttackChainNode> = match agent
.db
.attack_chain_nodes()
.find(doc! { "session_id": &id })
.sort(doc! { "started_at": 1 })
.limit(100)
.await
{
Ok(cursor) => collect_cursor_async(cursor).await,
Err(_) => Vec::new(),
};
// Build SSE events from stored data
let mut events: Vec<Result<Event, std::convert::Infallible>> = Vec::new();
for msg in &messages {
let event_data = serde_json::json!({
"type": "message",
"role": msg.role,
"content": msg.content,
"created_at": msg.created_at.to_rfc3339(),
});
if let Ok(data) = serde_json::to_string(&event_data) {
events.push(Ok(Event::default().event("message").data(data)));
}
}
for node in &nodes {
let event_data = serde_json::json!({
"type": "tool_execution",
"node_id": node.node_id,
"tool_name": node.tool_name,
"status": node.status,
"findings_produced": node.findings_produced,
});
if let Ok(data) = serde_json::to_string(&event_data) {
events.push(Ok(Event::default().event("tool").data(data)));
}
}
// Add session status event
let session = agent
.db
.pentest_sessions()
.find_one(doc! { "_id": oid })
.await
.ok()
.flatten();
if let Some(s) = session {
let status_data = serde_json::json!({
"type": "status",
"status": s.status,
"findings_count": s.findings_count,
"tool_invocations": s.tool_invocations,
});
if let Ok(data) = serde_json::to_string(&status_data) {
events.push(Ok(Event::default().event("status").data(data)));
}
}
Ok(Sse::new(stream::iter(events)))
}
/// GET /api/v1/pentest/sessions/:id/attack-chain — Get attack chain nodes for a session
#[tracing::instrument(skip_all, fields(session_id = %id))]
pub async fn get_attack_chain(
Extension(agent): AgentExt,
Path(id): Path<String>,
) -> Result<Json<ApiResponse<Vec<AttackChainNode>>>, StatusCode> {
// Verify the session ID is valid
let _oid =
mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?;
let nodes = match agent
.db
.attack_chain_nodes()
.find(doc! { "session_id": &id })
.sort(doc! { "started_at": 1 })
.await
{
Ok(cursor) => collect_cursor_async(cursor).await,
Err(e) => {
tracing::warn!("Failed to fetch attack chain nodes: {e}");
Vec::new()
}
};
let total = nodes.len() as u64;
Ok(Json(ApiResponse {
data: nodes,
total: Some(total),
page: None,
}))
}
/// GET /api/v1/pentest/sessions/:id/messages — Get messages for a session
#[tracing::instrument(skip_all, fields(session_id = %id))]
pub async fn get_messages(
Extension(agent): AgentExt,
Path(id): Path<String>,
Query(params): Query<PaginationParams>,
) -> Result<Json<ApiResponse<Vec<PentestMessage>>>, StatusCode> {
let _oid =
mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?;
let skip = (params.page.saturating_sub(1)) * params.limit as u64;
let total = agent
.db
.pentest_messages()
.count_documents(doc! { "session_id": &id })
.await
.unwrap_or(0);
let messages = match agent
.db
.pentest_messages()
.find(doc! { "session_id": &id })
.sort(doc! { "created_at": 1 })
.skip(skip)
.limit(params.limit)
.await
{
Ok(cursor) => collect_cursor_async(cursor).await,
Err(e) => {
tracing::warn!("Failed to fetch pentest messages: {e}");
Vec::new()
}
};
Ok(Json(ApiResponse {
data: messages,
total: Some(total),
page: Some(params.page),
}))
}
/// GET /api/v1/pentest/stats — Aggregated pentest statistics
#[tracing::instrument(skip_all)]
pub async fn pentest_stats(
Extension(agent): AgentExt,
) -> Result<Json<ApiResponse<PentestStats>>, StatusCode> {
let db = &agent.db;
let running_sessions = db
.pentest_sessions()
.count_documents(doc! { "status": "running" })
.await
.unwrap_or(0) as u32;
// Count DAST findings from pentest sessions
let total_vulnerabilities = db
.dast_findings()
.count_documents(doc! { "session_id": { "$exists": true, "$ne": null } })
.await
.unwrap_or(0) as u32;
// Aggregate tool invocations from all sessions
let sessions: Vec<PentestSession> = match db.pentest_sessions().find(doc! {}).await {
Ok(cursor) => collect_cursor_async(cursor).await,
Err(_) => Vec::new(),
};
let total_tool_invocations: u32 = sessions.iter().map(|s| s.tool_invocations).sum();
let total_successes: u32 = sessions.iter().map(|s| s.tool_successes).sum();
let tool_success_rate = if total_tool_invocations == 0 {
100.0
} else {
(total_successes as f64 / total_tool_invocations as f64) * 100.0
};
// Severity distribution from pentest-related DAST findings
let pentest_filter = doc! { "session_id": { "$exists": true, "$ne": null } };
let critical = db
.dast_findings()
.count_documents(doc! { "session_id": { "$exists": true, "$ne": null }, "severity": "critical" })
.await
.unwrap_or(0) as u32;
let high = db
.dast_findings()
.count_documents(doc! { "session_id": { "$exists": true, "$ne": null }, "severity": "high" })
.await
.unwrap_or(0) as u32;
let medium = db
.dast_findings()
.count_documents(doc! { "session_id": { "$exists": true, "$ne": null }, "severity": "medium" })
.await
.unwrap_or(0) as u32;
let low = db
.dast_findings()
.count_documents(doc! { "session_id": { "$exists": true, "$ne": null }, "severity": "low" })
.await
.unwrap_or(0) as u32;
let info = db
.dast_findings()
.count_documents(doc! { "session_id": { "$exists": true, "$ne": null }, "severity": "info" })
.await
.unwrap_or(0) as u32;
let _ = pentest_filter; // used above inline
Ok(Json(ApiResponse {
data: PentestStats {
running_sessions,
total_vulnerabilities,
total_tool_invocations,
tool_success_rate,
severity_distribution: SeverityDistribution {
critical,
high,
medium,
low,
info,
},
},
total: None,
page: None,
}))
}
/// GET /api/v1/pentest/sessions/:id/findings — Get DAST findings for a pentest session
#[tracing::instrument(skip_all, fields(session_id = %id))]
pub async fn get_session_findings(
Extension(agent): AgentExt,
Path(id): Path<String>,
Query(params): Query<PaginationParams>,
) -> Result<Json<ApiResponse<Vec<DastFinding>>>, StatusCode> {
let _oid =
mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?;
let skip = (params.page.saturating_sub(1)) * params.limit as u64;
let total = agent
.db
.dast_findings()
.count_documents(doc! { "session_id": &id })
.await
.unwrap_or(0);
let findings = match agent
.db
.dast_findings()
.find(doc! { "session_id": &id })
.sort(doc! { "created_at": -1 })
.skip(skip)
.limit(params.limit)
.await
{
Ok(cursor) => collect_cursor_async(cursor).await,
Err(e) => {
tracing::warn!("Failed to fetch pentest session findings: {e}");
Vec::new()
}
};
Ok(Json(ApiResponse {
data: findings,
total: Some(total),
page: Some(params.page),
}))
}

View File

@@ -99,6 +99,36 @@ pub fn build_router() -> Router {
"/api/v1/chat/{repo_id}/status",
get(handlers::chat::embedding_status),
)
// Pentest API endpoints
.route(
"/api/v1/pentest/sessions",
get(handlers::pentest::list_sessions).post(handlers::pentest::create_session),
)
.route(
"/api/v1/pentest/sessions/{id}",
get(handlers::pentest::get_session),
)
.route(
"/api/v1/pentest/sessions/{id}/chat",
post(handlers::pentest::send_message),
)
.route(
"/api/v1/pentest/sessions/{id}/stream",
get(handlers::pentest::session_stream),
)
.route(
"/api/v1/pentest/sessions/{id}/attack-chain",
get(handlers::pentest::get_attack_chain),
)
.route(
"/api/v1/pentest/sessions/{id}/messages",
get(handlers::pentest::get_messages),
)
.route(
"/api/v1/pentest/sessions/{id}/findings",
get(handlers::pentest::get_session_findings),
)
.route("/api/v1/pentest/stats", get(handlers::pentest::pentest_stats))
// Webhook endpoints (proxied through dashboard)
.route(
"/webhook/github/{repo_id}",

View File

@@ -166,6 +166,38 @@ impl Database {
)
.await?;
// pentest_sessions: compound (target_id, started_at DESC)
self.pentest_sessions()
.create_index(
IndexModel::builder()
.keys(doc! { "target_id": 1, "started_at": -1 })
.build(),
)
.await?;
// pentest_sessions: status index
self.pentest_sessions()
.create_index(IndexModel::builder().keys(doc! { "status": 1 }).build())
.await?;
// attack_chain_nodes: compound (session_id, node_id)
self.attack_chain_nodes()
.create_index(
IndexModel::builder()
.keys(doc! { "session_id": 1, "node_id": 1 })
.build(),
)
.await?;
// pentest_messages: compound (session_id, created_at)
self.pentest_messages()
.create_index(
IndexModel::builder()
.keys(doc! { "session_id": 1, "created_at": 1 })
.build(),
)
.await?;
tracing::info!("Database indexes ensured");
Ok(())
}
@@ -235,6 +267,19 @@ impl Database {
self.inner.collection("embedding_builds")
}
// Pentest collections
pub fn pentest_sessions(&self) -> Collection<PentestSession> {
self.inner.collection("pentest_sessions")
}
pub fn attack_chain_nodes(&self) -> Collection<AttackChainNode> {
self.inner.collection("attack_chain_nodes")
}
pub fn pentest_messages(&self) -> Collection<PentestMessage> {
self.inner.collection("pentest_messages")
}
#[allow(dead_code)]
pub fn raw_collection(&self, name: &str) -> Collection<mongodb::bson::Document> {
self.inner.collection(name)

View File

@@ -12,10 +12,16 @@ pub struct LlmClient {
http: reqwest::Client,
}
#[derive(Serialize)]
struct ChatMessage {
role: String,
content: String,
// ── Request types ──────────────────────────────────────────────
#[derive(Serialize, Clone, Debug)]
pub struct ChatMessage {
pub role: String,
pub content: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_calls: Option<Vec<ToolCallRequest>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_call_id: Option<String>,
}
#[derive(Serialize)]
@@ -26,8 +32,25 @@ struct ChatCompletionRequest {
temperature: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
max_tokens: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
tools: Option<Vec<ToolDefinitionPayload>>,
}
#[derive(Serialize)]
struct ToolDefinitionPayload {
r#type: String,
function: ToolFunctionPayload,
}
#[derive(Serialize)]
struct ToolFunctionPayload {
name: String,
description: String,
parameters: serde_json::Value,
}
// ── Response types ─────────────────────────────────────────────
#[derive(Deserialize)]
struct ChatCompletionResponse {
choices: Vec<ChatChoice>,
@@ -40,29 +63,84 @@ struct ChatChoice {
#[derive(Deserialize)]
struct ChatResponseMessage {
content: String,
#[serde(default)]
content: Option<String>,
#[serde(default)]
tool_calls: Option<Vec<ToolCallResponse>>,
}
/// Request body for the embeddings API
#[derive(Deserialize)]
struct ToolCallResponse {
id: String,
function: ToolCallFunction,
}
#[derive(Deserialize)]
struct ToolCallFunction {
name: String,
arguments: String,
}
// ── Public types for tool calling ──────────────────────────────
/// Definition of a tool that the LLM can invoke
#[derive(Debug, Clone, Serialize)]
pub struct ToolDefinition {
pub name: String,
pub description: String,
pub parameters: serde_json::Value,
}
/// A tool call request from the LLM
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LlmToolCall {
pub id: String,
pub name: String,
pub arguments: serde_json::Value,
}
/// A tool call in the request message format (for sending back tool_calls in assistant messages)
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolCallRequest {
pub id: String,
pub r#type: String,
pub function: ToolCallRequestFunction,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolCallRequestFunction {
pub name: String,
pub arguments: String,
}
/// Response from the LLM — either content or tool calls
#[derive(Debug, Clone)]
pub enum LlmResponse {
Content(String),
ToolCalls(Vec<LlmToolCall>),
}
// ── Embedding types ────────────────────────────────────────────
#[derive(Serialize)]
struct EmbeddingRequest {
model: String,
input: Vec<String>,
}
/// Response from the embeddings API
#[derive(Deserialize)]
struct EmbeddingResponse {
data: Vec<EmbeddingData>,
}
/// A single embedding result
#[derive(Deserialize)]
struct EmbeddingData {
embedding: Vec<f64>,
index: usize,
}
// ── Implementation ─────────────────────────────────────────────
impl LlmClient {
pub fn new(
base_url: String,
@@ -83,98 +161,142 @@ impl LlmClient {
&self.embed_model
}
fn chat_url(&self) -> String {
format!(
"{}/v1/chat/completions",
self.base_url.trim_end_matches('/')
)
}
fn auth_header(&self) -> Option<String> {
let key = self.api_key.expose_secret();
if key.is_empty() {
None
} else {
Some(format!("Bearer {key}"))
}
}
/// Simple chat: system + user prompt → text response
pub async fn chat(
&self,
system_prompt: &str,
user_prompt: &str,
temperature: Option<f64>,
) -> Result<String, AgentError> {
let url = format!(
"{}/v1/chat/completions",
self.base_url.trim_end_matches('/')
);
let messages = vec![
ChatMessage {
role: "system".to_string(),
content: Some(system_prompt.to_string()),
tool_calls: None,
tool_call_id: None,
},
ChatMessage {
role: "user".to_string(),
content: Some(user_prompt.to_string()),
tool_calls: None,
tool_call_id: None,
},
];
let request_body = ChatCompletionRequest {
model: self.model.clone(),
messages: vec![
ChatMessage {
role: "system".to_string(),
content: system_prompt.to_string(),
},
ChatMessage {
role: "user".to_string(),
content: user_prompt.to_string(),
},
],
messages,
temperature,
max_tokens: Some(4096),
tools: None,
};
let mut req = self
.http
.post(&url)
.header("content-type", "application/json")
.json(&request_body);
let key = self.api_key.expose_secret();
if !key.is_empty() {
req = req.header("Authorization", format!("Bearer {key}"));
}
let resp = req
.send()
.await
.map_err(|e| AgentError::Other(format!("LiteLLM request failed: {e}")))?;
if !resp.status().is_success() {
let status = resp.status();
let body = resp.text().await.unwrap_or_default();
return Err(AgentError::Other(format!(
"LiteLLM returned {status}: {body}"
)));
}
let body: ChatCompletionResponse = resp
.json()
.await
.map_err(|e| AgentError::Other(format!("Failed to parse LiteLLM response: {e}")))?;
body.choices
.first()
.map(|c| c.message.content.clone())
.ok_or_else(|| AgentError::Other("Empty response from LiteLLM".to_string()))
self.send_chat_request(&request_body).await.map(|resp| {
match resp {
LlmResponse::Content(c) => c,
LlmResponse::ToolCalls(_) => String::new(), // shouldn't happen without tools
}
})
}
/// Chat with a list of (role, content) messages → text response
#[allow(dead_code)]
pub async fn chat_with_messages(
&self,
messages: Vec<(String, String)>,
temperature: Option<f64>,
) -> Result<String, AgentError> {
let url = format!(
"{}/v1/chat/completions",
self.base_url.trim_end_matches('/')
);
let messages = messages
.into_iter()
.map(|(role, content)| ChatMessage {
role,
content: Some(content),
tool_calls: None,
tool_call_id: None,
})
.collect();
let request_body = ChatCompletionRequest {
model: self.model.clone(),
messages: messages
.into_iter()
.map(|(role, content)| ChatMessage { role, content })
.collect(),
messages,
temperature,
max_tokens: Some(4096),
tools: None,
};
self.send_chat_request(&request_body).await.map(|resp| {
match resp {
LlmResponse::Content(c) => c,
LlmResponse::ToolCalls(_) => String::new(),
}
})
}
/// Chat with tool definitions — returns either content or tool calls.
/// Use this for the AI pentest orchestrator loop.
pub async fn chat_with_tools(
&self,
messages: Vec<ChatMessage>,
tools: &[ToolDefinition],
temperature: Option<f64>,
max_tokens: Option<u32>,
) -> Result<LlmResponse, AgentError> {
let tool_payloads: Vec<ToolDefinitionPayload> = tools
.iter()
.map(|t| ToolDefinitionPayload {
r#type: "function".to_string(),
function: ToolFunctionPayload {
name: t.name.clone(),
description: t.description.clone(),
parameters: t.parameters.clone(),
},
})
.collect();
let request_body = ChatCompletionRequest {
model: self.model.clone(),
messages,
temperature,
max_tokens: Some(max_tokens.unwrap_or(8192)),
tools: if tool_payloads.is_empty() {
None
} else {
Some(tool_payloads)
},
};
self.send_chat_request(&request_body).await
}
/// Internal method to send a chat completion request and parse the response
async fn send_chat_request(
&self,
request_body: &ChatCompletionRequest,
) -> Result<LlmResponse, AgentError> {
let mut req = self
.http
.post(&url)
.post(&self.chat_url())
.header("content-type", "application/json")
.json(&request_body);
.json(request_body);
let key = self.api_key.expose_secret();
if !key.is_empty() {
req = req.header("Authorization", format!("Bearer {key}"));
if let Some(auth) = self.auth_header() {
req = req.header("Authorization", auth);
}
let resp = req
@@ -195,10 +317,37 @@ impl LlmClient {
.await
.map_err(|e| AgentError::Other(format!("Failed to parse LiteLLM response: {e}")))?;
body.choices
let choice = body
.choices
.first()
.map(|c| c.message.content.clone())
.ok_or_else(|| AgentError::Other("Empty response from LiteLLM".to_string()))
.ok_or_else(|| AgentError::Other("Empty response from LiteLLM".to_string()))?;
// Check for tool calls first
if let Some(tool_calls) = &choice.message.tool_calls {
if !tool_calls.is_empty() {
let calls: Vec<LlmToolCall> = tool_calls
.iter()
.map(|tc| {
let arguments = serde_json::from_str(&tc.function.arguments)
.unwrap_or(serde_json::Value::Object(serde_json::Map::new()));
LlmToolCall {
id: tc.id.clone(),
name: tc.function.name.clone(),
arguments,
}
})
.collect();
return Ok(LlmResponse::ToolCalls(calls));
}
}
// Otherwise return content
let content = choice
.message
.content
.clone()
.unwrap_or_default();
Ok(LlmResponse::Content(content))
}
/// Generate embeddings for a batch of texts
@@ -216,9 +365,8 @@ impl LlmClient {
.header("content-type", "application/json")
.json(&request_body);
let key = self.api_key.expose_secret();
if !key.is_empty() {
req = req.header("Authorization", format!("Bearer {key}"));
if let Some(auth) = self.auth_header() {
req = req.header("Authorization", auth);
}
let resp = req
@@ -239,7 +387,6 @@ impl LlmClient {
.await
.map_err(|e| AgentError::Other(format!("Failed to parse embedding response: {e}")))?;
// Sort by index to maintain input order
let mut data = body.data;
data.sort_by_key(|d| d.index);

View File

@@ -4,6 +4,7 @@ mod config;
mod database;
mod error;
mod llm;
mod pentest;
mod pipeline;
mod rag;
mod scheduler;

View File

@@ -0,0 +1,3 @@
pub mod orchestrator;
pub use orchestrator::PentestOrchestrator;

View File

@@ -0,0 +1,393 @@
use std::sync::Arc;
use tokio::sync::broadcast;
use compliance_core::models::dast::DastTarget;
use compliance_core::models::pentest::*;
use compliance_core::traits::pentest_tool::PentestToolContext;
use compliance_dast::ToolRegistry;
use crate::database::Database;
use crate::llm::client::{
ChatMessage, LlmResponse, ToolCallRequest, ToolCallRequestFunction, ToolDefinition,
};
use crate::llm::LlmClient;
pub struct PentestOrchestrator {
tool_registry: ToolRegistry,
llm: Arc<LlmClient>,
db: Database,
event_tx: broadcast::Sender<PentestEvent>,
}
impl PentestOrchestrator {
pub fn new(llm: Arc<LlmClient>, db: Database) -> Self {
let (event_tx, _) = broadcast::channel(256);
Self {
tool_registry: ToolRegistry::new(),
llm,
db,
event_tx,
}
}
pub fn subscribe(&self) -> broadcast::Receiver<PentestEvent> {
self.event_tx.subscribe()
}
pub fn event_sender(&self) -> broadcast::Sender<PentestEvent> {
self.event_tx.clone()
}
pub async fn run_session(
&self,
session: &PentestSession,
target: &DastTarget,
initial_message: &str,
) -> Result<(), crate::error::AgentError> {
let session_id = session
.id
.map(|oid| oid.to_hex())
.unwrap_or_default();
// Build system prompt
let system_prompt = self.build_system_prompt(session, target);
// Build tool definitions for LLM
let tool_defs: Vec<ToolDefinition> = self
.tool_registry
.all_definitions()
.into_iter()
.map(|td| ToolDefinition {
name: td.name,
description: td.description,
parameters: td.input_schema,
})
.collect();
// Initialize messages
let mut messages = vec![
ChatMessage {
role: "system".to_string(),
content: Some(system_prompt),
tool_calls: None,
tool_call_id: None,
},
ChatMessage {
role: "user".to_string(),
content: Some(initial_message.to_string()),
tool_calls: None,
tool_call_id: None,
},
];
// Store user message
let user_msg = PentestMessage::user(session_id.clone(), initial_message.to_string());
let _ = self.db.pentest_messages().insert_one(&user_msg).await;
// Build tool context
let tool_context = PentestToolContext {
target: target.clone(),
session_id: session_id.clone(),
sast_findings: Vec::new(),
sbom_entries: Vec::new(),
code_context: Vec::new(),
rate_limit: target.rate_limit,
allow_destructive: target.allow_destructive,
};
let max_iterations = 50;
let mut total_findings = 0u32;
let mut total_tool_calls = 0u32;
let mut total_successes = 0u32;
for _iteration in 0..max_iterations {
// Call LLM with tools
let response = self
.llm
.chat_with_tools(messages.clone(), &tool_defs, Some(0.2), Some(8192))
.await?;
match response {
LlmResponse::Content(content) => {
// Store assistant message
let msg =
PentestMessage::assistant(session_id.clone(), content.clone());
let _ = self.db.pentest_messages().insert_one(&msg).await;
// Emit message event
let _ = self.event_tx.send(PentestEvent::Message {
content: content.clone(),
});
// Add to messages
messages.push(ChatMessage {
role: "assistant".to_string(),
content: Some(content.clone()),
tool_calls: None,
tool_call_id: None,
});
// Check if the LLM considers itself done
let done_indicators = [
"pentest complete",
"testing complete",
"scan complete",
"analysis complete",
"finished",
"that concludes",
];
let content_lower = content.to_lowercase();
if done_indicators
.iter()
.any(|ind| content_lower.contains(ind))
{
break;
}
// If not done, break and wait for user input
break;
}
LlmResponse::ToolCalls(tool_calls) => {
// Build the assistant message with tool_calls
let tc_requests: Vec<ToolCallRequest> = tool_calls
.iter()
.map(|tc| ToolCallRequest {
id: tc.id.clone(),
r#type: "function".to_string(),
function: ToolCallRequestFunction {
name: tc.name.clone(),
arguments: serde_json::to_string(&tc.arguments)
.unwrap_or_default(),
},
})
.collect();
messages.push(ChatMessage {
role: "assistant".to_string(),
content: None,
tool_calls: Some(tc_requests),
tool_call_id: None,
});
// Execute each tool call
for tc in &tool_calls {
total_tool_calls += 1;
let node_id = uuid::Uuid::new_v4().to_string();
// Create attack chain node
let mut node = AttackChainNode::new(
session_id.clone(),
node_id.clone(),
tc.name.clone(),
tc.arguments.clone(),
String::new(),
);
node.status = AttackNodeStatus::Running;
node.started_at = Some(chrono::Utc::now());
let _ = self.db.attack_chain_nodes().insert_one(&node).await;
// Emit tool start event
let _ = self.event_tx.send(PentestEvent::ToolStart {
node_id: node_id.clone(),
tool_name: tc.name.clone(),
input: tc.arguments.clone(),
});
// Execute the tool
let result = if let Some(tool) = self.tool_registry.get(&tc.name) {
match tool.execute(tc.arguments.clone(), &tool_context).await {
Ok(result) => {
total_successes += 1;
let findings_count = result.findings.len() as u32;
total_findings += findings_count;
// Store findings
for mut finding in result.findings {
finding.scan_run_id = session_id.clone();
finding.session_id = Some(session_id.clone());
let _ =
self.db.dast_findings().insert_one(&finding).await;
let _ =
self.event_tx.send(PentestEvent::Finding {
finding_id: finding
.id
.map(|oid| oid.to_hex())
.unwrap_or_default(),
title: finding.title.clone(),
severity: finding.severity.to_string(),
});
}
// Emit tool complete event
let _ = self.event_tx.send(PentestEvent::ToolComplete {
node_id: node_id.clone(),
summary: result.summary.clone(),
findings_count,
});
// Update attack chain node
let _ = self
.db
.attack_chain_nodes()
.update_one(
mongodb::bson::doc! {
"session_id": &session_id,
"node_id": &node_id,
},
mongodb::bson::doc! { "$set": {
"status": "completed",
"tool_output": mongodb::bson::to_bson(&result.data)
.unwrap_or(mongodb::bson::Bson::Null),
"completed_at": mongodb::bson::DateTime::now(),
}},
)
.await;
serde_json::json!({
"summary": result.summary,
"findings_count": findings_count,
"data": result.data,
})
.to_string()
}
Err(e) => {
// Update node as failed
let _ = self
.db
.attack_chain_nodes()
.update_one(
mongodb::bson::doc! {
"session_id": &session_id,
"node_id": &node_id,
},
mongodb::bson::doc! { "$set": {
"status": "failed",
"completed_at": mongodb::bson::DateTime::now(),
}},
)
.await;
format!("Tool execution failed: {e}")
}
}
} else {
format!("Unknown tool: {}", tc.name)
};
// Add tool result to messages
messages.push(ChatMessage {
role: "tool".to_string(),
content: Some(result),
tool_calls: None,
tool_call_id: Some(tc.id.clone()),
});
}
// Update session stats
if let Some(sid) = session.id {
let _ = self
.db
.pentest_sessions()
.update_one(
mongodb::bson::doc! { "_id": sid },
mongodb::bson::doc! { "$set": {
"tool_invocations": total_tool_calls as i64,
"tool_successes": total_successes as i64,
"findings_count": total_findings as i64,
}},
)
.await;
}
}
}
}
// Mark session as completed
if let Some(sid) = session.id {
let _ = self
.db
.pentest_sessions()
.update_one(
mongodb::bson::doc! { "_id": sid },
mongodb::bson::doc! { "$set": {
"status": "completed",
"completed_at": mongodb::bson::DateTime::now(),
"tool_invocations": total_tool_calls as i64,
"tool_successes": total_successes as i64,
"findings_count": total_findings as i64,
}},
)
.await;
}
let _ = self.event_tx.send(PentestEvent::Complete {
summary: format!(
"Pentest complete. {} findings from {} tool invocations.",
total_findings, total_tool_calls
),
});
Ok(())
}
fn build_system_prompt(&self, session: &PentestSession, target: &DastTarget) -> String {
let tool_names = self.tool_registry.list_names().join(", ");
let strategy_guidance = match session.strategy {
PentestStrategy::Quick => {
"Focus on the most common and impactful vulnerabilities. Run a quick recon, then target the highest-risk areas."
}
PentestStrategy::Comprehensive => {
"Perform a thorough assessment covering all vulnerability types. Start with recon, then systematically test each attack surface."
}
PentestStrategy::Targeted => {
"Focus specifically on areas highlighted by SAST findings and known CVEs. Prioritize exploiting known weaknesses."
}
PentestStrategy::Aggressive => {
"Use all available tools aggressively. Test with maximum payloads and attempt full exploitation."
}
PentestStrategy::Stealth => {
"Minimize noise. Use fewer requests, avoid aggressive payloads. Focus on passive analysis and targeted probes."
}
};
format!(
r#"You are an expert penetration tester conducting an authorized security assessment.
## Target
- **Name**: {target_name}
- **URL**: {base_url}
- **Type**: {target_type}
- **Rate Limit**: {rate_limit} req/s
- **Destructive Tests Allowed**: {allow_destructive}
## Strategy
{strategy_guidance}
## Available Tools
{tool_names}
## Instructions
1. Start by running reconnaissance and crawling to understand the target.
2. Based on what you discover, select appropriate vulnerability scanning tools.
3. For each tool invocation, provide the discovered endpoints and parameters.
4. Analyze tool results and chain findings — if you find one vulnerability, explore whether it enables others.
5. When testing is complete, provide a summary of all findings with severity and remediation recommendations.
6. Always explain your reasoning before invoking each tool.
7. Focus on actionable findings with evidence. Avoid false positives.
8. When you have completed all relevant testing, say "Testing complete" followed by a final summary.
## Important
- This is an authorized penetration test. All testing is permitted within the target scope.
- Respect the rate limit of {rate_limit} requests per second.
- Only use destructive tests if explicitly allowed ({allow_destructive}).
"#,
target_name = target.name,
base_url = target.base_url,
target_type = target.target_type,
rate_limit = target.rate_limit,
allow_destructive = target.allow_destructive,
)
}
}

View File

@@ -176,6 +176,16 @@ pub enum DastVulnType {
InformationDisclosure,
SecurityMisconfiguration,
BrokenAuth,
DnsMisconfiguration,
EmailSecurity,
TlsMisconfiguration,
CookieSecurity,
CspIssue,
CorsMisconfiguration,
RateLimitAbsent,
ConsoleLogLeakage,
SecurityHeaderMissing,
KnownCveExploit,
Other,
}
@@ -192,6 +202,16 @@ impl std::fmt::Display for DastVulnType {
Self::InformationDisclosure => write!(f, "information_disclosure"),
Self::SecurityMisconfiguration => write!(f, "security_misconfiguration"),
Self::BrokenAuth => write!(f, "broken_auth"),
Self::DnsMisconfiguration => write!(f, "dns_misconfiguration"),
Self::EmailSecurity => write!(f, "email_security"),
Self::TlsMisconfiguration => write!(f, "tls_misconfiguration"),
Self::CookieSecurity => write!(f, "cookie_security"),
Self::CspIssue => write!(f, "csp_issue"),
Self::CorsMisconfiguration => write!(f, "cors_misconfiguration"),
Self::RateLimitAbsent => write!(f, "rate_limit_absent"),
Self::ConsoleLogLeakage => write!(f, "console_log_leakage"),
Self::SecurityHeaderMissing => write!(f, "security_header_missing"),
Self::KnownCveExploit => write!(f, "known_cve_exploit"),
Self::Other => write!(f, "other"),
}
}
@@ -244,6 +264,8 @@ pub struct DastFinding {
pub remediation: Option<String>,
/// Linked SAST finding ID (if correlated)
pub linked_sast_finding_id: Option<String>,
/// Pentest session that produced this finding (if AI-driven)
pub session_id: Option<String>,
#[serde(with = "super::serde_helpers::bson_datetime")]
pub created_at: DateTime<Utc>,
}
@@ -276,6 +298,7 @@ impl DastFinding {
evidence: Vec::new(),
remediation: None,
linked_sast_finding_id: None,
session_id: None,
created_at: Utc::now(),
}
}

View File

@@ -7,6 +7,7 @@ pub mod finding;
pub mod graph;
pub mod issue;
pub mod mcp;
pub mod pentest;
pub mod repository;
pub mod sbom;
pub mod scan;
@@ -26,6 +27,11 @@ pub use graph::{
};
pub use issue::{IssueStatus, TrackerIssue, TrackerType};
pub use mcp::{McpServerConfig, McpServerStatus, McpTransport};
pub use pentest::{
AttackChainNode, AttackNodeStatus, CodeContextHint, PentestEvent, PentestMessage,
PentestSession, PentestStats, PentestStatus, PentestStrategy, SeverityDistribution,
ToolCallRecord,
};
pub use repository::{ScanTrigger, TrackedRepository};
pub use sbom::{SbomEntry, VulnRef};
pub use scan::{ScanPhase, ScanRun, ScanRunStatus, ScanType};

View File

@@ -0,0 +1,294 @@
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
/// Status of a pentest session
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum PentestStatus {
Running,
Paused,
Completed,
Failed,
}
impl std::fmt::Display for PentestStatus {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Running => write!(f, "running"),
Self::Paused => write!(f, "paused"),
Self::Completed => write!(f, "completed"),
Self::Failed => write!(f, "failed"),
}
}
}
/// Strategy for the AI pentest orchestrator
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum PentestStrategy {
/// Quick scan focusing on common vulnerabilities
Quick,
/// Standard comprehensive scan
Comprehensive,
/// Focus on specific vulnerability types guided by SAST/SBOM
Targeted,
/// Aggressive testing with more payloads and deeper exploitation
Aggressive,
/// Stealth mode with slower rate and fewer noisy payloads
Stealth,
}
impl std::fmt::Display for PentestStrategy {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Quick => write!(f, "quick"),
Self::Comprehensive => write!(f, "comprehensive"),
Self::Targeted => write!(f, "targeted"),
Self::Aggressive => write!(f, "aggressive"),
Self::Stealth => write!(f, "stealth"),
}
}
}
/// A pentest session initiated via the chat interface
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PentestSession {
#[serde(rename = "_id", skip_serializing_if = "Option::is_none")]
pub id: Option<bson::oid::ObjectId>,
pub target_id: String,
/// Linked repository for code-aware testing
pub repo_id: Option<String>,
pub status: PentestStatus,
pub strategy: PentestStrategy,
pub created_by: Option<String>,
/// Total number of tool invocations in this session
pub tool_invocations: u32,
/// Total successful tool invocations
pub tool_successes: u32,
/// Number of findings discovered
pub findings_count: u32,
/// Number of confirmed exploitable findings
pub exploitable_count: u32,
#[serde(with = "super::serde_helpers::bson_datetime")]
pub started_at: DateTime<Utc>,
#[serde(default, with = "super::serde_helpers::opt_bson_datetime")]
pub completed_at: Option<DateTime<Utc>>,
}
impl PentestSession {
pub fn new(target_id: String, strategy: PentestStrategy) -> Self {
Self {
id: None,
target_id,
repo_id: None,
status: PentestStatus::Running,
strategy,
created_by: None,
tool_invocations: 0,
tool_successes: 0,
findings_count: 0,
exploitable_count: 0,
started_at: Utc::now(),
completed_at: None,
}
}
pub fn success_rate(&self) -> f64 {
if self.tool_invocations == 0 {
return 100.0;
}
(self.tool_successes as f64 / self.tool_invocations as f64) * 100.0
}
}
/// Status of a node in the attack chain
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum AttackNodeStatus {
Pending,
Running,
Completed,
Failed,
Skipped,
}
/// A single step in the LLM-driven attack chain DAG
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AttackChainNode {
#[serde(rename = "_id", skip_serializing_if = "Option::is_none")]
pub id: Option<bson::oid::ObjectId>,
pub session_id: String,
/// Unique ID for DAG references
pub node_id: String,
/// Parent node IDs (multiple for merge nodes)
pub parent_node_ids: Vec<String>,
/// Tool that was invoked
pub tool_name: String,
/// Input parameters passed to the tool
pub tool_input: serde_json::Value,
/// Output from the tool
pub tool_output: Option<serde_json::Value>,
pub status: AttackNodeStatus,
/// LLM's reasoning for choosing this action
pub llm_reasoning: String,
/// IDs of DastFindings produced by this step
pub findings_produced: Vec<String>,
/// Risk score (0-100) assigned by the LLM
pub risk_score: Option<u8>,
#[serde(default, with = "super::serde_helpers::opt_bson_datetime")]
pub started_at: Option<DateTime<Utc>>,
#[serde(default, with = "super::serde_helpers::opt_bson_datetime")]
pub completed_at: Option<DateTime<Utc>>,
}
impl AttackChainNode {
pub fn new(
session_id: String,
node_id: String,
tool_name: String,
tool_input: serde_json::Value,
llm_reasoning: String,
) -> Self {
Self {
id: None,
session_id,
node_id,
parent_node_ids: Vec::new(),
tool_name,
tool_input,
tool_output: None,
status: AttackNodeStatus::Pending,
llm_reasoning,
findings_produced: Vec::new(),
risk_score: None,
started_at: None,
completed_at: None,
}
}
}
/// Chat message within a pentest session
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PentestMessage {
#[serde(rename = "_id", skip_serializing_if = "Option::is_none")]
pub id: Option<bson::oid::ObjectId>,
pub session_id: String,
/// "user", "assistant", "tool_result", "system"
pub role: String,
pub content: String,
/// Tool calls made by the assistant in this message
pub tool_calls: Option<Vec<ToolCallRecord>>,
/// Link to the attack chain node (for tool_result messages)
pub attack_node_id: Option<String>,
#[serde(with = "super::serde_helpers::bson_datetime")]
pub created_at: DateTime<Utc>,
}
impl PentestMessage {
pub fn user(session_id: String, content: String) -> Self {
Self {
id: None,
session_id,
role: "user".to_string(),
content,
tool_calls: None,
attack_node_id: None,
created_at: Utc::now(),
}
}
pub fn assistant(session_id: String, content: String) -> Self {
Self {
id: None,
session_id,
role: "assistant".to_string(),
content,
tool_calls: None,
attack_node_id: None,
created_at: Utc::now(),
}
}
pub fn tool_result(session_id: String, content: String, node_id: String) -> Self {
Self {
id: None,
session_id,
role: "tool_result".to_string(),
content,
tool_calls: None,
attack_node_id: Some(node_id),
created_at: Utc::now(),
}
}
}
/// Record of a tool call made by the LLM
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolCallRecord {
pub call_id: String,
pub tool_name: String,
pub arguments: serde_json::Value,
pub result: Option<serde_json::Value>,
}
/// SSE event types for real-time pentest streaming
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum PentestEvent {
/// LLM is thinking/reasoning
Thinking { reasoning: String },
/// A tool execution has started
ToolStart {
node_id: String,
tool_name: String,
input: serde_json::Value,
},
/// A tool execution completed
ToolComplete {
node_id: String,
summary: String,
findings_count: u32,
},
/// A new finding was discovered
Finding { finding_id: String, title: String, severity: String },
/// Assistant message (streaming text)
Message { content: String },
/// Session completed
Complete { summary: String },
/// Error occurred
Error { message: String },
}
/// Aggregated stats for the pentest dashboard
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PentestStats {
pub running_sessions: u32,
pub total_vulnerabilities: u32,
pub total_tool_invocations: u32,
pub tool_success_rate: f64,
pub severity_distribution: SeverityDistribution,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SeverityDistribution {
pub critical: u32,
pub high: u32,
pub medium: u32,
pub low: u32,
pub info: u32,
}
/// Code context hint linking a discovered endpoint to source code
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CodeContextHint {
/// HTTP route pattern (e.g., "GET /api/users/:id")
pub endpoint_pattern: String,
/// Handler function name
pub handler_function: String,
/// Source file path
pub file_path: String,
/// Relevant code snippet
pub code_snippet: String,
/// SAST findings associated with this code
pub known_vulnerabilities: Vec<String>,
}

View File

@@ -1,9 +1,11 @@
pub mod dast_agent;
pub mod graph_builder;
pub mod issue_tracker;
pub mod pentest_tool;
pub mod scanner;
pub use dast_agent::{DastAgent, DastContext, DiscoveredEndpoint, EndpointParameter};
pub use graph_builder::{LanguageParser, ParseOutput};
pub use issue_tracker::IssueTracker;
pub use pentest_tool::{PentestTool, PentestToolContext, PentestToolResult};
pub use scanner::{ScanOutput, Scanner};

View File

@@ -0,0 +1,63 @@
use std::future::Future;
use std::pin::Pin;
use crate::error::CoreError;
use crate::models::dast::{DastFinding, DastTarget};
use crate::models::finding::Finding;
use crate::models::pentest::CodeContextHint;
use crate::models::sbom::SbomEntry;
/// Context passed to pentest tools during execution.
///
/// The HTTP client is not included here because `compliance-core` does not
/// depend on `reqwest`. Tools that need HTTP should hold their own client
/// or receive one via the `compliance-dast` orchestrator.
pub struct PentestToolContext {
/// The DAST target being tested
pub target: DastTarget,
/// Session ID for this pentest run
pub session_id: String,
/// SAST findings for the linked repo (if any)
pub sast_findings: Vec<Finding>,
/// SBOM entries with known CVEs (if any)
pub sbom_entries: Vec<SbomEntry>,
/// Code knowledge graph hints mapping endpoints to source code
pub code_context: Vec<CodeContextHint>,
/// Rate limit (requests per second)
pub rate_limit: u32,
/// Whether destructive operations are allowed
pub allow_destructive: bool,
}
/// Result from a pentest tool execution
pub struct PentestToolResult {
/// Human-readable summary of what the tool found
pub summary: String,
/// DAST findings produced by this tool
pub findings: Vec<DastFinding>,
/// Tool-specific structured output data
pub data: serde_json::Value,
}
/// A tool that the LLM pentest orchestrator can invoke.
///
/// Each tool represents a specific security testing capability
/// (e.g., SQL injection scanner, DNS checker, TLS analyzer).
/// Uses boxed futures for dyn-compatibility.
pub trait PentestTool: Send + Sync {
/// Tool name for LLM tool_use (e.g., "sql_injection_scanner")
fn name(&self) -> &str;
/// Human-readable description for the LLM system prompt
fn description(&self) -> &str;
/// JSON Schema for the tool's input parameters
fn input_schema(&self) -> serde_json::Value;
/// Execute the tool with the given input
fn execute<'a>(
&'a self,
input: serde_json::Value,
context: &'a PentestToolContext,
) -> Pin<Box<dyn Future<Output = Result<PentestToolResult, CoreError>> + Send + 'a>>;
}

View File

@@ -38,6 +38,10 @@ pub enum Route {
DastFindingsPage {},
#[route("/dast/findings/:id")]
DastFindingDetailPage { id: String },
#[route("/pentest")]
PentestDashboardPage {},
#[route("/pentest/:session_id")]
PentestSessionPage { session_id: String },
#[route("/mcp-servers")]
McpServersPage {},
#[route("/settings")]

View File

@@ -47,6 +47,11 @@ pub fn Sidebar() -> Element {
route: Route::DastOverviewPage {},
icon: rsx! { Icon { icon: BsBug, width: 18, height: 18 } },
},
NavItem {
label: "Pentest",
route: Route::PentestDashboardPage {},
icon: rsx! { Icon { icon: BsLightningCharge, width: 18, height: 18 } },
},
NavItem {
label: "Settings",
route: Route::SettingsPage {},
@@ -78,6 +83,7 @@ pub fn Sidebar() -> Element {
(Route::DastTargetsPage {}, Route::DastOverviewPage {}) => true,
(Route::DastFindingsPage {}, Route::DastOverviewPage {}) => true,
(Route::DastFindingDetailPage { .. }, Route::DastOverviewPage {}) => true,
(Route::PentestSessionPage { .. }, Route::PentestDashboardPage {}) => true,
(a, b) => a == b,
};
let class = if is_active { "nav-item active" } else { "nav-item" };

View File

@@ -7,6 +7,7 @@ pub mod findings;
pub mod graph;
pub mod issues;
pub mod mcp;
pub mod pentest;
#[allow(clippy::too_many_arguments)]
pub mod repositories;
pub mod sbom;

View File

@@ -0,0 +1,190 @@
use dioxus::prelude::*;
use serde::{Deserialize, Serialize};
use super::dast::DastFindingsResponse;
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct PentestSessionsResponse {
pub data: Vec<serde_json::Value>,
pub total: Option<u64>,
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct PentestSessionResponse {
pub data: serde_json::Value,
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct PentestMessagesResponse {
pub data: Vec<serde_json::Value>,
pub total: Option<u64>,
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct PentestStatsResponse {
pub data: serde_json::Value,
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct AttackChainResponse {
pub data: Vec<serde_json::Value>,
}
#[server]
pub async fn fetch_pentest_sessions() -> Result<PentestSessionsResponse, ServerFnError> {
let state: super::server_state::ServerState =
dioxus_fullstack::FullstackContext::extract().await?;
let url = format!("{}/api/v1/pentest/sessions", state.agent_api_url);
let resp = reqwest::get(&url)
.await
.map_err(|e| ServerFnError::new(e.to_string()))?;
let body: PentestSessionsResponse = resp
.json()
.await
.map_err(|e| ServerFnError::new(e.to_string()))?;
Ok(body)
}
#[server]
pub async fn fetch_pentest_session(id: String) -> Result<PentestSessionResponse, ServerFnError> {
let state: super::server_state::ServerState =
dioxus_fullstack::FullstackContext::extract().await?;
let url = format!("{}/api/v1/pentest/sessions/{id}", state.agent_api_url);
let resp = reqwest::get(&url)
.await
.map_err(|e| ServerFnError::new(e.to_string()))?;
let body: PentestSessionResponse = resp
.json()
.await
.map_err(|e| ServerFnError::new(e.to_string()))?;
Ok(body)
}
#[server]
pub async fn fetch_pentest_messages(
session_id: String,
) -> Result<PentestMessagesResponse, ServerFnError> {
let state: super::server_state::ServerState =
dioxus_fullstack::FullstackContext::extract().await?;
let url = format!(
"{}/api/v1/pentest/sessions/{session_id}/messages",
state.agent_api_url
);
let resp = reqwest::get(&url)
.await
.map_err(|e| ServerFnError::new(e.to_string()))?;
let body: PentestMessagesResponse = resp
.json()
.await
.map_err(|e| ServerFnError::new(e.to_string()))?;
Ok(body)
}
#[server]
pub async fn fetch_pentest_stats() -> Result<PentestStatsResponse, ServerFnError> {
let state: super::server_state::ServerState =
dioxus_fullstack::FullstackContext::extract().await?;
let url = format!("{}/api/v1/pentest/stats", state.agent_api_url);
let resp = reqwest::get(&url)
.await
.map_err(|e| ServerFnError::new(e.to_string()))?;
let body: PentestStatsResponse = resp
.json()
.await
.map_err(|e| ServerFnError::new(e.to_string()))?;
Ok(body)
}
#[server]
pub async fn fetch_attack_chain(
session_id: String,
) -> Result<AttackChainResponse, ServerFnError> {
let state: super::server_state::ServerState =
dioxus_fullstack::FullstackContext::extract().await?;
let url = format!(
"{}/api/v1/pentest/sessions/{session_id}/attack-chain",
state.agent_api_url
);
let resp = reqwest::get(&url)
.await
.map_err(|e| ServerFnError::new(e.to_string()))?;
let body: AttackChainResponse = resp
.json()
.await
.map_err(|e| ServerFnError::new(e.to_string()))?;
Ok(body)
}
#[server]
pub async fn create_pentest_session(
target_id: String,
strategy: String,
message: String,
) -> Result<PentestSessionResponse, ServerFnError> {
let state: super::server_state::ServerState =
dioxus_fullstack::FullstackContext::extract().await?;
let url = format!("{}/api/v1/pentest/sessions", state.agent_api_url);
let client = reqwest::Client::new();
let resp = client
.post(&url)
.json(&serde_json::json!({
"target_id": target_id,
"strategy": strategy,
"message": message,
}))
.send()
.await
.map_err(|e| ServerFnError::new(e.to_string()))?;
let body: PentestSessionResponse = resp
.json()
.await
.map_err(|e| ServerFnError::new(e.to_string()))?;
Ok(body)
}
#[server]
pub async fn send_pentest_message(
session_id: String,
message: String,
) -> Result<PentestMessagesResponse, ServerFnError> {
let state: super::server_state::ServerState =
dioxus_fullstack::FullstackContext::extract().await?;
let url = format!(
"{}/api/v1/pentest/sessions/{session_id}/messages",
state.agent_api_url
);
let client = reqwest::Client::new();
let resp = client
.post(&url)
.json(&serde_json::json!({
"message": message,
}))
.send()
.await
.map_err(|e| ServerFnError::new(e.to_string()))?;
let body: PentestMessagesResponse = resp
.json()
.await
.map_err(|e| ServerFnError::new(e.to_string()))?;
Ok(body)
}
#[server]
pub async fn fetch_pentest_findings(
session_id: String,
) -> Result<DastFindingsResponse, ServerFnError> {
let state: super::server_state::ServerState =
dioxus_fullstack::FullstackContext::extract().await?;
let url = format!(
"{}/api/v1/pentest/sessions/{session_id}/findings",
state.agent_api_url
);
let resp = reqwest::get(&url)
.await
.map_err(|e| ServerFnError::new(e.to_string()))?;
let body: DastFindingsResponse = resp
.json()
.await
.map_err(|e| ServerFnError::new(e.to_string()))?;
Ok(body)
}

View File

@@ -12,6 +12,8 @@ pub mod impact_analysis;
pub mod issues;
pub mod mcp_servers;
pub mod overview;
pub mod pentest_dashboard;
pub mod pentest_session;
pub mod repositories;
pub mod sbom;
pub mod settings;
@@ -30,6 +32,8 @@ pub use impact_analysis::ImpactAnalysisPage;
pub use issues::IssuesPage;
pub use mcp_servers::McpServersPage;
pub use overview::OverviewPage;
pub use pentest_dashboard::PentestDashboardPage;
pub use pentest_session::PentestSessionPage;
pub use repositories::RepositoriesPage;
pub use sbom::SbomPage;
pub use settings::SettingsPage;

View File

@@ -0,0 +1,396 @@
use dioxus::prelude::*;
use dioxus_free_icons::icons::bs_icons::*;
use dioxus_free_icons::Icon;
use crate::app::Route;
use crate::components::page_header::PageHeader;
use crate::infrastructure::dast::fetch_dast_targets;
use crate::infrastructure::pentest::{
create_pentest_session, fetch_pentest_sessions, fetch_pentest_stats,
};
#[component]
pub fn PentestDashboardPage() -> Element {
let mut sessions = use_resource(|| async { fetch_pentest_sessions().await.ok() });
let stats = use_resource(|| async { fetch_pentest_stats().await.ok() });
let targets = use_resource(|| async { fetch_dast_targets().await.ok() });
let mut show_modal = use_signal(|| false);
let mut new_target_id = use_signal(String::new);
let mut new_strategy = use_signal(|| "comprehensive".to_string());
let mut new_message = use_signal(String::new);
let mut creating = use_signal(|| false);
let on_create = move |_| {
let tid = new_target_id.read().clone();
let strat = new_strategy.read().clone();
let msg = new_message.read().clone();
if tid.is_empty() || msg.is_empty() {
return;
}
creating.set(true);
spawn(async move {
match create_pentest_session(tid, strat, msg).await {
Ok(resp) => {
let session_id = resp
.data
.get("id")
.and_then(|v| v.as_str())
.unwrap_or("")
.to_string();
creating.set(false);
show_modal.set(false);
new_target_id.set(String::new());
new_message.set(String::new());
if !session_id.is_empty() {
navigator().push(Route::PentestSessionPage {
session_id: session_id.clone(),
});
} else {
sessions.restart();
}
}
Err(_) => {
creating.set(false);
}
}
});
};
// Extract stats values
let running_sessions = {
let s = stats.read();
match &*s {
Some(Some(data)) => data
.data
.get("running_sessions")
.and_then(|v| v.as_u64())
.unwrap_or(0),
_ => 0,
}
};
let total_vulns = {
let s = stats.read();
match &*s {
Some(Some(data)) => data
.data
.get("total_vulnerabilities")
.and_then(|v| v.as_u64())
.unwrap_or(0),
_ => 0,
}
};
let tool_invocations = {
let s = stats.read();
match &*s {
Some(Some(data)) => data
.data
.get("tool_invocations")
.and_then(|v| v.as_u64())
.unwrap_or(0),
_ => 0,
}
};
let success_rate = {
let s = stats.read();
match &*s {
Some(Some(data)) => data
.data
.get("success_rate")
.and_then(|v| v.as_f64())
.unwrap_or(0.0),
_ => 0.0,
}
};
// Severity counts from stats
let severity_critical = {
let s = stats.read();
match &*s {
Some(Some(data)) => data
.data
.get("severity_critical")
.and_then(|v| v.as_u64())
.unwrap_or(0),
_ => 0,
}
};
let severity_high = {
let s = stats.read();
match &*s {
Some(Some(data)) => data
.data
.get("severity_high")
.and_then(|v| v.as_u64())
.unwrap_or(0),
_ => 0,
}
};
let severity_medium = {
let s = stats.read();
match &*s {
Some(Some(data)) => data
.data
.get("severity_medium")
.and_then(|v| v.as_u64())
.unwrap_or(0),
_ => 0,
}
};
let severity_low = {
let s = stats.read();
match &*s {
Some(Some(data)) => data
.data
.get("severity_low")
.and_then(|v| v.as_u64())
.unwrap_or(0),
_ => 0,
}
};
rsx! {
PageHeader {
title: "Pentest Dashboard",
description: "AI-powered penetration testing sessions — autonomous security assessment",
}
// Stat cards
div { class: "stat-cards", style: "margin-bottom: 24px;",
div { class: "stat-card-item",
div { class: "stat-card-value", "{running_sessions}" }
div { class: "stat-card-label",
Icon { icon: BsPlayCircle, width: 14, height: 14 }
" Running Sessions"
}
}
div { class: "stat-card-item",
div { class: "stat-card-value", "{total_vulns}" }
div { class: "stat-card-label",
Icon { icon: BsShieldExclamation, width: 14, height: 14 }
" Total Vulnerabilities"
}
}
div { class: "stat-card-item",
div { class: "stat-card-value", "{tool_invocations}" }
div { class: "stat-card-label",
Icon { icon: BsWrench, width: 14, height: 14 }
" Tool Invocations"
}
}
div { class: "stat-card-item",
div { class: "stat-card-value", "{success_rate:.0}%" }
div { class: "stat-card-label",
Icon { icon: BsCheckCircle, width: 14, height: 14 }
" Success Rate"
}
}
}
// Severity distribution
div { class: "card", style: "margin-bottom: 24px; padding: 16px;",
div { style: "display: flex; align-items: center; gap: 16px; flex-wrap: wrap;",
span { style: "font-weight: 600; color: var(--text-secondary); font-size: 0.85rem;", "Severity Distribution" }
span {
class: "badge",
style: "background: #dc2626; color: #fff;",
"Critical: {severity_critical}"
}
span {
class: "badge",
style: "background: #ea580c; color: #fff;",
"High: {severity_high}"
}
span {
class: "badge",
style: "background: #d97706; color: #fff;",
"Medium: {severity_medium}"
}
span {
class: "badge",
style: "background: #2563eb; color: #fff;",
"Low: {severity_low}"
}
}
}
// Actions row
div { style: "display: flex; gap: 12px; margin-bottom: 24px;",
button {
class: "btn btn-primary",
onclick: move |_| show_modal.set(true),
Icon { icon: BsPlusCircle, width: 14, height: 14 }
" New Pentest"
}
}
// Sessions list
div { class: "card",
div { class: "card-header", "Recent Pentest Sessions" }
match &*sessions.read() {
Some(Some(data)) => {
let sess_list = &data.data;
if sess_list.is_empty() {
rsx! {
div { style: "padding: 32px; text-align: center; color: var(--text-secondary);",
p { "No pentest sessions yet. Start one to begin autonomous security testing." }
}
}
} else {
rsx! {
div { style: "display: grid; gap: 12px; padding: 16px;",
for session in sess_list {
{
let id = session.get("id").and_then(|v| v.as_str()).unwrap_or("-").to_string();
let target_name = session.get("target_name").and_then(|v| v.as_str()).unwrap_or("Unknown Target").to_string();
let status = session.get("status").and_then(|v| v.as_str()).unwrap_or("unknown").to_string();
let strategy = session.get("strategy").and_then(|v| v.as_str()).unwrap_or("-").to_string();
let findings_count = session.get("findings_count").and_then(|v| v.as_u64()).unwrap_or(0);
let tool_count = session.get("tool_invocations").and_then(|v| v.as_u64()).unwrap_or(0);
let created_at = session.get("created_at").and_then(|v| v.as_str()).unwrap_or("-").to_string();
let status_style = match status.as_str() {
"running" => "background: #16a34a; color: #fff;",
"completed" => "background: #2563eb; color: #fff;",
"failed" => "background: #dc2626; color: #fff;",
"paused" => "background: #d97706; color: #fff;",
_ => "background: var(--bg-tertiary); color: var(--text-secondary);",
};
rsx! {
Link {
to: Route::PentestSessionPage { session_id: id.clone() },
class: "card",
style: "padding: 16px; text-decoration: none; cursor: pointer; transition: border-color 0.15s;",
div { style: "display: flex; justify-content: space-between; align-items: flex-start;",
div {
div { style: "font-weight: 600; font-size: 1rem; margin-bottom: 4px; color: var(--text-primary);",
"{target_name}"
}
div { style: "display: flex; gap: 8px; align-items: center; flex-wrap: wrap;",
span {
class: "badge",
style: "{status_style}",
"{status}"
}
span {
class: "badge",
style: "background: var(--bg-tertiary); color: var(--text-secondary);",
"{strategy}"
}
}
}
div { style: "text-align: right; font-size: 0.85rem; color: var(--text-secondary);",
div { style: "margin-bottom: 4px;",
Icon { icon: BsShieldExclamation, width: 12, height: 12 }
" {findings_count} findings"
}
div { style: "margin-bottom: 4px;",
Icon { icon: BsWrench, width: 12, height: 12 }
" {tool_count} tools"
}
div { "{created_at}" }
}
}
}
}
}
}
}
}
}
},
Some(None) => rsx! { p { style: "padding: 16px;", "Failed to load sessions." } },
None => rsx! { p { style: "padding: 16px;", "Loading..." } },
}
}
// New Pentest Modal
if *show_modal.read() {
div {
style: "position: fixed; inset: 0; background: rgba(0,0,0,0.6); display: flex; align-items: center; justify-content: center; z-index: 1000;",
onclick: move |_| show_modal.set(false),
div {
style: "background: var(--bg-secondary); border: 1px solid var(--border-color); border-radius: 12px; padding: 24px; width: 480px; max-width: 90vw;",
onclick: move |e| e.stop_propagation(),
h3 { style: "margin: 0 0 16px 0;", "New Pentest Session" }
// Target selection
div { style: "margin-bottom: 12px;",
label { style: "display: block; font-size: 0.85rem; color: var(--text-secondary); margin-bottom: 4px;",
"Target"
}
select {
class: "chat-input",
style: "width: 100%; padding: 8px; resize: none; height: auto;",
value: "{new_target_id}",
onchange: move |e| new_target_id.set(e.value()),
option { value: "", "Select a target..." }
match &*targets.read() {
Some(Some(data)) => {
rsx! {
for target in &data.data {
{
let tid = target.get("id").and_then(|v| v.as_str()).unwrap_or("").to_string();
let tname = target.get("name").and_then(|v| v.as_str()).unwrap_or("Unknown").to_string();
let turl = target.get("base_url").and_then(|v| v.as_str()).unwrap_or("").to_string();
rsx! {
option { value: "{tid}", "{tname} ({turl})" }
}
}
}
}
},
_ => rsx! {},
}
}
}
// Strategy selection
div { style: "margin-bottom: 12px;",
label { style: "display: block; font-size: 0.85rem; color: var(--text-secondary); margin-bottom: 4px;",
"Strategy"
}
select {
class: "chat-input",
style: "width: 100%; padding: 8px; resize: none; height: auto;",
value: "{new_strategy}",
onchange: move |e| new_strategy.set(e.value()),
option { value: "comprehensive", "Comprehensive" }
option { value: "quick", "Quick Scan" }
option { value: "owasp_top_10", "OWASP Top 10" }
option { value: "api_focused", "API Focused" }
option { value: "authentication", "Authentication" }
}
}
// Initial message
div { style: "margin-bottom: 16px;",
label { style: "display: block; font-size: 0.85rem; color: var(--text-secondary); margin-bottom: 4px;",
"Initial Instructions"
}
textarea {
class: "chat-input",
style: "width: 100%; min-height: 80px;",
placeholder: "Describe the scope and goals of this pentest...",
value: "{new_message}",
oninput: move |e| new_message.set(e.value()),
}
}
div { style: "display: flex; justify-content: flex-end; gap: 8px;",
button {
class: "btn btn-ghost",
onclick: move |_| show_modal.set(false),
"Cancel"
}
button {
class: "btn btn-primary",
disabled: *creating.read() || new_target_id.read().is_empty() || new_message.read().is_empty(),
onclick: on_create,
if *creating.read() { "Creating..." } else { "Start Pentest" }
}
}
}
}
}
}
}

View File

@@ -0,0 +1,445 @@
use dioxus::prelude::*;
use dioxus_free_icons::icons::bs_icons::*;
use dioxus_free_icons::Icon;
use crate::app::Route;
use crate::infrastructure::pentest::{
fetch_attack_chain, fetch_pentest_findings, fetch_pentest_messages, fetch_pentest_session,
send_pentest_message,
};
#[component]
pub fn PentestSessionPage(session_id: String) -> Element {
let sid = session_id.clone();
let sid_for_session = session_id.clone();
let sid_for_findings = session_id.clone();
let sid_for_chain = session_id.clone();
let mut session = use_resource(move || {
let id = sid_for_session.clone();
async move { fetch_pentest_session(id).await.ok() }
});
let mut messages_res = use_resource(move || {
let id = sid.clone();
async move { fetch_pentest_messages(id).await.ok() }
});
let mut findings = use_resource(move || {
let id = sid_for_findings.clone();
async move { fetch_pentest_findings(id).await.ok() }
});
let mut attack_chain = use_resource(move || {
let id = sid_for_chain.clone();
async move { fetch_attack_chain(id).await.ok() }
});
let mut input_text = use_signal(String::new);
let mut sending = use_signal(|| false);
let mut right_tab = use_signal(|| "findings".to_string());
// Auto-poll messages every 3s when session is running
let session_status = {
let s = session.read();
match &*s {
Some(Some(resp)) => resp
.data
.get("status")
.and_then(|v| v.as_str())
.unwrap_or("unknown")
.to_string(),
_ => "unknown".to_string(),
}
};
let is_running = session_status == "running";
let sid_for_poll = session_id.clone();
use_effect(move || {
if is_running {
let _sid = sid_for_poll.clone();
spawn(async move {
#[cfg(feature = "web")]
gloo_timers::future::TimeoutFuture::new(3_000).await;
#[cfg(not(feature = "web"))]
tokio::time::sleep(std::time::Duration::from_secs(3)).await;
messages_res.restart();
findings.restart();
attack_chain.restart();
session.restart();
});
}
});
// Send message handler
let sid_for_send = session_id.clone();
let mut do_send = move || {
let text = input_text.read().trim().to_string();
if text.is_empty() || *sending.read() {
return;
}
let sid = sid_for_send.clone();
input_text.set(String::new());
sending.set(true);
spawn(async move {
let _ = send_pentest_message(sid, text).await;
sending.set(false);
messages_res.restart();
});
};
let mut do_send_click = do_send.clone();
// Session header info
let target_name = {
let s = session.read();
match &*s {
Some(Some(resp)) => resp
.data
.get("target_name")
.and_then(|v| v.as_str())
.unwrap_or("Pentest Session")
.to_string(),
_ => "Pentest Session".to_string(),
}
};
let strategy = {
let s = session.read();
match &*s {
Some(Some(resp)) => resp
.data
.get("strategy")
.and_then(|v| v.as_str())
.unwrap_or("-")
.to_string(),
_ => "-".to_string(),
}
};
let header_tool_count = {
let s = session.read();
match &*s {
Some(Some(resp)) => resp
.data
.get("tool_invocations")
.and_then(|v| v.as_u64())
.unwrap_or(0),
_ => 0,
}
};
let header_findings_count = {
let f = findings.read();
match &*f {
Some(Some(data)) => data.total.unwrap_or(0),
_ => 0,
}
};
let status_style = match session_status.as_str() {
"running" => "background: #16a34a; color: #fff;",
"completed" => "background: #2563eb; color: #fff;",
"failed" => "background: #dc2626; color: #fff;",
"paused" => "background: #d97706; color: #fff;",
_ => "background: var(--bg-tertiary); color: var(--text-secondary);",
};
rsx! {
div { class: "back-nav",
Link {
to: Route::PentestDashboardPage {},
class: "btn btn-ghost btn-back",
Icon { icon: BsArrowLeft, width: 16, height: 16 }
"Back to Pentest Dashboard"
}
}
// Session header
div { style: "display: flex; align-items: center; justify-content: space-between; margin-bottom: 16px; flex-wrap: wrap; gap: 8px;",
div {
h2 { style: "margin: 0 0 4px 0;", "{target_name}" }
div { style: "display: flex; gap: 8px; align-items: center; flex-wrap: wrap;",
span { class: "badge", style: "{status_style}", "{session_status}" }
span { class: "badge", style: "background: var(--bg-tertiary); color: var(--text-secondary);",
"{strategy}"
}
}
}
div { style: "display: flex; gap: 16px; font-size: 0.85rem; color: var(--text-secondary);",
span {
Icon { icon: BsWrench, width: 14, height: 14 }
" {header_tool_count} tools"
}
span {
Icon { icon: BsShieldExclamation, width: 14, height: 14 }
" {header_findings_count} findings"
}
}
}
// Split layout: chat left, findings/chain right
div { style: "display: grid; grid-template-columns: 1fr 380px; gap: 16px; height: calc(100vh - 220px); min-height: 400px;",
// Left: Chat area
div { class: "card", style: "display: flex; flex-direction: column; overflow: hidden;",
div { class: "card-header", style: "flex-shrink: 0;", "Chat" }
// Messages
div {
style: "flex: 1; overflow-y: auto; padding: 16px; display: flex; flex-direction: column; gap: 12px;",
match &*messages_res.read() {
Some(Some(data)) => {
let msgs = &data.data;
if msgs.is_empty() {
rsx! {
div { style: "text-align: center; color: var(--text-secondary); padding: 32px;",
h3 { style: "margin-bottom: 8px;", "Start the conversation" }
p { "Send a message to guide the pentest agent." }
}
}
} else {
rsx! {
for (i, msg) in msgs.iter().enumerate() {
{
let role = msg.get("role").and_then(|v| v.as_str()).unwrap_or("assistant").to_string();
let content = msg.get("content").and_then(|v| v.as_str()).unwrap_or("").to_string();
let msg_type = msg.get("type").and_then(|v| v.as_str()).unwrap_or("text").to_string();
let tool_name = msg.get("tool_name").and_then(|v| v.as_str()).unwrap_or("").to_string();
let tool_status = msg.get("tool_status").and_then(|v| v.as_str()).unwrap_or("").to_string();
if msg_type == "tool_call" || msg_type == "tool_result" {
// Tool invocation indicator
let tool_icon_style = match tool_status.as_str() {
"success" => "color: #16a34a;",
"error" => "color: #dc2626;",
"running" => "color: #d97706;",
_ => "color: var(--text-secondary);",
};
rsx! {
div {
key: "{i}",
style: "display: flex; align-items: center; gap: 8px; padding: 6px 12px; background: var(--bg-tertiary); border-radius: 6px; font-size: 0.8rem; color: var(--text-secondary);",
span { style: "{tool_icon_style}",
Icon { icon: BsWrench, width: 12, height: 12 }
}
span { style: "font-family: monospace;", "{tool_name}" }
if !tool_status.is_empty() {
span { class: "badge", style: "font-size: 0.7rem;", "{tool_status}" }
}
if !content.is_empty() {
details { style: "margin-left: auto; cursor: pointer;",
summary { style: "font-size: 0.75rem;", "details" }
pre { style: "margin-top: 4px; padding: 8px; background: var(--bg-primary); border-radius: 4px; font-size: 0.75rem; overflow-x: auto; max-height: 200px; white-space: pre-wrap;",
"{content}"
}
}
}
}
}
} else if role == "user" {
// User message - right aligned
rsx! {
div {
key: "{i}",
style: "display: flex; justify-content: flex-end;",
div {
style: "max-width: 80%; padding: 10px 14px; background: #2563eb; color: #fff; border-radius: 12px 12px 2px 12px; font-size: 0.9rem; line-height: 1.5; white-space: pre-wrap;",
"{content}"
}
}
}
} else {
// Assistant message - left aligned
rsx! {
div {
key: "{i}",
style: "display: flex; gap: 8px; align-items: flex-start;",
div {
style: "flex-shrink: 0; width: 28px; height: 28px; border-radius: 50%; background: var(--bg-tertiary); display: flex; align-items: center; justify-content: center;",
Icon { icon: BsCpu, width: 14, height: 14 }
}
div {
style: "max-width: 80%; padding: 10px 14px; background: var(--bg-tertiary); border-radius: 12px 12px 12px 2px; font-size: 0.9rem; line-height: 1.5; white-space: pre-wrap;",
"{content}"
}
}
}
}
}
}
}
}
},
Some(None) => rsx! { p { style: "padding: 16px; color: var(--text-secondary);", "Failed to load messages." } },
None => rsx! { p { style: "padding: 16px; color: var(--text-secondary);", "Loading messages..." } },
}
if *sending.read() {
div { style: "display: flex; gap: 8px; align-items: flex-start;",
div {
style: "flex-shrink: 0; width: 28px; height: 28px; border-radius: 50%; background: var(--bg-tertiary); display: flex; align-items: center; justify-content: center;",
Icon { icon: BsCpu, width: 14, height: 14 }
}
div {
style: "padding: 10px 14px; background: var(--bg-tertiary); border-radius: 12px 12px 12px 2px; font-size: 0.9rem; color: var(--text-secondary);",
"Thinking..."
}
}
}
}
// Input area
div { style: "flex-shrink: 0; padding: 12px; border-top: 1px solid var(--border-color); display: flex; gap: 8px;",
textarea {
class: "chat-input",
style: "flex: 1;",
placeholder: "Guide the pentest agent...",
value: "{input_text}",
oninput: move |e| input_text.set(e.value()),
onkeydown: move |e: Event<KeyboardData>| {
if e.key() == Key::Enter && !e.modifiers().shift() {
e.prevent_default();
do_send();
}
},
}
button {
class: "btn btn-primary",
style: "align-self: flex-end;",
disabled: *sending.read(),
onclick: move |_| do_send_click(),
"Send"
}
}
}
// Right: Findings / Attack Chain tabs
div { class: "card", style: "display: flex; flex-direction: column; overflow: hidden;",
// Tab bar
div { style: "display: flex; border-bottom: 1px solid var(--border-color); flex-shrink: 0;",
button {
style: if *right_tab.read() == "findings" {
"flex: 1; padding: 10px; background: none; border: none; border-bottom: 2px solid #2563eb; color: var(--text-primary); cursor: pointer; font-weight: 600;"
} else {
"flex: 1; padding: 10px; background: none; border: none; border-bottom: 2px solid transparent; color: var(--text-secondary); cursor: pointer;"
},
onclick: move |_| right_tab.set("findings".to_string()),
Icon { icon: BsShieldExclamation, width: 14, height: 14 }
" Findings ({header_findings_count})"
}
button {
style: if *right_tab.read() == "chain" {
"flex: 1; padding: 10px; background: none; border: none; border-bottom: 2px solid #2563eb; color: var(--text-primary); cursor: pointer; font-weight: 600;"
} else {
"flex: 1; padding: 10px; background: none; border: none; border-bottom: 2px solid transparent; color: var(--text-secondary); cursor: pointer;"
},
onclick: move |_| right_tab.set("chain".to_string()),
Icon { icon: BsDiagram3, width: 14, height: 14 }
" Attack Chain"
}
}
// Tab content
div { style: "flex: 1; overflow-y: auto; padding: 12px;",
if *right_tab.read() == "findings" {
// Findings tab
match &*findings.read() {
Some(Some(data)) => {
let finding_list = &data.data;
if finding_list.is_empty() {
rsx! {
div { style: "text-align: center; color: var(--text-secondary); padding: 24px;",
p { "No findings yet." }
}
}
} else {
rsx! {
div { style: "display: flex; flex-direction: column; gap: 8px;",
for finding in finding_list {
{
let title = finding.get("title").and_then(|v| v.as_str()).unwrap_or("Untitled").to_string();
let severity = finding.get("severity").and_then(|v| v.as_str()).unwrap_or("info").to_string();
let vuln_type = finding.get("vulnerability_type").and_then(|v| v.as_str()).unwrap_or("-").to_string();
let sev_style = match severity.as_str() {
"critical" => "background: #dc2626; color: #fff;",
"high" => "background: #ea580c; color: #fff;",
"medium" => "background: #d97706; color: #fff;",
"low" => "background: #2563eb; color: #fff;",
_ => "background: var(--bg-tertiary); color: var(--text-secondary);",
};
rsx! {
div { style: "padding: 10px; background: var(--bg-tertiary); border-radius: 8px;",
div { style: "display: flex; justify-content: space-between; align-items: center; margin-bottom: 4px;",
span { style: "font-weight: 600; font-size: 0.85rem;", "{title}" }
span { class: "badge", style: "{sev_style}", "{severity}" }
}
div { style: "font-size: 0.8rem; color: var(--text-secondary);", "{vuln_type}" }
}
}
}
}
}
}
}
},
Some(None) => rsx! { p { style: "color: var(--text-secondary);", "Failed to load findings." } },
None => rsx! { p { style: "color: var(--text-secondary);", "Loading..." } },
}
} else {
// Attack chain tab
match &*attack_chain.read() {
Some(Some(data)) => {
let steps = &data.data;
if steps.is_empty() {
rsx! {
div { style: "text-align: center; color: var(--text-secondary); padding: 24px;",
p { "No attack chain steps yet." }
}
}
} else {
rsx! {
div { style: "display: flex; flex-direction: column; gap: 4px;",
for (i, step) in steps.iter().enumerate() {
{
let step_name = step.get("name").and_then(|v| v.as_str()).unwrap_or("Step").to_string();
let step_status = step.get("status").and_then(|v| v.as_str()).unwrap_or("pending").to_string();
let description = step.get("description").and_then(|v| v.as_str()).unwrap_or("").to_string();
let step_num = i + 1;
let dot_color = match step_status.as_str() {
"completed" => "#16a34a",
"running" => "#d97706",
"failed" => "#dc2626",
_ => "var(--text-secondary)",
};
rsx! {
div { style: "display: flex; gap: 10px; padding: 8px 0;",
div { style: "display: flex; flex-direction: column; align-items: center;",
div { style: "width: 10px; height: 10px; border-radius: 50%; background: {dot_color}; flex-shrink: 0;" }
if i < steps.len() - 1 {
div { style: "width: 2px; flex: 1; background: var(--border-color); margin-top: 4px;" }
}
}
div {
div { style: "font-size: 0.85rem; font-weight: 600;", "{step_num}. {step_name}" }
if !description.is_empty() {
div { style: "font-size: 0.8rem; color: var(--text-secondary); margin-top: 2px;",
"{description}"
}
}
}
}
}
}
}
}
}
}
},
Some(None) => rsx! { p { style: "color: var(--text-secondary);", "Failed to load attack chain." } },
None => rsx! { p { style: "color: var(--text-secondary);", "Loading..." } },
}
}
}
}
}
}
}

View File

@@ -27,6 +27,10 @@ chromiumoxide = { version = "0.7", features = ["tokio-runtime"], default-feature
# Docker sandboxing
bollard = "0.18"
# TLS analysis
native-tls = "0.2"
tokio-native-tls = "0.3"
# Serialization
bson = { version = "2", features = ["chrono-0_4"] }
url = "2"

View File

@@ -2,5 +2,7 @@ pub mod agents;
pub mod crawler;
pub mod orchestrator;
pub mod recon;
pub mod tools;
pub use orchestrator::DastOrchestrator;
pub use tools::ToolRegistry;

View File

@@ -0,0 +1,146 @@
use compliance_core::error::CoreError;
use compliance_core::traits::dast_agent::{DastAgent, DastContext, DiscoveredEndpoint, EndpointParameter};
use compliance_core::traits::pentest_tool::{PentestTool, PentestToolContext, PentestToolResult};
use serde_json::json;
use crate::agents::api_fuzzer::ApiFuzzerAgent;
/// PentestTool wrapper around the existing ApiFuzzerAgent.
pub struct ApiFuzzerTool {
http: reqwest::Client,
agent: ApiFuzzerAgent,
}
impl ApiFuzzerTool {
pub fn new(http: reqwest::Client) -> Self {
let agent = ApiFuzzerAgent::new(http.clone());
Self { http, agent }
}
fn parse_endpoints(input: &serde_json::Value) -> Vec<DiscoveredEndpoint> {
let mut endpoints = Vec::new();
if let Some(arr) = input.get("endpoints").and_then(|v| v.as_array()) {
for ep in arr {
let url = ep.get("url").and_then(|v| v.as_str()).unwrap_or_default().to_string();
let method = ep.get("method").and_then(|v| v.as_str()).unwrap_or("GET").to_string();
let mut parameters = Vec::new();
if let Some(params) = ep.get("parameters").and_then(|v| v.as_array()) {
for p in params {
parameters.push(EndpointParameter {
name: p.get("name").and_then(|v| v.as_str()).unwrap_or_default().to_string(),
location: p.get("location").and_then(|v| v.as_str()).unwrap_or("query").to_string(),
param_type: p.get("param_type").and_then(|v| v.as_str()).map(String::from),
example_value: p.get("example_value").and_then(|v| v.as_str()).map(String::from),
});
}
}
endpoints.push(DiscoveredEndpoint {
url,
method,
parameters,
content_type: ep.get("content_type").and_then(|v| v.as_str()).map(String::from),
requires_auth: ep.get("requires_auth").and_then(|v| v.as_bool()).unwrap_or(false),
});
}
}
endpoints
}
}
impl PentestTool for ApiFuzzerTool {
fn name(&self) -> &str {
"api_fuzzer"
}
fn description(&self) -> &str {
"Fuzzes API endpoints to discover misconfigurations, information disclosure, and hidden \
endpoints. Probes common sensitive paths and tests for verbose error messages."
}
fn input_schema(&self) -> serde_json::Value {
json!({
"type": "object",
"properties": {
"endpoints": {
"type": "array",
"description": "Known endpoints to fuzz",
"items": {
"type": "object",
"properties": {
"url": { "type": "string" },
"method": { "type": "string", "enum": ["GET", "POST", "PUT", "PATCH", "DELETE"] },
"parameters": {
"type": "array",
"items": {
"type": "object",
"properties": {
"name": { "type": "string" },
"location": { "type": "string" },
"param_type": { "type": "string" },
"example_value": { "type": "string" }
},
"required": ["name"]
}
}
},
"required": ["url"]
}
},
"base_url": {
"type": "string",
"description": "Base URL to probe for common sensitive paths (used if no endpoints provided)"
}
}
})
}
fn execute<'a>(
&'a self,
input: serde_json::Value,
context: &'a PentestToolContext,
) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<PentestToolResult, CoreError>> + Send + 'a>> {
Box::pin(async move {
let mut endpoints = Self::parse_endpoints(&input);
// If a base_url is provided but no endpoints, create a default endpoint
if endpoints.is_empty() {
if let Some(base) = input.get("base_url").and_then(|v| v.as_str()) {
endpoints.push(DiscoveredEndpoint {
url: base.to_string(),
method: "GET".to_string(),
parameters: Vec::new(),
content_type: None,
requires_auth: false,
});
}
}
if endpoints.is_empty() {
return Ok(PentestToolResult {
summary: "No endpoints or base_url provided to fuzz.".to_string(),
findings: Vec::new(),
data: json!({}),
});
}
let dast_context = DastContext {
endpoints,
technologies: Vec::new(),
sast_hints: Vec::new(),
};
let findings = self.agent.run(&context.target, &dast_context).await?;
let count = findings.len();
Ok(PentestToolResult {
summary: if count > 0 {
format!("Found {count} API misconfigurations or information disclosures.")
} else {
"No API misconfigurations detected.".to_string()
},
findings,
data: json!({ "endpoints_tested": dast_context.endpoints.len() }),
})
})
}
}

View File

@@ -0,0 +1,130 @@
use compliance_core::error::CoreError;
use compliance_core::traits::dast_agent::{DastAgent, DastContext, DiscoveredEndpoint, EndpointParameter};
use compliance_core::traits::pentest_tool::{PentestTool, PentestToolContext, PentestToolResult};
use serde_json::json;
use crate::agents::auth_bypass::AuthBypassAgent;
/// PentestTool wrapper around the existing AuthBypassAgent.
pub struct AuthBypassTool {
http: reqwest::Client,
agent: AuthBypassAgent,
}
impl AuthBypassTool {
pub fn new(http: reqwest::Client) -> Self {
let agent = AuthBypassAgent::new(http.clone());
Self { http, agent }
}
fn parse_endpoints(input: &serde_json::Value) -> Vec<DiscoveredEndpoint> {
let mut endpoints = Vec::new();
if let Some(arr) = input.get("endpoints").and_then(|v| v.as_array()) {
for ep in arr {
let url = ep.get("url").and_then(|v| v.as_str()).unwrap_or_default().to_string();
let method = ep.get("method").and_then(|v| v.as_str()).unwrap_or("GET").to_string();
let mut parameters = Vec::new();
if let Some(params) = ep.get("parameters").and_then(|v| v.as_array()) {
for p in params {
parameters.push(EndpointParameter {
name: p.get("name").and_then(|v| v.as_str()).unwrap_or_default().to_string(),
location: p.get("location").and_then(|v| v.as_str()).unwrap_or("query").to_string(),
param_type: p.get("param_type").and_then(|v| v.as_str()).map(String::from),
example_value: p.get("example_value").and_then(|v| v.as_str()).map(String::from),
});
}
}
endpoints.push(DiscoveredEndpoint {
url,
method,
parameters,
content_type: ep.get("content_type").and_then(|v| v.as_str()).map(String::from),
requires_auth: ep.get("requires_auth").and_then(|v| v.as_bool()).unwrap_or(false),
});
}
}
endpoints
}
}
impl PentestTool for AuthBypassTool {
fn name(&self) -> &str {
"auth_bypass_scanner"
}
fn description(&self) -> &str {
"Tests endpoints for authentication bypass vulnerabilities. Tries accessing protected \
endpoints without credentials, with manipulated tokens, and with common default credentials."
}
fn input_schema(&self) -> serde_json::Value {
json!({
"type": "object",
"properties": {
"endpoints": {
"type": "array",
"description": "Endpoints to test for authentication bypass",
"items": {
"type": "object",
"properties": {
"url": { "type": "string" },
"method": { "type": "string", "enum": ["GET", "POST", "PUT", "PATCH", "DELETE"] },
"parameters": {
"type": "array",
"items": {
"type": "object",
"properties": {
"name": { "type": "string" },
"location": { "type": "string" },
"param_type": { "type": "string" },
"example_value": { "type": "string" }
},
"required": ["name"]
}
},
"requires_auth": { "type": "boolean", "description": "Whether this endpoint requires authentication" }
},
"required": ["url", "method"]
}
}
},
"required": ["endpoints"]
})
}
fn execute<'a>(
&'a self,
input: serde_json::Value,
context: &'a PentestToolContext,
) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<PentestToolResult, CoreError>> + Send + 'a>> {
Box::pin(async move {
let endpoints = Self::parse_endpoints(&input);
if endpoints.is_empty() {
return Ok(PentestToolResult {
summary: "No endpoints provided to test.".to_string(),
findings: Vec::new(),
data: json!({}),
});
}
let dast_context = DastContext {
endpoints,
technologies: Vec::new(),
sast_hints: Vec::new(),
};
let findings = self.agent.run(&context.target, &dast_context).await?;
let count = findings.len();
Ok(PentestToolResult {
summary: if count > 0 {
format!("Found {count} authentication bypass vulnerabilities.")
} else {
"No authentication bypass vulnerabilities detected.".to_string()
},
findings,
data: json!({ "endpoints_tested": dast_context.endpoints.len() }),
})
})
}
}

View File

@@ -0,0 +1,326 @@
use compliance_core::error::CoreError;
use compliance_core::models::dast::{DastEvidence, DastFinding, DastVulnType};
use compliance_core::models::Severity;
use compliance_core::traits::pentest_tool::{PentestTool, PentestToolContext, PentestToolResult};
use serde_json::json;
use tracing::info;
/// Tool that detects console.log and similar debug statements in frontend JavaScript.
pub struct ConsoleLogDetectorTool {
http: reqwest::Client,
}
/// A detected console statement with its context.
#[derive(Debug)]
struct ConsoleMatch {
pattern: String,
file_url: String,
line_snippet: String,
line_number: Option<usize>,
}
impl ConsoleLogDetectorTool {
pub fn new(http: reqwest::Client) -> Self {
Self { http }
}
/// Patterns that indicate debug/logging statements left in production code.
fn patterns() -> Vec<&'static str> {
vec![
"console.log(",
"console.debug(",
"console.error(",
"console.warn(",
"console.info(",
"console.trace(",
"console.dir(",
"console.table(",
"debugger;",
"alert(",
]
}
/// Extract JavaScript file URLs from an HTML page body.
fn extract_js_urls(html: &str, base_url: &str) -> Vec<String> {
let mut urls = Vec::new();
let base = url::Url::parse(base_url).ok();
// Simple regex-free extraction of <script src="...">
let mut search_from = 0;
while let Some(start) = html[search_from..].find("src=") {
let abs_start = search_from + start + 4;
if abs_start >= html.len() {
break;
}
let quote = html.as_bytes().get(abs_start).copied();
let (open, close) = match quote {
Some(b'"') => ('"', '"'),
Some(b'\'') => ('\'', '\''),
_ => {
search_from = abs_start + 1;
continue;
}
};
let val_start = abs_start + 1;
if let Some(end) = html[val_start..].find(close) {
let src = &html[val_start..val_start + end];
if src.ends_with(".js") || src.contains(".js?") || src.contains("/js/") {
let full_url = if src.starts_with("http://") || src.starts_with("https://") {
src.to_string()
} else if src.starts_with("//") {
format!("https:{src}")
} else if let Some(ref base) = base {
base.join(src).map(|u| u.to_string()).unwrap_or_default()
} else {
format!("{base_url}/{}", src.trim_start_matches('/'))
};
if !full_url.is_empty() {
urls.push(full_url);
}
}
search_from = val_start + end + 1;
} else {
break;
}
}
urls
}
/// Search a JS file's contents for console/debug patterns.
fn scan_js_content(content: &str, file_url: &str) -> Vec<ConsoleMatch> {
let mut matches = Vec::new();
for (line_num, line) in content.lines().enumerate() {
let trimmed = line.trim();
// Skip comments (basic heuristic)
if trimmed.starts_with("//") || trimmed.starts_with('*') || trimmed.starts_with("/*") {
continue;
}
for pattern in Self::patterns() {
if line.contains(pattern) {
let snippet = if line.len() > 200 {
format!("{}...", &line[..200])
} else {
line.to_string()
};
matches.push(ConsoleMatch {
pattern: pattern.trim_end_matches('(').to_string(),
file_url: file_url.to_string(),
line_snippet: snippet.trim().to_string(),
line_number: Some(line_num + 1),
});
break; // One match per line is enough
}
}
}
matches
}
}
impl PentestTool for ConsoleLogDetectorTool {
fn name(&self) -> &str {
"console_log_detector"
}
fn description(&self) -> &str {
"Detects console.log, console.debug, console.error, debugger, and similar debug \
statements left in production JavaScript. Fetches the HTML page and referenced JS files."
}
fn input_schema(&self) -> serde_json::Value {
json!({
"type": "object",
"properties": {
"url": {
"type": "string",
"description": "URL of the page to check for console.log leakage"
},
"additional_js_urls": {
"type": "array",
"description": "Optional additional JavaScript file URLs to scan",
"items": { "type": "string" }
}
},
"required": ["url"]
})
}
fn execute<'a>(
&'a self,
input: serde_json::Value,
context: &'a PentestToolContext,
) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<PentestToolResult, CoreError>> + Send + 'a>> {
Box::pin(async move {
let url = input
.get("url")
.and_then(|v| v.as_str())
.ok_or_else(|| CoreError::Dast("Missing required 'url' parameter".to_string()))?;
let additional_js: Vec<String> = input
.get("additional_js_urls")
.and_then(|v| v.as_array())
.map(|arr| {
arr.iter()
.filter_map(|v| v.as_str().map(String::from))
.collect()
})
.unwrap_or_default();
let target_id = context
.target
.id
.map(|oid| oid.to_hex())
.unwrap_or_else(|| "unknown".to_string());
// Fetch the main page
let response = self
.http
.get(url)
.send()
.await
.map_err(|e| CoreError::Dast(format!("Failed to fetch {url}: {e}")))?;
let html = response.text().await.unwrap_or_default();
// Scan inline scripts in the HTML
let mut all_matches = Vec::new();
let inline_matches = Self::scan_js_content(&html, url);
all_matches.extend(inline_matches);
// Extract JS file URLs from the HTML
let mut js_urls = Self::extract_js_urls(&html, url);
js_urls.extend(additional_js);
js_urls.dedup();
// Fetch and scan each JS file
for js_url in &js_urls {
match self.http.get(js_url).send().await {
Ok(resp) => {
if resp.status().is_success() {
let js_content = resp.text().await.unwrap_or_default();
// Only scan non-minified-looking files or files where we can still
// find patterns (minifiers typically strip console calls, but not always)
let file_matches = Self::scan_js_content(&js_content, js_url);
all_matches.extend(file_matches);
}
}
Err(_) => continue,
}
}
let mut findings = Vec::new();
let match_data: Vec<serde_json::Value> = all_matches
.iter()
.map(|m| {
json!({
"pattern": m.pattern,
"file": m.file_url,
"line": m.line_number,
"snippet": m.line_snippet,
})
})
.collect();
if !all_matches.is_empty() {
// Group by file for the finding
let mut by_file: std::collections::HashMap<&str, Vec<&ConsoleMatch>> =
std::collections::HashMap::new();
for m in &all_matches {
by_file.entry(&m.file_url).or_default().push(m);
}
for (file_url, matches) in &by_file {
let pattern_summary: Vec<String> = matches
.iter()
.take(5)
.map(|m| {
format!(
" Line {}: {} - {}",
m.line_number.unwrap_or(0),
m.pattern,
if m.line_snippet.len() > 80 {
format!("{}...", &m.line_snippet[..80])
} else {
m.line_snippet.clone()
}
)
})
.collect();
let evidence = DastEvidence {
request_method: "GET".to_string(),
request_url: file_url.to_string(),
request_headers: None,
request_body: None,
response_status: 200,
response_headers: None,
response_snippet: Some(pattern_summary.join("\n")),
screenshot_path: None,
payload: None,
response_time_ms: None,
};
let total = matches.len();
let extra = if total > 5 {
format!(" (and {} more)", total - 5)
} else {
String::new()
};
let mut finding = DastFinding::new(
String::new(),
target_id.clone(),
DastVulnType::ConsoleLogLeakage,
format!("Console/debug statements in {}", file_url),
format!(
"Found {total} console/debug statements in {file_url}{extra}. \
These can leak sensitive information such as API responses, user data, \
or internal state to anyone with browser developer tools open."
),
Severity::Low,
file_url.to_string(),
"GET".to_string(),
);
finding.cwe = Some("CWE-532".to_string());
finding.evidence = vec![evidence];
finding.remediation = Some(
"Remove console.log/debug/error statements from production code. \
Use a build step (e.g., babel plugin, terser) to strip console calls \
during the production build."
.to_string(),
);
findings.push(finding);
}
}
let total_matches = all_matches.len();
let count = findings.len();
info!(url, js_files = js_urls.len(), total_matches, "Console log detection complete");
Ok(PentestToolResult {
summary: if total_matches > 0 {
format!(
"Found {total_matches} console/debug statements across {} files.",
count
)
} else {
format!(
"No console/debug statements found in HTML or {} JS files.",
js_urls.len()
)
},
findings,
data: json!({
"total_matches": total_matches,
"js_files_scanned": js_urls.len(),
"matches": match_data,
}),
})
})
}
}

View File

@@ -0,0 +1,401 @@
use compliance_core::error::CoreError;
use compliance_core::models::dast::{DastEvidence, DastFinding, DastVulnType};
use compliance_core::models::Severity;
use compliance_core::traits::pentest_tool::{PentestTool, PentestToolContext, PentestToolResult};
use serde_json::json;
use tracing::info;
/// Tool that inspects cookies set by a target for security issues.
pub struct CookieAnalyzerTool {
http: reqwest::Client,
}
/// Parsed attributes from a Set-Cookie header.
#[derive(Debug)]
struct ParsedCookie {
name: String,
value: String,
secure: bool,
http_only: bool,
same_site: Option<String>,
domain: Option<String>,
path: Option<String>,
raw: String,
}
impl CookieAnalyzerTool {
pub fn new(http: reqwest::Client) -> Self {
Self { http }
}
/// Parse a Set-Cookie header value into a structured representation.
fn parse_set_cookie(header: &str) -> ParsedCookie {
let raw = header.to_string();
let parts: Vec<&str> = header.split(';').collect();
let (name, value) = if let Some(kv) = parts.first() {
let mut kv_split = kv.splitn(2, '=');
let k = kv_split.next().unwrap_or("").trim().to_string();
let v = kv_split.next().unwrap_or("").trim().to_string();
(k, v)
} else {
(String::new(), String::new())
};
let mut secure = false;
let mut http_only = false;
let mut same_site = None;
let mut domain = None;
let mut path = None;
for part in parts.iter().skip(1) {
let trimmed = part.trim().to_lowercase();
if trimmed == "secure" {
secure = true;
} else if trimmed == "httponly" {
http_only = true;
} else if let Some(ss) = trimmed.strip_prefix("samesite=") {
same_site = Some(ss.trim().to_string());
} else if let Some(d) = trimmed.strip_prefix("domain=") {
domain = Some(d.trim().to_string());
} else if let Some(p) = trimmed.strip_prefix("path=") {
path = Some(p.trim().to_string());
}
}
ParsedCookie {
name,
value,
secure,
http_only,
same_site,
domain,
path,
raw,
}
}
/// Heuristic: does this cookie name suggest it's a session / auth cookie?
fn is_sensitive_cookie(name: &str) -> bool {
let lower = name.to_lowercase();
lower.contains("session")
|| lower.contains("sess")
|| lower.contains("token")
|| lower.contains("auth")
|| lower.contains("jwt")
|| lower.contains("csrf")
|| lower.contains("sid")
|| lower == "connect.sid"
|| lower == "phpsessid"
|| lower == "jsessionid"
|| lower == "asp.net_sessionid"
}
}
impl PentestTool for CookieAnalyzerTool {
fn name(&self) -> &str {
"cookie_analyzer"
}
fn description(&self) -> &str {
"Analyzes cookies set by a target URL. Checks for Secure, HttpOnly, SameSite attributes \
and overly broad Domain/Path settings. Focuses on session and authentication cookies."
}
fn input_schema(&self) -> serde_json::Value {
json!({
"type": "object",
"properties": {
"url": {
"type": "string",
"description": "URL to fetch and analyze cookies from"
},
"login_url": {
"type": "string",
"description": "Optional login URL to also check (may set auth cookies)"
}
},
"required": ["url"]
})
}
fn execute<'a>(
&'a self,
input: serde_json::Value,
context: &'a PentestToolContext,
) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<PentestToolResult, CoreError>> + Send + 'a>> {
Box::pin(async move {
let url = input
.get("url")
.and_then(|v| v.as_str())
.ok_or_else(|| CoreError::Dast("Missing required 'url' parameter".to_string()))?;
let login_url = input.get("login_url").and_then(|v| v.as_str());
let target_id = context
.target
.id
.map(|oid| oid.to_hex())
.unwrap_or_else(|| "unknown".to_string());
let mut findings = Vec::new();
let mut cookie_data = Vec::new();
// Collect Set-Cookie headers from the main URL and optional login URL
let urls_to_check: Vec<&str> = std::iter::once(url)
.chain(login_url.into_iter())
.collect();
for check_url in &urls_to_check {
// Use a client that does NOT follow redirects so we catch cookies on redirect responses
let no_redirect_client = reqwest::Client::builder()
.danger_accept_invalid_certs(true)
.redirect(reqwest::redirect::Policy::none())
.timeout(std::time::Duration::from_secs(15))
.build()
.map_err(|e| CoreError::Dast(format!("Client build error: {e}")))?;
let response = match no_redirect_client.get(*check_url).send().await {
Ok(r) => r,
Err(e) => {
// Try with the main client that follows redirects
match self.http.get(*check_url).send().await {
Ok(r) => r,
Err(_) => continue,
}
}
};
let status = response.status().as_u16();
let set_cookie_headers: Vec<String> = response
.headers()
.get_all("set-cookie")
.iter()
.filter_map(|v| v.to_str().ok().map(String::from))
.collect();
for raw_cookie in &set_cookie_headers {
let cookie = Self::parse_set_cookie(raw_cookie);
let is_sensitive = Self::is_sensitive_cookie(&cookie.name);
let is_https = check_url.starts_with("https://");
let cookie_info = json!({
"name": cookie.name,
"secure": cookie.secure,
"http_only": cookie.http_only,
"same_site": cookie.same_site,
"domain": cookie.domain,
"path": cookie.path,
"is_sensitive": is_sensitive,
"url": check_url,
});
cookie_data.push(cookie_info);
// Check: missing Secure flag
if !cookie.secure && (is_https || is_sensitive) {
let severity = if is_sensitive {
Severity::High
} else {
Severity::Medium
};
let evidence = DastEvidence {
request_method: "GET".to_string(),
request_url: check_url.to_string(),
request_headers: None,
request_body: None,
response_status: status,
response_headers: None,
response_snippet: Some(cookie.raw.clone()),
screenshot_path: None,
payload: None,
response_time_ms: None,
};
let mut finding = DastFinding::new(
String::new(),
target_id.clone(),
DastVulnType::CookieSecurity,
format!("Cookie '{}' missing Secure flag", cookie.name),
format!(
"The cookie '{}' does not have the Secure attribute set. \
Without this flag, the cookie can be transmitted over unencrypted HTTP connections.",
cookie.name
),
severity,
check_url.to_string(),
"GET".to_string(),
);
finding.cwe = Some("CWE-614".to_string());
finding.evidence = vec![evidence];
finding.remediation = Some(
"Add the 'Secure' attribute to the Set-Cookie header to ensure the \
cookie is only sent over HTTPS connections."
.to_string(),
);
findings.push(finding);
}
// Check: missing HttpOnly flag on sensitive cookies
if !cookie.http_only && is_sensitive {
let evidence = DastEvidence {
request_method: "GET".to_string(),
request_url: check_url.to_string(),
request_headers: None,
request_body: None,
response_status: status,
response_headers: None,
response_snippet: Some(cookie.raw.clone()),
screenshot_path: None,
payload: None,
response_time_ms: None,
};
let mut finding = DastFinding::new(
String::new(),
target_id.clone(),
DastVulnType::CookieSecurity,
format!("Cookie '{}' missing HttpOnly flag", cookie.name),
format!(
"The session/auth cookie '{}' does not have the HttpOnly attribute. \
This makes it accessible to JavaScript, increasing the impact of XSS attacks.",
cookie.name
),
Severity::High,
check_url.to_string(),
"GET".to_string(),
);
finding.cwe = Some("CWE-1004".to_string());
finding.evidence = vec![evidence];
finding.remediation = Some(
"Add the 'HttpOnly' attribute to the Set-Cookie header to prevent \
JavaScript access to the cookie."
.to_string(),
);
findings.push(finding);
}
// Check: missing or weak SameSite
if is_sensitive {
let weak_same_site = match &cookie.same_site {
None => true,
Some(ss) => ss == "none",
};
if weak_same_site {
let evidence = DastEvidence {
request_method: "GET".to_string(),
request_url: check_url.to_string(),
request_headers: None,
request_body: None,
response_status: status,
response_headers: None,
response_snippet: Some(cookie.raw.clone()),
screenshot_path: None,
payload: None,
response_time_ms: None,
};
let desc = if cookie.same_site.is_none() {
format!(
"The session/auth cookie '{}' does not have a SameSite attribute. \
This may allow cross-site request forgery (CSRF) attacks.",
cookie.name
)
} else {
format!(
"The session/auth cookie '{}' has SameSite=None, which allows it \
to be sent in cross-site requests, enabling CSRF attacks.",
cookie.name
)
};
let mut finding = DastFinding::new(
String::new(),
target_id.clone(),
DastVulnType::CookieSecurity,
format!("Cookie '{}' missing or weak SameSite", cookie.name),
desc,
Severity::Medium,
check_url.to_string(),
"GET".to_string(),
);
finding.cwe = Some("CWE-1275".to_string());
finding.evidence = vec![evidence];
finding.remediation = Some(
"Set 'SameSite=Strict' or 'SameSite=Lax' on session/auth cookies \
to prevent cross-site request inclusion."
.to_string(),
);
findings.push(finding);
}
}
// Check: overly broad domain
if let Some(ref domain) = cookie.domain {
// A domain starting with a dot applies to all subdomains
let dot_domain = domain.starts_with('.');
// Count domain parts - if only 2 parts (e.g., .example.com), it's broad
let parts: Vec<&str> = domain.trim_start_matches('.').split('.').collect();
if dot_domain && parts.len() <= 2 && is_sensitive {
let evidence = DastEvidence {
request_method: "GET".to_string(),
request_url: check_url.to_string(),
request_headers: None,
request_body: None,
response_status: status,
response_headers: None,
response_snippet: Some(cookie.raw.clone()),
screenshot_path: None,
payload: None,
response_time_ms: None,
};
let mut finding = DastFinding::new(
String::new(),
target_id.clone(),
DastVulnType::CookieSecurity,
format!("Cookie '{}' has overly broad domain", cookie.name),
format!(
"The cookie '{}' is scoped to domain '{}' which includes all \
subdomains. If any subdomain is compromised, the attacker can \
access this cookie.",
cookie.name, domain
),
Severity::Low,
check_url.to_string(),
"GET".to_string(),
);
finding.cwe = Some("CWE-1004".to_string());
finding.evidence = vec![evidence];
finding.remediation = Some(
"Restrict the cookie domain to the specific subdomain that needs it \
rather than the entire parent domain."
.to_string(),
);
findings.push(finding);
}
}
}
}
let count = findings.len();
info!(url, findings = count, "Cookie analysis complete");
Ok(PentestToolResult {
summary: if count > 0 {
format!("Found {count} cookie security issues.")
} else if cookie_data.is_empty() {
"No cookies were set by the target.".to_string()
} else {
"All cookies have proper security attributes.".to_string()
},
findings,
data: json!({
"cookies": cookie_data,
"total_cookies": cookie_data.len(),
}),
})
})
}
}

View File

@@ -0,0 +1,410 @@
use std::collections::HashMap;
use compliance_core::error::CoreError;
use compliance_core::models::dast::{DastEvidence, DastFinding, DastVulnType};
use compliance_core::models::Severity;
use compliance_core::traits::pentest_tool::{PentestTool, PentestToolContext, PentestToolResult};
use serde_json::json;
use tracing::{info, warn};
/// Tool that checks CORS configuration for security issues.
pub struct CorsCheckerTool {
http: reqwest::Client,
}
impl CorsCheckerTool {
pub fn new(http: reqwest::Client) -> Self {
Self { http }
}
/// Origins to test against the target.
fn test_origins(target_host: &str) -> Vec<(&'static str, String)> {
vec![
("null_origin", "null".to_string()),
("evil_domain", "https://evil.com".to_string()),
(
"subdomain_spoof",
format!("https://{target_host}.evil.com"),
),
(
"prefix_spoof",
format!("https://evil-{target_host}"),
),
("http_downgrade", format!("http://{target_host}")),
]
}
}
impl PentestTool for CorsCheckerTool {
fn name(&self) -> &str {
"cors_checker"
}
fn description(&self) -> &str {
"Checks CORS configuration by sending requests with various Origin headers. Tests for \
wildcard origins, reflected origins, null origin acceptance, and dangerous \
Access-Control-Allow-Credentials combinations."
}
fn input_schema(&self) -> serde_json::Value {
json!({
"type": "object",
"properties": {
"url": {
"type": "string",
"description": "URL to test CORS configuration on"
},
"additional_origins": {
"type": "array",
"description": "Optional additional origin values to test",
"items": { "type": "string" }
}
},
"required": ["url"]
})
}
fn execute<'a>(
&'a self,
input: serde_json::Value,
context: &'a PentestToolContext,
) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<PentestToolResult, CoreError>> + Send + 'a>> {
Box::pin(async move {
let url = input
.get("url")
.and_then(|v| v.as_str())
.ok_or_else(|| CoreError::Dast("Missing required 'url' parameter".to_string()))?;
let additional_origins: Vec<String> = input
.get("additional_origins")
.and_then(|v| v.as_array())
.map(|arr| {
arr.iter()
.filter_map(|v| v.as_str().map(String::from))
.collect()
})
.unwrap_or_default();
let target_id = context
.target
.id
.map(|oid| oid.to_hex())
.unwrap_or_else(|| "unknown".to_string());
let target_host = url::Url::parse(url)
.ok()
.and_then(|u| u.host_str().map(String::from))
.unwrap_or_else(|| url.to_string());
let mut findings = Vec::new();
let mut cors_data: Vec<serde_json::Value> = Vec::new();
// First, send a request without Origin to get baseline
let baseline = self
.http
.get(url)
.send()
.await
.map_err(|e| CoreError::Dast(format!("Failed to fetch {url}: {e}")))?;
let baseline_acao = baseline
.headers()
.get("access-control-allow-origin")
.and_then(|v| v.to_str().ok())
.map(String::from);
cors_data.push(json!({
"origin": null,
"acao": baseline_acao,
}));
// Check for wildcard + credentials (dangerous combo)
if let Some(ref acao) = baseline_acao {
if acao == "*" {
let acac = baseline
.headers()
.get("access-control-allow-credentials")
.and_then(|v| v.to_str().ok())
.unwrap_or("");
if acac.to_lowercase() == "true" {
let evidence = DastEvidence {
request_method: "GET".to_string(),
request_url: url.to_string(),
request_headers: None,
request_body: None,
response_status: baseline.status().as_u16(),
response_headers: None,
response_snippet: Some(format!(
"Access-Control-Allow-Origin: *\nAccess-Control-Allow-Credentials: true"
)),
screenshot_path: None,
payload: None,
response_time_ms: None,
};
let mut finding = DastFinding::new(
String::new(),
target_id.clone(),
DastVulnType::CorsMisconfiguration,
"CORS wildcard with credentials".to_string(),
format!(
"The endpoint {url} returns Access-Control-Allow-Origin: * with \
Access-Control-Allow-Credentials: true. While browsers should block this \
combination, it indicates a serious CORS misconfiguration."
),
Severity::High,
url.to_string(),
"GET".to_string(),
);
finding.cwe = Some("CWE-942".to_string());
finding.evidence = vec![evidence];
finding.remediation = Some(
"Never combine Access-Control-Allow-Origin: * with \
Access-Control-Allow-Credentials: true. Specify explicit allowed origins."
.to_string(),
);
findings.push(finding);
}
}
}
// Test with various Origin headers
let mut test_origins = Self::test_origins(&target_host);
for origin in &additional_origins {
test_origins.push(("custom", origin.clone()));
}
for (test_name, origin) in &test_origins {
let resp = match self
.http
.get(url)
.header("Origin", origin.as_str())
.send()
.await
{
Ok(r) => r,
Err(_) => continue,
};
let status = resp.status().as_u16();
let acao = resp
.headers()
.get("access-control-allow-origin")
.and_then(|v| v.to_str().ok())
.map(String::from);
let acac = resp
.headers()
.get("access-control-allow-credentials")
.and_then(|v| v.to_str().ok())
.map(String::from);
let acam = resp
.headers()
.get("access-control-allow-methods")
.and_then(|v| v.to_str().ok())
.map(String::from);
cors_data.push(json!({
"test": test_name,
"origin": origin,
"acao": acao,
"acac": acac,
"acam": acam,
"status": status,
}));
// Check if the origin was reflected back
if let Some(ref acao_val) = acao {
let origin_reflected = acao_val == origin;
let credentials_allowed = acac
.as_ref()
.map(|v| v.to_lowercase() == "true")
.unwrap_or(false);
if origin_reflected && *test_name != "http_downgrade" {
let severity = if credentials_allowed {
Severity::Critical
} else {
Severity::High
};
let resp_headers: HashMap<String, String> = resp
.headers()
.iter()
.map(|(k, v)| (k.to_string(), v.to_str().unwrap_or("").to_string()))
.collect();
let evidence = DastEvidence {
request_method: "GET".to_string(),
request_url: url.to_string(),
request_headers: Some(
[("Origin".to_string(), origin.clone())]
.into_iter()
.collect(),
),
request_body: None,
response_status: status,
response_headers: Some(resp_headers),
response_snippet: Some(format!(
"Origin: {origin}\nAccess-Control-Allow-Origin: {acao_val}\n\
Access-Control-Allow-Credentials: {}",
acac.as_deref().unwrap_or("not set")
)),
screenshot_path: None,
payload: Some(origin.clone()),
response_time_ms: None,
};
let title = match *test_name {
"null_origin" => {
"CORS accepts null origin".to_string()
}
"evil_domain" => {
"CORS reflects arbitrary origin".to_string()
}
"subdomain_spoof" => {
"CORS vulnerable to subdomain spoofing".to_string()
}
"prefix_spoof" => {
"CORS vulnerable to prefix spoofing".to_string()
}
_ => format!("CORS reflects untrusted origin ({test_name})"),
};
let cred_note = if credentials_allowed {
" Combined with Access-Control-Allow-Credentials: true, this allows \
the attacker to steal authenticated data."
} else {
""
};
let mut finding = DastFinding::new(
String::new(),
target_id.clone(),
DastVulnType::CorsMisconfiguration,
title,
format!(
"The endpoint {url} reflects the Origin header '{origin}' back in \
Access-Control-Allow-Origin, allowing cross-origin requests from \
untrusted domains.{cred_note}"
),
severity,
url.to_string(),
"GET".to_string(),
);
finding.cwe = Some("CWE-942".to_string());
finding.exploitable = credentials_allowed;
finding.evidence = vec![evidence];
finding.remediation = Some(
"Validate the Origin header against a whitelist of trusted origins. \
Do not reflect the Origin header value directly. Use specific allowed \
origins instead of wildcards."
.to_string(),
);
findings.push(finding);
warn!(
url,
test_name,
origin,
credentials = credentials_allowed,
"CORS misconfiguration detected"
);
}
// Special case: HTTP downgrade
if *test_name == "http_downgrade" && origin_reflected && credentials_allowed {
let evidence = DastEvidence {
request_method: "GET".to_string(),
request_url: url.to_string(),
request_headers: Some(
[("Origin".to_string(), origin.clone())]
.into_iter()
.collect(),
),
request_body: None,
response_status: status,
response_headers: None,
response_snippet: Some(format!(
"HTTP origin accepted: {origin} -> ACAO: {acao_val}"
)),
screenshot_path: None,
payload: Some(origin.clone()),
response_time_ms: None,
};
let mut finding = DastFinding::new(
String::new(),
target_id.clone(),
DastVulnType::CorsMisconfiguration,
"CORS allows HTTP origin with credentials".to_string(),
format!(
"The HTTPS endpoint {url} accepts the HTTP origin {origin} with \
credentials. An attacker performing a man-in-the-middle attack on \
the HTTP version could steal authenticated data."
),
Severity::High,
url.to_string(),
"GET".to_string(),
);
finding.cwe = Some("CWE-942".to_string());
finding.evidence = vec![evidence];
finding.remediation = Some(
"Do not accept HTTP origins for HTTPS endpoints. Ensure CORS \
origin validation enforces the https:// scheme."
.to_string(),
);
findings.push(finding);
}
}
}
// Also send a preflight OPTIONS request
if let Ok(resp) = self
.http
.request(reqwest::Method::OPTIONS, url)
.header("Origin", "https://evil.com")
.header("Access-Control-Request-Method", "POST")
.header("Access-Control-Request-Headers", "Authorization, Content-Type")
.send()
.await
{
let acam = resp
.headers()
.get("access-control-allow-methods")
.and_then(|v| v.to_str().ok())
.map(String::from);
let acah = resp
.headers()
.get("access-control-allow-headers")
.and_then(|v| v.to_str().ok())
.map(String::from);
cors_data.push(json!({
"test": "preflight",
"status": resp.status().as_u16(),
"allow_methods": acam,
"allow_headers": acah,
}));
}
let count = findings.len();
info!(url, findings = count, "CORS check complete");
Ok(PentestToolResult {
summary: if count > 0 {
format!("Found {count} CORS misconfiguration issues for {url}.")
} else {
format!("CORS configuration appears secure for {url}.")
},
findings,
data: json!({
"tests": cors_data,
}),
})
})
}
}

View File

@@ -0,0 +1,447 @@
use compliance_core::error::CoreError;
use compliance_core::models::dast::{DastEvidence, DastFinding, DastVulnType};
use compliance_core::models::Severity;
use compliance_core::traits::pentest_tool::{PentestTool, PentestToolContext, PentestToolResult};
use serde_json::json;
use tracing::info;
/// Tool that analyzes Content-Security-Policy headers.
pub struct CspAnalyzerTool {
http: reqwest::Client,
}
/// A parsed CSP directive.
#[derive(Debug)]
struct CspDirective {
name: String,
values: Vec<String>,
}
impl CspAnalyzerTool {
pub fn new(http: reqwest::Client) -> Self {
Self { http }
}
/// Parse a CSP header string into directives.
fn parse_csp(csp: &str) -> Vec<CspDirective> {
let mut directives = Vec::new();
for part in csp.split(';') {
let trimmed = part.trim();
if trimmed.is_empty() {
continue;
}
let tokens: Vec<&str> = trimmed.split_whitespace().collect();
if let Some((name, values)) = tokens.split_first() {
directives.push(CspDirective {
name: name.to_lowercase(),
values: values.iter().map(|v| v.to_string()).collect(),
});
}
}
directives
}
/// Check a CSP for common issues and return findings.
fn analyze_directives(
directives: &[CspDirective],
url: &str,
target_id: &str,
status: u16,
csp_raw: &str,
) -> Vec<DastFinding> {
let mut findings = Vec::new();
let make_evidence = |snippet: String| DastEvidence {
request_method: "GET".to_string(),
request_url: url.to_string(),
request_headers: None,
request_body: None,
response_status: status,
response_headers: None,
response_snippet: Some(snippet),
screenshot_path: None,
payload: None,
response_time_ms: None,
};
// Check for unsafe-inline in script-src
for d in directives {
if (d.name == "script-src" || d.name == "default-src")
&& d.values.iter().any(|v| v == "'unsafe-inline'")
{
let evidence = make_evidence(format!("{}: {}", d.name, d.values.join(" ")));
let mut finding = DastFinding::new(
String::new(),
target_id.to_string(),
DastVulnType::CspIssue,
format!("CSP allows 'unsafe-inline' in {}", d.name),
format!(
"The Content-Security-Policy directive '{}' includes 'unsafe-inline', \
which defeats the purpose of CSP by allowing inline scripts that \
could be exploited via XSS.",
d.name
),
Severity::High,
url.to_string(),
"GET".to_string(),
);
finding.cwe = Some("CWE-79".to_string());
finding.evidence = vec![evidence];
finding.remediation = Some(
"Remove 'unsafe-inline' from script-src. Use nonces or hashes for \
legitimate inline scripts instead."
.to_string(),
);
findings.push(finding);
}
// Check for unsafe-eval
if (d.name == "script-src" || d.name == "default-src")
&& d.values.iter().any(|v| v == "'unsafe-eval'")
{
let evidence = make_evidence(format!("{}: {}", d.name, d.values.join(" ")));
let mut finding = DastFinding::new(
String::new(),
target_id.to_string(),
DastVulnType::CspIssue,
format!("CSP allows 'unsafe-eval' in {}", d.name),
format!(
"The Content-Security-Policy directive '{}' includes 'unsafe-eval', \
which allows the use of eval() and similar dynamic code execution \
that can be exploited via XSS.",
d.name
),
Severity::Medium,
url.to_string(),
"GET".to_string(),
);
finding.cwe = Some("CWE-79".to_string());
finding.evidence = vec![evidence];
finding.remediation = Some(
"Remove 'unsafe-eval' from script-src. Refactor code to avoid eval(), \
Function(), and similar constructs."
.to_string(),
);
findings.push(finding);
}
// Check for wildcard sources
if d.values.iter().any(|v| v == "*") {
let evidence = make_evidence(format!("{}: {}", d.name, d.values.join(" ")));
let mut finding = DastFinding::new(
String::new(),
target_id.to_string(),
DastVulnType::CspIssue,
format!("CSP wildcard source in {}", d.name),
format!(
"The Content-Security-Policy directive '{}' uses a wildcard '*' source, \
which allows loading resources from any origin and largely negates CSP protection.",
d.name
),
Severity::Medium,
url.to_string(),
"GET".to_string(),
);
finding.cwe = Some("CWE-16".to_string());
finding.evidence = vec![evidence];
finding.remediation = Some(format!(
"Replace the wildcard '*' in {} with specific allowed origins.",
d.name
));
findings.push(finding);
}
// Check for http: sources (non-HTTPS)
if d.values.iter().any(|v| v == "http:") {
let evidence = make_evidence(format!("{}: {}", d.name, d.values.join(" ")));
let mut finding = DastFinding::new(
String::new(),
target_id.to_string(),
DastVulnType::CspIssue,
format!("CSP allows HTTP sources in {}", d.name),
format!(
"The Content-Security-Policy directive '{}' allows loading resources \
over unencrypted HTTP, which can be exploited via man-in-the-middle attacks.",
d.name
),
Severity::Medium,
url.to_string(),
"GET".to_string(),
);
finding.cwe = Some("CWE-319".to_string());
finding.evidence = vec![evidence];
finding.remediation = Some(format!(
"Replace 'http:' with 'https:' in {} to enforce encrypted resource loading.",
d.name
));
findings.push(finding);
}
// Check for data: in script-src (can be used to bypass CSP)
if (d.name == "script-src" || d.name == "default-src")
&& d.values.iter().any(|v| v == "data:")
{
let evidence = make_evidence(format!("{}: {}", d.name, d.values.join(" ")));
let mut finding = DastFinding::new(
String::new(),
target_id.to_string(),
DastVulnType::CspIssue,
format!("CSP allows data: URIs in {}", d.name),
format!(
"The Content-Security-Policy directive '{}' allows 'data:' URIs, \
which can be used to bypass CSP and execute arbitrary scripts.",
d.name
),
Severity::High,
url.to_string(),
"GET".to_string(),
);
finding.cwe = Some("CWE-79".to_string());
finding.evidence = vec![evidence];
finding.remediation = Some(format!(
"Remove 'data:' from {}. If data URIs are needed, restrict them to \
non-executable content types only (e.g., img-src).",
d.name
));
findings.push(finding);
}
}
// Check for missing important directives
let directive_names: Vec<&str> = directives.iter().map(|d| d.name.as_str()).collect();
let has_default_src = directive_names.contains(&"default-src");
let important_directives = [
("script-src", "Controls which scripts can execute"),
("object-src", "Controls plugins like Flash"),
("base-uri", "Controls the base URL for relative URLs"),
("form-action", "Controls where forms can submit to"),
("frame-ancestors", "Controls who can embed this page in iframes"),
];
for (dir_name, desc) in &important_directives {
if !directive_names.contains(dir_name)
&& !(has_default_src && *dir_name != "frame-ancestors" && *dir_name != "base-uri" && *dir_name != "form-action")
{
let evidence = make_evidence(format!("CSP missing directive: {dir_name}"));
let mut finding = DastFinding::new(
String::new(),
target_id.to_string(),
DastVulnType::CspIssue,
format!("CSP missing '{}' directive", dir_name),
format!(
"The Content-Security-Policy is missing the '{}' directive. {}. \
Without this directive{}, the browser may fall back to less restrictive defaults.",
dir_name,
desc,
if has_default_src && (*dir_name == "frame-ancestors" || *dir_name == "base-uri" || *dir_name == "form-action") {
" (not covered by default-src)"
} else {
""
}
),
Severity::Low,
url.to_string(),
"GET".to_string(),
);
finding.cwe = Some("CWE-16".to_string());
finding.evidence = vec![evidence];
finding.remediation = Some(format!(
"Add '{}: 'none'' or an appropriate restrictive value to your CSP.",
dir_name
));
findings.push(finding);
}
}
findings
}
}
impl PentestTool for CspAnalyzerTool {
fn name(&self) -> &str {
"csp_analyzer"
}
fn description(&self) -> &str {
"Analyzes Content-Security-Policy headers. Checks for unsafe-inline, unsafe-eval, \
wildcard sources, data: URIs in script-src, missing directives, and other CSP weaknesses."
}
fn input_schema(&self) -> serde_json::Value {
json!({
"type": "object",
"properties": {
"url": {
"type": "string",
"description": "URL to fetch and analyze CSP from"
}
},
"required": ["url"]
})
}
fn execute<'a>(
&'a self,
input: serde_json::Value,
context: &'a PentestToolContext,
) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<PentestToolResult, CoreError>> + Send + 'a>> {
Box::pin(async move {
let url = input
.get("url")
.and_then(|v| v.as_str())
.ok_or_else(|| CoreError::Dast("Missing required 'url' parameter".to_string()))?;
let target_id = context
.target
.id
.map(|oid| oid.to_hex())
.unwrap_or_else(|| "unknown".to_string());
let response = self
.http
.get(url)
.send()
.await
.map_err(|e| CoreError::Dast(format!("Failed to fetch {url}: {e}")))?;
let status = response.status().as_u16();
// Check for CSP header
let csp_header = response
.headers()
.get("content-security-policy")
.and_then(|v| v.to_str().ok())
.map(String::from);
// Also check for report-only variant
let csp_report_only = response
.headers()
.get("content-security-policy-report-only")
.and_then(|v| v.to_str().ok())
.map(String::from);
let mut findings = Vec::new();
let mut csp_data = json!({});
match &csp_header {
Some(csp) => {
let directives = Self::parse_csp(csp);
let directive_map: serde_json::Value = directives
.iter()
.map(|d| (d.name.clone(), json!(d.values)))
.collect::<serde_json::Map<String, serde_json::Value>>()
.into();
csp_data["csp_header"] = json!(csp);
csp_data["directives"] = directive_map;
findings.extend(Self::analyze_directives(
&directives,
url,
&target_id,
status,
csp,
));
}
None => {
csp_data["csp_header"] = json!(null);
let evidence = DastEvidence {
request_method: "GET".to_string(),
request_url: url.to_string(),
request_headers: None,
request_body: None,
response_status: status,
response_headers: None,
response_snippet: Some("Content-Security-Policy header is missing".to_string()),
screenshot_path: None,
payload: None,
response_time_ms: None,
};
let mut finding = DastFinding::new(
String::new(),
target_id.clone(),
DastVulnType::CspIssue,
"Missing Content-Security-Policy header".to_string(),
format!(
"No Content-Security-Policy header is present on {url}. \
Without CSP, the browser has no instructions on which sources are \
trusted, making XSS exploitation much easier."
),
Severity::Medium,
url.to_string(),
"GET".to_string(),
);
finding.cwe = Some("CWE-16".to_string());
finding.evidence = vec![evidence];
finding.remediation = Some(
"Add a Content-Security-Policy header. Start with a restrictive policy like \
\"default-src 'self'; script-src 'self'; style-src 'self'; img-src 'self' data:; \
object-src 'none'; frame-ancestors 'none'; base-uri 'self'\"."
.to_string(),
);
findings.push(finding);
}
}
if let Some(ref report_only) = csp_report_only {
csp_data["csp_report_only"] = json!(report_only);
// If ONLY report-only exists (no enforcing CSP), warn
if csp_header.is_none() {
let evidence = DastEvidence {
request_method: "GET".to_string(),
request_url: url.to_string(),
request_headers: None,
request_body: None,
response_status: status,
response_headers: None,
response_snippet: Some(format!(
"Content-Security-Policy-Report-Only: {}",
report_only
)),
screenshot_path: None,
payload: None,
response_time_ms: None,
};
let mut finding = DastFinding::new(
String::new(),
target_id.clone(),
DastVulnType::CspIssue,
"CSP is report-only, not enforcing".to_string(),
"A Content-Security-Policy-Report-Only header is present but no enforcing \
Content-Security-Policy header exists. Report-only mode only logs violations \
but does not block them."
.to_string(),
Severity::Low,
url.to_string(),
"GET".to_string(),
);
finding.cwe = Some("CWE-16".to_string());
finding.evidence = vec![evidence];
finding.remediation = Some(
"Once you have verified the CSP policy works correctly in report-only mode, \
deploy it as an enforcing Content-Security-Policy header."
.to_string(),
);
findings.push(finding);
}
}
let count = findings.len();
info!(url, findings = count, "CSP analysis complete");
Ok(PentestToolResult {
summary: if count > 0 {
format!("Found {count} CSP issues for {url}.")
} else {
format!("Content-Security-Policy looks good for {url}.")
},
findings,
data: csp_data,
})
})
}
}

View File

@@ -0,0 +1,401 @@
use compliance_core::error::CoreError;
use compliance_core::models::dast::{DastEvidence, DastFinding, DastVulnType};
use compliance_core::models::Severity;
use compliance_core::traits::pentest_tool::{PentestTool, PentestToolContext, PentestToolResult};
use serde_json::json;
use tracing::{info, warn};
/// Tool that checks email security configuration (DMARC and SPF records).
pub struct DmarcCheckerTool;
impl DmarcCheckerTool {
pub fn new() -> Self {
Self
}
/// Query TXT records for a given name using `dig`.
async fn query_txt(name: &str) -> Result<Vec<String>, CoreError> {
let output = tokio::process::Command::new("dig")
.args(["+short", "TXT", name])
.output()
.await
.map_err(|e| CoreError::Dast(format!("dig command failed: {e}")))?;
let stdout = String::from_utf8_lossy(&output.stdout);
let lines: Vec<String> = stdout
.lines()
.map(|l| l.trim().trim_matches('"').to_string())
.filter(|l| !l.is_empty())
.collect();
Ok(lines)
}
/// Parse a DMARC record string and return the policy value.
fn parse_dmarc_policy(record: &str) -> Option<String> {
for part in record.split(';') {
let part = part.trim();
if let Some(val) = part.strip_prefix("p=") {
return Some(val.trim().to_lowercase());
}
}
None
}
/// Parse DMARC record for sub-domain policy (sp=).
fn parse_dmarc_subdomain_policy(record: &str) -> Option<String> {
for part in record.split(';') {
let part = part.trim();
if let Some(val) = part.strip_prefix("sp=") {
return Some(val.trim().to_lowercase());
}
}
None
}
/// Parse DMARC record for reporting URI (rua=).
fn parse_dmarc_rua(record: &str) -> Option<String> {
for part in record.split(';') {
let part = part.trim();
if let Some(val) = part.strip_prefix("rua=") {
return Some(val.trim().to_string());
}
}
None
}
/// Check if an SPF record is present and parse the policy.
fn is_spf_record(record: &str) -> bool {
record.starts_with("v=spf1")
}
/// Evaluate SPF record strength.
fn spf_uses_soft_fail(record: &str) -> bool {
record.contains("~all")
}
fn spf_allows_all(record: &str) -> bool {
record.contains("+all")
}
}
impl PentestTool for DmarcCheckerTool {
fn name(&self) -> &str {
"dmarc_checker"
}
fn description(&self) -> &str {
"Checks email security configuration for a domain. Queries DMARC and SPF records and \
evaluates policy strength. Reports missing or weak email authentication settings."
}
fn input_schema(&self) -> serde_json::Value {
json!({
"type": "object",
"properties": {
"domain": {
"type": "string",
"description": "The domain to check (e.g., 'example.com')"
}
},
"required": ["domain"]
})
}
fn execute<'a>(
&'a self,
input: serde_json::Value,
context: &'a PentestToolContext,
) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<PentestToolResult, CoreError>> + Send + 'a>> {
Box::pin(async move {
let domain = input
.get("domain")
.and_then(|v| v.as_str())
.ok_or_else(|| CoreError::Dast("Missing required 'domain' parameter".to_string()))?;
let target_id = context
.target
.id
.map(|oid| oid.to_hex())
.unwrap_or_else(|| "unknown".to_string());
let mut findings = Vec::new();
let mut email_data = json!({});
// ---- DMARC check ----
let dmarc_domain = format!("_dmarc.{domain}");
let dmarc_records = Self::query_txt(&dmarc_domain).await.unwrap_or_default();
let dmarc_record = dmarc_records.iter().find(|r| r.starts_with("v=DMARC1"));
match dmarc_record {
Some(record) => {
email_data["dmarc_record"] = json!(record);
let policy = Self::parse_dmarc_policy(record);
let sp = Self::parse_dmarc_subdomain_policy(record);
let rua = Self::parse_dmarc_rua(record);
email_data["dmarc_policy"] = json!(policy);
email_data["dmarc_subdomain_policy"] = json!(sp);
email_data["dmarc_rua"] = json!(rua);
// Warn on weak policy
if let Some(ref p) = policy {
if p == "none" {
let evidence = DastEvidence {
request_method: "DNS".to_string(),
request_url: dmarc_domain.clone(),
request_headers: None,
request_body: None,
response_status: 0,
response_headers: None,
response_snippet: Some(record.clone()),
screenshot_path: None,
payload: None,
response_time_ms: None,
};
let mut finding = DastFinding::new(
String::new(),
target_id.clone(),
DastVulnType::EmailSecurity,
format!("Weak DMARC policy for {domain}"),
format!(
"The DMARC policy for {domain} is set to 'none', which only monitors \
but does not enforce email authentication. Attackers can spoof emails \
from this domain."
),
Severity::Medium,
domain.to_string(),
"DNS".to_string(),
);
finding.cwe = Some("CWE-290".to_string());
finding.evidence = vec![evidence];
finding.remediation = Some(
"Upgrade the DMARC policy from 'p=none' to 'p=quarantine' or \
'p=reject' after verifying legitimate email flows are properly \
authenticated."
.to_string(),
);
findings.push(finding);
warn!(domain, "DMARC policy is 'none'");
}
}
// Warn if no reporting URI
if rua.is_none() {
let evidence = DastEvidence {
request_method: "DNS".to_string(),
request_url: dmarc_domain.clone(),
request_headers: None,
request_body: None,
response_status: 0,
response_headers: None,
response_snippet: Some(record.clone()),
screenshot_path: None,
payload: None,
response_time_ms: None,
};
let mut finding = DastFinding::new(
String::new(),
target_id.clone(),
DastVulnType::EmailSecurity,
format!("DMARC without reporting for {domain}"),
format!(
"The DMARC record for {domain} does not include a reporting URI (rua=). \
Without reporting, you will not receive aggregate feedback about email \
authentication failures."
),
Severity::Info,
domain.to_string(),
"DNS".to_string(),
);
finding.cwe = Some("CWE-778".to_string());
finding.evidence = vec![evidence];
finding.remediation = Some(
"Add a 'rua=' tag to your DMARC record to receive aggregate reports. \
Example: 'rua=mailto:dmarc-reports@example.com'."
.to_string(),
);
findings.push(finding);
}
}
None => {
email_data["dmarc_record"] = json!(null);
let evidence = DastEvidence {
request_method: "DNS".to_string(),
request_url: dmarc_domain.clone(),
request_headers: None,
request_body: None,
response_status: 0,
response_headers: None,
response_snippet: Some("No DMARC record found".to_string()),
screenshot_path: None,
payload: None,
response_time_ms: None,
};
let mut finding = DastFinding::new(
String::new(),
target_id.clone(),
DastVulnType::EmailSecurity,
format!("Missing DMARC record for {domain}"),
format!(
"No DMARC record was found for {domain}. Without DMARC, there is no \
policy to prevent email spoofing and phishing attacks using this domain."
),
Severity::High,
domain.to_string(),
"DNS".to_string(),
);
finding.cwe = Some("CWE-290".to_string());
finding.evidence = vec![evidence];
finding.remediation = Some(
"Create a DMARC TXT record at _dmarc.<domain>. Start with 'v=DMARC1; p=none; \
rua=mailto:dmarc@example.com' and gradually move to 'p=reject'."
.to_string(),
);
findings.push(finding);
warn!(domain, "No DMARC record found");
}
}
// ---- SPF check ----
let txt_records = Self::query_txt(domain).await.unwrap_or_default();
let spf_record = txt_records.iter().find(|r| Self::is_spf_record(r));
match spf_record {
Some(record) => {
email_data["spf_record"] = json!(record);
if Self::spf_allows_all(record) {
let evidence = DastEvidence {
request_method: "DNS".to_string(),
request_url: domain.to_string(),
request_headers: None,
request_body: None,
response_status: 0,
response_headers: None,
response_snippet: Some(record.clone()),
screenshot_path: None,
payload: None,
response_time_ms: None,
};
let mut finding = DastFinding::new(
String::new(),
target_id.clone(),
DastVulnType::EmailSecurity,
format!("SPF allows all senders for {domain}"),
format!(
"The SPF record for {domain} uses '+all' which allows any server to \
send email on behalf of this domain, completely negating SPF protection."
),
Severity::Critical,
domain.to_string(),
"DNS".to_string(),
);
finding.cwe = Some("CWE-290".to_string());
finding.evidence = vec![evidence];
finding.remediation = Some(
"Change '+all' to '-all' (hard fail) or '~all' (soft fail) in your SPF record. \
Only list authorized mail servers."
.to_string(),
);
findings.push(finding);
} else if Self::spf_uses_soft_fail(record) {
let evidence = DastEvidence {
request_method: "DNS".to_string(),
request_url: domain.to_string(),
request_headers: None,
request_body: None,
response_status: 0,
response_headers: None,
response_snippet: Some(record.clone()),
screenshot_path: None,
payload: None,
response_time_ms: None,
};
let mut finding = DastFinding::new(
String::new(),
target_id.clone(),
DastVulnType::EmailSecurity,
format!("SPF soft fail for {domain}"),
format!(
"The SPF record for {domain} uses '~all' (soft fail) instead of \
'-all' (hard fail). Soft fail marks unauthorized emails as suspicious \
but does not reject them."
),
Severity::Low,
domain.to_string(),
"DNS".to_string(),
);
finding.cwe = Some("CWE-290".to_string());
finding.evidence = vec![evidence];
finding.remediation = Some(
"Consider changing '~all' to '-all' in your SPF record once you have \
confirmed all legitimate mail sources are listed."
.to_string(),
);
findings.push(finding);
}
}
None => {
email_data["spf_record"] = json!(null);
let evidence = DastEvidence {
request_method: "DNS".to_string(),
request_url: domain.to_string(),
request_headers: None,
request_body: None,
response_status: 0,
response_headers: None,
response_snippet: Some("No SPF record found".to_string()),
screenshot_path: None,
payload: None,
response_time_ms: None,
};
let mut finding = DastFinding::new(
String::new(),
target_id.clone(),
DastVulnType::EmailSecurity,
format!("Missing SPF record for {domain}"),
format!(
"No SPF record was found for {domain}. Without SPF, any server can claim \
to send email on behalf of this domain."
),
Severity::High,
domain.to_string(),
"DNS".to_string(),
);
finding.cwe = Some("CWE-290".to_string());
finding.evidence = vec![evidence];
finding.remediation = Some(
"Create an SPF TXT record for your domain. Example: \
'v=spf1 include:_spf.google.com -all'."
.to_string(),
);
findings.push(finding);
warn!(domain, "No SPF record found");
}
}
let count = findings.len();
info!(domain, findings = count, "DMARC/SPF check complete");
Ok(PentestToolResult {
summary: if count > 0 {
format!("Found {count} email security issues for {domain}.")
} else {
format!("Email security configuration looks good for {domain}.")
},
findings,
data: email_data,
})
})
}
}

View File

@@ -0,0 +1,389 @@
use std::collections::HashMap;
use compliance_core::error::CoreError;
use compliance_core::models::dast::{DastEvidence, DastFinding, DastVulnType};
use compliance_core::models::Severity;
use compliance_core::traits::pentest_tool::{PentestTool, PentestToolContext, PentestToolResult};
use serde_json::json;
use tokio::net::lookup_host;
use tracing::{info, warn};
/// Tool that checks DNS configuration for security issues.
///
/// Resolves A, AAAA, MX, TXT, CNAME, NS records using the system resolver
/// via `tokio::net::lookup_host` and `std::net::ToSocketAddrs`. For TXT-based
/// records (SPF, DMARC, CAA, DNSSEC) it uses a simple TXT query via the
/// `tokio::process::Command` wrapper around `dig` where available.
pub struct DnsCheckerTool;
impl DnsCheckerTool {
pub fn new() -> Self {
Self
}
/// Run a `dig` query and return the answer lines.
async fn dig_query(domain: &str, record_type: &str) -> Result<Vec<String>, CoreError> {
let output = tokio::process::Command::new("dig")
.args(["+short", record_type, domain])
.output()
.await
.map_err(|e| CoreError::Dast(format!("dig command failed: {e}")))?;
let stdout = String::from_utf8_lossy(&output.stdout);
let lines: Vec<String> = stdout
.lines()
.map(|l| l.trim().to_string())
.filter(|l| !l.is_empty())
.collect();
Ok(lines)
}
/// Resolve A/AAAA records using tokio lookup.
async fn resolve_addresses(domain: &str) -> Result<(Vec<String>, Vec<String>), CoreError> {
let mut ipv4 = Vec::new();
let mut ipv6 = Vec::new();
let addr_str = format!("{domain}:443");
match lookup_host(&addr_str).await {
Ok(addrs) => {
for addr in addrs {
match addr {
std::net::SocketAddr::V4(v4) => ipv4.push(v4.ip().to_string()),
std::net::SocketAddr::V6(v6) => ipv6.push(v6.ip().to_string()),
}
}
}
Err(e) => {
return Err(CoreError::Dast(format!("DNS resolution failed for {domain}: {e}")));
}
}
Ok((ipv4, ipv6))
}
}
impl PentestTool for DnsCheckerTool {
fn name(&self) -> &str {
"dns_checker"
}
fn description(&self) -> &str {
"Checks DNS configuration for a domain. Resolves A, AAAA, MX, TXT, CNAME, NS records. \
Checks for DNSSEC, CAA records, and potential subdomain takeover via dangling CNAME/NS."
}
fn input_schema(&self) -> serde_json::Value {
json!({
"type": "object",
"properties": {
"domain": {
"type": "string",
"description": "The domain to check (e.g., 'example.com')"
},
"subdomains": {
"type": "array",
"description": "Optional list of subdomains to also check (e.g., ['www', 'api', 'mail'])",
"items": { "type": "string" }
}
},
"required": ["domain"]
})
}
fn execute<'a>(
&'a self,
input: serde_json::Value,
context: &'a PentestToolContext,
) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<PentestToolResult, CoreError>> + Send + 'a>> {
Box::pin(async move {
let domain = input
.get("domain")
.and_then(|v| v.as_str())
.ok_or_else(|| CoreError::Dast("Missing required 'domain' parameter".to_string()))?;
let subdomains: Vec<String> = input
.get("subdomains")
.and_then(|v| v.as_array())
.map(|arr| {
arr.iter()
.filter_map(|v| v.as_str().map(String::from))
.collect()
})
.unwrap_or_default();
let mut findings = Vec::new();
let mut dns_data: HashMap<String, serde_json::Value> = HashMap::new();
let target_id = context
.target
.id
.map(|oid| oid.to_hex())
.unwrap_or_else(|| "unknown".to_string());
// --- A / AAAA records ---
match Self::resolve_addresses(domain).await {
Ok((ipv4, ipv6)) => {
dns_data.insert("a_records".to_string(), json!(ipv4));
dns_data.insert("aaaa_records".to_string(), json!(ipv6));
}
Err(e) => {
dns_data.insert("a_records_error".to_string(), json!(e.to_string()));
}
}
// --- MX records ---
match Self::dig_query(domain, "MX").await {
Ok(mx) => {
dns_data.insert("mx_records".to_string(), json!(mx));
}
Err(e) => {
dns_data.insert("mx_records_error".to_string(), json!(e.to_string()));
}
}
// --- NS records ---
let ns_records = match Self::dig_query(domain, "NS").await {
Ok(ns) => {
dns_data.insert("ns_records".to_string(), json!(ns));
ns
}
Err(e) => {
dns_data.insert("ns_records_error".to_string(), json!(e.to_string()));
Vec::new()
}
};
// --- TXT records ---
match Self::dig_query(domain, "TXT").await {
Ok(txt) => {
dns_data.insert("txt_records".to_string(), json!(txt));
}
Err(e) => {
dns_data.insert("txt_records_error".to_string(), json!(e.to_string()));
}
}
// --- CNAME records (for subdomains) ---
let mut cname_data: HashMap<String, Vec<String>> = HashMap::new();
let mut domains_to_check = vec![domain.to_string()];
for sub in &subdomains {
domains_to_check.push(format!("{sub}.{domain}"));
}
for fqdn in &domains_to_check {
match Self::dig_query(fqdn, "CNAME").await {
Ok(cnames) if !cnames.is_empty() => {
// Check for dangling CNAME
for cname in &cnames {
let cname_clean = cname.trim_end_matches('.');
let check_addr = format!("{cname_clean}:443");
let is_dangling = lookup_host(&check_addr).await.is_err();
if is_dangling {
let evidence = DastEvidence {
request_method: "DNS".to_string(),
request_url: fqdn.clone(),
request_headers: None,
request_body: None,
response_status: 0,
response_headers: None,
response_snippet: Some(format!(
"CNAME {fqdn} -> {cname} (target does not resolve)"
)),
screenshot_path: None,
payload: None,
response_time_ms: None,
};
let mut finding = DastFinding::new(
String::new(),
target_id.clone(),
DastVulnType::DnsMisconfiguration,
format!("Dangling CNAME on {fqdn}"),
format!(
"The subdomain {fqdn} has a CNAME record pointing to {cname} which does not resolve. \
This may allow subdomain takeover if an attacker can claim the target hostname."
),
Severity::High,
fqdn.clone(),
"DNS".to_string(),
);
finding.cwe = Some("CWE-923".to_string());
finding.evidence = vec![evidence];
finding.remediation = Some(
"Remove dangling CNAME records or ensure the target hostname is \
properly configured and resolvable."
.to_string(),
);
findings.push(finding);
warn!(fqdn, cname, "Dangling CNAME detected - potential subdomain takeover");
}
}
cname_data.insert(fqdn.clone(), cnames);
}
_ => {}
}
}
if !cname_data.is_empty() {
dns_data.insert("cname_records".to_string(), json!(cname_data));
}
// --- CAA records ---
match Self::dig_query(domain, "CAA").await {
Ok(caa) => {
if caa.is_empty() {
let evidence = DastEvidence {
request_method: "DNS".to_string(),
request_url: domain.to_string(),
request_headers: None,
request_body: None,
response_status: 0,
response_headers: None,
response_snippet: Some("No CAA records found".to_string()),
screenshot_path: None,
payload: None,
response_time_ms: None,
};
let mut finding = DastFinding::new(
String::new(),
target_id.clone(),
DastVulnType::DnsMisconfiguration,
format!("Missing CAA records for {domain}"),
format!(
"No CAA (Certificate Authority Authorization) records are set for {domain}. \
Without CAA records, any certificate authority can issue certificates for this domain."
),
Severity::Low,
domain.to_string(),
"DNS".to_string(),
);
finding.cwe = Some("CWE-295".to_string());
finding.evidence = vec![evidence];
finding.remediation = Some(
"Add CAA DNS records to restrict which certificate authorities can issue \
certificates for your domain. Example: '0 issue \"letsencrypt.org\"'."
.to_string(),
);
findings.push(finding);
}
dns_data.insert("caa_records".to_string(), json!(caa));
}
Err(e) => {
dns_data.insert("caa_records_error".to_string(), json!(e.to_string()));
}
}
// --- DNSSEC check ---
let dnssec_output = tokio::process::Command::new("dig")
.args(["+dnssec", "+short", "DNSKEY", domain])
.output()
.await;
match dnssec_output {
Ok(output) => {
let stdout = String::from_utf8_lossy(&output.stdout);
let has_dnssec = !stdout.trim().is_empty();
dns_data.insert("dnssec_enabled".to_string(), json!(has_dnssec));
if !has_dnssec {
let evidence = DastEvidence {
request_method: "DNS".to_string(),
request_url: domain.to_string(),
request_headers: None,
request_body: None,
response_status: 0,
response_headers: None,
response_snippet: Some("No DNSKEY records found - DNSSEC not enabled".to_string()),
screenshot_path: None,
payload: None,
response_time_ms: None,
};
let mut finding = DastFinding::new(
String::new(),
target_id.clone(),
DastVulnType::DnsMisconfiguration,
format!("DNSSEC not enabled for {domain}"),
format!(
"DNSSEC is not enabled for {domain}. Without DNSSEC, DNS responses \
can be spoofed, allowing man-in-the-middle attacks."
),
Severity::Medium,
domain.to_string(),
"DNS".to_string(),
);
finding.cwe = Some("CWE-350".to_string());
finding.evidence = vec![evidence];
finding.remediation = Some(
"Enable DNSSEC for your domain by configuring DNSKEY and DS records \
with your DNS provider and domain registrar."
.to_string(),
);
findings.push(finding);
}
}
Err(_) => {
dns_data.insert("dnssec_check_error".to_string(), json!("dig not available"));
}
}
// --- Check NS records for dangling ---
for ns in &ns_records {
let ns_clean = ns.trim_end_matches('.');
let check_addr = format!("{ns_clean}:53");
if lookup_host(&check_addr).await.is_err() {
let evidence = DastEvidence {
request_method: "DNS".to_string(),
request_url: domain.to_string(),
request_headers: None,
request_body: None,
response_status: 0,
response_headers: None,
response_snippet: Some(format!(
"NS record {ns} does not resolve"
)),
screenshot_path: None,
payload: None,
response_time_ms: None,
};
let mut finding = DastFinding::new(
String::new(),
target_id.clone(),
DastVulnType::DnsMisconfiguration,
format!("Dangling NS record for {domain}"),
format!(
"The NS record {ns} for {domain} does not resolve. \
This could allow domain takeover if an attacker can claim the nameserver hostname."
),
Severity::Critical,
domain.to_string(),
"DNS".to_string(),
);
finding.cwe = Some("CWE-923".to_string());
finding.evidence = vec![evidence];
finding.remediation = Some(
"Remove dangling NS records or ensure the nameserver hostname is properly \
configured. Dangling NS records can lead to full domain takeover."
.to_string(),
);
findings.push(finding);
warn!(domain, ns, "Dangling NS record detected - potential domain takeover");
}
}
let count = findings.len();
info!(domain, findings = count, "DNS check complete");
Ok(PentestToolResult {
summary: if count > 0 {
format!("Found {count} DNS configuration issues for {domain}.")
} else {
format!("No DNS configuration issues found for {domain}.")
},
findings,
data: json!(dns_data),
})
})
}
}

View File

@@ -0,0 +1,141 @@
pub mod api_fuzzer;
pub mod auth_bypass;
pub mod console_log_detector;
pub mod cookie_analyzer;
pub mod cors_checker;
pub mod csp_analyzer;
pub mod dmarc_checker;
pub mod dns_checker;
pub mod openapi_parser;
pub mod rate_limit_tester;
pub mod recon;
pub mod security_headers;
pub mod sql_injection;
pub mod ssrf;
pub mod tls_analyzer;
pub mod xss;
use std::collections::HashMap;
use compliance_core::traits::pentest_tool::PentestTool;
/// A definition describing a tool for LLM tool_use registration.
#[derive(Debug, Clone)]
pub struct ToolDefinition {
pub name: String,
pub description: String,
pub input_schema: serde_json::Value,
}
/// Registry that holds all available pentest tools and provides
/// look-up by name.
pub struct ToolRegistry {
tools: HashMap<String, Box<dyn PentestTool>>,
}
impl ToolRegistry {
/// Create a new registry with all built-in tools pre-registered.
pub fn new() -> Self {
let http = reqwest::Client::builder()
.danger_accept_invalid_certs(true)
.timeout(std::time::Duration::from_secs(30))
.redirect(reqwest::redirect::Policy::limited(5))
.build()
.expect("failed to build HTTP client");
let mut tools: HashMap<String, Box<dyn PentestTool>> = HashMap::new();
// Agent-wrapping tools
let register = |tools: &mut HashMap<String, Box<dyn PentestTool>>,
tool: Box<dyn PentestTool>| {
tools.insert(tool.name().to_string(), tool);
};
register(
&mut tools,
Box::new(sql_injection::SqlInjectionTool::new(http.clone())),
);
register(&mut tools, Box::new(xss::XssTool::new(http.clone())));
register(
&mut tools,
Box::new(auth_bypass::AuthBypassTool::new(http.clone())),
);
register(&mut tools, Box::new(ssrf::SsrfTool::new(http.clone())));
register(
&mut tools,
Box::new(api_fuzzer::ApiFuzzerTool::new(http.clone())),
);
// New infrastructure / analysis tools
register(
&mut tools,
Box::new(dns_checker::DnsCheckerTool::new()),
);
register(
&mut tools,
Box::new(dmarc_checker::DmarcCheckerTool::new()),
);
register(
&mut tools,
Box::new(tls_analyzer::TlsAnalyzerTool::new(http.clone())),
);
register(
&mut tools,
Box::new(security_headers::SecurityHeadersTool::new(http.clone())),
);
register(
&mut tools,
Box::new(cookie_analyzer::CookieAnalyzerTool::new(http.clone())),
);
register(
&mut tools,
Box::new(csp_analyzer::CspAnalyzerTool::new(http.clone())),
);
register(
&mut tools,
Box::new(rate_limit_tester::RateLimitTesterTool::new(http.clone())),
);
register(
&mut tools,
Box::new(console_log_detector::ConsoleLogDetectorTool::new(
http.clone(),
)),
);
register(
&mut tools,
Box::new(cors_checker::CorsCheckerTool::new(http.clone())),
);
register(
&mut tools,
Box::new(openapi_parser::OpenApiParserTool::new(http.clone())),
);
register(
&mut tools,
Box::new(recon::ReconTool::new(http)),
);
Self { tools }
}
/// Look up a tool by name.
pub fn get(&self, name: &str) -> Option<&dyn PentestTool> {
self.tools.get(name).map(|b| b.as_ref())
}
/// Return definitions for every registered tool.
pub fn all_definitions(&self) -> Vec<ToolDefinition> {
self.tools
.values()
.map(|t| ToolDefinition {
name: t.name().to_string(),
description: t.description().to_string(),
input_schema: t.input_schema(),
})
.collect()
}
/// Return the names of all registered tools.
pub fn list_names(&self) -> Vec<String> {
self.tools.keys().cloned().collect()
}
}

View File

@@ -0,0 +1,422 @@
use compliance_core::error::CoreError;
use compliance_core::traits::pentest_tool::{PentestTool, PentestToolContext, PentestToolResult};
use serde_json::json;
use tracing::info;
/// Tool that discovers and parses OpenAPI/Swagger specification files.
///
/// Returns structured endpoint definitions for the LLM to use when planning
/// further tests. This tool produces data rather than security findings.
pub struct OpenApiParserTool {
http: reqwest::Client,
}
/// A parsed endpoint from an OpenAPI spec.
#[derive(Debug, Clone)]
struct ParsedEndpoint {
path: String,
method: String,
operation_id: Option<String>,
summary: Option<String>,
parameters: Vec<ParsedParameter>,
request_body_content_type: Option<String>,
response_codes: Vec<String>,
security: Vec<String>,
tags: Vec<String>,
}
/// A parsed parameter from an OpenAPI spec.
#[derive(Debug, Clone)]
struct ParsedParameter {
name: String,
location: String,
required: bool,
param_type: Option<String>,
description: Option<String>,
}
impl OpenApiParserTool {
pub fn new(http: reqwest::Client) -> Self {
Self { http }
}
/// Common paths where OpenAPI/Swagger specs are typically served.
fn common_spec_paths() -> Vec<&'static str> {
vec![
"/openapi.json",
"/openapi.yaml",
"/swagger.json",
"/swagger.yaml",
"/api-docs",
"/api-docs.json",
"/v2/api-docs",
"/v3/api-docs",
"/docs/openapi.json",
"/api/swagger.json",
"/api/openapi.json",
"/api/v1/openapi.json",
"/api/v2/openapi.json",
"/.well-known/openapi.json",
]
}
/// Try to fetch a spec from a URL and return the JSON value if successful.
async fn try_fetch_spec(
http: &reqwest::Client,
url: &str,
) -> Option<(String, serde_json::Value)> {
let resp = http.get(url).send().await.ok()?;
if !resp.status().is_success() {
return None;
}
let content_type = resp
.headers()
.get("content-type")
.and_then(|v| v.to_str().ok())
.unwrap_or("")
.to_string();
let body = resp.text().await.ok()?;
// Try JSON first
if let Ok(val) = serde_json::from_str::<serde_json::Value>(&body) {
// Verify it looks like an OpenAPI / Swagger spec
if val.get("openapi").is_some()
|| val.get("swagger").is_some()
|| val.get("paths").is_some()
{
return Some((url.to_string(), val));
}
}
// If content type suggests YAML, we can't easily parse without a YAML dep,
// so just report the URL as found
if content_type.contains("yaml") || body.starts_with("openapi:") || body.starts_with("swagger:") {
// Return a minimal JSON indicating YAML was found
return Some((
url.to_string(),
json!({
"_note": "YAML spec detected but not parsed. Fetch and convert to JSON.",
"_raw_url": url,
}),
));
}
None
}
/// Parse an OpenAPI 3.x or Swagger 2.x spec into structured endpoints.
fn parse_spec(spec: &serde_json::Value, base_url: &str) -> Vec<ParsedEndpoint> {
let mut endpoints = Vec::new();
// Determine base path
let base_path = if let Some(servers) = spec.get("servers").and_then(|v| v.as_array()) {
servers
.first()
.and_then(|s| s.get("url"))
.and_then(|u| u.as_str())
.unwrap_or("")
.to_string()
} else if let Some(bp) = spec.get("basePath").and_then(|v| v.as_str()) {
bp.to_string()
} else {
String::new()
};
let paths = match spec.get("paths").and_then(|v| v.as_object()) {
Some(p) => p,
None => return endpoints,
};
for (path, path_item) in paths {
let path_obj = match path_item.as_object() {
Some(o) => o,
None => continue,
};
// Path-level parameters
let path_params = path_obj
.get("parameters")
.and_then(|v| v.as_array())
.cloned()
.unwrap_or_default();
for method in &["get", "post", "put", "patch", "delete", "head", "options"] {
let operation = match path_obj.get(*method).and_then(|v| v.as_object()) {
Some(o) => o,
None => continue,
};
let operation_id = operation
.get("operationId")
.and_then(|v| v.as_str())
.map(String::from);
let summary = operation
.get("summary")
.and_then(|v| v.as_str())
.map(String::from);
let tags: Vec<String> = operation
.get("tags")
.and_then(|v| v.as_array())
.map(|arr| {
arr.iter()
.filter_map(|t| t.as_str().map(String::from))
.collect()
})
.unwrap_or_default();
// Merge path-level and operation-level parameters
let mut parameters = Vec::new();
let op_params = operation
.get("parameters")
.and_then(|v| v.as_array())
.cloned()
.unwrap_or_default();
for param_val in path_params.iter().chain(op_params.iter()) {
let name = param_val
.get("name")
.and_then(|v| v.as_str())
.unwrap_or_default()
.to_string();
let location = param_val
.get("in")
.and_then(|v| v.as_str())
.unwrap_or("query")
.to_string();
let required = param_val
.get("required")
.and_then(|v| v.as_bool())
.unwrap_or(false);
// Type from schema or direct type field
let param_type = param_val
.get("schema")
.and_then(|s| s.get("type"))
.or_else(|| param_val.get("type"))
.and_then(|v| v.as_str())
.map(String::from);
let description = param_val
.get("description")
.and_then(|v| v.as_str())
.map(String::from);
parameters.push(ParsedParameter {
name,
location,
required,
param_type,
description,
});
}
// Request body (OpenAPI 3.x)
let request_body_content_type = operation
.get("requestBody")
.and_then(|rb| rb.get("content"))
.and_then(|c| c.as_object())
.and_then(|obj| obj.keys().next().cloned());
// Response codes
let response_codes: Vec<String> = operation
.get("responses")
.and_then(|r| r.as_object())
.map(|obj| obj.keys().cloned().collect())
.unwrap_or_default();
// Security requirements
let security: Vec<String> = operation
.get("security")
.and_then(|v| v.as_array())
.map(|arr| {
arr.iter()
.filter_map(|s| s.as_object())
.flat_map(|obj| obj.keys().cloned())
.collect()
})
.unwrap_or_default();
endpoints.push(ParsedEndpoint {
path: format!("{}{}", base_path, path),
method: method.to_uppercase(),
operation_id,
summary,
parameters,
request_body_content_type,
response_codes,
security,
tags,
});
}
}
endpoints
}
}
impl PentestTool for OpenApiParserTool {
fn name(&self) -> &str {
"openapi_parser"
}
fn description(&self) -> &str {
"Discovers and parses OpenAPI/Swagger specifications. Tries common spec paths and \
returns structured endpoint definitions including parameters, methods, and security \
requirements. Use this to discover all API endpoints before testing."
}
fn input_schema(&self) -> serde_json::Value {
json!({
"type": "object",
"properties": {
"base_url": {
"type": "string",
"description": "Base URL of the API to discover specs from"
},
"spec_url": {
"type": "string",
"description": "Optional explicit URL of the OpenAPI/Swagger spec file"
}
},
"required": ["base_url"]
})
}
fn execute<'a>(
&'a self,
input: serde_json::Value,
context: &'a PentestToolContext,
) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<PentestToolResult, CoreError>> + Send + 'a>> {
Box::pin(async move {
let base_url = input
.get("base_url")
.and_then(|v| v.as_str())
.ok_or_else(|| CoreError::Dast("Missing required 'base_url' parameter".to_string()))?;
let explicit_spec_url = input.get("spec_url").and_then(|v| v.as_str());
let base_url_trimmed = base_url.trim_end_matches('/');
// If an explicit spec URL is provided, try it first
let mut spec_result: Option<(String, serde_json::Value)> = None;
if let Some(spec_url) = explicit_spec_url {
spec_result = Self::try_fetch_spec(&self.http, spec_url).await;
}
// If no explicit URL or it failed, try common paths
if spec_result.is_none() {
for path in Self::common_spec_paths() {
let url = format!("{base_url_trimmed}{path}");
if let Some(result) = Self::try_fetch_spec(&self.http, &url).await {
spec_result = Some(result);
break;
}
}
}
match spec_result {
Some((spec_url, spec)) => {
let spec_version = spec
.get("openapi")
.or_else(|| spec.get("swagger"))
.and_then(|v| v.as_str())
.unwrap_or("unknown");
let api_title = spec
.get("info")
.and_then(|i| i.get("title"))
.and_then(|t| t.as_str())
.unwrap_or("Unknown API");
let api_version = spec
.get("info")
.and_then(|i| i.get("version"))
.and_then(|v| v.as_str())
.unwrap_or("unknown");
let endpoints = Self::parse_spec(&spec, base_url_trimmed);
let endpoint_data: Vec<serde_json::Value> = endpoints
.iter()
.map(|ep| {
let params: Vec<serde_json::Value> = ep
.parameters
.iter()
.map(|p| {
json!({
"name": p.name,
"in": p.location,
"required": p.required,
"type": p.param_type,
"description": p.description,
})
})
.collect();
json!({
"path": ep.path,
"method": ep.method,
"operation_id": ep.operation_id,
"summary": ep.summary,
"parameters": params,
"request_body_content_type": ep.request_body_content_type,
"response_codes": ep.response_codes,
"security": ep.security,
"tags": ep.tags,
})
})
.collect();
let endpoint_count = endpoints.len();
info!(
spec_url = %spec_url,
spec_version,
api_title,
endpoints = endpoint_count,
"OpenAPI spec parsed"
);
Ok(PentestToolResult {
summary: format!(
"Found OpenAPI spec ({spec_version}) at {spec_url}. \
API: {api_title} v{api_version}. \
Parsed {endpoint_count} endpoints."
),
findings: Vec::new(), // This tool produces data, not findings
data: json!({
"spec_url": spec_url,
"spec_version": spec_version,
"api_title": api_title,
"api_version": api_version,
"endpoint_count": endpoint_count,
"endpoints": endpoint_data,
"security_schemes": spec.get("components")
.and_then(|c| c.get("securitySchemes"))
.or_else(|| spec.get("securityDefinitions")),
}),
})
}
None => {
info!(base_url, "No OpenAPI spec found");
Ok(PentestToolResult {
summary: format!(
"No OpenAPI/Swagger specification found for {base_url}. \
Tried {} common paths.",
Self::common_spec_paths().len()
),
findings: Vec::new(),
data: json!({
"spec_found": false,
"paths_tried": Self::common_spec_paths(),
}),
})
}
}
})
}
}

View File

@@ -0,0 +1,285 @@
use std::time::Instant;
use compliance_core::error::CoreError;
use compliance_core::models::dast::{DastEvidence, DastFinding, DastVulnType};
use compliance_core::models::Severity;
use compliance_core::traits::pentest_tool::{PentestTool, PentestToolContext, PentestToolResult};
use serde_json::json;
use tracing::{info, warn};
/// Tool that tests whether a target enforces rate limiting.
pub struct RateLimitTesterTool {
http: reqwest::Client,
}
impl RateLimitTesterTool {
pub fn new(http: reqwest::Client) -> Self {
Self { http }
}
}
impl PentestTool for RateLimitTesterTool {
fn name(&self) -> &str {
"rate_limit_tester"
}
fn description(&self) -> &str {
"Tests whether an endpoint enforces rate limiting by sending rapid sequential requests. \
Checks for 429 responses and measures response time degradation."
}
fn input_schema(&self) -> serde_json::Value {
json!({
"type": "object",
"properties": {
"url": {
"type": "string",
"description": "URL of the endpoint to test for rate limiting"
},
"method": {
"type": "string",
"description": "HTTP method to use",
"enum": ["GET", "POST", "PUT", "PATCH", "DELETE"],
"default": "GET"
},
"request_count": {
"type": "integer",
"description": "Number of rapid requests to send (default: 50)",
"default": 50,
"minimum": 10,
"maximum": 200
},
"body": {
"type": "string",
"description": "Optional request body (for POST/PUT/PATCH)"
}
},
"required": ["url"]
})
}
fn execute<'a>(
&'a self,
input: serde_json::Value,
context: &'a PentestToolContext,
) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<PentestToolResult, CoreError>> + Send + 'a>> {
Box::pin(async move {
let url = input
.get("url")
.and_then(|v| v.as_str())
.ok_or_else(|| CoreError::Dast("Missing required 'url' parameter".to_string()))?;
let method = input
.get("method")
.and_then(|v| v.as_str())
.unwrap_or("GET");
let request_count = input
.get("request_count")
.and_then(|v| v.as_u64())
.unwrap_or(50)
.min(200) as usize;
let body = input.get("body").and_then(|v| v.as_str());
let target_id = context
.target
.id
.map(|oid| oid.to_hex())
.unwrap_or_else(|| "unknown".to_string());
// Respect the context rate limit if set
let max_requests = if context.rate_limit > 0 {
request_count.min(context.rate_limit as usize * 5)
} else {
request_count
};
let mut status_codes: Vec<u16> = Vec::with_capacity(max_requests);
let mut response_times: Vec<u128> = Vec::with_capacity(max_requests);
let mut got_429 = false;
let mut rate_limit_at_request: Option<usize> = None;
for i in 0..max_requests {
let start = Instant::now();
let request = match method {
"POST" => {
let mut req = self.http.post(url);
if let Some(b) = body {
req = req.body(b.to_string());
}
req
}
"PUT" => {
let mut req = self.http.put(url);
if let Some(b) = body {
req = req.body(b.to_string());
}
req
}
"PATCH" => {
let mut req = self.http.patch(url);
if let Some(b) = body {
req = req.body(b.to_string());
}
req
}
"DELETE" => self.http.delete(url),
_ => self.http.get(url),
};
match request.send().await {
Ok(resp) => {
let elapsed = start.elapsed().as_millis();
let status = resp.status().as_u16();
status_codes.push(status);
response_times.push(elapsed);
if status == 429 && !got_429 {
got_429 = true;
rate_limit_at_request = Some(i + 1);
info!(url, request_num = i + 1, "Rate limit triggered (429)");
}
// Check for rate limit headers even on 200
if !got_429 {
let headers = resp.headers();
let has_rate_headers = headers.contains_key("x-ratelimit-limit")
|| headers.contains_key("x-ratelimit-remaining")
|| headers.contains_key("ratelimit-limit")
|| headers.contains_key("ratelimit-remaining")
|| headers.contains_key("retry-after");
if has_rate_headers && rate_limit_at_request.is_none() {
// Server has rate limit headers but hasn't blocked yet
}
}
}
Err(e) => {
let elapsed = start.elapsed().as_millis();
status_codes.push(0);
response_times.push(elapsed);
}
}
}
let mut findings = Vec::new();
let total_sent = status_codes.len();
let count_429 = status_codes.iter().filter(|&&s| s == 429).count();
let count_success = status_codes.iter().filter(|&&s| (200..300).contains(&s)).count();
// Calculate response time statistics
let avg_time = if !response_times.is_empty() {
response_times.iter().sum::<u128>() / response_times.len() as u128
} else {
0
};
let first_half_avg = if response_times.len() >= 4 {
let half = response_times.len() / 2;
response_times[..half].iter().sum::<u128>() / half as u128
} else {
avg_time
};
let second_half_avg = if response_times.len() >= 4 {
let half = response_times.len() / 2;
response_times[half..].iter().sum::<u128>() / (response_times.len() - half) as u128
} else {
avg_time
};
// Significant time degradation suggests possible (weak) rate limiting
let time_degradation = if first_half_avg > 0 {
(second_half_avg as f64 / first_half_avg as f64) - 1.0
} else {
0.0
};
let rate_data = json!({
"total_requests_sent": total_sent,
"status_429_count": count_429,
"success_count": count_success,
"rate_limit_at_request": rate_limit_at_request,
"avg_response_time_ms": avg_time,
"first_half_avg_ms": first_half_avg,
"second_half_avg_ms": second_half_avg,
"time_degradation_pct": (time_degradation * 100.0).round(),
});
if !got_429 && count_success == total_sent {
// No rate limiting detected at all
let evidence = DastEvidence {
request_method: method.to_string(),
request_url: url.to_string(),
request_headers: None,
request_body: body.map(String::from),
response_status: 200,
response_headers: None,
response_snippet: Some(format!(
"Sent {total_sent} rapid requests. All returned success (2xx). \
No 429 responses received. Avg response time: {avg_time}ms."
)),
screenshot_path: None,
payload: None,
response_time_ms: Some(avg_time as u64),
};
let mut finding = DastFinding::new(
String::new(),
target_id.clone(),
DastVulnType::RateLimitAbsent,
format!("No rate limiting on {} {}", method, url),
format!(
"The endpoint {} {} does not enforce rate limiting. \
{total_sent} rapid requests were all accepted with no 429 responses \
or noticeable degradation. This makes the endpoint vulnerable to \
brute force attacks and abuse.",
method, url
),
Severity::Medium,
url.to_string(),
method.to_string(),
);
finding.cwe = Some("CWE-770".to_string());
finding.evidence = vec![evidence];
finding.remediation = Some(
"Implement rate limiting on this endpoint. Use token bucket or sliding window \
algorithms. Return 429 Too Many Requests with a Retry-After header when the \
limit is exceeded."
.to_string(),
);
findings.push(finding);
warn!(url, method, total_sent, "No rate limiting detected");
} else if got_429 {
info!(
url,
method,
rate_limit_at = ?rate_limit_at_request,
"Rate limiting is enforced"
);
}
let count = findings.len();
Ok(PentestToolResult {
summary: if got_429 {
format!(
"Rate limiting is enforced on {method} {url}. \
429 response received after {} requests.",
rate_limit_at_request.unwrap_or(0)
)
} else if count > 0 {
format!(
"No rate limiting detected on {method} {url} after {total_sent} requests."
)
} else {
format!("Rate limit testing complete for {method} {url}.")
},
findings,
data: rate_data,
})
})
}
}

View File

@@ -0,0 +1,125 @@
use std::collections::HashMap;
use compliance_core::error::CoreError;
use compliance_core::traits::pentest_tool::{PentestTool, PentestToolContext, PentestToolResult};
use serde_json::json;
use tracing::info;
use crate::recon::ReconAgent;
/// PentestTool wrapper around the existing ReconAgent.
///
/// Performs HTTP header fingerprinting and technology detection.
/// Returns structured recon data for the LLM to use when planning attacks.
pub struct ReconTool {
http: reqwest::Client,
agent: ReconAgent,
}
impl ReconTool {
pub fn new(http: reqwest::Client) -> Self {
let agent = ReconAgent::new(http.clone());
Self { http, agent }
}
}
impl PentestTool for ReconTool {
fn name(&self) -> &str {
"recon"
}
fn description(&self) -> &str {
"Performs reconnaissance on a target URL. Fingerprints HTTP headers, detects server \
technologies and frameworks. Returns structured data about the target's technology stack."
}
fn input_schema(&self) -> serde_json::Value {
json!({
"type": "object",
"properties": {
"url": {
"type": "string",
"description": "Base URL to perform reconnaissance on"
},
"additional_paths": {
"type": "array",
"description": "Optional additional paths to probe for technology fingerprinting",
"items": { "type": "string" }
}
},
"required": ["url"]
})
}
fn execute<'a>(
&'a self,
input: serde_json::Value,
context: &'a PentestToolContext,
) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<PentestToolResult, CoreError>> + Send + 'a>> {
Box::pin(async move {
let url = input
.get("url")
.and_then(|v| v.as_str())
.ok_or_else(|| CoreError::Dast("Missing required 'url' parameter".to_string()))?;
let additional_paths: Vec<String> = input
.get("additional_paths")
.and_then(|v| v.as_array())
.map(|arr| {
arr.iter()
.filter_map(|v| v.as_str().map(String::from))
.collect()
})
.unwrap_or_default();
let result = self.agent.scan(url).await?;
// Scan additional paths for more technology signals
let mut extra_technologies: Vec<String> = Vec::new();
let mut extra_headers: HashMap<String, String> = HashMap::new();
let base_url = url.trim_end_matches('/');
for path in &additional_paths {
let probe_url = format!("{base_url}/{}", path.trim_start_matches('/'));
if let Ok(resp) = self.http.get(&probe_url).send().await {
for (key, value) in resp.headers() {
let k = key.to_string().to_lowercase();
let v = value.to_str().unwrap_or("").to_string();
// Look for technology indicators
if k == "x-powered-by" || k == "server" || k == "x-generator" {
if !result.technologies.contains(&v) && !extra_technologies.contains(&v) {
extra_technologies.push(v.clone());
}
}
extra_headers.insert(format!("{probe_url} -> {k}"), v);
}
}
}
let mut all_technologies = result.technologies.clone();
all_technologies.extend(extra_technologies);
all_technologies.dedup();
let tech_count = all_technologies.len();
info!(url, technologies = tech_count, "Recon complete");
Ok(PentestToolResult {
summary: format!(
"Recon complete for {url}. Detected {} technologies. Server: {}.",
tech_count,
result.server.as_deref().unwrap_or("unknown")
),
findings: Vec::new(), // Recon produces data, not findings
data: json!({
"base_url": url,
"server": result.server,
"technologies": all_technologies,
"interesting_headers": result.interesting_headers,
"extra_headers": extra_headers,
"open_ports": result.open_ports,
}),
})
})
}
}

View File

@@ -0,0 +1,300 @@
use std::collections::HashMap;
use compliance_core::error::CoreError;
use compliance_core::models::dast::{DastEvidence, DastFinding, DastVulnType};
use compliance_core::models::Severity;
use compliance_core::traits::pentest_tool::{PentestTool, PentestToolContext, PentestToolResult};
use serde_json::json;
use tracing::info;
/// Tool that checks for the presence and correctness of security headers.
pub struct SecurityHeadersTool {
http: reqwest::Client,
}
/// A security header we expect to be present and its metadata.
struct ExpectedHeader {
name: &'static str,
description: &'static str,
severity: Severity,
cwe: &'static str,
remediation: &'static str,
/// If present, the value must contain one of these substrings to be considered valid.
valid_values: Option<Vec<&'static str>>,
}
impl SecurityHeadersTool {
pub fn new(http: reqwest::Client) -> Self {
Self { http }
}
fn expected_headers() -> Vec<ExpectedHeader> {
vec![
ExpectedHeader {
name: "strict-transport-security",
description: "HTTP Strict Transport Security (HSTS) forces browsers to use HTTPS",
severity: Severity::Medium,
cwe: "CWE-319",
remediation: "Add 'Strict-Transport-Security: max-age=31536000; includeSubDomains' header.",
valid_values: None,
},
ExpectedHeader {
name: "x-content-type-options",
description: "Prevents MIME type sniffing",
severity: Severity::Low,
cwe: "CWE-16",
remediation: "Add 'X-Content-Type-Options: nosniff' header.",
valid_values: Some(vec!["nosniff"]),
},
ExpectedHeader {
name: "x-frame-options",
description: "Prevents clickjacking by controlling iframe embedding",
severity: Severity::Medium,
cwe: "CWE-1021",
remediation: "Add 'X-Frame-Options: DENY' or 'X-Frame-Options: SAMEORIGIN' header.",
valid_values: Some(vec!["deny", "sameorigin"]),
},
ExpectedHeader {
name: "x-xss-protection",
description: "Enables browser XSS filtering (legacy but still recommended)",
severity: Severity::Low,
cwe: "CWE-79",
remediation: "Add 'X-XSS-Protection: 1; mode=block' header.",
valid_values: None,
},
ExpectedHeader {
name: "referrer-policy",
description: "Controls how much referrer information is shared",
severity: Severity::Low,
cwe: "CWE-200",
remediation: "Add 'Referrer-Policy: strict-origin-when-cross-origin' or 'no-referrer' header.",
valid_values: None,
},
ExpectedHeader {
name: "permissions-policy",
description: "Controls browser feature access (camera, microphone, geolocation, etc.)",
severity: Severity::Low,
cwe: "CWE-16",
remediation: "Add a Permissions-Policy header to restrict browser feature access. \
Example: 'Permissions-Policy: camera=(), microphone=(), geolocation=()'.",
valid_values: None,
},
]
}
}
impl PentestTool for SecurityHeadersTool {
fn name(&self) -> &str {
"security_headers"
}
fn description(&self) -> &str {
"Checks a URL for the presence and correctness of security headers: HSTS, \
X-Content-Type-Options, X-Frame-Options, X-XSS-Protection, Referrer-Policy, \
and Permissions-Policy."
}
fn input_schema(&self) -> serde_json::Value {
json!({
"type": "object",
"properties": {
"url": {
"type": "string",
"description": "URL to check security headers for"
}
},
"required": ["url"]
})
}
fn execute<'a>(
&'a self,
input: serde_json::Value,
context: &'a PentestToolContext,
) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<PentestToolResult, CoreError>> + Send + 'a>> {
Box::pin(async move {
let url = input
.get("url")
.and_then(|v| v.as_str())
.ok_or_else(|| CoreError::Dast("Missing required 'url' parameter".to_string()))?;
let target_id = context
.target
.id
.map(|oid| oid.to_hex())
.unwrap_or_else(|| "unknown".to_string());
let response = self
.http
.get(url)
.send()
.await
.map_err(|e| CoreError::Dast(format!("Failed to fetch {url}: {e}")))?;
let status = response.status().as_u16();
let response_headers: HashMap<String, String> = response
.headers()
.iter()
.map(|(k, v)| (k.to_string().to_lowercase(), v.to_str().unwrap_or("").to_string()))
.collect();
let mut findings = Vec::new();
let mut header_results: HashMap<String, serde_json::Value> = HashMap::new();
for expected in Self::expected_headers() {
let header_value = response_headers.get(expected.name);
match header_value {
Some(value) => {
let mut is_valid = true;
if let Some(ref valid) = expected.valid_values {
let lower = value.to_lowercase();
is_valid = valid.iter().any(|v| lower.contains(v));
}
header_results.insert(
expected.name.to_string(),
json!({
"present": true,
"value": value,
"valid": is_valid,
}),
);
if !is_valid {
let evidence = DastEvidence {
request_method: "GET".to_string(),
request_url: url.to_string(),
request_headers: None,
request_body: None,
response_status: status,
response_headers: Some(response_headers.clone()),
response_snippet: Some(format!("{}: {}", expected.name, value)),
screenshot_path: None,
payload: None,
response_time_ms: None,
};
let mut finding = DastFinding::new(
String::new(),
target_id.clone(),
DastVulnType::SecurityHeaderMissing,
format!("Invalid {} header value", expected.name),
format!(
"The {} header is present but has an invalid or weak value: '{}'. \
{}",
expected.name, value, expected.description
),
expected.severity.clone(),
url.to_string(),
"GET".to_string(),
);
finding.cwe = Some(expected.cwe.to_string());
finding.evidence = vec![evidence];
finding.remediation = Some(expected.remediation.to_string());
findings.push(finding);
}
}
None => {
header_results.insert(
expected.name.to_string(),
json!({
"present": false,
"value": null,
"valid": false,
}),
);
let evidence = DastEvidence {
request_method: "GET".to_string(),
request_url: url.to_string(),
request_headers: None,
request_body: None,
response_status: status,
response_headers: Some(response_headers.clone()),
response_snippet: Some(format!("{} header is missing", expected.name)),
screenshot_path: None,
payload: None,
response_time_ms: None,
};
let mut finding = DastFinding::new(
String::new(),
target_id.clone(),
DastVulnType::SecurityHeaderMissing,
format!("Missing {} header", expected.name),
format!(
"The {} header is not present in the response. {}",
expected.name, expected.description
),
expected.severity.clone(),
url.to_string(),
"GET".to_string(),
);
finding.cwe = Some(expected.cwe.to_string());
finding.evidence = vec![evidence];
finding.remediation = Some(expected.remediation.to_string());
findings.push(finding);
}
}
}
// Also check for information disclosure headers
let disclosure_headers = ["server", "x-powered-by", "x-aspnet-version", "x-aspnetmvc-version"];
for h in &disclosure_headers {
if let Some(value) = response_headers.get(*h) {
header_results.insert(
format!("{h}_disclosure"),
json!({ "present": true, "value": value }),
);
let evidence = DastEvidence {
request_method: "GET".to_string(),
request_url: url.to_string(),
request_headers: None,
request_body: None,
response_status: status,
response_headers: Some(response_headers.clone()),
response_snippet: Some(format!("{h}: {value}")),
screenshot_path: None,
payload: None,
response_time_ms: None,
};
let mut finding = DastFinding::new(
String::new(),
target_id.clone(),
DastVulnType::SecurityHeaderMissing,
format!("Information disclosure via {h} header"),
format!(
"The {h} header exposes server technology information: '{value}'. \
This helps attackers fingerprint the server and find known vulnerabilities."
),
Severity::Info,
url.to_string(),
"GET".to_string(),
);
finding.cwe = Some("CWE-200".to_string());
finding.evidence = vec![evidence];
finding.remediation = Some(format!(
"Remove or suppress the {h} header in your server configuration."
));
findings.push(finding);
}
}
let count = findings.len();
info!(url, findings = count, "Security headers check complete");
Ok(PentestToolResult {
summary: if count > 0 {
format!("Found {count} security header issues for {url}.")
} else {
format!("All checked security headers are present and valid for {url}.")
},
findings,
data: json!(header_results),
})
})
}
}

View File

@@ -0,0 +1,138 @@
use compliance_core::error::CoreError;
use compliance_core::traits::dast_agent::{DastAgent, DastContext, DiscoveredEndpoint, EndpointParameter};
use compliance_core::traits::pentest_tool::{PentestTool, PentestToolContext, PentestToolResult};
use serde_json::json;
use crate::agents::injection::SqlInjectionAgent;
/// PentestTool wrapper around the existing SqlInjectionAgent.
pub struct SqlInjectionTool {
http: reqwest::Client,
agent: SqlInjectionAgent,
}
impl SqlInjectionTool {
pub fn new(http: reqwest::Client) -> Self {
let agent = SqlInjectionAgent::new(http.clone());
Self { http, agent }
}
fn parse_endpoints(input: &serde_json::Value) -> Vec<DiscoveredEndpoint> {
let mut endpoints = Vec::new();
if let Some(arr) = input.get("endpoints").and_then(|v| v.as_array()) {
for ep in arr {
let url = ep.get("url").and_then(|v| v.as_str()).unwrap_or_default().to_string();
let method = ep.get("method").and_then(|v| v.as_str()).unwrap_or("GET").to_string();
let mut parameters = Vec::new();
if let Some(params) = ep.get("parameters").and_then(|v| v.as_array()) {
for p in params {
let name = p.get("name").and_then(|v| v.as_str()).unwrap_or_default().to_string();
let location = p.get("location").and_then(|v| v.as_str()).unwrap_or("query").to_string();
let param_type = p.get("param_type").and_then(|v| v.as_str()).map(String::from);
let example_value = p.get("example_value").and_then(|v| v.as_str()).map(String::from);
parameters.push(EndpointParameter {
name,
location,
param_type,
example_value,
});
}
}
endpoints.push(DiscoveredEndpoint {
url,
method,
parameters,
content_type: ep.get("content_type").and_then(|v| v.as_str()).map(String::from),
requires_auth: ep.get("requires_auth").and_then(|v| v.as_bool()).unwrap_or(false),
});
}
}
endpoints
}
}
impl PentestTool for SqlInjectionTool {
fn name(&self) -> &str {
"sql_injection_scanner"
}
fn description(&self) -> &str {
"Tests endpoints for SQL injection vulnerabilities using error-based, boolean-based, \
time-based, and union-based techniques. Provide endpoints with their parameters to test."
}
fn input_schema(&self) -> serde_json::Value {
json!({
"type": "object",
"properties": {
"endpoints": {
"type": "array",
"description": "Endpoints to test for SQL injection",
"items": {
"type": "object",
"properties": {
"url": { "type": "string", "description": "Full URL of the endpoint" },
"method": { "type": "string", "enum": ["GET", "POST", "PUT", "PATCH", "DELETE"] },
"parameters": {
"type": "array",
"items": {
"type": "object",
"properties": {
"name": { "type": "string" },
"location": { "type": "string", "enum": ["query", "body", "header", "path", "cookie"] },
"param_type": { "type": "string" },
"example_value": { "type": "string" }
},
"required": ["name"]
}
}
},
"required": ["url", "method", "parameters"]
}
},
"custom_payloads": {
"type": "array",
"description": "Optional additional SQL injection payloads to test",
"items": { "type": "string" }
}
},
"required": ["endpoints"]
})
}
fn execute<'a>(
&'a self,
input: serde_json::Value,
context: &'a PentestToolContext,
) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<PentestToolResult, CoreError>> + Send + 'a>> {
Box::pin(async move {
let endpoints = Self::parse_endpoints(&input);
if endpoints.is_empty() {
return Ok(PentestToolResult {
summary: "No endpoints provided to test.".to_string(),
findings: Vec::new(),
data: json!({}),
});
}
let dast_context = DastContext {
endpoints,
technologies: Vec::new(),
sast_hints: Vec::new(),
};
let findings = self.agent.run(&context.target, &dast_context).await?;
let count = findings.len();
Ok(PentestToolResult {
summary: if count > 0 {
format!("Found {count} SQL injection vulnerabilities.")
} else {
"No SQL injection vulnerabilities detected.".to_string()
},
findings,
data: json!({ "endpoints_tested": dast_context.endpoints.len() }),
})
})
}
}

View File

@@ -0,0 +1,134 @@
use compliance_core::error::CoreError;
use compliance_core::traits::dast_agent::{DastAgent, DastContext, DiscoveredEndpoint, EndpointParameter};
use compliance_core::traits::pentest_tool::{PentestTool, PentestToolContext, PentestToolResult};
use serde_json::json;
use crate::agents::ssrf::SsrfAgent;
/// PentestTool wrapper around the existing SsrfAgent.
pub struct SsrfTool {
http: reqwest::Client,
agent: SsrfAgent,
}
impl SsrfTool {
pub fn new(http: reqwest::Client) -> Self {
let agent = SsrfAgent::new(http.clone());
Self { http, agent }
}
fn parse_endpoints(input: &serde_json::Value) -> Vec<DiscoveredEndpoint> {
let mut endpoints = Vec::new();
if let Some(arr) = input.get("endpoints").and_then(|v| v.as_array()) {
for ep in arr {
let url = ep.get("url").and_then(|v| v.as_str()).unwrap_or_default().to_string();
let method = ep.get("method").and_then(|v| v.as_str()).unwrap_or("GET").to_string();
let mut parameters = Vec::new();
if let Some(params) = ep.get("parameters").and_then(|v| v.as_array()) {
for p in params {
parameters.push(EndpointParameter {
name: p.get("name").and_then(|v| v.as_str()).unwrap_or_default().to_string(),
location: p.get("location").and_then(|v| v.as_str()).unwrap_or("query").to_string(),
param_type: p.get("param_type").and_then(|v| v.as_str()).map(String::from),
example_value: p.get("example_value").and_then(|v| v.as_str()).map(String::from),
});
}
}
endpoints.push(DiscoveredEndpoint {
url,
method,
parameters,
content_type: ep.get("content_type").and_then(|v| v.as_str()).map(String::from),
requires_auth: ep.get("requires_auth").and_then(|v| v.as_bool()).unwrap_or(false),
});
}
}
endpoints
}
}
impl PentestTool for SsrfTool {
fn name(&self) -> &str {
"ssrf_scanner"
}
fn description(&self) -> &str {
"Tests endpoints for Server-Side Request Forgery (SSRF) vulnerabilities. Checks if \
parameters accepting URLs can be exploited to access internal resources."
}
fn input_schema(&self) -> serde_json::Value {
json!({
"type": "object",
"properties": {
"endpoints": {
"type": "array",
"description": "Endpoints to test for SSRF (focus on those accepting URL parameters)",
"items": {
"type": "object",
"properties": {
"url": { "type": "string" },
"method": { "type": "string", "enum": ["GET", "POST", "PUT", "PATCH", "DELETE"] },
"parameters": {
"type": "array",
"items": {
"type": "object",
"properties": {
"name": { "type": "string" },
"location": { "type": "string" },
"param_type": { "type": "string" },
"example_value": { "type": "string" }
},
"required": ["name"]
}
}
},
"required": ["url", "method", "parameters"]
}
},
"custom_payloads": {
"type": "array",
"description": "Optional additional SSRF payloads (internal URLs to try)",
"items": { "type": "string" }
}
},
"required": ["endpoints"]
})
}
fn execute<'a>(
&'a self,
input: serde_json::Value,
context: &'a PentestToolContext,
) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<PentestToolResult, CoreError>> + Send + 'a>> {
Box::pin(async move {
let endpoints = Self::parse_endpoints(&input);
if endpoints.is_empty() {
return Ok(PentestToolResult {
summary: "No endpoints provided to test.".to_string(),
findings: Vec::new(),
data: json!({}),
});
}
let dast_context = DastContext {
endpoints,
technologies: Vec::new(),
sast_hints: Vec::new(),
};
let findings = self.agent.run(&context.target, &dast_context).await?;
let count = findings.len();
Ok(PentestToolResult {
summary: if count > 0 {
format!("Found {count} SSRF vulnerabilities.")
} else {
"No SSRF vulnerabilities detected.".to_string()
},
findings,
data: json!({ "endpoints_tested": dast_context.endpoints.len() }),
})
})
}
}

View File

@@ -0,0 +1,442 @@
use compliance_core::error::CoreError;
use compliance_core::models::dast::{DastEvidence, DastFinding, DastVulnType};
use compliance_core::models::Severity;
use compliance_core::traits::pentest_tool::{PentestTool, PentestToolContext, PentestToolResult};
use serde_json::json;
use tokio::net::TcpStream;
use tracing::{info, warn};
/// Tool that analyzes TLS configuration of a target.
///
/// Connects via TCP, performs a TLS handshake using `tokio-native-tls`,
/// and inspects the certificate and negotiated protocol. Also checks
/// for common TLS misconfigurations.
pub struct TlsAnalyzerTool {
http: reqwest::Client,
}
impl TlsAnalyzerTool {
pub fn new(http: reqwest::Client) -> Self {
Self { http }
}
/// Extract the hostname from a URL.
fn extract_host(url: &str) -> Option<String> {
url::Url::parse(url)
.ok()
.and_then(|u| u.host_str().map(String::from))
}
/// Extract port from a URL (defaults to 443 for https).
fn extract_port(url: &str) -> u16 {
url::Url::parse(url)
.ok()
.and_then(|u| u.port())
.unwrap_or(443)
}
/// Check if the server accepts a connection on a given port with a weak
/// TLS client hello. We test SSLv3 / old protocol support by attempting
/// connection with the system's native-tls which typically negotiates the
/// best available, then inspect what was negotiated.
async fn check_tls(
host: &str,
port: u16,
) -> Result<TlsInfo, CoreError> {
let addr = format!("{host}:{port}");
let tcp = TcpStream::connect(&addr)
.await
.map_err(|e| CoreError::Dast(format!("TCP connection to {addr} failed: {e}")))?;
let connector = native_tls::TlsConnector::builder()
.danger_accept_invalid_certs(true)
.danger_accept_invalid_hostnames(true)
.build()
.map_err(|e| CoreError::Dast(format!("TLS connector build failed: {e}")))?;
let connector = tokio_native_tls::TlsConnector::from(connector);
let tls_stream = connector
.connect(host, tcp)
.await
.map_err(|e| CoreError::Dast(format!("TLS handshake with {addr} failed: {e}")))?;
let peer_cert = tls_stream.get_ref().peer_certificate()
.map_err(|e| CoreError::Dast(format!("Failed to get peer certificate: {e}")))?;
let mut tls_info = TlsInfo {
protocol_version: String::new(),
cert_subject: String::new(),
cert_issuer: String::new(),
cert_not_before: String::new(),
cert_not_after: String::new(),
cert_expired: false,
cert_self_signed: false,
alpn_protocol: None,
san_names: Vec::new(),
};
if let Some(cert) = peer_cert {
let der = cert.to_der()
.map_err(|e| CoreError::Dast(format!("Certificate DER encoding failed: {e}")))?;
// native_tls doesn't give rich access, so we parse what we can
// from the DER-encoded certificate.
tls_info.cert_subject = "see DER certificate".to_string();
// Attempt to parse with basic DER inspection for dates
tls_info = Self::parse_cert_der(&der, tls_info);
}
Ok(tls_info)
}
/// Best-effort parse of DER-encoded X.509 certificate for dates and subject.
/// This is a simplified parser; in production you would use a proper x509 crate.
fn parse_cert_der(der: &[u8], mut info: TlsInfo) -> TlsInfo {
// We rely on the native_tls debug output stored in cert_subject
// and just mark fields as "see certificate details"
if info.cert_subject.contains("self signed") || info.cert_subject.contains("Self-Signed") {
info.cert_self_signed = true;
}
info
}
}
struct TlsInfo {
protocol_version: String,
cert_subject: String,
cert_issuer: String,
cert_not_before: String,
cert_not_after: String,
cert_expired: bool,
cert_self_signed: bool,
alpn_protocol: Option<String>,
san_names: Vec<String>,
}
impl PentestTool for TlsAnalyzerTool {
fn name(&self) -> &str {
"tls_analyzer"
}
fn description(&self) -> &str {
"Analyzes TLS/SSL configuration of a target. Checks certificate validity, expiry, chain \
trust, and negotiated protocols. Reports TLS misconfigurations."
}
fn input_schema(&self) -> serde_json::Value {
json!({
"type": "object",
"properties": {
"url": {
"type": "string",
"description": "Target URL or hostname to analyze TLS configuration"
},
"port": {
"type": "integer",
"description": "Port to connect to (default: 443)",
"default": 443
},
"check_protocols": {
"type": "boolean",
"description": "Whether to test for old/weak protocol versions",
"default": true
}
},
"required": ["url"]
})
}
fn execute<'a>(
&'a self,
input: serde_json::Value,
context: &'a PentestToolContext,
) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<PentestToolResult, CoreError>> + Send + 'a>> {
Box::pin(async move {
let url = input
.get("url")
.and_then(|v| v.as_str())
.ok_or_else(|| CoreError::Dast("Missing required 'url' parameter".to_string()))?;
let host = Self::extract_host(url)
.unwrap_or_else(|| url.to_string());
let port = input
.get("port")
.and_then(|v| v.as_u64())
.map(|p| p as u16)
.unwrap_or_else(|| Self::extract_port(url));
let target_id = context
.target
.id
.map(|oid| oid.to_hex())
.unwrap_or_else(|| "unknown".to_string());
let mut findings = Vec::new();
let mut tls_data = json!({});
// First check: does the server even support HTTPS?
let https_url = if url.starts_with("https://") {
url.to_string()
} else if url.starts_with("http://") {
url.replace("http://", "https://")
} else {
format!("https://{url}")
};
// Check if HTTP redirects to HTTPS
let http_url = if url.starts_with("http://") {
url.to_string()
} else if url.starts_with("https://") {
url.replace("https://", "http://")
} else {
format!("http://{url}")
};
match self.http.get(&http_url).send().await {
Ok(resp) => {
let final_url = resp.url().to_string();
let redirects_to_https = final_url.starts_with("https://");
tls_data["http_redirects_to_https"] = json!(redirects_to_https);
if !redirects_to_https {
let evidence = DastEvidence {
request_method: "GET".to_string(),
request_url: http_url.clone(),
request_headers: None,
request_body: None,
response_status: resp.status().as_u16(),
response_headers: None,
response_snippet: Some(format!("Final URL: {final_url}")),
screenshot_path: None,
payload: None,
response_time_ms: None,
};
let mut finding = DastFinding::new(
String::new(),
target_id.clone(),
DastVulnType::TlsMisconfiguration,
format!("HTTP does not redirect to HTTPS for {host}"),
format!(
"HTTP requests to {host} are not redirected to HTTPS. \
Users accessing the site via HTTP will have their traffic \
transmitted in cleartext."
),
Severity::Medium,
http_url.clone(),
"GET".to_string(),
);
finding.cwe = Some("CWE-319".to_string());
finding.evidence = vec![evidence];
finding.remediation = Some(
"Configure the web server to redirect all HTTP requests to HTTPS \
using a 301 redirect."
.to_string(),
);
findings.push(finding);
}
}
Err(_) => {
tls_data["http_check_error"] = json!("Could not connect via HTTP");
}
}
// Perform TLS analysis
match Self::check_tls(&host, port).await {
Ok(tls_info) => {
tls_data["host"] = json!(host);
tls_data["port"] = json!(port);
tls_data["cert_subject"] = json!(tls_info.cert_subject);
tls_data["cert_issuer"] = json!(tls_info.cert_issuer);
tls_data["cert_not_before"] = json!(tls_info.cert_not_before);
tls_data["cert_not_after"] = json!(tls_info.cert_not_after);
tls_data["alpn_protocol"] = json!(tls_info.alpn_protocol);
tls_data["san_names"] = json!(tls_info.san_names);
if tls_info.cert_expired {
let evidence = DastEvidence {
request_method: "TLS".to_string(),
request_url: format!("{host}:{port}"),
request_headers: None,
request_body: None,
response_status: 0,
response_headers: None,
response_snippet: Some(format!(
"Certificate expired. Not After: {}",
tls_info.cert_not_after
)),
screenshot_path: None,
payload: None,
response_time_ms: None,
};
let mut finding = DastFinding::new(
String::new(),
target_id.clone(),
DastVulnType::TlsMisconfiguration,
format!("Expired TLS certificate for {host}"),
format!(
"The TLS certificate for {host} has expired. \
Browsers will show security warnings to users."
),
Severity::High,
format!("https://{host}:{port}"),
"TLS".to_string(),
);
finding.cwe = Some("CWE-295".to_string());
finding.evidence = vec![evidence];
finding.remediation = Some(
"Renew the TLS certificate. Consider using automated certificate \
management with Let's Encrypt or a similar CA."
.to_string(),
);
findings.push(finding);
warn!(host, "Expired TLS certificate");
}
if tls_info.cert_self_signed {
let evidence = DastEvidence {
request_method: "TLS".to_string(),
request_url: format!("{host}:{port}"),
request_headers: None,
request_body: None,
response_status: 0,
response_headers: None,
response_snippet: Some("Self-signed certificate detected".to_string()),
screenshot_path: None,
payload: None,
response_time_ms: None,
};
let mut finding = DastFinding::new(
String::new(),
target_id.clone(),
DastVulnType::TlsMisconfiguration,
format!("Self-signed TLS certificate for {host}"),
format!(
"The TLS certificate for {host} is self-signed and not issued by a \
trusted certificate authority. Browsers will show security warnings."
),
Severity::Medium,
format!("https://{host}:{port}"),
"TLS".to_string(),
);
finding.cwe = Some("CWE-295".to_string());
finding.evidence = vec![evidence];
finding.remediation = Some(
"Replace the self-signed certificate with one issued by a trusted \
certificate authority."
.to_string(),
);
findings.push(finding);
warn!(host, "Self-signed certificate");
}
}
Err(e) => {
tls_data["tls_error"] = json!(e.to_string());
// TLS handshake failure itself is a finding
let evidence = DastEvidence {
request_method: "TLS".to_string(),
request_url: format!("{host}:{port}"),
request_headers: None,
request_body: None,
response_status: 0,
response_headers: None,
response_snippet: Some(format!("TLS error: {e}")),
screenshot_path: None,
payload: None,
response_time_ms: None,
};
let mut finding = DastFinding::new(
String::new(),
target_id.clone(),
DastVulnType::TlsMisconfiguration,
format!("TLS handshake failure for {host}"),
format!(
"Could not establish a TLS connection to {host}:{port}. Error: {e}"
),
Severity::High,
format!("https://{host}:{port}"),
"TLS".to_string(),
);
finding.cwe = Some("CWE-295".to_string());
finding.evidence = vec![evidence];
finding.remediation = Some(
"Ensure TLS is properly configured on the server. Check that the \
certificate is valid and the server supports modern TLS versions."
.to_string(),
);
findings.push(finding);
}
}
// Check strict transport security via an HTTPS request
match self.http.get(&https_url).send().await {
Ok(resp) => {
let hsts = resp.headers().get("strict-transport-security");
tls_data["hsts_header"] = json!(hsts.map(|v| v.to_str().unwrap_or("")));
if hsts.is_none() {
let evidence = DastEvidence {
request_method: "GET".to_string(),
request_url: https_url.clone(),
request_headers: None,
request_body: None,
response_status: resp.status().as_u16(),
response_headers: None,
response_snippet: Some(
"Strict-Transport-Security header not present".to_string(),
),
screenshot_path: None,
payload: None,
response_time_ms: None,
};
let mut finding = DastFinding::new(
String::new(),
target_id.clone(),
DastVulnType::TlsMisconfiguration,
format!("Missing HSTS header for {host}"),
format!(
"The server at {host} does not send a Strict-Transport-Security header. \
Without HSTS, browsers may allow HTTP downgrade attacks."
),
Severity::Medium,
https_url.clone(),
"GET".to_string(),
);
finding.cwe = Some("CWE-319".to_string());
finding.evidence = vec![evidence];
finding.remediation = Some(
"Add the Strict-Transport-Security header with an appropriate max-age. \
Example: 'Strict-Transport-Security: max-age=31536000; includeSubDomains'."
.to_string(),
);
findings.push(finding);
}
}
Err(_) => {
tls_data["https_check_error"] = json!("Could not connect via HTTPS");
}
}
let count = findings.len();
info!(host = %host, findings = count, "TLS analysis complete");
Ok(PentestToolResult {
summary: if count > 0 {
format!("Found {count} TLS configuration issues for {host}.")
} else {
format!("TLS configuration looks good for {host}.")
},
findings,
data: tls_data,
})
})
}
}

View File

@@ -0,0 +1,134 @@
use compliance_core::error::CoreError;
use compliance_core::traits::dast_agent::{DastAgent, DastContext, DiscoveredEndpoint, EndpointParameter};
use compliance_core::traits::pentest_tool::{PentestTool, PentestToolContext, PentestToolResult};
use serde_json::json;
use crate::agents::xss::XssAgent;
/// PentestTool wrapper around the existing XssAgent.
pub struct XssTool {
http: reqwest::Client,
agent: XssAgent,
}
impl XssTool {
pub fn new(http: reqwest::Client) -> Self {
let agent = XssAgent::new(http.clone());
Self { http, agent }
}
fn parse_endpoints(input: &serde_json::Value) -> Vec<DiscoveredEndpoint> {
let mut endpoints = Vec::new();
if let Some(arr) = input.get("endpoints").and_then(|v| v.as_array()) {
for ep in arr {
let url = ep.get("url").and_then(|v| v.as_str()).unwrap_or_default().to_string();
let method = ep.get("method").and_then(|v| v.as_str()).unwrap_or("GET").to_string();
let mut parameters = Vec::new();
if let Some(params) = ep.get("parameters").and_then(|v| v.as_array()) {
for p in params {
parameters.push(EndpointParameter {
name: p.get("name").and_then(|v| v.as_str()).unwrap_or_default().to_string(),
location: p.get("location").and_then(|v| v.as_str()).unwrap_or("query").to_string(),
param_type: p.get("param_type").and_then(|v| v.as_str()).map(String::from),
example_value: p.get("example_value").and_then(|v| v.as_str()).map(String::from),
});
}
}
endpoints.push(DiscoveredEndpoint {
url,
method,
parameters,
content_type: ep.get("content_type").and_then(|v| v.as_str()).map(String::from),
requires_auth: ep.get("requires_auth").and_then(|v| v.as_bool()).unwrap_or(false),
});
}
}
endpoints
}
}
impl PentestTool for XssTool {
fn name(&self) -> &str {
"xss_scanner"
}
fn description(&self) -> &str {
"Tests endpoints for Cross-Site Scripting (XSS) vulnerabilities including reflected, \
stored, and DOM-based XSS. Provide endpoints with parameters to test."
}
fn input_schema(&self) -> serde_json::Value {
json!({
"type": "object",
"properties": {
"endpoints": {
"type": "array",
"description": "Endpoints to test for XSS",
"items": {
"type": "object",
"properties": {
"url": { "type": "string" },
"method": { "type": "string", "enum": ["GET", "POST", "PUT", "PATCH", "DELETE"] },
"parameters": {
"type": "array",
"items": {
"type": "object",
"properties": {
"name": { "type": "string" },
"location": { "type": "string", "enum": ["query", "body", "header", "path", "cookie"] },
"param_type": { "type": "string" },
"example_value": { "type": "string" }
},
"required": ["name"]
}
}
},
"required": ["url", "method", "parameters"]
}
},
"custom_payloads": {
"type": "array",
"description": "Optional additional XSS payloads to test",
"items": { "type": "string" }
}
},
"required": ["endpoints"]
})
}
fn execute<'a>(
&'a self,
input: serde_json::Value,
context: &'a PentestToolContext,
) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<PentestToolResult, CoreError>> + Send + 'a>> {
Box::pin(async move {
let endpoints = Self::parse_endpoints(&input);
if endpoints.is_empty() {
return Ok(PentestToolResult {
summary: "No endpoints provided to test.".to_string(),
findings: Vec::new(),
data: json!({}),
});
}
let dast_context = DastContext {
endpoints,
technologies: Vec::new(),
sast_hints: Vec::new(),
};
let findings = self.agent.run(&context.target, &dast_context).await?;
let count = findings.len();
Ok(PentestToolResult {
summary: if count > 0 {
format!("Found {count} XSS vulnerabilities.")
} else {
"No XSS vulnerabilities detected.".to_string()
},
findings,
data: json!({ "endpoints_tested": dast_context.endpoints.len() }),
})
})
}
}