refactor: modularize codebase and add 404 unit tests (#13)
All checks were successful
CI / Format (push) Successful in 4s
CI / Clippy (push) Successful in 4m19s
CI / Security Audit (push) Successful in 1m44s
CI / Detect Changes (push) Successful in 5s
CI / Tests (push) Successful in 5m15s
CI / Deploy Agent (push) Successful in 2s
CI / Deploy Dashboard (push) Successful in 2s
CI / Deploy Docs (push) Has been skipped
CI / Deploy MCP (push) Successful in 2s

This commit was merged in pull request #13.
This commit is contained in:
2026-03-13 08:03:45 +00:00
parent acc5b86aa4
commit 3bb690e5bb
89 changed files with 11884 additions and 6046 deletions

View File

@@ -1,5 +1,7 @@
use compliance_core::error::CoreError;
use compliance_core::traits::dast_agent::{DastAgent, DastContext, DiscoveredEndpoint, EndpointParameter};
use compliance_core::traits::dast_agent::{
DastAgent, DastContext, DiscoveredEndpoint, EndpointParameter,
};
use compliance_core::traits::pentest_tool::{PentestTool, PentestToolContext, PentestToolResult};
use serde_json::json;
@@ -7,30 +9,52 @@ use crate::agents::api_fuzzer::ApiFuzzerAgent;
/// PentestTool wrapper around the existing ApiFuzzerAgent.
pub struct ApiFuzzerTool {
http: reqwest::Client,
_http: reqwest::Client,
agent: ApiFuzzerAgent,
}
impl ApiFuzzerTool {
pub fn new(http: reqwest::Client) -> Self {
let agent = ApiFuzzerAgent::new(http.clone());
Self { http, agent }
Self { _http: http, agent }
}
fn parse_endpoints(input: &serde_json::Value) -> Vec<DiscoveredEndpoint> {
let mut endpoints = Vec::new();
if let Some(arr) = input.get("endpoints").and_then(|v| v.as_array()) {
for ep in arr {
let url = ep.get("url").and_then(|v| v.as_str()).unwrap_or_default().to_string();
let method = ep.get("method").and_then(|v| v.as_str()).unwrap_or("GET").to_string();
let url = ep
.get("url")
.and_then(|v| v.as_str())
.unwrap_or_default()
.to_string();
let method = ep
.get("method")
.and_then(|v| v.as_str())
.unwrap_or("GET")
.to_string();
let mut parameters = Vec::new();
if let Some(params) = ep.get("parameters").and_then(|v| v.as_array()) {
for p in params {
parameters.push(EndpointParameter {
name: p.get("name").and_then(|v| v.as_str()).unwrap_or_default().to_string(),
location: p.get("location").and_then(|v| v.as_str()).unwrap_or("query").to_string(),
param_type: p.get("param_type").and_then(|v| v.as_str()).map(String::from),
example_value: p.get("example_value").and_then(|v| v.as_str()).map(String::from),
name: p
.get("name")
.and_then(|v| v.as_str())
.unwrap_or_default()
.to_string(),
location: p
.get("location")
.and_then(|v| v.as_str())
.unwrap_or("query")
.to_string(),
param_type: p
.get("param_type")
.and_then(|v| v.as_str())
.map(String::from),
example_value: p
.get("example_value")
.and_then(|v| v.as_str())
.map(String::from),
});
}
}
@@ -38,8 +62,14 @@ impl ApiFuzzerTool {
url,
method,
parameters,
content_type: ep.get("content_type").and_then(|v| v.as_str()).map(String::from),
requires_auth: ep.get("requires_auth").and_then(|v| v.as_bool()).unwrap_or(false),
content_type: ep
.get("content_type")
.and_then(|v| v.as_str())
.map(String::from),
requires_auth: ep
.get("requires_auth")
.and_then(|v| v.as_bool())
.unwrap_or(false),
});
}
}
@@ -98,49 +128,51 @@ impl PentestTool for ApiFuzzerTool {
&'a self,
input: serde_json::Value,
context: &'a PentestToolContext,
) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<PentestToolResult, CoreError>> + Send + 'a>> {
) -> std::pin::Pin<
Box<dyn std::future::Future<Output = Result<PentestToolResult, CoreError>> + Send + 'a>,
> {
Box::pin(async move {
let mut endpoints = Self::parse_endpoints(&input);
let mut endpoints = Self::parse_endpoints(&input);
// If a base_url is provided but no endpoints, create a default endpoint
if endpoints.is_empty() {
if let Some(base) = input.get("base_url").and_then(|v| v.as_str()) {
endpoints.push(DiscoveredEndpoint {
url: base.to_string(),
method: "GET".to_string(),
parameters: Vec::new(),
content_type: None,
requires_auth: false,
// If a base_url is provided but no endpoints, create a default endpoint
if endpoints.is_empty() {
if let Some(base) = input.get("base_url").and_then(|v| v.as_str()) {
endpoints.push(DiscoveredEndpoint {
url: base.to_string(),
method: "GET".to_string(),
parameters: Vec::new(),
content_type: None,
requires_auth: false,
});
}
}
if endpoints.is_empty() {
return Ok(PentestToolResult {
summary: "No endpoints or base_url provided to fuzz.".to_string(),
findings: Vec::new(),
data: json!({}),
});
}
}
if endpoints.is_empty() {
return Ok(PentestToolResult {
summary: "No endpoints or base_url provided to fuzz.".to_string(),
findings: Vec::new(),
data: json!({}),
});
}
let dast_context = DastContext {
endpoints,
technologies: Vec::new(),
sast_hints: Vec::new(),
};
let dast_context = DastContext {
endpoints,
technologies: Vec::new(),
sast_hints: Vec::new(),
};
let findings = self.agent.run(&context.target, &dast_context).await?;
let count = findings.len();
let findings = self.agent.run(&context.target, &dast_context).await?;
let count = findings.len();
Ok(PentestToolResult {
summary: if count > 0 {
format!("Found {count} API misconfigurations or information disclosures.")
} else {
"No API misconfigurations detected.".to_string()
},
findings,
data: json!({ "endpoints_tested": dast_context.endpoints.len() }),
})
Ok(PentestToolResult {
summary: if count > 0 {
format!("Found {count} API misconfigurations or information disclosures.")
} else {
"No API misconfigurations detected.".to_string()
},
findings,
data: json!({ "endpoints_tested": dast_context.endpoints.len() }),
})
})
}
}

View File

@@ -1,5 +1,7 @@
use compliance_core::error::CoreError;
use compliance_core::traits::dast_agent::{DastAgent, DastContext, DiscoveredEndpoint, EndpointParameter};
use compliance_core::traits::dast_agent::{
DastAgent, DastContext, DiscoveredEndpoint, EndpointParameter,
};
use compliance_core::traits::pentest_tool::{PentestTool, PentestToolContext, PentestToolResult};
use serde_json::json;
@@ -7,30 +9,52 @@ use crate::agents::auth_bypass::AuthBypassAgent;
/// PentestTool wrapper around the existing AuthBypassAgent.
pub struct AuthBypassTool {
http: reqwest::Client,
_http: reqwest::Client,
agent: AuthBypassAgent,
}
impl AuthBypassTool {
pub fn new(http: reqwest::Client) -> Self {
let agent = AuthBypassAgent::new(http.clone());
Self { http, agent }
Self { _http: http, agent }
}
fn parse_endpoints(input: &serde_json::Value) -> Vec<DiscoveredEndpoint> {
let mut endpoints = Vec::new();
if let Some(arr) = input.get("endpoints").and_then(|v| v.as_array()) {
for ep in arr {
let url = ep.get("url").and_then(|v| v.as_str()).unwrap_or_default().to_string();
let method = ep.get("method").and_then(|v| v.as_str()).unwrap_or("GET").to_string();
let url = ep
.get("url")
.and_then(|v| v.as_str())
.unwrap_or_default()
.to_string();
let method = ep
.get("method")
.and_then(|v| v.as_str())
.unwrap_or("GET")
.to_string();
let mut parameters = Vec::new();
if let Some(params) = ep.get("parameters").and_then(|v| v.as_array()) {
for p in params {
parameters.push(EndpointParameter {
name: p.get("name").and_then(|v| v.as_str()).unwrap_or_default().to_string(),
location: p.get("location").and_then(|v| v.as_str()).unwrap_or("query").to_string(),
param_type: p.get("param_type").and_then(|v| v.as_str()).map(String::from),
example_value: p.get("example_value").and_then(|v| v.as_str()).map(String::from),
name: p
.get("name")
.and_then(|v| v.as_str())
.unwrap_or_default()
.to_string(),
location: p
.get("location")
.and_then(|v| v.as_str())
.unwrap_or("query")
.to_string(),
param_type: p
.get("param_type")
.and_then(|v| v.as_str())
.map(String::from),
example_value: p
.get("example_value")
.and_then(|v| v.as_str())
.map(String::from),
});
}
}
@@ -38,8 +62,14 @@ impl AuthBypassTool {
url,
method,
parameters,
content_type: ep.get("content_type").and_then(|v| v.as_str()).map(String::from),
requires_auth: ep.get("requires_auth").and_then(|v| v.as_bool()).unwrap_or(false),
content_type: ep
.get("content_type")
.and_then(|v| v.as_str())
.map(String::from),
requires_auth: ep
.get("requires_auth")
.and_then(|v| v.as_bool())
.unwrap_or(false),
});
}
}
@@ -96,35 +126,37 @@ impl PentestTool for AuthBypassTool {
&'a self,
input: serde_json::Value,
context: &'a PentestToolContext,
) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<PentestToolResult, CoreError>> + Send + 'a>> {
) -> std::pin::Pin<
Box<dyn std::future::Future<Output = Result<PentestToolResult, CoreError>> + Send + 'a>,
> {
Box::pin(async move {
let endpoints = Self::parse_endpoints(&input);
if endpoints.is_empty() {
return Ok(PentestToolResult {
summary: "No endpoints provided to test.".to_string(),
findings: Vec::new(),
data: json!({}),
});
}
let endpoints = Self::parse_endpoints(&input);
if endpoints.is_empty() {
return Ok(PentestToolResult {
summary: "No endpoints provided to test.".to_string(),
findings: Vec::new(),
data: json!({}),
});
}
let dast_context = DastContext {
endpoints,
technologies: Vec::new(),
sast_hints: Vec::new(),
};
let dast_context = DastContext {
endpoints,
technologies: Vec::new(),
sast_hints: Vec::new(),
};
let findings = self.agent.run(&context.target, &dast_context).await?;
let count = findings.len();
let findings = self.agent.run(&context.target, &dast_context).await?;
let count = findings.len();
Ok(PentestToolResult {
summary: if count > 0 {
format!("Found {count} authentication bypass vulnerabilities.")
} else {
"No authentication bypass vulnerabilities detected.".to_string()
},
findings,
data: json!({ "endpoints_tested": dast_context.endpoints.len() }),
})
Ok(PentestToolResult {
summary: if count > 0 {
format!("Found {count} authentication bypass vulnerabilities.")
} else {
"No authentication bypass vulnerabilities detected.".to_string()
},
findings,
data: json!({ "endpoints_tested": dast_context.endpoints.len() }),
})
})
}
}

View File

@@ -54,7 +54,7 @@ impl ConsoleLogDetectorTool {
}
let quote = html.as_bytes().get(abs_start).copied();
let (open, close) = match quote {
let (_open, close) = match quote {
Some(b'"') => ('"', '"'),
Some(b'\'') => ('\'', '\''),
_ => {
@@ -122,6 +122,96 @@ impl ConsoleLogDetectorTool {
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn extract_js_urls_from_html() {
let html = r#"
<html>
<head>
<script src="/static/app.js"></script>
<script src="https://cdn.example.com/lib.js"></script>
<script src='//cdn2.example.com/vendor.js'></script>
</head>
</html>
"#;
let urls = ConsoleLogDetectorTool::extract_js_urls(html, "https://example.com");
assert_eq!(urls.len(), 3);
assert!(urls.contains(&"https://example.com/static/app.js".to_string()));
assert!(urls.contains(&"https://cdn.example.com/lib.js".to_string()));
assert!(urls.contains(&"https://cdn2.example.com/vendor.js".to_string()));
}
#[test]
fn extract_js_urls_no_scripts() {
let html = "<html><body><p>Hello</p></body></html>";
let urls = ConsoleLogDetectorTool::extract_js_urls(html, "https://example.com");
assert!(urls.is_empty());
}
#[test]
fn extract_js_urls_filters_non_js() {
let html = r#"<link src="/style.css"><script src="/app.js"></script>"#;
let urls = ConsoleLogDetectorTool::extract_js_urls(html, "https://example.com");
// Only .js files should be extracted
assert_eq!(urls.len(), 1);
assert!(urls[0].ends_with("/app.js"));
}
#[test]
fn scan_js_content_finds_console_log() {
let js = r#"
function init() {
console.log("debug info");
doStuff();
}
"#;
let matches = ConsoleLogDetectorTool::scan_js_content(js, "https://example.com/app.js");
assert_eq!(matches.len(), 1);
assert_eq!(matches[0].pattern, "console.log");
assert_eq!(matches[0].line_number, Some(3));
}
#[test]
fn scan_js_content_finds_multiple_patterns() {
let js =
"console.log('a');\nconsole.debug('b');\nconsole.error('c');\ndebugger;\nalert('x');";
let matches = ConsoleLogDetectorTool::scan_js_content(js, "test.js");
assert_eq!(matches.len(), 5);
}
#[test]
fn scan_js_content_skips_comments() {
let js = "// console.log('commented out');\n* console.log('also comment');\n/* console.log('block comment') */";
let matches = ConsoleLogDetectorTool::scan_js_content(js, "test.js");
assert!(matches.is_empty());
}
#[test]
fn scan_js_content_one_match_per_line() {
let js = "console.log('a'); console.debug('b');";
let matches = ConsoleLogDetectorTool::scan_js_content(js, "test.js");
// Only one match per line
assert_eq!(matches.len(), 1);
}
#[test]
fn scan_js_content_empty_input() {
let matches = ConsoleLogDetectorTool::scan_js_content("", "test.js");
assert!(matches.is_empty());
}
#[test]
fn patterns_list_is_not_empty() {
let patterns = ConsoleLogDetectorTool::patterns();
assert!(patterns.len() >= 8);
assert!(patterns.contains(&"console.log("));
assert!(patterns.contains(&"debugger;"));
}
}
impl PentestTool for ConsoleLogDetectorTool {
fn name(&self) -> &str {
"console_log_detector"
@@ -154,173 +244,180 @@ impl PentestTool for ConsoleLogDetectorTool {
&'a self,
input: serde_json::Value,
context: &'a PentestToolContext,
) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<PentestToolResult, CoreError>> + Send + 'a>> {
) -> std::pin::Pin<
Box<dyn std::future::Future<Output = Result<PentestToolResult, CoreError>> + Send + 'a>,
> {
Box::pin(async move {
let url = input
.get("url")
.and_then(|v| v.as_str())
.ok_or_else(|| CoreError::Dast("Missing required 'url' parameter".to_string()))?;
let url = input
.get("url")
.and_then(|v| v.as_str())
.ok_or_else(|| CoreError::Dast("Missing required 'url' parameter".to_string()))?;
let additional_js: Vec<String> = input
.get("additional_js_urls")
.and_then(|v| v.as_array())
.map(|arr| {
arr.iter()
.filter_map(|v| v.as_str().map(String::from))
.collect()
})
.unwrap_or_default();
let target_id = context
.target
.id
.map(|oid| oid.to_hex())
.unwrap_or_else(|| "unknown".to_string());
// Fetch the main page
let response = self
.http
.get(url)
.send()
.await
.map_err(|e| CoreError::Dast(format!("Failed to fetch {url}: {e}")))?;
let html = response.text().await.unwrap_or_default();
// Scan inline scripts in the HTML
let mut all_matches = Vec::new();
let inline_matches = Self::scan_js_content(&html, url);
all_matches.extend(inline_matches);
// Extract JS file URLs from the HTML
let mut js_urls = Self::extract_js_urls(&html, url);
js_urls.extend(additional_js);
js_urls.dedup();
// Fetch and scan each JS file
for js_url in &js_urls {
match self.http.get(js_url).send().await {
Ok(resp) => {
if resp.status().is_success() {
let js_content = resp.text().await.unwrap_or_default();
// Only scan non-minified-looking files or files where we can still
// find patterns (minifiers typically strip console calls, but not always)
let file_matches = Self::scan_js_content(&js_content, js_url);
all_matches.extend(file_matches);
}
}
Err(_) => continue,
}
}
let mut findings = Vec::new();
let match_data: Vec<serde_json::Value> = all_matches
.iter()
.map(|m| {
json!({
"pattern": m.pattern,
"file": m.file_url,
"line": m.line_number,
"snippet": m.line_snippet,
let additional_js: Vec<String> = input
.get("additional_js_urls")
.and_then(|v| v.as_array())
.map(|arr| {
arr.iter()
.filter_map(|v| v.as_str().map(String::from))
.collect()
})
})
.collect();
.unwrap_or_default();
if !all_matches.is_empty() {
// Group by file for the finding
let mut by_file: std::collections::HashMap<&str, Vec<&ConsoleMatch>> =
std::collections::HashMap::new();
for m in &all_matches {
by_file.entry(&m.file_url).or_default().push(m);
let target_id = context
.target
.id
.map(|oid| oid.to_hex())
.unwrap_or_else(|| "unknown".to_string());
// Fetch the main page
let response = self
.http
.get(url)
.send()
.await
.map_err(|e| CoreError::Dast(format!("Failed to fetch {url}: {e}")))?;
let html = response.text().await.unwrap_or_default();
// Scan inline scripts in the HTML
let mut all_matches = Vec::new();
let inline_matches = Self::scan_js_content(&html, url);
all_matches.extend(inline_matches);
// Extract JS file URLs from the HTML
let mut js_urls = Self::extract_js_urls(&html, url);
js_urls.extend(additional_js);
js_urls.dedup();
// Fetch and scan each JS file
for js_url in &js_urls {
match self.http.get(js_url).send().await {
Ok(resp) => {
if resp.status().is_success() {
let js_content = resp.text().await.unwrap_or_default();
// Only scan non-minified-looking files or files where we can still
// find patterns (minifiers typically strip console calls, but not always)
let file_matches = Self::scan_js_content(&js_content, js_url);
all_matches.extend(file_matches);
}
}
Err(_) => continue,
}
}
for (file_url, matches) in &by_file {
let pattern_summary: Vec<String> = matches
.iter()
.take(5)
.map(|m| {
format!(
" Line {}: {} - {}",
m.line_number.unwrap_or(0),
m.pattern,
if m.line_snippet.len() > 80 {
format!("{}...", &m.line_snippet[..80])
} else {
m.line_snippet.clone()
}
)
let mut findings = Vec::new();
let match_data: Vec<serde_json::Value> = all_matches
.iter()
.map(|m| {
json!({
"pattern": m.pattern,
"file": m.file_url,
"line": m.line_number,
"snippet": m.line_snippet,
})
.collect();
})
.collect();
let evidence = DastEvidence {
request_method: "GET".to_string(),
request_url: file_url.to_string(),
request_headers: None,
request_body: None,
response_status: 200,
response_headers: None,
response_snippet: Some(pattern_summary.join("\n")),
screenshot_path: None,
payload: None,
response_time_ms: None,
};
if !all_matches.is_empty() {
// Group by file for the finding
let mut by_file: std::collections::HashMap<&str, Vec<&ConsoleMatch>> =
std::collections::HashMap::new();
for m in &all_matches {
by_file.entry(&m.file_url).or_default().push(m);
}
let total = matches.len();
let extra = if total > 5 {
format!(" (and {} more)", total - 5)
} else {
String::new()
};
for (file_url, matches) in &by_file {
let pattern_summary: Vec<String> = matches
.iter()
.take(5)
.map(|m| {
format!(
" Line {}: {} - {}",
m.line_number.unwrap_or(0),
m.pattern,
if m.line_snippet.len() > 80 {
format!("{}...", &m.line_snippet[..80])
} else {
m.line_snippet.clone()
}
)
})
.collect();
let mut finding = DastFinding::new(
String::new(),
target_id.clone(),
DastVulnType::ConsoleLogLeakage,
format!("Console/debug statements in {}", file_url),
format!(
"Found {total} console/debug statements in {file_url}{extra}. \
let evidence = DastEvidence {
request_method: "GET".to_string(),
request_url: file_url.to_string(),
request_headers: None,
request_body: None,
response_status: 200,
response_headers: None,
response_snippet: Some(pattern_summary.join("\n")),
screenshot_path: None,
payload: None,
response_time_ms: None,
};
let total = matches.len();
let extra = if total > 5 {
format!(" (and {} more)", total - 5)
} else {
String::new()
};
let mut finding = DastFinding::new(
String::new(),
target_id.clone(),
DastVulnType::ConsoleLogLeakage,
format!("Console/debug statements in {}", file_url),
format!(
"Found {total} console/debug statements in {file_url}{extra}. \
These can leak sensitive information such as API responses, user data, \
or internal state to anyone with browser developer tools open."
),
Severity::Low,
file_url.to_string(),
"GET".to_string(),
);
finding.cwe = Some("CWE-532".to_string());
finding.evidence = vec![evidence];
finding.remediation = Some(
"Remove console.log/debug/error statements from production code. \
),
Severity::Low,
file_url.to_string(),
"GET".to_string(),
);
finding.cwe = Some("CWE-532".to_string());
finding.evidence = vec![evidence];
finding.remediation = Some(
"Remove console.log/debug/error statements from production code. \
Use a build step (e.g., babel plugin, terser) to strip console calls \
during the production build."
.to_string(),
);
findings.push(finding);
.to_string(),
);
findings.push(finding);
}
}
}
let total_matches = all_matches.len();
let count = findings.len();
info!(url, js_files = js_urls.len(), total_matches, "Console log detection complete");
let total_matches = all_matches.len();
let count = findings.len();
info!(
url,
js_files = js_urls.len(),
total_matches,
"Console log detection complete"
);
Ok(PentestToolResult {
summary: if total_matches > 0 {
format!(
"Found {total_matches} console/debug statements across {} files.",
count
)
} else {
format!(
"No console/debug statements found in HTML or {} JS files.",
js_urls.len()
)
},
findings,
data: json!({
"total_matches": total_matches,
"js_files_scanned": js_urls.len(),
"matches": match_data,
}),
})
Ok(PentestToolResult {
summary: if total_matches > 0 {
format!(
"Found {total_matches} console/debug statements across {} files.",
count
)
} else {
format!(
"No console/debug statements found in HTML or {} JS files.",
js_urls.len()
)
},
findings,
data: json!({
"total_matches": total_matches,
"js_files_scanned": js_urls.len(),
"matches": match_data,
}),
})
})
}
}

View File

@@ -14,6 +14,7 @@ pub struct CookieAnalyzerTool {
#[derive(Debug)]
struct ParsedCookie {
name: String,
#[allow(dead_code)]
value: String,
secure: bool,
http_only: bool,
@@ -92,6 +93,81 @@ impl CookieAnalyzerTool {
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parse_simple_cookie() {
let cookie = CookieAnalyzerTool::parse_set_cookie("session_id=abc123");
assert_eq!(cookie.name, "session_id");
assert_eq!(cookie.value, "abc123");
assert!(!cookie.secure);
assert!(!cookie.http_only);
assert!(cookie.same_site.is_none());
assert!(cookie.domain.is_none());
assert!(cookie.path.is_none());
}
#[test]
fn parse_cookie_with_all_attributes() {
let raw = "token=xyz; Secure; HttpOnly; SameSite=Strict; Domain=.example.com; Path=/api";
let cookie = CookieAnalyzerTool::parse_set_cookie(raw);
assert_eq!(cookie.name, "token");
assert_eq!(cookie.value, "xyz");
assert!(cookie.secure);
assert!(cookie.http_only);
assert_eq!(cookie.same_site.as_deref(), Some("strict"));
assert_eq!(cookie.domain.as_deref(), Some(".example.com"));
assert_eq!(cookie.path.as_deref(), Some("/api"));
assert_eq!(cookie.raw, raw);
}
#[test]
fn parse_cookie_samesite_none() {
let cookie = CookieAnalyzerTool::parse_set_cookie("id=1; SameSite=None; Secure");
assert_eq!(cookie.same_site.as_deref(), Some("none"));
assert!(cookie.secure);
}
#[test]
fn parse_cookie_with_equals_in_value() {
let cookie = CookieAnalyzerTool::parse_set_cookie("data=a=b=c; HttpOnly");
assert_eq!(cookie.name, "data");
assert_eq!(cookie.value, "a=b=c");
assert!(cookie.http_only);
}
#[test]
fn is_sensitive_cookie_known_names() {
assert!(CookieAnalyzerTool::is_sensitive_cookie("session_id"));
assert!(CookieAnalyzerTool::is_sensitive_cookie("PHPSESSID"));
assert!(CookieAnalyzerTool::is_sensitive_cookie("JSESSIONID"));
assert!(CookieAnalyzerTool::is_sensitive_cookie("connect.sid"));
assert!(CookieAnalyzerTool::is_sensitive_cookie("asp.net_sessionid"));
assert!(CookieAnalyzerTool::is_sensitive_cookie("auth_token"));
assert!(CookieAnalyzerTool::is_sensitive_cookie("jwt_access"));
assert!(CookieAnalyzerTool::is_sensitive_cookie("csrf_token"));
assert!(CookieAnalyzerTool::is_sensitive_cookie("my_sess_cookie"));
assert!(CookieAnalyzerTool::is_sensitive_cookie("SID"));
}
#[test]
fn is_sensitive_cookie_non_sensitive() {
assert!(!CookieAnalyzerTool::is_sensitive_cookie("theme"));
assert!(!CookieAnalyzerTool::is_sensitive_cookie("language"));
assert!(!CookieAnalyzerTool::is_sensitive_cookie("_ga"));
assert!(!CookieAnalyzerTool::is_sensitive_cookie("tracking"));
}
#[test]
fn parse_empty_cookie_header() {
let cookie = CookieAnalyzerTool::parse_set_cookie("");
assert_eq!(cookie.name, "");
assert_eq!(cookie.value, "");
}
}
impl PentestTool for CookieAnalyzerTool {
fn name(&self) -> &str {
"cookie_analyzer"
@@ -123,96 +199,96 @@ impl PentestTool for CookieAnalyzerTool {
&'a self,
input: serde_json::Value,
context: &'a PentestToolContext,
) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<PentestToolResult, CoreError>> + Send + 'a>> {
) -> std::pin::Pin<
Box<dyn std::future::Future<Output = Result<PentestToolResult, CoreError>> + Send + 'a>,
> {
Box::pin(async move {
let url = input
.get("url")
.and_then(|v| v.as_str())
.ok_or_else(|| CoreError::Dast("Missing required 'url' parameter".to_string()))?;
let url = input
.get("url")
.and_then(|v| v.as_str())
.ok_or_else(|| CoreError::Dast("Missing required 'url' parameter".to_string()))?;
let login_url = input.get("login_url").and_then(|v| v.as_str());
let login_url = input.get("login_url").and_then(|v| v.as_str());
let target_id = context
.target
.id
.map(|oid| oid.to_hex())
.unwrap_or_else(|| "unknown".to_string());
let target_id = context
.target
.id
.map(|oid| oid.to_hex())
.unwrap_or_else(|| "unknown".to_string());
let mut findings = Vec::new();
let mut cookie_data = Vec::new();
let mut findings = Vec::new();
let mut cookie_data = Vec::new();
// Collect Set-Cookie headers from the main URL and optional login URL
let urls_to_check: Vec<&str> = std::iter::once(url)
.chain(login_url.into_iter())
.collect();
// Collect Set-Cookie headers from the main URL and optional login URL
let urls_to_check: Vec<&str> = std::iter::once(url).chain(login_url).collect();
for check_url in &urls_to_check {
// Use a client that does NOT follow redirects so we catch cookies on redirect responses
let no_redirect_client = reqwest::Client::builder()
.danger_accept_invalid_certs(true)
.redirect(reqwest::redirect::Policy::none())
.timeout(std::time::Duration::from_secs(15))
.build()
.map_err(|e| CoreError::Dast(format!("Client build error: {e}")))?;
for check_url in &urls_to_check {
// Use a client that does NOT follow redirects so we catch cookies on redirect responses
let no_redirect_client = reqwest::Client::builder()
.danger_accept_invalid_certs(true)
.redirect(reqwest::redirect::Policy::none())
.timeout(std::time::Duration::from_secs(15))
.build()
.map_err(|e| CoreError::Dast(format!("Client build error: {e}")))?;
let response = match no_redirect_client.get(*check_url).send().await {
Ok(r) => r,
Err(e) => {
// Try with the main client that follows redirects
match self.http.get(*check_url).send().await {
Ok(r) => r,
Err(_) => continue,
let response = match no_redirect_client.get(*check_url).send().await {
Ok(r) => r,
Err(_e) => {
// Try with the main client that follows redirects
match self.http.get(*check_url).send().await {
Ok(r) => r,
Err(_) => continue,
}
}
}
};
};
let status = response.status().as_u16();
let set_cookie_headers: Vec<String> = response
.headers()
.get_all("set-cookie")
.iter()
.filter_map(|v| v.to_str().ok().map(String::from))
.collect();
let status = response.status().as_u16();
let set_cookie_headers: Vec<String> = response
.headers()
.get_all("set-cookie")
.iter()
.filter_map(|v| v.to_str().ok().map(String::from))
.collect();
for raw_cookie in &set_cookie_headers {
let cookie = Self::parse_set_cookie(raw_cookie);
let is_sensitive = Self::is_sensitive_cookie(&cookie.name);
let is_https = check_url.starts_with("https://");
for raw_cookie in &set_cookie_headers {
let cookie = Self::parse_set_cookie(raw_cookie);
let is_sensitive = Self::is_sensitive_cookie(&cookie.name);
let is_https = check_url.starts_with("https://");
let cookie_info = json!({
"name": cookie.name,
"secure": cookie.secure,
"http_only": cookie.http_only,
"same_site": cookie.same_site,
"domain": cookie.domain,
"path": cookie.path,
"is_sensitive": is_sensitive,
"url": check_url,
});
cookie_data.push(cookie_info);
let cookie_info = json!({
"name": cookie.name,
"secure": cookie.secure,
"http_only": cookie.http_only,
"same_site": cookie.same_site,
"domain": cookie.domain,
"path": cookie.path,
"is_sensitive": is_sensitive,
"url": check_url,
});
cookie_data.push(cookie_info);
// Check: missing Secure flag
if !cookie.secure && (is_https || is_sensitive) {
let severity = if is_sensitive {
Severity::High
} else {
Severity::Medium
};
// Check: missing Secure flag
if !cookie.secure && (is_https || is_sensitive) {
let severity = if is_sensitive {
Severity::High
} else {
Severity::Medium
};
let evidence = DastEvidence {
request_method: "GET".to_string(),
request_url: check_url.to_string(),
request_headers: None,
request_body: None,
response_status: status,
response_headers: None,
response_snippet: Some(cookie.raw.clone()),
screenshot_path: None,
payload: None,
response_time_ms: None,
};
let evidence = DastEvidence {
request_method: "GET".to_string(),
request_url: check_url.to_string(),
request_headers: None,
request_body: None,
response_status: status,
response_headers: None,
response_snippet: Some(cookie.raw.clone()),
screenshot_path: None,
payload: None,
response_time_ms: None,
};
let mut finding = DastFinding::new(
let mut finding = DastFinding::new(
String::new(),
target_id.clone(),
DastVulnType::CookieSecurity,
@@ -226,32 +302,32 @@ impl PentestTool for CookieAnalyzerTool {
check_url.to_string(),
"GET".to_string(),
);
finding.cwe = Some("CWE-614".to_string());
finding.evidence = vec![evidence];
finding.remediation = Some(
"Add the 'Secure' attribute to the Set-Cookie header to ensure the \
finding.cwe = Some("CWE-614".to_string());
finding.evidence = vec![evidence];
finding.remediation = Some(
"Add the 'Secure' attribute to the Set-Cookie header to ensure the \
cookie is only sent over HTTPS connections."
.to_string(),
);
findings.push(finding);
}
.to_string(),
);
findings.push(finding);
}
// Check: missing HttpOnly flag on sensitive cookies
if !cookie.http_only && is_sensitive {
let evidence = DastEvidence {
request_method: "GET".to_string(),
request_url: check_url.to_string(),
request_headers: None,
request_body: None,
response_status: status,
response_headers: None,
response_snippet: Some(cookie.raw.clone()),
screenshot_path: None,
payload: None,
response_time_ms: None,
};
// Check: missing HttpOnly flag on sensitive cookies
if !cookie.http_only && is_sensitive {
let evidence = DastEvidence {
request_method: "GET".to_string(),
request_url: check_url.to_string(),
request_headers: None,
request_body: None,
response_status: status,
response_headers: None,
response_snippet: Some(cookie.raw.clone()),
screenshot_path: None,
payload: None,
response_time_ms: None,
};
let mut finding = DastFinding::new(
let mut finding = DastFinding::new(
String::new(),
target_id.clone(),
DastVulnType::CookieSecurity,
@@ -265,137 +341,137 @@ impl PentestTool for CookieAnalyzerTool {
check_url.to_string(),
"GET".to_string(),
);
finding.cwe = Some("CWE-1004".to_string());
finding.evidence = vec![evidence];
finding.remediation = Some(
"Add the 'HttpOnly' attribute to the Set-Cookie header to prevent \
finding.cwe = Some("CWE-1004".to_string());
finding.evidence = vec![evidence];
finding.remediation = Some(
"Add the 'HttpOnly' attribute to the Set-Cookie header to prevent \
JavaScript access to the cookie."
.to_string(),
);
findings.push(finding);
}
.to_string(),
);
findings.push(finding);
}
// Check: missing or weak SameSite
if is_sensitive {
let weak_same_site = match &cookie.same_site {
None => true,
Some(ss) => ss == "none",
};
if weak_same_site {
let evidence = DastEvidence {
request_method: "GET".to_string(),
request_url: check_url.to_string(),
request_headers: None,
request_body: None,
response_status: status,
response_headers: None,
response_snippet: Some(cookie.raw.clone()),
screenshot_path: None,
payload: None,
response_time_ms: None,
// Check: missing or weak SameSite
if is_sensitive {
let weak_same_site = match &cookie.same_site {
None => true,
Some(ss) => ss == "none",
};
let desc = if cookie.same_site.is_none() {
format!(
if weak_same_site {
let evidence = DastEvidence {
request_method: "GET".to_string(),
request_url: check_url.to_string(),
request_headers: None,
request_body: None,
response_status: status,
response_headers: None,
response_snippet: Some(cookie.raw.clone()),
screenshot_path: None,
payload: None,
response_time_ms: None,
};
let desc = if cookie.same_site.is_none() {
format!(
"The session/auth cookie '{}' does not have a SameSite attribute. \
This may allow cross-site request forgery (CSRF) attacks.",
cookie.name
)
} else {
format!(
} else {
format!(
"The session/auth cookie '{}' has SameSite=None, which allows it \
to be sent in cross-site requests, enabling CSRF attacks.",
cookie.name
)
};
};
let mut finding = DastFinding::new(
String::new(),
target_id.clone(),
DastVulnType::CookieSecurity,
format!("Cookie '{}' missing or weak SameSite", cookie.name),
desc,
Severity::Medium,
check_url.to_string(),
"GET".to_string(),
);
finding.cwe = Some("CWE-1275".to_string());
finding.evidence = vec![evidence];
finding.remediation = Some(
"Set 'SameSite=Strict' or 'SameSite=Lax' on session/auth cookies \
let mut finding = DastFinding::new(
String::new(),
target_id.clone(),
DastVulnType::CookieSecurity,
format!("Cookie '{}' missing or weak SameSite", cookie.name),
desc,
Severity::Medium,
check_url.to_string(),
"GET".to_string(),
);
finding.cwe = Some("CWE-1275".to_string());
finding.evidence = vec![evidence];
finding.remediation = Some(
"Set 'SameSite=Strict' or 'SameSite=Lax' on session/auth cookies \
to prevent cross-site request inclusion."
.to_string(),
);
findings.push(finding);
.to_string(),
);
findings.push(finding);
}
}
}
// Check: overly broad domain
if let Some(ref domain) = cookie.domain {
// A domain starting with a dot applies to all subdomains
let dot_domain = domain.starts_with('.');
// Count domain parts - if only 2 parts (e.g., .example.com), it's broad
let parts: Vec<&str> = domain.trim_start_matches('.').split('.').collect();
if dot_domain && parts.len() <= 2 && is_sensitive {
let evidence = DastEvidence {
request_method: "GET".to_string(),
request_url: check_url.to_string(),
request_headers: None,
request_body: None,
response_status: status,
response_headers: None,
response_snippet: Some(cookie.raw.clone()),
screenshot_path: None,
payload: None,
response_time_ms: None,
};
// Check: overly broad domain
if let Some(ref domain) = cookie.domain {
// A domain starting with a dot applies to all subdomains
let dot_domain = domain.starts_with('.');
// Count domain parts - if only 2 parts (e.g., .example.com), it's broad
let parts: Vec<&str> = domain.trim_start_matches('.').split('.').collect();
if dot_domain && parts.len() <= 2 && is_sensitive {
let evidence = DastEvidence {
request_method: "GET".to_string(),
request_url: check_url.to_string(),
request_headers: None,
request_body: None,
response_status: status,
response_headers: None,
response_snippet: Some(cookie.raw.clone()),
screenshot_path: None,
payload: None,
response_time_ms: None,
};
let mut finding = DastFinding::new(
String::new(),
target_id.clone(),
DastVulnType::CookieSecurity,
format!("Cookie '{}' has overly broad domain", cookie.name),
format!(
"The cookie '{}' is scoped to domain '{}' which includes all \
let mut finding = DastFinding::new(
String::new(),
target_id.clone(),
DastVulnType::CookieSecurity,
format!("Cookie '{}' has overly broad domain", cookie.name),
format!(
"The cookie '{}' is scoped to domain '{}' which includes all \
subdomains. If any subdomain is compromised, the attacker can \
access this cookie.",
cookie.name, domain
),
Severity::Low,
check_url.to_string(),
"GET".to_string(),
);
finding.cwe = Some("CWE-1004".to_string());
finding.evidence = vec![evidence];
finding.remediation = Some(
cookie.name, domain
),
Severity::Low,
check_url.to_string(),
"GET".to_string(),
);
finding.cwe = Some("CWE-1004".to_string());
finding.evidence = vec![evidence];
finding.remediation = Some(
"Restrict the cookie domain to the specific subdomain that needs it \
rather than the entire parent domain."
.to_string(),
);
findings.push(finding);
findings.push(finding);
}
}
}
}
}
let count = findings.len();
info!(url, findings = count, "Cookie analysis complete");
let count = findings.len();
info!(url, findings = count, "Cookie analysis complete");
Ok(PentestToolResult {
summary: if count > 0 {
format!("Found {count} cookie security issues.")
} else if cookie_data.is_empty() {
"No cookies were set by the target.".to_string()
} else {
"All cookies have proper security attributes.".to_string()
},
findings,
data: json!({
"cookies": cookie_data,
"total_cookies": cookie_data.len(),
}),
})
Ok(PentestToolResult {
summary: if count > 0 {
format!("Found {count} cookie security issues.")
} else if cookie_data.is_empty() {
"No cookies were set by the target.".to_string()
} else {
"All cookies have proper security attributes.".to_string()
},
findings,
data: json!({
"cookies": cookie_data,
"total_cookies": cookie_data.len(),
}),
})
})
}
}

View File

@@ -22,19 +22,60 @@ impl CorsCheckerTool {
vec![
("null_origin", "null".to_string()),
("evil_domain", "https://evil.com".to_string()),
(
"subdomain_spoof",
format!("https://{target_host}.evil.com"),
),
(
"prefix_spoof",
format!("https://evil-{target_host}"),
),
("subdomain_spoof", format!("https://{target_host}.evil.com")),
("prefix_spoof", format!("https://evil-{target_host}")),
("http_downgrade", format!("http://{target_host}")),
]
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_origins_contains_expected_variants() {
let origins = CorsCheckerTool::test_origins("example.com");
assert_eq!(origins.len(), 5);
let names: Vec<&str> = origins.iter().map(|(name, _)| *name).collect();
assert!(names.contains(&"null_origin"));
assert!(names.contains(&"evil_domain"));
assert!(names.contains(&"subdomain_spoof"));
assert!(names.contains(&"prefix_spoof"));
assert!(names.contains(&"http_downgrade"));
}
#[test]
fn test_origins_uses_target_host() {
let origins = CorsCheckerTool::test_origins("myapp.io");
let subdomain = origins
.iter()
.find(|(n, _)| *n == "subdomain_spoof")
.unwrap();
assert_eq!(subdomain.1, "https://myapp.io.evil.com");
let prefix = origins.iter().find(|(n, _)| *n == "prefix_spoof").unwrap();
assert_eq!(prefix.1, "https://evil-myapp.io");
let http_downgrade = origins
.iter()
.find(|(n, _)| *n == "http_downgrade")
.unwrap();
assert_eq!(http_downgrade.1, "http://myapp.io");
}
#[test]
fn test_origins_null_and_evil_are_static() {
let origins = CorsCheckerTool::test_origins("anything.com");
let null_origin = origins.iter().find(|(n, _)| *n == "null_origin").unwrap();
assert_eq!(null_origin.1, "null");
let evil = origins.iter().find(|(n, _)| *n == "evil_domain").unwrap();
assert_eq!(evil.1, "https://evil.com");
}
}
impl PentestTool for CorsCheckerTool {
fn name(&self) -> &str {
"cors_checker"
@@ -68,82 +109,82 @@ impl PentestTool for CorsCheckerTool {
&'a self,
input: serde_json::Value,
context: &'a PentestToolContext,
) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<PentestToolResult, CoreError>> + Send + 'a>> {
) -> std::pin::Pin<
Box<dyn std::future::Future<Output = Result<PentestToolResult, CoreError>> + Send + 'a>,
> {
Box::pin(async move {
let url = input
.get("url")
.and_then(|v| v.as_str())
.ok_or_else(|| CoreError::Dast("Missing required 'url' parameter".to_string()))?;
let url = input
.get("url")
.and_then(|v| v.as_str())
.ok_or_else(|| CoreError::Dast("Missing required 'url' parameter".to_string()))?;
let additional_origins: Vec<String> = input
.get("additional_origins")
.and_then(|v| v.as_array())
.map(|arr| {
arr.iter()
.filter_map(|v| v.as_str().map(String::from))
.collect()
})
.unwrap_or_default();
let additional_origins: Vec<String> = input
.get("additional_origins")
.and_then(|v| v.as_array())
.map(|arr| {
arr.iter()
.filter_map(|v| v.as_str().map(String::from))
.collect()
})
.unwrap_or_default();
let target_id = context
.target
.id
.map(|oid| oid.to_hex())
.unwrap_or_else(|| "unknown".to_string());
let target_id = context
.target
.id
.map(|oid| oid.to_hex())
.unwrap_or_else(|| "unknown".to_string());
let target_host = url::Url::parse(url)
.ok()
.and_then(|u| u.host_str().map(String::from))
.unwrap_or_else(|| url.to_string());
let target_host = url::Url::parse(url)
.ok()
.and_then(|u| u.host_str().map(String::from))
.unwrap_or_else(|| url.to_string());
let mut findings = Vec::new();
let mut cors_data: Vec<serde_json::Value> = Vec::new();
let mut findings = Vec::new();
let mut cors_data: Vec<serde_json::Value> = Vec::new();
// First, send a request without Origin to get baseline
let baseline = self
.http
.get(url)
.send()
.await
.map_err(|e| CoreError::Dast(format!("Failed to fetch {url}: {e}")))?;
// First, send a request without Origin to get baseline
let baseline = self
.http
.get(url)
.send()
.await
.map_err(|e| CoreError::Dast(format!("Failed to fetch {url}: {e}")))?;
let baseline_acao = baseline
.headers()
.get("access-control-allow-origin")
.and_then(|v| v.to_str().ok())
.map(String::from);
let baseline_acao = baseline
.headers()
.get("access-control-allow-origin")
.and_then(|v| v.to_str().ok())
.map(String::from);
cors_data.push(json!({
"origin": null,
"acao": baseline_acao,
}));
cors_data.push(json!({
"origin": null,
"acao": baseline_acao,
}));
// Check for wildcard + credentials (dangerous combo)
if let Some(ref acao) = baseline_acao {
if acao == "*" {
let acac = baseline
.headers()
.get("access-control-allow-credentials")
.and_then(|v| v.to_str().ok())
.unwrap_or("");
// Check for wildcard + credentials (dangerous combo)
if let Some(ref acao) = baseline_acao {
if acao == "*" {
let acac = baseline
.headers()
.get("access-control-allow-credentials")
.and_then(|v| v.to_str().ok())
.unwrap_or("");
if acac.to_lowercase() == "true" {
let evidence = DastEvidence {
request_method: "GET".to_string(),
request_url: url.to_string(),
request_headers: None,
request_body: None,
response_status: baseline.status().as_u16(),
response_headers: None,
response_snippet: Some(format!(
"Access-Control-Allow-Origin: *\nAccess-Control-Allow-Credentials: true"
)),
screenshot_path: None,
payload: None,
response_time_ms: None,
};
if acac.to_lowercase() == "true" {
let evidence = DastEvidence {
request_method: "GET".to_string(),
request_url: url.to_string(),
request_headers: None,
request_body: None,
response_status: baseline.status().as_u16(),
response_headers: None,
response_snippet: Some("Access-Control-Allow-Origin: *\nAccess-Control-Allow-Credentials: true".to_string()),
screenshot_path: None,
payload: None,
response_time_ms: None,
};
let mut finding = DastFinding::new(
let mut finding = DastFinding::new(
String::new(),
target_id.clone(),
DastVulnType::CorsMisconfiguration,
@@ -157,254 +198,251 @@ impl PentestTool for CorsCheckerTool {
url.to_string(),
"GET".to_string(),
);
finding.cwe = Some("CWE-942".to_string());
finding.evidence = vec![evidence];
finding.remediation = Some(
"Never combine Access-Control-Allow-Origin: * with \
finding.cwe = Some("CWE-942".to_string());
finding.evidence = vec![evidence];
finding.remediation = Some(
"Never combine Access-Control-Allow-Origin: * with \
Access-Control-Allow-Credentials: true. Specify explicit allowed origins."
.to_string(),
);
findings.push(finding);
.to_string(),
);
findings.push(finding);
}
}
}
}
// Test with various Origin headers
let mut test_origins = Self::test_origins(&target_host);
for origin in &additional_origins {
test_origins.push(("custom", origin.clone()));
}
// Test with various Origin headers
let mut test_origins = Self::test_origins(&target_host);
for origin in &additional_origins {
test_origins.push(("custom", origin.clone()));
}
for (test_name, origin) in &test_origins {
let resp = match self
for (test_name, origin) in &test_origins {
let resp = match self
.http
.get(url)
.header("Origin", origin.as_str())
.send()
.await
{
Ok(r) => r,
Err(_) => continue,
};
let status = resp.status().as_u16();
let acao = resp
.headers()
.get("access-control-allow-origin")
.and_then(|v| v.to_str().ok())
.map(String::from);
let acac = resp
.headers()
.get("access-control-allow-credentials")
.and_then(|v| v.to_str().ok())
.map(String::from);
let acam = resp
.headers()
.get("access-control-allow-methods")
.and_then(|v| v.to_str().ok())
.map(String::from);
cors_data.push(json!({
"test": test_name,
"origin": origin,
"acao": acao,
"acac": acac,
"acam": acam,
"status": status,
}));
// Check if the origin was reflected back
if let Some(ref acao_val) = acao {
let origin_reflected = acao_val == origin;
let credentials_allowed = acac
.as_ref()
.map(|v| v.to_lowercase() == "true")
.unwrap_or(false);
if origin_reflected && *test_name != "http_downgrade" {
let severity = if credentials_allowed {
Severity::Critical
} else {
Severity::High
};
let resp_headers: HashMap<String, String> = resp
.headers()
.iter()
.map(|(k, v)| (k.to_string(), v.to_str().unwrap_or("").to_string()))
.collect();
let evidence = DastEvidence {
request_method: "GET".to_string(),
request_url: url.to_string(),
request_headers: Some(
[("Origin".to_string(), origin.clone())]
.into_iter()
.collect(),
),
request_body: None,
response_status: status,
response_headers: Some(resp_headers),
response_snippet: Some(format!(
"Origin: {origin}\nAccess-Control-Allow-Origin: {acao_val}\n\
Access-Control-Allow-Credentials: {}",
acac.as_deref().unwrap_or("not set")
)),
screenshot_path: None,
payload: Some(origin.clone()),
response_time_ms: None,
};
let title = match *test_name {
"null_origin" => "CORS accepts null origin".to_string(),
"evil_domain" => "CORS reflects arbitrary origin".to_string(),
"subdomain_spoof" => {
"CORS vulnerable to subdomain spoofing".to_string()
}
"prefix_spoof" => "CORS vulnerable to prefix spoofing".to_string(),
_ => format!("CORS reflects untrusted origin ({test_name})"),
};
let cred_note = if credentials_allowed {
" Combined with Access-Control-Allow-Credentials: true, this allows \
the attacker to steal authenticated data."
} else {
""
};
let mut finding = DastFinding::new(
String::new(),
target_id.clone(),
DastVulnType::CorsMisconfiguration,
title,
format!(
"The endpoint {url} reflects the Origin header '{origin}' back in \
Access-Control-Allow-Origin, allowing cross-origin requests from \
untrusted domains.{cred_note}"
),
severity,
url.to_string(),
"GET".to_string(),
);
finding.cwe = Some("CWE-942".to_string());
finding.exploitable = credentials_allowed;
finding.evidence = vec![evidence];
finding.remediation = Some(
"Validate the Origin header against a whitelist of trusted origins. \
Do not reflect the Origin header value directly. Use specific allowed \
origins instead of wildcards."
.to_string(),
);
findings.push(finding);
warn!(
url,
test_name,
origin,
credentials = credentials_allowed,
"CORS misconfiguration detected"
);
}
// Special case: HTTP downgrade
if *test_name == "http_downgrade" && origin_reflected && credentials_allowed {
let evidence = DastEvidence {
request_method: "GET".to_string(),
request_url: url.to_string(),
request_headers: Some(
[("Origin".to_string(), origin.clone())]
.into_iter()
.collect(),
),
request_body: None,
response_status: status,
response_headers: None,
response_snippet: Some(format!(
"HTTP origin accepted: {origin} -> ACAO: {acao_val}"
)),
screenshot_path: None,
payload: Some(origin.clone()),
response_time_ms: None,
};
let mut finding = DastFinding::new(
String::new(),
target_id.clone(),
DastVulnType::CorsMisconfiguration,
"CORS allows HTTP origin with credentials".to_string(),
format!(
"The HTTPS endpoint {url} accepts the HTTP origin {origin} with \
credentials. An attacker performing a man-in-the-middle attack on \
the HTTP version could steal authenticated data."
),
Severity::High,
url.to_string(),
"GET".to_string(),
);
finding.cwe = Some("CWE-942".to_string());
finding.evidence = vec![evidence];
finding.remediation = Some(
"Do not accept HTTP origins for HTTPS endpoints. Ensure CORS \
origin validation enforces the https:// scheme."
.to_string(),
);
findings.push(finding);
}
}
}
// Also send a preflight OPTIONS request
if let Ok(resp) = self
.http
.get(url)
.header("Origin", origin.as_str())
.request(reqwest::Method::OPTIONS, url)
.header("Origin", "https://evil.com")
.header("Access-Control-Request-Method", "POST")
.header(
"Access-Control-Request-Headers",
"Authorization, Content-Type",
)
.send()
.await
{
Ok(r) => r,
Err(_) => continue,
};
let acam = resp
.headers()
.get("access-control-allow-methods")
.and_then(|v| v.to_str().ok())
.map(String::from);
let status = resp.status().as_u16();
let acao = resp
.headers()
.get("access-control-allow-origin")
.and_then(|v| v.to_str().ok())
.map(String::from);
let acah = resp
.headers()
.get("access-control-allow-headers")
.and_then(|v| v.to_str().ok())
.map(String::from);
let acac = resp
.headers()
.get("access-control-allow-credentials")
.and_then(|v| v.to_str().ok())
.map(String::from);
let acam = resp
.headers()
.get("access-control-allow-methods")
.and_then(|v| v.to_str().ok())
.map(String::from);
cors_data.push(json!({
"test": test_name,
"origin": origin,
"acao": acao,
"acac": acac,
"acam": acam,
"status": status,
}));
// Check if the origin was reflected back
if let Some(ref acao_val) = acao {
let origin_reflected = acao_val == origin;
let credentials_allowed = acac
.as_ref()
.map(|v| v.to_lowercase() == "true")
.unwrap_or(false);
if origin_reflected && *test_name != "http_downgrade" {
let severity = if credentials_allowed {
Severity::Critical
} else {
Severity::High
};
let resp_headers: HashMap<String, String> = resp
.headers()
.iter()
.map(|(k, v)| (k.to_string(), v.to_str().unwrap_or("").to_string()))
.collect();
let evidence = DastEvidence {
request_method: "GET".to_string(),
request_url: url.to_string(),
request_headers: Some(
[("Origin".to_string(), origin.clone())]
.into_iter()
.collect(),
),
request_body: None,
response_status: status,
response_headers: Some(resp_headers),
response_snippet: Some(format!(
"Origin: {origin}\nAccess-Control-Allow-Origin: {acao_val}\n\
Access-Control-Allow-Credentials: {}",
acac.as_deref().unwrap_or("not set")
)),
screenshot_path: None,
payload: Some(origin.clone()),
response_time_ms: None,
};
let title = match *test_name {
"null_origin" => {
"CORS accepts null origin".to_string()
}
"evil_domain" => {
"CORS reflects arbitrary origin".to_string()
}
"subdomain_spoof" => {
"CORS vulnerable to subdomain spoofing".to_string()
}
"prefix_spoof" => {
"CORS vulnerable to prefix spoofing".to_string()
}
_ => format!("CORS reflects untrusted origin ({test_name})"),
};
let cred_note = if credentials_allowed {
" Combined with Access-Control-Allow-Credentials: true, this allows \
the attacker to steal authenticated data."
} else {
""
};
let mut finding = DastFinding::new(
String::new(),
target_id.clone(),
DastVulnType::CorsMisconfiguration,
title,
format!(
"The endpoint {url} reflects the Origin header '{origin}' back in \
Access-Control-Allow-Origin, allowing cross-origin requests from \
untrusted domains.{cred_note}"
),
severity,
url.to_string(),
"GET".to_string(),
);
finding.cwe = Some("CWE-942".to_string());
finding.exploitable = credentials_allowed;
finding.evidence = vec![evidence];
finding.remediation = Some(
"Validate the Origin header against a whitelist of trusted origins. \
Do not reflect the Origin header value directly. Use specific allowed \
origins instead of wildcards."
.to_string(),
);
findings.push(finding);
warn!(
url,
test_name,
origin,
credentials = credentials_allowed,
"CORS misconfiguration detected"
);
}
// Special case: HTTP downgrade
if *test_name == "http_downgrade" && origin_reflected && credentials_allowed {
let evidence = DastEvidence {
request_method: "GET".to_string(),
request_url: url.to_string(),
request_headers: Some(
[("Origin".to_string(), origin.clone())]
.into_iter()
.collect(),
),
request_body: None,
response_status: status,
response_headers: None,
response_snippet: Some(format!(
"HTTP origin accepted: {origin} -> ACAO: {acao_val}"
)),
screenshot_path: None,
payload: Some(origin.clone()),
response_time_ms: None,
};
let mut finding = DastFinding::new(
String::new(),
target_id.clone(),
DastVulnType::CorsMisconfiguration,
"CORS allows HTTP origin with credentials".to_string(),
format!(
"The HTTPS endpoint {url} accepts the HTTP origin {origin} with \
credentials. An attacker performing a man-in-the-middle attack on \
the HTTP version could steal authenticated data."
),
Severity::High,
url.to_string(),
"GET".to_string(),
);
finding.cwe = Some("CWE-942".to_string());
finding.evidence = vec![evidence];
finding.remediation = Some(
"Do not accept HTTP origins for HTTPS endpoints. Ensure CORS \
origin validation enforces the https:// scheme."
.to_string(),
);
findings.push(finding);
}
cors_data.push(json!({
"test": "preflight",
"status": resp.status().as_u16(),
"allow_methods": acam,
"allow_headers": acah,
}));
}
}
// Also send a preflight OPTIONS request
if let Ok(resp) = self
.http
.request(reqwest::Method::OPTIONS, url)
.header("Origin", "https://evil.com")
.header("Access-Control-Request-Method", "POST")
.header("Access-Control-Request-Headers", "Authorization, Content-Type")
.send()
.await
{
let acam = resp
.headers()
.get("access-control-allow-methods")
.and_then(|v| v.to_str().ok())
.map(String::from);
let count = findings.len();
info!(url, findings = count, "CORS check complete");
let acah = resp
.headers()
.get("access-control-allow-headers")
.and_then(|v| v.to_str().ok())
.map(String::from);
cors_data.push(json!({
"test": "preflight",
"status": resp.status().as_u16(),
"allow_methods": acam,
"allow_headers": acah,
}));
}
let count = findings.len();
info!(url, findings = count, "CORS check complete");
Ok(PentestToolResult {
summary: if count > 0 {
format!("Found {count} CORS misconfiguration issues for {url}.")
} else {
format!("CORS configuration appears secure for {url}.")
},
findings,
data: json!({
"tests": cors_data,
}),
})
Ok(PentestToolResult {
summary: if count > 0 {
format!("Found {count} CORS misconfiguration issues for {url}.")
} else {
format!("CORS configuration appears secure for {url}.")
},
findings,
data: json!({
"tests": cors_data,
}),
})
})
}
}

View File

@@ -47,7 +47,7 @@ impl CspAnalyzerTool {
url: &str,
target_id: &str,
status: u16,
csp_raw: &str,
_csp_raw: &str,
) -> Vec<DastFinding> {
let mut findings = Vec::new();
@@ -216,12 +216,18 @@ impl CspAnalyzerTool {
("object-src", "Controls plugins like Flash"),
("base-uri", "Controls the base URL for relative URLs"),
("form-action", "Controls where forms can submit to"),
("frame-ancestors", "Controls who can embed this page in iframes"),
(
"frame-ancestors",
"Controls who can embed this page in iframes",
),
];
for (dir_name, desc) in &important_directives {
if !directive_names.contains(dir_name)
&& !(has_default_src && *dir_name != "frame-ancestors" && *dir_name != "base-uri" && *dir_name != "form-action")
&& (!has_default_src
|| *dir_name == "frame-ancestors"
|| *dir_name == "base-uri"
|| *dir_name == "form-action")
{
let evidence = make_evidence(format!("CSP missing directive: {dir_name}"));
let mut finding = DastFinding::new(
@@ -258,6 +264,125 @@ impl CspAnalyzerTool {
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parse_csp_basic() {
let directives = CspAnalyzerTool::parse_csp(
"default-src 'self'; script-src 'self' https://cdn.example.com",
);
assert_eq!(directives.len(), 2);
assert_eq!(directives[0].name, "default-src");
assert_eq!(directives[0].values, vec!["'self'"]);
assert_eq!(directives[1].name, "script-src");
assert_eq!(
directives[1].values,
vec!["'self'", "https://cdn.example.com"]
);
}
#[test]
fn parse_csp_empty() {
let directives = CspAnalyzerTool::parse_csp("");
assert!(directives.is_empty());
}
#[test]
fn parse_csp_trailing_semicolons() {
let directives = CspAnalyzerTool::parse_csp("default-src 'none';;;");
assert_eq!(directives.len(), 1);
assert_eq!(directives[0].name, "default-src");
assert_eq!(directives[0].values, vec!["'none'"]);
}
#[test]
fn parse_csp_directive_without_value() {
let directives = CspAnalyzerTool::parse_csp("upgrade-insecure-requests");
assert_eq!(directives.len(), 1);
assert_eq!(directives[0].name, "upgrade-insecure-requests");
assert!(directives[0].values.is_empty());
}
#[test]
fn parse_csp_names_are_lowercased() {
let directives = CspAnalyzerTool::parse_csp("Script-Src 'self'");
assert_eq!(directives[0].name, "script-src");
}
#[test]
fn analyze_detects_unsafe_inline() {
let directives = CspAnalyzerTool::parse_csp("script-src 'self' 'unsafe-inline'");
let findings =
CspAnalyzerTool::analyze_directives(&directives, "https://example.com", "t1", 200, "");
assert!(findings.iter().any(|f| f.title.contains("unsafe-inline")));
}
#[test]
fn analyze_detects_unsafe_eval() {
let directives = CspAnalyzerTool::parse_csp("script-src 'self' 'unsafe-eval'");
let findings =
CspAnalyzerTool::analyze_directives(&directives, "https://example.com", "t1", 200, "");
assert!(findings.iter().any(|f| f.title.contains("unsafe-eval")));
}
#[test]
fn analyze_detects_wildcard() {
let directives = CspAnalyzerTool::parse_csp("img-src *");
let findings =
CspAnalyzerTool::analyze_directives(&directives, "https://example.com", "t1", 200, "");
assert!(findings.iter().any(|f| f.title.contains("wildcard")));
}
#[test]
fn analyze_detects_data_uri_in_script_src() {
let directives = CspAnalyzerTool::parse_csp("script-src 'self' data:");
let findings =
CspAnalyzerTool::analyze_directives(&directives, "https://example.com", "t1", 200, "");
assert!(findings.iter().any(|f| f.title.contains("data:")));
}
#[test]
fn analyze_detects_http_sources() {
let directives = CspAnalyzerTool::parse_csp("script-src http:");
let findings =
CspAnalyzerTool::analyze_directives(&directives, "https://example.com", "t1", 200, "");
assert!(findings.iter().any(|f| f.title.contains("HTTP sources")));
}
#[test]
fn analyze_detects_missing_directives_without_default_src() {
let directives = CspAnalyzerTool::parse_csp("img-src 'self'");
let findings =
CspAnalyzerTool::analyze_directives(&directives, "https://example.com", "t1", 200, "");
let missing_names: Vec<&str> = findings
.iter()
.filter(|f| f.title.contains("missing"))
.map(|f| f.title.as_str())
.collect();
// Should flag script-src, object-src, base-uri, form-action, frame-ancestors
assert!(missing_names.len() >= 4);
}
#[test]
fn analyze_good_csp_no_unsafe_findings() {
let directives = CspAnalyzerTool::parse_csp(
"default-src 'none'; script-src 'self'; style-src 'self'; \
img-src 'self'; object-src 'none'; base-uri 'self'; \
form-action 'self'; frame-ancestors 'none'",
);
let findings =
CspAnalyzerTool::analyze_directives(&directives, "https://example.com", "t1", 200, "");
// A well-configured CSP should not produce unsafe-inline/eval/wildcard findings
assert!(findings.iter().all(|f| {
!f.title.contains("unsafe-inline")
&& !f.title.contains("unsafe-eval")
&& !f.title.contains("wildcard")
}));
}
}
impl PentestTool for CspAnalyzerTool {
fn name(&self) -> &str {
"csp_analyzer"
@@ -285,163 +410,167 @@ impl PentestTool for CspAnalyzerTool {
&'a self,
input: serde_json::Value,
context: &'a PentestToolContext,
) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<PentestToolResult, CoreError>> + Send + 'a>> {
) -> std::pin::Pin<
Box<dyn std::future::Future<Output = Result<PentestToolResult, CoreError>> + Send + 'a>,
> {
Box::pin(async move {
let url = input
.get("url")
.and_then(|v| v.as_str())
.ok_or_else(|| CoreError::Dast("Missing required 'url' parameter".to_string()))?;
let url = input
.get("url")
.and_then(|v| v.as_str())
.ok_or_else(|| CoreError::Dast("Missing required 'url' parameter".to_string()))?;
let target_id = context
.target
.id
.map(|oid| oid.to_hex())
.unwrap_or_else(|| "unknown".to_string());
let target_id = context
.target
.id
.map(|oid| oid.to_hex())
.unwrap_or_else(|| "unknown".to_string());
let response = self
.http
.get(url)
.send()
.await
.map_err(|e| CoreError::Dast(format!("Failed to fetch {url}: {e}")))?;
let response = self
.http
.get(url)
.send()
.await
.map_err(|e| CoreError::Dast(format!("Failed to fetch {url}: {e}")))?;
let status = response.status().as_u16();
let status = response.status().as_u16();
// Check for CSP header
let csp_header = response
.headers()
.get("content-security-policy")
.and_then(|v| v.to_str().ok())
.map(String::from);
// Check for CSP header
let csp_header = response
.headers()
.get("content-security-policy")
.and_then(|v| v.to_str().ok())
.map(String::from);
// Also check for report-only variant
let csp_report_only = response
.headers()
.get("content-security-policy-report-only")
.and_then(|v| v.to_str().ok())
.map(String::from);
// Also check for report-only variant
let csp_report_only = response
.headers()
.get("content-security-policy-report-only")
.and_then(|v| v.to_str().ok())
.map(String::from);
let mut findings = Vec::new();
let mut csp_data = json!({});
let mut findings = Vec::new();
let mut csp_data = json!({});
match &csp_header {
Some(csp) => {
let directives = Self::parse_csp(csp);
let directive_map: serde_json::Value = directives
.iter()
.map(|d| (d.name.clone(), json!(d.values)))
.collect::<serde_json::Map<String, serde_json::Value>>()
.into();
match &csp_header {
Some(csp) => {
let directives = Self::parse_csp(csp);
let directive_map: serde_json::Value = directives
.iter()
.map(|d| (d.name.clone(), json!(d.values)))
.collect::<serde_json::Map<String, serde_json::Value>>()
.into();
csp_data["csp_header"] = json!(csp);
csp_data["directives"] = directive_map;
csp_data["csp_header"] = json!(csp);
csp_data["directives"] = directive_map;
findings.extend(Self::analyze_directives(
&directives,
url,
&target_id,
status,
csp,
));
}
None => {
csp_data["csp_header"] = json!(null);
findings.extend(Self::analyze_directives(
&directives,
url,
&target_id,
status,
csp,
));
}
None => {
csp_data["csp_header"] = json!(null);
let evidence = DastEvidence {
request_method: "GET".to_string(),
request_url: url.to_string(),
request_headers: None,
request_body: None,
response_status: status,
response_headers: None,
response_snippet: Some("Content-Security-Policy header is missing".to_string()),
screenshot_path: None,
payload: None,
response_time_ms: None,
};
let evidence = DastEvidence {
request_method: "GET".to_string(),
request_url: url.to_string(),
request_headers: None,
request_body: None,
response_status: status,
response_headers: None,
response_snippet: Some(
"Content-Security-Policy header is missing".to_string(),
),
screenshot_path: None,
payload: None,
response_time_ms: None,
};
let mut finding = DastFinding::new(
String::new(),
target_id.clone(),
DastVulnType::CspIssue,
"Missing Content-Security-Policy header".to_string(),
format!(
"No Content-Security-Policy header is present on {url}. \
let mut finding = DastFinding::new(
String::new(),
target_id.clone(),
DastVulnType::CspIssue,
"Missing Content-Security-Policy header".to_string(),
format!(
"No Content-Security-Policy header is present on {url}. \
Without CSP, the browser has no instructions on which sources are \
trusted, making XSS exploitation much easier."
),
Severity::Medium,
url.to_string(),
"GET".to_string(),
);
finding.cwe = Some("CWE-16".to_string());
finding.evidence = vec![evidence];
finding.remediation = Some(
),
Severity::Medium,
url.to_string(),
"GET".to_string(),
);
finding.cwe = Some("CWE-16".to_string());
finding.evidence = vec![evidence];
finding.remediation = Some(
"Add a Content-Security-Policy header. Start with a restrictive policy like \
\"default-src 'self'; script-src 'self'; style-src 'self'; img-src 'self' data:; \
object-src 'none'; frame-ancestors 'none'; base-uri 'self'\"."
.to_string(),
);
findings.push(finding);
findings.push(finding);
}
}
}
if let Some(ref report_only) = csp_report_only {
csp_data["csp_report_only"] = json!(report_only);
// If ONLY report-only exists (no enforcing CSP), warn
if csp_header.is_none() {
let evidence = DastEvidence {
request_method: "GET".to_string(),
request_url: url.to_string(),
request_headers: None,
request_body: None,
response_status: status,
response_headers: None,
response_snippet: Some(format!(
"Content-Security-Policy-Report-Only: {}",
report_only
)),
screenshot_path: None,
payload: None,
response_time_ms: None,
};
if let Some(ref report_only) = csp_report_only {
csp_data["csp_report_only"] = json!(report_only);
// If ONLY report-only exists (no enforcing CSP), warn
if csp_header.is_none() {
let evidence = DastEvidence {
request_method: "GET".to_string(),
request_url: url.to_string(),
request_headers: None,
request_body: None,
response_status: status,
response_headers: None,
response_snippet: Some(format!(
"Content-Security-Policy-Report-Only: {}",
report_only
)),
screenshot_path: None,
payload: None,
response_time_ms: None,
};
let mut finding = DastFinding::new(
String::new(),
target_id.clone(),
DastVulnType::CspIssue,
"CSP is report-only, not enforcing".to_string(),
"A Content-Security-Policy-Report-Only header is present but no enforcing \
let mut finding = DastFinding::new(
String::new(),
target_id.clone(),
DastVulnType::CspIssue,
"CSP is report-only, not enforcing".to_string(),
"A Content-Security-Policy-Report-Only header is present but no enforcing \
Content-Security-Policy header exists. Report-only mode only logs violations \
but does not block them."
.to_string(),
Severity::Low,
url.to_string(),
"GET".to_string(),
);
finding.cwe = Some("CWE-16".to_string());
finding.evidence = vec![evidence];
finding.remediation = Some(
.to_string(),
Severity::Low,
url.to_string(),
"GET".to_string(),
);
finding.cwe = Some("CWE-16".to_string());
finding.evidence = vec![evidence];
finding.remediation = Some(
"Once you have verified the CSP policy works correctly in report-only mode, \
deploy it as an enforcing Content-Security-Policy header."
.to_string(),
);
findings.push(finding);
findings.push(finding);
}
}
}
let count = findings.len();
info!(url, findings = count, "CSP analysis complete");
let count = findings.len();
info!(url, findings = count, "CSP analysis complete");
Ok(PentestToolResult {
summary: if count > 0 {
format!("Found {count} CSP issues for {url}.")
} else {
format!("Content-Security-Policy looks good for {url}.")
},
findings,
data: csp_data,
})
Ok(PentestToolResult {
summary: if count > 0 {
format!("Found {count} CSP issues for {url}.")
} else {
format!("Content-Security-Policy looks good for {url}.")
},
findings,
data: csp_data,
})
})
}
}

View File

@@ -8,6 +8,12 @@ use tracing::{info, warn};
/// Tool that checks email security configuration (DMARC and SPF records).
pub struct DmarcCheckerTool;
impl Default for DmarcCheckerTool {
fn default() -> Self {
Self::new()
}
}
impl DmarcCheckerTool {
pub fn new() -> Self {
Self
@@ -78,6 +84,105 @@ impl DmarcCheckerTool {
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parse_dmarc_policy_reject() {
let record = "v=DMARC1; p=reject; rua=mailto:dmarc@example.com";
assert_eq!(
DmarcCheckerTool::parse_dmarc_policy(record),
Some("reject".to_string())
);
}
#[test]
fn parse_dmarc_policy_none() {
let record = "v=DMARC1; p=none";
assert_eq!(
DmarcCheckerTool::parse_dmarc_policy(record),
Some("none".to_string())
);
}
#[test]
fn parse_dmarc_policy_quarantine() {
let record = "v=DMARC1; p=quarantine; sp=none";
assert_eq!(
DmarcCheckerTool::parse_dmarc_policy(record),
Some("quarantine".to_string())
);
}
#[test]
fn parse_dmarc_policy_missing() {
let record = "v=DMARC1; rua=mailto:test@example.com";
assert_eq!(DmarcCheckerTool::parse_dmarc_policy(record), None);
}
#[test]
fn parse_dmarc_subdomain_policy() {
let record = "v=DMARC1; p=reject; sp=quarantine";
assert_eq!(
DmarcCheckerTool::parse_dmarc_subdomain_policy(record),
Some("quarantine".to_string())
);
}
#[test]
fn parse_dmarc_subdomain_policy_missing() {
let record = "v=DMARC1; p=reject";
assert_eq!(DmarcCheckerTool::parse_dmarc_subdomain_policy(record), None);
}
#[test]
fn parse_dmarc_rua_present() {
let record = "v=DMARC1; p=reject; rua=mailto:dmarc@example.com";
assert_eq!(
DmarcCheckerTool::parse_dmarc_rua(record),
Some("mailto:dmarc@example.com".to_string())
);
}
#[test]
fn parse_dmarc_rua_missing() {
let record = "v=DMARC1; p=none";
assert_eq!(DmarcCheckerTool::parse_dmarc_rua(record), None);
}
#[test]
fn is_spf_record_valid() {
assert!(DmarcCheckerTool::is_spf_record(
"v=spf1 include:_spf.google.com -all"
));
assert!(DmarcCheckerTool::is_spf_record("v=spf1 -all"));
}
#[test]
fn is_spf_record_invalid() {
assert!(!DmarcCheckerTool::is_spf_record("v=DMARC1; p=reject"));
assert!(!DmarcCheckerTool::is_spf_record("some random txt record"));
}
#[test]
fn spf_soft_fail_detection() {
assert!(DmarcCheckerTool::spf_uses_soft_fail(
"v=spf1 include:_spf.google.com ~all"
));
assert!(!DmarcCheckerTool::spf_uses_soft_fail(
"v=spf1 include:_spf.google.com -all"
));
}
#[test]
fn spf_allows_all_detection() {
assert!(DmarcCheckerTool::spf_allows_all("v=spf1 +all"));
assert!(!DmarcCheckerTool::spf_allows_all("v=spf1 -all"));
assert!(!DmarcCheckerTool::spf_allows_all("v=spf1 ~all"));
}
}
impl PentestTool for DmarcCheckerTool {
fn name(&self) -> &str {
"dmarc_checker"
@@ -105,43 +210,89 @@ impl PentestTool for DmarcCheckerTool {
&'a self,
input: serde_json::Value,
context: &'a PentestToolContext,
) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<PentestToolResult, CoreError>> + Send + 'a>> {
) -> std::pin::Pin<
Box<dyn std::future::Future<Output = Result<PentestToolResult, CoreError>> + Send + 'a>,
> {
Box::pin(async move {
let domain = input
.get("domain")
.and_then(|v| v.as_str())
.ok_or_else(|| CoreError::Dast("Missing required 'domain' parameter".to_string()))?;
let domain = input
.get("domain")
.and_then(|v| v.as_str())
.ok_or_else(|| {
CoreError::Dast("Missing required 'domain' parameter".to_string())
})?;
let target_id = context
.target
.id
.map(|oid| oid.to_hex())
.unwrap_or_else(|| "unknown".to_string());
let target_id = context
.target
.id
.map(|oid| oid.to_hex())
.unwrap_or_else(|| "unknown".to_string());
let mut findings = Vec::new();
let mut email_data = json!({});
let mut findings = Vec::new();
let mut email_data = json!({});
// ---- DMARC check ----
let dmarc_domain = format!("_dmarc.{domain}");
let dmarc_records = Self::query_txt(&dmarc_domain).await.unwrap_or_default();
// ---- DMARC check ----
let dmarc_domain = format!("_dmarc.{domain}");
let dmarc_records = Self::query_txt(&dmarc_domain).await.unwrap_or_default();
let dmarc_record = dmarc_records.iter().find(|r| r.starts_with("v=DMARC1"));
let dmarc_record = dmarc_records.iter().find(|r| r.starts_with("v=DMARC1"));
match dmarc_record {
Some(record) => {
email_data["dmarc_record"] = json!(record);
match dmarc_record {
Some(record) => {
email_data["dmarc_record"] = json!(record);
let policy = Self::parse_dmarc_policy(record);
let sp = Self::parse_dmarc_subdomain_policy(record);
let rua = Self::parse_dmarc_rua(record);
let policy = Self::parse_dmarc_policy(record);
let sp = Self::parse_dmarc_subdomain_policy(record);
let rua = Self::parse_dmarc_rua(record);
email_data["dmarc_policy"] = json!(policy);
email_data["dmarc_subdomain_policy"] = json!(sp);
email_data["dmarc_rua"] = json!(rua);
email_data["dmarc_policy"] = json!(policy);
email_data["dmarc_subdomain_policy"] = json!(sp);
email_data["dmarc_rua"] = json!(rua);
// Warn on weak policy
if let Some(ref p) = policy {
if p == "none" {
// Warn on weak policy
if let Some(ref p) = policy {
if p == "none" {
let evidence = DastEvidence {
request_method: "DNS".to_string(),
request_url: dmarc_domain.clone(),
request_headers: None,
request_body: None,
response_status: 0,
response_headers: None,
response_snippet: Some(record.clone()),
screenshot_path: None,
payload: None,
response_time_ms: None,
};
let mut finding = DastFinding::new(
String::new(),
target_id.clone(),
DastVulnType::EmailSecurity,
format!("Weak DMARC policy for {domain}"),
format!(
"The DMARC policy for {domain} is set to 'none', which only monitors \
but does not enforce email authentication. Attackers can spoof emails \
from this domain."
),
Severity::Medium,
domain.to_string(),
"DNS".to_string(),
);
finding.cwe = Some("CWE-290".to_string());
finding.evidence = vec![evidence];
finding.remediation = Some(
"Upgrade the DMARC policy from 'p=none' to 'p=quarantine' or \
'p=reject' after verifying legitimate email flows are properly \
authenticated."
.to_string(),
);
findings.push(finding);
warn!(domain, "DMARC policy is 'none'");
}
}
// Warn if no reporting URI
if rua.is_none() {
let evidence = DastEvidence {
request_method: "DNS".to_string(),
request_url: dmarc_domain.clone(),
@@ -156,48 +307,6 @@ impl PentestTool for DmarcCheckerTool {
};
let mut finding = DastFinding::new(
String::new(),
target_id.clone(),
DastVulnType::EmailSecurity,
format!("Weak DMARC policy for {domain}"),
format!(
"The DMARC policy for {domain} is set to 'none', which only monitors \
but does not enforce email authentication. Attackers can spoof emails \
from this domain."
),
Severity::Medium,
domain.to_string(),
"DNS".to_string(),
);
finding.cwe = Some("CWE-290".to_string());
finding.evidence = vec![evidence];
finding.remediation = Some(
"Upgrade the DMARC policy from 'p=none' to 'p=quarantine' or \
'p=reject' after verifying legitimate email flows are properly \
authenticated."
.to_string(),
);
findings.push(finding);
warn!(domain, "DMARC policy is 'none'");
}
}
// Warn if no reporting URI
if rua.is_none() {
let evidence = DastEvidence {
request_method: "DNS".to_string(),
request_url: dmarc_domain.clone(),
request_headers: None,
request_body: None,
response_status: 0,
response_headers: None,
response_snippet: Some(record.clone()),
screenshot_path: None,
payload: None,
response_time_ms: None,
};
let mut finding = DastFinding::new(
String::new(),
target_id.clone(),
DastVulnType::EmailSecurity,
@@ -211,80 +320,80 @@ impl PentestTool for DmarcCheckerTool {
domain.to_string(),
"DNS".to_string(),
);
finding.cwe = Some("CWE-778".to_string());
finding.evidence = vec![evidence];
finding.remediation = Some(
"Add a 'rua=' tag to your DMARC record to receive aggregate reports. \
finding.cwe = Some("CWE-778".to_string());
finding.evidence = vec![evidence];
finding.remediation = Some(
"Add a 'rua=' tag to your DMARC record to receive aggregate reports. \
Example: 'rua=mailto:dmarc-reports@example.com'."
.to_string(),
);
findings.push(finding);
.to_string(),
);
findings.push(finding);
}
}
}
None => {
email_data["dmarc_record"] = json!(null);
None => {
email_data["dmarc_record"] = json!(null);
let evidence = DastEvidence {
request_method: "DNS".to_string(),
request_url: dmarc_domain.clone(),
request_headers: None,
request_body: None,
response_status: 0,
response_headers: None,
response_snippet: Some("No DMARC record found".to_string()),
screenshot_path: None,
payload: None,
response_time_ms: None,
};
let mut finding = DastFinding::new(
String::new(),
target_id.clone(),
DastVulnType::EmailSecurity,
format!("Missing DMARC record for {domain}"),
format!(
"No DMARC record was found for {domain}. Without DMARC, there is no \
policy to prevent email spoofing and phishing attacks using this domain."
),
Severity::High,
domain.to_string(),
"DNS".to_string(),
);
finding.cwe = Some("CWE-290".to_string());
finding.evidence = vec![evidence];
finding.remediation = Some(
"Create a DMARC TXT record at _dmarc.<domain>. Start with 'v=DMARC1; p=none; \
rua=mailto:dmarc@example.com' and gradually move to 'p=reject'."
.to_string(),
);
findings.push(finding);
warn!(domain, "No DMARC record found");
}
}
// ---- SPF check ----
let txt_records = Self::query_txt(domain).await.unwrap_or_default();
let spf_record = txt_records.iter().find(|r| Self::is_spf_record(r));
match spf_record {
Some(record) => {
email_data["spf_record"] = json!(record);
if Self::spf_allows_all(record) {
let evidence = DastEvidence {
request_method: "DNS".to_string(),
request_url: domain.to_string(),
request_url: dmarc_domain.clone(),
request_headers: None,
request_body: None,
response_status: 0,
response_headers: None,
response_snippet: Some(record.clone()),
response_snippet: Some("No DMARC record found".to_string()),
screenshot_path: None,
payload: None,
response_time_ms: None,
};
let mut finding = DastFinding::new(
String::new(),
target_id.clone(),
DastVulnType::EmailSecurity,
format!("Missing DMARC record for {domain}"),
format!(
"No DMARC record was found for {domain}. Without DMARC, there is no \
policy to prevent email spoofing and phishing attacks using this domain."
),
Severity::High,
domain.to_string(),
"DNS".to_string(),
);
finding.cwe = Some("CWE-290".to_string());
finding.evidence = vec![evidence];
finding.remediation = Some(
"Create a DMARC TXT record at _dmarc.<domain>. Start with 'v=DMARC1; p=none; \
rua=mailto:dmarc@example.com' and gradually move to 'p=reject'."
.to_string(),
);
findings.push(finding);
warn!(domain, "No DMARC record found");
}
}
// ---- SPF check ----
let txt_records = Self::query_txt(domain).await.unwrap_or_default();
let spf_record = txt_records.iter().find(|r| Self::is_spf_record(r));
match spf_record {
Some(record) => {
email_data["spf_record"] = json!(record);
if Self::spf_allows_all(record) {
let evidence = DastEvidence {
request_method: "DNS".to_string(),
request_url: domain.to_string(),
request_headers: None,
request_body: None,
response_status: 0,
response_headers: None,
response_snippet: Some(record.clone()),
screenshot_path: None,
payload: None,
response_time_ms: None,
};
let mut finding = DastFinding::new(
String::new(),
target_id.clone(),
DastVulnType::EmailSecurity,
@@ -297,15 +406,55 @@ impl PentestTool for DmarcCheckerTool {
domain.to_string(),
"DNS".to_string(),
);
finding.cwe = Some("CWE-290".to_string());
finding.evidence = vec![evidence];
finding.remediation = Some(
finding.cwe = Some("CWE-290".to_string());
finding.evidence = vec![evidence];
finding.remediation = Some(
"Change '+all' to '-all' (hard fail) or '~all' (soft fail) in your SPF record. \
Only list authorized mail servers."
.to_string(),
);
findings.push(finding);
} else if Self::spf_uses_soft_fail(record) {
findings.push(finding);
} else if Self::spf_uses_soft_fail(record) {
let evidence = DastEvidence {
request_method: "DNS".to_string(),
request_url: domain.to_string(),
request_headers: None,
request_body: None,
response_status: 0,
response_headers: None,
response_snippet: Some(record.clone()),
screenshot_path: None,
payload: None,
response_time_ms: None,
};
let mut finding = DastFinding::new(
String::new(),
target_id.clone(),
DastVulnType::EmailSecurity,
format!("SPF soft fail for {domain}"),
format!(
"The SPF record for {domain} uses '~all' (soft fail) instead of \
'-all' (hard fail). Soft fail marks unauthorized emails as suspicious \
but does not reject them."
),
Severity::Low,
domain.to_string(),
"DNS".to_string(),
);
finding.cwe = Some("CWE-290".to_string());
finding.evidence = vec![evidence];
finding.remediation = Some(
"Consider changing '~all' to '-all' in your SPF record once you have \
confirmed all legitimate mail sources are listed."
.to_string(),
);
findings.push(finding);
}
}
None => {
email_data["spf_record"] = json!(null);
let evidence = DastEvidence {
request_method: "DNS".to_string(),
request_url: domain.to_string(),
@@ -313,7 +462,7 @@ impl PentestTool for DmarcCheckerTool {
request_body: None,
response_status: 0,
response_headers: None,
response_snippet: Some(record.clone()),
response_snippet: Some("No SPF record found".to_string()),
screenshot_path: None,
payload: None,
response_time_ms: None,
@@ -323,79 +472,39 @@ impl PentestTool for DmarcCheckerTool {
String::new(),
target_id.clone(),
DastVulnType::EmailSecurity,
format!("SPF soft fail for {domain}"),
format!("Missing SPF record for {domain}"),
format!(
"The SPF record for {domain} uses '~all' (soft fail) instead of \
'-all' (hard fail). Soft fail marks unauthorized emails as suspicious \
but does not reject them."
),
Severity::Low,
"No SPF record was found for {domain}. Without SPF, any server can claim \
to send email on behalf of this domain."
),
Severity::High,
domain.to_string(),
"DNS".to_string(),
);
finding.cwe = Some("CWE-290".to_string());
finding.evidence = vec![evidence];
finding.remediation = Some(
"Consider changing '~all' to '-all' in your SPF record once you have \
confirmed all legitimate mail sources are listed."
"Create an SPF TXT record for your domain. Example: \
'v=spf1 include:_spf.google.com -all'."
.to_string(),
);
findings.push(finding);
warn!(domain, "No SPF record found");
}
}
None => {
email_data["spf_record"] = json!(null);
let evidence = DastEvidence {
request_method: "DNS".to_string(),
request_url: domain.to_string(),
request_headers: None,
request_body: None,
response_status: 0,
response_headers: None,
response_snippet: Some("No SPF record found".to_string()),
screenshot_path: None,
payload: None,
response_time_ms: None,
};
let count = findings.len();
info!(domain, findings = count, "DMARC/SPF check complete");
let mut finding = DastFinding::new(
String::new(),
target_id.clone(),
DastVulnType::EmailSecurity,
format!("Missing SPF record for {domain}"),
format!(
"No SPF record was found for {domain}. Without SPF, any server can claim \
to send email on behalf of this domain."
),
Severity::High,
domain.to_string(),
"DNS".to_string(),
);
finding.cwe = Some("CWE-290".to_string());
finding.evidence = vec![evidence];
finding.remediation = Some(
"Create an SPF TXT record for your domain. Example: \
'v=spf1 include:_spf.google.com -all'."
.to_string(),
);
findings.push(finding);
warn!(domain, "No SPF record found");
}
}
let count = findings.len();
info!(domain, findings = count, "DMARC/SPF check complete");
Ok(PentestToolResult {
summary: if count > 0 {
format!("Found {count} email security issues for {domain}.")
} else {
format!("Email security configuration looks good for {domain}.")
},
findings,
data: email_data,
})
Ok(PentestToolResult {
summary: if count > 0 {
format!("Found {count} email security issues for {domain}.")
} else {
format!("Email security configuration looks good for {domain}.")
},
findings,
data: email_data,
})
})
}
}

View File

@@ -16,6 +16,12 @@ use tracing::{info, warn};
/// `tokio::process::Command` wrapper around `dig` where available.
pub struct DnsCheckerTool;
impl Default for DnsCheckerTool {
fn default() -> Self {
Self::new()
}
}
impl DnsCheckerTool {
pub fn new() -> Self {
Self
@@ -54,7 +60,9 @@ impl DnsCheckerTool {
}
}
Err(e) => {
return Err(CoreError::Dast(format!("DNS resolution failed for {domain}: {e}")));
return Err(CoreError::Dast(format!(
"DNS resolution failed for {domain}: {e}"
)));
}
}
@@ -94,107 +102,111 @@ impl PentestTool for DnsCheckerTool {
&'a self,
input: serde_json::Value,
context: &'a PentestToolContext,
) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<PentestToolResult, CoreError>> + Send + 'a>> {
) -> std::pin::Pin<
Box<dyn std::future::Future<Output = Result<PentestToolResult, CoreError>> + Send + 'a>,
> {
Box::pin(async move {
let domain = input
.get("domain")
.and_then(|v| v.as_str())
.ok_or_else(|| CoreError::Dast("Missing required 'domain' parameter".to_string()))?;
let domain = input
.get("domain")
.and_then(|v| v.as_str())
.ok_or_else(|| {
CoreError::Dast("Missing required 'domain' parameter".to_string())
})?;
let subdomains: Vec<String> = input
.get("subdomains")
.and_then(|v| v.as_array())
.map(|arr| {
arr.iter()
.filter_map(|v| v.as_str().map(String::from))
.collect()
})
.unwrap_or_default();
let subdomains: Vec<String> = input
.get("subdomains")
.and_then(|v| v.as_array())
.map(|arr| {
arr.iter()
.filter_map(|v| v.as_str().map(String::from))
.collect()
})
.unwrap_or_default();
let mut findings = Vec::new();
let mut dns_data: HashMap<String, serde_json::Value> = HashMap::new();
let mut findings = Vec::new();
let mut dns_data: HashMap<String, serde_json::Value> = HashMap::new();
let target_id = context
.target
.id
.map(|oid| oid.to_hex())
.unwrap_or_else(|| "unknown".to_string());
let target_id = context
.target
.id
.map(|oid| oid.to_hex())
.unwrap_or_else(|| "unknown".to_string());
// --- A / AAAA records ---
match Self::resolve_addresses(domain).await {
Ok((ipv4, ipv6)) => {
dns_data.insert("a_records".to_string(), json!(ipv4));
dns_data.insert("aaaa_records".to_string(), json!(ipv6));
// --- A / AAAA records ---
match Self::resolve_addresses(domain).await {
Ok((ipv4, ipv6)) => {
dns_data.insert("a_records".to_string(), json!(ipv4));
dns_data.insert("aaaa_records".to_string(), json!(ipv6));
}
Err(e) => {
dns_data.insert("a_records_error".to_string(), json!(e.to_string()));
}
}
Err(e) => {
dns_data.insert("a_records_error".to_string(), json!(e.to_string()));
}
}
// --- MX records ---
match Self::dig_query(domain, "MX").await {
Ok(mx) => {
dns_data.insert("mx_records".to_string(), json!(mx));
// --- MX records ---
match Self::dig_query(domain, "MX").await {
Ok(mx) => {
dns_data.insert("mx_records".to_string(), json!(mx));
}
Err(e) => {
dns_data.insert("mx_records_error".to_string(), json!(e.to_string()));
}
}
Err(e) => {
dns_data.insert("mx_records_error".to_string(), json!(e.to_string()));
}
}
// --- NS records ---
let ns_records = match Self::dig_query(domain, "NS").await {
Ok(ns) => {
dns_data.insert("ns_records".to_string(), json!(ns));
ns
}
Err(e) => {
dns_data.insert("ns_records_error".to_string(), json!(e.to_string()));
Vec::new()
}
};
// --- NS records ---
let ns_records = match Self::dig_query(domain, "NS").await {
Ok(ns) => {
dns_data.insert("ns_records".to_string(), json!(ns));
ns
}
Err(e) => {
dns_data.insert("ns_records_error".to_string(), json!(e.to_string()));
Vec::new()
}
};
// --- TXT records ---
match Self::dig_query(domain, "TXT").await {
Ok(txt) => {
dns_data.insert("txt_records".to_string(), json!(txt));
// --- TXT records ---
match Self::dig_query(domain, "TXT").await {
Ok(txt) => {
dns_data.insert("txt_records".to_string(), json!(txt));
}
Err(e) => {
dns_data.insert("txt_records_error".to_string(), json!(e.to_string()));
}
}
Err(e) => {
dns_data.insert("txt_records_error".to_string(), json!(e.to_string()));
// --- CNAME records (for subdomains) ---
let mut cname_data: HashMap<String, Vec<String>> = HashMap::new();
let mut domains_to_check = vec![domain.to_string()];
for sub in &subdomains {
domains_to_check.push(format!("{sub}.{domain}"));
}
}
// --- CNAME records (for subdomains) ---
let mut cname_data: HashMap<String, Vec<String>> = HashMap::new();
let mut domains_to_check = vec![domain.to_string()];
for sub in &subdomains {
domains_to_check.push(format!("{sub}.{domain}"));
}
for fqdn in &domains_to_check {
match Self::dig_query(fqdn, "CNAME").await {
Ok(cnames) if !cnames.is_empty() => {
// Check for dangling CNAME
for cname in &cnames {
let cname_clean = cname.trim_end_matches('.');
let check_addr = format!("{cname_clean}:443");
let is_dangling = lookup_host(&check_addr).await.is_err();
if is_dangling {
let evidence = DastEvidence {
request_method: "DNS".to_string(),
request_url: fqdn.clone(),
request_headers: None,
request_body: None,
response_status: 0,
response_headers: None,
response_snippet: Some(format!(
"CNAME {fqdn} -> {cname} (target does not resolve)"
)),
screenshot_path: None,
payload: None,
response_time_ms: None,
};
for fqdn in &domains_to_check {
match Self::dig_query(fqdn, "CNAME").await {
Ok(cnames) if !cnames.is_empty() => {
// Check for dangling CNAME
for cname in &cnames {
let cname_clean = cname.trim_end_matches('.');
let check_addr = format!("{cname_clean}:443");
let is_dangling = lookup_host(&check_addr).await.is_err();
if is_dangling {
let evidence = DastEvidence {
request_method: "DNS".to_string(),
request_url: fqdn.clone(),
request_headers: None,
request_body: None,
response_status: 0,
response_headers: None,
response_snippet: Some(format!(
"CNAME {fqdn} -> {cname} (target does not resolve)"
)),
screenshot_path: None,
payload: None,
response_time_ms: None,
};
let mut finding = DastFinding::new(
let mut finding = DastFinding::new(
String::new(),
target_id.clone(),
DastVulnType::DnsMisconfiguration,
@@ -207,44 +219,47 @@ impl PentestTool for DnsCheckerTool {
fqdn.clone(),
"DNS".to_string(),
);
finding.cwe = Some("CWE-923".to_string());
finding.evidence = vec![evidence];
finding.remediation = Some(
finding.cwe = Some("CWE-923".to_string());
finding.evidence = vec![evidence];
finding.remediation = Some(
"Remove dangling CNAME records or ensure the target hostname is \
properly configured and resolvable."
.to_string(),
);
findings.push(finding);
warn!(fqdn, cname, "Dangling CNAME detected - potential subdomain takeover");
findings.push(finding);
warn!(
fqdn,
cname, "Dangling CNAME detected - potential subdomain takeover"
);
}
}
cname_data.insert(fqdn.clone(), cnames);
}
cname_data.insert(fqdn.clone(), cnames);
_ => {}
}
_ => {}
}
}
if !cname_data.is_empty() {
dns_data.insert("cname_records".to_string(), json!(cname_data));
}
if !cname_data.is_empty() {
dns_data.insert("cname_records".to_string(), json!(cname_data));
}
// --- CAA records ---
match Self::dig_query(domain, "CAA").await {
Ok(caa) => {
if caa.is_empty() {
let evidence = DastEvidence {
request_method: "DNS".to_string(),
request_url: domain.to_string(),
request_headers: None,
request_body: None,
response_status: 0,
response_headers: None,
response_snippet: Some("No CAA records found".to_string()),
screenshot_path: None,
payload: None,
response_time_ms: None,
};
// --- CAA records ---
match Self::dig_query(domain, "CAA").await {
Ok(caa) => {
if caa.is_empty() {
let evidence = DastEvidence {
request_method: "DNS".to_string(),
request_url: domain.to_string(),
request_headers: None,
request_body: None,
response_status: 0,
response_headers: None,
response_snippet: Some("No CAA records found".to_string()),
screenshot_path: None,
payload: None,
response_time_ms: None,
};
let mut finding = DastFinding::new(
let mut finding = DastFinding::new(
String::new(),
target_id.clone(),
DastVulnType::DnsMisconfiguration,
@@ -257,35 +272,83 @@ impl PentestTool for DnsCheckerTool {
domain.to_string(),
"DNS".to_string(),
);
finding.cwe = Some("CWE-295".to_string());
finding.evidence = vec![evidence];
finding.remediation = Some(
finding.cwe = Some("CWE-295".to_string());
finding.evidence = vec![evidence];
finding.remediation = Some(
"Add CAA DNS records to restrict which certificate authorities can issue \
certificates for your domain. Example: '0 issue \"letsencrypt.org\"'."
.to_string(),
);
findings.push(finding);
findings.push(finding);
}
dns_data.insert("caa_records".to_string(), json!(caa));
}
Err(e) => {
dns_data.insert("caa_records_error".to_string(), json!(e.to_string()));
}
dns_data.insert("caa_records".to_string(), json!(caa));
}
Err(e) => {
dns_data.insert("caa_records_error".to_string(), json!(e.to_string()));
// --- DNSSEC check ---
let dnssec_output = tokio::process::Command::new("dig")
.args(["+dnssec", "+short", "DNSKEY", domain])
.output()
.await;
match dnssec_output {
Ok(output) => {
let stdout = String::from_utf8_lossy(&output.stdout);
let has_dnssec = !stdout.trim().is_empty();
dns_data.insert("dnssec_enabled".to_string(), json!(has_dnssec));
if !has_dnssec {
let evidence = DastEvidence {
request_method: "DNS".to_string(),
request_url: domain.to_string(),
request_headers: None,
request_body: None,
response_status: 0,
response_headers: None,
response_snippet: Some(
"No DNSKEY records found - DNSSEC not enabled".to_string(),
),
screenshot_path: None,
payload: None,
response_time_ms: None,
};
let mut finding = DastFinding::new(
String::new(),
target_id.clone(),
DastVulnType::DnsMisconfiguration,
format!("DNSSEC not enabled for {domain}"),
format!(
"DNSSEC is not enabled for {domain}. Without DNSSEC, DNS responses \
can be spoofed, allowing man-in-the-middle attacks."
),
Severity::Medium,
domain.to_string(),
"DNS".to_string(),
);
finding.cwe = Some("CWE-350".to_string());
finding.evidence = vec![evidence];
finding.remediation = Some(
"Enable DNSSEC for your domain by configuring DNSKEY and DS records \
with your DNS provider and domain registrar."
.to_string(),
);
findings.push(finding);
}
}
Err(_) => {
dns_data.insert("dnssec_check_error".to_string(), json!("dig not available"));
}
}
}
// --- DNSSEC check ---
let dnssec_output = tokio::process::Command::new("dig")
.args(["+dnssec", "+short", "DNSKEY", domain])
.output()
.await;
match dnssec_output {
Ok(output) => {
let stdout = String::from_utf8_lossy(&output.stdout);
let has_dnssec = !stdout.trim().is_empty();
dns_data.insert("dnssec_enabled".to_string(), json!(has_dnssec));
if !has_dnssec {
// --- Check NS records for dangling ---
for ns in &ns_records {
let ns_clean = ns.trim_end_matches('.');
let check_addr = format!("{ns_clean}:53");
if lookup_host(&check_addr).await.is_err() {
let evidence = DastEvidence {
request_method: "DNS".to_string(),
request_url: domain.to_string(),
@@ -293,61 +356,13 @@ impl PentestTool for DnsCheckerTool {
request_body: None,
response_status: 0,
response_headers: None,
response_snippet: Some("No DNSKEY records found - DNSSEC not enabled".to_string()),
response_snippet: Some(format!("NS record {ns} does not resolve")),
screenshot_path: None,
payload: None,
response_time_ms: None,
};
let mut finding = DastFinding::new(
String::new(),
target_id.clone(),
DastVulnType::DnsMisconfiguration,
format!("DNSSEC not enabled for {domain}"),
format!(
"DNSSEC is not enabled for {domain}. Without DNSSEC, DNS responses \
can be spoofed, allowing man-in-the-middle attacks."
),
Severity::Medium,
domain.to_string(),
"DNS".to_string(),
);
finding.cwe = Some("CWE-350".to_string());
finding.evidence = vec![evidence];
finding.remediation = Some(
"Enable DNSSEC for your domain by configuring DNSKEY and DS records \
with your DNS provider and domain registrar."
.to_string(),
);
findings.push(finding);
}
}
Err(_) => {
dns_data.insert("dnssec_check_error".to_string(), json!("dig not available"));
}
}
// --- Check NS records for dangling ---
for ns in &ns_records {
let ns_clean = ns.trim_end_matches('.');
let check_addr = format!("{ns_clean}:53");
if lookup_host(&check_addr).await.is_err() {
let evidence = DastEvidence {
request_method: "DNS".to_string(),
request_url: domain.to_string(),
request_headers: None,
request_body: None,
response_status: 0,
response_headers: None,
response_snippet: Some(format!(
"NS record {ns} does not resolve"
)),
screenshot_path: None,
payload: None,
response_time_ms: None,
};
let mut finding = DastFinding::new(
String::new(),
target_id.clone(),
DastVulnType::DnsMisconfiguration,
@@ -360,30 +375,33 @@ impl PentestTool for DnsCheckerTool {
domain.to_string(),
"DNS".to_string(),
);
finding.cwe = Some("CWE-923".to_string());
finding.evidence = vec![evidence];
finding.remediation = Some(
"Remove dangling NS records or ensure the nameserver hostname is properly \
finding.cwe = Some("CWE-923".to_string());
finding.evidence = vec![evidence];
finding.remediation = Some(
"Remove dangling NS records or ensure the nameserver hostname is properly \
configured. Dangling NS records can lead to full domain takeover."
.to_string(),
);
findings.push(finding);
warn!(domain, ns, "Dangling NS record detected - potential domain takeover");
.to_string(),
);
findings.push(finding);
warn!(
domain,
ns, "Dangling NS record detected - potential domain takeover"
);
}
}
}
let count = findings.len();
info!(domain, findings = count, "DNS check complete");
let count = findings.len();
info!(domain, findings = count, "DNS check complete");
Ok(PentestToolResult {
summary: if count > 0 {
format!("Found {count} DNS configuration issues for {domain}.")
} else {
format!("No DNS configuration issues found for {domain}.")
},
findings,
data: json!(dns_data),
})
Ok(PentestToolResult {
summary: if count > 0 {
format!("Found {count} DNS configuration issues for {domain}.")
} else {
format!("No DNS configuration issues found for {domain}.")
},
findings,
data: json!(dns_data),
})
})
}
}

View File

@@ -33,8 +33,15 @@ pub struct ToolRegistry {
tools: HashMap<String, Box<dyn PentestTool>>,
}
impl Default for ToolRegistry {
fn default() -> Self {
Self::new()
}
}
impl ToolRegistry {
/// Create a new registry with all built-in tools pre-registered.
#[allow(clippy::expect_used)]
pub fn new() -> Self {
let http = reqwest::Client::builder()
.danger_accept_invalid_certs(true)
@@ -67,13 +74,10 @@ impl ToolRegistry {
);
// New infrastructure / analysis tools
register(&mut tools, Box::<dns_checker::DnsCheckerTool>::default());
register(
&mut tools,
Box::new(dns_checker::DnsCheckerTool::new()),
);
register(
&mut tools,
Box::new(dmarc_checker::DmarcCheckerTool::new()),
Box::<dmarc_checker::DmarcCheckerTool>::default(),
);
register(
&mut tools,
@@ -109,10 +113,7 @@ impl ToolRegistry {
&mut tools,
Box::new(openapi_parser::OpenApiParserTool::new(http.clone())),
);
register(
&mut tools,
Box::new(recon::ReconTool::new(http)),
);
register(&mut tools, Box::new(recon::ReconTool::new(http)));
Self { tools }
}

View File

@@ -92,7 +92,10 @@ impl OpenApiParserTool {
// If content type suggests YAML, we can't easily parse without a YAML dep,
// so just report the URL as found
if content_type.contains("yaml") || body.starts_with("openapi:") || body.starts_with("swagger:") {
if content_type.contains("yaml")
|| body.starts_with("openapi:")
|| body.starts_with("swagger:")
{
// Return a minimal JSON indicating YAML was found
return Some((
url.to_string(),
@@ -107,7 +110,7 @@ impl OpenApiParserTool {
}
/// Parse an OpenAPI 3.x or Swagger 2.x spec into structured endpoints.
fn parse_spec(spec: &serde_json::Value, base_url: &str) -> Vec<ParsedEndpoint> {
fn parse_spec(spec: &serde_json::Value, _base_url: &str) -> Vec<ParsedEndpoint> {
let mut endpoints = Vec::new();
// Determine base path
@@ -258,6 +261,166 @@ impl OpenApiParserTool {
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn common_spec_paths_not_empty() {
let paths = OpenApiParserTool::common_spec_paths();
assert!(paths.len() >= 5);
assert!(paths.contains(&"/openapi.json"));
assert!(paths.contains(&"/swagger.json"));
}
#[test]
fn parse_spec_openapi3_basic() {
let spec = json!({
"openapi": "3.0.0",
"info": { "title": "Test API", "version": "1.0" },
"paths": {
"/users": {
"get": {
"operationId": "listUsers",
"summary": "List all users",
"parameters": [
{
"name": "limit",
"in": "query",
"required": false,
"schema": { "type": "integer" }
}
],
"responses": {
"200": { "description": "OK" },
"401": { "description": "Unauthorized" }
},
"tags": ["users"]
},
"post": {
"operationId": "createUser",
"requestBody": {
"content": {
"application/json": {}
}
},
"responses": { "201": {} },
"security": [{ "bearerAuth": [] }]
}
}
}
});
let endpoints = OpenApiParserTool::parse_spec(&spec, "https://api.example.com");
assert_eq!(endpoints.len(), 2);
let get_ep = endpoints.iter().find(|e| e.method == "GET").unwrap();
assert_eq!(get_ep.path, "/users");
assert_eq!(get_ep.operation_id.as_deref(), Some("listUsers"));
assert_eq!(get_ep.summary.as_deref(), Some("List all users"));
assert_eq!(get_ep.parameters.len(), 1);
assert_eq!(get_ep.parameters[0].name, "limit");
assert_eq!(get_ep.parameters[0].location, "query");
assert!(!get_ep.parameters[0].required);
assert_eq!(get_ep.parameters[0].param_type.as_deref(), Some("integer"));
assert_eq!(get_ep.response_codes.len(), 2);
assert_eq!(get_ep.tags, vec!["users"]);
let post_ep = endpoints.iter().find(|e| e.method == "POST").unwrap();
assert_eq!(
post_ep.request_body_content_type.as_deref(),
Some("application/json")
);
assert_eq!(post_ep.security, vec!["bearerAuth"]);
}
#[test]
fn parse_spec_swagger2_with_base_path() {
let spec = json!({
"swagger": "2.0",
"basePath": "/api/v1",
"paths": {
"/items": {
"get": {
"parameters": [
{ "name": "id", "in": "path", "required": true, "type": "string" }
],
"responses": { "200": {} }
}
}
}
});
let endpoints = OpenApiParserTool::parse_spec(&spec, "https://api.example.com");
assert_eq!(endpoints.len(), 1);
assert_eq!(endpoints[0].path, "/api/v1/items");
assert_eq!(
endpoints[0].parameters[0].param_type.as_deref(),
Some("string")
);
}
#[test]
fn parse_spec_empty_paths() {
let spec = json!({ "openapi": "3.0.0", "paths": {} });
let endpoints = OpenApiParserTool::parse_spec(&spec, "https://example.com");
assert!(endpoints.is_empty());
}
#[test]
fn parse_spec_no_paths_key() {
let spec = json!({ "openapi": "3.0.0" });
let endpoints = OpenApiParserTool::parse_spec(&spec, "https://example.com");
assert!(endpoints.is_empty());
}
#[test]
fn parse_spec_servers_base_url() {
let spec = json!({
"openapi": "3.0.0",
"servers": [{ "url": "/api/v2" }],
"paths": {
"/health": {
"get": { "responses": { "200": {} } }
}
}
});
let endpoints = OpenApiParserTool::parse_spec(&spec, "https://example.com");
assert_eq!(endpoints[0].path, "/api/v2/health");
}
#[test]
fn parse_spec_path_level_parameters_merged() {
let spec = json!({
"openapi": "3.0.0",
"paths": {
"/items/{id}": {
"parameters": [
{ "name": "id", "in": "path", "required": true, "schema": { "type": "string" } }
],
"get": {
"parameters": [
{ "name": "fields", "in": "query", "schema": { "type": "string" } }
],
"responses": { "200": {} }
}
}
}
});
let endpoints = OpenApiParserTool::parse_spec(&spec, "https://example.com");
assert_eq!(endpoints[0].parameters.len(), 2);
assert!(endpoints[0]
.parameters
.iter()
.any(|p| p.name == "id" && p.location == "path"));
assert!(endpoints[0]
.parameters
.iter()
.any(|p| p.name == "fields" && p.location == "query"));
}
}
impl PentestTool for OpenApiParserTool {
fn name(&self) -> &str {
"openapi_parser"
@@ -289,134 +452,138 @@ impl PentestTool for OpenApiParserTool {
fn execute<'a>(
&'a self,
input: serde_json::Value,
context: &'a PentestToolContext,
) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<PentestToolResult, CoreError>> + Send + 'a>> {
_context: &'a PentestToolContext,
) -> std::pin::Pin<
Box<dyn std::future::Future<Output = Result<PentestToolResult, CoreError>> + Send + 'a>,
> {
Box::pin(async move {
let base_url = input
.get("base_url")
.and_then(|v| v.as_str())
.ok_or_else(|| CoreError::Dast("Missing required 'base_url' parameter".to_string()))?;
let base_url = input
.get("base_url")
.and_then(|v| v.as_str())
.ok_or_else(|| {
CoreError::Dast("Missing required 'base_url' parameter".to_string())
})?;
let explicit_spec_url = input.get("spec_url").and_then(|v| v.as_str());
let explicit_spec_url = input.get("spec_url").and_then(|v| v.as_str());
let base_url_trimmed = base_url.trim_end_matches('/');
let base_url_trimmed = base_url.trim_end_matches('/');
// If an explicit spec URL is provided, try it first
let mut spec_result: Option<(String, serde_json::Value)> = None;
if let Some(spec_url) = explicit_spec_url {
spec_result = Self::try_fetch_spec(&self.http, spec_url).await;
}
// If an explicit spec URL is provided, try it first
let mut spec_result: Option<(String, serde_json::Value)> = None;
if let Some(spec_url) = explicit_spec_url {
spec_result = Self::try_fetch_spec(&self.http, spec_url).await;
}
// If no explicit URL or it failed, try common paths
if spec_result.is_none() {
for path in Self::common_spec_paths() {
let url = format!("{base_url_trimmed}{path}");
if let Some(result) = Self::try_fetch_spec(&self.http, &url).await {
spec_result = Some(result);
break;
// If no explicit URL or it failed, try common paths
if spec_result.is_none() {
for path in Self::common_spec_paths() {
let url = format!("{base_url_trimmed}{path}");
if let Some(result) = Self::try_fetch_spec(&self.http, &url).await {
spec_result = Some(result);
break;
}
}
}
}
match spec_result {
Some((spec_url, spec)) => {
let spec_version = spec
.get("openapi")
.or_else(|| spec.get("swagger"))
.and_then(|v| v.as_str())
.unwrap_or("unknown");
match spec_result {
Some((spec_url, spec)) => {
let spec_version = spec
.get("openapi")
.or_else(|| spec.get("swagger"))
.and_then(|v| v.as_str())
.unwrap_or("unknown");
let api_title = spec
.get("info")
.and_then(|i| i.get("title"))
.and_then(|t| t.as_str())
.unwrap_or("Unknown API");
let api_title = spec
.get("info")
.and_then(|i| i.get("title"))
.and_then(|t| t.as_str())
.unwrap_or("Unknown API");
let api_version = spec
.get("info")
.and_then(|i| i.get("version"))
.and_then(|v| v.as_str())
.unwrap_or("unknown");
let api_version = spec
.get("info")
.and_then(|i| i.get("version"))
.and_then(|v| v.as_str())
.unwrap_or("unknown");
let endpoints = Self::parse_spec(&spec, base_url_trimmed);
let endpoints = Self::parse_spec(&spec, base_url_trimmed);
let endpoint_data: Vec<serde_json::Value> = endpoints
.iter()
.map(|ep| {
let params: Vec<serde_json::Value> = ep
.parameters
.iter()
.map(|p| {
json!({
"name": p.name,
"in": p.location,
"required": p.required,
"type": p.param_type,
"description": p.description,
let endpoint_data: Vec<serde_json::Value> = endpoints
.iter()
.map(|ep| {
let params: Vec<serde_json::Value> = ep
.parameters
.iter()
.map(|p| {
json!({
"name": p.name,
"in": p.location,
"required": p.required,
"type": p.param_type,
"description": p.description,
})
})
.collect();
json!({
"path": ep.path,
"method": ep.method,
"operation_id": ep.operation_id,
"summary": ep.summary,
"parameters": params,
"request_body_content_type": ep.request_body_content_type,
"response_codes": ep.response_codes,
"security": ep.security,
"tags": ep.tags,
})
.collect();
json!({
"path": ep.path,
"method": ep.method,
"operation_id": ep.operation_id,
"summary": ep.summary,
"parameters": params,
"request_body_content_type": ep.request_body_content_type,
"response_codes": ep.response_codes,
"security": ep.security,
"tags": ep.tags,
})
})
.collect();
.collect();
let endpoint_count = endpoints.len();
info!(
spec_url = %spec_url,
spec_version,
api_title,
endpoints = endpoint_count,
"OpenAPI spec parsed"
);
let endpoint_count = endpoints.len();
info!(
spec_url = %spec_url,
spec_version,
api_title,
endpoints = endpoint_count,
"OpenAPI spec parsed"
);
Ok(PentestToolResult {
summary: format!(
"Found OpenAPI spec ({spec_version}) at {spec_url}. \
Ok(PentestToolResult {
summary: format!(
"Found OpenAPI spec ({spec_version}) at {spec_url}. \
API: {api_title} v{api_version}. \
Parsed {endpoint_count} endpoints."
),
findings: Vec::new(), // This tool produces data, not findings
data: json!({
"spec_url": spec_url,
"spec_version": spec_version,
"api_title": api_title,
"api_version": api_version,
"endpoint_count": endpoint_count,
"endpoints": endpoint_data,
"security_schemes": spec.get("components")
.and_then(|c| c.get("securitySchemes"))
.or_else(|| spec.get("securityDefinitions")),
}),
})
}
None => {
info!(base_url, "No OpenAPI spec found");
),
findings: Vec::new(), // This tool produces data, not findings
data: json!({
"spec_url": spec_url,
"spec_version": spec_version,
"api_title": api_title,
"api_version": api_version,
"endpoint_count": endpoint_count,
"endpoints": endpoint_data,
"security_schemes": spec.get("components")
.and_then(|c| c.get("securitySchemes"))
.or_else(|| spec.get("securityDefinitions")),
}),
})
}
None => {
info!(base_url, "No OpenAPI spec found");
Ok(PentestToolResult {
summary: format!(
"No OpenAPI/Swagger specification found for {base_url}. \
Ok(PentestToolResult {
summary: format!(
"No OpenAPI/Swagger specification found for {base_url}. \
Tried {} common paths.",
Self::common_spec_paths().len()
),
findings: Vec::new(),
data: json!({
"spec_found": false,
"paths_tried": Self::common_spec_paths(),
}),
})
Self::common_spec_paths().len()
),
findings: Vec::new(),
data: json!({
"spec_found": false,
"paths_tried": Self::common_spec_paths(),
}),
})
}
}
}
})
}
}

View File

@@ -62,224 +62,229 @@ impl PentestTool for RateLimitTesterTool {
&'a self,
input: serde_json::Value,
context: &'a PentestToolContext,
) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<PentestToolResult, CoreError>> + Send + 'a>> {
) -> std::pin::Pin<
Box<dyn std::future::Future<Output = Result<PentestToolResult, CoreError>> + Send + 'a>,
> {
Box::pin(async move {
let url = input
.get("url")
.and_then(|v| v.as_str())
.ok_or_else(|| CoreError::Dast("Missing required 'url' parameter".to_string()))?;
let url = input
.get("url")
.and_then(|v| v.as_str())
.ok_or_else(|| CoreError::Dast("Missing required 'url' parameter".to_string()))?;
let method = input
.get("method")
.and_then(|v| v.as_str())
.unwrap_or("GET");
let method = input
.get("method")
.and_then(|v| v.as_str())
.unwrap_or("GET");
let request_count = input
.get("request_count")
.and_then(|v| v.as_u64())
.unwrap_or(50)
.min(200) as usize;
let request_count = input
.get("request_count")
.and_then(|v| v.as_u64())
.unwrap_or(50)
.min(200) as usize;
let body = input.get("body").and_then(|v| v.as_str());
let body = input.get("body").and_then(|v| v.as_str());
let target_id = context
.target
.id
.map(|oid| oid.to_hex())
.unwrap_or_else(|| "unknown".to_string());
let target_id = context
.target
.id
.map(|oid| oid.to_hex())
.unwrap_or_else(|| "unknown".to_string());
// Respect the context rate limit if set
let max_requests = if context.rate_limit > 0 {
request_count.min(context.rate_limit as usize * 5)
} else {
request_count
};
let mut status_codes: Vec<u16> = Vec::with_capacity(max_requests);
let mut response_times: Vec<u128> = Vec::with_capacity(max_requests);
let mut got_429 = false;
let mut rate_limit_at_request: Option<usize> = None;
for i in 0..max_requests {
let start = Instant::now();
let request = match method {
"POST" => {
let mut req = self.http.post(url);
if let Some(b) = body {
req = req.body(b.to_string());
}
req
}
"PUT" => {
let mut req = self.http.put(url);
if let Some(b) = body {
req = req.body(b.to_string());
}
req
}
"PATCH" => {
let mut req = self.http.patch(url);
if let Some(b) = body {
req = req.body(b.to_string());
}
req
}
"DELETE" => self.http.delete(url),
_ => self.http.get(url),
// Respect the context rate limit if set
let max_requests = if context.rate_limit > 0 {
request_count.min(context.rate_limit as usize * 5)
} else {
request_count
};
match request.send().await {
Ok(resp) => {
let elapsed = start.elapsed().as_millis();
let status = resp.status().as_u16();
status_codes.push(status);
response_times.push(elapsed);
let mut status_codes: Vec<u16> = Vec::with_capacity(max_requests);
let mut response_times: Vec<u128> = Vec::with_capacity(max_requests);
let mut got_429 = false;
let mut rate_limit_at_request: Option<usize> = None;
if status == 429 && !got_429 {
got_429 = true;
rate_limit_at_request = Some(i + 1);
info!(url, request_num = i + 1, "Rate limit triggered (429)");
for i in 0..max_requests {
let start = Instant::now();
let request = match method {
"POST" => {
let mut req = self.http.post(url);
if let Some(b) = body {
req = req.body(b.to_string());
}
req
}
"PUT" => {
let mut req = self.http.put(url);
if let Some(b) = body {
req = req.body(b.to_string());
}
req
}
"PATCH" => {
let mut req = self.http.patch(url);
if let Some(b) = body {
req = req.body(b.to_string());
}
req
}
"DELETE" => self.http.delete(url),
_ => self.http.get(url),
};
// Check for rate limit headers even on 200
if !got_429 {
let headers = resp.headers();
let has_rate_headers = headers.contains_key("x-ratelimit-limit")
|| headers.contains_key("x-ratelimit-remaining")
|| headers.contains_key("ratelimit-limit")
|| headers.contains_key("ratelimit-remaining")
|| headers.contains_key("retry-after");
match request.send().await {
Ok(resp) => {
let elapsed = start.elapsed().as_millis();
let status = resp.status().as_u16();
status_codes.push(status);
response_times.push(elapsed);
if has_rate_headers && rate_limit_at_request.is_none() {
// Server has rate limit headers but hasn't blocked yet
if status == 429 && !got_429 {
got_429 = true;
rate_limit_at_request = Some(i + 1);
info!(url, request_num = i + 1, "Rate limit triggered (429)");
}
// Check for rate limit headers even on 200
if !got_429 {
let headers = resp.headers();
let has_rate_headers = headers.contains_key("x-ratelimit-limit")
|| headers.contains_key("x-ratelimit-remaining")
|| headers.contains_key("ratelimit-limit")
|| headers.contains_key("ratelimit-remaining")
|| headers.contains_key("retry-after");
if has_rate_headers && rate_limit_at_request.is_none() {
// Server has rate limit headers but hasn't blocked yet
}
}
}
}
Err(e) => {
let elapsed = start.elapsed().as_millis();
status_codes.push(0);
response_times.push(elapsed);
Err(_e) => {
let elapsed = start.elapsed().as_millis();
status_codes.push(0);
response_times.push(elapsed);
}
}
}
}
let mut findings = Vec::new();
let total_sent = status_codes.len();
let count_429 = status_codes.iter().filter(|&&s| s == 429).count();
let count_success = status_codes.iter().filter(|&&s| (200..300).contains(&s)).count();
let mut findings = Vec::new();
let total_sent = status_codes.len();
let count_429 = status_codes.iter().filter(|&&s| s == 429).count();
let count_success = status_codes
.iter()
.filter(|&&s| (200..300).contains(&s))
.count();
// Calculate response time statistics
let avg_time = if !response_times.is_empty() {
response_times.iter().sum::<u128>() / response_times.len() as u128
} else {
0
};
let first_half_avg = if response_times.len() >= 4 {
let half = response_times.len() / 2;
response_times[..half].iter().sum::<u128>() / half as u128
} else {
avg_time
};
let second_half_avg = if response_times.len() >= 4 {
let half = response_times.len() / 2;
response_times[half..].iter().sum::<u128>() / (response_times.len() - half) as u128
} else {
avg_time
};
// Significant time degradation suggests possible (weak) rate limiting
let time_degradation = if first_half_avg > 0 {
(second_half_avg as f64 / first_half_avg as f64) - 1.0
} else {
0.0
};
let rate_data = json!({
"total_requests_sent": total_sent,
"status_429_count": count_429,
"success_count": count_success,
"rate_limit_at_request": rate_limit_at_request,
"avg_response_time_ms": avg_time,
"first_half_avg_ms": first_half_avg,
"second_half_avg_ms": second_half_avg,
"time_degradation_pct": (time_degradation * 100.0).round(),
});
if !got_429 && count_success == total_sent {
// No rate limiting detected at all
let evidence = DastEvidence {
request_method: method.to_string(),
request_url: url.to_string(),
request_headers: None,
request_body: body.map(String::from),
response_status: 200,
response_headers: None,
response_snippet: Some(format!(
"Sent {total_sent} rapid requests. All returned success (2xx). \
No 429 responses received. Avg response time: {avg_time}ms."
)),
screenshot_path: None,
payload: None,
response_time_ms: Some(avg_time as u64),
// Calculate response time statistics
let avg_time = if !response_times.is_empty() {
response_times.iter().sum::<u128>() / response_times.len() as u128
} else {
0
};
let mut finding = DastFinding::new(
String::new(),
target_id.clone(),
DastVulnType::RateLimitAbsent,
format!("No rate limiting on {} {}", method, url),
format!(
"The endpoint {} {} does not enforce rate limiting. \
let first_half_avg = if response_times.len() >= 4 {
let half = response_times.len() / 2;
response_times[..half].iter().sum::<u128>() / half as u128
} else {
avg_time
};
let second_half_avg = if response_times.len() >= 4 {
let half = response_times.len() / 2;
response_times[half..].iter().sum::<u128>() / (response_times.len() - half) as u128
} else {
avg_time
};
// Significant time degradation suggests possible (weak) rate limiting
let time_degradation = if first_half_avg > 0 {
(second_half_avg as f64 / first_half_avg as f64) - 1.0
} else {
0.0
};
let rate_data = json!({
"total_requests_sent": total_sent,
"status_429_count": count_429,
"success_count": count_success,
"rate_limit_at_request": rate_limit_at_request,
"avg_response_time_ms": avg_time,
"first_half_avg_ms": first_half_avg,
"second_half_avg_ms": second_half_avg,
"time_degradation_pct": (time_degradation * 100.0).round(),
});
if !got_429 && count_success == total_sent {
// No rate limiting detected at all
let evidence = DastEvidence {
request_method: method.to_string(),
request_url: url.to_string(),
request_headers: None,
request_body: body.map(String::from),
response_status: 200,
response_headers: None,
response_snippet: Some(format!(
"Sent {total_sent} rapid requests. All returned success (2xx). \
No 429 responses received. Avg response time: {avg_time}ms."
)),
screenshot_path: None,
payload: None,
response_time_ms: Some(avg_time as u64),
};
let mut finding = DastFinding::new(
String::new(),
target_id.clone(),
DastVulnType::RateLimitAbsent,
format!("No rate limiting on {} {}", method, url),
format!(
"The endpoint {} {} does not enforce rate limiting. \
{total_sent} rapid requests were all accepted with no 429 responses \
or noticeable degradation. This makes the endpoint vulnerable to \
brute force attacks and abuse.",
method, url
),
Severity::Medium,
url.to_string(),
method.to_string(),
);
finding.cwe = Some("CWE-770".to_string());
finding.evidence = vec![evidence];
finding.remediation = Some(
"Implement rate limiting on this endpoint. Use token bucket or sliding window \
method, url
),
Severity::Medium,
url.to_string(),
method.to_string(),
);
finding.cwe = Some("CWE-770".to_string());
finding.evidence = vec![evidence];
finding.remediation = Some(
"Implement rate limiting on this endpoint. Use token bucket or sliding window \
algorithms. Return 429 Too Many Requests with a Retry-After header when the \
limit is exceeded."
.to_string(),
);
findings.push(finding);
warn!(url, method, total_sent, "No rate limiting detected");
} else if got_429 {
info!(
url,
method,
rate_limit_at = ?rate_limit_at_request,
"Rate limiting is enforced"
);
}
.to_string(),
);
findings.push(finding);
warn!(url, method, total_sent, "No rate limiting detected");
} else if got_429 {
info!(
url,
method,
rate_limit_at = ?rate_limit_at_request,
"Rate limiting is enforced"
);
}
let count = findings.len();
let count = findings.len();
Ok(PentestToolResult {
summary: if got_429 {
format!(
"Rate limiting is enforced on {method} {url}. \
Ok(PentestToolResult {
summary: if got_429 {
format!(
"Rate limiting is enforced on {method} {url}. \
429 response received after {} requests.",
rate_limit_at_request.unwrap_or(0)
)
} else if count > 0 {
format!(
"No rate limiting detected on {method} {url} after {total_sent} requests."
)
} else {
format!("Rate limit testing complete for {method} {url}.")
},
findings,
data: rate_data,
})
rate_limit_at_request.unwrap_or(0)
)
} else if count > 0 {
format!(
"No rate limiting detected on {method} {url} after {total_sent} requests."
)
} else {
format!("Rate limit testing complete for {method} {url}.")
},
findings,
data: rate_data,
})
})
}
}

View File

@@ -54,72 +54,75 @@ impl PentestTool for ReconTool {
fn execute<'a>(
&'a self,
input: serde_json::Value,
context: &'a PentestToolContext,
) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<PentestToolResult, CoreError>> + Send + 'a>> {
_context: &'a PentestToolContext,
) -> std::pin::Pin<
Box<dyn std::future::Future<Output = Result<PentestToolResult, CoreError>> + Send + 'a>,
> {
Box::pin(async move {
let url = input
.get("url")
.and_then(|v| v.as_str())
.ok_or_else(|| CoreError::Dast("Missing required 'url' parameter".to_string()))?;
let url = input
.get("url")
.and_then(|v| v.as_str())
.ok_or_else(|| CoreError::Dast("Missing required 'url' parameter".to_string()))?;
let additional_paths: Vec<String> = input
.get("additional_paths")
.and_then(|v| v.as_array())
.map(|arr| {
arr.iter()
.filter_map(|v| v.as_str().map(String::from))
.collect()
})
.unwrap_or_default();
let additional_paths: Vec<String> = input
.get("additional_paths")
.and_then(|v| v.as_array())
.map(|arr| {
arr.iter()
.filter_map(|v| v.as_str().map(String::from))
.collect()
})
.unwrap_or_default();
let result = self.agent.scan(url).await?;
let result = self.agent.scan(url).await?;
// Scan additional paths for more technology signals
let mut extra_technologies: Vec<String> = Vec::new();
let mut extra_headers: HashMap<String, String> = HashMap::new();
// Scan additional paths for more technology signals
let mut extra_technologies: Vec<String> = Vec::new();
let mut extra_headers: HashMap<String, String> = HashMap::new();
let base_url = url.trim_end_matches('/');
for path in &additional_paths {
let probe_url = format!("{base_url}/{}", path.trim_start_matches('/'));
if let Ok(resp) = self.http.get(&probe_url).send().await {
for (key, value) in resp.headers() {
let k = key.to_string().to_lowercase();
let v = value.to_str().unwrap_or("").to_string();
let base_url = url.trim_end_matches('/');
for path in &additional_paths {
let probe_url = format!("{base_url}/{}", path.trim_start_matches('/'));
if let Ok(resp) = self.http.get(&probe_url).send().await {
for (key, value) in resp.headers() {
let k = key.to_string().to_lowercase();
let v = value.to_str().unwrap_or("").to_string();
// Look for technology indicators
if k == "x-powered-by" || k == "server" || k == "x-generator" {
if !result.technologies.contains(&v) && !extra_technologies.contains(&v) {
// Look for technology indicators
if (k == "x-powered-by" || k == "server" || k == "x-generator")
&& !result.technologies.contains(&v)
&& !extra_technologies.contains(&v)
{
extra_technologies.push(v.clone());
}
extra_headers.insert(format!("{probe_url} -> {k}"), v);
}
extra_headers.insert(format!("{probe_url} -> {k}"), v);
}
}
}
let mut all_technologies = result.technologies.clone();
all_technologies.extend(extra_technologies);
all_technologies.dedup();
let mut all_technologies = result.technologies.clone();
all_technologies.extend(extra_technologies);
all_technologies.dedup();
let tech_count = all_technologies.len();
info!(url, technologies = tech_count, "Recon complete");
let tech_count = all_technologies.len();
info!(url, technologies = tech_count, "Recon complete");
Ok(PentestToolResult {
summary: format!(
"Recon complete for {url}. Detected {} technologies. Server: {}.",
tech_count,
result.server.as_deref().unwrap_or("unknown")
),
findings: Vec::new(), // Recon produces data, not findings
data: json!({
"base_url": url,
"server": result.server,
"technologies": all_technologies,
"interesting_headers": result.interesting_headers,
"extra_headers": extra_headers,
"open_ports": result.open_ports,
}),
})
Ok(PentestToolResult {
summary: format!(
"Recon complete for {url}. Detected {} technologies. Server: {}.",
tech_count,
result.server.as_deref().unwrap_or("unknown")
),
findings: Vec::new(), // Recon produces data, not findings
data: json!({
"base_url": url,
"server": result.server,
"technologies": all_technologies,
"interesting_headers": result.interesting_headers,
"extra_headers": extra_headers,
"open_ports": result.open_ports,
}),
})
})
}
}

View File

@@ -111,57 +111,107 @@ impl PentestTool for SecurityHeadersTool {
&'a self,
input: serde_json::Value,
context: &'a PentestToolContext,
) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<PentestToolResult, CoreError>> + Send + 'a>> {
) -> std::pin::Pin<
Box<dyn std::future::Future<Output = Result<PentestToolResult, CoreError>> + Send + 'a>,
> {
Box::pin(async move {
let url = input
.get("url")
.and_then(|v| v.as_str())
.ok_or_else(|| CoreError::Dast("Missing required 'url' parameter".to_string()))?;
let url = input
.get("url")
.and_then(|v| v.as_str())
.ok_or_else(|| CoreError::Dast("Missing required 'url' parameter".to_string()))?;
let target_id = context
.target
.id
.map(|oid| oid.to_hex())
.unwrap_or_else(|| "unknown".to_string());
let target_id = context
.target
.id
.map(|oid| oid.to_hex())
.unwrap_or_else(|| "unknown".to_string());
let response = self
.http
.get(url)
.send()
.await
.map_err(|e| CoreError::Dast(format!("Failed to fetch {url}: {e}")))?;
let response = self
.http
.get(url)
.send()
.await
.map_err(|e| CoreError::Dast(format!("Failed to fetch {url}: {e}")))?;
let status = response.status().as_u16();
let response_headers: HashMap<String, String> = response
.headers()
.iter()
.map(|(k, v)| (k.to_string().to_lowercase(), v.to_str().unwrap_or("").to_string()))
.collect();
let status = response.status().as_u16();
let response_headers: HashMap<String, String> = response
.headers()
.iter()
.map(|(k, v)| {
(
k.to_string().to_lowercase(),
v.to_str().unwrap_or("").to_string(),
)
})
.collect();
let mut findings = Vec::new();
let mut header_results: HashMap<String, serde_json::Value> = HashMap::new();
let mut findings = Vec::new();
let mut header_results: HashMap<String, serde_json::Value> = HashMap::new();
for expected in Self::expected_headers() {
let header_value = response_headers.get(expected.name);
for expected in Self::expected_headers() {
let header_value = response_headers.get(expected.name);
match header_value {
Some(value) => {
let mut is_valid = true;
if let Some(ref valid) = expected.valid_values {
let lower = value.to_lowercase();
is_valid = valid.iter().any(|v| lower.contains(v));
match header_value {
Some(value) => {
let mut is_valid = true;
if let Some(ref valid) = expected.valid_values {
let lower = value.to_lowercase();
is_valid = valid.iter().any(|v| lower.contains(v));
}
header_results.insert(
expected.name.to_string(),
json!({
"present": true,
"value": value,
"valid": is_valid,
}),
);
if !is_valid {
let evidence = DastEvidence {
request_method: "GET".to_string(),
request_url: url.to_string(),
request_headers: None,
request_body: None,
response_status: status,
response_headers: Some(response_headers.clone()),
response_snippet: Some(format!("{}: {}", expected.name, value)),
screenshot_path: None,
payload: None,
response_time_ms: None,
};
let mut finding = DastFinding::new(
String::new(),
target_id.clone(),
DastVulnType::SecurityHeaderMissing,
format!("Invalid {} header value", expected.name),
format!(
"The {} header is present but has an invalid or weak value: '{}'. \
{}",
expected.name, value, expected.description
),
expected.severity.clone(),
url.to_string(),
"GET".to_string(),
);
finding.cwe = Some(expected.cwe.to_string());
finding.evidence = vec![evidence];
finding.remediation = Some(expected.remediation.to_string());
findings.push(finding);
}
}
None => {
header_results.insert(
expected.name.to_string(),
json!({
"present": false,
"value": null,
"valid": false,
}),
);
header_results.insert(
expected.name.to_string(),
json!({
"present": true,
"value": value,
"valid": is_valid,
}),
);
if !is_valid {
let evidence = DastEvidence {
request_method: "GET".to_string(),
request_url: url.to_string(),
@@ -169,7 +219,7 @@ impl PentestTool for SecurityHeadersTool {
request_body: None,
response_status: status,
response_headers: Some(response_headers.clone()),
response_snippet: Some(format!("{}: {}", expected.name, value)),
response_snippet: Some(format!("{} header is missing", expected.name)),
screenshot_path: None,
payload: None,
response_time_ms: None,
@@ -179,11 +229,10 @@ impl PentestTool for SecurityHeadersTool {
String::new(),
target_id.clone(),
DastVulnType::SecurityHeaderMissing,
format!("Invalid {} header value", expected.name),
format!("Missing {} header", expected.name),
format!(
"The {} header is present but has an invalid or weak value: '{}'. \
{}",
expected.name, value, expected.description
"The {} header is not present in the response. {}",
expected.name, expected.description
),
expected.severity.clone(),
url.to_string(),
@@ -195,14 +244,20 @@ impl PentestTool for SecurityHeadersTool {
findings.push(finding);
}
}
None => {
}
// Also check for information disclosure headers
let disclosure_headers = [
"server",
"x-powered-by",
"x-aspnet-version",
"x-aspnetmvc-version",
];
for h in &disclosure_headers {
if let Some(value) = response_headers.get(*h) {
header_results.insert(
expected.name.to_string(),
json!({
"present": false,
"value": null,
"valid": false,
}),
format!("{h}_disclosure"),
json!({ "present": true, "value": value }),
);
let evidence = DastEvidence {
@@ -212,56 +267,13 @@ impl PentestTool for SecurityHeadersTool {
request_body: None,
response_status: status,
response_headers: Some(response_headers.clone()),
response_snippet: Some(format!("{} header is missing", expected.name)),
response_snippet: Some(format!("{h}: {value}")),
screenshot_path: None,
payload: None,
response_time_ms: None,
};
let mut finding = DastFinding::new(
String::new(),
target_id.clone(),
DastVulnType::SecurityHeaderMissing,
format!("Missing {} header", expected.name),
format!(
"The {} header is not present in the response. {}",
expected.name, expected.description
),
expected.severity.clone(),
url.to_string(),
"GET".to_string(),
);
finding.cwe = Some(expected.cwe.to_string());
finding.evidence = vec![evidence];
finding.remediation = Some(expected.remediation.to_string());
findings.push(finding);
}
}
}
// Also check for information disclosure headers
let disclosure_headers = ["server", "x-powered-by", "x-aspnet-version", "x-aspnetmvc-version"];
for h in &disclosure_headers {
if let Some(value) = response_headers.get(*h) {
header_results.insert(
format!("{h}_disclosure"),
json!({ "present": true, "value": value }),
);
let evidence = DastEvidence {
request_method: "GET".to_string(),
request_url: url.to_string(),
request_headers: None,
request_body: None,
response_status: status,
response_headers: Some(response_headers.clone()),
response_snippet: Some(format!("{h}: {value}")),
screenshot_path: None,
payload: None,
response_time_ms: None,
};
let mut finding = DastFinding::new(
String::new(),
target_id.clone(),
DastVulnType::SecurityHeaderMissing,
@@ -274,27 +286,27 @@ impl PentestTool for SecurityHeadersTool {
url.to_string(),
"GET".to_string(),
);
finding.cwe = Some("CWE-200".to_string());
finding.evidence = vec![evidence];
finding.remediation = Some(format!(
"Remove or suppress the {h} header in your server configuration."
));
findings.push(finding);
finding.cwe = Some("CWE-200".to_string());
finding.evidence = vec![evidence];
finding.remediation = Some(format!(
"Remove or suppress the {h} header in your server configuration."
));
findings.push(finding);
}
}
}
let count = findings.len();
info!(url, findings = count, "Security headers check complete");
let count = findings.len();
info!(url, findings = count, "Security headers check complete");
Ok(PentestToolResult {
summary: if count > 0 {
format!("Found {count} security header issues for {url}.")
} else {
format!("All checked security headers are present and valid for {url}.")
},
findings,
data: json!(header_results),
})
Ok(PentestToolResult {
summary: if count > 0 {
format!("Found {count} security header issues for {url}.")
} else {
format!("All checked security headers are present and valid for {url}.")
},
findings,
data: json!(header_results),
})
})
}
}

View File

@@ -1,5 +1,7 @@
use compliance_core::error::CoreError;
use compliance_core::traits::dast_agent::{DastAgent, DastContext, DiscoveredEndpoint, EndpointParameter};
use compliance_core::traits::dast_agent::{
DastAgent, DastContext, DiscoveredEndpoint, EndpointParameter,
};
use compliance_core::traits::pentest_tool::{PentestTool, PentestToolContext, PentestToolResult};
use serde_json::json;
@@ -7,29 +9,51 @@ use crate::agents::injection::SqlInjectionAgent;
/// PentestTool wrapper around the existing SqlInjectionAgent.
pub struct SqlInjectionTool {
http: reqwest::Client,
_http: reqwest::Client,
agent: SqlInjectionAgent,
}
impl SqlInjectionTool {
pub fn new(http: reqwest::Client) -> Self {
let agent = SqlInjectionAgent::new(http.clone());
Self { http, agent }
Self { _http: http, agent }
}
fn parse_endpoints(input: &serde_json::Value) -> Vec<DiscoveredEndpoint> {
let mut endpoints = Vec::new();
if let Some(arr) = input.get("endpoints").and_then(|v| v.as_array()) {
for ep in arr {
let url = ep.get("url").and_then(|v| v.as_str()).unwrap_or_default().to_string();
let method = ep.get("method").and_then(|v| v.as_str()).unwrap_or("GET").to_string();
let url = ep
.get("url")
.and_then(|v| v.as_str())
.unwrap_or_default()
.to_string();
let method = ep
.get("method")
.and_then(|v| v.as_str())
.unwrap_or("GET")
.to_string();
let mut parameters = Vec::new();
if let Some(params) = ep.get("parameters").and_then(|v| v.as_array()) {
for p in params {
let name = p.get("name").and_then(|v| v.as_str()).unwrap_or_default().to_string();
let location = p.get("location").and_then(|v| v.as_str()).unwrap_or("query").to_string();
let param_type = p.get("param_type").and_then(|v| v.as_str()).map(String::from);
let example_value = p.get("example_value").and_then(|v| v.as_str()).map(String::from);
let name = p
.get("name")
.and_then(|v| v.as_str())
.unwrap_or_default()
.to_string();
let location = p
.get("location")
.and_then(|v| v.as_str())
.unwrap_or("query")
.to_string();
let param_type = p
.get("param_type")
.and_then(|v| v.as_str())
.map(String::from);
let example_value = p
.get("example_value")
.and_then(|v| v.as_str())
.map(String::from);
parameters.push(EndpointParameter {
name,
location,
@@ -42,8 +66,14 @@ impl SqlInjectionTool {
url,
method,
parameters,
content_type: ep.get("content_type").and_then(|v| v.as_str()).map(String::from),
requires_auth: ep.get("requires_auth").and_then(|v| v.as_bool()).unwrap_or(false),
content_type: ep
.get("content_type")
.and_then(|v| v.as_str())
.map(String::from),
requires_auth: ep
.get("requires_auth")
.and_then(|v| v.as_bool())
.unwrap_or(false),
});
}
}
@@ -51,6 +81,62 @@ impl SqlInjectionTool {
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn parse_endpoints_basic() {
let input = json!({
"endpoints": [{
"url": "https://example.com/api/users",
"method": "POST",
"parameters": [
{ "name": "id", "location": "body", "param_type": "integer" }
]
}]
});
let endpoints = SqlInjectionTool::parse_endpoints(&input);
assert_eq!(endpoints.len(), 1);
assert_eq!(endpoints[0].url, "https://example.com/api/users");
assert_eq!(endpoints[0].method, "POST");
assert_eq!(endpoints[0].parameters[0].name, "id");
assert_eq!(endpoints[0].parameters[0].location, "body");
assert_eq!(
endpoints[0].parameters[0].param_type.as_deref(),
Some("integer")
);
}
#[test]
fn parse_endpoints_empty_input() {
assert!(SqlInjectionTool::parse_endpoints(&json!({})).is_empty());
assert!(SqlInjectionTool::parse_endpoints(&json!({ "endpoints": [] })).is_empty());
}
#[test]
fn parse_endpoints_multiple() {
let input = json!({
"endpoints": [
{ "url": "https://a.com/1", "method": "GET", "parameters": [] },
{ "url": "https://b.com/2", "method": "DELETE", "parameters": [] }
]
});
let endpoints = SqlInjectionTool::parse_endpoints(&input);
assert_eq!(endpoints.len(), 2);
assert_eq!(endpoints[0].url, "https://a.com/1");
assert_eq!(endpoints[1].method, "DELETE");
}
#[test]
fn parse_endpoints_default_method() {
let input = json!({ "endpoints": [{ "url": "https://x.com", "parameters": [] }] });
let endpoints = SqlInjectionTool::parse_endpoints(&input);
assert_eq!(endpoints[0].method, "GET");
}
}
impl PentestTool for SqlInjectionTool {
fn name(&self) -> &str {
"sql_injection_scanner"
@@ -104,35 +190,37 @@ impl PentestTool for SqlInjectionTool {
&'a self,
input: serde_json::Value,
context: &'a PentestToolContext,
) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<PentestToolResult, CoreError>> + Send + 'a>> {
) -> std::pin::Pin<
Box<dyn std::future::Future<Output = Result<PentestToolResult, CoreError>> + Send + 'a>,
> {
Box::pin(async move {
let endpoints = Self::parse_endpoints(&input);
if endpoints.is_empty() {
return Ok(PentestToolResult {
summary: "No endpoints provided to test.".to_string(),
findings: Vec::new(),
data: json!({}),
});
}
let endpoints = Self::parse_endpoints(&input);
if endpoints.is_empty() {
return Ok(PentestToolResult {
summary: "No endpoints provided to test.".to_string(),
findings: Vec::new(),
data: json!({}),
});
}
let dast_context = DastContext {
endpoints,
technologies: Vec::new(),
sast_hints: Vec::new(),
};
let dast_context = DastContext {
endpoints,
technologies: Vec::new(),
sast_hints: Vec::new(),
};
let findings = self.agent.run(&context.target, &dast_context).await?;
let count = findings.len();
let findings = self.agent.run(&context.target, &dast_context).await?;
let count = findings.len();
Ok(PentestToolResult {
summary: if count > 0 {
format!("Found {count} SQL injection vulnerabilities.")
} else {
"No SQL injection vulnerabilities detected.".to_string()
},
findings,
data: json!({ "endpoints_tested": dast_context.endpoints.len() }),
})
Ok(PentestToolResult {
summary: if count > 0 {
format!("Found {count} SQL injection vulnerabilities.")
} else {
"No SQL injection vulnerabilities detected.".to_string()
},
findings,
data: json!({ "endpoints_tested": dast_context.endpoints.len() }),
})
})
}
}

View File

@@ -1,5 +1,7 @@
use compliance_core::error::CoreError;
use compliance_core::traits::dast_agent::{DastAgent, DastContext, DiscoveredEndpoint, EndpointParameter};
use compliance_core::traits::dast_agent::{
DastAgent, DastContext, DiscoveredEndpoint, EndpointParameter,
};
use compliance_core::traits::pentest_tool::{PentestTool, PentestToolContext, PentestToolResult};
use serde_json::json;
@@ -7,30 +9,52 @@ use crate::agents::ssrf::SsrfAgent;
/// PentestTool wrapper around the existing SsrfAgent.
pub struct SsrfTool {
http: reqwest::Client,
_http: reqwest::Client,
agent: SsrfAgent,
}
impl SsrfTool {
pub fn new(http: reqwest::Client) -> Self {
let agent = SsrfAgent::new(http.clone());
Self { http, agent }
Self { _http: http, agent }
}
fn parse_endpoints(input: &serde_json::Value) -> Vec<DiscoveredEndpoint> {
let mut endpoints = Vec::new();
if let Some(arr) = input.get("endpoints").and_then(|v| v.as_array()) {
for ep in arr {
let url = ep.get("url").and_then(|v| v.as_str()).unwrap_or_default().to_string();
let method = ep.get("method").and_then(|v| v.as_str()).unwrap_or("GET").to_string();
let url = ep
.get("url")
.and_then(|v| v.as_str())
.unwrap_or_default()
.to_string();
let method = ep
.get("method")
.and_then(|v| v.as_str())
.unwrap_or("GET")
.to_string();
let mut parameters = Vec::new();
if let Some(params) = ep.get("parameters").and_then(|v| v.as_array()) {
for p in params {
parameters.push(EndpointParameter {
name: p.get("name").and_then(|v| v.as_str()).unwrap_or_default().to_string(),
location: p.get("location").and_then(|v| v.as_str()).unwrap_or("query").to_string(),
param_type: p.get("param_type").and_then(|v| v.as_str()).map(String::from),
example_value: p.get("example_value").and_then(|v| v.as_str()).map(String::from),
name: p
.get("name")
.and_then(|v| v.as_str())
.unwrap_or_default()
.to_string(),
location: p
.get("location")
.and_then(|v| v.as_str())
.unwrap_or("query")
.to_string(),
param_type: p
.get("param_type")
.and_then(|v| v.as_str())
.map(String::from),
example_value: p
.get("example_value")
.and_then(|v| v.as_str())
.map(String::from),
});
}
}
@@ -38,8 +62,14 @@ impl SsrfTool {
url,
method,
parameters,
content_type: ep.get("content_type").and_then(|v| v.as_str()).map(String::from),
requires_auth: ep.get("requires_auth").and_then(|v| v.as_bool()).unwrap_or(false),
content_type: ep
.get("content_type")
.and_then(|v| v.as_str())
.map(String::from),
requires_auth: ep
.get("requires_auth")
.and_then(|v| v.as_bool())
.unwrap_or(false),
});
}
}
@@ -100,35 +130,37 @@ impl PentestTool for SsrfTool {
&'a self,
input: serde_json::Value,
context: &'a PentestToolContext,
) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<PentestToolResult, CoreError>> + Send + 'a>> {
) -> std::pin::Pin<
Box<dyn std::future::Future<Output = Result<PentestToolResult, CoreError>> + Send + 'a>,
> {
Box::pin(async move {
let endpoints = Self::parse_endpoints(&input);
if endpoints.is_empty() {
return Ok(PentestToolResult {
summary: "No endpoints provided to test.".to_string(),
findings: Vec::new(),
data: json!({}),
});
}
let endpoints = Self::parse_endpoints(&input);
if endpoints.is_empty() {
return Ok(PentestToolResult {
summary: "No endpoints provided to test.".to_string(),
findings: Vec::new(),
data: json!({}),
});
}
let dast_context = DastContext {
endpoints,
technologies: Vec::new(),
sast_hints: Vec::new(),
};
let dast_context = DastContext {
endpoints,
technologies: Vec::new(),
sast_hints: Vec::new(),
};
let findings = self.agent.run(&context.target, &dast_context).await?;
let count = findings.len();
let findings = self.agent.run(&context.target, &dast_context).await?;
let count = findings.len();
Ok(PentestToolResult {
summary: if count > 0 {
format!("Found {count} SSRF vulnerabilities.")
} else {
"No SSRF vulnerabilities detected.".to_string()
},
findings,
data: json!({ "endpoints_tested": dast_context.endpoints.len() }),
})
Ok(PentestToolResult {
summary: if count > 0 {
format!("Found {count} SSRF vulnerabilities.")
} else {
"No SSRF vulnerabilities detected.".to_string()
},
findings,
data: json!({ "endpoints_tested": dast_context.endpoints.len() }),
})
})
}
}

View File

@@ -39,10 +39,7 @@ impl TlsAnalyzerTool {
/// TLS client hello. We test SSLv3 / old protocol support by attempting
/// connection with the system's native-tls which typically negotiates the
/// best available, then inspect what was negotiated.
async fn check_tls(
host: &str,
port: u16,
) -> Result<TlsInfo, CoreError> {
async fn check_tls(host: &str, port: u16) -> Result<TlsInfo, CoreError> {
let addr = format!("{host}:{port}");
let tcp = TcpStream::connect(&addr)
@@ -62,11 +59,13 @@ impl TlsAnalyzerTool {
.await
.map_err(|e| CoreError::Dast(format!("TLS handshake with {addr} failed: {e}")))?;
let peer_cert = tls_stream.get_ref().peer_certificate()
let peer_cert = tls_stream
.get_ref()
.peer_certificate()
.map_err(|e| CoreError::Dast(format!("Failed to get peer certificate: {e}")))?;
let mut tls_info = TlsInfo {
protocol_version: String::new(),
_protocol_version: String::new(),
cert_subject: String::new(),
cert_issuer: String::new(),
cert_not_before: String::new(),
@@ -78,7 +77,8 @@ impl TlsAnalyzerTool {
};
if let Some(cert) = peer_cert {
let der = cert.to_der()
let der = cert
.to_der()
.map_err(|e| CoreError::Dast(format!("Certificate DER encoding failed: {e}")))?;
// native_tls doesn't give rich access, so we parse what we can
@@ -93,7 +93,7 @@ impl TlsAnalyzerTool {
/// Best-effort parse of DER-encoded X.509 certificate for dates and subject.
/// This is a simplified parser; in production you would use a proper x509 crate.
fn parse_cert_der(der: &[u8], mut info: TlsInfo) -> TlsInfo {
fn parse_cert_der(_der: &[u8], mut info: TlsInfo) -> TlsInfo {
// We rely on the native_tls debug output stored in cert_subject
// and just mark fields as "see certificate details"
if info.cert_subject.contains("self signed") || info.cert_subject.contains("Self-Signed") {
@@ -104,7 +104,7 @@ impl TlsAnalyzerTool {
}
struct TlsInfo {
protocol_version: String,
_protocol_version: String,
cert_subject: String,
cert_issuer: String,
cert_not_before: String,
@@ -152,111 +152,194 @@ impl PentestTool for TlsAnalyzerTool {
&'a self,
input: serde_json::Value,
context: &'a PentestToolContext,
) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<PentestToolResult, CoreError>> + Send + 'a>> {
) -> std::pin::Pin<
Box<dyn std::future::Future<Output = Result<PentestToolResult, CoreError>> + Send + 'a>,
> {
Box::pin(async move {
let url = input
.get("url")
.and_then(|v| v.as_str())
.ok_or_else(|| CoreError::Dast("Missing required 'url' parameter".to_string()))?;
let url = input
.get("url")
.and_then(|v| v.as_str())
.ok_or_else(|| CoreError::Dast("Missing required 'url' parameter".to_string()))?;
let host = Self::extract_host(url)
.unwrap_or_else(|| url.to_string());
let host = Self::extract_host(url).unwrap_or_else(|| url.to_string());
let port = input
.get("port")
.and_then(|v| v.as_u64())
.map(|p| p as u16)
.unwrap_or_else(|| Self::extract_port(url));
let port = input
.get("port")
.and_then(|v| v.as_u64())
.map(|p| p as u16)
.unwrap_or_else(|| Self::extract_port(url));
let target_id = context
.target
.id
.map(|oid| oid.to_hex())
.unwrap_or_else(|| "unknown".to_string());
let target_id = context
.target
.id
.map(|oid| oid.to_hex())
.unwrap_or_else(|| "unknown".to_string());
let mut findings = Vec::new();
let mut tls_data = json!({});
let mut findings = Vec::new();
let mut tls_data = json!({});
// First check: does the server even support HTTPS?
let https_url = if url.starts_with("https://") {
url.to_string()
} else if url.starts_with("http://") {
url.replace("http://", "https://")
} else {
format!("https://{url}")
};
// First check: does the server even support HTTPS?
let https_url = if url.starts_with("https://") {
url.to_string()
} else if url.starts_with("http://") {
url.replace("http://", "https://")
} else {
format!("https://{url}")
};
// Check if HTTP redirects to HTTPS
let http_url = if url.starts_with("http://") {
url.to_string()
} else if url.starts_with("https://") {
url.replace("https://", "http://")
} else {
format!("http://{url}")
};
// Check if HTTP redirects to HTTPS
let http_url = if url.starts_with("http://") {
url.to_string()
} else if url.starts_with("https://") {
url.replace("https://", "http://")
} else {
format!("http://{url}")
};
match self.http.get(&http_url).send().await {
Ok(resp) => {
let final_url = resp.url().to_string();
let redirects_to_https = final_url.starts_with("https://");
tls_data["http_redirects_to_https"] = json!(redirects_to_https);
match self.http.get(&http_url).send().await {
Ok(resp) => {
let final_url = resp.url().to_string();
let redirects_to_https = final_url.starts_with("https://");
tls_data["http_redirects_to_https"] = json!(redirects_to_https);
if !redirects_to_https {
let evidence = DastEvidence {
request_method: "GET".to_string(),
request_url: http_url.clone(),
request_headers: None,
request_body: None,
response_status: resp.status().as_u16(),
response_headers: None,
response_snippet: Some(format!("Final URL: {final_url}")),
screenshot_path: None,
payload: None,
response_time_ms: None,
};
if !redirects_to_https {
let evidence = DastEvidence {
request_method: "GET".to_string(),
request_url: http_url.clone(),
request_headers: None,
request_body: None,
response_status: resp.status().as_u16(),
response_headers: None,
response_snippet: Some(format!("Final URL: {final_url}")),
screenshot_path: None,
payload: None,
response_time_ms: None,
};
let mut finding = DastFinding::new(
String::new(),
target_id.clone(),
DastVulnType::TlsMisconfiguration,
format!("HTTP does not redirect to HTTPS for {host}"),
format!(
"HTTP requests to {host} are not redirected to HTTPS. \
let mut finding = DastFinding::new(
String::new(),
target_id.clone(),
DastVulnType::TlsMisconfiguration,
format!("HTTP does not redirect to HTTPS for {host}"),
format!(
"HTTP requests to {host} are not redirected to HTTPS. \
Users accessing the site via HTTP will have their traffic \
transmitted in cleartext."
),
Severity::Medium,
http_url.clone(),
"GET".to_string(),
);
finding.cwe = Some("CWE-319".to_string());
finding.evidence = vec![evidence];
finding.remediation = Some(
"Configure the web server to redirect all HTTP requests to HTTPS \
),
Severity::Medium,
http_url.clone(),
"GET".to_string(),
);
finding.cwe = Some("CWE-319".to_string());
finding.evidence = vec![evidence];
finding.remediation = Some(
"Configure the web server to redirect all HTTP requests to HTTPS \
using a 301 redirect."
.to_string(),
);
findings.push(finding);
.to_string(),
);
findings.push(finding);
}
}
Err(_) => {
tls_data["http_check_error"] = json!("Could not connect via HTTP");
}
}
Err(_) => {
tls_data["http_check_error"] = json!("Could not connect via HTTP");
}
}
// Perform TLS analysis
match Self::check_tls(&host, port).await {
Ok(tls_info) => {
tls_data["host"] = json!(host);
tls_data["port"] = json!(port);
tls_data["cert_subject"] = json!(tls_info.cert_subject);
tls_data["cert_issuer"] = json!(tls_info.cert_issuer);
tls_data["cert_not_before"] = json!(tls_info.cert_not_before);
tls_data["cert_not_after"] = json!(tls_info.cert_not_after);
tls_data["alpn_protocol"] = json!(tls_info.alpn_protocol);
tls_data["san_names"] = json!(tls_info.san_names);
// Perform TLS analysis
match Self::check_tls(&host, port).await {
Ok(tls_info) => {
tls_data["host"] = json!(host);
tls_data["port"] = json!(port);
tls_data["cert_subject"] = json!(tls_info.cert_subject);
tls_data["cert_issuer"] = json!(tls_info.cert_issuer);
tls_data["cert_not_before"] = json!(tls_info.cert_not_before);
tls_data["cert_not_after"] = json!(tls_info.cert_not_after);
tls_data["alpn_protocol"] = json!(tls_info.alpn_protocol);
tls_data["san_names"] = json!(tls_info.san_names);
if tls_info.cert_expired {
if tls_info.cert_expired {
let evidence = DastEvidence {
request_method: "TLS".to_string(),
request_url: format!("{host}:{port}"),
request_headers: None,
request_body: None,
response_status: 0,
response_headers: None,
response_snippet: Some(format!(
"Certificate expired. Not After: {}",
tls_info.cert_not_after
)),
screenshot_path: None,
payload: None,
response_time_ms: None,
};
let mut finding = DastFinding::new(
String::new(),
target_id.clone(),
DastVulnType::TlsMisconfiguration,
format!("Expired TLS certificate for {host}"),
format!(
"The TLS certificate for {host} has expired. \
Browsers will show security warnings to users."
),
Severity::High,
format!("https://{host}:{port}"),
"TLS".to_string(),
);
finding.cwe = Some("CWE-295".to_string());
finding.evidence = vec![evidence];
finding.remediation = Some(
"Renew the TLS certificate. Consider using automated certificate \
management with Let's Encrypt or a similar CA."
.to_string(),
);
findings.push(finding);
warn!(host, "Expired TLS certificate");
}
if tls_info.cert_self_signed {
let evidence = DastEvidence {
request_method: "TLS".to_string(),
request_url: format!("{host}:{port}"),
request_headers: None,
request_body: None,
response_status: 0,
response_headers: None,
response_snippet: Some("Self-signed certificate detected".to_string()),
screenshot_path: None,
payload: None,
response_time_ms: None,
};
let mut finding = DastFinding::new(
String::new(),
target_id.clone(),
DastVulnType::TlsMisconfiguration,
format!("Self-signed TLS certificate for {host}"),
format!(
"The TLS certificate for {host} is self-signed and not issued by a \
trusted certificate authority. Browsers will show security warnings."
),
Severity::Medium,
format!("https://{host}:{port}"),
"TLS".to_string(),
);
finding.cwe = Some("CWE-295".to_string());
finding.evidence = vec![evidence];
finding.remediation = Some(
"Replace the self-signed certificate with one issued by a trusted \
certificate authority."
.to_string(),
);
findings.push(finding);
warn!(host, "Self-signed certificate");
}
}
Err(e) => {
tls_data["tls_error"] = json!(e.to_string());
// TLS handshake failure itself is a finding
let evidence = DastEvidence {
request_method: "TLS".to_string(),
request_url: format!("{host}:{port}"),
@@ -264,10 +347,7 @@ impl PentestTool for TlsAnalyzerTool {
request_body: None,
response_status: 0,
response_headers: None,
response_snippet: Some(format!(
"Certificate expired. Not After: {}",
tls_info.cert_not_after
)),
response_snippet: Some(format!("TLS error: {e}")),
screenshot_path: None,
payload: None,
response_time_ms: None,
@@ -277,10 +357,9 @@ impl PentestTool for TlsAnalyzerTool {
String::new(),
target_id.clone(),
DastVulnType::TlsMisconfiguration,
format!("Expired TLS certificate for {host}"),
format!("TLS handshake failure for {host}"),
format!(
"The TLS certificate for {host} has expired. \
Browsers will show security warnings to users."
"Could not establish a TLS connection to {host}:{port}. Error: {e}"
),
Severity::High,
format!("https://{host}:{port}"),
@@ -289,115 +368,37 @@ impl PentestTool for TlsAnalyzerTool {
finding.cwe = Some("CWE-295".to_string());
finding.evidence = vec![evidence];
finding.remediation = Some(
"Renew the TLS certificate. Consider using automated certificate \
management with Let's Encrypt or a similar CA."
.to_string(),
);
findings.push(finding);
warn!(host, "Expired TLS certificate");
}
if tls_info.cert_self_signed {
let evidence = DastEvidence {
request_method: "TLS".to_string(),
request_url: format!("{host}:{port}"),
request_headers: None,
request_body: None,
response_status: 0,
response_headers: None,
response_snippet: Some("Self-signed certificate detected".to_string()),
screenshot_path: None,
payload: None,
response_time_ms: None,
};
let mut finding = DastFinding::new(
String::new(),
target_id.clone(),
DastVulnType::TlsMisconfiguration,
format!("Self-signed TLS certificate for {host}"),
format!(
"The TLS certificate for {host} is self-signed and not issued by a \
trusted certificate authority. Browsers will show security warnings."
),
Severity::Medium,
format!("https://{host}:{port}"),
"TLS".to_string(),
);
finding.cwe = Some("CWE-295".to_string());
finding.evidence = vec![evidence];
finding.remediation = Some(
"Replace the self-signed certificate with one issued by a trusted \
certificate authority."
.to_string(),
);
findings.push(finding);
warn!(host, "Self-signed certificate");
}
}
Err(e) => {
tls_data["tls_error"] = json!(e.to_string());
// TLS handshake failure itself is a finding
let evidence = DastEvidence {
request_method: "TLS".to_string(),
request_url: format!("{host}:{port}"),
request_headers: None,
request_body: None,
response_status: 0,
response_headers: None,
response_snippet: Some(format!("TLS error: {e}")),
screenshot_path: None,
payload: None,
response_time_ms: None,
};
let mut finding = DastFinding::new(
String::new(),
target_id.clone(),
DastVulnType::TlsMisconfiguration,
format!("TLS handshake failure for {host}"),
format!(
"Could not establish a TLS connection to {host}:{port}. Error: {e}"
),
Severity::High,
format!("https://{host}:{port}"),
"TLS".to_string(),
);
finding.cwe = Some("CWE-295".to_string());
finding.evidence = vec![evidence];
finding.remediation = Some(
"Ensure TLS is properly configured on the server. Check that the \
"Ensure TLS is properly configured on the server. Check that the \
certificate is valid and the server supports modern TLS versions."
.to_string(),
);
findings.push(finding);
.to_string(),
);
findings.push(finding);
}
}
}
// Check strict transport security via an HTTPS request
match self.http.get(&https_url).send().await {
Ok(resp) => {
let hsts = resp.headers().get("strict-transport-security");
tls_data["hsts_header"] = json!(hsts.map(|v| v.to_str().unwrap_or("")));
// Check strict transport security via an HTTPS request
match self.http.get(&https_url).send().await {
Ok(resp) => {
let hsts = resp.headers().get("strict-transport-security");
tls_data["hsts_header"] = json!(hsts.map(|v| v.to_str().unwrap_or("")));
if hsts.is_none() {
let evidence = DastEvidence {
request_method: "GET".to_string(),
request_url: https_url.clone(),
request_headers: None,
request_body: None,
response_status: resp.status().as_u16(),
response_headers: None,
response_snippet: Some(
"Strict-Transport-Security header not present".to_string(),
),
screenshot_path: None,
payload: None,
response_time_ms: None,
};
if hsts.is_none() {
let evidence = DastEvidence {
request_method: "GET".to_string(),
request_url: https_url.clone(),
request_headers: None,
request_body: None,
response_status: resp.status().as_u16(),
response_headers: None,
response_snippet: Some(
"Strict-Transport-Security header not present".to_string(),
),
screenshot_path: None,
payload: None,
response_time_ms: None,
};
let mut finding = DastFinding::new(
let mut finding = DastFinding::new(
String::new(),
target_id.clone(),
DastVulnType::TlsMisconfiguration,
@@ -410,33 +411,33 @@ impl PentestTool for TlsAnalyzerTool {
https_url.clone(),
"GET".to_string(),
);
finding.cwe = Some("CWE-319".to_string());
finding.evidence = vec![evidence];
finding.remediation = Some(
finding.cwe = Some("CWE-319".to_string());
finding.evidence = vec![evidence];
finding.remediation = Some(
"Add the Strict-Transport-Security header with an appropriate max-age. \
Example: 'Strict-Transport-Security: max-age=31536000; includeSubDomains'."
.to_string(),
);
findings.push(finding);
findings.push(finding);
}
}
Err(_) => {
tls_data["https_check_error"] = json!("Could not connect via HTTPS");
}
}
Err(_) => {
tls_data["https_check_error"] = json!("Could not connect via HTTPS");
}
}
let count = findings.len();
info!(host = %host, findings = count, "TLS analysis complete");
let count = findings.len();
info!(host = %host, findings = count, "TLS analysis complete");
Ok(PentestToolResult {
summary: if count > 0 {
format!("Found {count} TLS configuration issues for {host}.")
} else {
format!("TLS configuration looks good for {host}.")
},
findings,
data: tls_data,
})
Ok(PentestToolResult {
summary: if count > 0 {
format!("Found {count} TLS configuration issues for {host}.")
} else {
format!("TLS configuration looks good for {host}.")
},
findings,
data: tls_data,
})
})
}
}

View File

@@ -1,5 +1,7 @@
use compliance_core::error::CoreError;
use compliance_core::traits::dast_agent::{DastAgent, DastContext, DiscoveredEndpoint, EndpointParameter};
use compliance_core::traits::dast_agent::{
DastAgent, DastContext, DiscoveredEndpoint, EndpointParameter,
};
use compliance_core::traits::pentest_tool::{PentestTool, PentestToolContext, PentestToolResult};
use serde_json::json;
@@ -7,30 +9,52 @@ use crate::agents::xss::XssAgent;
/// PentestTool wrapper around the existing XssAgent.
pub struct XssTool {
http: reqwest::Client,
_http: reqwest::Client,
agent: XssAgent,
}
impl XssTool {
pub fn new(http: reqwest::Client) -> Self {
let agent = XssAgent::new(http.clone());
Self { http, agent }
Self { _http: http, agent }
}
fn parse_endpoints(input: &serde_json::Value) -> Vec<DiscoveredEndpoint> {
let mut endpoints = Vec::new();
if let Some(arr) = input.get("endpoints").and_then(|v| v.as_array()) {
for ep in arr {
let url = ep.get("url").and_then(|v| v.as_str()).unwrap_or_default().to_string();
let method = ep.get("method").and_then(|v| v.as_str()).unwrap_or("GET").to_string();
let url = ep
.get("url")
.and_then(|v| v.as_str())
.unwrap_or_default()
.to_string();
let method = ep
.get("method")
.and_then(|v| v.as_str())
.unwrap_or("GET")
.to_string();
let mut parameters = Vec::new();
if let Some(params) = ep.get("parameters").and_then(|v| v.as_array()) {
for p in params {
parameters.push(EndpointParameter {
name: p.get("name").and_then(|v| v.as_str()).unwrap_or_default().to_string(),
location: p.get("location").and_then(|v| v.as_str()).unwrap_or("query").to_string(),
param_type: p.get("param_type").and_then(|v| v.as_str()).map(String::from),
example_value: p.get("example_value").and_then(|v| v.as_str()).map(String::from),
name: p
.get("name")
.and_then(|v| v.as_str())
.unwrap_or_default()
.to_string(),
location: p
.get("location")
.and_then(|v| v.as_str())
.unwrap_or("query")
.to_string(),
param_type: p
.get("param_type")
.and_then(|v| v.as_str())
.map(String::from),
example_value: p
.get("example_value")
.and_then(|v| v.as_str())
.map(String::from),
});
}
}
@@ -38,8 +62,14 @@ impl XssTool {
url,
method,
parameters,
content_type: ep.get("content_type").and_then(|v| v.as_str()).map(String::from),
requires_auth: ep.get("requires_auth").and_then(|v| v.as_bool()).unwrap_or(false),
content_type: ep
.get("content_type")
.and_then(|v| v.as_str())
.map(String::from),
requires_auth: ep
.get("requires_auth")
.and_then(|v| v.as_bool())
.unwrap_or(false),
});
}
}
@@ -47,6 +77,91 @@ impl XssTool {
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn parse_endpoints_basic() {
let input = json!({
"endpoints": [
{
"url": "https://example.com/search",
"method": "GET",
"parameters": [
{ "name": "q", "location": "query" }
]
}
]
});
let endpoints = XssTool::parse_endpoints(&input);
assert_eq!(endpoints.len(), 1);
assert_eq!(endpoints[0].url, "https://example.com/search");
assert_eq!(endpoints[0].method, "GET");
assert_eq!(endpoints[0].parameters.len(), 1);
assert_eq!(endpoints[0].parameters[0].name, "q");
assert_eq!(endpoints[0].parameters[0].location, "query");
}
#[test]
fn parse_endpoints_empty() {
let input = json!({ "endpoints": [] });
assert!(XssTool::parse_endpoints(&input).is_empty());
}
#[test]
fn parse_endpoints_missing_key() {
let input = json!({});
assert!(XssTool::parse_endpoints(&input).is_empty());
}
#[test]
fn parse_endpoints_defaults() {
let input = json!({
"endpoints": [
{ "url": "https://example.com/api", "parameters": [] }
]
});
let endpoints = XssTool::parse_endpoints(&input);
assert_eq!(endpoints[0].method, "GET"); // default
assert!(!endpoints[0].requires_auth); // default false
}
#[test]
fn parse_endpoints_full_params() {
let input = json!({
"endpoints": [{
"url": "https://example.com",
"method": "POST",
"content_type": "application/json",
"requires_auth": true,
"parameters": [{
"name": "body",
"location": "body",
"param_type": "string",
"example_value": "test"
}]
}]
});
let endpoints = XssTool::parse_endpoints(&input);
assert_eq!(endpoints[0].method, "POST");
assert_eq!(
endpoints[0].content_type.as_deref(),
Some("application/json")
);
assert!(endpoints[0].requires_auth);
assert_eq!(
endpoints[0].parameters[0].param_type.as_deref(),
Some("string")
);
assert_eq!(
endpoints[0].parameters[0].example_value.as_deref(),
Some("test")
);
}
}
impl PentestTool for XssTool {
fn name(&self) -> &str {
"xss_scanner"
@@ -100,35 +215,37 @@ impl PentestTool for XssTool {
&'a self,
input: serde_json::Value,
context: &'a PentestToolContext,
) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<PentestToolResult, CoreError>> + Send + 'a>> {
) -> std::pin::Pin<
Box<dyn std::future::Future<Output = Result<PentestToolResult, CoreError>> + Send + 'a>,
> {
Box::pin(async move {
let endpoints = Self::parse_endpoints(&input);
if endpoints.is_empty() {
return Ok(PentestToolResult {
summary: "No endpoints provided to test.".to_string(),
findings: Vec::new(),
data: json!({}),
});
}
let endpoints = Self::parse_endpoints(&input);
if endpoints.is_empty() {
return Ok(PentestToolResult {
summary: "No endpoints provided to test.".to_string(),
findings: Vec::new(),
data: json!({}),
});
}
let dast_context = DastContext {
endpoints,
technologies: Vec::new(),
sast_hints: Vec::new(),
};
let dast_context = DastContext {
endpoints,
technologies: Vec::new(),
sast_hints: Vec::new(),
};
let findings = self.agent.run(&context.target, &dast_context).await?;
let count = findings.len();
let findings = self.agent.run(&context.target, &dast_context).await?;
let count = findings.len();
Ok(PentestToolResult {
summary: if count > 0 {
format!("Found {count} XSS vulnerabilities.")
} else {
"No XSS vulnerabilities detected.".to_string()
},
findings,
data: json!({ "endpoints_tested": dast_context.endpoints.len() }),
})
Ok(PentestToolResult {
summary: if count > 0 {
format!("Found {count} XSS vulnerabilities.")
} else {
"No XSS vulnerabilities detected.".to_string()
},
findings,
data: json!({ "endpoints_tested": dast_context.endpoints.len() }),
})
})
}
}