Compare commits
7 Commits
76260acc76
...
cc6ae7717c
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
cc6ae7717c | ||
|
|
0428cba2b8 | ||
|
|
30301a12b5 | ||
|
|
af98e3e070 | ||
|
|
85ceef7e1f | ||
|
|
c0f9ba467c | ||
|
|
71d8741e10 |
89
Cargo.lock
generated
89
Cargo.lock
generated
@@ -680,12 +680,14 @@ dependencies = [
|
||||
"chrono",
|
||||
"compliance-core",
|
||||
"mongodb",
|
||||
"native-tls",
|
||||
"reqwest",
|
||||
"scraper",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"thiserror 2.0.18",
|
||||
"tokio",
|
||||
"tokio-native-tls",
|
||||
"tracing",
|
||||
"url",
|
||||
"uuid",
|
||||
@@ -1994,6 +1996,21 @@ version = "0.2.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "77ce24cb58228fbb8aa041425bb1050850ac19177686ea6e0f41a70416f56fdb"
|
||||
|
||||
[[package]]
|
||||
name = "foreign-types"
|
||||
version = "0.3.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f6f339eb8adc052cd2ca78910fda869aefa38d22d5cb648e6485e4d3fc06f3b1"
|
||||
dependencies = [
|
||||
"foreign-types-shared",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "foreign-types-shared"
|
||||
version = "0.1.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "00b0228411908ca8685dba7fc2cdd70ec9990a6e753e89b6ac91a84c40fbaf4b"
|
||||
|
||||
[[package]]
|
||||
name = "form_urlencoded"
|
||||
version = "1.2.2"
|
||||
@@ -2824,15 +2841,6 @@ dependencies = [
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "itertools"
|
||||
version = "0.12.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ba291022dbbd398a455acf126c1e341954079855bc60dfdda641363bd6922569"
|
||||
dependencies = [
|
||||
"either",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "itertools"
|
||||
version = "0.14.0"
|
||||
@@ -3399,6 +3407,23 @@ version = "0.3.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2195bf6aa996a481483b29d62a7663eed3fe39600c460e323f8ff41e90bdd89b"
|
||||
|
||||
[[package]]
|
||||
name = "native-tls"
|
||||
version = "0.2.18"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "465500e14ea162429d264d44189adc38b199b62b1c21eea9f69e4b73cb03bbf2"
|
||||
dependencies = [
|
||||
"libc",
|
||||
"log",
|
||||
"openssl",
|
||||
"openssl-probe 0.2.1",
|
||||
"openssl-sys",
|
||||
"schannel",
|
||||
"security-framework",
|
||||
"security-framework-sys",
|
||||
"tempfile",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ndk"
|
||||
version = "0.9.0"
|
||||
@@ -3578,6 +3603,32 @@ version = "0.1.13"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "269bca4c2591a28585d6bf10d9ed0332b7d76900a1b02bec41bdc3a2cdcda107"
|
||||
|
||||
[[package]]
|
||||
name = "openssl"
|
||||
version = "0.10.75"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "08838db121398ad17ab8531ce9de97b244589089e290a384c900cb9ff7434328"
|
||||
dependencies = [
|
||||
"bitflags",
|
||||
"cfg-if",
|
||||
"foreign-types",
|
||||
"libc",
|
||||
"once_cell",
|
||||
"openssl-macros",
|
||||
"openssl-sys",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "openssl-macros"
|
||||
version = "0.1.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "openssl-probe"
|
||||
version = "0.1.6"
|
||||
@@ -3949,7 +4000,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8a56d757972c98b346a9b766e3f02746cde6dd1cd1d1d563472929fdd74bec4d"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"itertools 0.12.1",
|
||||
"itertools",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
@@ -4899,7 +4950,7 @@ version = "0.8.9"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c1c97747dbf44bb1ca44a561ece23508e99cb592e862f22222dcf42f51d1e451"
|
||||
dependencies = [
|
||||
"heck 0.4.1",
|
||||
"heck 0.5.0",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
@@ -5116,7 +5167,7 @@ dependencies = [
|
||||
"fs4",
|
||||
"htmlescape",
|
||||
"hyperloglogplus",
|
||||
"itertools 0.14.0",
|
||||
"itertools",
|
||||
"levenshtein_automata",
|
||||
"log",
|
||||
"lru 0.12.5",
|
||||
@@ -5164,7 +5215,7 @@ checksum = "8b628488ae936c83e92b5c4056833054ca56f76c0e616aee8339e24ac89119cd"
|
||||
dependencies = [
|
||||
"downcast-rs",
|
||||
"fastdivide",
|
||||
"itertools 0.14.0",
|
||||
"itertools",
|
||||
"serde",
|
||||
"tantivy-bitpacker",
|
||||
"tantivy-common",
|
||||
@@ -5214,7 +5265,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f8292095d1a8a2c2b36380ec455f910ab52dde516af36321af332c93f20ab7d5"
|
||||
dependencies = [
|
||||
"futures-util",
|
||||
"itertools 0.14.0",
|
||||
"itertools",
|
||||
"tantivy-bitpacker",
|
||||
"tantivy-common",
|
||||
"tantivy-fst",
|
||||
@@ -5428,6 +5479,16 @@ dependencies = [
|
||||
"syn",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tokio-native-tls"
|
||||
version = "0.3.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "bbae76ab933c85776efabc971569dd6119c580d8f5d448769dec1764bf796ef2"
|
||||
dependencies = [
|
||||
"native-tls",
|
||||
"tokio",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tokio-rustls"
|
||||
version = "0.26.4"
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
pub mod chat;
|
||||
pub mod dast;
|
||||
pub mod graph;
|
||||
pub mod pentest;
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
@@ -1108,7 +1109,7 @@ pub async fn list_scan_runs(
|
||||
}))
|
||||
}
|
||||
|
||||
async fn collect_cursor_async<T: serde::de::DeserializeOwned + Unpin + Send>(
|
||||
pub(crate) async fn collect_cursor_async<T: serde::de::DeserializeOwned + Unpin + Send>(
|
||||
mut cursor: mongodb::Cursor<T>,
|
||||
) -> Vec<T> {
|
||||
use futures_util::StreamExt;
|
||||
|
||||
772
compliance-agent/src/api/handlers/pentest.rs
Normal file
772
compliance-agent/src/api/handlers/pentest.rs
Normal file
@@ -0,0 +1,772 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use axum::extract::{Extension, Path, Query};
|
||||
use axum::http::StatusCode;
|
||||
use axum::response::sse::{Event, Sse};
|
||||
use axum::response::IntoResponse;
|
||||
use axum::Json;
|
||||
use futures_util::stream;
|
||||
use mongodb::bson::doc;
|
||||
use serde::Deserialize;
|
||||
|
||||
use compliance_core::models::dast::DastFinding;
|
||||
use compliance_core::models::pentest::*;
|
||||
|
||||
use crate::agent::ComplianceAgent;
|
||||
use crate::pentest::PentestOrchestrator;
|
||||
|
||||
use super::{collect_cursor_async, ApiResponse, PaginationParams};
|
||||
|
||||
type AgentExt = Extension<Arc<ComplianceAgent>>;
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct CreateSessionRequest {
|
||||
pub target_id: String,
|
||||
#[serde(default = "default_strategy")]
|
||||
pub strategy: String,
|
||||
pub message: Option<String>,
|
||||
}
|
||||
|
||||
fn default_strategy() -> String {
|
||||
"comprehensive".to_string()
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct SendMessageRequest {
|
||||
pub message: String,
|
||||
}
|
||||
|
||||
/// POST /api/v1/pentest/sessions — Create a new pentest session and start the orchestrator
|
||||
#[tracing::instrument(skip_all)]
|
||||
pub async fn create_session(
|
||||
Extension(agent): AgentExt,
|
||||
Json(req): Json<CreateSessionRequest>,
|
||||
) -> Result<Json<ApiResponse<PentestSession>>, (StatusCode, String)> {
|
||||
let oid = mongodb::bson::oid::ObjectId::parse_str(&req.target_id).map_err(|_| {
|
||||
(
|
||||
StatusCode::BAD_REQUEST,
|
||||
"Invalid target_id format".to_string(),
|
||||
)
|
||||
})?;
|
||||
|
||||
// Look up the target
|
||||
let target = agent
|
||||
.db
|
||||
.dast_targets()
|
||||
.find_one(doc! { "_id": oid })
|
||||
.await
|
||||
.map_err(|e| {
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Database error: {e}"),
|
||||
)
|
||||
})?
|
||||
.ok_or_else(|| (StatusCode::NOT_FOUND, "Target not found".to_string()))?;
|
||||
|
||||
// Parse strategy
|
||||
let strategy = match req.strategy.as_str() {
|
||||
"quick" => PentestStrategy::Quick,
|
||||
"targeted" => PentestStrategy::Targeted,
|
||||
"aggressive" => PentestStrategy::Aggressive,
|
||||
"stealth" => PentestStrategy::Stealth,
|
||||
_ => PentestStrategy::Comprehensive,
|
||||
};
|
||||
|
||||
// Create session
|
||||
let mut session = PentestSession::new(req.target_id.clone(), strategy);
|
||||
session.repo_id = target.repo_id.clone();
|
||||
|
||||
let insert_result = agent
|
||||
.db
|
||||
.pentest_sessions()
|
||||
.insert_one(&session)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Failed to create session: {e}"),
|
||||
)
|
||||
})?;
|
||||
|
||||
// Set the generated ID back on the session so the orchestrator has it
|
||||
session.id = insert_result.inserted_id.as_object_id();
|
||||
|
||||
let initial_message = req.message.unwrap_or_else(|| {
|
||||
format!(
|
||||
"Begin a {} penetration test against {} ({}). \
|
||||
Identify vulnerabilities and provide evidence for each finding.",
|
||||
session.strategy, target.name, target.base_url,
|
||||
)
|
||||
});
|
||||
|
||||
// Spawn the orchestrator on a background task
|
||||
let llm = agent.llm.clone();
|
||||
let db = agent.db.clone();
|
||||
let session_clone = session.clone();
|
||||
let target_clone = target.clone();
|
||||
tokio::spawn(async move {
|
||||
let orchestrator = PentestOrchestrator::new(llm, db);
|
||||
orchestrator
|
||||
.run_session_guarded(&session_clone, &target_clone, &initial_message)
|
||||
.await;
|
||||
});
|
||||
|
||||
Ok(Json(ApiResponse {
|
||||
data: session,
|
||||
total: None,
|
||||
page: None,
|
||||
}))
|
||||
}
|
||||
|
||||
/// GET /api/v1/pentest/sessions — List pentest sessions
|
||||
#[tracing::instrument(skip_all)]
|
||||
pub async fn list_sessions(
|
||||
Extension(agent): AgentExt,
|
||||
Query(params): Query<PaginationParams>,
|
||||
) -> Result<Json<ApiResponse<Vec<PentestSession>>>, StatusCode> {
|
||||
let db = &agent.db;
|
||||
let skip = (params.page.saturating_sub(1)) * params.limit as u64;
|
||||
let total = db
|
||||
.pentest_sessions()
|
||||
.count_documents(doc! {})
|
||||
.await
|
||||
.unwrap_or(0);
|
||||
|
||||
let sessions = match db
|
||||
.pentest_sessions()
|
||||
.find(doc! {})
|
||||
.sort(doc! { "started_at": -1 })
|
||||
.skip(skip)
|
||||
.limit(params.limit)
|
||||
.await
|
||||
{
|
||||
Ok(cursor) => collect_cursor_async(cursor).await,
|
||||
Err(e) => {
|
||||
tracing::warn!("Failed to fetch pentest sessions: {e}");
|
||||
Vec::new()
|
||||
}
|
||||
};
|
||||
|
||||
Ok(Json(ApiResponse {
|
||||
data: sessions,
|
||||
total: Some(total),
|
||||
page: Some(params.page),
|
||||
}))
|
||||
}
|
||||
|
||||
/// GET /api/v1/pentest/sessions/:id — Get a single pentest session
|
||||
#[tracing::instrument(skip_all, fields(session_id = %id))]
|
||||
pub async fn get_session(
|
||||
Extension(agent): AgentExt,
|
||||
Path(id): Path<String>,
|
||||
) -> Result<Json<ApiResponse<PentestSession>>, StatusCode> {
|
||||
let oid =
|
||||
mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?;
|
||||
|
||||
let session = agent
|
||||
.db
|
||||
.pentest_sessions()
|
||||
.find_one(doc! { "_id": oid })
|
||||
.await
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
|
||||
.ok_or(StatusCode::NOT_FOUND)?;
|
||||
|
||||
Ok(Json(ApiResponse {
|
||||
data: session,
|
||||
total: None,
|
||||
page: None,
|
||||
}))
|
||||
}
|
||||
|
||||
/// POST /api/v1/pentest/sessions/:id/chat — Send a user message and trigger next orchestrator iteration
|
||||
#[tracing::instrument(skip_all, fields(session_id = %id))]
|
||||
pub async fn send_message(
|
||||
Extension(agent): AgentExt,
|
||||
Path(id): Path<String>,
|
||||
Json(req): Json<SendMessageRequest>,
|
||||
) -> Result<Json<ApiResponse<PentestMessage>>, (StatusCode, String)> {
|
||||
let oid = mongodb::bson::oid::ObjectId::parse_str(&id)
|
||||
.map_err(|_| (StatusCode::BAD_REQUEST, "Invalid session ID".to_string()))?;
|
||||
|
||||
// Verify session exists and is running
|
||||
let session = agent
|
||||
.db
|
||||
.pentest_sessions()
|
||||
.find_one(doc! { "_id": oid })
|
||||
.await
|
||||
.map_err(|e| {
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Database error: {e}"),
|
||||
)
|
||||
})?
|
||||
.ok_or_else(|| (StatusCode::NOT_FOUND, "Session not found".to_string()))?;
|
||||
|
||||
if session.status != PentestStatus::Running && session.status != PentestStatus::Paused {
|
||||
return Err((
|
||||
StatusCode::BAD_REQUEST,
|
||||
format!("Session is {}, cannot send messages", session.status),
|
||||
));
|
||||
}
|
||||
|
||||
// Look up the target
|
||||
let target_oid =
|
||||
mongodb::bson::oid::ObjectId::parse_str(&session.target_id).map_err(|_| {
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
"Invalid target_id in session".to_string(),
|
||||
)
|
||||
})?;
|
||||
|
||||
let target = agent
|
||||
.db
|
||||
.dast_targets()
|
||||
.find_one(doc! { "_id": target_oid })
|
||||
.await
|
||||
.map_err(|e| {
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Database error: {e}"),
|
||||
)
|
||||
})?
|
||||
.ok_or_else(|| {
|
||||
(
|
||||
StatusCode::NOT_FOUND,
|
||||
"Target for session not found".to_string(),
|
||||
)
|
||||
})?;
|
||||
|
||||
// Store user message
|
||||
let session_id = id.clone();
|
||||
let user_msg = PentestMessage::user(session_id.clone(), req.message.clone());
|
||||
let _ = agent.db.pentest_messages().insert_one(&user_msg).await;
|
||||
|
||||
let response_msg = user_msg.clone();
|
||||
|
||||
// Spawn orchestrator to continue the session
|
||||
let llm = agent.llm.clone();
|
||||
let db = agent.db.clone();
|
||||
let message = req.message.clone();
|
||||
tokio::spawn(async move {
|
||||
let orchestrator = PentestOrchestrator::new(llm, db);
|
||||
orchestrator
|
||||
.run_session_guarded(&session, &target, &message)
|
||||
.await;
|
||||
});
|
||||
|
||||
Ok(Json(ApiResponse {
|
||||
data: response_msg,
|
||||
total: None,
|
||||
page: None,
|
||||
}))
|
||||
}
|
||||
|
||||
/// GET /api/v1/pentest/sessions/:id/stream — SSE endpoint for real-time events
|
||||
///
|
||||
/// Returns recent messages as SSE events (polling approach).
|
||||
/// True real-time streaming with broadcast channels will be added in a future iteration.
|
||||
#[tracing::instrument(skip_all, fields(session_id = %id))]
|
||||
pub async fn session_stream(
|
||||
Extension(agent): AgentExt,
|
||||
Path(id): Path<String>,
|
||||
) -> Result<Sse<impl futures_util::Stream<Item = Result<Event, std::convert::Infallible>>>, StatusCode>
|
||||
{
|
||||
let oid =
|
||||
mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?;
|
||||
|
||||
// Verify session exists
|
||||
let _session = agent
|
||||
.db
|
||||
.pentest_sessions()
|
||||
.find_one(doc! { "_id": oid })
|
||||
.await
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
|
||||
.ok_or(StatusCode::NOT_FOUND)?;
|
||||
|
||||
// Fetch recent messages for this session
|
||||
let messages: Vec<PentestMessage> = match agent
|
||||
.db
|
||||
.pentest_messages()
|
||||
.find(doc! { "session_id": &id })
|
||||
.sort(doc! { "created_at": 1 })
|
||||
.limit(100)
|
||||
.await
|
||||
{
|
||||
Ok(cursor) => collect_cursor_async(cursor).await,
|
||||
Err(_) => Vec::new(),
|
||||
};
|
||||
|
||||
// Fetch recent attack chain nodes
|
||||
let nodes: Vec<AttackChainNode> = match agent
|
||||
.db
|
||||
.attack_chain_nodes()
|
||||
.find(doc! { "session_id": &id })
|
||||
.sort(doc! { "started_at": 1 })
|
||||
.limit(100)
|
||||
.await
|
||||
{
|
||||
Ok(cursor) => collect_cursor_async(cursor).await,
|
||||
Err(_) => Vec::new(),
|
||||
};
|
||||
|
||||
// Build SSE events from stored data
|
||||
let mut events: Vec<Result<Event, std::convert::Infallible>> = Vec::new();
|
||||
|
||||
for msg in &messages {
|
||||
let event_data = serde_json::json!({
|
||||
"type": "message",
|
||||
"role": msg.role,
|
||||
"content": msg.content,
|
||||
"created_at": msg.created_at.to_rfc3339(),
|
||||
});
|
||||
if let Ok(data) = serde_json::to_string(&event_data) {
|
||||
events.push(Ok(Event::default().event("message").data(data)));
|
||||
}
|
||||
}
|
||||
|
||||
for node in &nodes {
|
||||
let event_data = serde_json::json!({
|
||||
"type": "tool_execution",
|
||||
"node_id": node.node_id,
|
||||
"tool_name": node.tool_name,
|
||||
"status": node.status,
|
||||
"findings_produced": node.findings_produced,
|
||||
});
|
||||
if let Ok(data) = serde_json::to_string(&event_data) {
|
||||
events.push(Ok(Event::default().event("tool").data(data)));
|
||||
}
|
||||
}
|
||||
|
||||
// Add session status event
|
||||
let session = agent
|
||||
.db
|
||||
.pentest_sessions()
|
||||
.find_one(doc! { "_id": oid })
|
||||
.await
|
||||
.ok()
|
||||
.flatten();
|
||||
|
||||
if let Some(s) = session {
|
||||
let status_data = serde_json::json!({
|
||||
"type": "status",
|
||||
"status": s.status,
|
||||
"findings_count": s.findings_count,
|
||||
"tool_invocations": s.tool_invocations,
|
||||
});
|
||||
if let Ok(data) = serde_json::to_string(&status_data) {
|
||||
events.push(Ok(Event::default().event("status").data(data)));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(Sse::new(stream::iter(events)))
|
||||
}
|
||||
|
||||
/// GET /api/v1/pentest/sessions/:id/attack-chain — Get attack chain nodes for a session
|
||||
#[tracing::instrument(skip_all, fields(session_id = %id))]
|
||||
pub async fn get_attack_chain(
|
||||
Extension(agent): AgentExt,
|
||||
Path(id): Path<String>,
|
||||
) -> Result<Json<ApiResponse<Vec<AttackChainNode>>>, StatusCode> {
|
||||
// Verify the session ID is valid
|
||||
let _oid =
|
||||
mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?;
|
||||
|
||||
let nodes = match agent
|
||||
.db
|
||||
.attack_chain_nodes()
|
||||
.find(doc! { "session_id": &id })
|
||||
.sort(doc! { "started_at": 1 })
|
||||
.await
|
||||
{
|
||||
Ok(cursor) => collect_cursor_async(cursor).await,
|
||||
Err(e) => {
|
||||
tracing::warn!("Failed to fetch attack chain nodes: {e}");
|
||||
Vec::new()
|
||||
}
|
||||
};
|
||||
|
||||
let total = nodes.len() as u64;
|
||||
Ok(Json(ApiResponse {
|
||||
data: nodes,
|
||||
total: Some(total),
|
||||
page: None,
|
||||
}))
|
||||
}
|
||||
|
||||
/// GET /api/v1/pentest/sessions/:id/messages — Get messages for a session
|
||||
#[tracing::instrument(skip_all, fields(session_id = %id))]
|
||||
pub async fn get_messages(
|
||||
Extension(agent): AgentExt,
|
||||
Path(id): Path<String>,
|
||||
Query(params): Query<PaginationParams>,
|
||||
) -> Result<Json<ApiResponse<Vec<PentestMessage>>>, StatusCode> {
|
||||
let _oid =
|
||||
mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?;
|
||||
|
||||
let skip = (params.page.saturating_sub(1)) * params.limit as u64;
|
||||
let total = agent
|
||||
.db
|
||||
.pentest_messages()
|
||||
.count_documents(doc! { "session_id": &id })
|
||||
.await
|
||||
.unwrap_or(0);
|
||||
|
||||
let messages = match agent
|
||||
.db
|
||||
.pentest_messages()
|
||||
.find(doc! { "session_id": &id })
|
||||
.sort(doc! { "created_at": 1 })
|
||||
.skip(skip)
|
||||
.limit(params.limit)
|
||||
.await
|
||||
{
|
||||
Ok(cursor) => collect_cursor_async(cursor).await,
|
||||
Err(e) => {
|
||||
tracing::warn!("Failed to fetch pentest messages: {e}");
|
||||
Vec::new()
|
||||
}
|
||||
};
|
||||
|
||||
Ok(Json(ApiResponse {
|
||||
data: messages,
|
||||
total: Some(total),
|
||||
page: Some(params.page),
|
||||
}))
|
||||
}
|
||||
|
||||
/// GET /api/v1/pentest/stats — Aggregated pentest statistics
|
||||
#[tracing::instrument(skip_all)]
|
||||
pub async fn pentest_stats(
|
||||
Extension(agent): AgentExt,
|
||||
) -> Result<Json<ApiResponse<PentestStats>>, StatusCode> {
|
||||
let db = &agent.db;
|
||||
|
||||
let running_sessions = db
|
||||
.pentest_sessions()
|
||||
.count_documents(doc! { "status": "running" })
|
||||
.await
|
||||
.unwrap_or(0) as u32;
|
||||
|
||||
// Count DAST findings from pentest sessions
|
||||
let total_vulnerabilities = db
|
||||
.dast_findings()
|
||||
.count_documents(doc! { "session_id": { "$exists": true, "$ne": null } })
|
||||
.await
|
||||
.unwrap_or(0) as u32;
|
||||
|
||||
// Aggregate tool invocations from all sessions
|
||||
let sessions: Vec<PentestSession> = match db.pentest_sessions().find(doc! {}).await {
|
||||
Ok(cursor) => collect_cursor_async(cursor).await,
|
||||
Err(_) => Vec::new(),
|
||||
};
|
||||
|
||||
let total_tool_invocations: u32 = sessions.iter().map(|s| s.tool_invocations).sum();
|
||||
let total_successes: u32 = sessions.iter().map(|s| s.tool_successes).sum();
|
||||
let tool_success_rate = if total_tool_invocations == 0 {
|
||||
100.0
|
||||
} else {
|
||||
(total_successes as f64 / total_tool_invocations as f64) * 100.0
|
||||
};
|
||||
|
||||
// Severity distribution from pentest-related DAST findings
|
||||
let critical = db
|
||||
.dast_findings()
|
||||
.count_documents(doc! { "session_id": { "$exists": true, "$ne": null }, "severity": "critical" })
|
||||
.await
|
||||
.unwrap_or(0) as u32;
|
||||
let high = db
|
||||
.dast_findings()
|
||||
.count_documents(doc! { "session_id": { "$exists": true, "$ne": null }, "severity": "high" })
|
||||
.await
|
||||
.unwrap_or(0) as u32;
|
||||
let medium = db
|
||||
.dast_findings()
|
||||
.count_documents(doc! { "session_id": { "$exists": true, "$ne": null }, "severity": "medium" })
|
||||
.await
|
||||
.unwrap_or(0) as u32;
|
||||
let low = db
|
||||
.dast_findings()
|
||||
.count_documents(doc! { "session_id": { "$exists": true, "$ne": null }, "severity": "low" })
|
||||
.await
|
||||
.unwrap_or(0) as u32;
|
||||
let info = db
|
||||
.dast_findings()
|
||||
.count_documents(doc! { "session_id": { "$exists": true, "$ne": null }, "severity": "info" })
|
||||
.await
|
||||
.unwrap_or(0) as u32;
|
||||
|
||||
Ok(Json(ApiResponse {
|
||||
data: PentestStats {
|
||||
running_sessions,
|
||||
total_vulnerabilities,
|
||||
total_tool_invocations,
|
||||
tool_success_rate,
|
||||
severity_distribution: SeverityDistribution {
|
||||
critical,
|
||||
high,
|
||||
medium,
|
||||
low,
|
||||
info,
|
||||
},
|
||||
},
|
||||
total: None,
|
||||
page: None,
|
||||
}))
|
||||
}
|
||||
|
||||
/// GET /api/v1/pentest/sessions/:id/findings — Get DAST findings for a pentest session
|
||||
#[tracing::instrument(skip_all, fields(session_id = %id))]
|
||||
pub async fn get_session_findings(
|
||||
Extension(agent): AgentExt,
|
||||
Path(id): Path<String>,
|
||||
Query(params): Query<PaginationParams>,
|
||||
) -> Result<Json<ApiResponse<Vec<DastFinding>>>, StatusCode> {
|
||||
let _oid =
|
||||
mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?;
|
||||
|
||||
let skip = (params.page.saturating_sub(1)) * params.limit as u64;
|
||||
let total = agent
|
||||
.db
|
||||
.dast_findings()
|
||||
.count_documents(doc! { "session_id": &id })
|
||||
.await
|
||||
.unwrap_or(0);
|
||||
|
||||
let findings = match agent
|
||||
.db
|
||||
.dast_findings()
|
||||
.find(doc! { "session_id": &id })
|
||||
.sort(doc! { "created_at": -1 })
|
||||
.skip(skip)
|
||||
.limit(params.limit)
|
||||
.await
|
||||
{
|
||||
Ok(cursor) => collect_cursor_async(cursor).await,
|
||||
Err(e) => {
|
||||
tracing::warn!("Failed to fetch pentest session findings: {e}");
|
||||
Vec::new()
|
||||
}
|
||||
};
|
||||
|
||||
Ok(Json(ApiResponse {
|
||||
data: findings,
|
||||
total: Some(total),
|
||||
page: Some(params.page),
|
||||
}))
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct ExportParams {
|
||||
#[serde(default = "default_export_format")]
|
||||
pub format: String,
|
||||
}
|
||||
|
||||
fn default_export_format() -> String {
|
||||
"json".to_string()
|
||||
}
|
||||
|
||||
/// GET /api/v1/pentest/sessions/:id/export?format=json|markdown — Export a session report
|
||||
#[tracing::instrument(skip_all, fields(session_id = %id))]
|
||||
pub async fn export_session_report(
|
||||
Extension(agent): AgentExt,
|
||||
Path(id): Path<String>,
|
||||
Query(params): Query<ExportParams>,
|
||||
) -> Result<axum::response::Response, (StatusCode, String)> {
|
||||
let oid = mongodb::bson::oid::ObjectId::parse_str(&id)
|
||||
.map_err(|_| (StatusCode::BAD_REQUEST, "Invalid session ID".to_string()))?;
|
||||
|
||||
// Fetch session
|
||||
let session = agent
|
||||
.db
|
||||
.pentest_sessions()
|
||||
.find_one(doc! { "_id": oid })
|
||||
.await
|
||||
.map_err(|e| {
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Database error: {e}"),
|
||||
)
|
||||
})?
|
||||
.ok_or_else(|| (StatusCode::NOT_FOUND, "Session not found".to_string()))?;
|
||||
|
||||
// Fetch messages
|
||||
let messages: Vec<PentestMessage> = match agent
|
||||
.db
|
||||
.pentest_messages()
|
||||
.find(doc! { "session_id": &id })
|
||||
.sort(doc! { "created_at": 1 })
|
||||
.await
|
||||
{
|
||||
Ok(cursor) => collect_cursor_async(cursor).await,
|
||||
Err(_) => Vec::new(),
|
||||
};
|
||||
|
||||
// Fetch attack chain nodes
|
||||
let nodes: Vec<AttackChainNode> = match agent
|
||||
.db
|
||||
.attack_chain_nodes()
|
||||
.find(doc! { "session_id": &id })
|
||||
.sort(doc! { "started_at": 1 })
|
||||
.await
|
||||
{
|
||||
Ok(cursor) => collect_cursor_async(cursor).await,
|
||||
Err(_) => Vec::new(),
|
||||
};
|
||||
|
||||
// Fetch DAST findings for this session
|
||||
let findings: Vec<DastFinding> = match agent
|
||||
.db
|
||||
.dast_findings()
|
||||
.find(doc! { "session_id": &id })
|
||||
.sort(doc! { "created_at": -1 })
|
||||
.await
|
||||
{
|
||||
Ok(cursor) => collect_cursor_async(cursor).await,
|
||||
Err(_) => Vec::new(),
|
||||
};
|
||||
|
||||
// Compute severity counts
|
||||
let critical = findings.iter().filter(|f| f.severity.to_string() == "critical").count();
|
||||
let high = findings.iter().filter(|f| f.severity.to_string() == "high").count();
|
||||
let medium = findings.iter().filter(|f| f.severity.to_string() == "medium").count();
|
||||
let low = findings.iter().filter(|f| f.severity.to_string() == "low").count();
|
||||
let info = findings.iter().filter(|f| f.severity.to_string() == "info").count();
|
||||
|
||||
match params.format.as_str() {
|
||||
"markdown" => {
|
||||
let mut md = String::new();
|
||||
md.push_str("# Penetration Test Report\n\n");
|
||||
|
||||
// Executive summary
|
||||
md.push_str("## Executive Summary\n\n");
|
||||
md.push_str(&format!("| Field | Value |\n"));
|
||||
md.push_str("| --- | --- |\n");
|
||||
md.push_str(&format!("| **Session ID** | {} |\n", id));
|
||||
md.push_str(&format!("| **Status** | {} |\n", session.status));
|
||||
md.push_str(&format!("| **Strategy** | {} |\n", session.strategy));
|
||||
md.push_str(&format!("| **Target ID** | {} |\n", session.target_id));
|
||||
md.push_str(&format!(
|
||||
"| **Started** | {} |\n",
|
||||
session.started_at.to_rfc3339()
|
||||
));
|
||||
if let Some(ref completed) = session.completed_at {
|
||||
md.push_str(&format!(
|
||||
"| **Completed** | {} |\n",
|
||||
completed.to_rfc3339()
|
||||
));
|
||||
}
|
||||
md.push_str(&format!(
|
||||
"| **Tool Invocations** | {} |\n",
|
||||
session.tool_invocations
|
||||
));
|
||||
md.push_str(&format!(
|
||||
"| **Success Rate** | {:.1}% |\n",
|
||||
session.success_rate()
|
||||
));
|
||||
md.push('\n');
|
||||
|
||||
// Findings by severity
|
||||
md.push_str("## Findings Summary\n\n");
|
||||
md.push_str(&format!(
|
||||
"| Severity | Count |\n| --- | --- |\n| Critical | {} |\n| High | {} |\n| Medium | {} |\n| Low | {} |\n| Info | {} |\n| **Total** | **{}** |\n\n",
|
||||
critical, high, medium, low, info, findings.len()
|
||||
));
|
||||
|
||||
// Findings table
|
||||
if !findings.is_empty() {
|
||||
md.push_str("## Findings Detail\n\n");
|
||||
md.push_str("| # | Severity | Title | Endpoint | Exploitable |\n");
|
||||
md.push_str("| --- | --- | --- | --- | --- |\n");
|
||||
for (i, f) in findings.iter().enumerate() {
|
||||
md.push_str(&format!(
|
||||
"| {} | {} | {} | {} {} | {} |\n",
|
||||
i + 1,
|
||||
f.severity,
|
||||
f.title,
|
||||
f.method,
|
||||
f.endpoint,
|
||||
if f.exploitable { "Yes" } else { "No" },
|
||||
));
|
||||
}
|
||||
md.push('\n');
|
||||
}
|
||||
|
||||
// Attack chain timeline
|
||||
if !nodes.is_empty() {
|
||||
md.push_str("## Attack Chain Timeline\n\n");
|
||||
md.push_str("| # | Tool | Status | Findings | Reasoning |\n");
|
||||
md.push_str("| --- | --- | --- | --- | --- |\n");
|
||||
for (i, node) in nodes.iter().enumerate() {
|
||||
let reasoning_short = if node.llm_reasoning.len() > 80 {
|
||||
format!("{}...", &node.llm_reasoning[..80])
|
||||
} else {
|
||||
node.llm_reasoning.clone()
|
||||
};
|
||||
md.push_str(&format!(
|
||||
"| {} | {} | {} | {} | {} |\n",
|
||||
i + 1,
|
||||
node.tool_name,
|
||||
format!("{:?}", node.status).to_lowercase(),
|
||||
node.findings_produced.len(),
|
||||
reasoning_short,
|
||||
));
|
||||
}
|
||||
md.push('\n');
|
||||
}
|
||||
|
||||
// Statistics
|
||||
md.push_str("## Statistics\n\n");
|
||||
md.push_str(&format!("- **Total Findings:** {}\n", findings.len()));
|
||||
md.push_str(&format!("- **Exploitable Findings:** {}\n", session.exploitable_count));
|
||||
md.push_str(&format!("- **Attack Chain Steps:** {}\n", nodes.len()));
|
||||
md.push_str(&format!("- **Messages Exchanged:** {}\n", messages.len()));
|
||||
md.push_str(&format!("- **Tool Invocations:** {}\n", session.tool_invocations));
|
||||
md.push_str(&format!("- **Tool Success Rate:** {:.1}%\n", session.success_rate()));
|
||||
|
||||
Ok((
|
||||
StatusCode::OK,
|
||||
[
|
||||
(axum::http::header::CONTENT_TYPE, "text/markdown; charset=utf-8"),
|
||||
],
|
||||
md,
|
||||
)
|
||||
.into_response())
|
||||
}
|
||||
_ => {
|
||||
// JSON format
|
||||
let report = serde_json::json!({
|
||||
"session": {
|
||||
"id": id,
|
||||
"target_id": session.target_id,
|
||||
"repo_id": session.repo_id,
|
||||
"status": session.status,
|
||||
"strategy": session.strategy,
|
||||
"started_at": session.started_at.to_rfc3339(),
|
||||
"completed_at": session.completed_at.map(|d| d.to_rfc3339()),
|
||||
"tool_invocations": session.tool_invocations,
|
||||
"tool_successes": session.tool_successes,
|
||||
"success_rate": session.success_rate(),
|
||||
"findings_count": session.findings_count,
|
||||
"exploitable_count": session.exploitable_count,
|
||||
},
|
||||
"findings": findings,
|
||||
"attack_chain": nodes,
|
||||
"messages": messages,
|
||||
"summary": {
|
||||
"total_findings": findings.len(),
|
||||
"severity_distribution": {
|
||||
"critical": critical,
|
||||
"high": high,
|
||||
"medium": medium,
|
||||
"low": low,
|
||||
"info": info,
|
||||
},
|
||||
"attack_chain_steps": nodes.len(),
|
||||
"messages_exchanged": messages.len(),
|
||||
},
|
||||
});
|
||||
|
||||
Ok(Json(report).into_response())
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -99,6 +99,40 @@ pub fn build_router() -> Router {
|
||||
"/api/v1/chat/{repo_id}/status",
|
||||
get(handlers::chat::embedding_status),
|
||||
)
|
||||
// Pentest API endpoints
|
||||
.route(
|
||||
"/api/v1/pentest/sessions",
|
||||
get(handlers::pentest::list_sessions).post(handlers::pentest::create_session),
|
||||
)
|
||||
.route(
|
||||
"/api/v1/pentest/sessions/{id}",
|
||||
get(handlers::pentest::get_session),
|
||||
)
|
||||
.route(
|
||||
"/api/v1/pentest/sessions/{id}/chat",
|
||||
post(handlers::pentest::send_message),
|
||||
)
|
||||
.route(
|
||||
"/api/v1/pentest/sessions/{id}/stream",
|
||||
get(handlers::pentest::session_stream),
|
||||
)
|
||||
.route(
|
||||
"/api/v1/pentest/sessions/{id}/attack-chain",
|
||||
get(handlers::pentest::get_attack_chain),
|
||||
)
|
||||
.route(
|
||||
"/api/v1/pentest/sessions/{id}/messages",
|
||||
get(handlers::pentest::get_messages),
|
||||
)
|
||||
.route(
|
||||
"/api/v1/pentest/sessions/{id}/findings",
|
||||
get(handlers::pentest::get_session_findings),
|
||||
)
|
||||
.route(
|
||||
"/api/v1/pentest/sessions/{id}/export",
|
||||
get(handlers::pentest::export_session_report),
|
||||
)
|
||||
.route("/api/v1/pentest/stats", get(handlers::pentest::pentest_stats))
|
||||
// Webhook endpoints (proxied through dashboard)
|
||||
.route(
|
||||
"/webhook/github/{repo_id}",
|
||||
|
||||
@@ -166,6 +166,38 @@ impl Database {
|
||||
)
|
||||
.await?;
|
||||
|
||||
// pentest_sessions: compound (target_id, started_at DESC)
|
||||
self.pentest_sessions()
|
||||
.create_index(
|
||||
IndexModel::builder()
|
||||
.keys(doc! { "target_id": 1, "started_at": -1 })
|
||||
.build(),
|
||||
)
|
||||
.await?;
|
||||
|
||||
// pentest_sessions: status index
|
||||
self.pentest_sessions()
|
||||
.create_index(IndexModel::builder().keys(doc! { "status": 1 }).build())
|
||||
.await?;
|
||||
|
||||
// attack_chain_nodes: compound (session_id, node_id)
|
||||
self.attack_chain_nodes()
|
||||
.create_index(
|
||||
IndexModel::builder()
|
||||
.keys(doc! { "session_id": 1, "node_id": 1 })
|
||||
.build(),
|
||||
)
|
||||
.await?;
|
||||
|
||||
// pentest_messages: compound (session_id, created_at)
|
||||
self.pentest_messages()
|
||||
.create_index(
|
||||
IndexModel::builder()
|
||||
.keys(doc! { "session_id": 1, "created_at": 1 })
|
||||
.build(),
|
||||
)
|
||||
.await?;
|
||||
|
||||
tracing::info!("Database indexes ensured");
|
||||
Ok(())
|
||||
}
|
||||
@@ -235,6 +267,19 @@ impl Database {
|
||||
self.inner.collection("embedding_builds")
|
||||
}
|
||||
|
||||
// Pentest collections
|
||||
pub fn pentest_sessions(&self) -> Collection<PentestSession> {
|
||||
self.inner.collection("pentest_sessions")
|
||||
}
|
||||
|
||||
pub fn attack_chain_nodes(&self) -> Collection<AttackChainNode> {
|
||||
self.inner.collection("attack_chain_nodes")
|
||||
}
|
||||
|
||||
pub fn pentest_messages(&self) -> Collection<PentestMessage> {
|
||||
self.inner.collection("pentest_messages")
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub fn raw_collection(&self, name: &str) -> Collection<mongodb::bson::Document> {
|
||||
self.inner.collection(name)
|
||||
|
||||
@@ -12,10 +12,16 @@ pub struct LlmClient {
|
||||
http: reqwest::Client,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct ChatMessage {
|
||||
role: String,
|
||||
content: String,
|
||||
// ── Request types ──────────────────────────────────────────────
|
||||
|
||||
#[derive(Serialize, Clone, Debug)]
|
||||
pub struct ChatMessage {
|
||||
pub role: String,
|
||||
pub content: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tool_calls: Option<Vec<ToolCallRequest>>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tool_call_id: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
@@ -26,8 +32,25 @@ struct ChatCompletionRequest {
|
||||
temperature: Option<f64>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
max_tokens: Option<u32>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
tools: Option<Vec<ToolDefinitionPayload>>,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct ToolDefinitionPayload {
|
||||
r#type: String,
|
||||
function: ToolFunctionPayload,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct ToolFunctionPayload {
|
||||
name: String,
|
||||
description: String,
|
||||
parameters: serde_json::Value,
|
||||
}
|
||||
|
||||
// ── Response types ─────────────────────────────────────────────
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct ChatCompletionResponse {
|
||||
choices: Vec<ChatChoice>,
|
||||
@@ -40,29 +63,84 @@ struct ChatChoice {
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct ChatResponseMessage {
|
||||
content: String,
|
||||
#[serde(default)]
|
||||
content: Option<String>,
|
||||
#[serde(default)]
|
||||
tool_calls: Option<Vec<ToolCallResponse>>,
|
||||
}
|
||||
|
||||
/// Request body for the embeddings API
|
||||
#[derive(Deserialize)]
|
||||
struct ToolCallResponse {
|
||||
id: String,
|
||||
function: ToolCallFunction,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct ToolCallFunction {
|
||||
name: String,
|
||||
arguments: String,
|
||||
}
|
||||
|
||||
// ── Public types for tool calling ──────────────────────────────
|
||||
|
||||
/// Definition of a tool that the LLM can invoke
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
pub struct ToolDefinition {
|
||||
pub name: String,
|
||||
pub description: String,
|
||||
pub parameters: serde_json::Value,
|
||||
}
|
||||
|
||||
/// A tool call request from the LLM
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct LlmToolCall {
|
||||
pub id: String,
|
||||
pub name: String,
|
||||
pub arguments: serde_json::Value,
|
||||
}
|
||||
|
||||
/// A tool call in the request message format (for sending back tool_calls in assistant messages)
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ToolCallRequest {
|
||||
pub id: String,
|
||||
pub r#type: String,
|
||||
pub function: ToolCallRequestFunction,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ToolCallRequestFunction {
|
||||
pub name: String,
|
||||
pub arguments: String,
|
||||
}
|
||||
|
||||
/// Response from the LLM — either content or tool calls
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum LlmResponse {
|
||||
Content(String),
|
||||
ToolCalls(Vec<LlmToolCall>),
|
||||
}
|
||||
|
||||
// ── Embedding types ────────────────────────────────────────────
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct EmbeddingRequest {
|
||||
model: String,
|
||||
input: Vec<String>,
|
||||
}
|
||||
|
||||
/// Response from the embeddings API
|
||||
#[derive(Deserialize)]
|
||||
struct EmbeddingResponse {
|
||||
data: Vec<EmbeddingData>,
|
||||
}
|
||||
|
||||
/// A single embedding result
|
||||
#[derive(Deserialize)]
|
||||
struct EmbeddingData {
|
||||
embedding: Vec<f64>,
|
||||
index: usize,
|
||||
}
|
||||
|
||||
// ── Implementation ─────────────────────────────────────────────
|
||||
|
||||
impl LlmClient {
|
||||
pub fn new(
|
||||
base_url: String,
|
||||
@@ -83,98 +161,142 @@ impl LlmClient {
|
||||
&self.embed_model
|
||||
}
|
||||
|
||||
fn chat_url(&self) -> String {
|
||||
format!(
|
||||
"{}/v1/chat/completions",
|
||||
self.base_url.trim_end_matches('/')
|
||||
)
|
||||
}
|
||||
|
||||
fn auth_header(&self) -> Option<String> {
|
||||
let key = self.api_key.expose_secret();
|
||||
if key.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(format!("Bearer {key}"))
|
||||
}
|
||||
}
|
||||
|
||||
/// Simple chat: system + user prompt → text response
|
||||
pub async fn chat(
|
||||
&self,
|
||||
system_prompt: &str,
|
||||
user_prompt: &str,
|
||||
temperature: Option<f64>,
|
||||
) -> Result<String, AgentError> {
|
||||
let url = format!(
|
||||
"{}/v1/chat/completions",
|
||||
self.base_url.trim_end_matches('/')
|
||||
);
|
||||
let messages = vec![
|
||||
ChatMessage {
|
||||
role: "system".to_string(),
|
||||
content: Some(system_prompt.to_string()),
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
},
|
||||
ChatMessage {
|
||||
role: "user".to_string(),
|
||||
content: Some(user_prompt.to_string()),
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
},
|
||||
];
|
||||
|
||||
let request_body = ChatCompletionRequest {
|
||||
model: self.model.clone(),
|
||||
messages: vec![
|
||||
ChatMessage {
|
||||
role: "system".to_string(),
|
||||
content: system_prompt.to_string(),
|
||||
},
|
||||
ChatMessage {
|
||||
role: "user".to_string(),
|
||||
content: user_prompt.to_string(),
|
||||
},
|
||||
],
|
||||
messages,
|
||||
temperature,
|
||||
max_tokens: Some(4096),
|
||||
tools: None,
|
||||
};
|
||||
|
||||
let mut req = self
|
||||
.http
|
||||
.post(&url)
|
||||
.header("content-type", "application/json")
|
||||
.json(&request_body);
|
||||
|
||||
let key = self.api_key.expose_secret();
|
||||
if !key.is_empty() {
|
||||
req = req.header("Authorization", format!("Bearer {key}"));
|
||||
}
|
||||
|
||||
let resp = req
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| AgentError::Other(format!("LiteLLM request failed: {e}")))?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
let status = resp.status();
|
||||
let body = resp.text().await.unwrap_or_default();
|
||||
return Err(AgentError::Other(format!(
|
||||
"LiteLLM returned {status}: {body}"
|
||||
)));
|
||||
}
|
||||
|
||||
let body: ChatCompletionResponse = resp
|
||||
.json()
|
||||
.await
|
||||
.map_err(|e| AgentError::Other(format!("Failed to parse LiteLLM response: {e}")))?;
|
||||
|
||||
body.choices
|
||||
.first()
|
||||
.map(|c| c.message.content.clone())
|
||||
.ok_or_else(|| AgentError::Other("Empty response from LiteLLM".to_string()))
|
||||
self.send_chat_request(&request_body).await.map(|resp| {
|
||||
match resp {
|
||||
LlmResponse::Content(c) => c,
|
||||
LlmResponse::ToolCalls(_) => String::new(), // shouldn't happen without tools
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
/// Chat with a list of (role, content) messages → text response
|
||||
#[allow(dead_code)]
|
||||
pub async fn chat_with_messages(
|
||||
&self,
|
||||
messages: Vec<(String, String)>,
|
||||
temperature: Option<f64>,
|
||||
) -> Result<String, AgentError> {
|
||||
let url = format!(
|
||||
"{}/v1/chat/completions",
|
||||
self.base_url.trim_end_matches('/')
|
||||
);
|
||||
let messages = messages
|
||||
.into_iter()
|
||||
.map(|(role, content)| ChatMessage {
|
||||
role,
|
||||
content: Some(content),
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
})
|
||||
.collect();
|
||||
|
||||
let request_body = ChatCompletionRequest {
|
||||
model: self.model.clone(),
|
||||
messages: messages
|
||||
.into_iter()
|
||||
.map(|(role, content)| ChatMessage { role, content })
|
||||
.collect(),
|
||||
messages,
|
||||
temperature,
|
||||
max_tokens: Some(4096),
|
||||
tools: None,
|
||||
};
|
||||
|
||||
self.send_chat_request(&request_body).await.map(|resp| {
|
||||
match resp {
|
||||
LlmResponse::Content(c) => c,
|
||||
LlmResponse::ToolCalls(_) => String::new(),
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
/// Chat with tool definitions — returns either content or tool calls.
|
||||
/// Use this for the AI pentest orchestrator loop.
|
||||
pub async fn chat_with_tools(
|
||||
&self,
|
||||
messages: Vec<ChatMessage>,
|
||||
tools: &[ToolDefinition],
|
||||
temperature: Option<f64>,
|
||||
max_tokens: Option<u32>,
|
||||
) -> Result<LlmResponse, AgentError> {
|
||||
let tool_payloads: Vec<ToolDefinitionPayload> = tools
|
||||
.iter()
|
||||
.map(|t| ToolDefinitionPayload {
|
||||
r#type: "function".to_string(),
|
||||
function: ToolFunctionPayload {
|
||||
name: t.name.clone(),
|
||||
description: t.description.clone(),
|
||||
parameters: t.parameters.clone(),
|
||||
},
|
||||
})
|
||||
.collect();
|
||||
|
||||
let request_body = ChatCompletionRequest {
|
||||
model: self.model.clone(),
|
||||
messages,
|
||||
temperature,
|
||||
max_tokens: Some(max_tokens.unwrap_or(8192)),
|
||||
tools: if tool_payloads.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(tool_payloads)
|
||||
},
|
||||
};
|
||||
|
||||
self.send_chat_request(&request_body).await
|
||||
}
|
||||
|
||||
/// Internal method to send a chat completion request and parse the response
|
||||
async fn send_chat_request(
|
||||
&self,
|
||||
request_body: &ChatCompletionRequest,
|
||||
) -> Result<LlmResponse, AgentError> {
|
||||
let mut req = self
|
||||
.http
|
||||
.post(&url)
|
||||
.post(&self.chat_url())
|
||||
.header("content-type", "application/json")
|
||||
.json(&request_body);
|
||||
.json(request_body);
|
||||
|
||||
let key = self.api_key.expose_secret();
|
||||
if !key.is_empty() {
|
||||
req = req.header("Authorization", format!("Bearer {key}"));
|
||||
if let Some(auth) = self.auth_header() {
|
||||
req = req.header("Authorization", auth);
|
||||
}
|
||||
|
||||
let resp = req
|
||||
@@ -195,10 +317,37 @@ impl LlmClient {
|
||||
.await
|
||||
.map_err(|e| AgentError::Other(format!("Failed to parse LiteLLM response: {e}")))?;
|
||||
|
||||
body.choices
|
||||
let choice = body
|
||||
.choices
|
||||
.first()
|
||||
.map(|c| c.message.content.clone())
|
||||
.ok_or_else(|| AgentError::Other("Empty response from LiteLLM".to_string()))
|
||||
.ok_or_else(|| AgentError::Other("Empty response from LiteLLM".to_string()))?;
|
||||
|
||||
// Check for tool calls first
|
||||
if let Some(tool_calls) = &choice.message.tool_calls {
|
||||
if !tool_calls.is_empty() {
|
||||
let calls: Vec<LlmToolCall> = tool_calls
|
||||
.iter()
|
||||
.map(|tc| {
|
||||
let arguments = serde_json::from_str(&tc.function.arguments)
|
||||
.unwrap_or(serde_json::Value::Object(serde_json::Map::new()));
|
||||
LlmToolCall {
|
||||
id: tc.id.clone(),
|
||||
name: tc.function.name.clone(),
|
||||
arguments,
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
return Ok(LlmResponse::ToolCalls(calls));
|
||||
}
|
||||
}
|
||||
|
||||
// Otherwise return content
|
||||
let content = choice
|
||||
.message
|
||||
.content
|
||||
.clone()
|
||||
.unwrap_or_default();
|
||||
Ok(LlmResponse::Content(content))
|
||||
}
|
||||
|
||||
/// Generate embeddings for a batch of texts
|
||||
@@ -216,9 +365,8 @@ impl LlmClient {
|
||||
.header("content-type", "application/json")
|
||||
.json(&request_body);
|
||||
|
||||
let key = self.api_key.expose_secret();
|
||||
if !key.is_empty() {
|
||||
req = req.header("Authorization", format!("Bearer {key}"));
|
||||
if let Some(auth) = self.auth_header() {
|
||||
req = req.header("Authorization", auth);
|
||||
}
|
||||
|
||||
let resp = req
|
||||
@@ -239,7 +387,6 @@ impl LlmClient {
|
||||
.await
|
||||
.map_err(|e| AgentError::Other(format!("Failed to parse embedding response: {e}")))?;
|
||||
|
||||
// Sort by index to maintain input order
|
||||
let mut data = body.data;
|
||||
data.sort_by_key(|d| d.index);
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ mod config;
|
||||
mod database;
|
||||
mod error;
|
||||
mod llm;
|
||||
mod pentest;
|
||||
mod pipeline;
|
||||
mod rag;
|
||||
mod scheduler;
|
||||
|
||||
3
compliance-agent/src/pentest/mod.rs
Normal file
3
compliance-agent/src/pentest/mod.rs
Normal file
@@ -0,0 +1,3 @@
|
||||
pub mod orchestrator;
|
||||
|
||||
pub use orchestrator::PentestOrchestrator;
|
||||
736
compliance-agent/src/pentest/orchestrator.rs
Normal file
736
compliance-agent/src/pentest/orchestrator.rs
Normal file
@@ -0,0 +1,736 @@
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use futures_util::StreamExt;
|
||||
use mongodb::bson::doc;
|
||||
use tokio::sync::broadcast;
|
||||
|
||||
use compliance_core::models::dast::DastTarget;
|
||||
use compliance_core::models::finding::{Finding, FindingStatus, Severity};
|
||||
use compliance_core::models::pentest::*;
|
||||
use compliance_core::models::sbom::SbomEntry;
|
||||
use compliance_core::traits::pentest_tool::PentestToolContext;
|
||||
use compliance_dast::ToolRegistry;
|
||||
|
||||
use crate::database::Database;
|
||||
use crate::llm::client::{
|
||||
ChatMessage, LlmResponse, ToolCallRequest, ToolCallRequestFunction, ToolDefinition,
|
||||
};
|
||||
use crate::llm::LlmClient;
|
||||
|
||||
/// Maximum duration for a single pentest session before timeout
|
||||
const SESSION_TIMEOUT: Duration = Duration::from_secs(30 * 60); // 30 minutes
|
||||
|
||||
pub struct PentestOrchestrator {
|
||||
tool_registry: ToolRegistry,
|
||||
llm: Arc<LlmClient>,
|
||||
db: Database,
|
||||
event_tx: broadcast::Sender<PentestEvent>,
|
||||
}
|
||||
|
||||
impl PentestOrchestrator {
|
||||
pub fn new(llm: Arc<LlmClient>, db: Database) -> Self {
|
||||
let (event_tx, _) = broadcast::channel(256);
|
||||
Self {
|
||||
tool_registry: ToolRegistry::new(),
|
||||
llm,
|
||||
db,
|
||||
event_tx,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn subscribe(&self) -> broadcast::Receiver<PentestEvent> {
|
||||
self.event_tx.subscribe()
|
||||
}
|
||||
|
||||
pub fn event_sender(&self) -> broadcast::Sender<PentestEvent> {
|
||||
self.event_tx.clone()
|
||||
}
|
||||
|
||||
/// Run a pentest session with timeout and automatic failure marking on errors.
|
||||
pub async fn run_session_guarded(
|
||||
&self,
|
||||
session: &PentestSession,
|
||||
target: &DastTarget,
|
||||
initial_message: &str,
|
||||
) {
|
||||
let session_id = session.id;
|
||||
|
||||
match tokio::time::timeout(
|
||||
SESSION_TIMEOUT,
|
||||
self.run_session(session, target, initial_message),
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(Ok(())) => {
|
||||
tracing::info!(?session_id, "Pentest session completed successfully");
|
||||
}
|
||||
Ok(Err(e)) => {
|
||||
tracing::error!(?session_id, error = %e, "Pentest session failed");
|
||||
self.mark_session_failed(session_id, &format!("Error: {e}"))
|
||||
.await;
|
||||
let _ = self.event_tx.send(PentestEvent::Error {
|
||||
message: format!("Session failed: {e}"),
|
||||
});
|
||||
}
|
||||
Err(_) => {
|
||||
tracing::warn!(?session_id, "Pentest session timed out after 30 minutes");
|
||||
self.mark_session_failed(session_id, "Session timed out after 30 minutes")
|
||||
.await;
|
||||
let _ = self.event_tx.send(PentestEvent::Error {
|
||||
message: "Session timed out after 30 minutes".to_string(),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn mark_session_failed(
|
||||
&self,
|
||||
session_id: Option<mongodb::bson::oid::ObjectId>,
|
||||
reason: &str,
|
||||
) {
|
||||
if let Some(sid) = session_id {
|
||||
let _ = self
|
||||
.db
|
||||
.pentest_sessions()
|
||||
.update_one(
|
||||
doc! { "_id": sid },
|
||||
doc! { "$set": {
|
||||
"status": "failed",
|
||||
"completed_at": mongodb::bson::DateTime::now(),
|
||||
"error_message": reason,
|
||||
}},
|
||||
)
|
||||
.await;
|
||||
}
|
||||
}
|
||||
|
||||
async fn run_session(
|
||||
&self,
|
||||
session: &PentestSession,
|
||||
target: &DastTarget,
|
||||
initial_message: &str,
|
||||
) -> Result<(), crate::error::AgentError> {
|
||||
let session_id = session
|
||||
.id
|
||||
.map(|oid| oid.to_hex())
|
||||
.unwrap_or_default();
|
||||
|
||||
// Gather code-awareness context from linked repo
|
||||
let (sast_findings, sbom_entries, code_context) =
|
||||
self.gather_repo_context(target).await;
|
||||
|
||||
// Build system prompt with code context
|
||||
let system_prompt = self
|
||||
.build_system_prompt(session, target, &sast_findings, &sbom_entries, &code_context)
|
||||
.await;
|
||||
|
||||
// Build tool definitions for LLM
|
||||
let tool_defs: Vec<ToolDefinition> = self
|
||||
.tool_registry
|
||||
.all_definitions()
|
||||
.into_iter()
|
||||
.map(|td| ToolDefinition {
|
||||
name: td.name,
|
||||
description: td.description,
|
||||
parameters: td.input_schema,
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Initialize messages
|
||||
let mut messages = vec![
|
||||
ChatMessage {
|
||||
role: "system".to_string(),
|
||||
content: Some(system_prompt),
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
},
|
||||
ChatMessage {
|
||||
role: "user".to_string(),
|
||||
content: Some(initial_message.to_string()),
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
},
|
||||
];
|
||||
|
||||
// Store user message
|
||||
let user_msg = PentestMessage::user(session_id.clone(), initial_message.to_string());
|
||||
let _ = self.db.pentest_messages().insert_one(&user_msg).await;
|
||||
|
||||
// Build tool context with real data
|
||||
let tool_context = PentestToolContext {
|
||||
target: target.clone(),
|
||||
session_id: session_id.clone(),
|
||||
sast_findings,
|
||||
sbom_entries,
|
||||
code_context,
|
||||
rate_limit: target.rate_limit,
|
||||
allow_destructive: target.allow_destructive,
|
||||
};
|
||||
|
||||
let max_iterations = 50;
|
||||
let mut total_findings = 0u32;
|
||||
let mut total_tool_calls = 0u32;
|
||||
let mut total_successes = 0u32;
|
||||
let mut prev_node_ids: Vec<String> = Vec::new();
|
||||
|
||||
for _iteration in 0..max_iterations {
|
||||
let response = self
|
||||
.llm
|
||||
.chat_with_tools(messages.clone(), &tool_defs, Some(0.2), Some(8192))
|
||||
.await?;
|
||||
|
||||
match response {
|
||||
LlmResponse::Content(content) => {
|
||||
let msg =
|
||||
PentestMessage::assistant(session_id.clone(), content.clone());
|
||||
let _ = self.db.pentest_messages().insert_one(&msg).await;
|
||||
let _ = self.event_tx.send(PentestEvent::Message {
|
||||
content: content.clone(),
|
||||
});
|
||||
|
||||
messages.push(ChatMessage {
|
||||
role: "assistant".to_string(),
|
||||
content: Some(content.clone()),
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
});
|
||||
|
||||
let done_indicators = [
|
||||
"pentest complete",
|
||||
"testing complete",
|
||||
"scan complete",
|
||||
"analysis complete",
|
||||
"finished",
|
||||
"that concludes",
|
||||
];
|
||||
let content_lower = content.to_lowercase();
|
||||
if done_indicators
|
||||
.iter()
|
||||
.any(|ind| content_lower.contains(ind))
|
||||
{
|
||||
break;
|
||||
}
|
||||
break;
|
||||
}
|
||||
LlmResponse::ToolCalls(tool_calls) => {
|
||||
let tc_requests: Vec<ToolCallRequest> = tool_calls
|
||||
.iter()
|
||||
.map(|tc| ToolCallRequest {
|
||||
id: tc.id.clone(),
|
||||
r#type: "function".to_string(),
|
||||
function: ToolCallRequestFunction {
|
||||
name: tc.name.clone(),
|
||||
arguments: serde_json::to_string(&tc.arguments)
|
||||
.unwrap_or_default(),
|
||||
},
|
||||
})
|
||||
.collect();
|
||||
|
||||
messages.push(ChatMessage {
|
||||
role: "assistant".to_string(),
|
||||
content: None,
|
||||
tool_calls: Some(tc_requests),
|
||||
tool_call_id: None,
|
||||
});
|
||||
|
||||
let mut current_batch_node_ids: Vec<String> = Vec::new();
|
||||
|
||||
for tc in &tool_calls {
|
||||
total_tool_calls += 1;
|
||||
let node_id = uuid::Uuid::new_v4().to_string();
|
||||
|
||||
let mut node = AttackChainNode::new(
|
||||
session_id.clone(),
|
||||
node_id.clone(),
|
||||
tc.name.clone(),
|
||||
tc.arguments.clone(),
|
||||
String::new(),
|
||||
);
|
||||
// Link to previous iteration's nodes
|
||||
node.parent_node_ids = prev_node_ids.clone();
|
||||
node.status = AttackNodeStatus::Running;
|
||||
node.started_at = Some(chrono::Utc::now());
|
||||
let _ = self.db.attack_chain_nodes().insert_one(&node).await;
|
||||
current_batch_node_ids.push(node_id.clone());
|
||||
|
||||
let _ = self.event_tx.send(PentestEvent::ToolStart {
|
||||
node_id: node_id.clone(),
|
||||
tool_name: tc.name.clone(),
|
||||
input: tc.arguments.clone(),
|
||||
});
|
||||
|
||||
let result = if let Some(tool) = self.tool_registry.get(&tc.name) {
|
||||
match tool.execute(tc.arguments.clone(), &tool_context).await {
|
||||
Ok(result) => {
|
||||
total_successes += 1;
|
||||
let findings_count = result.findings.len() as u32;
|
||||
total_findings += findings_count;
|
||||
|
||||
for mut finding in result.findings {
|
||||
finding.scan_run_id = session_id.clone();
|
||||
finding.session_id = Some(session_id.clone());
|
||||
let _ =
|
||||
self.db.dast_findings().insert_one(&finding).await;
|
||||
let _ =
|
||||
self.event_tx.send(PentestEvent::Finding {
|
||||
finding_id: finding
|
||||
.id
|
||||
.map(|oid| oid.to_hex())
|
||||
.unwrap_or_default(),
|
||||
title: finding.title.clone(),
|
||||
severity: finding.severity.to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
let _ = self.event_tx.send(PentestEvent::ToolComplete {
|
||||
node_id: node_id.clone(),
|
||||
summary: result.summary.clone(),
|
||||
findings_count,
|
||||
});
|
||||
|
||||
let _ = self
|
||||
.db
|
||||
.attack_chain_nodes()
|
||||
.update_one(
|
||||
doc! {
|
||||
"session_id": &session_id,
|
||||
"node_id": &node_id,
|
||||
},
|
||||
doc! { "$set": {
|
||||
"status": "completed",
|
||||
"tool_output": mongodb::bson::to_bson(&result.data)
|
||||
.unwrap_or(mongodb::bson::Bson::Null),
|
||||
"completed_at": mongodb::bson::DateTime::now(),
|
||||
}},
|
||||
)
|
||||
.await;
|
||||
|
||||
serde_json::json!({
|
||||
"summary": result.summary,
|
||||
"findings_count": findings_count,
|
||||
"data": result.data,
|
||||
})
|
||||
.to_string()
|
||||
}
|
||||
Err(e) => {
|
||||
let _ = self
|
||||
.db
|
||||
.attack_chain_nodes()
|
||||
.update_one(
|
||||
doc! {
|
||||
"session_id": &session_id,
|
||||
"node_id": &node_id,
|
||||
},
|
||||
doc! { "$set": {
|
||||
"status": "failed",
|
||||
"completed_at": mongodb::bson::DateTime::now(),
|
||||
}},
|
||||
)
|
||||
.await;
|
||||
format!("Tool execution failed: {e}")
|
||||
}
|
||||
}
|
||||
} else {
|
||||
format!("Unknown tool: {}", tc.name)
|
||||
};
|
||||
|
||||
messages.push(ChatMessage {
|
||||
role: "tool".to_string(),
|
||||
content: Some(result),
|
||||
tool_calls: None,
|
||||
tool_call_id: Some(tc.id.clone()),
|
||||
});
|
||||
}
|
||||
|
||||
// Advance parent links so next iteration's nodes connect to this batch
|
||||
prev_node_ids = current_batch_node_ids;
|
||||
|
||||
if let Some(sid) = session.id {
|
||||
let _ = self
|
||||
.db
|
||||
.pentest_sessions()
|
||||
.update_one(
|
||||
doc! { "_id": sid },
|
||||
doc! { "$set": {
|
||||
"tool_invocations": total_tool_calls as i64,
|
||||
"tool_successes": total_successes as i64,
|
||||
"findings_count": total_findings as i64,
|
||||
}},
|
||||
)
|
||||
.await;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(sid) = session.id {
|
||||
let _ = self
|
||||
.db
|
||||
.pentest_sessions()
|
||||
.update_one(
|
||||
doc! { "_id": sid },
|
||||
doc! { "$set": {
|
||||
"status": "completed",
|
||||
"completed_at": mongodb::bson::DateTime::now(),
|
||||
"tool_invocations": total_tool_calls as i64,
|
||||
"tool_successes": total_successes as i64,
|
||||
"findings_count": total_findings as i64,
|
||||
}},
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
let _ = self.event_tx.send(PentestEvent::Complete {
|
||||
summary: format!(
|
||||
"Pentest complete. {} findings from {} tool invocations.",
|
||||
total_findings, total_tool_calls
|
||||
),
|
||||
});
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// ── Code-Awareness: Gather context from linked repo ─────────
|
||||
|
||||
/// Fetch SAST findings, SBOM entries (with CVEs), and code graph entry points
|
||||
/// for the repo linked to this DAST target.
|
||||
async fn gather_repo_context(
|
||||
&self,
|
||||
target: &DastTarget,
|
||||
) -> (Vec<Finding>, Vec<SbomEntry>, Vec<CodeContextHint>) {
|
||||
let Some(repo_id) = &target.repo_id else {
|
||||
return (Vec::new(), Vec::new(), Vec::new());
|
||||
};
|
||||
|
||||
let sast_findings = self.fetch_sast_findings(repo_id).await;
|
||||
let sbom_entries = self.fetch_vulnerable_sbom(repo_id).await;
|
||||
let code_context = self.fetch_code_context(repo_id, &sast_findings).await;
|
||||
|
||||
tracing::info!(
|
||||
repo_id,
|
||||
sast_findings = sast_findings.len(),
|
||||
vulnerable_deps = sbom_entries.len(),
|
||||
code_hints = code_context.len(),
|
||||
"Gathered code-awareness context for pentest"
|
||||
);
|
||||
|
||||
(sast_findings, sbom_entries, code_context)
|
||||
}
|
||||
|
||||
/// Fetch open/triaged SAST findings for the repo (not false positives or resolved)
|
||||
async fn fetch_sast_findings(&self, repo_id: &str) -> Vec<Finding> {
|
||||
let cursor = self
|
||||
.db
|
||||
.findings()
|
||||
.find(doc! {
|
||||
"repo_id": repo_id,
|
||||
"status": { "$in": ["open", "triaged"] },
|
||||
})
|
||||
.sort(doc! { "severity": -1 })
|
||||
.limit(100)
|
||||
.await;
|
||||
|
||||
match cursor {
|
||||
Ok(mut c) => {
|
||||
let mut results = Vec::new();
|
||||
while let Some(Ok(f)) = c.next().await {
|
||||
results.push(f);
|
||||
}
|
||||
results
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!("Failed to fetch SAST findings for pentest: {e}");
|
||||
Vec::new()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Fetch SBOM entries that have known vulnerabilities
|
||||
async fn fetch_vulnerable_sbom(&self, repo_id: &str) -> Vec<SbomEntry> {
|
||||
let cursor = self
|
||||
.db
|
||||
.sbom_entries()
|
||||
.find(doc! {
|
||||
"repo_id": repo_id,
|
||||
"known_vulnerabilities": { "$exists": true, "$ne": [] },
|
||||
})
|
||||
.limit(50)
|
||||
.await;
|
||||
|
||||
match cursor {
|
||||
Ok(mut c) => {
|
||||
let mut results = Vec::new();
|
||||
while let Some(Ok(e)) = c.next().await {
|
||||
results.push(e);
|
||||
}
|
||||
results
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!("Failed to fetch vulnerable SBOM entries: {e}");
|
||||
Vec::new()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Build CodeContextHint objects from the code knowledge graph.
|
||||
/// Maps entry points to their source files and links SAST findings.
|
||||
async fn fetch_code_context(
|
||||
&self,
|
||||
repo_id: &str,
|
||||
sast_findings: &[Finding],
|
||||
) -> Vec<CodeContextHint> {
|
||||
// Get entry point nodes from the code graph
|
||||
let cursor = self
|
||||
.db
|
||||
.graph_nodes()
|
||||
.find(doc! {
|
||||
"repo_id": repo_id,
|
||||
"is_entry_point": true,
|
||||
})
|
||||
.limit(50)
|
||||
.await;
|
||||
|
||||
let nodes = match cursor {
|
||||
Ok(mut c) => {
|
||||
let mut results = Vec::new();
|
||||
while let Some(Ok(n)) = c.next().await {
|
||||
results.push(n);
|
||||
}
|
||||
results
|
||||
}
|
||||
Err(_) => return Vec::new(),
|
||||
};
|
||||
|
||||
// Build hints by matching graph nodes to SAST findings by file path
|
||||
nodes
|
||||
.into_iter()
|
||||
.map(|node| {
|
||||
// Find SAST findings in the same file
|
||||
let linked_vulns: Vec<String> = sast_findings
|
||||
.iter()
|
||||
.filter(|f| {
|
||||
f.file_path.as_deref() == Some(&node.file_path)
|
||||
})
|
||||
.map(|f| {
|
||||
format!(
|
||||
"[{}] {}: {} (line {})",
|
||||
f.severity,
|
||||
f.scanner,
|
||||
f.title,
|
||||
f.line_number.unwrap_or(0)
|
||||
)
|
||||
})
|
||||
.collect();
|
||||
|
||||
CodeContextHint {
|
||||
endpoint_pattern: node.qualified_name.clone(),
|
||||
handler_function: node.name.clone(),
|
||||
file_path: node.file_path.clone(),
|
||||
code_snippet: String::new(), // Could fetch from embeddings
|
||||
known_vulnerabilities: linked_vulns,
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
// ── System Prompt Builder ───────────────────────────────────
|
||||
|
||||
async fn build_system_prompt(
|
||||
&self,
|
||||
session: &PentestSession,
|
||||
target: &DastTarget,
|
||||
sast_findings: &[Finding],
|
||||
sbom_entries: &[SbomEntry],
|
||||
code_context: &[CodeContextHint],
|
||||
) -> String {
|
||||
let tool_names = self.tool_registry.list_names().join(", ");
|
||||
let strategy_guidance = match session.strategy {
|
||||
PentestStrategy::Quick => {
|
||||
"Focus on the most common and impactful vulnerabilities. Run a quick recon, then target the highest-risk areas."
|
||||
}
|
||||
PentestStrategy::Comprehensive => {
|
||||
"Perform a thorough assessment covering all vulnerability types. Start with recon, then systematically test each attack surface."
|
||||
}
|
||||
PentestStrategy::Targeted => {
|
||||
"Focus specifically on areas highlighted by SAST findings and known CVEs. Prioritize exploiting known weaknesses."
|
||||
}
|
||||
PentestStrategy::Aggressive => {
|
||||
"Use all available tools aggressively. Test with maximum payloads and attempt full exploitation."
|
||||
}
|
||||
PentestStrategy::Stealth => {
|
||||
"Minimize noise. Use fewer requests, avoid aggressive payloads. Focus on passive analysis and targeted probes."
|
||||
}
|
||||
};
|
||||
|
||||
// Build SAST findings section
|
||||
let sast_section = if sast_findings.is_empty() {
|
||||
String::from("No SAST findings available for this target.")
|
||||
} else {
|
||||
let critical = sast_findings
|
||||
.iter()
|
||||
.filter(|f| f.severity == Severity::Critical)
|
||||
.count();
|
||||
let high = sast_findings
|
||||
.iter()
|
||||
.filter(|f| f.severity == Severity::High)
|
||||
.count();
|
||||
|
||||
let mut section = format!(
|
||||
"{} open findings ({} critical, {} high):\n",
|
||||
sast_findings.len(),
|
||||
critical,
|
||||
high
|
||||
);
|
||||
|
||||
// List the most important findings (critical/high first, up to 20)
|
||||
for f in sast_findings.iter().take(20) {
|
||||
let file_info = f
|
||||
.file_path
|
||||
.as_ref()
|
||||
.map(|p| {
|
||||
format!(
|
||||
" in {}:{}",
|
||||
p,
|
||||
f.line_number.unwrap_or(0)
|
||||
)
|
||||
})
|
||||
.unwrap_or_default();
|
||||
let status_note = match f.status {
|
||||
FindingStatus::Triaged => " [TRIAGED]",
|
||||
_ => "",
|
||||
};
|
||||
section.push_str(&format!(
|
||||
"- [{sev}] {title}{file}{status}\n",
|
||||
sev = f.severity,
|
||||
title = f.title,
|
||||
file = file_info,
|
||||
status = status_note,
|
||||
));
|
||||
if let Some(cwe) = &f.cwe {
|
||||
section.push_str(&format!(" CWE: {cwe}\n"));
|
||||
}
|
||||
}
|
||||
if sast_findings.len() > 20 {
|
||||
section.push_str(&format!(
|
||||
"... and {} more findings\n",
|
||||
sast_findings.len() - 20
|
||||
));
|
||||
}
|
||||
section
|
||||
};
|
||||
|
||||
// Build SBOM/CVE section
|
||||
let sbom_section = if sbom_entries.is_empty() {
|
||||
String::from("No vulnerable dependencies identified.")
|
||||
} else {
|
||||
let mut section = format!(
|
||||
"{} dependencies with known vulnerabilities:\n",
|
||||
sbom_entries.len()
|
||||
);
|
||||
for entry in sbom_entries.iter().take(15) {
|
||||
let cve_ids: Vec<&str> = entry
|
||||
.known_vulnerabilities
|
||||
.iter()
|
||||
.map(|v| v.id.as_str())
|
||||
.collect();
|
||||
section.push_str(&format!(
|
||||
"- {} {} ({}): {}\n",
|
||||
entry.name,
|
||||
entry.version,
|
||||
entry.package_manager,
|
||||
cve_ids.join(", ")
|
||||
));
|
||||
}
|
||||
if sbom_entries.len() > 15 {
|
||||
section.push_str(&format!(
|
||||
"... and {} more vulnerable dependencies\n",
|
||||
sbom_entries.len() - 15
|
||||
));
|
||||
}
|
||||
section
|
||||
};
|
||||
|
||||
// Build code context section
|
||||
let code_section = if code_context.is_empty() {
|
||||
String::from("No code knowledge graph available for this target.")
|
||||
} else {
|
||||
let with_vulns = code_context
|
||||
.iter()
|
||||
.filter(|c| !c.known_vulnerabilities.is_empty())
|
||||
.count();
|
||||
|
||||
let mut section = format!(
|
||||
"{} entry points identified ({} with linked SAST findings):\n",
|
||||
code_context.len(),
|
||||
with_vulns
|
||||
);
|
||||
|
||||
for hint in code_context.iter().take(20) {
|
||||
section.push_str(&format!(
|
||||
"- {} ({})\n",
|
||||
hint.endpoint_pattern, hint.file_path
|
||||
));
|
||||
for vuln in &hint.known_vulnerabilities {
|
||||
section.push_str(&format!(" SAST: {vuln}\n"));
|
||||
}
|
||||
}
|
||||
section
|
||||
};
|
||||
|
||||
format!(
|
||||
r#"You are an expert penetration tester conducting an authorized security assessment.
|
||||
|
||||
## Target
|
||||
- **Name**: {target_name}
|
||||
- **URL**: {base_url}
|
||||
- **Type**: {target_type}
|
||||
- **Rate Limit**: {rate_limit} req/s
|
||||
- **Destructive Tests Allowed**: {allow_destructive}
|
||||
- **Linked Repository**: {repo_linked}
|
||||
|
||||
## Strategy
|
||||
{strategy_guidance}
|
||||
|
||||
## SAST Findings (Static Analysis)
|
||||
{sast_section}
|
||||
|
||||
## Vulnerable Dependencies (SBOM)
|
||||
{sbom_section}
|
||||
|
||||
## Code Entry Points (Knowledge Graph)
|
||||
{code_section}
|
||||
|
||||
## Available Tools
|
||||
{tool_names}
|
||||
|
||||
## Instructions
|
||||
1. Start by running reconnaissance (recon tool) to fingerprint the target and discover technologies.
|
||||
2. Run the OpenAPI parser to discover API endpoints from specs.
|
||||
3. Check infrastructure: DNS, DMARC, TLS, security headers, cookies, CSP, CORS.
|
||||
4. Based on SAST findings, prioritize testing endpoints where vulnerabilities were found in code.
|
||||
5. For each vulnerability type found in SAST, use the corresponding DAST tool to verify exploitability.
|
||||
6. If vulnerable dependencies are listed, try to trigger known CVE conditions against the running application.
|
||||
7. Test rate limiting on critical endpoints (login, API).
|
||||
8. Check for console.log leakage in frontend JavaScript.
|
||||
9. Analyze tool results and chain findings — if one vulnerability enables others, explore the chain.
|
||||
10. When testing is complete, provide a structured summary with severity and remediation.
|
||||
11. Always explain your reasoning before invoking each tool.
|
||||
12. When done, say "Testing complete" followed by a final summary.
|
||||
|
||||
## Important
|
||||
- This is an authorized penetration test. All testing is permitted within the target scope.
|
||||
- Respect the rate limit of {rate_limit} requests per second.
|
||||
- Only use destructive tests if explicitly allowed ({allow_destructive}).
|
||||
- Use SAST findings to guide your testing — they tell you WHERE in the code vulnerabilities exist.
|
||||
- Use SBOM data to understand what technologies and versions the target runs.
|
||||
"#,
|
||||
target_name = target.name,
|
||||
base_url = target.base_url,
|
||||
target_type = target.target_type,
|
||||
rate_limit = target.rate_limit,
|
||||
allow_destructive = target.allow_destructive,
|
||||
repo_linked = target.repo_id.as_deref().unwrap_or("None"),
|
||||
)
|
||||
}
|
||||
}
|
||||
@@ -176,6 +176,16 @@ pub enum DastVulnType {
|
||||
InformationDisclosure,
|
||||
SecurityMisconfiguration,
|
||||
BrokenAuth,
|
||||
DnsMisconfiguration,
|
||||
EmailSecurity,
|
||||
TlsMisconfiguration,
|
||||
CookieSecurity,
|
||||
CspIssue,
|
||||
CorsMisconfiguration,
|
||||
RateLimitAbsent,
|
||||
ConsoleLogLeakage,
|
||||
SecurityHeaderMissing,
|
||||
KnownCveExploit,
|
||||
Other,
|
||||
}
|
||||
|
||||
@@ -192,6 +202,16 @@ impl std::fmt::Display for DastVulnType {
|
||||
Self::InformationDisclosure => write!(f, "information_disclosure"),
|
||||
Self::SecurityMisconfiguration => write!(f, "security_misconfiguration"),
|
||||
Self::BrokenAuth => write!(f, "broken_auth"),
|
||||
Self::DnsMisconfiguration => write!(f, "dns_misconfiguration"),
|
||||
Self::EmailSecurity => write!(f, "email_security"),
|
||||
Self::TlsMisconfiguration => write!(f, "tls_misconfiguration"),
|
||||
Self::CookieSecurity => write!(f, "cookie_security"),
|
||||
Self::CspIssue => write!(f, "csp_issue"),
|
||||
Self::CorsMisconfiguration => write!(f, "cors_misconfiguration"),
|
||||
Self::RateLimitAbsent => write!(f, "rate_limit_absent"),
|
||||
Self::ConsoleLogLeakage => write!(f, "console_log_leakage"),
|
||||
Self::SecurityHeaderMissing => write!(f, "security_header_missing"),
|
||||
Self::KnownCveExploit => write!(f, "known_cve_exploit"),
|
||||
Self::Other => write!(f, "other"),
|
||||
}
|
||||
}
|
||||
@@ -244,6 +264,8 @@ pub struct DastFinding {
|
||||
pub remediation: Option<String>,
|
||||
/// Linked SAST finding ID (if correlated)
|
||||
pub linked_sast_finding_id: Option<String>,
|
||||
/// Pentest session that produced this finding (if AI-driven)
|
||||
pub session_id: Option<String>,
|
||||
#[serde(with = "super::serde_helpers::bson_datetime")]
|
||||
pub created_at: DateTime<Utc>,
|
||||
}
|
||||
@@ -276,6 +298,7 @@ impl DastFinding {
|
||||
evidence: Vec::new(),
|
||||
remediation: None,
|
||||
linked_sast_finding_id: None,
|
||||
session_id: None,
|
||||
created_at: Utc::now(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,6 +7,7 @@ pub mod finding;
|
||||
pub mod graph;
|
||||
pub mod issue;
|
||||
pub mod mcp;
|
||||
pub mod pentest;
|
||||
pub mod repository;
|
||||
pub mod sbom;
|
||||
pub mod scan;
|
||||
@@ -26,6 +27,11 @@ pub use graph::{
|
||||
};
|
||||
pub use issue::{IssueStatus, TrackerIssue, TrackerType};
|
||||
pub use mcp::{McpServerConfig, McpServerStatus, McpTransport};
|
||||
pub use pentest::{
|
||||
AttackChainNode, AttackNodeStatus, CodeContextHint, PentestEvent, PentestMessage,
|
||||
PentestSession, PentestStats, PentestStatus, PentestStrategy, SeverityDistribution,
|
||||
ToolCallRecord,
|
||||
};
|
||||
pub use repository::{ScanTrigger, TrackedRepository};
|
||||
pub use sbom::{SbomEntry, VulnRef};
|
||||
pub use scan::{ScanPhase, ScanRun, ScanRunStatus, ScanType};
|
||||
|
||||
294
compliance-core/src/models/pentest.rs
Normal file
294
compliance-core/src/models/pentest.rs
Normal file
@@ -0,0 +1,294 @@
|
||||
use chrono::{DateTime, Utc};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Status of a pentest session
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum PentestStatus {
|
||||
Running,
|
||||
Paused,
|
||||
Completed,
|
||||
Failed,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for PentestStatus {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
Self::Running => write!(f, "running"),
|
||||
Self::Paused => write!(f, "paused"),
|
||||
Self::Completed => write!(f, "completed"),
|
||||
Self::Failed => write!(f, "failed"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Strategy for the AI pentest orchestrator
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum PentestStrategy {
|
||||
/// Quick scan focusing on common vulnerabilities
|
||||
Quick,
|
||||
/// Standard comprehensive scan
|
||||
Comprehensive,
|
||||
/// Focus on specific vulnerability types guided by SAST/SBOM
|
||||
Targeted,
|
||||
/// Aggressive testing with more payloads and deeper exploitation
|
||||
Aggressive,
|
||||
/// Stealth mode with slower rate and fewer noisy payloads
|
||||
Stealth,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for PentestStrategy {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
Self::Quick => write!(f, "quick"),
|
||||
Self::Comprehensive => write!(f, "comprehensive"),
|
||||
Self::Targeted => write!(f, "targeted"),
|
||||
Self::Aggressive => write!(f, "aggressive"),
|
||||
Self::Stealth => write!(f, "stealth"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A pentest session initiated via the chat interface
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct PentestSession {
|
||||
#[serde(rename = "_id", skip_serializing_if = "Option::is_none")]
|
||||
pub id: Option<bson::oid::ObjectId>,
|
||||
pub target_id: String,
|
||||
/// Linked repository for code-aware testing
|
||||
pub repo_id: Option<String>,
|
||||
pub status: PentestStatus,
|
||||
pub strategy: PentestStrategy,
|
||||
pub created_by: Option<String>,
|
||||
/// Total number of tool invocations in this session
|
||||
pub tool_invocations: u32,
|
||||
/// Total successful tool invocations
|
||||
pub tool_successes: u32,
|
||||
/// Number of findings discovered
|
||||
pub findings_count: u32,
|
||||
/// Number of confirmed exploitable findings
|
||||
pub exploitable_count: u32,
|
||||
#[serde(with = "super::serde_helpers::bson_datetime")]
|
||||
pub started_at: DateTime<Utc>,
|
||||
#[serde(default, with = "super::serde_helpers::opt_bson_datetime")]
|
||||
pub completed_at: Option<DateTime<Utc>>,
|
||||
}
|
||||
|
||||
impl PentestSession {
|
||||
pub fn new(target_id: String, strategy: PentestStrategy) -> Self {
|
||||
Self {
|
||||
id: None,
|
||||
target_id,
|
||||
repo_id: None,
|
||||
status: PentestStatus::Running,
|
||||
strategy,
|
||||
created_by: None,
|
||||
tool_invocations: 0,
|
||||
tool_successes: 0,
|
||||
findings_count: 0,
|
||||
exploitable_count: 0,
|
||||
started_at: Utc::now(),
|
||||
completed_at: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn success_rate(&self) -> f64 {
|
||||
if self.tool_invocations == 0 {
|
||||
return 100.0;
|
||||
}
|
||||
(self.tool_successes as f64 / self.tool_invocations as f64) * 100.0
|
||||
}
|
||||
}
|
||||
|
||||
/// Status of a node in the attack chain
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum AttackNodeStatus {
|
||||
Pending,
|
||||
Running,
|
||||
Completed,
|
||||
Failed,
|
||||
Skipped,
|
||||
}
|
||||
|
||||
/// A single step in the LLM-driven attack chain DAG
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct AttackChainNode {
|
||||
#[serde(rename = "_id", skip_serializing_if = "Option::is_none")]
|
||||
pub id: Option<bson::oid::ObjectId>,
|
||||
pub session_id: String,
|
||||
/// Unique ID for DAG references
|
||||
pub node_id: String,
|
||||
/// Parent node IDs (multiple for merge nodes)
|
||||
pub parent_node_ids: Vec<String>,
|
||||
/// Tool that was invoked
|
||||
pub tool_name: String,
|
||||
/// Input parameters passed to the tool
|
||||
pub tool_input: serde_json::Value,
|
||||
/// Output from the tool
|
||||
pub tool_output: Option<serde_json::Value>,
|
||||
pub status: AttackNodeStatus,
|
||||
/// LLM's reasoning for choosing this action
|
||||
pub llm_reasoning: String,
|
||||
/// IDs of DastFindings produced by this step
|
||||
pub findings_produced: Vec<String>,
|
||||
/// Risk score (0-100) assigned by the LLM
|
||||
pub risk_score: Option<u8>,
|
||||
#[serde(default, with = "super::serde_helpers::opt_bson_datetime")]
|
||||
pub started_at: Option<DateTime<Utc>>,
|
||||
#[serde(default, with = "super::serde_helpers::opt_bson_datetime")]
|
||||
pub completed_at: Option<DateTime<Utc>>,
|
||||
}
|
||||
|
||||
impl AttackChainNode {
|
||||
pub fn new(
|
||||
session_id: String,
|
||||
node_id: String,
|
||||
tool_name: String,
|
||||
tool_input: serde_json::Value,
|
||||
llm_reasoning: String,
|
||||
) -> Self {
|
||||
Self {
|
||||
id: None,
|
||||
session_id,
|
||||
node_id,
|
||||
parent_node_ids: Vec::new(),
|
||||
tool_name,
|
||||
tool_input,
|
||||
tool_output: None,
|
||||
status: AttackNodeStatus::Pending,
|
||||
llm_reasoning,
|
||||
findings_produced: Vec::new(),
|
||||
risk_score: None,
|
||||
started_at: None,
|
||||
completed_at: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Chat message within a pentest session
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct PentestMessage {
|
||||
#[serde(rename = "_id", skip_serializing_if = "Option::is_none")]
|
||||
pub id: Option<bson::oid::ObjectId>,
|
||||
pub session_id: String,
|
||||
/// "user", "assistant", "tool_result", "system"
|
||||
pub role: String,
|
||||
pub content: String,
|
||||
/// Tool calls made by the assistant in this message
|
||||
pub tool_calls: Option<Vec<ToolCallRecord>>,
|
||||
/// Link to the attack chain node (for tool_result messages)
|
||||
pub attack_node_id: Option<String>,
|
||||
#[serde(with = "super::serde_helpers::bson_datetime")]
|
||||
pub created_at: DateTime<Utc>,
|
||||
}
|
||||
|
||||
impl PentestMessage {
|
||||
pub fn user(session_id: String, content: String) -> Self {
|
||||
Self {
|
||||
id: None,
|
||||
session_id,
|
||||
role: "user".to_string(),
|
||||
content,
|
||||
tool_calls: None,
|
||||
attack_node_id: None,
|
||||
created_at: Utc::now(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn assistant(session_id: String, content: String) -> Self {
|
||||
Self {
|
||||
id: None,
|
||||
session_id,
|
||||
role: "assistant".to_string(),
|
||||
content,
|
||||
tool_calls: None,
|
||||
attack_node_id: None,
|
||||
created_at: Utc::now(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn tool_result(session_id: String, content: String, node_id: String) -> Self {
|
||||
Self {
|
||||
id: None,
|
||||
session_id,
|
||||
role: "tool_result".to_string(),
|
||||
content,
|
||||
tool_calls: None,
|
||||
attack_node_id: Some(node_id),
|
||||
created_at: Utc::now(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Record of a tool call made by the LLM
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ToolCallRecord {
|
||||
pub call_id: String,
|
||||
pub tool_name: String,
|
||||
pub arguments: serde_json::Value,
|
||||
pub result: Option<serde_json::Value>,
|
||||
}
|
||||
|
||||
/// SSE event types for real-time pentest streaming
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(tag = "type", rename_all = "snake_case")]
|
||||
pub enum PentestEvent {
|
||||
/// LLM is thinking/reasoning
|
||||
Thinking { reasoning: String },
|
||||
/// A tool execution has started
|
||||
ToolStart {
|
||||
node_id: String,
|
||||
tool_name: String,
|
||||
input: serde_json::Value,
|
||||
},
|
||||
/// A tool execution completed
|
||||
ToolComplete {
|
||||
node_id: String,
|
||||
summary: String,
|
||||
findings_count: u32,
|
||||
},
|
||||
/// A new finding was discovered
|
||||
Finding { finding_id: String, title: String, severity: String },
|
||||
/// Assistant message (streaming text)
|
||||
Message { content: String },
|
||||
/// Session completed
|
||||
Complete { summary: String },
|
||||
/// Error occurred
|
||||
Error { message: String },
|
||||
}
|
||||
|
||||
/// Aggregated stats for the pentest dashboard
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct PentestStats {
|
||||
pub running_sessions: u32,
|
||||
pub total_vulnerabilities: u32,
|
||||
pub total_tool_invocations: u32,
|
||||
pub tool_success_rate: f64,
|
||||
pub severity_distribution: SeverityDistribution,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct SeverityDistribution {
|
||||
pub critical: u32,
|
||||
pub high: u32,
|
||||
pub medium: u32,
|
||||
pub low: u32,
|
||||
pub info: u32,
|
||||
}
|
||||
|
||||
/// Code context hint linking a discovered endpoint to source code
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct CodeContextHint {
|
||||
/// HTTP route pattern (e.g., "GET /api/users/:id")
|
||||
pub endpoint_pattern: String,
|
||||
/// Handler function name
|
||||
pub handler_function: String,
|
||||
/// Source file path
|
||||
pub file_path: String,
|
||||
/// Relevant code snippet
|
||||
pub code_snippet: String,
|
||||
/// SAST findings associated with this code
|
||||
pub known_vulnerabilities: Vec<String>,
|
||||
}
|
||||
@@ -1,9 +1,11 @@
|
||||
pub mod dast_agent;
|
||||
pub mod graph_builder;
|
||||
pub mod issue_tracker;
|
||||
pub mod pentest_tool;
|
||||
pub mod scanner;
|
||||
|
||||
pub use dast_agent::{DastAgent, DastContext, DiscoveredEndpoint, EndpointParameter};
|
||||
pub use graph_builder::{LanguageParser, ParseOutput};
|
||||
pub use issue_tracker::IssueTracker;
|
||||
pub use pentest_tool::{PentestTool, PentestToolContext, PentestToolResult};
|
||||
pub use scanner::{ScanOutput, Scanner};
|
||||
|
||||
63
compliance-core/src/traits/pentest_tool.rs
Normal file
63
compliance-core/src/traits/pentest_tool.rs
Normal file
@@ -0,0 +1,63 @@
|
||||
use std::future::Future;
|
||||
use std::pin::Pin;
|
||||
|
||||
use crate::error::CoreError;
|
||||
use crate::models::dast::{DastFinding, DastTarget};
|
||||
use crate::models::finding::Finding;
|
||||
use crate::models::pentest::CodeContextHint;
|
||||
use crate::models::sbom::SbomEntry;
|
||||
|
||||
/// Context passed to pentest tools during execution.
|
||||
///
|
||||
/// The HTTP client is not included here because `compliance-core` does not
|
||||
/// depend on `reqwest`. Tools that need HTTP should hold their own client
|
||||
/// or receive one via the `compliance-dast` orchestrator.
|
||||
pub struct PentestToolContext {
|
||||
/// The DAST target being tested
|
||||
pub target: DastTarget,
|
||||
/// Session ID for this pentest run
|
||||
pub session_id: String,
|
||||
/// SAST findings for the linked repo (if any)
|
||||
pub sast_findings: Vec<Finding>,
|
||||
/// SBOM entries with known CVEs (if any)
|
||||
pub sbom_entries: Vec<SbomEntry>,
|
||||
/// Code knowledge graph hints mapping endpoints to source code
|
||||
pub code_context: Vec<CodeContextHint>,
|
||||
/// Rate limit (requests per second)
|
||||
pub rate_limit: u32,
|
||||
/// Whether destructive operations are allowed
|
||||
pub allow_destructive: bool,
|
||||
}
|
||||
|
||||
/// Result from a pentest tool execution
|
||||
pub struct PentestToolResult {
|
||||
/// Human-readable summary of what the tool found
|
||||
pub summary: String,
|
||||
/// DAST findings produced by this tool
|
||||
pub findings: Vec<DastFinding>,
|
||||
/// Tool-specific structured output data
|
||||
pub data: serde_json::Value,
|
||||
}
|
||||
|
||||
/// A tool that the LLM pentest orchestrator can invoke.
|
||||
///
|
||||
/// Each tool represents a specific security testing capability
|
||||
/// (e.g., SQL injection scanner, DNS checker, TLS analyzer).
|
||||
/// Uses boxed futures for dyn-compatibility.
|
||||
pub trait PentestTool: Send + Sync {
|
||||
/// Tool name for LLM tool_use (e.g., "sql_injection_scanner")
|
||||
fn name(&self) -> &str;
|
||||
|
||||
/// Human-readable description for the LLM system prompt
|
||||
fn description(&self) -> &str;
|
||||
|
||||
/// JSON Schema for the tool's input parameters
|
||||
fn input_schema(&self) -> serde_json::Value;
|
||||
|
||||
/// Execute the tool with the given input
|
||||
fn execute<'a>(
|
||||
&'a self,
|
||||
input: serde_json::Value,
|
||||
context: &'a PentestToolContext,
|
||||
) -> Pin<Box<dyn Future<Output = Result<PentestToolResult, CoreError>> + Send + 'a>>;
|
||||
}
|
||||
234
compliance-dashboard/assets/attack-chain-viz.js
Normal file
234
compliance-dashboard/assets/attack-chain-viz.js
Normal file
@@ -0,0 +1,234 @@
|
||||
// ═══════════════════════════════════════════════════════════════
|
||||
// Attack Chain DAG Visualization — vis-network wrapper
|
||||
// Obsidian Control theme
|
||||
// ═══════════════════════════════════════════════════════════════
|
||||
|
||||
(function () {
|
||||
"use strict";
|
||||
|
||||
// Status color palette matching Obsidian Control
|
||||
var STATUS_COLORS = {
|
||||
completed: { bg: "#16a34a", border: "#12873c", font: "#060a13" },
|
||||
running: { bg: "#d97706", border: "#b56205", font: "#060a13" },
|
||||
failed: { bg: "#dc2626", border: "#b91c1c", font: "#ffffff" },
|
||||
pending: { bg: "#5e7291", border: "#3d506b", font: "#e4eaf4" },
|
||||
skipped: { bg: "#374151", border: "#1f2937", font: "#e4eaf4" },
|
||||
};
|
||||
|
||||
var EDGE_COLOR = "rgba(94, 114, 145, 0.5)";
|
||||
|
||||
var network = null;
|
||||
var nodesDataset = null;
|
||||
var edgesDataset = null;
|
||||
var rawNodesMap = {};
|
||||
|
||||
function getStatusColor(status) {
|
||||
return STATUS_COLORS[status] || STATUS_COLORS.pending;
|
||||
}
|
||||
|
||||
function truncate(str, maxLen) {
|
||||
if (!str) return "";
|
||||
return str.length > maxLen ? str.substring(0, maxLen) + "…" : str;
|
||||
}
|
||||
|
||||
function buildTooltip(node) {
|
||||
var lines = [];
|
||||
lines.push("Tool: " + (node.tool_name || "unknown"));
|
||||
lines.push("Status: " + (node.status || "pending"));
|
||||
if (node.llm_reasoning) {
|
||||
lines.push("Reasoning: " + truncate(node.llm_reasoning, 200));
|
||||
}
|
||||
var findingsCount = node.findings_produced ? node.findings_produced.length : 0;
|
||||
lines.push("Findings: " + findingsCount);
|
||||
lines.push("Risk: " + (node.risk_score != null ? node.risk_score : "N/A"));
|
||||
return lines.join("\n");
|
||||
}
|
||||
|
||||
function toVisNode(node) {
|
||||
var color = getStatusColor(node.status);
|
||||
// Scale node size by risk_score: min 12, max 40
|
||||
var risk = typeof node.risk_score === "number" ? node.risk_score : 0;
|
||||
var size = Math.max(12, Math.min(40, 12 + (risk / 100) * 28));
|
||||
|
||||
return {
|
||||
id: node.node_id,
|
||||
label: node.tool_name || "unknown",
|
||||
title: buildTooltip(node),
|
||||
size: size,
|
||||
color: {
|
||||
background: color.bg,
|
||||
border: color.border,
|
||||
highlight: { background: color.bg, border: "#ffffff" },
|
||||
hover: { background: color.bg, border: "#ffffff" },
|
||||
},
|
||||
font: {
|
||||
color: color.font,
|
||||
size: 11,
|
||||
face: "'JetBrains Mono', monospace",
|
||||
strokeWidth: 2,
|
||||
strokeColor: "#060a13",
|
||||
},
|
||||
borderWidth: 1,
|
||||
borderWidthSelected: 3,
|
||||
shape: "dot",
|
||||
_raw: node,
|
||||
};
|
||||
}
|
||||
|
||||
function buildEdges(nodes) {
|
||||
var edges = [];
|
||||
var seen = {};
|
||||
nodes.forEach(function (node) {
|
||||
if (!node.parent_node_ids) return;
|
||||
node.parent_node_ids.forEach(function (parentId) {
|
||||
var key = parentId + "|" + node.node_id;
|
||||
if (seen[key]) return;
|
||||
seen[key] = true;
|
||||
edges.push({
|
||||
from: parentId,
|
||||
to: node.node_id,
|
||||
color: {
|
||||
color: EDGE_COLOR,
|
||||
highlight: "#ffffff",
|
||||
hover: EDGE_COLOR,
|
||||
},
|
||||
width: 2,
|
||||
arrows: {
|
||||
to: { enabled: true, scaleFactor: 0.5 },
|
||||
},
|
||||
smooth: {
|
||||
enabled: true,
|
||||
type: "cubicBezier",
|
||||
roundness: 0.5,
|
||||
forceDirection: "vertical",
|
||||
},
|
||||
});
|
||||
});
|
||||
});
|
||||
return edges;
|
||||
}
|
||||
|
||||
/**
|
||||
* Load and render an attack chain DAG.
|
||||
* Called from Rust via eval().
|
||||
* @param {Array} nodes - Array of AttackChainNode objects
|
||||
*/
|
||||
window.__loadAttackChain = function (nodes) {
|
||||
var container = document.getElementById("attack-chain-canvas");
|
||||
if (!container) {
|
||||
console.error("[attack-chain-viz] #attack-chain-canvas not found");
|
||||
return;
|
||||
}
|
||||
|
||||
// Build lookup map
|
||||
rawNodesMap = {};
|
||||
nodes.forEach(function (n) {
|
||||
rawNodesMap[n.node_id] = n;
|
||||
});
|
||||
|
||||
var visNodes = nodes.map(toVisNode);
|
||||
var visEdges = buildEdges(nodes);
|
||||
|
||||
nodesDataset = new vis.DataSet(visNodes);
|
||||
edgesDataset = new vis.DataSet(visEdges);
|
||||
|
||||
var options = {
|
||||
nodes: {
|
||||
font: { color: "#e4eaf4", size: 11 },
|
||||
scaling: { min: 12, max: 40 },
|
||||
},
|
||||
edges: {
|
||||
font: { color: "#5e7291", size: 9, strokeWidth: 0 },
|
||||
selectionWidth: 3,
|
||||
},
|
||||
physics: {
|
||||
enabled: false,
|
||||
},
|
||||
layout: {
|
||||
hierarchical: {
|
||||
enabled: true,
|
||||
direction: "UD",
|
||||
sortMethod: "directed",
|
||||
levelSeparation: 120,
|
||||
nodeSpacing: 160,
|
||||
treeSpacing: 200,
|
||||
blockShifting: true,
|
||||
edgeMinimization: true,
|
||||
parentCentralization: true,
|
||||
},
|
||||
},
|
||||
interaction: {
|
||||
hover: true,
|
||||
tooltipDelay: 200,
|
||||
hideEdgesOnDrag: false,
|
||||
hideEdgesOnZoom: false,
|
||||
multiselect: false,
|
||||
navigationButtons: false,
|
||||
keyboard: { enabled: true },
|
||||
},
|
||||
};
|
||||
|
||||
// Destroy previous instance
|
||||
if (network) {
|
||||
network.destroy();
|
||||
}
|
||||
|
||||
network = new vis.Network(
|
||||
container,
|
||||
{ nodes: nodesDataset, edges: edgesDataset },
|
||||
options
|
||||
);
|
||||
|
||||
// Click handler — sends data to Rust
|
||||
network.on("click", function (params) {
|
||||
if (params.nodes.length > 0) {
|
||||
var nodeId = params.nodes[0];
|
||||
var visNode = nodesDataset.get(nodeId);
|
||||
if (visNode && visNode._raw && window.__onAttackNodeClick) {
|
||||
window.__onAttackNodeClick(JSON.stringify(visNode._raw));
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
console.log(
|
||||
"[attack-chain-viz] Loaded " + nodes.length + " nodes, " + visEdges.length + " edges"
|
||||
);
|
||||
};
|
||||
|
||||
/**
|
||||
* Callback placeholder for Rust to set.
|
||||
* Called with JSON string of the clicked node's data.
|
||||
*/
|
||||
window.__onAttackNodeClick = null;
|
||||
|
||||
/**
|
||||
* Fit entire attack chain DAG in view.
|
||||
*/
|
||||
window.__fitAttackChain = function () {
|
||||
if (!network) return;
|
||||
network.fit({
|
||||
animation: { duration: 400, easingFunction: "easeInOutQuad" },
|
||||
});
|
||||
};
|
||||
|
||||
/**
|
||||
* Select and focus on a specific node by node_id.
|
||||
*/
|
||||
window.__highlightAttackNode = function (nodeId) {
|
||||
if (!network || !nodesDataset) return;
|
||||
|
||||
var node = nodesDataset.get(nodeId);
|
||||
if (!node) return;
|
||||
|
||||
network.selectNodes([nodeId]);
|
||||
network.focus(nodeId, {
|
||||
scale: 1.5,
|
||||
animation: { duration: 500, easingFunction: "easeInOutQuad" },
|
||||
});
|
||||
|
||||
// Trigger click callback too
|
||||
if (node._raw && window.__onAttackNodeClick) {
|
||||
window.__onAttackNodeClick(JSON.stringify(node._raw));
|
||||
}
|
||||
};
|
||||
})();
|
||||
@@ -38,6 +38,10 @@ pub enum Route {
|
||||
DastFindingsPage {},
|
||||
#[route("/dast/findings/:id")]
|
||||
DastFindingDetailPage { id: String },
|
||||
#[route("/pentest")]
|
||||
PentestDashboardPage {},
|
||||
#[route("/pentest/:session_id")]
|
||||
PentestSessionPage { session_id: String },
|
||||
#[route("/mcp-servers")]
|
||||
McpServersPage {},
|
||||
#[route("/settings")]
|
||||
@@ -49,6 +53,7 @@ const MAIN_CSS: Asset = asset!("/assets/main.css");
|
||||
const TAILWIND_CSS: Asset = asset!("/assets/tailwind.css");
|
||||
const VIS_NETWORK_JS: Asset = asset!("/assets/vis-network.min.js");
|
||||
const GRAPH_VIZ_JS: Asset = asset!("/assets/graph-viz.js");
|
||||
const ATTACK_CHAIN_VIZ_JS: Asset = asset!("/assets/attack-chain-viz.js");
|
||||
|
||||
#[component]
|
||||
pub fn App() -> Element {
|
||||
@@ -58,6 +63,7 @@ pub fn App() -> Element {
|
||||
document::Link { rel: "stylesheet", href: MAIN_CSS }
|
||||
document::Script { src: VIS_NETWORK_JS }
|
||||
document::Script { src: GRAPH_VIZ_JS }
|
||||
document::Script { src: ATTACK_CHAIN_VIZ_JS }
|
||||
Router::<Route> {}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -47,6 +47,11 @@ pub fn Sidebar() -> Element {
|
||||
route: Route::DastOverviewPage {},
|
||||
icon: rsx! { Icon { icon: BsBug, width: 18, height: 18 } },
|
||||
},
|
||||
NavItem {
|
||||
label: "Pentest",
|
||||
route: Route::PentestDashboardPage {},
|
||||
icon: rsx! { Icon { icon: BsLightningCharge, width: 18, height: 18 } },
|
||||
},
|
||||
NavItem {
|
||||
label: "Settings",
|
||||
route: Route::SettingsPage {},
|
||||
@@ -78,6 +83,7 @@ pub fn Sidebar() -> Element {
|
||||
(Route::DastTargetsPage {}, Route::DastOverviewPage {}) => true,
|
||||
(Route::DastFindingsPage {}, Route::DastOverviewPage {}) => true,
|
||||
(Route::DastFindingDetailPage { .. }, Route::DastOverviewPage {}) => true,
|
||||
(Route::PentestSessionPage { .. }, Route::PentestDashboardPage {}) => true,
|
||||
(a, b) => a == b,
|
||||
};
|
||||
let class = if is_active { "nav-item active" } else { "nav-item" };
|
||||
|
||||
@@ -7,6 +7,7 @@ pub mod findings;
|
||||
pub mod graph;
|
||||
pub mod issues;
|
||||
pub mod mcp;
|
||||
pub mod pentest;
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub mod repositories;
|
||||
pub mod sbom;
|
||||
|
||||
270
compliance-dashboard/src/infrastructure/pentest.rs
Normal file
270
compliance-dashboard/src/infrastructure/pentest.rs
Normal file
@@ -0,0 +1,270 @@
|
||||
use dioxus::prelude::*;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use super::dast::DastFindingsResponse;
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
||||
pub struct PentestSessionsResponse {
|
||||
pub data: Vec<serde_json::Value>,
|
||||
pub total: Option<u64>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
||||
pub struct PentestSessionResponse {
|
||||
pub data: serde_json::Value,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
||||
pub struct PentestMessagesResponse {
|
||||
pub data: Vec<serde_json::Value>,
|
||||
pub total: Option<u64>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
||||
pub struct PentestStatsResponse {
|
||||
pub data: serde_json::Value,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
||||
pub struct AttackChainResponse {
|
||||
pub data: Vec<serde_json::Value>,
|
||||
}
|
||||
|
||||
#[server]
|
||||
pub async fn fetch_pentest_sessions() -> Result<PentestSessionsResponse, ServerFnError> {
|
||||
let state: super::server_state::ServerState =
|
||||
dioxus_fullstack::FullstackContext::extract().await?;
|
||||
|
||||
// Fetch sessions
|
||||
let url = format!("{}/api/v1/pentest/sessions", state.agent_api_url);
|
||||
let resp = reqwest::get(&url)
|
||||
.await
|
||||
.map_err(|e| ServerFnError::new(e.to_string()))?;
|
||||
let mut body: PentestSessionsResponse = resp
|
||||
.json()
|
||||
.await
|
||||
.map_err(|e| ServerFnError::new(e.to_string()))?;
|
||||
|
||||
// Fetch DAST targets to resolve target names
|
||||
let targets_url = format!("{}/api/v1/dast/targets", state.agent_api_url);
|
||||
if let Ok(tresp) = reqwest::get(&targets_url).await {
|
||||
if let Ok(tbody) = tresp.json::<serde_json::Value>().await {
|
||||
let targets = tbody.get("data").and_then(|v| v.as_array());
|
||||
if let Some(targets) = targets {
|
||||
// Build target_id -> name lookup
|
||||
let target_map: std::collections::HashMap<String, String> = targets
|
||||
.iter()
|
||||
.filter_map(|t| {
|
||||
let id = t.get("_id")?.get("$oid")?.as_str()?.to_string();
|
||||
let name = t.get("name")?.as_str()?.to_string();
|
||||
Some((id, name))
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Enrich sessions with target_name
|
||||
for session in body.data.iter_mut() {
|
||||
if let Some(tid) = session.get("target_id").and_then(|v| v.as_str()) {
|
||||
if let Some(name) = target_map.get(tid) {
|
||||
session.as_object_mut().map(|obj| {
|
||||
obj.insert(
|
||||
"target_name".to_string(),
|
||||
serde_json::Value::String(name.clone()),
|
||||
)
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(body)
|
||||
}
|
||||
|
||||
#[server]
|
||||
pub async fn fetch_pentest_session(id: String) -> Result<PentestSessionResponse, ServerFnError> {
|
||||
let state: super::server_state::ServerState =
|
||||
dioxus_fullstack::FullstackContext::extract().await?;
|
||||
let url = format!("{}/api/v1/pentest/sessions/{id}", state.agent_api_url);
|
||||
let resp = reqwest::get(&url)
|
||||
.await
|
||||
.map_err(|e| ServerFnError::new(e.to_string()))?;
|
||||
let mut body: PentestSessionResponse = resp
|
||||
.json()
|
||||
.await
|
||||
.map_err(|e| ServerFnError::new(e.to_string()))?;
|
||||
|
||||
// Resolve target name from targets list
|
||||
if let Some(tid) = body.data.get("target_id").and_then(|v| v.as_str()) {
|
||||
let targets_url = format!("{}/api/v1/dast/targets", state.agent_api_url);
|
||||
if let Ok(tresp) = reqwest::get(&targets_url).await {
|
||||
if let Ok(tbody) = tresp.json::<serde_json::Value>().await {
|
||||
if let Some(targets) = tbody.get("data").and_then(|v| v.as_array()) {
|
||||
for t in targets {
|
||||
let t_id = t.get("_id").and_then(|v| v.get("$oid")).and_then(|v| v.as_str()).unwrap_or("");
|
||||
if t_id == tid {
|
||||
if let Some(name) = t.get("name").and_then(|v| v.as_str()) {
|
||||
body.data.as_object_mut().map(|obj| {
|
||||
obj.insert("target_name".to_string(), serde_json::Value::String(name.to_string()))
|
||||
});
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(body)
|
||||
}
|
||||
|
||||
#[server]
|
||||
pub async fn fetch_pentest_messages(
|
||||
session_id: String,
|
||||
) -> Result<PentestMessagesResponse, ServerFnError> {
|
||||
let state: super::server_state::ServerState =
|
||||
dioxus_fullstack::FullstackContext::extract().await?;
|
||||
let url = format!(
|
||||
"{}/api/v1/pentest/sessions/{session_id}/messages",
|
||||
state.agent_api_url
|
||||
);
|
||||
let resp = reqwest::get(&url)
|
||||
.await
|
||||
.map_err(|e| ServerFnError::new(e.to_string()))?;
|
||||
let body: PentestMessagesResponse = resp
|
||||
.json()
|
||||
.await
|
||||
.map_err(|e| ServerFnError::new(e.to_string()))?;
|
||||
Ok(body)
|
||||
}
|
||||
|
||||
#[server]
|
||||
pub async fn fetch_pentest_stats() -> Result<PentestStatsResponse, ServerFnError> {
|
||||
let state: super::server_state::ServerState =
|
||||
dioxus_fullstack::FullstackContext::extract().await?;
|
||||
let url = format!("{}/api/v1/pentest/stats", state.agent_api_url);
|
||||
let resp = reqwest::get(&url)
|
||||
.await
|
||||
.map_err(|e| ServerFnError::new(e.to_string()))?;
|
||||
let body: PentestStatsResponse = resp
|
||||
.json()
|
||||
.await
|
||||
.map_err(|e| ServerFnError::new(e.to_string()))?;
|
||||
Ok(body)
|
||||
}
|
||||
|
||||
#[server]
|
||||
pub async fn fetch_attack_chain(
|
||||
session_id: String,
|
||||
) -> Result<AttackChainResponse, ServerFnError> {
|
||||
let state: super::server_state::ServerState =
|
||||
dioxus_fullstack::FullstackContext::extract().await?;
|
||||
let url = format!(
|
||||
"{}/api/v1/pentest/sessions/{session_id}/attack-chain",
|
||||
state.agent_api_url
|
||||
);
|
||||
let resp = reqwest::get(&url)
|
||||
.await
|
||||
.map_err(|e| ServerFnError::new(e.to_string()))?;
|
||||
let body: AttackChainResponse = resp
|
||||
.json()
|
||||
.await
|
||||
.map_err(|e| ServerFnError::new(e.to_string()))?;
|
||||
Ok(body)
|
||||
}
|
||||
|
||||
#[server]
|
||||
pub async fn create_pentest_session(
|
||||
target_id: String,
|
||||
strategy: String,
|
||||
message: String,
|
||||
) -> Result<PentestSessionResponse, ServerFnError> {
|
||||
let state: super::server_state::ServerState =
|
||||
dioxus_fullstack::FullstackContext::extract().await?;
|
||||
let url = format!("{}/api/v1/pentest/sessions", state.agent_api_url);
|
||||
let client = reqwest::Client::new();
|
||||
let resp = client
|
||||
.post(&url)
|
||||
.json(&serde_json::json!({
|
||||
"target_id": target_id,
|
||||
"strategy": strategy,
|
||||
"message": message,
|
||||
}))
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| ServerFnError::new(e.to_string()))?;
|
||||
let body: PentestSessionResponse = resp
|
||||
.json()
|
||||
.await
|
||||
.map_err(|e| ServerFnError::new(e.to_string()))?;
|
||||
Ok(body)
|
||||
}
|
||||
|
||||
#[server]
|
||||
pub async fn send_pentest_message(
|
||||
session_id: String,
|
||||
message: String,
|
||||
) -> Result<PentestMessagesResponse, ServerFnError> {
|
||||
let state: super::server_state::ServerState =
|
||||
dioxus_fullstack::FullstackContext::extract().await?;
|
||||
let url = format!(
|
||||
"{}/api/v1/pentest/sessions/{session_id}/chat",
|
||||
state.agent_api_url
|
||||
);
|
||||
let client = reqwest::Client::new();
|
||||
let resp = client
|
||||
.post(&url)
|
||||
.json(&serde_json::json!({
|
||||
"message": message,
|
||||
}))
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| ServerFnError::new(e.to_string()))?;
|
||||
let body: PentestMessagesResponse = resp
|
||||
.json()
|
||||
.await
|
||||
.map_err(|e| ServerFnError::new(e.to_string()))?;
|
||||
Ok(body)
|
||||
}
|
||||
|
||||
#[server]
|
||||
pub async fn fetch_pentest_findings(
|
||||
session_id: String,
|
||||
) -> Result<DastFindingsResponse, ServerFnError> {
|
||||
let state: super::server_state::ServerState =
|
||||
dioxus_fullstack::FullstackContext::extract().await?;
|
||||
let url = format!(
|
||||
"{}/api/v1/pentest/sessions/{session_id}/findings",
|
||||
state.agent_api_url
|
||||
);
|
||||
let resp = reqwest::get(&url)
|
||||
.await
|
||||
.map_err(|e| ServerFnError::new(e.to_string()))?;
|
||||
let body: DastFindingsResponse = resp
|
||||
.json()
|
||||
.await
|
||||
.map_err(|e| ServerFnError::new(e.to_string()))?;
|
||||
Ok(body)
|
||||
}
|
||||
|
||||
#[server]
|
||||
pub async fn export_pentest_report(
|
||||
session_id: String,
|
||||
format: String,
|
||||
) -> Result<String, ServerFnError> {
|
||||
let state: super::server_state::ServerState =
|
||||
dioxus_fullstack::FullstackContext::extract().await?;
|
||||
let url = format!(
|
||||
"{}/api/v1/pentest/sessions/{session_id}/export?format={format}",
|
||||
state.agent_api_url
|
||||
);
|
||||
let resp = reqwest::get(&url)
|
||||
.await
|
||||
.map_err(|e| ServerFnError::new(e.to_string()))?;
|
||||
let body = resp
|
||||
.text()
|
||||
.await
|
||||
.map_err(|e| ServerFnError::new(e.to_string()))?;
|
||||
Ok(body)
|
||||
}
|
||||
@@ -12,6 +12,8 @@ pub mod impact_analysis;
|
||||
pub mod issues;
|
||||
pub mod mcp_servers;
|
||||
pub mod overview;
|
||||
pub mod pentest_dashboard;
|
||||
pub mod pentest_session;
|
||||
pub mod repositories;
|
||||
pub mod sbom;
|
||||
pub mod settings;
|
||||
@@ -30,6 +32,8 @@ pub use impact_analysis::ImpactAnalysisPage;
|
||||
pub use issues::IssuesPage;
|
||||
pub use mcp_servers::McpServersPage;
|
||||
pub use overview::OverviewPage;
|
||||
pub use pentest_dashboard::PentestDashboardPage;
|
||||
pub use pentest_session::PentestSessionPage;
|
||||
pub use repositories::RepositoriesPage;
|
||||
pub use sbom::SbomPage;
|
||||
pub use settings::SettingsPage;
|
||||
|
||||
403
compliance-dashboard/src/pages/pentest_dashboard.rs
Normal file
403
compliance-dashboard/src/pages/pentest_dashboard.rs
Normal file
@@ -0,0 +1,403 @@
|
||||
use dioxus::prelude::*;
|
||||
use dioxus_free_icons::icons::bs_icons::*;
|
||||
use dioxus_free_icons::Icon;
|
||||
|
||||
use crate::app::Route;
|
||||
use crate::components::page_header::PageHeader;
|
||||
use crate::infrastructure::dast::fetch_dast_targets;
|
||||
use crate::infrastructure::pentest::{
|
||||
create_pentest_session, fetch_pentest_sessions, fetch_pentest_stats,
|
||||
};
|
||||
|
||||
#[component]
|
||||
pub fn PentestDashboardPage() -> Element {
|
||||
let mut sessions = use_resource(|| async { fetch_pentest_sessions().await.ok() });
|
||||
let stats = use_resource(|| async { fetch_pentest_stats().await.ok() });
|
||||
let targets = use_resource(|| async { fetch_dast_targets().await.ok() });
|
||||
|
||||
let mut show_modal = use_signal(|| false);
|
||||
let mut new_target_id = use_signal(String::new);
|
||||
let mut new_strategy = use_signal(|| "comprehensive".to_string());
|
||||
let mut new_message = use_signal(String::new);
|
||||
let mut creating = use_signal(|| false);
|
||||
|
||||
let on_create = move |_| {
|
||||
let tid = new_target_id.read().clone();
|
||||
let strat = new_strategy.read().clone();
|
||||
let msg = new_message.read().clone();
|
||||
if tid.is_empty() || msg.is_empty() {
|
||||
return;
|
||||
}
|
||||
creating.set(true);
|
||||
spawn(async move {
|
||||
match create_pentest_session(tid, strat, msg).await {
|
||||
Ok(resp) => {
|
||||
let session_id = resp
|
||||
.data
|
||||
.get("_id")
|
||||
.and_then(|v| v.get("$oid"))
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("")
|
||||
.to_string();
|
||||
creating.set(false);
|
||||
show_modal.set(false);
|
||||
new_target_id.set(String::new());
|
||||
new_message.set(String::new());
|
||||
if !session_id.is_empty() {
|
||||
navigator().push(Route::PentestSessionPage {
|
||||
session_id: session_id.clone(),
|
||||
});
|
||||
} else {
|
||||
sessions.restart();
|
||||
}
|
||||
}
|
||||
Err(_) => {
|
||||
creating.set(false);
|
||||
}
|
||||
}
|
||||
});
|
||||
};
|
||||
|
||||
// Extract stats values
|
||||
let running_sessions = {
|
||||
let s = stats.read();
|
||||
match &*s {
|
||||
Some(Some(data)) => data
|
||||
.data
|
||||
.get("running_sessions")
|
||||
.and_then(|v| v.as_u64())
|
||||
.unwrap_or(0),
|
||||
_ => 0,
|
||||
}
|
||||
};
|
||||
let total_vulns = {
|
||||
let s = stats.read();
|
||||
match &*s {
|
||||
Some(Some(data)) => data
|
||||
.data
|
||||
.get("total_vulnerabilities")
|
||||
.and_then(|v| v.as_u64())
|
||||
.unwrap_or(0),
|
||||
_ => 0,
|
||||
}
|
||||
};
|
||||
let tool_invocations = {
|
||||
let s = stats.read();
|
||||
match &*s {
|
||||
Some(Some(data)) => data
|
||||
.data
|
||||
.get("tool_invocations")
|
||||
.and_then(|v| v.as_u64())
|
||||
.unwrap_or(0),
|
||||
_ => 0,
|
||||
}
|
||||
};
|
||||
let success_rate = {
|
||||
let s = stats.read();
|
||||
match &*s {
|
||||
Some(Some(data)) => data
|
||||
.data
|
||||
.get("success_rate")
|
||||
.and_then(|v| v.as_f64())
|
||||
.unwrap_or(0.0),
|
||||
_ => 0.0,
|
||||
}
|
||||
};
|
||||
|
||||
// Severity counts from stats
|
||||
let severity_critical = {
|
||||
let s = stats.read();
|
||||
match &*s {
|
||||
Some(Some(data)) => data
|
||||
.data
|
||||
.get("severity_critical")
|
||||
.and_then(|v| v.as_u64())
|
||||
.unwrap_or(0),
|
||||
_ => 0,
|
||||
}
|
||||
};
|
||||
let severity_high = {
|
||||
let s = stats.read();
|
||||
match &*s {
|
||||
Some(Some(data)) => data
|
||||
.data
|
||||
.get("severity_high")
|
||||
.and_then(|v| v.as_u64())
|
||||
.unwrap_or(0),
|
||||
_ => 0,
|
||||
}
|
||||
};
|
||||
let severity_medium = {
|
||||
let s = stats.read();
|
||||
match &*s {
|
||||
Some(Some(data)) => data
|
||||
.data
|
||||
.get("severity_medium")
|
||||
.and_then(|v| v.as_u64())
|
||||
.unwrap_or(0),
|
||||
_ => 0,
|
||||
}
|
||||
};
|
||||
let severity_low = {
|
||||
let s = stats.read();
|
||||
match &*s {
|
||||
Some(Some(data)) => data
|
||||
.data
|
||||
.get("severity_low")
|
||||
.and_then(|v| v.as_u64())
|
||||
.unwrap_or(0),
|
||||
_ => 0,
|
||||
}
|
||||
};
|
||||
|
||||
rsx! {
|
||||
PageHeader {
|
||||
title: "Pentest Dashboard",
|
||||
description: "AI-powered penetration testing sessions — autonomous security assessment",
|
||||
}
|
||||
|
||||
// Stat cards
|
||||
div { class: "stat-cards", style: "margin-bottom: 24px;",
|
||||
div { class: "stat-card-item",
|
||||
div { class: "stat-card-value", "{running_sessions}" }
|
||||
div { class: "stat-card-label",
|
||||
Icon { icon: BsPlayCircle, width: 14, height: 14 }
|
||||
" Running Sessions"
|
||||
}
|
||||
}
|
||||
div { class: "stat-card-item",
|
||||
div { class: "stat-card-value", "{total_vulns}" }
|
||||
div { class: "stat-card-label",
|
||||
Icon { icon: BsShieldExclamation, width: 14, height: 14 }
|
||||
" Total Vulnerabilities"
|
||||
}
|
||||
}
|
||||
div { class: "stat-card-item",
|
||||
div { class: "stat-card-value", "{tool_invocations}" }
|
||||
div { class: "stat-card-label",
|
||||
Icon { icon: BsWrench, width: 14, height: 14 }
|
||||
" Tool Invocations"
|
||||
}
|
||||
}
|
||||
div { class: "stat-card-item",
|
||||
div { class: "stat-card-value", "{success_rate:.0}%" }
|
||||
div { class: "stat-card-label",
|
||||
Icon { icon: BsCheckCircle, width: 14, height: 14 }
|
||||
" Success Rate"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Severity distribution
|
||||
div { class: "card", style: "margin-bottom: 24px; padding: 16px;",
|
||||
div { style: "display: flex; align-items: center; gap: 16px; flex-wrap: wrap;",
|
||||
span { style: "font-weight: 600; color: var(--text-secondary); font-size: 0.85rem;", "Severity Distribution" }
|
||||
span {
|
||||
class: "badge",
|
||||
style: "background: #dc2626; color: #fff;",
|
||||
"Critical: {severity_critical}"
|
||||
}
|
||||
span {
|
||||
class: "badge",
|
||||
style: "background: #ea580c; color: #fff;",
|
||||
"High: {severity_high}"
|
||||
}
|
||||
span {
|
||||
class: "badge",
|
||||
style: "background: #d97706; color: #fff;",
|
||||
"Medium: {severity_medium}"
|
||||
}
|
||||
span {
|
||||
class: "badge",
|
||||
style: "background: #2563eb; color: #fff;",
|
||||
"Low: {severity_low}"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Actions row
|
||||
div { style: "display: flex; gap: 12px; margin-bottom: 24px;",
|
||||
button {
|
||||
class: "btn btn-primary",
|
||||
onclick: move |_| show_modal.set(true),
|
||||
Icon { icon: BsPlusCircle, width: 14, height: 14 }
|
||||
" New Pentest"
|
||||
}
|
||||
}
|
||||
|
||||
// Sessions list
|
||||
div { class: "card",
|
||||
div { class: "card-header", "Recent Pentest Sessions" }
|
||||
match &*sessions.read() {
|
||||
Some(Some(data)) => {
|
||||
let sess_list = &data.data;
|
||||
if sess_list.is_empty() {
|
||||
rsx! {
|
||||
div { style: "padding: 32px; text-align: center; color: var(--text-secondary);",
|
||||
p { "No pentest sessions yet. Start one to begin autonomous security testing." }
|
||||
}
|
||||
}
|
||||
} else {
|
||||
rsx! {
|
||||
div { style: "display: grid; gap: 12px; padding: 16px;",
|
||||
for session in sess_list {
|
||||
{
|
||||
let id = session.get("_id")
|
||||
.and_then(|v| v.get("$oid"))
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("-").to_string();
|
||||
let target_name = session.get("target_name").and_then(|v| v.as_str()).unwrap_or("Unknown Target").to_string();
|
||||
let status = session.get("status").and_then(|v| v.as_str()).unwrap_or("unknown").to_string();
|
||||
let strategy = session.get("strategy").and_then(|v| v.as_str()).unwrap_or("-").to_string();
|
||||
let findings_count = session.get("findings_count").and_then(|v| v.as_u64()).unwrap_or(0);
|
||||
let tool_count = session.get("tool_invocations").and_then(|v| v.as_u64()).unwrap_or(0);
|
||||
let created_at = session.get("created_at").and_then(|v| v.as_str()).unwrap_or("-").to_string();
|
||||
let status_style = match status.as_str() {
|
||||
"running" => "background: #16a34a; color: #fff;",
|
||||
"completed" => "background: #2563eb; color: #fff;",
|
||||
"failed" => "background: #dc2626; color: #fff;",
|
||||
"paused" => "background: #d97706; color: #fff;",
|
||||
_ => "background: var(--bg-tertiary); color: var(--text-secondary);",
|
||||
};
|
||||
rsx! {
|
||||
Link {
|
||||
to: Route::PentestSessionPage { session_id: id.clone() },
|
||||
class: "card",
|
||||
style: "padding: 16px; text-decoration: none; cursor: pointer; transition: border-color 0.15s;",
|
||||
div { style: "display: flex; justify-content: space-between; align-items: flex-start;",
|
||||
div {
|
||||
div { style: "font-weight: 600; font-size: 1rem; margin-bottom: 4px; color: var(--text-primary);",
|
||||
"{target_name}"
|
||||
}
|
||||
div { style: "display: flex; gap: 8px; align-items: center; flex-wrap: wrap;",
|
||||
span {
|
||||
class: "badge",
|
||||
style: "{status_style}",
|
||||
"{status}"
|
||||
}
|
||||
span {
|
||||
class: "badge",
|
||||
style: "background: var(--bg-tertiary); color: var(--text-secondary);",
|
||||
"{strategy}"
|
||||
}
|
||||
}
|
||||
}
|
||||
div { style: "text-align: right; font-size: 0.85rem; color: var(--text-secondary);",
|
||||
div { style: "margin-bottom: 4px;",
|
||||
Icon { icon: BsShieldExclamation, width: 12, height: 12 }
|
||||
" {findings_count} findings"
|
||||
}
|
||||
div { style: "margin-bottom: 4px;",
|
||||
Icon { icon: BsWrench, width: 12, height: 12 }
|
||||
" {tool_count} tools"
|
||||
}
|
||||
div { "{created_at}" }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
Some(None) => rsx! { p { style: "padding: 16px;", "Failed to load sessions." } },
|
||||
None => rsx! { p { style: "padding: 16px;", "Loading..." } },
|
||||
}
|
||||
}
|
||||
|
||||
// New Pentest Modal
|
||||
if *show_modal.read() {
|
||||
div {
|
||||
style: "position: fixed; inset: 0; background: rgba(0,0,0,0.6); display: flex; align-items: center; justify-content: center; z-index: 1000;",
|
||||
onclick: move |_| show_modal.set(false),
|
||||
div {
|
||||
style: "background: var(--bg-secondary); border: 1px solid var(--border-color); border-radius: 12px; padding: 24px; width: 480px; max-width: 90vw;",
|
||||
onclick: move |e| e.stop_propagation(),
|
||||
h3 { style: "margin: 0 0 16px 0;", "New Pentest Session" }
|
||||
|
||||
// Target selection
|
||||
div { style: "margin-bottom: 12px;",
|
||||
label { style: "display: block; font-size: 0.85rem; color: var(--text-secondary); margin-bottom: 4px;",
|
||||
"Target"
|
||||
}
|
||||
select {
|
||||
class: "chat-input",
|
||||
style: "width: 100%; padding: 8px; resize: none; height: auto;",
|
||||
value: "{new_target_id}",
|
||||
onchange: move |e| new_target_id.set(e.value()),
|
||||
option { value: "", "Select a target..." }
|
||||
match &*targets.read() {
|
||||
Some(Some(data)) => {
|
||||
rsx! {
|
||||
for target in &data.data {
|
||||
{
|
||||
let tid = target.get("_id")
|
||||
.and_then(|v| v.get("$oid"))
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("").to_string();
|
||||
let tname = target.get("name").and_then(|v| v.as_str()).unwrap_or("Unknown").to_string();
|
||||
let turl = target.get("base_url").and_then(|v| v.as_str()).unwrap_or("").to_string();
|
||||
rsx! {
|
||||
option { value: "{tid}", "{tname} ({turl})" }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
_ => rsx! {},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Strategy selection
|
||||
div { style: "margin-bottom: 12px;",
|
||||
label { style: "display: block; font-size: 0.85rem; color: var(--text-secondary); margin-bottom: 4px;",
|
||||
"Strategy"
|
||||
}
|
||||
select {
|
||||
class: "chat-input",
|
||||
style: "width: 100%; padding: 8px; resize: none; height: auto;",
|
||||
value: "{new_strategy}",
|
||||
onchange: move |e| new_strategy.set(e.value()),
|
||||
option { value: "comprehensive", "Comprehensive" }
|
||||
option { value: "quick", "Quick Scan" }
|
||||
option { value: "owasp_top_10", "OWASP Top 10" }
|
||||
option { value: "api_focused", "API Focused" }
|
||||
option { value: "authentication", "Authentication" }
|
||||
}
|
||||
}
|
||||
|
||||
// Initial message
|
||||
div { style: "margin-bottom: 16px;",
|
||||
label { style: "display: block; font-size: 0.85rem; color: var(--text-secondary); margin-bottom: 4px;",
|
||||
"Initial Instructions"
|
||||
}
|
||||
textarea {
|
||||
class: "chat-input",
|
||||
style: "width: 100%; min-height: 80px;",
|
||||
placeholder: "Describe the scope and goals of this pentest...",
|
||||
value: "{new_message}",
|
||||
oninput: move |e| new_message.set(e.value()),
|
||||
}
|
||||
}
|
||||
|
||||
div { style: "display: flex; justify-content: flex-end; gap: 8px;",
|
||||
button {
|
||||
class: "btn btn-ghost",
|
||||
onclick: move |_| show_modal.set(false),
|
||||
"Cancel"
|
||||
}
|
||||
button {
|
||||
class: "btn btn-primary",
|
||||
disabled: *creating.read() || new_target_id.read().is_empty() || new_message.read().is_empty(),
|
||||
onclick: on_create,
|
||||
if *creating.read() { "Creating..." } else { "Start Pentest" }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
796
compliance-dashboard/src/pages/pentest_session.rs
Normal file
796
compliance-dashboard/src/pages/pentest_session.rs
Normal file
@@ -0,0 +1,796 @@
|
||||
use dioxus::prelude::*;
|
||||
use dioxus_free_icons::icons::bs_icons::*;
|
||||
use dioxus_free_icons::Icon;
|
||||
|
||||
use crate::app::Route;
|
||||
use crate::infrastructure::pentest::{
|
||||
export_pentest_report, fetch_attack_chain, fetch_pentest_findings, fetch_pentest_messages,
|
||||
fetch_pentest_session, send_pentest_message,
|
||||
};
|
||||
|
||||
/// Simple markdown-to-HTML converter for assistant messages.
|
||||
/// Handles headers, bold, italic, code blocks, inline code, and lists.
|
||||
fn markdown_to_html(input: &str) -> String {
|
||||
let mut html = String::new();
|
||||
let mut in_code_block = false;
|
||||
let mut in_list = false;
|
||||
|
||||
for line in input.lines() {
|
||||
if line.starts_with("```") {
|
||||
if in_code_block {
|
||||
html.push_str("</code></pre>");
|
||||
in_code_block = false;
|
||||
} else {
|
||||
if in_list {
|
||||
html.push_str("</ul>");
|
||||
in_list = false;
|
||||
}
|
||||
html.push_str("<pre style=\"background:var(--bg-primary);padding:10px;border-radius:6px;overflow-x:auto;font-size:0.8rem;margin:6px 0;\"><code>");
|
||||
in_code_block = true;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
if in_code_block {
|
||||
// Escape HTML inside code blocks
|
||||
let escaped = line
|
||||
.replace('&', "&")
|
||||
.replace('<', "<")
|
||||
.replace('>', ">");
|
||||
html.push_str(&escaped);
|
||||
html.push('\n');
|
||||
continue;
|
||||
}
|
||||
|
||||
let trimmed = line.trim();
|
||||
|
||||
// Blank line — close list if open
|
||||
if trimmed.is_empty() {
|
||||
if in_list {
|
||||
html.push_str("</ul>");
|
||||
in_list = false;
|
||||
}
|
||||
html.push_str("<br/>");
|
||||
continue;
|
||||
}
|
||||
|
||||
// Lists
|
||||
if trimmed.starts_with("- ") || trimmed.starts_with("* ") {
|
||||
if !in_list {
|
||||
html.push_str("<ul style=\"margin:4px 0;padding-left:20px;\">");
|
||||
in_list = true;
|
||||
}
|
||||
let content = inline_format(&trimmed[2..]);
|
||||
html.push_str(&format!("<li>{content}</li>"));
|
||||
continue;
|
||||
}
|
||||
|
||||
// Numbered lists
|
||||
if trimmed.len() > 2 {
|
||||
let mut chars = trimmed.chars();
|
||||
let first = chars.next();
|
||||
let second = chars.next();
|
||||
if first.map(|c| c.is_ascii_digit()).unwrap_or(false)
|
||||
&& (second == Some('.') || second == Some(')'))
|
||||
{
|
||||
let rest = &trimmed[2..].trim_start();
|
||||
if !in_list {
|
||||
html.push_str("<ul style=\"margin:4px 0;padding-left:20px;\">");
|
||||
in_list = true;
|
||||
}
|
||||
let content = inline_format(rest);
|
||||
html.push_str(&format!("<li>{content}</li>"));
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
// Close list if we're no longer in one
|
||||
if in_list {
|
||||
html.push_str("</ul>");
|
||||
in_list = false;
|
||||
}
|
||||
|
||||
// Headers
|
||||
if trimmed.starts_with("### ") {
|
||||
let content = inline_format(&trimmed[4..]);
|
||||
html.push_str(&format!(
|
||||
"<h4 style=\"margin:8px 0 4px;font-size:0.9rem;\">{content}</h4>"
|
||||
));
|
||||
} else if trimmed.starts_with("## ") {
|
||||
let content = inline_format(&trimmed[3..]);
|
||||
html.push_str(&format!(
|
||||
"<h3 style=\"margin:10px 0 4px;font-size:0.95rem;\">{content}</h3>"
|
||||
));
|
||||
} else if trimmed.starts_with("# ") {
|
||||
let content = inline_format(&trimmed[2..]);
|
||||
html.push_str(&format!(
|
||||
"<h2 style=\"margin:12px 0 6px;font-size:1rem;\">{content}</h2>"
|
||||
));
|
||||
} else {
|
||||
let content = inline_format(trimmed);
|
||||
html.push_str(&format!("<p style=\"margin:2px 0;\">{content}</p>"));
|
||||
}
|
||||
}
|
||||
|
||||
if in_list {
|
||||
html.push_str("</ul>");
|
||||
}
|
||||
if in_code_block {
|
||||
html.push_str("</code></pre>");
|
||||
}
|
||||
|
||||
html
|
||||
}
|
||||
|
||||
/// Handle inline formatting: bold, italic, inline code
|
||||
fn inline_format(text: &str) -> String {
|
||||
let mut result = text
|
||||
.replace('&', "&")
|
||||
.replace('<', "<")
|
||||
.replace('>', ">");
|
||||
|
||||
// Inline code (backticks)
|
||||
while let Some(start) = result.find('`') {
|
||||
if let Some(end) = result[start + 1..].find('`') {
|
||||
let code_content = &result[start + 1..start + 1 + end].to_string();
|
||||
let replacement = format!(
|
||||
"<code style=\"background:var(--bg-primary);padding:1px 4px;border-radius:3px;font-size:0.85em;\">{code_content}</code>"
|
||||
);
|
||||
result = format!("{}{}{}", &result[..start], replacement, &result[start + 2 + end..]);
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Bold (**text**)
|
||||
while let Some(start) = result.find("**") {
|
||||
if let Some(end) = result[start + 2..].find("**") {
|
||||
let bold_content = &result[start + 2..start + 2 + end].to_string();
|
||||
result = format!(
|
||||
"{}<strong>{bold_content}</strong>{}",
|
||||
&result[..start],
|
||||
&result[start + 4 + end..]
|
||||
);
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
#[component]
|
||||
pub fn PentestSessionPage(session_id: String) -> Element {
|
||||
let sid = session_id.clone();
|
||||
let sid_for_session = session_id.clone();
|
||||
let sid_for_findings = session_id.clone();
|
||||
let sid_for_chain = session_id.clone();
|
||||
|
||||
let mut session = use_resource(move || {
|
||||
let id = sid_for_session.clone();
|
||||
async move { fetch_pentest_session(id).await.ok() }
|
||||
});
|
||||
let mut messages_res = use_resource(move || {
|
||||
let id = sid.clone();
|
||||
async move { fetch_pentest_messages(id).await.ok() }
|
||||
});
|
||||
let mut findings = use_resource(move || {
|
||||
let id = sid_for_findings.clone();
|
||||
async move { fetch_pentest_findings(id).await.ok() }
|
||||
});
|
||||
let mut attack_chain = use_resource(move || {
|
||||
let id = sid_for_chain.clone();
|
||||
async move { fetch_attack_chain(id).await.ok() }
|
||||
});
|
||||
|
||||
let mut input_text = use_signal(String::new);
|
||||
let mut sending = use_signal(|| false);
|
||||
let mut right_tab = use_signal(|| "findings".to_string());
|
||||
let mut chain_view = use_signal(|| "list".to_string());
|
||||
let mut exporting = use_signal(|| false);
|
||||
let mut poll_gen = use_signal(|| 0u32); // incremented to trigger re-poll
|
||||
|
||||
// Extract session status
|
||||
let session_status = {
|
||||
let s = session.read();
|
||||
match &*s {
|
||||
Some(Some(resp)) => resp
|
||||
.data
|
||||
.get("status")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("unknown")
|
||||
.to_string(),
|
||||
_ => "unknown".to_string(),
|
||||
}
|
||||
};
|
||||
|
||||
let is_running = session_status == "running";
|
||||
|
||||
// Continuous polling: re-fetch all data every 3s while running
|
||||
use_effect(move || {
|
||||
let _gen = *poll_gen.read(); // subscribe to changes
|
||||
if is_running {
|
||||
spawn(async move {
|
||||
#[cfg(feature = "web")]
|
||||
gloo_timers::future::TimeoutFuture::new(3_000).await;
|
||||
#[cfg(not(feature = "web"))]
|
||||
tokio::time::sleep(std::time::Duration::from_secs(3)).await;
|
||||
messages_res.restart();
|
||||
findings.restart();
|
||||
attack_chain.restart();
|
||||
session.restart();
|
||||
// Bump generation to trigger the next poll cycle
|
||||
let next = poll_gen.peek().wrapping_add(1);
|
||||
poll_gen.set(next);
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
// Load attack chain into vis-network when graph tab is active and data is available
|
||||
// Use a separate effect that reads the reactive resources directly
|
||||
use_effect(move || {
|
||||
let tab = right_tab.read().clone();
|
||||
let view = chain_view.read().clone();
|
||||
let chain = attack_chain.read().clone();
|
||||
|
||||
if tab == "chain" && view == "graph" {
|
||||
if let Some(Some(data)) = &chain {
|
||||
if !data.data.is_empty() {
|
||||
let nodes_json =
|
||||
serde_json::to_string(&data.data).unwrap_or_else(|_| "[]".to_string());
|
||||
// Small delay to ensure the DOM container exists
|
||||
spawn(async move {
|
||||
#[cfg(feature = "web")]
|
||||
gloo_timers::future::TimeoutFuture::new(100).await;
|
||||
#[cfg(not(feature = "web"))]
|
||||
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
|
||||
let js = format!(
|
||||
r#"if (window.__loadAttackChain) {{ window.__loadAttackChain({nodes_json}); }}"#
|
||||
);
|
||||
document::eval(&js);
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// Send message handler
|
||||
let sid_for_send = session_id.clone();
|
||||
let mut do_send = move || {
|
||||
let text = input_text.read().trim().to_string();
|
||||
if text.is_empty() || *sending.read() {
|
||||
return;
|
||||
}
|
||||
let sid = sid_for_send.clone();
|
||||
input_text.set(String::new());
|
||||
sending.set(true);
|
||||
spawn(async move {
|
||||
let _ = send_pentest_message(sid, text).await;
|
||||
sending.set(false);
|
||||
messages_res.restart();
|
||||
});
|
||||
};
|
||||
|
||||
let mut do_send_click = do_send.clone();
|
||||
|
||||
// Export handlers
|
||||
let sid_for_export = session_id.clone();
|
||||
let do_export_md = move |_| {
|
||||
let sid = sid_for_export.clone();
|
||||
exporting.set(true);
|
||||
spawn(async move {
|
||||
match export_pentest_report(sid.clone(), "markdown".to_string()).await {
|
||||
Ok(content) => {
|
||||
let escaped = content
|
||||
.replace('\\', "\\\\")
|
||||
.replace('`', "\\`")
|
||||
.replace("${", "\\${");
|
||||
let js = format!(
|
||||
r#"
|
||||
var blob = new Blob([`{escaped}`], {{ type: 'text/markdown' }});
|
||||
var url = URL.createObjectURL(blob);
|
||||
var a = document.createElement('a');
|
||||
a.href = url;
|
||||
a.download = 'pentest-report-{sid}.md';
|
||||
document.body.appendChild(a);
|
||||
a.click();
|
||||
document.body.removeChild(a);
|
||||
URL.revokeObjectURL(url);
|
||||
"#
|
||||
);
|
||||
document::eval(&js);
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!("Export failed: {e}");
|
||||
}
|
||||
}
|
||||
exporting.set(false);
|
||||
});
|
||||
};
|
||||
|
||||
let sid_for_export_json = session_id.clone();
|
||||
let do_export_json = move |_| {
|
||||
let sid = sid_for_export_json.clone();
|
||||
exporting.set(true);
|
||||
spawn(async move {
|
||||
match export_pentest_report(sid.clone(), "json".to_string()).await {
|
||||
Ok(content) => {
|
||||
let escaped = content
|
||||
.replace('\\', "\\\\")
|
||||
.replace('`', "\\`")
|
||||
.replace("${", "\\${");
|
||||
let js = format!(
|
||||
r#"
|
||||
var blob = new Blob([`{escaped}`], {{ type: 'application/json' }});
|
||||
var url = URL.createObjectURL(blob);
|
||||
var a = document.createElement('a');
|
||||
a.href = url;
|
||||
a.download = 'pentest-report-{sid}.json';
|
||||
document.body.appendChild(a);
|
||||
a.click();
|
||||
document.body.removeChild(a);
|
||||
URL.revokeObjectURL(url);
|
||||
"#
|
||||
);
|
||||
document::eval(&js);
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!("Export failed: {e}");
|
||||
}
|
||||
}
|
||||
exporting.set(false);
|
||||
});
|
||||
};
|
||||
|
||||
// Session header info
|
||||
let target_name = {
|
||||
let s = session.read();
|
||||
match &*s {
|
||||
Some(Some(resp)) => resp
|
||||
.data
|
||||
.get("target_name")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("Pentest Session")
|
||||
.to_string(),
|
||||
_ => "Pentest Session".to_string(),
|
||||
}
|
||||
};
|
||||
|
||||
let strategy = {
|
||||
let s = session.read();
|
||||
match &*s {
|
||||
Some(Some(resp)) => resp
|
||||
.data
|
||||
.get("strategy")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("-")
|
||||
.to_string(),
|
||||
_ => "-".to_string(),
|
||||
}
|
||||
};
|
||||
|
||||
let header_tool_count = {
|
||||
let s = session.read();
|
||||
match &*s {
|
||||
Some(Some(resp)) => resp
|
||||
.data
|
||||
.get("tool_invocations")
|
||||
.and_then(|v| v.as_u64())
|
||||
.unwrap_or(0),
|
||||
_ => 0,
|
||||
}
|
||||
};
|
||||
|
||||
let header_findings_count = {
|
||||
let f = findings.read();
|
||||
match &*f {
|
||||
Some(Some(data)) => data.total.unwrap_or(0),
|
||||
_ => 0,
|
||||
}
|
||||
};
|
||||
|
||||
let status_style = match session_status.as_str() {
|
||||
"running" => "background: #16a34a; color: #fff;",
|
||||
"completed" => "background: #2563eb; color: #fff;",
|
||||
"failed" => "background: #dc2626; color: #fff;",
|
||||
"paused" => "background: #d97706; color: #fff;",
|
||||
_ => "background: var(--bg-tertiary); color: var(--text-secondary);",
|
||||
};
|
||||
|
||||
rsx! {
|
||||
div { class: "back-nav",
|
||||
Link {
|
||||
to: Route::PentestDashboardPage {},
|
||||
class: "btn btn-ghost btn-back",
|
||||
Icon { icon: BsArrowLeft, width: 16, height: 16 }
|
||||
"Back to Pentest Dashboard"
|
||||
}
|
||||
}
|
||||
|
||||
// Session header
|
||||
div { style: "display: flex; align-items: center; justify-content: space-between; margin-bottom: 16px; flex-wrap: wrap; gap: 8px;",
|
||||
div {
|
||||
h2 { style: "margin: 0 0 4px 0;", "{target_name}" }
|
||||
div { style: "display: flex; gap: 8px; align-items: center; flex-wrap: wrap;",
|
||||
span { class: "badge", style: "{status_style}", "{session_status}" }
|
||||
span { class: "badge", style: "background: var(--bg-tertiary); color: var(--text-secondary);",
|
||||
"{strategy}"
|
||||
}
|
||||
}
|
||||
}
|
||||
div { style: "display: flex; gap: 8px; align-items: center;",
|
||||
div { style: "display: flex; gap: 4px;",
|
||||
button {
|
||||
class: "btn btn-ghost",
|
||||
style: "font-size: 0.8rem; padding: 4px 10px;",
|
||||
disabled: *exporting.read(),
|
||||
onclick: do_export_md,
|
||||
Icon { icon: BsFileEarmarkText, width: 12, height: 12 }
|
||||
" Export MD"
|
||||
}
|
||||
button {
|
||||
class: "btn btn-ghost",
|
||||
style: "font-size: 0.8rem; padding: 4px 10px;",
|
||||
disabled: *exporting.read(),
|
||||
onclick: do_export_json,
|
||||
Icon { icon: BsFiletypeJson, width: 12, height: 12 }
|
||||
" Export JSON"
|
||||
}
|
||||
}
|
||||
div { style: "display: flex; gap: 16px; font-size: 0.85rem; color: var(--text-secondary);",
|
||||
span {
|
||||
Icon { icon: BsWrench, width: 14, height: 14 }
|
||||
" {header_tool_count} tools"
|
||||
}
|
||||
span {
|
||||
Icon { icon: BsShieldExclamation, width: 14, height: 14 }
|
||||
" {header_findings_count} findings"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Split layout: chat left, findings/chain right
|
||||
div { style: "display: grid; grid-template-columns: 1fr 420px; gap: 16px; height: calc(100vh - 220px); min-height: 400px;",
|
||||
|
||||
// Left: Chat area
|
||||
div { class: "card", style: "display: flex; flex-direction: column; overflow: hidden;",
|
||||
div { class: "card-header", style: "flex-shrink: 0;", "Chat" }
|
||||
|
||||
// Messages
|
||||
div {
|
||||
style: "flex: 1; overflow-y: auto; padding: 16px; display: flex; flex-direction: column; gap: 12px;",
|
||||
match &*messages_res.read() {
|
||||
Some(Some(data)) => {
|
||||
let msgs = &data.data;
|
||||
if msgs.is_empty() {
|
||||
rsx! {
|
||||
div { style: "text-align: center; color: var(--text-secondary); padding: 32px;",
|
||||
h3 { style: "margin-bottom: 8px;", "Start the conversation" }
|
||||
p { "Send a message to guide the pentest agent." }
|
||||
}
|
||||
}
|
||||
} else {
|
||||
rsx! {
|
||||
for (i, msg) in msgs.iter().enumerate() {
|
||||
{
|
||||
let role = msg.get("role").and_then(|v| v.as_str()).unwrap_or("assistant").to_string();
|
||||
let content = msg.get("content").and_then(|v| v.as_str()).unwrap_or("").to_string();
|
||||
let msg_type = msg.get("type").and_then(|v| v.as_str()).unwrap_or("text").to_string();
|
||||
let tool_name = msg.get("tool_name").and_then(|v| v.as_str()).unwrap_or("").to_string();
|
||||
let tool_status = msg.get("tool_status").and_then(|v| v.as_str()).unwrap_or("").to_string();
|
||||
|
||||
if msg_type == "tool_call" || msg_type == "tool_result" || role == "tool_result" {
|
||||
let tool_icon_style = match tool_status.as_str() {
|
||||
"success" => "color: #16a34a;",
|
||||
"error" => "color: #dc2626;",
|
||||
"running" => "color: #d97706;",
|
||||
_ => "color: var(--text-secondary);",
|
||||
};
|
||||
rsx! {
|
||||
div {
|
||||
key: "{i}",
|
||||
style: "display: flex; align-items: center; gap: 8px; padding: 6px 12px; background: var(--bg-tertiary); border-radius: 6px; font-size: 0.8rem; color: var(--text-secondary);",
|
||||
span { style: "{tool_icon_style}",
|
||||
Icon { icon: BsWrench, width: 12, height: 12 }
|
||||
}
|
||||
span { style: "font-family: monospace;", "{tool_name}" }
|
||||
if !tool_status.is_empty() {
|
||||
span { class: "badge", style: "font-size: 0.7rem;", "{tool_status}" }
|
||||
}
|
||||
if !content.is_empty() {
|
||||
details { style: "margin-left: auto; cursor: pointer;",
|
||||
summary { style: "font-size: 0.75rem;", "details" }
|
||||
pre { style: "margin-top: 4px; padding: 8px; background: var(--bg-primary); border-radius: 4px; font-size: 0.75rem; overflow-x: auto; max-height: 200px; white-space: pre-wrap;",
|
||||
"{content}"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if role == "user" {
|
||||
rsx! {
|
||||
div {
|
||||
key: "{i}",
|
||||
style: "display: flex; justify-content: flex-end;",
|
||||
div {
|
||||
style: "max-width: 80%; padding: 10px 14px; background: #2563eb; color: #fff; border-radius: 12px 12px 2px 12px; font-size: 0.9rem; line-height: 1.5; white-space: pre-wrap;",
|
||||
"{content}"
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Assistant message — render markdown
|
||||
let rendered_html = markdown_to_html(&content);
|
||||
rsx! {
|
||||
div {
|
||||
key: "{i}",
|
||||
style: "display: flex; gap: 8px; align-items: flex-start;",
|
||||
div {
|
||||
style: "flex-shrink: 0; width: 28px; height: 28px; border-radius: 50%; background: var(--bg-tertiary); display: flex; align-items: center; justify-content: center;",
|
||||
Icon { icon: BsCpu, width: 14, height: 14 }
|
||||
}
|
||||
div {
|
||||
style: "max-width: 80%; padding: 10px 14px; background: var(--bg-tertiary); border-radius: 12px 12px 12px 2px; font-size: 0.9rem; line-height: 1.5;",
|
||||
dangerous_inner_html: "{rendered_html}",
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
Some(None) => rsx! { p { style: "padding: 16px; color: var(--text-secondary);", "Failed to load messages." } },
|
||||
None => rsx! { p { style: "padding: 16px; color: var(--text-secondary);", "Loading messages..." } },
|
||||
}
|
||||
|
||||
if *sending.read() {
|
||||
div { style: "display: flex; gap: 8px; align-items: flex-start;",
|
||||
div {
|
||||
style: "flex-shrink: 0; width: 28px; height: 28px; border-radius: 50%; background: var(--bg-tertiary); display: flex; align-items: center; justify-content: center;",
|
||||
Icon { icon: BsCpu, width: 14, height: 14 }
|
||||
}
|
||||
div {
|
||||
style: "padding: 10px 14px; background: var(--bg-tertiary); border-radius: 12px 12px 12px 2px; font-size: 0.9rem; color: var(--text-secondary);",
|
||||
"Thinking..."
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Input area — disabled while pentest is running (messages have no effect)
|
||||
if is_running {
|
||||
div { style: "flex-shrink: 0; padding: 12px; border-top: 1px solid var(--border-color); text-align: center; color: var(--text-secondary); font-size: 0.85rem;",
|
||||
Icon { icon: BsPlayCircle, width: 14, height: 14 }
|
||||
" Pentest is running — input disabled until complete"
|
||||
}
|
||||
} else {
|
||||
div { style: "flex-shrink: 0; padding: 12px; border-top: 1px solid var(--border-color); display: flex; gap: 8px;",
|
||||
textarea {
|
||||
class: "chat-input",
|
||||
style: "flex: 1;",
|
||||
placeholder: "Guide the pentest agent...",
|
||||
value: "{input_text}",
|
||||
oninput: move |e| input_text.set(e.value()),
|
||||
onkeydown: move |e: Event<KeyboardData>| {
|
||||
if e.key() == Key::Enter && !e.modifiers().shift() {
|
||||
e.prevent_default();
|
||||
do_send();
|
||||
}
|
||||
},
|
||||
}
|
||||
button {
|
||||
class: "btn btn-primary",
|
||||
style: "align-self: flex-end;",
|
||||
disabled: *sending.read(),
|
||||
onclick: move |_| do_send_click(),
|
||||
"Send"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Right: Findings / Attack Chain tabs
|
||||
div { class: "card", style: "display: flex; flex-direction: column; overflow: hidden;",
|
||||
// Tab bar
|
||||
div { style: "display: flex; border-bottom: 1px solid var(--border-color); flex-shrink: 0;",
|
||||
button {
|
||||
style: if *right_tab.read() == "findings" {
|
||||
"flex: 1; padding: 10px; background: none; border: none; border-bottom: 2px solid #2563eb; color: var(--text-primary); cursor: pointer; font-weight: 600;"
|
||||
} else {
|
||||
"flex: 1; padding: 10px; background: none; border: none; border-bottom: 2px solid transparent; color: var(--text-secondary); cursor: pointer;"
|
||||
},
|
||||
onclick: move |_| right_tab.set("findings".to_string()),
|
||||
Icon { icon: BsShieldExclamation, width: 14, height: 14 }
|
||||
" Findings ({header_findings_count})"
|
||||
}
|
||||
button {
|
||||
style: if *right_tab.read() == "chain" {
|
||||
"flex: 1; padding: 10px; background: none; border: none; border-bottom: 2px solid #2563eb; color: var(--text-primary); cursor: pointer; font-weight: 600;"
|
||||
} else {
|
||||
"flex: 1; padding: 10px; background: none; border: none; border-bottom: 2px solid transparent; color: var(--text-secondary); cursor: pointer;"
|
||||
},
|
||||
onclick: move |_| right_tab.set("chain".to_string()),
|
||||
Icon { icon: BsDiagram3, width: 14, height: 14 }
|
||||
" Attack Chain"
|
||||
}
|
||||
}
|
||||
|
||||
// Tab content
|
||||
div { style: "flex: 1; overflow-y: auto; display: flex; flex-direction: column;",
|
||||
if *right_tab.read() == "findings" {
|
||||
// Findings tab
|
||||
div { style: "padding: 12px; flex: 1; overflow-y: auto;",
|
||||
match &*findings.read() {
|
||||
Some(Some(data)) => {
|
||||
let finding_list = &data.data;
|
||||
if finding_list.is_empty() {
|
||||
rsx! {
|
||||
div { style: "text-align: center; color: var(--text-secondary); padding: 24px;",
|
||||
p { "No findings yet." }
|
||||
}
|
||||
}
|
||||
} else {
|
||||
rsx! {
|
||||
div { style: "display: flex; flex-direction: column; gap: 8px;",
|
||||
for finding in finding_list {
|
||||
{
|
||||
let title = finding.get("title").and_then(|v| v.as_str()).unwrap_or("Untitled").to_string();
|
||||
let severity = finding.get("severity").and_then(|v| v.as_str()).unwrap_or("info").to_string();
|
||||
let vuln_type = finding.get("vuln_type").and_then(|v| v.as_str()).unwrap_or("-").to_string();
|
||||
let endpoint = finding.get("endpoint").and_then(|v| v.as_str()).unwrap_or("").to_string();
|
||||
let exploitable = finding.get("exploitable").and_then(|v| v.as_bool()).unwrap_or(false);
|
||||
let sev_style = match severity.as_str() {
|
||||
"critical" => "background: #dc2626; color: #fff;",
|
||||
"high" => "background: #ea580c; color: #fff;",
|
||||
"medium" => "background: #d97706; color: #fff;",
|
||||
"low" => "background: #2563eb; color: #fff;",
|
||||
_ => "background: var(--bg-tertiary); color: var(--text-secondary);",
|
||||
};
|
||||
rsx! {
|
||||
div { style: "padding: 10px; background: var(--bg-tertiary); border-radius: 8px;",
|
||||
div { style: "display: flex; justify-content: space-between; align-items: center; margin-bottom: 4px;",
|
||||
span { style: "font-weight: 600; font-size: 0.85rem;", "{title}" }
|
||||
div { style: "display: flex; gap: 4px;",
|
||||
if exploitable {
|
||||
span { class: "badge", style: "background: #dc2626; color: #fff; font-size: 0.7rem;", "Exploitable" }
|
||||
}
|
||||
span { class: "badge", style: "{sev_style}", "{severity}" }
|
||||
}
|
||||
}
|
||||
div { style: "font-size: 0.8rem; color: var(--text-secondary);", "{vuln_type}" }
|
||||
if !endpoint.is_empty() {
|
||||
div { style: "font-size: 0.75rem; color: var(--text-secondary); font-family: monospace; margin-top: 2px;",
|
||||
"{endpoint}"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
Some(None) => rsx! { p { style: "color: var(--text-secondary); padding: 12px;", "Failed to load findings." } },
|
||||
None => rsx! { p { style: "color: var(--text-secondary); padding: 12px;", "Loading..." } },
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Attack chain tab — graph/list toggle
|
||||
div { style: "display: flex; gap: 4px; padding: 8px 12px; flex-shrink: 0;",
|
||||
button {
|
||||
class: if *chain_view.read() == "graph" { "btn btn-primary" } else { "btn btn-ghost" },
|
||||
style: "font-size: 0.75rem; padding: 3px 8px;",
|
||||
onclick: move |_| chain_view.set("graph".to_string()),
|
||||
Icon { icon: BsDiagram3, width: 12, height: 12 }
|
||||
" Graph"
|
||||
}
|
||||
button {
|
||||
class: if *chain_view.read() == "list" { "btn btn-primary" } else { "btn btn-ghost" },
|
||||
style: "font-size: 0.75rem; padding: 3px 8px;",
|
||||
onclick: move |_| chain_view.set("list".to_string()),
|
||||
Icon { icon: BsListOl, width: 12, height: 12 }
|
||||
" List"
|
||||
}
|
||||
}
|
||||
|
||||
if *chain_view.read() == "graph" {
|
||||
// Interactive DAG visualization
|
||||
div { style: "flex: 1; position: relative; min-height: 300px;",
|
||||
div {
|
||||
id: "attack-chain-canvas",
|
||||
style: "width: 100%; height: 100%; position: absolute; inset: 0;",
|
||||
}
|
||||
match &*attack_chain.read() {
|
||||
Some(Some(data)) if data.data.is_empty() => rsx! {
|
||||
div { style: "position: absolute; inset: 0; display: flex; align-items: center; justify-content: center; color: var(--text-secondary);",
|
||||
p { "No attack chain steps yet." }
|
||||
}
|
||||
},
|
||||
_ => rsx! {},
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// List view
|
||||
div { style: "flex: 1; overflow-y: auto; padding: 0 12px 12px;",
|
||||
match &*attack_chain.read() {
|
||||
Some(Some(data)) => {
|
||||
let steps = &data.data;
|
||||
if steps.is_empty() {
|
||||
rsx! {
|
||||
div { style: "text-align: center; color: var(--text-secondary); padding: 24px;",
|
||||
p { "No attack chain steps yet." }
|
||||
}
|
||||
}
|
||||
} else {
|
||||
rsx! {
|
||||
div { style: "display: flex; flex-direction: column; gap: 4px;",
|
||||
for (i, step) in steps.iter().enumerate() {
|
||||
{
|
||||
let tool_name = step.get("tool_name").and_then(|v| v.as_str()).unwrap_or("Step").to_string();
|
||||
let step_status = step.get("status").and_then(|v| v.as_str()).unwrap_or("pending").to_string();
|
||||
let reasoning = step.get("llm_reasoning").and_then(|v| v.as_str()).unwrap_or("").to_string();
|
||||
let findings_count = step.get("findings_produced").and_then(|v| v.as_array()).map(|a| a.len()).unwrap_or(0);
|
||||
let risk_score = step.get("risk_score").and_then(|v| v.as_u64());
|
||||
let step_num = i + 1;
|
||||
let dot_color = match step_status.as_str() {
|
||||
"completed" => "#16a34a",
|
||||
"running" => "#d97706",
|
||||
"failed" => "#dc2626",
|
||||
"skipped" => "#374151",
|
||||
_ => "var(--text-secondary)",
|
||||
};
|
||||
rsx! {
|
||||
div { style: "display: flex; gap: 10px; padding: 8px 0;",
|
||||
div { style: "display: flex; flex-direction: column; align-items: center;",
|
||||
div { style: "width: 10px; height: 10px; border-radius: 50%; background: {dot_color}; flex-shrink: 0;" }
|
||||
if i < steps.len() - 1 {
|
||||
div { style: "width: 2px; flex: 1; background: var(--border-color); margin-top: 4px;" }
|
||||
}
|
||||
}
|
||||
div { style: "flex: 1; min-width: 0;",
|
||||
div { style: "display: flex; justify-content: space-between; align-items: center;",
|
||||
span { style: "font-size: 0.85rem; font-weight: 600;",
|
||||
"{step_num}. {tool_name}"
|
||||
}
|
||||
div { style: "display: flex; gap: 4px;",
|
||||
if findings_count > 0 {
|
||||
span { class: "badge", style: "font-size: 0.65rem; background: #dc2626; color: #fff;",
|
||||
"{findings_count} findings"
|
||||
}
|
||||
}
|
||||
if let Some(score) = risk_score {
|
||||
span { class: "badge", style: "font-size: 0.65rem;",
|
||||
"risk: {score}"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if !reasoning.is_empty() {
|
||||
div { style: "font-size: 0.8rem; color: var(--text-secondary); margin-top: 2px; overflow: hidden; text-overflow: ellipsis; display: -webkit-box; -webkit-line-clamp: 2; -webkit-box-orient: vertical;",
|
||||
"{reasoning}"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
Some(None) => rsx! { p { style: "color: var(--text-secondary); padding: 12px;", "Failed to load attack chain." } },
|
||||
None => rsx! { p { style: "color: var(--text-secondary); padding: 12px;", "Loading..." } },
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -27,6 +27,10 @@ chromiumoxide = { version = "0.7", features = ["tokio-runtime"], default-feature
|
||||
# Docker sandboxing
|
||||
bollard = "0.18"
|
||||
|
||||
# TLS analysis
|
||||
native-tls = "0.2"
|
||||
tokio-native-tls = "0.3"
|
||||
|
||||
# Serialization
|
||||
bson = { version = "2", features = ["chrono-0_4"] }
|
||||
url = "2"
|
||||
|
||||
@@ -2,5 +2,7 @@ pub mod agents;
|
||||
pub mod crawler;
|
||||
pub mod orchestrator;
|
||||
pub mod recon;
|
||||
pub mod tools;
|
||||
|
||||
pub use orchestrator::DastOrchestrator;
|
||||
pub use tools::ToolRegistry;
|
||||
|
||||
146
compliance-dast/src/tools/api_fuzzer.rs
Normal file
146
compliance-dast/src/tools/api_fuzzer.rs
Normal file
@@ -0,0 +1,146 @@
|
||||
use compliance_core::error::CoreError;
|
||||
use compliance_core::traits::dast_agent::{DastAgent, DastContext, DiscoveredEndpoint, EndpointParameter};
|
||||
use compliance_core::traits::pentest_tool::{PentestTool, PentestToolContext, PentestToolResult};
|
||||
use serde_json::json;
|
||||
|
||||
use crate::agents::api_fuzzer::ApiFuzzerAgent;
|
||||
|
||||
/// PentestTool wrapper around the existing ApiFuzzerAgent.
|
||||
pub struct ApiFuzzerTool {
|
||||
http: reqwest::Client,
|
||||
agent: ApiFuzzerAgent,
|
||||
}
|
||||
|
||||
impl ApiFuzzerTool {
|
||||
pub fn new(http: reqwest::Client) -> Self {
|
||||
let agent = ApiFuzzerAgent::new(http.clone());
|
||||
Self { http, agent }
|
||||
}
|
||||
|
||||
fn parse_endpoints(input: &serde_json::Value) -> Vec<DiscoveredEndpoint> {
|
||||
let mut endpoints = Vec::new();
|
||||
if let Some(arr) = input.get("endpoints").and_then(|v| v.as_array()) {
|
||||
for ep in arr {
|
||||
let url = ep.get("url").and_then(|v| v.as_str()).unwrap_or_default().to_string();
|
||||
let method = ep.get("method").and_then(|v| v.as_str()).unwrap_or("GET").to_string();
|
||||
let mut parameters = Vec::new();
|
||||
if let Some(params) = ep.get("parameters").and_then(|v| v.as_array()) {
|
||||
for p in params {
|
||||
parameters.push(EndpointParameter {
|
||||
name: p.get("name").and_then(|v| v.as_str()).unwrap_or_default().to_string(),
|
||||
location: p.get("location").and_then(|v| v.as_str()).unwrap_or("query").to_string(),
|
||||
param_type: p.get("param_type").and_then(|v| v.as_str()).map(String::from),
|
||||
example_value: p.get("example_value").and_then(|v| v.as_str()).map(String::from),
|
||||
});
|
||||
}
|
||||
}
|
||||
endpoints.push(DiscoveredEndpoint {
|
||||
url,
|
||||
method,
|
||||
parameters,
|
||||
content_type: ep.get("content_type").and_then(|v| v.as_str()).map(String::from),
|
||||
requires_auth: ep.get("requires_auth").and_then(|v| v.as_bool()).unwrap_or(false),
|
||||
});
|
||||
}
|
||||
}
|
||||
endpoints
|
||||
}
|
||||
}
|
||||
|
||||
impl PentestTool for ApiFuzzerTool {
|
||||
fn name(&self) -> &str {
|
||||
"api_fuzzer"
|
||||
}
|
||||
|
||||
fn description(&self) -> &str {
|
||||
"Fuzzes API endpoints to discover misconfigurations, information disclosure, and hidden \
|
||||
endpoints. Probes common sensitive paths and tests for verbose error messages."
|
||||
}
|
||||
|
||||
fn input_schema(&self) -> serde_json::Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"endpoints": {
|
||||
"type": "array",
|
||||
"description": "Known endpoints to fuzz",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"url": { "type": "string" },
|
||||
"method": { "type": "string", "enum": ["GET", "POST", "PUT", "PATCH", "DELETE"] },
|
||||
"parameters": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": { "type": "string" },
|
||||
"location": { "type": "string" },
|
||||
"param_type": { "type": "string" },
|
||||
"example_value": { "type": "string" }
|
||||
},
|
||||
"required": ["name"]
|
||||
}
|
||||
}
|
||||
},
|
||||
"required": ["url"]
|
||||
}
|
||||
},
|
||||
"base_url": {
|
||||
"type": "string",
|
||||
"description": "Base URL to probe for common sensitive paths (used if no endpoints provided)"
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
fn execute<'a>(
|
||||
&'a self,
|
||||
input: serde_json::Value,
|
||||
context: &'a PentestToolContext,
|
||||
) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<PentestToolResult, CoreError>> + Send + 'a>> {
|
||||
Box::pin(async move {
|
||||
let mut endpoints = Self::parse_endpoints(&input);
|
||||
|
||||
// If a base_url is provided but no endpoints, create a default endpoint
|
||||
if endpoints.is_empty() {
|
||||
if let Some(base) = input.get("base_url").and_then(|v| v.as_str()) {
|
||||
endpoints.push(DiscoveredEndpoint {
|
||||
url: base.to_string(),
|
||||
method: "GET".to_string(),
|
||||
parameters: Vec::new(),
|
||||
content_type: None,
|
||||
requires_auth: false,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
if endpoints.is_empty() {
|
||||
return Ok(PentestToolResult {
|
||||
summary: "No endpoints or base_url provided to fuzz.".to_string(),
|
||||
findings: Vec::new(),
|
||||
data: json!({}),
|
||||
});
|
||||
}
|
||||
|
||||
let dast_context = DastContext {
|
||||
endpoints,
|
||||
technologies: Vec::new(),
|
||||
sast_hints: Vec::new(),
|
||||
};
|
||||
|
||||
let findings = self.agent.run(&context.target, &dast_context).await?;
|
||||
let count = findings.len();
|
||||
|
||||
Ok(PentestToolResult {
|
||||
summary: if count > 0 {
|
||||
format!("Found {count} API misconfigurations or information disclosures.")
|
||||
} else {
|
||||
"No API misconfigurations detected.".to_string()
|
||||
},
|
||||
findings,
|
||||
data: json!({ "endpoints_tested": dast_context.endpoints.len() }),
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
130
compliance-dast/src/tools/auth_bypass.rs
Normal file
130
compliance-dast/src/tools/auth_bypass.rs
Normal file
@@ -0,0 +1,130 @@
|
||||
use compliance_core::error::CoreError;
|
||||
use compliance_core::traits::dast_agent::{DastAgent, DastContext, DiscoveredEndpoint, EndpointParameter};
|
||||
use compliance_core::traits::pentest_tool::{PentestTool, PentestToolContext, PentestToolResult};
|
||||
use serde_json::json;
|
||||
|
||||
use crate::agents::auth_bypass::AuthBypassAgent;
|
||||
|
||||
/// PentestTool wrapper around the existing AuthBypassAgent.
|
||||
pub struct AuthBypassTool {
|
||||
http: reqwest::Client,
|
||||
agent: AuthBypassAgent,
|
||||
}
|
||||
|
||||
impl AuthBypassTool {
|
||||
pub fn new(http: reqwest::Client) -> Self {
|
||||
let agent = AuthBypassAgent::new(http.clone());
|
||||
Self { http, agent }
|
||||
}
|
||||
|
||||
fn parse_endpoints(input: &serde_json::Value) -> Vec<DiscoveredEndpoint> {
|
||||
let mut endpoints = Vec::new();
|
||||
if let Some(arr) = input.get("endpoints").and_then(|v| v.as_array()) {
|
||||
for ep in arr {
|
||||
let url = ep.get("url").and_then(|v| v.as_str()).unwrap_or_default().to_string();
|
||||
let method = ep.get("method").and_then(|v| v.as_str()).unwrap_or("GET").to_string();
|
||||
let mut parameters = Vec::new();
|
||||
if let Some(params) = ep.get("parameters").and_then(|v| v.as_array()) {
|
||||
for p in params {
|
||||
parameters.push(EndpointParameter {
|
||||
name: p.get("name").and_then(|v| v.as_str()).unwrap_or_default().to_string(),
|
||||
location: p.get("location").and_then(|v| v.as_str()).unwrap_or("query").to_string(),
|
||||
param_type: p.get("param_type").and_then(|v| v.as_str()).map(String::from),
|
||||
example_value: p.get("example_value").and_then(|v| v.as_str()).map(String::from),
|
||||
});
|
||||
}
|
||||
}
|
||||
endpoints.push(DiscoveredEndpoint {
|
||||
url,
|
||||
method,
|
||||
parameters,
|
||||
content_type: ep.get("content_type").and_then(|v| v.as_str()).map(String::from),
|
||||
requires_auth: ep.get("requires_auth").and_then(|v| v.as_bool()).unwrap_or(false),
|
||||
});
|
||||
}
|
||||
}
|
||||
endpoints
|
||||
}
|
||||
}
|
||||
|
||||
impl PentestTool for AuthBypassTool {
|
||||
fn name(&self) -> &str {
|
||||
"auth_bypass_scanner"
|
||||
}
|
||||
|
||||
fn description(&self) -> &str {
|
||||
"Tests endpoints for authentication bypass vulnerabilities. Tries accessing protected \
|
||||
endpoints without credentials, with manipulated tokens, and with common default credentials."
|
||||
}
|
||||
|
||||
fn input_schema(&self) -> serde_json::Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"endpoints": {
|
||||
"type": "array",
|
||||
"description": "Endpoints to test for authentication bypass",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"url": { "type": "string" },
|
||||
"method": { "type": "string", "enum": ["GET", "POST", "PUT", "PATCH", "DELETE"] },
|
||||
"parameters": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": { "type": "string" },
|
||||
"location": { "type": "string" },
|
||||
"param_type": { "type": "string" },
|
||||
"example_value": { "type": "string" }
|
||||
},
|
||||
"required": ["name"]
|
||||
}
|
||||
},
|
||||
"requires_auth": { "type": "boolean", "description": "Whether this endpoint requires authentication" }
|
||||
},
|
||||
"required": ["url", "method"]
|
||||
}
|
||||
}
|
||||
},
|
||||
"required": ["endpoints"]
|
||||
})
|
||||
}
|
||||
|
||||
fn execute<'a>(
|
||||
&'a self,
|
||||
input: serde_json::Value,
|
||||
context: &'a PentestToolContext,
|
||||
) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<PentestToolResult, CoreError>> + Send + 'a>> {
|
||||
Box::pin(async move {
|
||||
let endpoints = Self::parse_endpoints(&input);
|
||||
if endpoints.is_empty() {
|
||||
return Ok(PentestToolResult {
|
||||
summary: "No endpoints provided to test.".to_string(),
|
||||
findings: Vec::new(),
|
||||
data: json!({}),
|
||||
});
|
||||
}
|
||||
|
||||
let dast_context = DastContext {
|
||||
endpoints,
|
||||
technologies: Vec::new(),
|
||||
sast_hints: Vec::new(),
|
||||
};
|
||||
|
||||
let findings = self.agent.run(&context.target, &dast_context).await?;
|
||||
let count = findings.len();
|
||||
|
||||
Ok(PentestToolResult {
|
||||
summary: if count > 0 {
|
||||
format!("Found {count} authentication bypass vulnerabilities.")
|
||||
} else {
|
||||
"No authentication bypass vulnerabilities detected.".to_string()
|
||||
},
|
||||
findings,
|
||||
data: json!({ "endpoints_tested": dast_context.endpoints.len() }),
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
326
compliance-dast/src/tools/console_log_detector.rs
Normal file
326
compliance-dast/src/tools/console_log_detector.rs
Normal file
@@ -0,0 +1,326 @@
|
||||
use compliance_core::error::CoreError;
|
||||
use compliance_core::models::dast::{DastEvidence, DastFinding, DastVulnType};
|
||||
use compliance_core::models::Severity;
|
||||
use compliance_core::traits::pentest_tool::{PentestTool, PentestToolContext, PentestToolResult};
|
||||
use serde_json::json;
|
||||
use tracing::info;
|
||||
|
||||
/// Tool that detects console.log and similar debug statements in frontend JavaScript.
|
||||
pub struct ConsoleLogDetectorTool {
|
||||
http: reqwest::Client,
|
||||
}
|
||||
|
||||
/// A detected console statement with its context.
|
||||
#[derive(Debug)]
|
||||
struct ConsoleMatch {
|
||||
pattern: String,
|
||||
file_url: String,
|
||||
line_snippet: String,
|
||||
line_number: Option<usize>,
|
||||
}
|
||||
|
||||
impl ConsoleLogDetectorTool {
|
||||
pub fn new(http: reqwest::Client) -> Self {
|
||||
Self { http }
|
||||
}
|
||||
|
||||
/// Patterns that indicate debug/logging statements left in production code.
|
||||
fn patterns() -> Vec<&'static str> {
|
||||
vec![
|
||||
"console.log(",
|
||||
"console.debug(",
|
||||
"console.error(",
|
||||
"console.warn(",
|
||||
"console.info(",
|
||||
"console.trace(",
|
||||
"console.dir(",
|
||||
"console.table(",
|
||||
"debugger;",
|
||||
"alert(",
|
||||
]
|
||||
}
|
||||
|
||||
/// Extract JavaScript file URLs from an HTML page body.
|
||||
fn extract_js_urls(html: &str, base_url: &str) -> Vec<String> {
|
||||
let mut urls = Vec::new();
|
||||
let base = url::Url::parse(base_url).ok();
|
||||
|
||||
// Simple regex-free extraction of <script src="...">
|
||||
let mut search_from = 0;
|
||||
while let Some(start) = html[search_from..].find("src=") {
|
||||
let abs_start = search_from + start + 4;
|
||||
if abs_start >= html.len() {
|
||||
break;
|
||||
}
|
||||
|
||||
let quote = html.as_bytes().get(abs_start).copied();
|
||||
let (open, close) = match quote {
|
||||
Some(b'"') => ('"', '"'),
|
||||
Some(b'\'') => ('\'', '\''),
|
||||
_ => {
|
||||
search_from = abs_start + 1;
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
let val_start = abs_start + 1;
|
||||
if let Some(end) = html[val_start..].find(close) {
|
||||
let src = &html[val_start..val_start + end];
|
||||
if src.ends_with(".js") || src.contains(".js?") || src.contains("/js/") {
|
||||
let full_url = if src.starts_with("http://") || src.starts_with("https://") {
|
||||
src.to_string()
|
||||
} else if src.starts_with("//") {
|
||||
format!("https:{src}")
|
||||
} else if let Some(ref base) = base {
|
||||
base.join(src).map(|u| u.to_string()).unwrap_or_default()
|
||||
} else {
|
||||
format!("{base_url}/{}", src.trim_start_matches('/'))
|
||||
};
|
||||
if !full_url.is_empty() {
|
||||
urls.push(full_url);
|
||||
}
|
||||
}
|
||||
search_from = val_start + end + 1;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
urls
|
||||
}
|
||||
|
||||
/// Search a JS file's contents for console/debug patterns.
|
||||
fn scan_js_content(content: &str, file_url: &str) -> Vec<ConsoleMatch> {
|
||||
let mut matches = Vec::new();
|
||||
|
||||
for (line_num, line) in content.lines().enumerate() {
|
||||
let trimmed = line.trim();
|
||||
// Skip comments (basic heuristic)
|
||||
if trimmed.starts_with("//") || trimmed.starts_with('*') || trimmed.starts_with("/*") {
|
||||
continue;
|
||||
}
|
||||
|
||||
for pattern in Self::patterns() {
|
||||
if line.contains(pattern) {
|
||||
let snippet = if line.len() > 200 {
|
||||
format!("{}...", &line[..200])
|
||||
} else {
|
||||
line.to_string()
|
||||
};
|
||||
matches.push(ConsoleMatch {
|
||||
pattern: pattern.trim_end_matches('(').to_string(),
|
||||
file_url: file_url.to_string(),
|
||||
line_snippet: snippet.trim().to_string(),
|
||||
line_number: Some(line_num + 1),
|
||||
});
|
||||
break; // One match per line is enough
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
matches
|
||||
}
|
||||
}
|
||||
|
||||
impl PentestTool for ConsoleLogDetectorTool {
|
||||
fn name(&self) -> &str {
|
||||
"console_log_detector"
|
||||
}
|
||||
|
||||
fn description(&self) -> &str {
|
||||
"Detects console.log, console.debug, console.error, debugger, and similar debug \
|
||||
statements left in production JavaScript. Fetches the HTML page and referenced JS files."
|
||||
}
|
||||
|
||||
fn input_schema(&self) -> serde_json::Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"url": {
|
||||
"type": "string",
|
||||
"description": "URL of the page to check for console.log leakage"
|
||||
},
|
||||
"additional_js_urls": {
|
||||
"type": "array",
|
||||
"description": "Optional additional JavaScript file URLs to scan",
|
||||
"items": { "type": "string" }
|
||||
}
|
||||
},
|
||||
"required": ["url"]
|
||||
})
|
||||
}
|
||||
|
||||
fn execute<'a>(
|
||||
&'a self,
|
||||
input: serde_json::Value,
|
||||
context: &'a PentestToolContext,
|
||||
) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<PentestToolResult, CoreError>> + Send + 'a>> {
|
||||
Box::pin(async move {
|
||||
let url = input
|
||||
.get("url")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| CoreError::Dast("Missing required 'url' parameter".to_string()))?;
|
||||
|
||||
let additional_js: Vec<String> = input
|
||||
.get("additional_js_urls")
|
||||
.and_then(|v| v.as_array())
|
||||
.map(|arr| {
|
||||
arr.iter()
|
||||
.filter_map(|v| v.as_str().map(String::from))
|
||||
.collect()
|
||||
})
|
||||
.unwrap_or_default();
|
||||
|
||||
let target_id = context
|
||||
.target
|
||||
.id
|
||||
.map(|oid| oid.to_hex())
|
||||
.unwrap_or_else(|| "unknown".to_string());
|
||||
|
||||
// Fetch the main page
|
||||
let response = self
|
||||
.http
|
||||
.get(url)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| CoreError::Dast(format!("Failed to fetch {url}: {e}")))?;
|
||||
|
||||
let html = response.text().await.unwrap_or_default();
|
||||
|
||||
// Scan inline scripts in the HTML
|
||||
let mut all_matches = Vec::new();
|
||||
let inline_matches = Self::scan_js_content(&html, url);
|
||||
all_matches.extend(inline_matches);
|
||||
|
||||
// Extract JS file URLs from the HTML
|
||||
let mut js_urls = Self::extract_js_urls(&html, url);
|
||||
js_urls.extend(additional_js);
|
||||
js_urls.dedup();
|
||||
|
||||
// Fetch and scan each JS file
|
||||
for js_url in &js_urls {
|
||||
match self.http.get(js_url).send().await {
|
||||
Ok(resp) => {
|
||||
if resp.status().is_success() {
|
||||
let js_content = resp.text().await.unwrap_or_default();
|
||||
// Only scan non-minified-looking files or files where we can still
|
||||
// find patterns (minifiers typically strip console calls, but not always)
|
||||
let file_matches = Self::scan_js_content(&js_content, js_url);
|
||||
all_matches.extend(file_matches);
|
||||
}
|
||||
}
|
||||
Err(_) => continue,
|
||||
}
|
||||
}
|
||||
|
||||
let mut findings = Vec::new();
|
||||
let match_data: Vec<serde_json::Value> = all_matches
|
||||
.iter()
|
||||
.map(|m| {
|
||||
json!({
|
||||
"pattern": m.pattern,
|
||||
"file": m.file_url,
|
||||
"line": m.line_number,
|
||||
"snippet": m.line_snippet,
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
||||
if !all_matches.is_empty() {
|
||||
// Group by file for the finding
|
||||
let mut by_file: std::collections::HashMap<&str, Vec<&ConsoleMatch>> =
|
||||
std::collections::HashMap::new();
|
||||
for m in &all_matches {
|
||||
by_file.entry(&m.file_url).or_default().push(m);
|
||||
}
|
||||
|
||||
for (file_url, matches) in &by_file {
|
||||
let pattern_summary: Vec<String> = matches
|
||||
.iter()
|
||||
.take(5)
|
||||
.map(|m| {
|
||||
format!(
|
||||
" Line {}: {} - {}",
|
||||
m.line_number.unwrap_or(0),
|
||||
m.pattern,
|
||||
if m.line_snippet.len() > 80 {
|
||||
format!("{}...", &m.line_snippet[..80])
|
||||
} else {
|
||||
m.line_snippet.clone()
|
||||
}
|
||||
)
|
||||
})
|
||||
.collect();
|
||||
|
||||
let evidence = DastEvidence {
|
||||
request_method: "GET".to_string(),
|
||||
request_url: file_url.to_string(),
|
||||
request_headers: None,
|
||||
request_body: None,
|
||||
response_status: 200,
|
||||
response_headers: None,
|
||||
response_snippet: Some(pattern_summary.join("\n")),
|
||||
screenshot_path: None,
|
||||
payload: None,
|
||||
response_time_ms: None,
|
||||
};
|
||||
|
||||
let total = matches.len();
|
||||
let extra = if total > 5 {
|
||||
format!(" (and {} more)", total - 5)
|
||||
} else {
|
||||
String::new()
|
||||
};
|
||||
|
||||
let mut finding = DastFinding::new(
|
||||
String::new(),
|
||||
target_id.clone(),
|
||||
DastVulnType::ConsoleLogLeakage,
|
||||
format!("Console/debug statements in {}", file_url),
|
||||
format!(
|
||||
"Found {total} console/debug statements in {file_url}{extra}. \
|
||||
These can leak sensitive information such as API responses, user data, \
|
||||
or internal state to anyone with browser developer tools open."
|
||||
),
|
||||
Severity::Low,
|
||||
file_url.to_string(),
|
||||
"GET".to_string(),
|
||||
);
|
||||
finding.cwe = Some("CWE-532".to_string());
|
||||
finding.evidence = vec![evidence];
|
||||
finding.remediation = Some(
|
||||
"Remove console.log/debug/error statements from production code. \
|
||||
Use a build step (e.g., babel plugin, terser) to strip console calls \
|
||||
during the production build."
|
||||
.to_string(),
|
||||
);
|
||||
findings.push(finding);
|
||||
}
|
||||
}
|
||||
|
||||
let total_matches = all_matches.len();
|
||||
let count = findings.len();
|
||||
info!(url, js_files = js_urls.len(), total_matches, "Console log detection complete");
|
||||
|
||||
Ok(PentestToolResult {
|
||||
summary: if total_matches > 0 {
|
||||
format!(
|
||||
"Found {total_matches} console/debug statements across {} files.",
|
||||
count
|
||||
)
|
||||
} else {
|
||||
format!(
|
||||
"No console/debug statements found in HTML or {} JS files.",
|
||||
js_urls.len()
|
||||
)
|
||||
},
|
||||
findings,
|
||||
data: json!({
|
||||
"total_matches": total_matches,
|
||||
"js_files_scanned": js_urls.len(),
|
||||
"matches": match_data,
|
||||
}),
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
401
compliance-dast/src/tools/cookie_analyzer.rs
Normal file
401
compliance-dast/src/tools/cookie_analyzer.rs
Normal file
@@ -0,0 +1,401 @@
|
||||
use compliance_core::error::CoreError;
|
||||
use compliance_core::models::dast::{DastEvidence, DastFinding, DastVulnType};
|
||||
use compliance_core::models::Severity;
|
||||
use compliance_core::traits::pentest_tool::{PentestTool, PentestToolContext, PentestToolResult};
|
||||
use serde_json::json;
|
||||
use tracing::info;
|
||||
|
||||
/// Tool that inspects cookies set by a target for security issues.
|
||||
pub struct CookieAnalyzerTool {
|
||||
http: reqwest::Client,
|
||||
}
|
||||
|
||||
/// Parsed attributes from a Set-Cookie header.
|
||||
#[derive(Debug)]
|
||||
struct ParsedCookie {
|
||||
name: String,
|
||||
value: String,
|
||||
secure: bool,
|
||||
http_only: bool,
|
||||
same_site: Option<String>,
|
||||
domain: Option<String>,
|
||||
path: Option<String>,
|
||||
raw: String,
|
||||
}
|
||||
|
||||
impl CookieAnalyzerTool {
|
||||
pub fn new(http: reqwest::Client) -> Self {
|
||||
Self { http }
|
||||
}
|
||||
|
||||
/// Parse a Set-Cookie header value into a structured representation.
|
||||
fn parse_set_cookie(header: &str) -> ParsedCookie {
|
||||
let raw = header.to_string();
|
||||
let parts: Vec<&str> = header.split(';').collect();
|
||||
|
||||
let (name, value) = if let Some(kv) = parts.first() {
|
||||
let mut kv_split = kv.splitn(2, '=');
|
||||
let k = kv_split.next().unwrap_or("").trim().to_string();
|
||||
let v = kv_split.next().unwrap_or("").trim().to_string();
|
||||
(k, v)
|
||||
} else {
|
||||
(String::new(), String::new())
|
||||
};
|
||||
|
||||
let mut secure = false;
|
||||
let mut http_only = false;
|
||||
let mut same_site = None;
|
||||
let mut domain = None;
|
||||
let mut path = None;
|
||||
|
||||
for part in parts.iter().skip(1) {
|
||||
let trimmed = part.trim().to_lowercase();
|
||||
if trimmed == "secure" {
|
||||
secure = true;
|
||||
} else if trimmed == "httponly" {
|
||||
http_only = true;
|
||||
} else if let Some(ss) = trimmed.strip_prefix("samesite=") {
|
||||
same_site = Some(ss.trim().to_string());
|
||||
} else if let Some(d) = trimmed.strip_prefix("domain=") {
|
||||
domain = Some(d.trim().to_string());
|
||||
} else if let Some(p) = trimmed.strip_prefix("path=") {
|
||||
path = Some(p.trim().to_string());
|
||||
}
|
||||
}
|
||||
|
||||
ParsedCookie {
|
||||
name,
|
||||
value,
|
||||
secure,
|
||||
http_only,
|
||||
same_site,
|
||||
domain,
|
||||
path,
|
||||
raw,
|
||||
}
|
||||
}
|
||||
|
||||
/// Heuristic: does this cookie name suggest it's a session / auth cookie?
|
||||
fn is_sensitive_cookie(name: &str) -> bool {
|
||||
let lower = name.to_lowercase();
|
||||
lower.contains("session")
|
||||
|| lower.contains("sess")
|
||||
|| lower.contains("token")
|
||||
|| lower.contains("auth")
|
||||
|| lower.contains("jwt")
|
||||
|| lower.contains("csrf")
|
||||
|| lower.contains("sid")
|
||||
|| lower == "connect.sid"
|
||||
|| lower == "phpsessid"
|
||||
|| lower == "jsessionid"
|
||||
|| lower == "asp.net_sessionid"
|
||||
}
|
||||
}
|
||||
|
||||
impl PentestTool for CookieAnalyzerTool {
|
||||
fn name(&self) -> &str {
|
||||
"cookie_analyzer"
|
||||
}
|
||||
|
||||
fn description(&self) -> &str {
|
||||
"Analyzes cookies set by a target URL. Checks for Secure, HttpOnly, SameSite attributes \
|
||||
and overly broad Domain/Path settings. Focuses on session and authentication cookies."
|
||||
}
|
||||
|
||||
fn input_schema(&self) -> serde_json::Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"url": {
|
||||
"type": "string",
|
||||
"description": "URL to fetch and analyze cookies from"
|
||||
},
|
||||
"login_url": {
|
||||
"type": "string",
|
||||
"description": "Optional login URL to also check (may set auth cookies)"
|
||||
}
|
||||
},
|
||||
"required": ["url"]
|
||||
})
|
||||
}
|
||||
|
||||
fn execute<'a>(
|
||||
&'a self,
|
||||
input: serde_json::Value,
|
||||
context: &'a PentestToolContext,
|
||||
) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<PentestToolResult, CoreError>> + Send + 'a>> {
|
||||
Box::pin(async move {
|
||||
let url = input
|
||||
.get("url")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| CoreError::Dast("Missing required 'url' parameter".to_string()))?;
|
||||
|
||||
let login_url = input.get("login_url").and_then(|v| v.as_str());
|
||||
|
||||
let target_id = context
|
||||
.target
|
||||
.id
|
||||
.map(|oid| oid.to_hex())
|
||||
.unwrap_or_else(|| "unknown".to_string());
|
||||
|
||||
let mut findings = Vec::new();
|
||||
let mut cookie_data = Vec::new();
|
||||
|
||||
// Collect Set-Cookie headers from the main URL and optional login URL
|
||||
let urls_to_check: Vec<&str> = std::iter::once(url)
|
||||
.chain(login_url.into_iter())
|
||||
.collect();
|
||||
|
||||
for check_url in &urls_to_check {
|
||||
// Use a client that does NOT follow redirects so we catch cookies on redirect responses
|
||||
let no_redirect_client = reqwest::Client::builder()
|
||||
.danger_accept_invalid_certs(true)
|
||||
.redirect(reqwest::redirect::Policy::none())
|
||||
.timeout(std::time::Duration::from_secs(15))
|
||||
.build()
|
||||
.map_err(|e| CoreError::Dast(format!("Client build error: {e}")))?;
|
||||
|
||||
let response = match no_redirect_client.get(*check_url).send().await {
|
||||
Ok(r) => r,
|
||||
Err(e) => {
|
||||
// Try with the main client that follows redirects
|
||||
match self.http.get(*check_url).send().await {
|
||||
Ok(r) => r,
|
||||
Err(_) => continue,
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
let status = response.status().as_u16();
|
||||
let set_cookie_headers: Vec<String> = response
|
||||
.headers()
|
||||
.get_all("set-cookie")
|
||||
.iter()
|
||||
.filter_map(|v| v.to_str().ok().map(String::from))
|
||||
.collect();
|
||||
|
||||
for raw_cookie in &set_cookie_headers {
|
||||
let cookie = Self::parse_set_cookie(raw_cookie);
|
||||
let is_sensitive = Self::is_sensitive_cookie(&cookie.name);
|
||||
let is_https = check_url.starts_with("https://");
|
||||
|
||||
let cookie_info = json!({
|
||||
"name": cookie.name,
|
||||
"secure": cookie.secure,
|
||||
"http_only": cookie.http_only,
|
||||
"same_site": cookie.same_site,
|
||||
"domain": cookie.domain,
|
||||
"path": cookie.path,
|
||||
"is_sensitive": is_sensitive,
|
||||
"url": check_url,
|
||||
});
|
||||
cookie_data.push(cookie_info);
|
||||
|
||||
// Check: missing Secure flag
|
||||
if !cookie.secure && (is_https || is_sensitive) {
|
||||
let severity = if is_sensitive {
|
||||
Severity::High
|
||||
} else {
|
||||
Severity::Medium
|
||||
};
|
||||
|
||||
let evidence = DastEvidence {
|
||||
request_method: "GET".to_string(),
|
||||
request_url: check_url.to_string(),
|
||||
request_headers: None,
|
||||
request_body: None,
|
||||
response_status: status,
|
||||
response_headers: None,
|
||||
response_snippet: Some(cookie.raw.clone()),
|
||||
screenshot_path: None,
|
||||
payload: None,
|
||||
response_time_ms: None,
|
||||
};
|
||||
|
||||
let mut finding = DastFinding::new(
|
||||
String::new(),
|
||||
target_id.clone(),
|
||||
DastVulnType::CookieSecurity,
|
||||
format!("Cookie '{}' missing Secure flag", cookie.name),
|
||||
format!(
|
||||
"The cookie '{}' does not have the Secure attribute set. \
|
||||
Without this flag, the cookie can be transmitted over unencrypted HTTP connections.",
|
||||
cookie.name
|
||||
),
|
||||
severity,
|
||||
check_url.to_string(),
|
||||
"GET".to_string(),
|
||||
);
|
||||
finding.cwe = Some("CWE-614".to_string());
|
||||
finding.evidence = vec![evidence];
|
||||
finding.remediation = Some(
|
||||
"Add the 'Secure' attribute to the Set-Cookie header to ensure the \
|
||||
cookie is only sent over HTTPS connections."
|
||||
.to_string(),
|
||||
);
|
||||
findings.push(finding);
|
||||
}
|
||||
|
||||
// Check: missing HttpOnly flag on sensitive cookies
|
||||
if !cookie.http_only && is_sensitive {
|
||||
let evidence = DastEvidence {
|
||||
request_method: "GET".to_string(),
|
||||
request_url: check_url.to_string(),
|
||||
request_headers: None,
|
||||
request_body: None,
|
||||
response_status: status,
|
||||
response_headers: None,
|
||||
response_snippet: Some(cookie.raw.clone()),
|
||||
screenshot_path: None,
|
||||
payload: None,
|
||||
response_time_ms: None,
|
||||
};
|
||||
|
||||
let mut finding = DastFinding::new(
|
||||
String::new(),
|
||||
target_id.clone(),
|
||||
DastVulnType::CookieSecurity,
|
||||
format!("Cookie '{}' missing HttpOnly flag", cookie.name),
|
||||
format!(
|
||||
"The session/auth cookie '{}' does not have the HttpOnly attribute. \
|
||||
This makes it accessible to JavaScript, increasing the impact of XSS attacks.",
|
||||
cookie.name
|
||||
),
|
||||
Severity::High,
|
||||
check_url.to_string(),
|
||||
"GET".to_string(),
|
||||
);
|
||||
finding.cwe = Some("CWE-1004".to_string());
|
||||
finding.evidence = vec![evidence];
|
||||
finding.remediation = Some(
|
||||
"Add the 'HttpOnly' attribute to the Set-Cookie header to prevent \
|
||||
JavaScript access to the cookie."
|
||||
.to_string(),
|
||||
);
|
||||
findings.push(finding);
|
||||
}
|
||||
|
||||
// Check: missing or weak SameSite
|
||||
if is_sensitive {
|
||||
let weak_same_site = match &cookie.same_site {
|
||||
None => true,
|
||||
Some(ss) => ss == "none",
|
||||
};
|
||||
|
||||
if weak_same_site {
|
||||
let evidence = DastEvidence {
|
||||
request_method: "GET".to_string(),
|
||||
request_url: check_url.to_string(),
|
||||
request_headers: None,
|
||||
request_body: None,
|
||||
response_status: status,
|
||||
response_headers: None,
|
||||
response_snippet: Some(cookie.raw.clone()),
|
||||
screenshot_path: None,
|
||||
payload: None,
|
||||
response_time_ms: None,
|
||||
};
|
||||
|
||||
let desc = if cookie.same_site.is_none() {
|
||||
format!(
|
||||
"The session/auth cookie '{}' does not have a SameSite attribute. \
|
||||
This may allow cross-site request forgery (CSRF) attacks.",
|
||||
cookie.name
|
||||
)
|
||||
} else {
|
||||
format!(
|
||||
"The session/auth cookie '{}' has SameSite=None, which allows it \
|
||||
to be sent in cross-site requests, enabling CSRF attacks.",
|
||||
cookie.name
|
||||
)
|
||||
};
|
||||
|
||||
let mut finding = DastFinding::new(
|
||||
String::new(),
|
||||
target_id.clone(),
|
||||
DastVulnType::CookieSecurity,
|
||||
format!("Cookie '{}' missing or weak SameSite", cookie.name),
|
||||
desc,
|
||||
Severity::Medium,
|
||||
check_url.to_string(),
|
||||
"GET".to_string(),
|
||||
);
|
||||
finding.cwe = Some("CWE-1275".to_string());
|
||||
finding.evidence = vec![evidence];
|
||||
finding.remediation = Some(
|
||||
"Set 'SameSite=Strict' or 'SameSite=Lax' on session/auth cookies \
|
||||
to prevent cross-site request inclusion."
|
||||
.to_string(),
|
||||
);
|
||||
findings.push(finding);
|
||||
}
|
||||
}
|
||||
|
||||
// Check: overly broad domain
|
||||
if let Some(ref domain) = cookie.domain {
|
||||
// A domain starting with a dot applies to all subdomains
|
||||
let dot_domain = domain.starts_with('.');
|
||||
// Count domain parts - if only 2 parts (e.g., .example.com), it's broad
|
||||
let parts: Vec<&str> = domain.trim_start_matches('.').split('.').collect();
|
||||
if dot_domain && parts.len() <= 2 && is_sensitive {
|
||||
let evidence = DastEvidence {
|
||||
request_method: "GET".to_string(),
|
||||
request_url: check_url.to_string(),
|
||||
request_headers: None,
|
||||
request_body: None,
|
||||
response_status: status,
|
||||
response_headers: None,
|
||||
response_snippet: Some(cookie.raw.clone()),
|
||||
screenshot_path: None,
|
||||
payload: None,
|
||||
response_time_ms: None,
|
||||
};
|
||||
|
||||
let mut finding = DastFinding::new(
|
||||
String::new(),
|
||||
target_id.clone(),
|
||||
DastVulnType::CookieSecurity,
|
||||
format!("Cookie '{}' has overly broad domain", cookie.name),
|
||||
format!(
|
||||
"The cookie '{}' is scoped to domain '{}' which includes all \
|
||||
subdomains. If any subdomain is compromised, the attacker can \
|
||||
access this cookie.",
|
||||
cookie.name, domain
|
||||
),
|
||||
Severity::Low,
|
||||
check_url.to_string(),
|
||||
"GET".to_string(),
|
||||
);
|
||||
finding.cwe = Some("CWE-1004".to_string());
|
||||
finding.evidence = vec![evidence];
|
||||
finding.remediation = Some(
|
||||
"Restrict the cookie domain to the specific subdomain that needs it \
|
||||
rather than the entire parent domain."
|
||||
.to_string(),
|
||||
);
|
||||
findings.push(finding);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let count = findings.len();
|
||||
info!(url, findings = count, "Cookie analysis complete");
|
||||
|
||||
Ok(PentestToolResult {
|
||||
summary: if count > 0 {
|
||||
format!("Found {count} cookie security issues.")
|
||||
} else if cookie_data.is_empty() {
|
||||
"No cookies were set by the target.".to_string()
|
||||
} else {
|
||||
"All cookies have proper security attributes.".to_string()
|
||||
},
|
||||
findings,
|
||||
data: json!({
|
||||
"cookies": cookie_data,
|
||||
"total_cookies": cookie_data.len(),
|
||||
}),
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
410
compliance-dast/src/tools/cors_checker.rs
Normal file
410
compliance-dast/src/tools/cors_checker.rs
Normal file
@@ -0,0 +1,410 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
use compliance_core::error::CoreError;
|
||||
use compliance_core::models::dast::{DastEvidence, DastFinding, DastVulnType};
|
||||
use compliance_core::models::Severity;
|
||||
use compliance_core::traits::pentest_tool::{PentestTool, PentestToolContext, PentestToolResult};
|
||||
use serde_json::json;
|
||||
use tracing::{info, warn};
|
||||
|
||||
/// Tool that checks CORS configuration for security issues.
|
||||
pub struct CorsCheckerTool {
|
||||
http: reqwest::Client,
|
||||
}
|
||||
|
||||
impl CorsCheckerTool {
|
||||
pub fn new(http: reqwest::Client) -> Self {
|
||||
Self { http }
|
||||
}
|
||||
|
||||
/// Origins to test against the target.
|
||||
fn test_origins(target_host: &str) -> Vec<(&'static str, String)> {
|
||||
vec![
|
||||
("null_origin", "null".to_string()),
|
||||
("evil_domain", "https://evil.com".to_string()),
|
||||
(
|
||||
"subdomain_spoof",
|
||||
format!("https://{target_host}.evil.com"),
|
||||
),
|
||||
(
|
||||
"prefix_spoof",
|
||||
format!("https://evil-{target_host}"),
|
||||
),
|
||||
("http_downgrade", format!("http://{target_host}")),
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
impl PentestTool for CorsCheckerTool {
|
||||
fn name(&self) -> &str {
|
||||
"cors_checker"
|
||||
}
|
||||
|
||||
fn description(&self) -> &str {
|
||||
"Checks CORS configuration by sending requests with various Origin headers. Tests for \
|
||||
wildcard origins, reflected origins, null origin acceptance, and dangerous \
|
||||
Access-Control-Allow-Credentials combinations."
|
||||
}
|
||||
|
||||
fn input_schema(&self) -> serde_json::Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"url": {
|
||||
"type": "string",
|
||||
"description": "URL to test CORS configuration on"
|
||||
},
|
||||
"additional_origins": {
|
||||
"type": "array",
|
||||
"description": "Optional additional origin values to test",
|
||||
"items": { "type": "string" }
|
||||
}
|
||||
},
|
||||
"required": ["url"]
|
||||
})
|
||||
}
|
||||
|
||||
fn execute<'a>(
|
||||
&'a self,
|
||||
input: serde_json::Value,
|
||||
context: &'a PentestToolContext,
|
||||
) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<PentestToolResult, CoreError>> + Send + 'a>> {
|
||||
Box::pin(async move {
|
||||
let url = input
|
||||
.get("url")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| CoreError::Dast("Missing required 'url' parameter".to_string()))?;
|
||||
|
||||
let additional_origins: Vec<String> = input
|
||||
.get("additional_origins")
|
||||
.and_then(|v| v.as_array())
|
||||
.map(|arr| {
|
||||
arr.iter()
|
||||
.filter_map(|v| v.as_str().map(String::from))
|
||||
.collect()
|
||||
})
|
||||
.unwrap_or_default();
|
||||
|
||||
let target_id = context
|
||||
.target
|
||||
.id
|
||||
.map(|oid| oid.to_hex())
|
||||
.unwrap_or_else(|| "unknown".to_string());
|
||||
|
||||
let target_host = url::Url::parse(url)
|
||||
.ok()
|
||||
.and_then(|u| u.host_str().map(String::from))
|
||||
.unwrap_or_else(|| url.to_string());
|
||||
|
||||
let mut findings = Vec::new();
|
||||
let mut cors_data: Vec<serde_json::Value> = Vec::new();
|
||||
|
||||
// First, send a request without Origin to get baseline
|
||||
let baseline = self
|
||||
.http
|
||||
.get(url)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| CoreError::Dast(format!("Failed to fetch {url}: {e}")))?;
|
||||
|
||||
let baseline_acao = baseline
|
||||
.headers()
|
||||
.get("access-control-allow-origin")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.map(String::from);
|
||||
|
||||
cors_data.push(json!({
|
||||
"origin": null,
|
||||
"acao": baseline_acao,
|
||||
}));
|
||||
|
||||
// Check for wildcard + credentials (dangerous combo)
|
||||
if let Some(ref acao) = baseline_acao {
|
||||
if acao == "*" {
|
||||
let acac = baseline
|
||||
.headers()
|
||||
.get("access-control-allow-credentials")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.unwrap_or("");
|
||||
|
||||
if acac.to_lowercase() == "true" {
|
||||
let evidence = DastEvidence {
|
||||
request_method: "GET".to_string(),
|
||||
request_url: url.to_string(),
|
||||
request_headers: None,
|
||||
request_body: None,
|
||||
response_status: baseline.status().as_u16(),
|
||||
response_headers: None,
|
||||
response_snippet: Some(format!(
|
||||
"Access-Control-Allow-Origin: *\nAccess-Control-Allow-Credentials: true"
|
||||
)),
|
||||
screenshot_path: None,
|
||||
payload: None,
|
||||
response_time_ms: None,
|
||||
};
|
||||
|
||||
let mut finding = DastFinding::new(
|
||||
String::new(),
|
||||
target_id.clone(),
|
||||
DastVulnType::CorsMisconfiguration,
|
||||
"CORS wildcard with credentials".to_string(),
|
||||
format!(
|
||||
"The endpoint {url} returns Access-Control-Allow-Origin: * with \
|
||||
Access-Control-Allow-Credentials: true. While browsers should block this \
|
||||
combination, it indicates a serious CORS misconfiguration."
|
||||
),
|
||||
Severity::High,
|
||||
url.to_string(),
|
||||
"GET".to_string(),
|
||||
);
|
||||
finding.cwe = Some("CWE-942".to_string());
|
||||
finding.evidence = vec![evidence];
|
||||
finding.remediation = Some(
|
||||
"Never combine Access-Control-Allow-Origin: * with \
|
||||
Access-Control-Allow-Credentials: true. Specify explicit allowed origins."
|
||||
.to_string(),
|
||||
);
|
||||
findings.push(finding);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Test with various Origin headers
|
||||
let mut test_origins = Self::test_origins(&target_host);
|
||||
for origin in &additional_origins {
|
||||
test_origins.push(("custom", origin.clone()));
|
||||
}
|
||||
|
||||
for (test_name, origin) in &test_origins {
|
||||
let resp = match self
|
||||
.http
|
||||
.get(url)
|
||||
.header("Origin", origin.as_str())
|
||||
.send()
|
||||
.await
|
||||
{
|
||||
Ok(r) => r,
|
||||
Err(_) => continue,
|
||||
};
|
||||
|
||||
let status = resp.status().as_u16();
|
||||
let acao = resp
|
||||
.headers()
|
||||
.get("access-control-allow-origin")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.map(String::from);
|
||||
|
||||
let acac = resp
|
||||
.headers()
|
||||
.get("access-control-allow-credentials")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.map(String::from);
|
||||
|
||||
let acam = resp
|
||||
.headers()
|
||||
.get("access-control-allow-methods")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.map(String::from);
|
||||
|
||||
cors_data.push(json!({
|
||||
"test": test_name,
|
||||
"origin": origin,
|
||||
"acao": acao,
|
||||
"acac": acac,
|
||||
"acam": acam,
|
||||
"status": status,
|
||||
}));
|
||||
|
||||
// Check if the origin was reflected back
|
||||
if let Some(ref acao_val) = acao {
|
||||
let origin_reflected = acao_val == origin;
|
||||
let credentials_allowed = acac
|
||||
.as_ref()
|
||||
.map(|v| v.to_lowercase() == "true")
|
||||
.unwrap_or(false);
|
||||
|
||||
if origin_reflected && *test_name != "http_downgrade" {
|
||||
let severity = if credentials_allowed {
|
||||
Severity::Critical
|
||||
} else {
|
||||
Severity::High
|
||||
};
|
||||
|
||||
let resp_headers: HashMap<String, String> = resp
|
||||
.headers()
|
||||
.iter()
|
||||
.map(|(k, v)| (k.to_string(), v.to_str().unwrap_or("").to_string()))
|
||||
.collect();
|
||||
|
||||
let evidence = DastEvidence {
|
||||
request_method: "GET".to_string(),
|
||||
request_url: url.to_string(),
|
||||
request_headers: Some(
|
||||
[("Origin".to_string(), origin.clone())]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
),
|
||||
request_body: None,
|
||||
response_status: status,
|
||||
response_headers: Some(resp_headers),
|
||||
response_snippet: Some(format!(
|
||||
"Origin: {origin}\nAccess-Control-Allow-Origin: {acao_val}\n\
|
||||
Access-Control-Allow-Credentials: {}",
|
||||
acac.as_deref().unwrap_or("not set")
|
||||
)),
|
||||
screenshot_path: None,
|
||||
payload: Some(origin.clone()),
|
||||
response_time_ms: None,
|
||||
};
|
||||
|
||||
let title = match *test_name {
|
||||
"null_origin" => {
|
||||
"CORS accepts null origin".to_string()
|
||||
}
|
||||
"evil_domain" => {
|
||||
"CORS reflects arbitrary origin".to_string()
|
||||
}
|
||||
"subdomain_spoof" => {
|
||||
"CORS vulnerable to subdomain spoofing".to_string()
|
||||
}
|
||||
"prefix_spoof" => {
|
||||
"CORS vulnerable to prefix spoofing".to_string()
|
||||
}
|
||||
_ => format!("CORS reflects untrusted origin ({test_name})"),
|
||||
};
|
||||
|
||||
let cred_note = if credentials_allowed {
|
||||
" Combined with Access-Control-Allow-Credentials: true, this allows \
|
||||
the attacker to steal authenticated data."
|
||||
} else {
|
||||
""
|
||||
};
|
||||
|
||||
let mut finding = DastFinding::new(
|
||||
String::new(),
|
||||
target_id.clone(),
|
||||
DastVulnType::CorsMisconfiguration,
|
||||
title,
|
||||
format!(
|
||||
"The endpoint {url} reflects the Origin header '{origin}' back in \
|
||||
Access-Control-Allow-Origin, allowing cross-origin requests from \
|
||||
untrusted domains.{cred_note}"
|
||||
),
|
||||
severity,
|
||||
url.to_string(),
|
||||
"GET".to_string(),
|
||||
);
|
||||
finding.cwe = Some("CWE-942".to_string());
|
||||
finding.exploitable = credentials_allowed;
|
||||
finding.evidence = vec![evidence];
|
||||
finding.remediation = Some(
|
||||
"Validate the Origin header against a whitelist of trusted origins. \
|
||||
Do not reflect the Origin header value directly. Use specific allowed \
|
||||
origins instead of wildcards."
|
||||
.to_string(),
|
||||
);
|
||||
findings.push(finding);
|
||||
|
||||
warn!(
|
||||
url,
|
||||
test_name,
|
||||
origin,
|
||||
credentials = credentials_allowed,
|
||||
"CORS misconfiguration detected"
|
||||
);
|
||||
}
|
||||
|
||||
// Special case: HTTP downgrade
|
||||
if *test_name == "http_downgrade" && origin_reflected && credentials_allowed {
|
||||
let evidence = DastEvidence {
|
||||
request_method: "GET".to_string(),
|
||||
request_url: url.to_string(),
|
||||
request_headers: Some(
|
||||
[("Origin".to_string(), origin.clone())]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
),
|
||||
request_body: None,
|
||||
response_status: status,
|
||||
response_headers: None,
|
||||
response_snippet: Some(format!(
|
||||
"HTTP origin accepted: {origin} -> ACAO: {acao_val}"
|
||||
)),
|
||||
screenshot_path: None,
|
||||
payload: Some(origin.clone()),
|
||||
response_time_ms: None,
|
||||
};
|
||||
|
||||
let mut finding = DastFinding::new(
|
||||
String::new(),
|
||||
target_id.clone(),
|
||||
DastVulnType::CorsMisconfiguration,
|
||||
"CORS allows HTTP origin with credentials".to_string(),
|
||||
format!(
|
||||
"The HTTPS endpoint {url} accepts the HTTP origin {origin} with \
|
||||
credentials. An attacker performing a man-in-the-middle attack on \
|
||||
the HTTP version could steal authenticated data."
|
||||
),
|
||||
Severity::High,
|
||||
url.to_string(),
|
||||
"GET".to_string(),
|
||||
);
|
||||
finding.cwe = Some("CWE-942".to_string());
|
||||
finding.evidence = vec![evidence];
|
||||
finding.remediation = Some(
|
||||
"Do not accept HTTP origins for HTTPS endpoints. Ensure CORS \
|
||||
origin validation enforces the https:// scheme."
|
||||
.to_string(),
|
||||
);
|
||||
findings.push(finding);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Also send a preflight OPTIONS request
|
||||
if let Ok(resp) = self
|
||||
.http
|
||||
.request(reqwest::Method::OPTIONS, url)
|
||||
.header("Origin", "https://evil.com")
|
||||
.header("Access-Control-Request-Method", "POST")
|
||||
.header("Access-Control-Request-Headers", "Authorization, Content-Type")
|
||||
.send()
|
||||
.await
|
||||
{
|
||||
let acam = resp
|
||||
.headers()
|
||||
.get("access-control-allow-methods")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.map(String::from);
|
||||
|
||||
let acah = resp
|
||||
.headers()
|
||||
.get("access-control-allow-headers")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.map(String::from);
|
||||
|
||||
cors_data.push(json!({
|
||||
"test": "preflight",
|
||||
"status": resp.status().as_u16(),
|
||||
"allow_methods": acam,
|
||||
"allow_headers": acah,
|
||||
}));
|
||||
}
|
||||
|
||||
let count = findings.len();
|
||||
info!(url, findings = count, "CORS check complete");
|
||||
|
||||
Ok(PentestToolResult {
|
||||
summary: if count > 0 {
|
||||
format!("Found {count} CORS misconfiguration issues for {url}.")
|
||||
} else {
|
||||
format!("CORS configuration appears secure for {url}.")
|
||||
},
|
||||
findings,
|
||||
data: json!({
|
||||
"tests": cors_data,
|
||||
}),
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
447
compliance-dast/src/tools/csp_analyzer.rs
Normal file
447
compliance-dast/src/tools/csp_analyzer.rs
Normal file
@@ -0,0 +1,447 @@
|
||||
use compliance_core::error::CoreError;
|
||||
use compliance_core::models::dast::{DastEvidence, DastFinding, DastVulnType};
|
||||
use compliance_core::models::Severity;
|
||||
use compliance_core::traits::pentest_tool::{PentestTool, PentestToolContext, PentestToolResult};
|
||||
use serde_json::json;
|
||||
use tracing::info;
|
||||
|
||||
/// Tool that analyzes Content-Security-Policy headers.
|
||||
pub struct CspAnalyzerTool {
|
||||
http: reqwest::Client,
|
||||
}
|
||||
|
||||
/// A parsed CSP directive.
|
||||
#[derive(Debug)]
|
||||
struct CspDirective {
|
||||
name: String,
|
||||
values: Vec<String>,
|
||||
}
|
||||
|
||||
impl CspAnalyzerTool {
|
||||
pub fn new(http: reqwest::Client) -> Self {
|
||||
Self { http }
|
||||
}
|
||||
|
||||
/// Parse a CSP header string into directives.
|
||||
fn parse_csp(csp: &str) -> Vec<CspDirective> {
|
||||
let mut directives = Vec::new();
|
||||
for part in csp.split(';') {
|
||||
let trimmed = part.trim();
|
||||
if trimmed.is_empty() {
|
||||
continue;
|
||||
}
|
||||
let tokens: Vec<&str> = trimmed.split_whitespace().collect();
|
||||
if let Some((name, values)) = tokens.split_first() {
|
||||
directives.push(CspDirective {
|
||||
name: name.to_lowercase(),
|
||||
values: values.iter().map(|v| v.to_string()).collect(),
|
||||
});
|
||||
}
|
||||
}
|
||||
directives
|
||||
}
|
||||
|
||||
/// Check a CSP for common issues and return findings.
|
||||
fn analyze_directives(
|
||||
directives: &[CspDirective],
|
||||
url: &str,
|
||||
target_id: &str,
|
||||
status: u16,
|
||||
csp_raw: &str,
|
||||
) -> Vec<DastFinding> {
|
||||
let mut findings = Vec::new();
|
||||
|
||||
let make_evidence = |snippet: String| DastEvidence {
|
||||
request_method: "GET".to_string(),
|
||||
request_url: url.to_string(),
|
||||
request_headers: None,
|
||||
request_body: None,
|
||||
response_status: status,
|
||||
response_headers: None,
|
||||
response_snippet: Some(snippet),
|
||||
screenshot_path: None,
|
||||
payload: None,
|
||||
response_time_ms: None,
|
||||
};
|
||||
|
||||
// Check for unsafe-inline in script-src
|
||||
for d in directives {
|
||||
if (d.name == "script-src" || d.name == "default-src")
|
||||
&& d.values.iter().any(|v| v == "'unsafe-inline'")
|
||||
{
|
||||
let evidence = make_evidence(format!("{}: {}", d.name, d.values.join(" ")));
|
||||
let mut finding = DastFinding::new(
|
||||
String::new(),
|
||||
target_id.to_string(),
|
||||
DastVulnType::CspIssue,
|
||||
format!("CSP allows 'unsafe-inline' in {}", d.name),
|
||||
format!(
|
||||
"The Content-Security-Policy directive '{}' includes 'unsafe-inline', \
|
||||
which defeats the purpose of CSP by allowing inline scripts that \
|
||||
could be exploited via XSS.",
|
||||
d.name
|
||||
),
|
||||
Severity::High,
|
||||
url.to_string(),
|
||||
"GET".to_string(),
|
||||
);
|
||||
finding.cwe = Some("CWE-79".to_string());
|
||||
finding.evidence = vec![evidence];
|
||||
finding.remediation = Some(
|
||||
"Remove 'unsafe-inline' from script-src. Use nonces or hashes for \
|
||||
legitimate inline scripts instead."
|
||||
.to_string(),
|
||||
);
|
||||
findings.push(finding);
|
||||
}
|
||||
|
||||
// Check for unsafe-eval
|
||||
if (d.name == "script-src" || d.name == "default-src")
|
||||
&& d.values.iter().any(|v| v == "'unsafe-eval'")
|
||||
{
|
||||
let evidence = make_evidence(format!("{}: {}", d.name, d.values.join(" ")));
|
||||
let mut finding = DastFinding::new(
|
||||
String::new(),
|
||||
target_id.to_string(),
|
||||
DastVulnType::CspIssue,
|
||||
format!("CSP allows 'unsafe-eval' in {}", d.name),
|
||||
format!(
|
||||
"The Content-Security-Policy directive '{}' includes 'unsafe-eval', \
|
||||
which allows the use of eval() and similar dynamic code execution \
|
||||
that can be exploited via XSS.",
|
||||
d.name
|
||||
),
|
||||
Severity::Medium,
|
||||
url.to_string(),
|
||||
"GET".to_string(),
|
||||
);
|
||||
finding.cwe = Some("CWE-79".to_string());
|
||||
finding.evidence = vec![evidence];
|
||||
finding.remediation = Some(
|
||||
"Remove 'unsafe-eval' from script-src. Refactor code to avoid eval(), \
|
||||
Function(), and similar constructs."
|
||||
.to_string(),
|
||||
);
|
||||
findings.push(finding);
|
||||
}
|
||||
|
||||
// Check for wildcard sources
|
||||
if d.values.iter().any(|v| v == "*") {
|
||||
let evidence = make_evidence(format!("{}: {}", d.name, d.values.join(" ")));
|
||||
let mut finding = DastFinding::new(
|
||||
String::new(),
|
||||
target_id.to_string(),
|
||||
DastVulnType::CspIssue,
|
||||
format!("CSP wildcard source in {}", d.name),
|
||||
format!(
|
||||
"The Content-Security-Policy directive '{}' uses a wildcard '*' source, \
|
||||
which allows loading resources from any origin and largely negates CSP protection.",
|
||||
d.name
|
||||
),
|
||||
Severity::Medium,
|
||||
url.to_string(),
|
||||
"GET".to_string(),
|
||||
);
|
||||
finding.cwe = Some("CWE-16".to_string());
|
||||
finding.evidence = vec![evidence];
|
||||
finding.remediation = Some(format!(
|
||||
"Replace the wildcard '*' in {} with specific allowed origins.",
|
||||
d.name
|
||||
));
|
||||
findings.push(finding);
|
||||
}
|
||||
|
||||
// Check for http: sources (non-HTTPS)
|
||||
if d.values.iter().any(|v| v == "http:") {
|
||||
let evidence = make_evidence(format!("{}: {}", d.name, d.values.join(" ")));
|
||||
let mut finding = DastFinding::new(
|
||||
String::new(),
|
||||
target_id.to_string(),
|
||||
DastVulnType::CspIssue,
|
||||
format!("CSP allows HTTP sources in {}", d.name),
|
||||
format!(
|
||||
"The Content-Security-Policy directive '{}' allows loading resources \
|
||||
over unencrypted HTTP, which can be exploited via man-in-the-middle attacks.",
|
||||
d.name
|
||||
),
|
||||
Severity::Medium,
|
||||
url.to_string(),
|
||||
"GET".to_string(),
|
||||
);
|
||||
finding.cwe = Some("CWE-319".to_string());
|
||||
finding.evidence = vec![evidence];
|
||||
finding.remediation = Some(format!(
|
||||
"Replace 'http:' with 'https:' in {} to enforce encrypted resource loading.",
|
||||
d.name
|
||||
));
|
||||
findings.push(finding);
|
||||
}
|
||||
|
||||
// Check for data: in script-src (can be used to bypass CSP)
|
||||
if (d.name == "script-src" || d.name == "default-src")
|
||||
&& d.values.iter().any(|v| v == "data:")
|
||||
{
|
||||
let evidence = make_evidence(format!("{}: {}", d.name, d.values.join(" ")));
|
||||
let mut finding = DastFinding::new(
|
||||
String::new(),
|
||||
target_id.to_string(),
|
||||
DastVulnType::CspIssue,
|
||||
format!("CSP allows data: URIs in {}", d.name),
|
||||
format!(
|
||||
"The Content-Security-Policy directive '{}' allows 'data:' URIs, \
|
||||
which can be used to bypass CSP and execute arbitrary scripts.",
|
||||
d.name
|
||||
),
|
||||
Severity::High,
|
||||
url.to_string(),
|
||||
"GET".to_string(),
|
||||
);
|
||||
finding.cwe = Some("CWE-79".to_string());
|
||||
finding.evidence = vec![evidence];
|
||||
finding.remediation = Some(format!(
|
||||
"Remove 'data:' from {}. If data URIs are needed, restrict them to \
|
||||
non-executable content types only (e.g., img-src).",
|
||||
d.name
|
||||
));
|
||||
findings.push(finding);
|
||||
}
|
||||
}
|
||||
|
||||
// Check for missing important directives
|
||||
let directive_names: Vec<&str> = directives.iter().map(|d| d.name.as_str()).collect();
|
||||
let has_default_src = directive_names.contains(&"default-src");
|
||||
|
||||
let important_directives = [
|
||||
("script-src", "Controls which scripts can execute"),
|
||||
("object-src", "Controls plugins like Flash"),
|
||||
("base-uri", "Controls the base URL for relative URLs"),
|
||||
("form-action", "Controls where forms can submit to"),
|
||||
("frame-ancestors", "Controls who can embed this page in iframes"),
|
||||
];
|
||||
|
||||
for (dir_name, desc) in &important_directives {
|
||||
if !directive_names.contains(dir_name)
|
||||
&& !(has_default_src && *dir_name != "frame-ancestors" && *dir_name != "base-uri" && *dir_name != "form-action")
|
||||
{
|
||||
let evidence = make_evidence(format!("CSP missing directive: {dir_name}"));
|
||||
let mut finding = DastFinding::new(
|
||||
String::new(),
|
||||
target_id.to_string(),
|
||||
DastVulnType::CspIssue,
|
||||
format!("CSP missing '{}' directive", dir_name),
|
||||
format!(
|
||||
"The Content-Security-Policy is missing the '{}' directive. {}. \
|
||||
Without this directive{}, the browser may fall back to less restrictive defaults.",
|
||||
dir_name,
|
||||
desc,
|
||||
if has_default_src && (*dir_name == "frame-ancestors" || *dir_name == "base-uri" || *dir_name == "form-action") {
|
||||
" (not covered by default-src)"
|
||||
} else {
|
||||
""
|
||||
}
|
||||
),
|
||||
Severity::Low,
|
||||
url.to_string(),
|
||||
"GET".to_string(),
|
||||
);
|
||||
finding.cwe = Some("CWE-16".to_string());
|
||||
finding.evidence = vec![evidence];
|
||||
finding.remediation = Some(format!(
|
||||
"Add '{}: 'none'' or an appropriate restrictive value to your CSP.",
|
||||
dir_name
|
||||
));
|
||||
findings.push(finding);
|
||||
}
|
||||
}
|
||||
|
||||
findings
|
||||
}
|
||||
}
|
||||
|
||||
impl PentestTool for CspAnalyzerTool {
|
||||
fn name(&self) -> &str {
|
||||
"csp_analyzer"
|
||||
}
|
||||
|
||||
fn description(&self) -> &str {
|
||||
"Analyzes Content-Security-Policy headers. Checks for unsafe-inline, unsafe-eval, \
|
||||
wildcard sources, data: URIs in script-src, missing directives, and other CSP weaknesses."
|
||||
}
|
||||
|
||||
fn input_schema(&self) -> serde_json::Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"url": {
|
||||
"type": "string",
|
||||
"description": "URL to fetch and analyze CSP from"
|
||||
}
|
||||
},
|
||||
"required": ["url"]
|
||||
})
|
||||
}
|
||||
|
||||
fn execute<'a>(
|
||||
&'a self,
|
||||
input: serde_json::Value,
|
||||
context: &'a PentestToolContext,
|
||||
) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<PentestToolResult, CoreError>> + Send + 'a>> {
|
||||
Box::pin(async move {
|
||||
let url = input
|
||||
.get("url")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| CoreError::Dast("Missing required 'url' parameter".to_string()))?;
|
||||
|
||||
let target_id = context
|
||||
.target
|
||||
.id
|
||||
.map(|oid| oid.to_hex())
|
||||
.unwrap_or_else(|| "unknown".to_string());
|
||||
|
||||
let response = self
|
||||
.http
|
||||
.get(url)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| CoreError::Dast(format!("Failed to fetch {url}: {e}")))?;
|
||||
|
||||
let status = response.status().as_u16();
|
||||
|
||||
// Check for CSP header
|
||||
let csp_header = response
|
||||
.headers()
|
||||
.get("content-security-policy")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.map(String::from);
|
||||
|
||||
// Also check for report-only variant
|
||||
let csp_report_only = response
|
||||
.headers()
|
||||
.get("content-security-policy-report-only")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.map(String::from);
|
||||
|
||||
let mut findings = Vec::new();
|
||||
let mut csp_data = json!({});
|
||||
|
||||
match &csp_header {
|
||||
Some(csp) => {
|
||||
let directives = Self::parse_csp(csp);
|
||||
let directive_map: serde_json::Value = directives
|
||||
.iter()
|
||||
.map(|d| (d.name.clone(), json!(d.values)))
|
||||
.collect::<serde_json::Map<String, serde_json::Value>>()
|
||||
.into();
|
||||
|
||||
csp_data["csp_header"] = json!(csp);
|
||||
csp_data["directives"] = directive_map;
|
||||
|
||||
findings.extend(Self::analyze_directives(
|
||||
&directives,
|
||||
url,
|
||||
&target_id,
|
||||
status,
|
||||
csp,
|
||||
));
|
||||
}
|
||||
None => {
|
||||
csp_data["csp_header"] = json!(null);
|
||||
|
||||
let evidence = DastEvidence {
|
||||
request_method: "GET".to_string(),
|
||||
request_url: url.to_string(),
|
||||
request_headers: None,
|
||||
request_body: None,
|
||||
response_status: status,
|
||||
response_headers: None,
|
||||
response_snippet: Some("Content-Security-Policy header is missing".to_string()),
|
||||
screenshot_path: None,
|
||||
payload: None,
|
||||
response_time_ms: None,
|
||||
};
|
||||
|
||||
let mut finding = DastFinding::new(
|
||||
String::new(),
|
||||
target_id.clone(),
|
||||
DastVulnType::CspIssue,
|
||||
"Missing Content-Security-Policy header".to_string(),
|
||||
format!(
|
||||
"No Content-Security-Policy header is present on {url}. \
|
||||
Without CSP, the browser has no instructions on which sources are \
|
||||
trusted, making XSS exploitation much easier."
|
||||
),
|
||||
Severity::Medium,
|
||||
url.to_string(),
|
||||
"GET".to_string(),
|
||||
);
|
||||
finding.cwe = Some("CWE-16".to_string());
|
||||
finding.evidence = vec![evidence];
|
||||
finding.remediation = Some(
|
||||
"Add a Content-Security-Policy header. Start with a restrictive policy like \
|
||||
\"default-src 'self'; script-src 'self'; style-src 'self'; img-src 'self' data:; \
|
||||
object-src 'none'; frame-ancestors 'none'; base-uri 'self'\"."
|
||||
.to_string(),
|
||||
);
|
||||
findings.push(finding);
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(ref report_only) = csp_report_only {
|
||||
csp_data["csp_report_only"] = json!(report_only);
|
||||
// If ONLY report-only exists (no enforcing CSP), warn
|
||||
if csp_header.is_none() {
|
||||
let evidence = DastEvidence {
|
||||
request_method: "GET".to_string(),
|
||||
request_url: url.to_string(),
|
||||
request_headers: None,
|
||||
request_body: None,
|
||||
response_status: status,
|
||||
response_headers: None,
|
||||
response_snippet: Some(format!(
|
||||
"Content-Security-Policy-Report-Only: {}",
|
||||
report_only
|
||||
)),
|
||||
screenshot_path: None,
|
||||
payload: None,
|
||||
response_time_ms: None,
|
||||
};
|
||||
|
||||
let mut finding = DastFinding::new(
|
||||
String::new(),
|
||||
target_id.clone(),
|
||||
DastVulnType::CspIssue,
|
||||
"CSP is report-only, not enforcing".to_string(),
|
||||
"A Content-Security-Policy-Report-Only header is present but no enforcing \
|
||||
Content-Security-Policy header exists. Report-only mode only logs violations \
|
||||
but does not block them."
|
||||
.to_string(),
|
||||
Severity::Low,
|
||||
url.to_string(),
|
||||
"GET".to_string(),
|
||||
);
|
||||
finding.cwe = Some("CWE-16".to_string());
|
||||
finding.evidence = vec![evidence];
|
||||
finding.remediation = Some(
|
||||
"Once you have verified the CSP policy works correctly in report-only mode, \
|
||||
deploy it as an enforcing Content-Security-Policy header."
|
||||
.to_string(),
|
||||
);
|
||||
findings.push(finding);
|
||||
}
|
||||
}
|
||||
|
||||
let count = findings.len();
|
||||
info!(url, findings = count, "CSP analysis complete");
|
||||
|
||||
Ok(PentestToolResult {
|
||||
summary: if count > 0 {
|
||||
format!("Found {count} CSP issues for {url}.")
|
||||
} else {
|
||||
format!("Content-Security-Policy looks good for {url}.")
|
||||
},
|
||||
findings,
|
||||
data: csp_data,
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
401
compliance-dast/src/tools/dmarc_checker.rs
Normal file
401
compliance-dast/src/tools/dmarc_checker.rs
Normal file
@@ -0,0 +1,401 @@
|
||||
use compliance_core::error::CoreError;
|
||||
use compliance_core::models::dast::{DastEvidence, DastFinding, DastVulnType};
|
||||
use compliance_core::models::Severity;
|
||||
use compliance_core::traits::pentest_tool::{PentestTool, PentestToolContext, PentestToolResult};
|
||||
use serde_json::json;
|
||||
use tracing::{info, warn};
|
||||
|
||||
/// Tool that checks email security configuration (DMARC and SPF records).
|
||||
pub struct DmarcCheckerTool;
|
||||
|
||||
impl DmarcCheckerTool {
|
||||
pub fn new() -> Self {
|
||||
Self
|
||||
}
|
||||
|
||||
/// Query TXT records for a given name using `dig`.
|
||||
async fn query_txt(name: &str) -> Result<Vec<String>, CoreError> {
|
||||
let output = tokio::process::Command::new("dig")
|
||||
.args(["+short", "TXT", name])
|
||||
.output()
|
||||
.await
|
||||
.map_err(|e| CoreError::Dast(format!("dig command failed: {e}")))?;
|
||||
|
||||
let stdout = String::from_utf8_lossy(&output.stdout);
|
||||
let lines: Vec<String> = stdout
|
||||
.lines()
|
||||
.map(|l| l.trim().trim_matches('"').to_string())
|
||||
.filter(|l| !l.is_empty())
|
||||
.collect();
|
||||
Ok(lines)
|
||||
}
|
||||
|
||||
/// Parse a DMARC record string and return the policy value.
|
||||
fn parse_dmarc_policy(record: &str) -> Option<String> {
|
||||
for part in record.split(';') {
|
||||
let part = part.trim();
|
||||
if let Some(val) = part.strip_prefix("p=") {
|
||||
return Some(val.trim().to_lowercase());
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
/// Parse DMARC record for sub-domain policy (sp=).
|
||||
fn parse_dmarc_subdomain_policy(record: &str) -> Option<String> {
|
||||
for part in record.split(';') {
|
||||
let part = part.trim();
|
||||
if let Some(val) = part.strip_prefix("sp=") {
|
||||
return Some(val.trim().to_lowercase());
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
/// Parse DMARC record for reporting URI (rua=).
|
||||
fn parse_dmarc_rua(record: &str) -> Option<String> {
|
||||
for part in record.split(';') {
|
||||
let part = part.trim();
|
||||
if let Some(val) = part.strip_prefix("rua=") {
|
||||
return Some(val.trim().to_string());
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
/// Check if an SPF record is present and parse the policy.
|
||||
fn is_spf_record(record: &str) -> bool {
|
||||
record.starts_with("v=spf1")
|
||||
}
|
||||
|
||||
/// Evaluate SPF record strength.
|
||||
fn spf_uses_soft_fail(record: &str) -> bool {
|
||||
record.contains("~all")
|
||||
}
|
||||
|
||||
fn spf_allows_all(record: &str) -> bool {
|
||||
record.contains("+all")
|
||||
}
|
||||
}
|
||||
|
||||
impl PentestTool for DmarcCheckerTool {
|
||||
fn name(&self) -> &str {
|
||||
"dmarc_checker"
|
||||
}
|
||||
|
||||
fn description(&self) -> &str {
|
||||
"Checks email security configuration for a domain. Queries DMARC and SPF records and \
|
||||
evaluates policy strength. Reports missing or weak email authentication settings."
|
||||
}
|
||||
|
||||
fn input_schema(&self) -> serde_json::Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"domain": {
|
||||
"type": "string",
|
||||
"description": "The domain to check (e.g., 'example.com')"
|
||||
}
|
||||
},
|
||||
"required": ["domain"]
|
||||
})
|
||||
}
|
||||
|
||||
fn execute<'a>(
|
||||
&'a self,
|
||||
input: serde_json::Value,
|
||||
context: &'a PentestToolContext,
|
||||
) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<PentestToolResult, CoreError>> + Send + 'a>> {
|
||||
Box::pin(async move {
|
||||
let domain = input
|
||||
.get("domain")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| CoreError::Dast("Missing required 'domain' parameter".to_string()))?;
|
||||
|
||||
let target_id = context
|
||||
.target
|
||||
.id
|
||||
.map(|oid| oid.to_hex())
|
||||
.unwrap_or_else(|| "unknown".to_string());
|
||||
|
||||
let mut findings = Vec::new();
|
||||
let mut email_data = json!({});
|
||||
|
||||
// ---- DMARC check ----
|
||||
let dmarc_domain = format!("_dmarc.{domain}");
|
||||
let dmarc_records = Self::query_txt(&dmarc_domain).await.unwrap_or_default();
|
||||
|
||||
let dmarc_record = dmarc_records.iter().find(|r| r.starts_with("v=DMARC1"));
|
||||
|
||||
match dmarc_record {
|
||||
Some(record) => {
|
||||
email_data["dmarc_record"] = json!(record);
|
||||
|
||||
let policy = Self::parse_dmarc_policy(record);
|
||||
let sp = Self::parse_dmarc_subdomain_policy(record);
|
||||
let rua = Self::parse_dmarc_rua(record);
|
||||
|
||||
email_data["dmarc_policy"] = json!(policy);
|
||||
email_data["dmarc_subdomain_policy"] = json!(sp);
|
||||
email_data["dmarc_rua"] = json!(rua);
|
||||
|
||||
// Warn on weak policy
|
||||
if let Some(ref p) = policy {
|
||||
if p == "none" {
|
||||
let evidence = DastEvidence {
|
||||
request_method: "DNS".to_string(),
|
||||
request_url: dmarc_domain.clone(),
|
||||
request_headers: None,
|
||||
request_body: None,
|
||||
response_status: 0,
|
||||
response_headers: None,
|
||||
response_snippet: Some(record.clone()),
|
||||
screenshot_path: None,
|
||||
payload: None,
|
||||
response_time_ms: None,
|
||||
};
|
||||
|
||||
let mut finding = DastFinding::new(
|
||||
String::new(),
|
||||
target_id.clone(),
|
||||
DastVulnType::EmailSecurity,
|
||||
format!("Weak DMARC policy for {domain}"),
|
||||
format!(
|
||||
"The DMARC policy for {domain} is set to 'none', which only monitors \
|
||||
but does not enforce email authentication. Attackers can spoof emails \
|
||||
from this domain."
|
||||
),
|
||||
Severity::Medium,
|
||||
domain.to_string(),
|
||||
"DNS".to_string(),
|
||||
);
|
||||
finding.cwe = Some("CWE-290".to_string());
|
||||
finding.evidence = vec![evidence];
|
||||
finding.remediation = Some(
|
||||
"Upgrade the DMARC policy from 'p=none' to 'p=quarantine' or \
|
||||
'p=reject' after verifying legitimate email flows are properly \
|
||||
authenticated."
|
||||
.to_string(),
|
||||
);
|
||||
findings.push(finding);
|
||||
warn!(domain, "DMARC policy is 'none'");
|
||||
}
|
||||
}
|
||||
|
||||
// Warn if no reporting URI
|
||||
if rua.is_none() {
|
||||
let evidence = DastEvidence {
|
||||
request_method: "DNS".to_string(),
|
||||
request_url: dmarc_domain.clone(),
|
||||
request_headers: None,
|
||||
request_body: None,
|
||||
response_status: 0,
|
||||
response_headers: None,
|
||||
response_snippet: Some(record.clone()),
|
||||
screenshot_path: None,
|
||||
payload: None,
|
||||
response_time_ms: None,
|
||||
};
|
||||
|
||||
let mut finding = DastFinding::new(
|
||||
String::new(),
|
||||
target_id.clone(),
|
||||
DastVulnType::EmailSecurity,
|
||||
format!("DMARC without reporting for {domain}"),
|
||||
format!(
|
||||
"The DMARC record for {domain} does not include a reporting URI (rua=). \
|
||||
Without reporting, you will not receive aggregate feedback about email \
|
||||
authentication failures."
|
||||
),
|
||||
Severity::Info,
|
||||
domain.to_string(),
|
||||
"DNS".to_string(),
|
||||
);
|
||||
finding.cwe = Some("CWE-778".to_string());
|
||||
finding.evidence = vec![evidence];
|
||||
finding.remediation = Some(
|
||||
"Add a 'rua=' tag to your DMARC record to receive aggregate reports. \
|
||||
Example: 'rua=mailto:dmarc-reports@example.com'."
|
||||
.to_string(),
|
||||
);
|
||||
findings.push(finding);
|
||||
}
|
||||
}
|
||||
None => {
|
||||
email_data["dmarc_record"] = json!(null);
|
||||
|
||||
let evidence = DastEvidence {
|
||||
request_method: "DNS".to_string(),
|
||||
request_url: dmarc_domain.clone(),
|
||||
request_headers: None,
|
||||
request_body: None,
|
||||
response_status: 0,
|
||||
response_headers: None,
|
||||
response_snippet: Some("No DMARC record found".to_string()),
|
||||
screenshot_path: None,
|
||||
payload: None,
|
||||
response_time_ms: None,
|
||||
};
|
||||
|
||||
let mut finding = DastFinding::new(
|
||||
String::new(),
|
||||
target_id.clone(),
|
||||
DastVulnType::EmailSecurity,
|
||||
format!("Missing DMARC record for {domain}"),
|
||||
format!(
|
||||
"No DMARC record was found for {domain}. Without DMARC, there is no \
|
||||
policy to prevent email spoofing and phishing attacks using this domain."
|
||||
),
|
||||
Severity::High,
|
||||
domain.to_string(),
|
||||
"DNS".to_string(),
|
||||
);
|
||||
finding.cwe = Some("CWE-290".to_string());
|
||||
finding.evidence = vec![evidence];
|
||||
finding.remediation = Some(
|
||||
"Create a DMARC TXT record at _dmarc.<domain>. Start with 'v=DMARC1; p=none; \
|
||||
rua=mailto:dmarc@example.com' and gradually move to 'p=reject'."
|
||||
.to_string(),
|
||||
);
|
||||
findings.push(finding);
|
||||
warn!(domain, "No DMARC record found");
|
||||
}
|
||||
}
|
||||
|
||||
// ---- SPF check ----
|
||||
let txt_records = Self::query_txt(domain).await.unwrap_or_default();
|
||||
let spf_record = txt_records.iter().find(|r| Self::is_spf_record(r));
|
||||
|
||||
match spf_record {
|
||||
Some(record) => {
|
||||
email_data["spf_record"] = json!(record);
|
||||
|
||||
if Self::spf_allows_all(record) {
|
||||
let evidence = DastEvidence {
|
||||
request_method: "DNS".to_string(),
|
||||
request_url: domain.to_string(),
|
||||
request_headers: None,
|
||||
request_body: None,
|
||||
response_status: 0,
|
||||
response_headers: None,
|
||||
response_snippet: Some(record.clone()),
|
||||
screenshot_path: None,
|
||||
payload: None,
|
||||
response_time_ms: None,
|
||||
};
|
||||
|
||||
let mut finding = DastFinding::new(
|
||||
String::new(),
|
||||
target_id.clone(),
|
||||
DastVulnType::EmailSecurity,
|
||||
format!("SPF allows all senders for {domain}"),
|
||||
format!(
|
||||
"The SPF record for {domain} uses '+all' which allows any server to \
|
||||
send email on behalf of this domain, completely negating SPF protection."
|
||||
),
|
||||
Severity::Critical,
|
||||
domain.to_string(),
|
||||
"DNS".to_string(),
|
||||
);
|
||||
finding.cwe = Some("CWE-290".to_string());
|
||||
finding.evidence = vec![evidence];
|
||||
finding.remediation = Some(
|
||||
"Change '+all' to '-all' (hard fail) or '~all' (soft fail) in your SPF record. \
|
||||
Only list authorized mail servers."
|
||||
.to_string(),
|
||||
);
|
||||
findings.push(finding);
|
||||
} else if Self::spf_uses_soft_fail(record) {
|
||||
let evidence = DastEvidence {
|
||||
request_method: "DNS".to_string(),
|
||||
request_url: domain.to_string(),
|
||||
request_headers: None,
|
||||
request_body: None,
|
||||
response_status: 0,
|
||||
response_headers: None,
|
||||
response_snippet: Some(record.clone()),
|
||||
screenshot_path: None,
|
||||
payload: None,
|
||||
response_time_ms: None,
|
||||
};
|
||||
|
||||
let mut finding = DastFinding::new(
|
||||
String::new(),
|
||||
target_id.clone(),
|
||||
DastVulnType::EmailSecurity,
|
||||
format!("SPF soft fail for {domain}"),
|
||||
format!(
|
||||
"The SPF record for {domain} uses '~all' (soft fail) instead of \
|
||||
'-all' (hard fail). Soft fail marks unauthorized emails as suspicious \
|
||||
but does not reject them."
|
||||
),
|
||||
Severity::Low,
|
||||
domain.to_string(),
|
||||
"DNS".to_string(),
|
||||
);
|
||||
finding.cwe = Some("CWE-290".to_string());
|
||||
finding.evidence = vec![evidence];
|
||||
finding.remediation = Some(
|
||||
"Consider changing '~all' to '-all' in your SPF record once you have \
|
||||
confirmed all legitimate mail sources are listed."
|
||||
.to_string(),
|
||||
);
|
||||
findings.push(finding);
|
||||
}
|
||||
}
|
||||
None => {
|
||||
email_data["spf_record"] = json!(null);
|
||||
|
||||
let evidence = DastEvidence {
|
||||
request_method: "DNS".to_string(),
|
||||
request_url: domain.to_string(),
|
||||
request_headers: None,
|
||||
request_body: None,
|
||||
response_status: 0,
|
||||
response_headers: None,
|
||||
response_snippet: Some("No SPF record found".to_string()),
|
||||
screenshot_path: None,
|
||||
payload: None,
|
||||
response_time_ms: None,
|
||||
};
|
||||
|
||||
let mut finding = DastFinding::new(
|
||||
String::new(),
|
||||
target_id.clone(),
|
||||
DastVulnType::EmailSecurity,
|
||||
format!("Missing SPF record for {domain}"),
|
||||
format!(
|
||||
"No SPF record was found for {domain}. Without SPF, any server can claim \
|
||||
to send email on behalf of this domain."
|
||||
),
|
||||
Severity::High,
|
||||
domain.to_string(),
|
||||
"DNS".to_string(),
|
||||
);
|
||||
finding.cwe = Some("CWE-290".to_string());
|
||||
finding.evidence = vec![evidence];
|
||||
finding.remediation = Some(
|
||||
"Create an SPF TXT record for your domain. Example: \
|
||||
'v=spf1 include:_spf.google.com -all'."
|
||||
.to_string(),
|
||||
);
|
||||
findings.push(finding);
|
||||
warn!(domain, "No SPF record found");
|
||||
}
|
||||
}
|
||||
|
||||
let count = findings.len();
|
||||
info!(domain, findings = count, "DMARC/SPF check complete");
|
||||
|
||||
Ok(PentestToolResult {
|
||||
summary: if count > 0 {
|
||||
format!("Found {count} email security issues for {domain}.")
|
||||
} else {
|
||||
format!("Email security configuration looks good for {domain}.")
|
||||
},
|
||||
findings,
|
||||
data: email_data,
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
389
compliance-dast/src/tools/dns_checker.rs
Normal file
389
compliance-dast/src/tools/dns_checker.rs
Normal file
@@ -0,0 +1,389 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
use compliance_core::error::CoreError;
|
||||
use compliance_core::models::dast::{DastEvidence, DastFinding, DastVulnType};
|
||||
use compliance_core::models::Severity;
|
||||
use compliance_core::traits::pentest_tool::{PentestTool, PentestToolContext, PentestToolResult};
|
||||
use serde_json::json;
|
||||
use tokio::net::lookup_host;
|
||||
use tracing::{info, warn};
|
||||
|
||||
/// Tool that checks DNS configuration for security issues.
|
||||
///
|
||||
/// Resolves A, AAAA, MX, TXT, CNAME, NS records using the system resolver
|
||||
/// via `tokio::net::lookup_host` and `std::net::ToSocketAddrs`. For TXT-based
|
||||
/// records (SPF, DMARC, CAA, DNSSEC) it uses a simple TXT query via the
|
||||
/// `tokio::process::Command` wrapper around `dig` where available.
|
||||
pub struct DnsCheckerTool;
|
||||
|
||||
impl DnsCheckerTool {
|
||||
pub fn new() -> Self {
|
||||
Self
|
||||
}
|
||||
|
||||
/// Run a `dig` query and return the answer lines.
|
||||
async fn dig_query(domain: &str, record_type: &str) -> Result<Vec<String>, CoreError> {
|
||||
let output = tokio::process::Command::new("dig")
|
||||
.args(["+short", record_type, domain])
|
||||
.output()
|
||||
.await
|
||||
.map_err(|e| CoreError::Dast(format!("dig command failed: {e}")))?;
|
||||
|
||||
let stdout = String::from_utf8_lossy(&output.stdout);
|
||||
let lines: Vec<String> = stdout
|
||||
.lines()
|
||||
.map(|l| l.trim().to_string())
|
||||
.filter(|l| !l.is_empty())
|
||||
.collect();
|
||||
Ok(lines)
|
||||
}
|
||||
|
||||
/// Resolve A/AAAA records using tokio lookup.
|
||||
async fn resolve_addresses(domain: &str) -> Result<(Vec<String>, Vec<String>), CoreError> {
|
||||
let mut ipv4 = Vec::new();
|
||||
let mut ipv6 = Vec::new();
|
||||
|
||||
let addr_str = format!("{domain}:443");
|
||||
match lookup_host(&addr_str).await {
|
||||
Ok(addrs) => {
|
||||
for addr in addrs {
|
||||
match addr {
|
||||
std::net::SocketAddr::V4(v4) => ipv4.push(v4.ip().to_string()),
|
||||
std::net::SocketAddr::V6(v6) => ipv6.push(v6.ip().to_string()),
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
return Err(CoreError::Dast(format!("DNS resolution failed for {domain}: {e}")));
|
||||
}
|
||||
}
|
||||
|
||||
Ok((ipv4, ipv6))
|
||||
}
|
||||
}
|
||||
|
||||
impl PentestTool for DnsCheckerTool {
|
||||
fn name(&self) -> &str {
|
||||
"dns_checker"
|
||||
}
|
||||
|
||||
fn description(&self) -> &str {
|
||||
"Checks DNS configuration for a domain. Resolves A, AAAA, MX, TXT, CNAME, NS records. \
|
||||
Checks for DNSSEC, CAA records, and potential subdomain takeover via dangling CNAME/NS."
|
||||
}
|
||||
|
||||
fn input_schema(&self) -> serde_json::Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"domain": {
|
||||
"type": "string",
|
||||
"description": "The domain to check (e.g., 'example.com')"
|
||||
},
|
||||
"subdomains": {
|
||||
"type": "array",
|
||||
"description": "Optional list of subdomains to also check (e.g., ['www', 'api', 'mail'])",
|
||||
"items": { "type": "string" }
|
||||
}
|
||||
},
|
||||
"required": ["domain"]
|
||||
})
|
||||
}
|
||||
|
||||
fn execute<'a>(
|
||||
&'a self,
|
||||
input: serde_json::Value,
|
||||
context: &'a PentestToolContext,
|
||||
) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<PentestToolResult, CoreError>> + Send + 'a>> {
|
||||
Box::pin(async move {
|
||||
let domain = input
|
||||
.get("domain")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| CoreError::Dast("Missing required 'domain' parameter".to_string()))?;
|
||||
|
||||
let subdomains: Vec<String> = input
|
||||
.get("subdomains")
|
||||
.and_then(|v| v.as_array())
|
||||
.map(|arr| {
|
||||
arr.iter()
|
||||
.filter_map(|v| v.as_str().map(String::from))
|
||||
.collect()
|
||||
})
|
||||
.unwrap_or_default();
|
||||
|
||||
let mut findings = Vec::new();
|
||||
let mut dns_data: HashMap<String, serde_json::Value> = HashMap::new();
|
||||
|
||||
let target_id = context
|
||||
.target
|
||||
.id
|
||||
.map(|oid| oid.to_hex())
|
||||
.unwrap_or_else(|| "unknown".to_string());
|
||||
|
||||
// --- A / AAAA records ---
|
||||
match Self::resolve_addresses(domain).await {
|
||||
Ok((ipv4, ipv6)) => {
|
||||
dns_data.insert("a_records".to_string(), json!(ipv4));
|
||||
dns_data.insert("aaaa_records".to_string(), json!(ipv6));
|
||||
}
|
||||
Err(e) => {
|
||||
dns_data.insert("a_records_error".to_string(), json!(e.to_string()));
|
||||
}
|
||||
}
|
||||
|
||||
// --- MX records ---
|
||||
match Self::dig_query(domain, "MX").await {
|
||||
Ok(mx) => {
|
||||
dns_data.insert("mx_records".to_string(), json!(mx));
|
||||
}
|
||||
Err(e) => {
|
||||
dns_data.insert("mx_records_error".to_string(), json!(e.to_string()));
|
||||
}
|
||||
}
|
||||
|
||||
// --- NS records ---
|
||||
let ns_records = match Self::dig_query(domain, "NS").await {
|
||||
Ok(ns) => {
|
||||
dns_data.insert("ns_records".to_string(), json!(ns));
|
||||
ns
|
||||
}
|
||||
Err(e) => {
|
||||
dns_data.insert("ns_records_error".to_string(), json!(e.to_string()));
|
||||
Vec::new()
|
||||
}
|
||||
};
|
||||
|
||||
// --- TXT records ---
|
||||
match Self::dig_query(domain, "TXT").await {
|
||||
Ok(txt) => {
|
||||
dns_data.insert("txt_records".to_string(), json!(txt));
|
||||
}
|
||||
Err(e) => {
|
||||
dns_data.insert("txt_records_error".to_string(), json!(e.to_string()));
|
||||
}
|
||||
}
|
||||
|
||||
// --- CNAME records (for subdomains) ---
|
||||
let mut cname_data: HashMap<String, Vec<String>> = HashMap::new();
|
||||
let mut domains_to_check = vec![domain.to_string()];
|
||||
for sub in &subdomains {
|
||||
domains_to_check.push(format!("{sub}.{domain}"));
|
||||
}
|
||||
|
||||
for fqdn in &domains_to_check {
|
||||
match Self::dig_query(fqdn, "CNAME").await {
|
||||
Ok(cnames) if !cnames.is_empty() => {
|
||||
// Check for dangling CNAME
|
||||
for cname in &cnames {
|
||||
let cname_clean = cname.trim_end_matches('.');
|
||||
let check_addr = format!("{cname_clean}:443");
|
||||
let is_dangling = lookup_host(&check_addr).await.is_err();
|
||||
if is_dangling {
|
||||
let evidence = DastEvidence {
|
||||
request_method: "DNS".to_string(),
|
||||
request_url: fqdn.clone(),
|
||||
request_headers: None,
|
||||
request_body: None,
|
||||
response_status: 0,
|
||||
response_headers: None,
|
||||
response_snippet: Some(format!(
|
||||
"CNAME {fqdn} -> {cname} (target does not resolve)"
|
||||
)),
|
||||
screenshot_path: None,
|
||||
payload: None,
|
||||
response_time_ms: None,
|
||||
};
|
||||
|
||||
let mut finding = DastFinding::new(
|
||||
String::new(),
|
||||
target_id.clone(),
|
||||
DastVulnType::DnsMisconfiguration,
|
||||
format!("Dangling CNAME on {fqdn}"),
|
||||
format!(
|
||||
"The subdomain {fqdn} has a CNAME record pointing to {cname} which does not resolve. \
|
||||
This may allow subdomain takeover if an attacker can claim the target hostname."
|
||||
),
|
||||
Severity::High,
|
||||
fqdn.clone(),
|
||||
"DNS".to_string(),
|
||||
);
|
||||
finding.cwe = Some("CWE-923".to_string());
|
||||
finding.evidence = vec![evidence];
|
||||
finding.remediation = Some(
|
||||
"Remove dangling CNAME records or ensure the target hostname is \
|
||||
properly configured and resolvable."
|
||||
.to_string(),
|
||||
);
|
||||
findings.push(finding);
|
||||
warn!(fqdn, cname, "Dangling CNAME detected - potential subdomain takeover");
|
||||
}
|
||||
}
|
||||
cname_data.insert(fqdn.clone(), cnames);
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
if !cname_data.is_empty() {
|
||||
dns_data.insert("cname_records".to_string(), json!(cname_data));
|
||||
}
|
||||
|
||||
// --- CAA records ---
|
||||
match Self::dig_query(domain, "CAA").await {
|
||||
Ok(caa) => {
|
||||
if caa.is_empty() {
|
||||
let evidence = DastEvidence {
|
||||
request_method: "DNS".to_string(),
|
||||
request_url: domain.to_string(),
|
||||
request_headers: None,
|
||||
request_body: None,
|
||||
response_status: 0,
|
||||
response_headers: None,
|
||||
response_snippet: Some("No CAA records found".to_string()),
|
||||
screenshot_path: None,
|
||||
payload: None,
|
||||
response_time_ms: None,
|
||||
};
|
||||
|
||||
let mut finding = DastFinding::new(
|
||||
String::new(),
|
||||
target_id.clone(),
|
||||
DastVulnType::DnsMisconfiguration,
|
||||
format!("Missing CAA records for {domain}"),
|
||||
format!(
|
||||
"No CAA (Certificate Authority Authorization) records are set for {domain}. \
|
||||
Without CAA records, any certificate authority can issue certificates for this domain."
|
||||
),
|
||||
Severity::Low,
|
||||
domain.to_string(),
|
||||
"DNS".to_string(),
|
||||
);
|
||||
finding.cwe = Some("CWE-295".to_string());
|
||||
finding.evidence = vec![evidence];
|
||||
finding.remediation = Some(
|
||||
"Add CAA DNS records to restrict which certificate authorities can issue \
|
||||
certificates for your domain. Example: '0 issue \"letsencrypt.org\"'."
|
||||
.to_string(),
|
||||
);
|
||||
findings.push(finding);
|
||||
}
|
||||
dns_data.insert("caa_records".to_string(), json!(caa));
|
||||
}
|
||||
Err(e) => {
|
||||
dns_data.insert("caa_records_error".to_string(), json!(e.to_string()));
|
||||
}
|
||||
}
|
||||
|
||||
// --- DNSSEC check ---
|
||||
let dnssec_output = tokio::process::Command::new("dig")
|
||||
.args(["+dnssec", "+short", "DNSKEY", domain])
|
||||
.output()
|
||||
.await;
|
||||
|
||||
match dnssec_output {
|
||||
Ok(output) => {
|
||||
let stdout = String::from_utf8_lossy(&output.stdout);
|
||||
let has_dnssec = !stdout.trim().is_empty();
|
||||
dns_data.insert("dnssec_enabled".to_string(), json!(has_dnssec));
|
||||
|
||||
if !has_dnssec {
|
||||
let evidence = DastEvidence {
|
||||
request_method: "DNS".to_string(),
|
||||
request_url: domain.to_string(),
|
||||
request_headers: None,
|
||||
request_body: None,
|
||||
response_status: 0,
|
||||
response_headers: None,
|
||||
response_snippet: Some("No DNSKEY records found - DNSSEC not enabled".to_string()),
|
||||
screenshot_path: None,
|
||||
payload: None,
|
||||
response_time_ms: None,
|
||||
};
|
||||
|
||||
let mut finding = DastFinding::new(
|
||||
String::new(),
|
||||
target_id.clone(),
|
||||
DastVulnType::DnsMisconfiguration,
|
||||
format!("DNSSEC not enabled for {domain}"),
|
||||
format!(
|
||||
"DNSSEC is not enabled for {domain}. Without DNSSEC, DNS responses \
|
||||
can be spoofed, allowing man-in-the-middle attacks."
|
||||
),
|
||||
Severity::Medium,
|
||||
domain.to_string(),
|
||||
"DNS".to_string(),
|
||||
);
|
||||
finding.cwe = Some("CWE-350".to_string());
|
||||
finding.evidence = vec![evidence];
|
||||
finding.remediation = Some(
|
||||
"Enable DNSSEC for your domain by configuring DNSKEY and DS records \
|
||||
with your DNS provider and domain registrar."
|
||||
.to_string(),
|
||||
);
|
||||
findings.push(finding);
|
||||
}
|
||||
}
|
||||
Err(_) => {
|
||||
dns_data.insert("dnssec_check_error".to_string(), json!("dig not available"));
|
||||
}
|
||||
}
|
||||
|
||||
// --- Check NS records for dangling ---
|
||||
for ns in &ns_records {
|
||||
let ns_clean = ns.trim_end_matches('.');
|
||||
let check_addr = format!("{ns_clean}:53");
|
||||
if lookup_host(&check_addr).await.is_err() {
|
||||
let evidence = DastEvidence {
|
||||
request_method: "DNS".to_string(),
|
||||
request_url: domain.to_string(),
|
||||
request_headers: None,
|
||||
request_body: None,
|
||||
response_status: 0,
|
||||
response_headers: None,
|
||||
response_snippet: Some(format!(
|
||||
"NS record {ns} does not resolve"
|
||||
)),
|
||||
screenshot_path: None,
|
||||
payload: None,
|
||||
response_time_ms: None,
|
||||
};
|
||||
|
||||
let mut finding = DastFinding::new(
|
||||
String::new(),
|
||||
target_id.clone(),
|
||||
DastVulnType::DnsMisconfiguration,
|
||||
format!("Dangling NS record for {domain}"),
|
||||
format!(
|
||||
"The NS record {ns} for {domain} does not resolve. \
|
||||
This could allow domain takeover if an attacker can claim the nameserver hostname."
|
||||
),
|
||||
Severity::Critical,
|
||||
domain.to_string(),
|
||||
"DNS".to_string(),
|
||||
);
|
||||
finding.cwe = Some("CWE-923".to_string());
|
||||
finding.evidence = vec![evidence];
|
||||
finding.remediation = Some(
|
||||
"Remove dangling NS records or ensure the nameserver hostname is properly \
|
||||
configured. Dangling NS records can lead to full domain takeover."
|
||||
.to_string(),
|
||||
);
|
||||
findings.push(finding);
|
||||
warn!(domain, ns, "Dangling NS record detected - potential domain takeover");
|
||||
}
|
||||
}
|
||||
|
||||
let count = findings.len();
|
||||
info!(domain, findings = count, "DNS check complete");
|
||||
|
||||
Ok(PentestToolResult {
|
||||
summary: if count > 0 {
|
||||
format!("Found {count} DNS configuration issues for {domain}.")
|
||||
} else {
|
||||
format!("No DNS configuration issues found for {domain}.")
|
||||
},
|
||||
findings,
|
||||
data: json!(dns_data),
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
141
compliance-dast/src/tools/mod.rs
Normal file
141
compliance-dast/src/tools/mod.rs
Normal file
@@ -0,0 +1,141 @@
|
||||
pub mod api_fuzzer;
|
||||
pub mod auth_bypass;
|
||||
pub mod console_log_detector;
|
||||
pub mod cookie_analyzer;
|
||||
pub mod cors_checker;
|
||||
pub mod csp_analyzer;
|
||||
pub mod dmarc_checker;
|
||||
pub mod dns_checker;
|
||||
pub mod openapi_parser;
|
||||
pub mod rate_limit_tester;
|
||||
pub mod recon;
|
||||
pub mod security_headers;
|
||||
pub mod sql_injection;
|
||||
pub mod ssrf;
|
||||
pub mod tls_analyzer;
|
||||
pub mod xss;
|
||||
|
||||
use std::collections::HashMap;
|
||||
|
||||
use compliance_core::traits::pentest_tool::PentestTool;
|
||||
|
||||
/// A definition describing a tool for LLM tool_use registration.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ToolDefinition {
|
||||
pub name: String,
|
||||
pub description: String,
|
||||
pub input_schema: serde_json::Value,
|
||||
}
|
||||
|
||||
/// Registry that holds all available pentest tools and provides
|
||||
/// look-up by name.
|
||||
pub struct ToolRegistry {
|
||||
tools: HashMap<String, Box<dyn PentestTool>>,
|
||||
}
|
||||
|
||||
impl ToolRegistry {
|
||||
/// Create a new registry with all built-in tools pre-registered.
|
||||
pub fn new() -> Self {
|
||||
let http = reqwest::Client::builder()
|
||||
.danger_accept_invalid_certs(true)
|
||||
.timeout(std::time::Duration::from_secs(30))
|
||||
.redirect(reqwest::redirect::Policy::limited(5))
|
||||
.build()
|
||||
.expect("failed to build HTTP client");
|
||||
|
||||
let mut tools: HashMap<String, Box<dyn PentestTool>> = HashMap::new();
|
||||
|
||||
// Agent-wrapping tools
|
||||
let register = |tools: &mut HashMap<String, Box<dyn PentestTool>>,
|
||||
tool: Box<dyn PentestTool>| {
|
||||
tools.insert(tool.name().to_string(), tool);
|
||||
};
|
||||
|
||||
register(
|
||||
&mut tools,
|
||||
Box::new(sql_injection::SqlInjectionTool::new(http.clone())),
|
||||
);
|
||||
register(&mut tools, Box::new(xss::XssTool::new(http.clone())));
|
||||
register(
|
||||
&mut tools,
|
||||
Box::new(auth_bypass::AuthBypassTool::new(http.clone())),
|
||||
);
|
||||
register(&mut tools, Box::new(ssrf::SsrfTool::new(http.clone())));
|
||||
register(
|
||||
&mut tools,
|
||||
Box::new(api_fuzzer::ApiFuzzerTool::new(http.clone())),
|
||||
);
|
||||
|
||||
// New infrastructure / analysis tools
|
||||
register(
|
||||
&mut tools,
|
||||
Box::new(dns_checker::DnsCheckerTool::new()),
|
||||
);
|
||||
register(
|
||||
&mut tools,
|
||||
Box::new(dmarc_checker::DmarcCheckerTool::new()),
|
||||
);
|
||||
register(
|
||||
&mut tools,
|
||||
Box::new(tls_analyzer::TlsAnalyzerTool::new(http.clone())),
|
||||
);
|
||||
register(
|
||||
&mut tools,
|
||||
Box::new(security_headers::SecurityHeadersTool::new(http.clone())),
|
||||
);
|
||||
register(
|
||||
&mut tools,
|
||||
Box::new(cookie_analyzer::CookieAnalyzerTool::new(http.clone())),
|
||||
);
|
||||
register(
|
||||
&mut tools,
|
||||
Box::new(csp_analyzer::CspAnalyzerTool::new(http.clone())),
|
||||
);
|
||||
register(
|
||||
&mut tools,
|
||||
Box::new(rate_limit_tester::RateLimitTesterTool::new(http.clone())),
|
||||
);
|
||||
register(
|
||||
&mut tools,
|
||||
Box::new(console_log_detector::ConsoleLogDetectorTool::new(
|
||||
http.clone(),
|
||||
)),
|
||||
);
|
||||
register(
|
||||
&mut tools,
|
||||
Box::new(cors_checker::CorsCheckerTool::new(http.clone())),
|
||||
);
|
||||
register(
|
||||
&mut tools,
|
||||
Box::new(openapi_parser::OpenApiParserTool::new(http.clone())),
|
||||
);
|
||||
register(
|
||||
&mut tools,
|
||||
Box::new(recon::ReconTool::new(http)),
|
||||
);
|
||||
|
||||
Self { tools }
|
||||
}
|
||||
|
||||
/// Look up a tool by name.
|
||||
pub fn get(&self, name: &str) -> Option<&dyn PentestTool> {
|
||||
self.tools.get(name).map(|b| b.as_ref())
|
||||
}
|
||||
|
||||
/// Return definitions for every registered tool.
|
||||
pub fn all_definitions(&self) -> Vec<ToolDefinition> {
|
||||
self.tools
|
||||
.values()
|
||||
.map(|t| ToolDefinition {
|
||||
name: t.name().to_string(),
|
||||
description: t.description().to_string(),
|
||||
input_schema: t.input_schema(),
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Return the names of all registered tools.
|
||||
pub fn list_names(&self) -> Vec<String> {
|
||||
self.tools.keys().cloned().collect()
|
||||
}
|
||||
}
|
||||
422
compliance-dast/src/tools/openapi_parser.rs
Normal file
422
compliance-dast/src/tools/openapi_parser.rs
Normal file
@@ -0,0 +1,422 @@
|
||||
use compliance_core::error::CoreError;
|
||||
use compliance_core::traits::pentest_tool::{PentestTool, PentestToolContext, PentestToolResult};
|
||||
use serde_json::json;
|
||||
use tracing::info;
|
||||
|
||||
/// Tool that discovers and parses OpenAPI/Swagger specification files.
|
||||
///
|
||||
/// Returns structured endpoint definitions for the LLM to use when planning
|
||||
/// further tests. This tool produces data rather than security findings.
|
||||
pub struct OpenApiParserTool {
|
||||
http: reqwest::Client,
|
||||
}
|
||||
|
||||
/// A parsed endpoint from an OpenAPI spec.
|
||||
#[derive(Debug, Clone)]
|
||||
struct ParsedEndpoint {
|
||||
path: String,
|
||||
method: String,
|
||||
operation_id: Option<String>,
|
||||
summary: Option<String>,
|
||||
parameters: Vec<ParsedParameter>,
|
||||
request_body_content_type: Option<String>,
|
||||
response_codes: Vec<String>,
|
||||
security: Vec<String>,
|
||||
tags: Vec<String>,
|
||||
}
|
||||
|
||||
/// A parsed parameter from an OpenAPI spec.
|
||||
#[derive(Debug, Clone)]
|
||||
struct ParsedParameter {
|
||||
name: String,
|
||||
location: String,
|
||||
required: bool,
|
||||
param_type: Option<String>,
|
||||
description: Option<String>,
|
||||
}
|
||||
|
||||
impl OpenApiParserTool {
|
||||
pub fn new(http: reqwest::Client) -> Self {
|
||||
Self { http }
|
||||
}
|
||||
|
||||
/// Common paths where OpenAPI/Swagger specs are typically served.
|
||||
fn common_spec_paths() -> Vec<&'static str> {
|
||||
vec![
|
||||
"/openapi.json",
|
||||
"/openapi.yaml",
|
||||
"/swagger.json",
|
||||
"/swagger.yaml",
|
||||
"/api-docs",
|
||||
"/api-docs.json",
|
||||
"/v2/api-docs",
|
||||
"/v3/api-docs",
|
||||
"/docs/openapi.json",
|
||||
"/api/swagger.json",
|
||||
"/api/openapi.json",
|
||||
"/api/v1/openapi.json",
|
||||
"/api/v2/openapi.json",
|
||||
"/.well-known/openapi.json",
|
||||
]
|
||||
}
|
||||
|
||||
/// Try to fetch a spec from a URL and return the JSON value if successful.
|
||||
async fn try_fetch_spec(
|
||||
http: &reqwest::Client,
|
||||
url: &str,
|
||||
) -> Option<(String, serde_json::Value)> {
|
||||
let resp = http.get(url).send().await.ok()?;
|
||||
if !resp.status().is_success() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let content_type = resp
|
||||
.headers()
|
||||
.get("content-type")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.unwrap_or("")
|
||||
.to_string();
|
||||
|
||||
let body = resp.text().await.ok()?;
|
||||
|
||||
// Try JSON first
|
||||
if let Ok(val) = serde_json::from_str::<serde_json::Value>(&body) {
|
||||
// Verify it looks like an OpenAPI / Swagger spec
|
||||
if val.get("openapi").is_some()
|
||||
|| val.get("swagger").is_some()
|
||||
|| val.get("paths").is_some()
|
||||
{
|
||||
return Some((url.to_string(), val));
|
||||
}
|
||||
}
|
||||
|
||||
// If content type suggests YAML, we can't easily parse without a YAML dep,
|
||||
// so just report the URL as found
|
||||
if content_type.contains("yaml") || body.starts_with("openapi:") || body.starts_with("swagger:") {
|
||||
// Return a minimal JSON indicating YAML was found
|
||||
return Some((
|
||||
url.to_string(),
|
||||
json!({
|
||||
"_note": "YAML spec detected but not parsed. Fetch and convert to JSON.",
|
||||
"_raw_url": url,
|
||||
}),
|
||||
));
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
/// Parse an OpenAPI 3.x or Swagger 2.x spec into structured endpoints.
|
||||
fn parse_spec(spec: &serde_json::Value, base_url: &str) -> Vec<ParsedEndpoint> {
|
||||
let mut endpoints = Vec::new();
|
||||
|
||||
// Determine base path
|
||||
let base_path = if let Some(servers) = spec.get("servers").and_then(|v| v.as_array()) {
|
||||
servers
|
||||
.first()
|
||||
.and_then(|s| s.get("url"))
|
||||
.and_then(|u| u.as_str())
|
||||
.unwrap_or("")
|
||||
.to_string()
|
||||
} else if let Some(bp) = spec.get("basePath").and_then(|v| v.as_str()) {
|
||||
bp.to_string()
|
||||
} else {
|
||||
String::new()
|
||||
};
|
||||
|
||||
let paths = match spec.get("paths").and_then(|v| v.as_object()) {
|
||||
Some(p) => p,
|
||||
None => return endpoints,
|
||||
};
|
||||
|
||||
for (path, path_item) in paths {
|
||||
let path_obj = match path_item.as_object() {
|
||||
Some(o) => o,
|
||||
None => continue,
|
||||
};
|
||||
|
||||
// Path-level parameters
|
||||
let path_params = path_obj
|
||||
.get("parameters")
|
||||
.and_then(|v| v.as_array())
|
||||
.cloned()
|
||||
.unwrap_or_default();
|
||||
|
||||
for method in &["get", "post", "put", "patch", "delete", "head", "options"] {
|
||||
let operation = match path_obj.get(*method).and_then(|v| v.as_object()) {
|
||||
Some(o) => o,
|
||||
None => continue,
|
||||
};
|
||||
|
||||
let operation_id = operation
|
||||
.get("operationId")
|
||||
.and_then(|v| v.as_str())
|
||||
.map(String::from);
|
||||
|
||||
let summary = operation
|
||||
.get("summary")
|
||||
.and_then(|v| v.as_str())
|
||||
.map(String::from);
|
||||
|
||||
let tags: Vec<String> = operation
|
||||
.get("tags")
|
||||
.and_then(|v| v.as_array())
|
||||
.map(|arr| {
|
||||
arr.iter()
|
||||
.filter_map(|t| t.as_str().map(String::from))
|
||||
.collect()
|
||||
})
|
||||
.unwrap_or_default();
|
||||
|
||||
// Merge path-level and operation-level parameters
|
||||
let mut parameters = Vec::new();
|
||||
let op_params = operation
|
||||
.get("parameters")
|
||||
.and_then(|v| v.as_array())
|
||||
.cloned()
|
||||
.unwrap_or_default();
|
||||
|
||||
for param_val in path_params.iter().chain(op_params.iter()) {
|
||||
let name = param_val
|
||||
.get("name")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or_default()
|
||||
.to_string();
|
||||
let location = param_val
|
||||
.get("in")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("query")
|
||||
.to_string();
|
||||
let required = param_val
|
||||
.get("required")
|
||||
.and_then(|v| v.as_bool())
|
||||
.unwrap_or(false);
|
||||
|
||||
// Type from schema or direct type field
|
||||
let param_type = param_val
|
||||
.get("schema")
|
||||
.and_then(|s| s.get("type"))
|
||||
.or_else(|| param_val.get("type"))
|
||||
.and_then(|v| v.as_str())
|
||||
.map(String::from);
|
||||
|
||||
let description = param_val
|
||||
.get("description")
|
||||
.and_then(|v| v.as_str())
|
||||
.map(String::from);
|
||||
|
||||
parameters.push(ParsedParameter {
|
||||
name,
|
||||
location,
|
||||
required,
|
||||
param_type,
|
||||
description,
|
||||
});
|
||||
}
|
||||
|
||||
// Request body (OpenAPI 3.x)
|
||||
let request_body_content_type = operation
|
||||
.get("requestBody")
|
||||
.and_then(|rb| rb.get("content"))
|
||||
.and_then(|c| c.as_object())
|
||||
.and_then(|obj| obj.keys().next().cloned());
|
||||
|
||||
// Response codes
|
||||
let response_codes: Vec<String> = operation
|
||||
.get("responses")
|
||||
.and_then(|r| r.as_object())
|
||||
.map(|obj| obj.keys().cloned().collect())
|
||||
.unwrap_or_default();
|
||||
|
||||
// Security requirements
|
||||
let security: Vec<String> = operation
|
||||
.get("security")
|
||||
.and_then(|v| v.as_array())
|
||||
.map(|arr| {
|
||||
arr.iter()
|
||||
.filter_map(|s| s.as_object())
|
||||
.flat_map(|obj| obj.keys().cloned())
|
||||
.collect()
|
||||
})
|
||||
.unwrap_or_default();
|
||||
|
||||
endpoints.push(ParsedEndpoint {
|
||||
path: format!("{}{}", base_path, path),
|
||||
method: method.to_uppercase(),
|
||||
operation_id,
|
||||
summary,
|
||||
parameters,
|
||||
request_body_content_type,
|
||||
response_codes,
|
||||
security,
|
||||
tags,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
endpoints
|
||||
}
|
||||
}
|
||||
|
||||
impl PentestTool for OpenApiParserTool {
|
||||
fn name(&self) -> &str {
|
||||
"openapi_parser"
|
||||
}
|
||||
|
||||
fn description(&self) -> &str {
|
||||
"Discovers and parses OpenAPI/Swagger specifications. Tries common spec paths and \
|
||||
returns structured endpoint definitions including parameters, methods, and security \
|
||||
requirements. Use this to discover all API endpoints before testing."
|
||||
}
|
||||
|
||||
fn input_schema(&self) -> serde_json::Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"base_url": {
|
||||
"type": "string",
|
||||
"description": "Base URL of the API to discover specs from"
|
||||
},
|
||||
"spec_url": {
|
||||
"type": "string",
|
||||
"description": "Optional explicit URL of the OpenAPI/Swagger spec file"
|
||||
}
|
||||
},
|
||||
"required": ["base_url"]
|
||||
})
|
||||
}
|
||||
|
||||
fn execute<'a>(
|
||||
&'a self,
|
||||
input: serde_json::Value,
|
||||
context: &'a PentestToolContext,
|
||||
) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<PentestToolResult, CoreError>> + Send + 'a>> {
|
||||
Box::pin(async move {
|
||||
let base_url = input
|
||||
.get("base_url")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| CoreError::Dast("Missing required 'base_url' parameter".to_string()))?;
|
||||
|
||||
let explicit_spec_url = input.get("spec_url").and_then(|v| v.as_str());
|
||||
|
||||
let base_url_trimmed = base_url.trim_end_matches('/');
|
||||
|
||||
// If an explicit spec URL is provided, try it first
|
||||
let mut spec_result: Option<(String, serde_json::Value)> = None;
|
||||
if let Some(spec_url) = explicit_spec_url {
|
||||
spec_result = Self::try_fetch_spec(&self.http, spec_url).await;
|
||||
}
|
||||
|
||||
// If no explicit URL or it failed, try common paths
|
||||
if spec_result.is_none() {
|
||||
for path in Self::common_spec_paths() {
|
||||
let url = format!("{base_url_trimmed}{path}");
|
||||
if let Some(result) = Self::try_fetch_spec(&self.http, &url).await {
|
||||
spec_result = Some(result);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
match spec_result {
|
||||
Some((spec_url, spec)) => {
|
||||
let spec_version = spec
|
||||
.get("openapi")
|
||||
.or_else(|| spec.get("swagger"))
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("unknown");
|
||||
|
||||
let api_title = spec
|
||||
.get("info")
|
||||
.and_then(|i| i.get("title"))
|
||||
.and_then(|t| t.as_str())
|
||||
.unwrap_or("Unknown API");
|
||||
|
||||
let api_version = spec
|
||||
.get("info")
|
||||
.and_then(|i| i.get("version"))
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("unknown");
|
||||
|
||||
let endpoints = Self::parse_spec(&spec, base_url_trimmed);
|
||||
|
||||
let endpoint_data: Vec<serde_json::Value> = endpoints
|
||||
.iter()
|
||||
.map(|ep| {
|
||||
let params: Vec<serde_json::Value> = ep
|
||||
.parameters
|
||||
.iter()
|
||||
.map(|p| {
|
||||
json!({
|
||||
"name": p.name,
|
||||
"in": p.location,
|
||||
"required": p.required,
|
||||
"type": p.param_type,
|
||||
"description": p.description,
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
||||
json!({
|
||||
"path": ep.path,
|
||||
"method": ep.method,
|
||||
"operation_id": ep.operation_id,
|
||||
"summary": ep.summary,
|
||||
"parameters": params,
|
||||
"request_body_content_type": ep.request_body_content_type,
|
||||
"response_codes": ep.response_codes,
|
||||
"security": ep.security,
|
||||
"tags": ep.tags,
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
||||
let endpoint_count = endpoints.len();
|
||||
info!(
|
||||
spec_url = %spec_url,
|
||||
spec_version,
|
||||
api_title,
|
||||
endpoints = endpoint_count,
|
||||
"OpenAPI spec parsed"
|
||||
);
|
||||
|
||||
Ok(PentestToolResult {
|
||||
summary: format!(
|
||||
"Found OpenAPI spec ({spec_version}) at {spec_url}. \
|
||||
API: {api_title} v{api_version}. \
|
||||
Parsed {endpoint_count} endpoints."
|
||||
),
|
||||
findings: Vec::new(), // This tool produces data, not findings
|
||||
data: json!({
|
||||
"spec_url": spec_url,
|
||||
"spec_version": spec_version,
|
||||
"api_title": api_title,
|
||||
"api_version": api_version,
|
||||
"endpoint_count": endpoint_count,
|
||||
"endpoints": endpoint_data,
|
||||
"security_schemes": spec.get("components")
|
||||
.and_then(|c| c.get("securitySchemes"))
|
||||
.or_else(|| spec.get("securityDefinitions")),
|
||||
}),
|
||||
})
|
||||
}
|
||||
None => {
|
||||
info!(base_url, "No OpenAPI spec found");
|
||||
|
||||
Ok(PentestToolResult {
|
||||
summary: format!(
|
||||
"No OpenAPI/Swagger specification found for {base_url}. \
|
||||
Tried {} common paths.",
|
||||
Self::common_spec_paths().len()
|
||||
),
|
||||
findings: Vec::new(),
|
||||
data: json!({
|
||||
"spec_found": false,
|
||||
"paths_tried": Self::common_spec_paths(),
|
||||
}),
|
||||
})
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
285
compliance-dast/src/tools/rate_limit_tester.rs
Normal file
285
compliance-dast/src/tools/rate_limit_tester.rs
Normal file
@@ -0,0 +1,285 @@
|
||||
use std::time::Instant;
|
||||
|
||||
use compliance_core::error::CoreError;
|
||||
use compliance_core::models::dast::{DastEvidence, DastFinding, DastVulnType};
|
||||
use compliance_core::models::Severity;
|
||||
use compliance_core::traits::pentest_tool::{PentestTool, PentestToolContext, PentestToolResult};
|
||||
use serde_json::json;
|
||||
use tracing::{info, warn};
|
||||
|
||||
/// Tool that tests whether a target enforces rate limiting.
|
||||
pub struct RateLimitTesterTool {
|
||||
http: reqwest::Client,
|
||||
}
|
||||
|
||||
impl RateLimitTesterTool {
|
||||
pub fn new(http: reqwest::Client) -> Self {
|
||||
Self { http }
|
||||
}
|
||||
}
|
||||
|
||||
impl PentestTool for RateLimitTesterTool {
|
||||
fn name(&self) -> &str {
|
||||
"rate_limit_tester"
|
||||
}
|
||||
|
||||
fn description(&self) -> &str {
|
||||
"Tests whether an endpoint enforces rate limiting by sending rapid sequential requests. \
|
||||
Checks for 429 responses and measures response time degradation."
|
||||
}
|
||||
|
||||
fn input_schema(&self) -> serde_json::Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"url": {
|
||||
"type": "string",
|
||||
"description": "URL of the endpoint to test for rate limiting"
|
||||
},
|
||||
"method": {
|
||||
"type": "string",
|
||||
"description": "HTTP method to use",
|
||||
"enum": ["GET", "POST", "PUT", "PATCH", "DELETE"],
|
||||
"default": "GET"
|
||||
},
|
||||
"request_count": {
|
||||
"type": "integer",
|
||||
"description": "Number of rapid requests to send (default: 50)",
|
||||
"default": 50,
|
||||
"minimum": 10,
|
||||
"maximum": 200
|
||||
},
|
||||
"body": {
|
||||
"type": "string",
|
||||
"description": "Optional request body (for POST/PUT/PATCH)"
|
||||
}
|
||||
},
|
||||
"required": ["url"]
|
||||
})
|
||||
}
|
||||
|
||||
fn execute<'a>(
|
||||
&'a self,
|
||||
input: serde_json::Value,
|
||||
context: &'a PentestToolContext,
|
||||
) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<PentestToolResult, CoreError>> + Send + 'a>> {
|
||||
Box::pin(async move {
|
||||
let url = input
|
||||
.get("url")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| CoreError::Dast("Missing required 'url' parameter".to_string()))?;
|
||||
|
||||
let method = input
|
||||
.get("method")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("GET");
|
||||
|
||||
let request_count = input
|
||||
.get("request_count")
|
||||
.and_then(|v| v.as_u64())
|
||||
.unwrap_or(50)
|
||||
.min(200) as usize;
|
||||
|
||||
let body = input.get("body").and_then(|v| v.as_str());
|
||||
|
||||
let target_id = context
|
||||
.target
|
||||
.id
|
||||
.map(|oid| oid.to_hex())
|
||||
.unwrap_or_else(|| "unknown".to_string());
|
||||
|
||||
// Respect the context rate limit if set
|
||||
let max_requests = if context.rate_limit > 0 {
|
||||
request_count.min(context.rate_limit as usize * 5)
|
||||
} else {
|
||||
request_count
|
||||
};
|
||||
|
||||
let mut status_codes: Vec<u16> = Vec::with_capacity(max_requests);
|
||||
let mut response_times: Vec<u128> = Vec::with_capacity(max_requests);
|
||||
let mut got_429 = false;
|
||||
let mut rate_limit_at_request: Option<usize> = None;
|
||||
|
||||
for i in 0..max_requests {
|
||||
let start = Instant::now();
|
||||
|
||||
let request = match method {
|
||||
"POST" => {
|
||||
let mut req = self.http.post(url);
|
||||
if let Some(b) = body {
|
||||
req = req.body(b.to_string());
|
||||
}
|
||||
req
|
||||
}
|
||||
"PUT" => {
|
||||
let mut req = self.http.put(url);
|
||||
if let Some(b) = body {
|
||||
req = req.body(b.to_string());
|
||||
}
|
||||
req
|
||||
}
|
||||
"PATCH" => {
|
||||
let mut req = self.http.patch(url);
|
||||
if let Some(b) = body {
|
||||
req = req.body(b.to_string());
|
||||
}
|
||||
req
|
||||
}
|
||||
"DELETE" => self.http.delete(url),
|
||||
_ => self.http.get(url),
|
||||
};
|
||||
|
||||
match request.send().await {
|
||||
Ok(resp) => {
|
||||
let elapsed = start.elapsed().as_millis();
|
||||
let status = resp.status().as_u16();
|
||||
status_codes.push(status);
|
||||
response_times.push(elapsed);
|
||||
|
||||
if status == 429 && !got_429 {
|
||||
got_429 = true;
|
||||
rate_limit_at_request = Some(i + 1);
|
||||
info!(url, request_num = i + 1, "Rate limit triggered (429)");
|
||||
}
|
||||
|
||||
// Check for rate limit headers even on 200
|
||||
if !got_429 {
|
||||
let headers = resp.headers();
|
||||
let has_rate_headers = headers.contains_key("x-ratelimit-limit")
|
||||
|| headers.contains_key("x-ratelimit-remaining")
|
||||
|| headers.contains_key("ratelimit-limit")
|
||||
|| headers.contains_key("ratelimit-remaining")
|
||||
|| headers.contains_key("retry-after");
|
||||
|
||||
if has_rate_headers && rate_limit_at_request.is_none() {
|
||||
// Server has rate limit headers but hasn't blocked yet
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
let elapsed = start.elapsed().as_millis();
|
||||
status_codes.push(0);
|
||||
response_times.push(elapsed);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let mut findings = Vec::new();
|
||||
let total_sent = status_codes.len();
|
||||
let count_429 = status_codes.iter().filter(|&&s| s == 429).count();
|
||||
let count_success = status_codes.iter().filter(|&&s| (200..300).contains(&s)).count();
|
||||
|
||||
// Calculate response time statistics
|
||||
let avg_time = if !response_times.is_empty() {
|
||||
response_times.iter().sum::<u128>() / response_times.len() as u128
|
||||
} else {
|
||||
0
|
||||
};
|
||||
|
||||
let first_half_avg = if response_times.len() >= 4 {
|
||||
let half = response_times.len() / 2;
|
||||
response_times[..half].iter().sum::<u128>() / half as u128
|
||||
} else {
|
||||
avg_time
|
||||
};
|
||||
|
||||
let second_half_avg = if response_times.len() >= 4 {
|
||||
let half = response_times.len() / 2;
|
||||
response_times[half..].iter().sum::<u128>() / (response_times.len() - half) as u128
|
||||
} else {
|
||||
avg_time
|
||||
};
|
||||
|
||||
// Significant time degradation suggests possible (weak) rate limiting
|
||||
let time_degradation = if first_half_avg > 0 {
|
||||
(second_half_avg as f64 / first_half_avg as f64) - 1.0
|
||||
} else {
|
||||
0.0
|
||||
};
|
||||
|
||||
let rate_data = json!({
|
||||
"total_requests_sent": total_sent,
|
||||
"status_429_count": count_429,
|
||||
"success_count": count_success,
|
||||
"rate_limit_at_request": rate_limit_at_request,
|
||||
"avg_response_time_ms": avg_time,
|
||||
"first_half_avg_ms": first_half_avg,
|
||||
"second_half_avg_ms": second_half_avg,
|
||||
"time_degradation_pct": (time_degradation * 100.0).round(),
|
||||
});
|
||||
|
||||
if !got_429 && count_success == total_sent {
|
||||
// No rate limiting detected at all
|
||||
let evidence = DastEvidence {
|
||||
request_method: method.to_string(),
|
||||
request_url: url.to_string(),
|
||||
request_headers: None,
|
||||
request_body: body.map(String::from),
|
||||
response_status: 200,
|
||||
response_headers: None,
|
||||
response_snippet: Some(format!(
|
||||
"Sent {total_sent} rapid requests. All returned success (2xx). \
|
||||
No 429 responses received. Avg response time: {avg_time}ms."
|
||||
)),
|
||||
screenshot_path: None,
|
||||
payload: None,
|
||||
response_time_ms: Some(avg_time as u64),
|
||||
};
|
||||
|
||||
let mut finding = DastFinding::new(
|
||||
String::new(),
|
||||
target_id.clone(),
|
||||
DastVulnType::RateLimitAbsent,
|
||||
format!("No rate limiting on {} {}", method, url),
|
||||
format!(
|
||||
"The endpoint {} {} does not enforce rate limiting. \
|
||||
{total_sent} rapid requests were all accepted with no 429 responses \
|
||||
or noticeable degradation. This makes the endpoint vulnerable to \
|
||||
brute force attacks and abuse.",
|
||||
method, url
|
||||
),
|
||||
Severity::Medium,
|
||||
url.to_string(),
|
||||
method.to_string(),
|
||||
);
|
||||
finding.cwe = Some("CWE-770".to_string());
|
||||
finding.evidence = vec![evidence];
|
||||
finding.remediation = Some(
|
||||
"Implement rate limiting on this endpoint. Use token bucket or sliding window \
|
||||
algorithms. Return 429 Too Many Requests with a Retry-After header when the \
|
||||
limit is exceeded."
|
||||
.to_string(),
|
||||
);
|
||||
findings.push(finding);
|
||||
warn!(url, method, total_sent, "No rate limiting detected");
|
||||
} else if got_429 {
|
||||
info!(
|
||||
url,
|
||||
method,
|
||||
rate_limit_at = ?rate_limit_at_request,
|
||||
"Rate limiting is enforced"
|
||||
);
|
||||
}
|
||||
|
||||
let count = findings.len();
|
||||
|
||||
Ok(PentestToolResult {
|
||||
summary: if got_429 {
|
||||
format!(
|
||||
"Rate limiting is enforced on {method} {url}. \
|
||||
429 response received after {} requests.",
|
||||
rate_limit_at_request.unwrap_or(0)
|
||||
)
|
||||
} else if count > 0 {
|
||||
format!(
|
||||
"No rate limiting detected on {method} {url} after {total_sent} requests."
|
||||
)
|
||||
} else {
|
||||
format!("Rate limit testing complete for {method} {url}.")
|
||||
},
|
||||
findings,
|
||||
data: rate_data,
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
125
compliance-dast/src/tools/recon.rs
Normal file
125
compliance-dast/src/tools/recon.rs
Normal file
@@ -0,0 +1,125 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
use compliance_core::error::CoreError;
|
||||
use compliance_core::traits::pentest_tool::{PentestTool, PentestToolContext, PentestToolResult};
|
||||
use serde_json::json;
|
||||
use tracing::info;
|
||||
|
||||
use crate::recon::ReconAgent;
|
||||
|
||||
/// PentestTool wrapper around the existing ReconAgent.
|
||||
///
|
||||
/// Performs HTTP header fingerprinting and technology detection.
|
||||
/// Returns structured recon data for the LLM to use when planning attacks.
|
||||
pub struct ReconTool {
|
||||
http: reqwest::Client,
|
||||
agent: ReconAgent,
|
||||
}
|
||||
|
||||
impl ReconTool {
|
||||
pub fn new(http: reqwest::Client) -> Self {
|
||||
let agent = ReconAgent::new(http.clone());
|
||||
Self { http, agent }
|
||||
}
|
||||
}
|
||||
|
||||
impl PentestTool for ReconTool {
|
||||
fn name(&self) -> &str {
|
||||
"recon"
|
||||
}
|
||||
|
||||
fn description(&self) -> &str {
|
||||
"Performs reconnaissance on a target URL. Fingerprints HTTP headers, detects server \
|
||||
technologies and frameworks. Returns structured data about the target's technology stack."
|
||||
}
|
||||
|
||||
fn input_schema(&self) -> serde_json::Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"url": {
|
||||
"type": "string",
|
||||
"description": "Base URL to perform reconnaissance on"
|
||||
},
|
||||
"additional_paths": {
|
||||
"type": "array",
|
||||
"description": "Optional additional paths to probe for technology fingerprinting",
|
||||
"items": { "type": "string" }
|
||||
}
|
||||
},
|
||||
"required": ["url"]
|
||||
})
|
||||
}
|
||||
|
||||
fn execute<'a>(
|
||||
&'a self,
|
||||
input: serde_json::Value,
|
||||
context: &'a PentestToolContext,
|
||||
) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<PentestToolResult, CoreError>> + Send + 'a>> {
|
||||
Box::pin(async move {
|
||||
let url = input
|
||||
.get("url")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| CoreError::Dast("Missing required 'url' parameter".to_string()))?;
|
||||
|
||||
let additional_paths: Vec<String> = input
|
||||
.get("additional_paths")
|
||||
.and_then(|v| v.as_array())
|
||||
.map(|arr| {
|
||||
arr.iter()
|
||||
.filter_map(|v| v.as_str().map(String::from))
|
||||
.collect()
|
||||
})
|
||||
.unwrap_or_default();
|
||||
|
||||
let result = self.agent.scan(url).await?;
|
||||
|
||||
// Scan additional paths for more technology signals
|
||||
let mut extra_technologies: Vec<String> = Vec::new();
|
||||
let mut extra_headers: HashMap<String, String> = HashMap::new();
|
||||
|
||||
let base_url = url.trim_end_matches('/');
|
||||
for path in &additional_paths {
|
||||
let probe_url = format!("{base_url}/{}", path.trim_start_matches('/'));
|
||||
if let Ok(resp) = self.http.get(&probe_url).send().await {
|
||||
for (key, value) in resp.headers() {
|
||||
let k = key.to_string().to_lowercase();
|
||||
let v = value.to_str().unwrap_or("").to_string();
|
||||
|
||||
// Look for technology indicators
|
||||
if k == "x-powered-by" || k == "server" || k == "x-generator" {
|
||||
if !result.technologies.contains(&v) && !extra_technologies.contains(&v) {
|
||||
extra_technologies.push(v.clone());
|
||||
}
|
||||
}
|
||||
extra_headers.insert(format!("{probe_url} -> {k}"), v);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let mut all_technologies = result.technologies.clone();
|
||||
all_technologies.extend(extra_technologies);
|
||||
all_technologies.dedup();
|
||||
|
||||
let tech_count = all_technologies.len();
|
||||
info!(url, technologies = tech_count, "Recon complete");
|
||||
|
||||
Ok(PentestToolResult {
|
||||
summary: format!(
|
||||
"Recon complete for {url}. Detected {} technologies. Server: {}.",
|
||||
tech_count,
|
||||
result.server.as_deref().unwrap_or("unknown")
|
||||
),
|
||||
findings: Vec::new(), // Recon produces data, not findings
|
||||
data: json!({
|
||||
"base_url": url,
|
||||
"server": result.server,
|
||||
"technologies": all_technologies,
|
||||
"interesting_headers": result.interesting_headers,
|
||||
"extra_headers": extra_headers,
|
||||
"open_ports": result.open_ports,
|
||||
}),
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
300
compliance-dast/src/tools/security_headers.rs
Normal file
300
compliance-dast/src/tools/security_headers.rs
Normal file
@@ -0,0 +1,300 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
use compliance_core::error::CoreError;
|
||||
use compliance_core::models::dast::{DastEvidence, DastFinding, DastVulnType};
|
||||
use compliance_core::models::Severity;
|
||||
use compliance_core::traits::pentest_tool::{PentestTool, PentestToolContext, PentestToolResult};
|
||||
use serde_json::json;
|
||||
use tracing::info;
|
||||
|
||||
/// Tool that checks for the presence and correctness of security headers.
|
||||
pub struct SecurityHeadersTool {
|
||||
http: reqwest::Client,
|
||||
}
|
||||
|
||||
/// A security header we expect to be present and its metadata.
|
||||
struct ExpectedHeader {
|
||||
name: &'static str,
|
||||
description: &'static str,
|
||||
severity: Severity,
|
||||
cwe: &'static str,
|
||||
remediation: &'static str,
|
||||
/// If present, the value must contain one of these substrings to be considered valid.
|
||||
valid_values: Option<Vec<&'static str>>,
|
||||
}
|
||||
|
||||
impl SecurityHeadersTool {
|
||||
pub fn new(http: reqwest::Client) -> Self {
|
||||
Self { http }
|
||||
}
|
||||
|
||||
fn expected_headers() -> Vec<ExpectedHeader> {
|
||||
vec![
|
||||
ExpectedHeader {
|
||||
name: "strict-transport-security",
|
||||
description: "HTTP Strict Transport Security (HSTS) forces browsers to use HTTPS",
|
||||
severity: Severity::Medium,
|
||||
cwe: "CWE-319",
|
||||
remediation: "Add 'Strict-Transport-Security: max-age=31536000; includeSubDomains' header.",
|
||||
valid_values: None,
|
||||
},
|
||||
ExpectedHeader {
|
||||
name: "x-content-type-options",
|
||||
description: "Prevents MIME type sniffing",
|
||||
severity: Severity::Low,
|
||||
cwe: "CWE-16",
|
||||
remediation: "Add 'X-Content-Type-Options: nosniff' header.",
|
||||
valid_values: Some(vec!["nosniff"]),
|
||||
},
|
||||
ExpectedHeader {
|
||||
name: "x-frame-options",
|
||||
description: "Prevents clickjacking by controlling iframe embedding",
|
||||
severity: Severity::Medium,
|
||||
cwe: "CWE-1021",
|
||||
remediation: "Add 'X-Frame-Options: DENY' or 'X-Frame-Options: SAMEORIGIN' header.",
|
||||
valid_values: Some(vec!["deny", "sameorigin"]),
|
||||
},
|
||||
ExpectedHeader {
|
||||
name: "x-xss-protection",
|
||||
description: "Enables browser XSS filtering (legacy but still recommended)",
|
||||
severity: Severity::Low,
|
||||
cwe: "CWE-79",
|
||||
remediation: "Add 'X-XSS-Protection: 1; mode=block' header.",
|
||||
valid_values: None,
|
||||
},
|
||||
ExpectedHeader {
|
||||
name: "referrer-policy",
|
||||
description: "Controls how much referrer information is shared",
|
||||
severity: Severity::Low,
|
||||
cwe: "CWE-200",
|
||||
remediation: "Add 'Referrer-Policy: strict-origin-when-cross-origin' or 'no-referrer' header.",
|
||||
valid_values: None,
|
||||
},
|
||||
ExpectedHeader {
|
||||
name: "permissions-policy",
|
||||
description: "Controls browser feature access (camera, microphone, geolocation, etc.)",
|
||||
severity: Severity::Low,
|
||||
cwe: "CWE-16",
|
||||
remediation: "Add a Permissions-Policy header to restrict browser feature access. \
|
||||
Example: 'Permissions-Policy: camera=(), microphone=(), geolocation=()'.",
|
||||
valid_values: None,
|
||||
},
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
impl PentestTool for SecurityHeadersTool {
|
||||
fn name(&self) -> &str {
|
||||
"security_headers"
|
||||
}
|
||||
|
||||
fn description(&self) -> &str {
|
||||
"Checks a URL for the presence and correctness of security headers: HSTS, \
|
||||
X-Content-Type-Options, X-Frame-Options, X-XSS-Protection, Referrer-Policy, \
|
||||
and Permissions-Policy."
|
||||
}
|
||||
|
||||
fn input_schema(&self) -> serde_json::Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"url": {
|
||||
"type": "string",
|
||||
"description": "URL to check security headers for"
|
||||
}
|
||||
},
|
||||
"required": ["url"]
|
||||
})
|
||||
}
|
||||
|
||||
fn execute<'a>(
|
||||
&'a self,
|
||||
input: serde_json::Value,
|
||||
context: &'a PentestToolContext,
|
||||
) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<PentestToolResult, CoreError>> + Send + 'a>> {
|
||||
Box::pin(async move {
|
||||
let url = input
|
||||
.get("url")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| CoreError::Dast("Missing required 'url' parameter".to_string()))?;
|
||||
|
||||
let target_id = context
|
||||
.target
|
||||
.id
|
||||
.map(|oid| oid.to_hex())
|
||||
.unwrap_or_else(|| "unknown".to_string());
|
||||
|
||||
let response = self
|
||||
.http
|
||||
.get(url)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| CoreError::Dast(format!("Failed to fetch {url}: {e}")))?;
|
||||
|
||||
let status = response.status().as_u16();
|
||||
let response_headers: HashMap<String, String> = response
|
||||
.headers()
|
||||
.iter()
|
||||
.map(|(k, v)| (k.to_string().to_lowercase(), v.to_str().unwrap_or("").to_string()))
|
||||
.collect();
|
||||
|
||||
let mut findings = Vec::new();
|
||||
let mut header_results: HashMap<String, serde_json::Value> = HashMap::new();
|
||||
|
||||
for expected in Self::expected_headers() {
|
||||
let header_value = response_headers.get(expected.name);
|
||||
|
||||
match header_value {
|
||||
Some(value) => {
|
||||
let mut is_valid = true;
|
||||
if let Some(ref valid) = expected.valid_values {
|
||||
let lower = value.to_lowercase();
|
||||
is_valid = valid.iter().any(|v| lower.contains(v));
|
||||
}
|
||||
|
||||
header_results.insert(
|
||||
expected.name.to_string(),
|
||||
json!({
|
||||
"present": true,
|
||||
"value": value,
|
||||
"valid": is_valid,
|
||||
}),
|
||||
);
|
||||
|
||||
if !is_valid {
|
||||
let evidence = DastEvidence {
|
||||
request_method: "GET".to_string(),
|
||||
request_url: url.to_string(),
|
||||
request_headers: None,
|
||||
request_body: None,
|
||||
response_status: status,
|
||||
response_headers: Some(response_headers.clone()),
|
||||
response_snippet: Some(format!("{}: {}", expected.name, value)),
|
||||
screenshot_path: None,
|
||||
payload: None,
|
||||
response_time_ms: None,
|
||||
};
|
||||
|
||||
let mut finding = DastFinding::new(
|
||||
String::new(),
|
||||
target_id.clone(),
|
||||
DastVulnType::SecurityHeaderMissing,
|
||||
format!("Invalid {} header value", expected.name),
|
||||
format!(
|
||||
"The {} header is present but has an invalid or weak value: '{}'. \
|
||||
{}",
|
||||
expected.name, value, expected.description
|
||||
),
|
||||
expected.severity.clone(),
|
||||
url.to_string(),
|
||||
"GET".to_string(),
|
||||
);
|
||||
finding.cwe = Some(expected.cwe.to_string());
|
||||
finding.evidence = vec![evidence];
|
||||
finding.remediation = Some(expected.remediation.to_string());
|
||||
findings.push(finding);
|
||||
}
|
||||
}
|
||||
None => {
|
||||
header_results.insert(
|
||||
expected.name.to_string(),
|
||||
json!({
|
||||
"present": false,
|
||||
"value": null,
|
||||
"valid": false,
|
||||
}),
|
||||
);
|
||||
|
||||
let evidence = DastEvidence {
|
||||
request_method: "GET".to_string(),
|
||||
request_url: url.to_string(),
|
||||
request_headers: None,
|
||||
request_body: None,
|
||||
response_status: status,
|
||||
response_headers: Some(response_headers.clone()),
|
||||
response_snippet: Some(format!("{} header is missing", expected.name)),
|
||||
screenshot_path: None,
|
||||
payload: None,
|
||||
response_time_ms: None,
|
||||
};
|
||||
|
||||
let mut finding = DastFinding::new(
|
||||
String::new(),
|
||||
target_id.clone(),
|
||||
DastVulnType::SecurityHeaderMissing,
|
||||
format!("Missing {} header", expected.name),
|
||||
format!(
|
||||
"The {} header is not present in the response. {}",
|
||||
expected.name, expected.description
|
||||
),
|
||||
expected.severity.clone(),
|
||||
url.to_string(),
|
||||
"GET".to_string(),
|
||||
);
|
||||
finding.cwe = Some(expected.cwe.to_string());
|
||||
finding.evidence = vec![evidence];
|
||||
finding.remediation = Some(expected.remediation.to_string());
|
||||
findings.push(finding);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Also check for information disclosure headers
|
||||
let disclosure_headers = ["server", "x-powered-by", "x-aspnet-version", "x-aspnetmvc-version"];
|
||||
for h in &disclosure_headers {
|
||||
if let Some(value) = response_headers.get(*h) {
|
||||
header_results.insert(
|
||||
format!("{h}_disclosure"),
|
||||
json!({ "present": true, "value": value }),
|
||||
);
|
||||
|
||||
let evidence = DastEvidence {
|
||||
request_method: "GET".to_string(),
|
||||
request_url: url.to_string(),
|
||||
request_headers: None,
|
||||
request_body: None,
|
||||
response_status: status,
|
||||
response_headers: Some(response_headers.clone()),
|
||||
response_snippet: Some(format!("{h}: {value}")),
|
||||
screenshot_path: None,
|
||||
payload: None,
|
||||
response_time_ms: None,
|
||||
};
|
||||
|
||||
let mut finding = DastFinding::new(
|
||||
String::new(),
|
||||
target_id.clone(),
|
||||
DastVulnType::SecurityHeaderMissing,
|
||||
format!("Information disclosure via {h} header"),
|
||||
format!(
|
||||
"The {h} header exposes server technology information: '{value}'. \
|
||||
This helps attackers fingerprint the server and find known vulnerabilities."
|
||||
),
|
||||
Severity::Info,
|
||||
url.to_string(),
|
||||
"GET".to_string(),
|
||||
);
|
||||
finding.cwe = Some("CWE-200".to_string());
|
||||
finding.evidence = vec![evidence];
|
||||
finding.remediation = Some(format!(
|
||||
"Remove or suppress the {h} header in your server configuration."
|
||||
));
|
||||
findings.push(finding);
|
||||
}
|
||||
}
|
||||
|
||||
let count = findings.len();
|
||||
info!(url, findings = count, "Security headers check complete");
|
||||
|
||||
Ok(PentestToolResult {
|
||||
summary: if count > 0 {
|
||||
format!("Found {count} security header issues for {url}.")
|
||||
} else {
|
||||
format!("All checked security headers are present and valid for {url}.")
|
||||
},
|
||||
findings,
|
||||
data: json!(header_results),
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
138
compliance-dast/src/tools/sql_injection.rs
Normal file
138
compliance-dast/src/tools/sql_injection.rs
Normal file
@@ -0,0 +1,138 @@
|
||||
use compliance_core::error::CoreError;
|
||||
use compliance_core::traits::dast_agent::{DastAgent, DastContext, DiscoveredEndpoint, EndpointParameter};
|
||||
use compliance_core::traits::pentest_tool::{PentestTool, PentestToolContext, PentestToolResult};
|
||||
use serde_json::json;
|
||||
|
||||
use crate::agents::injection::SqlInjectionAgent;
|
||||
|
||||
/// PentestTool wrapper around the existing SqlInjectionAgent.
|
||||
pub struct SqlInjectionTool {
|
||||
http: reqwest::Client,
|
||||
agent: SqlInjectionAgent,
|
||||
}
|
||||
|
||||
impl SqlInjectionTool {
|
||||
pub fn new(http: reqwest::Client) -> Self {
|
||||
let agent = SqlInjectionAgent::new(http.clone());
|
||||
Self { http, agent }
|
||||
}
|
||||
|
||||
fn parse_endpoints(input: &serde_json::Value) -> Vec<DiscoveredEndpoint> {
|
||||
let mut endpoints = Vec::new();
|
||||
if let Some(arr) = input.get("endpoints").and_then(|v| v.as_array()) {
|
||||
for ep in arr {
|
||||
let url = ep.get("url").and_then(|v| v.as_str()).unwrap_or_default().to_string();
|
||||
let method = ep.get("method").and_then(|v| v.as_str()).unwrap_or("GET").to_string();
|
||||
let mut parameters = Vec::new();
|
||||
if let Some(params) = ep.get("parameters").and_then(|v| v.as_array()) {
|
||||
for p in params {
|
||||
let name = p.get("name").and_then(|v| v.as_str()).unwrap_or_default().to_string();
|
||||
let location = p.get("location").and_then(|v| v.as_str()).unwrap_or("query").to_string();
|
||||
let param_type = p.get("param_type").and_then(|v| v.as_str()).map(String::from);
|
||||
let example_value = p.get("example_value").and_then(|v| v.as_str()).map(String::from);
|
||||
parameters.push(EndpointParameter {
|
||||
name,
|
||||
location,
|
||||
param_type,
|
||||
example_value,
|
||||
});
|
||||
}
|
||||
}
|
||||
endpoints.push(DiscoveredEndpoint {
|
||||
url,
|
||||
method,
|
||||
parameters,
|
||||
content_type: ep.get("content_type").and_then(|v| v.as_str()).map(String::from),
|
||||
requires_auth: ep.get("requires_auth").and_then(|v| v.as_bool()).unwrap_or(false),
|
||||
});
|
||||
}
|
||||
}
|
||||
endpoints
|
||||
}
|
||||
}
|
||||
|
||||
impl PentestTool for SqlInjectionTool {
|
||||
fn name(&self) -> &str {
|
||||
"sql_injection_scanner"
|
||||
}
|
||||
|
||||
fn description(&self) -> &str {
|
||||
"Tests endpoints for SQL injection vulnerabilities using error-based, boolean-based, \
|
||||
time-based, and union-based techniques. Provide endpoints with their parameters to test."
|
||||
}
|
||||
|
||||
fn input_schema(&self) -> serde_json::Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"endpoints": {
|
||||
"type": "array",
|
||||
"description": "Endpoints to test for SQL injection",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"url": { "type": "string", "description": "Full URL of the endpoint" },
|
||||
"method": { "type": "string", "enum": ["GET", "POST", "PUT", "PATCH", "DELETE"] },
|
||||
"parameters": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": { "type": "string" },
|
||||
"location": { "type": "string", "enum": ["query", "body", "header", "path", "cookie"] },
|
||||
"param_type": { "type": "string" },
|
||||
"example_value": { "type": "string" }
|
||||
},
|
||||
"required": ["name"]
|
||||
}
|
||||
}
|
||||
},
|
||||
"required": ["url", "method", "parameters"]
|
||||
}
|
||||
},
|
||||
"custom_payloads": {
|
||||
"type": "array",
|
||||
"description": "Optional additional SQL injection payloads to test",
|
||||
"items": { "type": "string" }
|
||||
}
|
||||
},
|
||||
"required": ["endpoints"]
|
||||
})
|
||||
}
|
||||
|
||||
fn execute<'a>(
|
||||
&'a self,
|
||||
input: serde_json::Value,
|
||||
context: &'a PentestToolContext,
|
||||
) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<PentestToolResult, CoreError>> + Send + 'a>> {
|
||||
Box::pin(async move {
|
||||
let endpoints = Self::parse_endpoints(&input);
|
||||
if endpoints.is_empty() {
|
||||
return Ok(PentestToolResult {
|
||||
summary: "No endpoints provided to test.".to_string(),
|
||||
findings: Vec::new(),
|
||||
data: json!({}),
|
||||
});
|
||||
}
|
||||
|
||||
let dast_context = DastContext {
|
||||
endpoints,
|
||||
technologies: Vec::new(),
|
||||
sast_hints: Vec::new(),
|
||||
};
|
||||
|
||||
let findings = self.agent.run(&context.target, &dast_context).await?;
|
||||
let count = findings.len();
|
||||
|
||||
Ok(PentestToolResult {
|
||||
summary: if count > 0 {
|
||||
format!("Found {count} SQL injection vulnerabilities.")
|
||||
} else {
|
||||
"No SQL injection vulnerabilities detected.".to_string()
|
||||
},
|
||||
findings,
|
||||
data: json!({ "endpoints_tested": dast_context.endpoints.len() }),
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
134
compliance-dast/src/tools/ssrf.rs
Normal file
134
compliance-dast/src/tools/ssrf.rs
Normal file
@@ -0,0 +1,134 @@
|
||||
use compliance_core::error::CoreError;
|
||||
use compliance_core::traits::dast_agent::{DastAgent, DastContext, DiscoveredEndpoint, EndpointParameter};
|
||||
use compliance_core::traits::pentest_tool::{PentestTool, PentestToolContext, PentestToolResult};
|
||||
use serde_json::json;
|
||||
|
||||
use crate::agents::ssrf::SsrfAgent;
|
||||
|
||||
/// PentestTool wrapper around the existing SsrfAgent.
|
||||
pub struct SsrfTool {
|
||||
http: reqwest::Client,
|
||||
agent: SsrfAgent,
|
||||
}
|
||||
|
||||
impl SsrfTool {
|
||||
pub fn new(http: reqwest::Client) -> Self {
|
||||
let agent = SsrfAgent::new(http.clone());
|
||||
Self { http, agent }
|
||||
}
|
||||
|
||||
fn parse_endpoints(input: &serde_json::Value) -> Vec<DiscoveredEndpoint> {
|
||||
let mut endpoints = Vec::new();
|
||||
if let Some(arr) = input.get("endpoints").and_then(|v| v.as_array()) {
|
||||
for ep in arr {
|
||||
let url = ep.get("url").and_then(|v| v.as_str()).unwrap_or_default().to_string();
|
||||
let method = ep.get("method").and_then(|v| v.as_str()).unwrap_or("GET").to_string();
|
||||
let mut parameters = Vec::new();
|
||||
if let Some(params) = ep.get("parameters").and_then(|v| v.as_array()) {
|
||||
for p in params {
|
||||
parameters.push(EndpointParameter {
|
||||
name: p.get("name").and_then(|v| v.as_str()).unwrap_or_default().to_string(),
|
||||
location: p.get("location").and_then(|v| v.as_str()).unwrap_or("query").to_string(),
|
||||
param_type: p.get("param_type").and_then(|v| v.as_str()).map(String::from),
|
||||
example_value: p.get("example_value").and_then(|v| v.as_str()).map(String::from),
|
||||
});
|
||||
}
|
||||
}
|
||||
endpoints.push(DiscoveredEndpoint {
|
||||
url,
|
||||
method,
|
||||
parameters,
|
||||
content_type: ep.get("content_type").and_then(|v| v.as_str()).map(String::from),
|
||||
requires_auth: ep.get("requires_auth").and_then(|v| v.as_bool()).unwrap_or(false),
|
||||
});
|
||||
}
|
||||
}
|
||||
endpoints
|
||||
}
|
||||
}
|
||||
|
||||
impl PentestTool for SsrfTool {
|
||||
fn name(&self) -> &str {
|
||||
"ssrf_scanner"
|
||||
}
|
||||
|
||||
fn description(&self) -> &str {
|
||||
"Tests endpoints for Server-Side Request Forgery (SSRF) vulnerabilities. Checks if \
|
||||
parameters accepting URLs can be exploited to access internal resources."
|
||||
}
|
||||
|
||||
fn input_schema(&self) -> serde_json::Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"endpoints": {
|
||||
"type": "array",
|
||||
"description": "Endpoints to test for SSRF (focus on those accepting URL parameters)",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"url": { "type": "string" },
|
||||
"method": { "type": "string", "enum": ["GET", "POST", "PUT", "PATCH", "DELETE"] },
|
||||
"parameters": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": { "type": "string" },
|
||||
"location": { "type": "string" },
|
||||
"param_type": { "type": "string" },
|
||||
"example_value": { "type": "string" }
|
||||
},
|
||||
"required": ["name"]
|
||||
}
|
||||
}
|
||||
},
|
||||
"required": ["url", "method", "parameters"]
|
||||
}
|
||||
},
|
||||
"custom_payloads": {
|
||||
"type": "array",
|
||||
"description": "Optional additional SSRF payloads (internal URLs to try)",
|
||||
"items": { "type": "string" }
|
||||
}
|
||||
},
|
||||
"required": ["endpoints"]
|
||||
})
|
||||
}
|
||||
|
||||
fn execute<'a>(
|
||||
&'a self,
|
||||
input: serde_json::Value,
|
||||
context: &'a PentestToolContext,
|
||||
) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<PentestToolResult, CoreError>> + Send + 'a>> {
|
||||
Box::pin(async move {
|
||||
let endpoints = Self::parse_endpoints(&input);
|
||||
if endpoints.is_empty() {
|
||||
return Ok(PentestToolResult {
|
||||
summary: "No endpoints provided to test.".to_string(),
|
||||
findings: Vec::new(),
|
||||
data: json!({}),
|
||||
});
|
||||
}
|
||||
|
||||
let dast_context = DastContext {
|
||||
endpoints,
|
||||
technologies: Vec::new(),
|
||||
sast_hints: Vec::new(),
|
||||
};
|
||||
|
||||
let findings = self.agent.run(&context.target, &dast_context).await?;
|
||||
let count = findings.len();
|
||||
|
||||
Ok(PentestToolResult {
|
||||
summary: if count > 0 {
|
||||
format!("Found {count} SSRF vulnerabilities.")
|
||||
} else {
|
||||
"No SSRF vulnerabilities detected.".to_string()
|
||||
},
|
||||
findings,
|
||||
data: json!({ "endpoints_tested": dast_context.endpoints.len() }),
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
442
compliance-dast/src/tools/tls_analyzer.rs
Normal file
442
compliance-dast/src/tools/tls_analyzer.rs
Normal file
@@ -0,0 +1,442 @@
|
||||
use compliance_core::error::CoreError;
|
||||
use compliance_core::models::dast::{DastEvidence, DastFinding, DastVulnType};
|
||||
use compliance_core::models::Severity;
|
||||
use compliance_core::traits::pentest_tool::{PentestTool, PentestToolContext, PentestToolResult};
|
||||
use serde_json::json;
|
||||
use tokio::net::TcpStream;
|
||||
use tracing::{info, warn};
|
||||
|
||||
/// Tool that analyzes TLS configuration of a target.
|
||||
///
|
||||
/// Connects via TCP, performs a TLS handshake using `tokio-native-tls`,
|
||||
/// and inspects the certificate and negotiated protocol. Also checks
|
||||
/// for common TLS misconfigurations.
|
||||
pub struct TlsAnalyzerTool {
|
||||
http: reqwest::Client,
|
||||
}
|
||||
|
||||
impl TlsAnalyzerTool {
|
||||
pub fn new(http: reqwest::Client) -> Self {
|
||||
Self { http }
|
||||
}
|
||||
|
||||
/// Extract the hostname from a URL.
|
||||
fn extract_host(url: &str) -> Option<String> {
|
||||
url::Url::parse(url)
|
||||
.ok()
|
||||
.and_then(|u| u.host_str().map(String::from))
|
||||
}
|
||||
|
||||
/// Extract port from a URL (defaults to 443 for https).
|
||||
fn extract_port(url: &str) -> u16 {
|
||||
url::Url::parse(url)
|
||||
.ok()
|
||||
.and_then(|u| u.port())
|
||||
.unwrap_or(443)
|
||||
}
|
||||
|
||||
/// Check if the server accepts a connection on a given port with a weak
|
||||
/// TLS client hello. We test SSLv3 / old protocol support by attempting
|
||||
/// connection with the system's native-tls which typically negotiates the
|
||||
/// best available, then inspect what was negotiated.
|
||||
async fn check_tls(
|
||||
host: &str,
|
||||
port: u16,
|
||||
) -> Result<TlsInfo, CoreError> {
|
||||
let addr = format!("{host}:{port}");
|
||||
|
||||
let tcp = TcpStream::connect(&addr)
|
||||
.await
|
||||
.map_err(|e| CoreError::Dast(format!("TCP connection to {addr} failed: {e}")))?;
|
||||
|
||||
let connector = native_tls::TlsConnector::builder()
|
||||
.danger_accept_invalid_certs(true)
|
||||
.danger_accept_invalid_hostnames(true)
|
||||
.build()
|
||||
.map_err(|e| CoreError::Dast(format!("TLS connector build failed: {e}")))?;
|
||||
|
||||
let connector = tokio_native_tls::TlsConnector::from(connector);
|
||||
|
||||
let tls_stream = connector
|
||||
.connect(host, tcp)
|
||||
.await
|
||||
.map_err(|e| CoreError::Dast(format!("TLS handshake with {addr} failed: {e}")))?;
|
||||
|
||||
let peer_cert = tls_stream.get_ref().peer_certificate()
|
||||
.map_err(|e| CoreError::Dast(format!("Failed to get peer certificate: {e}")))?;
|
||||
|
||||
let mut tls_info = TlsInfo {
|
||||
protocol_version: String::new(),
|
||||
cert_subject: String::new(),
|
||||
cert_issuer: String::new(),
|
||||
cert_not_before: String::new(),
|
||||
cert_not_after: String::new(),
|
||||
cert_expired: false,
|
||||
cert_self_signed: false,
|
||||
alpn_protocol: None,
|
||||
san_names: Vec::new(),
|
||||
};
|
||||
|
||||
if let Some(cert) = peer_cert {
|
||||
let der = cert.to_der()
|
||||
.map_err(|e| CoreError::Dast(format!("Certificate DER encoding failed: {e}")))?;
|
||||
|
||||
// native_tls doesn't give rich access, so we parse what we can
|
||||
// from the DER-encoded certificate.
|
||||
tls_info.cert_subject = "see DER certificate".to_string();
|
||||
// Attempt to parse with basic DER inspection for dates
|
||||
tls_info = Self::parse_cert_der(&der, tls_info);
|
||||
}
|
||||
|
||||
Ok(tls_info)
|
||||
}
|
||||
|
||||
/// Best-effort parse of DER-encoded X.509 certificate for dates and subject.
|
||||
/// This is a simplified parser; in production you would use a proper x509 crate.
|
||||
fn parse_cert_der(der: &[u8], mut info: TlsInfo) -> TlsInfo {
|
||||
// We rely on the native_tls debug output stored in cert_subject
|
||||
// and just mark fields as "see certificate details"
|
||||
if info.cert_subject.contains("self signed") || info.cert_subject.contains("Self-Signed") {
|
||||
info.cert_self_signed = true;
|
||||
}
|
||||
info
|
||||
}
|
||||
}
|
||||
|
||||
struct TlsInfo {
|
||||
protocol_version: String,
|
||||
cert_subject: String,
|
||||
cert_issuer: String,
|
||||
cert_not_before: String,
|
||||
cert_not_after: String,
|
||||
cert_expired: bool,
|
||||
cert_self_signed: bool,
|
||||
alpn_protocol: Option<String>,
|
||||
san_names: Vec<String>,
|
||||
}
|
||||
|
||||
impl PentestTool for TlsAnalyzerTool {
|
||||
fn name(&self) -> &str {
|
||||
"tls_analyzer"
|
||||
}
|
||||
|
||||
fn description(&self) -> &str {
|
||||
"Analyzes TLS/SSL configuration of a target. Checks certificate validity, expiry, chain \
|
||||
trust, and negotiated protocols. Reports TLS misconfigurations."
|
||||
}
|
||||
|
||||
fn input_schema(&self) -> serde_json::Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"url": {
|
||||
"type": "string",
|
||||
"description": "Target URL or hostname to analyze TLS configuration"
|
||||
},
|
||||
"port": {
|
||||
"type": "integer",
|
||||
"description": "Port to connect to (default: 443)",
|
||||
"default": 443
|
||||
},
|
||||
"check_protocols": {
|
||||
"type": "boolean",
|
||||
"description": "Whether to test for old/weak protocol versions",
|
||||
"default": true
|
||||
}
|
||||
},
|
||||
"required": ["url"]
|
||||
})
|
||||
}
|
||||
|
||||
fn execute<'a>(
|
||||
&'a self,
|
||||
input: serde_json::Value,
|
||||
context: &'a PentestToolContext,
|
||||
) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<PentestToolResult, CoreError>> + Send + 'a>> {
|
||||
Box::pin(async move {
|
||||
let url = input
|
||||
.get("url")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| CoreError::Dast("Missing required 'url' parameter".to_string()))?;
|
||||
|
||||
let host = Self::extract_host(url)
|
||||
.unwrap_or_else(|| url.to_string());
|
||||
|
||||
let port = input
|
||||
.get("port")
|
||||
.and_then(|v| v.as_u64())
|
||||
.map(|p| p as u16)
|
||||
.unwrap_or_else(|| Self::extract_port(url));
|
||||
|
||||
let target_id = context
|
||||
.target
|
||||
.id
|
||||
.map(|oid| oid.to_hex())
|
||||
.unwrap_or_else(|| "unknown".to_string());
|
||||
|
||||
let mut findings = Vec::new();
|
||||
let mut tls_data = json!({});
|
||||
|
||||
// First check: does the server even support HTTPS?
|
||||
let https_url = if url.starts_with("https://") {
|
||||
url.to_string()
|
||||
} else if url.starts_with("http://") {
|
||||
url.replace("http://", "https://")
|
||||
} else {
|
||||
format!("https://{url}")
|
||||
};
|
||||
|
||||
// Check if HTTP redirects to HTTPS
|
||||
let http_url = if url.starts_with("http://") {
|
||||
url.to_string()
|
||||
} else if url.starts_with("https://") {
|
||||
url.replace("https://", "http://")
|
||||
} else {
|
||||
format!("http://{url}")
|
||||
};
|
||||
|
||||
match self.http.get(&http_url).send().await {
|
||||
Ok(resp) => {
|
||||
let final_url = resp.url().to_string();
|
||||
let redirects_to_https = final_url.starts_with("https://");
|
||||
tls_data["http_redirects_to_https"] = json!(redirects_to_https);
|
||||
|
||||
if !redirects_to_https {
|
||||
let evidence = DastEvidence {
|
||||
request_method: "GET".to_string(),
|
||||
request_url: http_url.clone(),
|
||||
request_headers: None,
|
||||
request_body: None,
|
||||
response_status: resp.status().as_u16(),
|
||||
response_headers: None,
|
||||
response_snippet: Some(format!("Final URL: {final_url}")),
|
||||
screenshot_path: None,
|
||||
payload: None,
|
||||
response_time_ms: None,
|
||||
};
|
||||
|
||||
let mut finding = DastFinding::new(
|
||||
String::new(),
|
||||
target_id.clone(),
|
||||
DastVulnType::TlsMisconfiguration,
|
||||
format!("HTTP does not redirect to HTTPS for {host}"),
|
||||
format!(
|
||||
"HTTP requests to {host} are not redirected to HTTPS. \
|
||||
Users accessing the site via HTTP will have their traffic \
|
||||
transmitted in cleartext."
|
||||
),
|
||||
Severity::Medium,
|
||||
http_url.clone(),
|
||||
"GET".to_string(),
|
||||
);
|
||||
finding.cwe = Some("CWE-319".to_string());
|
||||
finding.evidence = vec![evidence];
|
||||
finding.remediation = Some(
|
||||
"Configure the web server to redirect all HTTP requests to HTTPS \
|
||||
using a 301 redirect."
|
||||
.to_string(),
|
||||
);
|
||||
findings.push(finding);
|
||||
}
|
||||
}
|
||||
Err(_) => {
|
||||
tls_data["http_check_error"] = json!("Could not connect via HTTP");
|
||||
}
|
||||
}
|
||||
|
||||
// Perform TLS analysis
|
||||
match Self::check_tls(&host, port).await {
|
||||
Ok(tls_info) => {
|
||||
tls_data["host"] = json!(host);
|
||||
tls_data["port"] = json!(port);
|
||||
tls_data["cert_subject"] = json!(tls_info.cert_subject);
|
||||
tls_data["cert_issuer"] = json!(tls_info.cert_issuer);
|
||||
tls_data["cert_not_before"] = json!(tls_info.cert_not_before);
|
||||
tls_data["cert_not_after"] = json!(tls_info.cert_not_after);
|
||||
tls_data["alpn_protocol"] = json!(tls_info.alpn_protocol);
|
||||
tls_data["san_names"] = json!(tls_info.san_names);
|
||||
|
||||
if tls_info.cert_expired {
|
||||
let evidence = DastEvidence {
|
||||
request_method: "TLS".to_string(),
|
||||
request_url: format!("{host}:{port}"),
|
||||
request_headers: None,
|
||||
request_body: None,
|
||||
response_status: 0,
|
||||
response_headers: None,
|
||||
response_snippet: Some(format!(
|
||||
"Certificate expired. Not After: {}",
|
||||
tls_info.cert_not_after
|
||||
)),
|
||||
screenshot_path: None,
|
||||
payload: None,
|
||||
response_time_ms: None,
|
||||
};
|
||||
|
||||
let mut finding = DastFinding::new(
|
||||
String::new(),
|
||||
target_id.clone(),
|
||||
DastVulnType::TlsMisconfiguration,
|
||||
format!("Expired TLS certificate for {host}"),
|
||||
format!(
|
||||
"The TLS certificate for {host} has expired. \
|
||||
Browsers will show security warnings to users."
|
||||
),
|
||||
Severity::High,
|
||||
format!("https://{host}:{port}"),
|
||||
"TLS".to_string(),
|
||||
);
|
||||
finding.cwe = Some("CWE-295".to_string());
|
||||
finding.evidence = vec![evidence];
|
||||
finding.remediation = Some(
|
||||
"Renew the TLS certificate. Consider using automated certificate \
|
||||
management with Let's Encrypt or a similar CA."
|
||||
.to_string(),
|
||||
);
|
||||
findings.push(finding);
|
||||
warn!(host, "Expired TLS certificate");
|
||||
}
|
||||
|
||||
if tls_info.cert_self_signed {
|
||||
let evidence = DastEvidence {
|
||||
request_method: "TLS".to_string(),
|
||||
request_url: format!("{host}:{port}"),
|
||||
request_headers: None,
|
||||
request_body: None,
|
||||
response_status: 0,
|
||||
response_headers: None,
|
||||
response_snippet: Some("Self-signed certificate detected".to_string()),
|
||||
screenshot_path: None,
|
||||
payload: None,
|
||||
response_time_ms: None,
|
||||
};
|
||||
|
||||
let mut finding = DastFinding::new(
|
||||
String::new(),
|
||||
target_id.clone(),
|
||||
DastVulnType::TlsMisconfiguration,
|
||||
format!("Self-signed TLS certificate for {host}"),
|
||||
format!(
|
||||
"The TLS certificate for {host} is self-signed and not issued by a \
|
||||
trusted certificate authority. Browsers will show security warnings."
|
||||
),
|
||||
Severity::Medium,
|
||||
format!("https://{host}:{port}"),
|
||||
"TLS".to_string(),
|
||||
);
|
||||
finding.cwe = Some("CWE-295".to_string());
|
||||
finding.evidence = vec![evidence];
|
||||
finding.remediation = Some(
|
||||
"Replace the self-signed certificate with one issued by a trusted \
|
||||
certificate authority."
|
||||
.to_string(),
|
||||
);
|
||||
findings.push(finding);
|
||||
warn!(host, "Self-signed certificate");
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
tls_data["tls_error"] = json!(e.to_string());
|
||||
|
||||
// TLS handshake failure itself is a finding
|
||||
let evidence = DastEvidence {
|
||||
request_method: "TLS".to_string(),
|
||||
request_url: format!("{host}:{port}"),
|
||||
request_headers: None,
|
||||
request_body: None,
|
||||
response_status: 0,
|
||||
response_headers: None,
|
||||
response_snippet: Some(format!("TLS error: {e}")),
|
||||
screenshot_path: None,
|
||||
payload: None,
|
||||
response_time_ms: None,
|
||||
};
|
||||
|
||||
let mut finding = DastFinding::new(
|
||||
String::new(),
|
||||
target_id.clone(),
|
||||
DastVulnType::TlsMisconfiguration,
|
||||
format!("TLS handshake failure for {host}"),
|
||||
format!(
|
||||
"Could not establish a TLS connection to {host}:{port}. Error: {e}"
|
||||
),
|
||||
Severity::High,
|
||||
format!("https://{host}:{port}"),
|
||||
"TLS".to_string(),
|
||||
);
|
||||
finding.cwe = Some("CWE-295".to_string());
|
||||
finding.evidence = vec![evidence];
|
||||
finding.remediation = Some(
|
||||
"Ensure TLS is properly configured on the server. Check that the \
|
||||
certificate is valid and the server supports modern TLS versions."
|
||||
.to_string(),
|
||||
);
|
||||
findings.push(finding);
|
||||
}
|
||||
}
|
||||
|
||||
// Check strict transport security via an HTTPS request
|
||||
match self.http.get(&https_url).send().await {
|
||||
Ok(resp) => {
|
||||
let hsts = resp.headers().get("strict-transport-security");
|
||||
tls_data["hsts_header"] = json!(hsts.map(|v| v.to_str().unwrap_or("")));
|
||||
|
||||
if hsts.is_none() {
|
||||
let evidence = DastEvidence {
|
||||
request_method: "GET".to_string(),
|
||||
request_url: https_url.clone(),
|
||||
request_headers: None,
|
||||
request_body: None,
|
||||
response_status: resp.status().as_u16(),
|
||||
response_headers: None,
|
||||
response_snippet: Some(
|
||||
"Strict-Transport-Security header not present".to_string(),
|
||||
),
|
||||
screenshot_path: None,
|
||||
payload: None,
|
||||
response_time_ms: None,
|
||||
};
|
||||
|
||||
let mut finding = DastFinding::new(
|
||||
String::new(),
|
||||
target_id.clone(),
|
||||
DastVulnType::TlsMisconfiguration,
|
||||
format!("Missing HSTS header for {host}"),
|
||||
format!(
|
||||
"The server at {host} does not send a Strict-Transport-Security header. \
|
||||
Without HSTS, browsers may allow HTTP downgrade attacks."
|
||||
),
|
||||
Severity::Medium,
|
||||
https_url.clone(),
|
||||
"GET".to_string(),
|
||||
);
|
||||
finding.cwe = Some("CWE-319".to_string());
|
||||
finding.evidence = vec![evidence];
|
||||
finding.remediation = Some(
|
||||
"Add the Strict-Transport-Security header with an appropriate max-age. \
|
||||
Example: 'Strict-Transport-Security: max-age=31536000; includeSubDomains'."
|
||||
.to_string(),
|
||||
);
|
||||
findings.push(finding);
|
||||
}
|
||||
}
|
||||
Err(_) => {
|
||||
tls_data["https_check_error"] = json!("Could not connect via HTTPS");
|
||||
}
|
||||
}
|
||||
|
||||
let count = findings.len();
|
||||
info!(host = %host, findings = count, "TLS analysis complete");
|
||||
|
||||
Ok(PentestToolResult {
|
||||
summary: if count > 0 {
|
||||
format!("Found {count} TLS configuration issues for {host}.")
|
||||
} else {
|
||||
format!("TLS configuration looks good for {host}.")
|
||||
},
|
||||
findings,
|
||||
data: tls_data,
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
134
compliance-dast/src/tools/xss.rs
Normal file
134
compliance-dast/src/tools/xss.rs
Normal file
@@ -0,0 +1,134 @@
|
||||
use compliance_core::error::CoreError;
|
||||
use compliance_core::traits::dast_agent::{DastAgent, DastContext, DiscoveredEndpoint, EndpointParameter};
|
||||
use compliance_core::traits::pentest_tool::{PentestTool, PentestToolContext, PentestToolResult};
|
||||
use serde_json::json;
|
||||
|
||||
use crate::agents::xss::XssAgent;
|
||||
|
||||
/// PentestTool wrapper around the existing XssAgent.
|
||||
pub struct XssTool {
|
||||
http: reqwest::Client,
|
||||
agent: XssAgent,
|
||||
}
|
||||
|
||||
impl XssTool {
|
||||
pub fn new(http: reqwest::Client) -> Self {
|
||||
let agent = XssAgent::new(http.clone());
|
||||
Self { http, agent }
|
||||
}
|
||||
|
||||
fn parse_endpoints(input: &serde_json::Value) -> Vec<DiscoveredEndpoint> {
|
||||
let mut endpoints = Vec::new();
|
||||
if let Some(arr) = input.get("endpoints").and_then(|v| v.as_array()) {
|
||||
for ep in arr {
|
||||
let url = ep.get("url").and_then(|v| v.as_str()).unwrap_or_default().to_string();
|
||||
let method = ep.get("method").and_then(|v| v.as_str()).unwrap_or("GET").to_string();
|
||||
let mut parameters = Vec::new();
|
||||
if let Some(params) = ep.get("parameters").and_then(|v| v.as_array()) {
|
||||
for p in params {
|
||||
parameters.push(EndpointParameter {
|
||||
name: p.get("name").and_then(|v| v.as_str()).unwrap_or_default().to_string(),
|
||||
location: p.get("location").and_then(|v| v.as_str()).unwrap_or("query").to_string(),
|
||||
param_type: p.get("param_type").and_then(|v| v.as_str()).map(String::from),
|
||||
example_value: p.get("example_value").and_then(|v| v.as_str()).map(String::from),
|
||||
});
|
||||
}
|
||||
}
|
||||
endpoints.push(DiscoveredEndpoint {
|
||||
url,
|
||||
method,
|
||||
parameters,
|
||||
content_type: ep.get("content_type").and_then(|v| v.as_str()).map(String::from),
|
||||
requires_auth: ep.get("requires_auth").and_then(|v| v.as_bool()).unwrap_or(false),
|
||||
});
|
||||
}
|
||||
}
|
||||
endpoints
|
||||
}
|
||||
}
|
||||
|
||||
impl PentestTool for XssTool {
|
||||
fn name(&self) -> &str {
|
||||
"xss_scanner"
|
||||
}
|
||||
|
||||
fn description(&self) -> &str {
|
||||
"Tests endpoints for Cross-Site Scripting (XSS) vulnerabilities including reflected, \
|
||||
stored, and DOM-based XSS. Provide endpoints with parameters to test."
|
||||
}
|
||||
|
||||
fn input_schema(&self) -> serde_json::Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"endpoints": {
|
||||
"type": "array",
|
||||
"description": "Endpoints to test for XSS",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"url": { "type": "string" },
|
||||
"method": { "type": "string", "enum": ["GET", "POST", "PUT", "PATCH", "DELETE"] },
|
||||
"parameters": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": { "type": "string" },
|
||||
"location": { "type": "string", "enum": ["query", "body", "header", "path", "cookie"] },
|
||||
"param_type": { "type": "string" },
|
||||
"example_value": { "type": "string" }
|
||||
},
|
||||
"required": ["name"]
|
||||
}
|
||||
}
|
||||
},
|
||||
"required": ["url", "method", "parameters"]
|
||||
}
|
||||
},
|
||||
"custom_payloads": {
|
||||
"type": "array",
|
||||
"description": "Optional additional XSS payloads to test",
|
||||
"items": { "type": "string" }
|
||||
}
|
||||
},
|
||||
"required": ["endpoints"]
|
||||
})
|
||||
}
|
||||
|
||||
fn execute<'a>(
|
||||
&'a self,
|
||||
input: serde_json::Value,
|
||||
context: &'a PentestToolContext,
|
||||
) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<PentestToolResult, CoreError>> + Send + 'a>> {
|
||||
Box::pin(async move {
|
||||
let endpoints = Self::parse_endpoints(&input);
|
||||
if endpoints.is_empty() {
|
||||
return Ok(PentestToolResult {
|
||||
summary: "No endpoints provided to test.".to_string(),
|
||||
findings: Vec::new(),
|
||||
data: json!({}),
|
||||
});
|
||||
}
|
||||
|
||||
let dast_context = DastContext {
|
||||
endpoints,
|
||||
technologies: Vec::new(),
|
||||
sast_hints: Vec::new(),
|
||||
};
|
||||
|
||||
let findings = self.agent.run(&context.target, &dast_context).await?;
|
||||
let count = findings.len();
|
||||
|
||||
Ok(PentestToolResult {
|
||||
summary: if count > 0 {
|
||||
format!("Found {count} XSS vulnerabilities.")
|
||||
} else {
|
||||
"No XSS vulnerabilities detected.".to_string()
|
||||
},
|
||||
findings,
|
||||
data: json!({ "endpoints_tested": dast_context.endpoints.len() }),
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -31,4 +31,16 @@ impl Database {
|
||||
pub fn dast_scan_runs(&self) -> Collection<DastScanRun> {
|
||||
self.inner.collection("dast_scan_runs")
|
||||
}
|
||||
|
||||
pub fn pentest_sessions(&self) -> Collection<PentestSession> {
|
||||
self.inner.collection("pentest_sessions")
|
||||
}
|
||||
|
||||
pub fn attack_chain_nodes(&self) -> Collection<AttackChainNode> {
|
||||
self.inner.collection("attack_chain_nodes")
|
||||
}
|
||||
|
||||
pub fn pentest_messages(&self) -> Collection<PentestMessage> {
|
||||
self.inner.collection("pentest_messages")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,7 +3,7 @@ use rmcp::{
|
||||
};
|
||||
|
||||
use crate::database::Database;
|
||||
use crate::tools::{dast, findings, sbom};
|
||||
use crate::tools::{dast, findings, pentest, sbom};
|
||||
|
||||
pub struct ComplianceMcpServer {
|
||||
db: Database,
|
||||
@@ -89,6 +89,54 @@ impl ComplianceMcpServer {
|
||||
) -> Result<CallToolResult, rmcp::ErrorData> {
|
||||
dast::dast_scan_summary(&self.db, params).await
|
||||
}
|
||||
|
||||
// ── Pentest ─────────────────────────────────────────────
|
||||
|
||||
#[tool(
|
||||
description = "List AI pentest sessions with optional filters for target, status, and strategy"
|
||||
)]
|
||||
async fn list_pentest_sessions(
|
||||
&self,
|
||||
Parameters(params): Parameters<pentest::ListPentestSessionsParams>,
|
||||
) -> Result<CallToolResult, rmcp::ErrorData> {
|
||||
pentest::list_pentest_sessions(&self.db, params).await
|
||||
}
|
||||
|
||||
#[tool(description = "Get a single AI pentest session by its ID")]
|
||||
async fn get_pentest_session(
|
||||
&self,
|
||||
Parameters(params): Parameters<pentest::GetPentestSessionParams>,
|
||||
) -> Result<CallToolResult, rmcp::ErrorData> {
|
||||
pentest::get_pentest_session(&self.db, params).await
|
||||
}
|
||||
|
||||
#[tool(
|
||||
description = "Get the attack chain DAG for a pentest session showing each tool invocation, its reasoning, and results"
|
||||
)]
|
||||
async fn get_attack_chain(
|
||||
&self,
|
||||
Parameters(params): Parameters<pentest::GetAttackChainParams>,
|
||||
) -> Result<CallToolResult, rmcp::ErrorData> {
|
||||
pentest::get_attack_chain(&self.db, params).await
|
||||
}
|
||||
|
||||
#[tool(description = "Get chat messages from a pentest session")]
|
||||
async fn get_pentest_messages(
|
||||
&self,
|
||||
Parameters(params): Parameters<pentest::GetPentestMessagesParams>,
|
||||
) -> Result<CallToolResult, rmcp::ErrorData> {
|
||||
pentest::get_pentest_messages(&self.db, params).await
|
||||
}
|
||||
|
||||
#[tool(
|
||||
description = "Get aggregated pentest statistics including running sessions, vulnerability counts, and severity distribution"
|
||||
)]
|
||||
async fn pentest_stats(
|
||||
&self,
|
||||
Parameters(params): Parameters<pentest::PentestStatsParams>,
|
||||
) -> Result<CallToolResult, rmcp::ErrorData> {
|
||||
pentest::pentest_stats(&self.db, params).await
|
||||
}
|
||||
}
|
||||
|
||||
#[tool_handler]
|
||||
@@ -101,7 +149,7 @@ impl ServerHandler for ComplianceMcpServer {
|
||||
.build(),
|
||||
server_info: Implementation::from_build_env(),
|
||||
instructions: Some(
|
||||
"Compliance Scanner MCP server. Query security findings, SBOM data, and DAST results."
|
||||
"Compliance Scanner MCP server. Query security findings, SBOM data, DAST results, and AI pentest sessions."
|
||||
.to_string(),
|
||||
),
|
||||
}
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
pub mod dast;
|
||||
pub mod findings;
|
||||
pub mod pentest;
|
||||
pub mod sbom;
|
||||
|
||||
261
compliance-mcp/src/tools/pentest.rs
Normal file
261
compliance-mcp/src/tools/pentest.rs
Normal file
@@ -0,0 +1,261 @@
|
||||
use mongodb::bson::doc;
|
||||
use rmcp::{model::*, ErrorData as McpError};
|
||||
use schemars::JsonSchema;
|
||||
use serde::Deserialize;
|
||||
|
||||
use crate::database::Database;
|
||||
|
||||
const MAX_LIMIT: i64 = 200;
|
||||
const DEFAULT_LIMIT: i64 = 50;
|
||||
|
||||
fn cap_limit(limit: Option<i64>) -> i64 {
|
||||
limit.unwrap_or(DEFAULT_LIMIT).clamp(1, MAX_LIMIT)
|
||||
}
|
||||
|
||||
// ── List Pentest Sessions ──────────────────────────────────────
|
||||
|
||||
#[derive(Debug, Deserialize, JsonSchema)]
|
||||
pub struct ListPentestSessionsParams {
|
||||
/// Filter by target ID
|
||||
pub target_id: Option<String>,
|
||||
/// Filter by status: running, paused, completed, failed
|
||||
pub status: Option<String>,
|
||||
/// Filter by strategy: quick, comprehensive, targeted, aggressive, stealth
|
||||
pub strategy: Option<String>,
|
||||
/// Maximum number of results (default 50, max 200)
|
||||
pub limit: Option<i64>,
|
||||
}
|
||||
|
||||
pub async fn list_pentest_sessions(
|
||||
db: &Database,
|
||||
params: ListPentestSessionsParams,
|
||||
) -> Result<CallToolResult, McpError> {
|
||||
let mut filter = doc! {};
|
||||
if let Some(ref target_id) = params.target_id {
|
||||
filter.insert("target_id", target_id);
|
||||
}
|
||||
if let Some(ref status) = params.status {
|
||||
filter.insert("status", status);
|
||||
}
|
||||
if let Some(ref strategy) = params.strategy {
|
||||
filter.insert("strategy", strategy);
|
||||
}
|
||||
|
||||
let limit = cap_limit(params.limit);
|
||||
|
||||
let mut cursor = db
|
||||
.pentest_sessions()
|
||||
.find(filter)
|
||||
.sort(doc! { "started_at": -1 })
|
||||
.limit(limit)
|
||||
.await
|
||||
.map_err(|e| McpError::internal_error(format!("DB error: {e}"), None))?;
|
||||
|
||||
let mut results = Vec::new();
|
||||
while cursor
|
||||
.advance()
|
||||
.await
|
||||
.map_err(|e| McpError::internal_error(format!("cursor error: {e}"), None))?
|
||||
{
|
||||
let session = cursor
|
||||
.deserialize_current()
|
||||
.map_err(|e| McpError::internal_error(format!("deserialize error: {e}"), None))?;
|
||||
results.push(session);
|
||||
}
|
||||
|
||||
let json = serde_json::to_string_pretty(&results)
|
||||
.map_err(|e| McpError::internal_error(format!("json error: {e}"), None))?;
|
||||
|
||||
Ok(CallToolResult::success(vec![Content::text(json)]))
|
||||
}
|
||||
|
||||
// ── Get Pentest Session ────────────────────────────────────────
|
||||
|
||||
#[derive(Debug, Deserialize, JsonSchema)]
|
||||
pub struct GetPentestSessionParams {
|
||||
/// Pentest session ID (MongoDB ObjectId hex string)
|
||||
pub id: String,
|
||||
}
|
||||
|
||||
pub async fn get_pentest_session(
|
||||
db: &Database,
|
||||
params: GetPentestSessionParams,
|
||||
) -> Result<CallToolResult, McpError> {
|
||||
let oid = bson::oid::ObjectId::parse_str(¶ms.id)
|
||||
.map_err(|e| McpError::invalid_params(format!("invalid id: {e}"), None))?;
|
||||
|
||||
let session = db
|
||||
.pentest_sessions()
|
||||
.find_one(doc! { "_id": oid })
|
||||
.await
|
||||
.map_err(|e| McpError::internal_error(format!("DB error: {e}"), None))?
|
||||
.ok_or_else(|| McpError::invalid_params("session not found", None))?;
|
||||
|
||||
let json = serde_json::to_string_pretty(&session)
|
||||
.map_err(|e| McpError::internal_error(format!("json error: {e}"), None))?;
|
||||
|
||||
Ok(CallToolResult::success(vec![Content::text(json)]))
|
||||
}
|
||||
|
||||
// ── Get Attack Chain ───────────────────────────────────────────
|
||||
|
||||
#[derive(Debug, Deserialize, JsonSchema)]
|
||||
pub struct GetAttackChainParams {
|
||||
/// Pentest session ID to get the attack chain for
|
||||
pub session_id: String,
|
||||
/// Maximum number of nodes (default 50, max 200)
|
||||
pub limit: Option<i64>,
|
||||
}
|
||||
|
||||
pub async fn get_attack_chain(
|
||||
db: &Database,
|
||||
params: GetAttackChainParams,
|
||||
) -> Result<CallToolResult, McpError> {
|
||||
let limit = cap_limit(params.limit);
|
||||
|
||||
let mut cursor = db
|
||||
.attack_chain_nodes()
|
||||
.find(doc! { "session_id": ¶ms.session_id })
|
||||
.sort(doc! { "started_at": 1 })
|
||||
.limit(limit)
|
||||
.await
|
||||
.map_err(|e| McpError::internal_error(format!("DB error: {e}"), None))?;
|
||||
|
||||
let mut results = Vec::new();
|
||||
while cursor
|
||||
.advance()
|
||||
.await
|
||||
.map_err(|e| McpError::internal_error(format!("cursor error: {e}"), None))?
|
||||
{
|
||||
let node = cursor
|
||||
.deserialize_current()
|
||||
.map_err(|e| McpError::internal_error(format!("deserialize error: {e}"), None))?;
|
||||
results.push(node);
|
||||
}
|
||||
|
||||
let json = serde_json::to_string_pretty(&results)
|
||||
.map_err(|e| McpError::internal_error(format!("json error: {e}"), None))?;
|
||||
|
||||
Ok(CallToolResult::success(vec![Content::text(json)]))
|
||||
}
|
||||
|
||||
// ── Get Pentest Messages ───────────────────────────────────────
|
||||
|
||||
#[derive(Debug, Deserialize, JsonSchema)]
|
||||
pub struct GetPentestMessagesParams {
|
||||
/// Pentest session ID
|
||||
pub session_id: String,
|
||||
/// Maximum number of messages (default 50, max 200)
|
||||
pub limit: Option<i64>,
|
||||
}
|
||||
|
||||
pub async fn get_pentest_messages(
|
||||
db: &Database,
|
||||
params: GetPentestMessagesParams,
|
||||
) -> Result<CallToolResult, McpError> {
|
||||
let limit = cap_limit(params.limit);
|
||||
|
||||
let mut cursor = db
|
||||
.pentest_messages()
|
||||
.find(doc! { "session_id": ¶ms.session_id })
|
||||
.sort(doc! { "created_at": 1 })
|
||||
.limit(limit)
|
||||
.await
|
||||
.map_err(|e| McpError::internal_error(format!("DB error: {e}"), None))?;
|
||||
|
||||
let mut results = Vec::new();
|
||||
while cursor
|
||||
.advance()
|
||||
.await
|
||||
.map_err(|e| McpError::internal_error(format!("cursor error: {e}"), None))?
|
||||
{
|
||||
let msg = cursor
|
||||
.deserialize_current()
|
||||
.map_err(|e| McpError::internal_error(format!("deserialize error: {e}"), None))?;
|
||||
results.push(msg);
|
||||
}
|
||||
|
||||
let json = serde_json::to_string_pretty(&results)
|
||||
.map_err(|e| McpError::internal_error(format!("json error: {e}"), None))?;
|
||||
|
||||
Ok(CallToolResult::success(vec![Content::text(json)]))
|
||||
}
|
||||
|
||||
// ── Pentest Stats ──────────────────────────────────────────────
|
||||
|
||||
#[derive(Debug, Deserialize, JsonSchema)]
|
||||
pub struct PentestStatsParams {
|
||||
/// Filter stats by target ID
|
||||
pub target_id: Option<String>,
|
||||
}
|
||||
|
||||
pub async fn pentest_stats(
|
||||
db: &Database,
|
||||
params: PentestStatsParams,
|
||||
) -> Result<CallToolResult, McpError> {
|
||||
let mut base_filter = doc! {};
|
||||
if let Some(ref target_id) = params.target_id {
|
||||
base_filter.insert("target_id", target_id);
|
||||
}
|
||||
|
||||
// Count running sessions
|
||||
let mut running_filter = base_filter.clone();
|
||||
running_filter.insert("status", "running");
|
||||
let running = db
|
||||
.pentest_sessions()
|
||||
.count_documents(running_filter)
|
||||
.await
|
||||
.map_err(|e| McpError::internal_error(format!("DB error: {e}"), None))?;
|
||||
|
||||
// Count total sessions
|
||||
let total_sessions = db
|
||||
.pentest_sessions()
|
||||
.count_documents(base_filter.clone())
|
||||
.await
|
||||
.map_err(|e| McpError::internal_error(format!("DB error: {e}"), None))?;
|
||||
|
||||
// Get findings for these sessions — query DAST findings with session_id set
|
||||
let mut findings_filter = doc! { "session_id": { "$ne": null } };
|
||||
if let Some(ref target_id) = params.target_id {
|
||||
findings_filter.insert("target_id", target_id);
|
||||
}
|
||||
let total_findings = db
|
||||
.dast_findings()
|
||||
.count_documents(findings_filter.clone())
|
||||
.await
|
||||
.map_err(|e| McpError::internal_error(format!("DB error: {e}"), None))?;
|
||||
|
||||
let mut exploitable_filter = findings_filter.clone();
|
||||
exploitable_filter.insert("exploitable", true);
|
||||
let exploitable = db
|
||||
.dast_findings()
|
||||
.count_documents(exploitable_filter)
|
||||
.await
|
||||
.map_err(|e| McpError::internal_error(format!("DB error: {e}"), None))?;
|
||||
|
||||
// Severity counts
|
||||
let mut severity = serde_json::Map::new();
|
||||
for sev in ["critical", "high", "medium", "low", "info"] {
|
||||
let mut sf = findings_filter.clone();
|
||||
sf.insert("severity", sev);
|
||||
let count = db
|
||||
.dast_findings()
|
||||
.count_documents(sf)
|
||||
.await
|
||||
.map_err(|e| McpError::internal_error(format!("DB error: {e}"), None))?;
|
||||
severity.insert(sev.to_string(), serde_json::json!(count));
|
||||
}
|
||||
|
||||
let summary = serde_json::json!({
|
||||
"running_sessions": running,
|
||||
"total_sessions": total_sessions,
|
||||
"total_findings": total_findings,
|
||||
"exploitable_findings": exploitable,
|
||||
"severity_distribution": severity,
|
||||
});
|
||||
|
||||
let json = serde_json::to_string_pretty(&summary)
|
||||
.map_err(|e| McpError::internal_error(format!("json error: {e}"), None))?;
|
||||
|
||||
Ok(CallToolResult::success(vec![Content::text(json)]))
|
||||
}
|
||||
Reference in New Issue
Block a user