use std::collections::HashMap; use compliance_core::error::CoreError; use compliance_core::models::dast::{DastEvidence, DastFinding, DastVulnType}; use compliance_core::models::Severity; use compliance_core::traits::pentest_tool::{PentestTool, PentestToolContext, PentestToolResult}; use serde_json::json; use tracing::{info, warn}; /// Tool that checks CORS configuration for security issues. pub struct CorsCheckerTool { http: reqwest::Client, } impl CorsCheckerTool { pub fn new(http: reqwest::Client) -> Self { Self { http } } /// Origins to test against the target. fn test_origins(target_host: &str) -> Vec<(&'static str, String)> { vec![ ("null_origin", "null".to_string()), ("evil_domain", "https://evil.com".to_string()), ( "subdomain_spoof", format!("https://{target_host}.evil.com"), ), ( "prefix_spoof", format!("https://evil-{target_host}"), ), ("http_downgrade", format!("http://{target_host}")), ] } } impl PentestTool for CorsCheckerTool { fn name(&self) -> &str { "cors_checker" } fn description(&self) -> &str { "Checks CORS configuration by sending requests with various Origin headers. Tests for \ wildcard origins, reflected origins, null origin acceptance, and dangerous \ Access-Control-Allow-Credentials combinations." } fn input_schema(&self) -> serde_json::Value { json!({ "type": "object", "properties": { "url": { "type": "string", "description": "URL to test CORS configuration on" }, "additional_origins": { "type": "array", "description": "Optional additional origin values to test", "items": { "type": "string" } } }, "required": ["url"] }) } fn execute<'a>( &'a self, input: serde_json::Value, context: &'a PentestToolContext, ) -> std::pin::Pin> + Send + 'a>> { Box::pin(async move { let url = input .get("url") .and_then(|v| v.as_str()) .ok_or_else(|| CoreError::Dast("Missing required 'url' parameter".to_string()))?; let additional_origins: Vec = input .get("additional_origins") .and_then(|v| v.as_array()) .map(|arr| { arr.iter() .filter_map(|v| v.as_str().map(String::from)) .collect() }) .unwrap_or_default(); let target_id = context .target .id .map(|oid| oid.to_hex()) .unwrap_or_else(|| "unknown".to_string()); let target_host = url::Url::parse(url) .ok() .and_then(|u| u.host_str().map(String::from)) .unwrap_or_else(|| url.to_string()); let mut findings = Vec::new(); let mut cors_data: Vec = Vec::new(); // First, send a request without Origin to get baseline let baseline = self .http .get(url) .send() .await .map_err(|e| CoreError::Dast(format!("Failed to fetch {url}: {e}")))?; let baseline_acao = baseline .headers() .get("access-control-allow-origin") .and_then(|v| v.to_str().ok()) .map(String::from); cors_data.push(json!({ "origin": null, "acao": baseline_acao, })); // Check for wildcard + credentials (dangerous combo) if let Some(ref acao) = baseline_acao { if acao == "*" { let acac = baseline .headers() .get("access-control-allow-credentials") .and_then(|v| v.to_str().ok()) .unwrap_or(""); if acac.to_lowercase() == "true" { let evidence = DastEvidence { request_method: "GET".to_string(), request_url: url.to_string(), request_headers: None, request_body: None, response_status: baseline.status().as_u16(), response_headers: None, response_snippet: Some(format!( "Access-Control-Allow-Origin: *\nAccess-Control-Allow-Credentials: true" )), screenshot_path: None, payload: None, response_time_ms: None, }; let mut finding = DastFinding::new( String::new(), target_id.clone(), DastVulnType::CorsMisconfiguration, "CORS wildcard with credentials".to_string(), format!( "The endpoint {url} returns Access-Control-Allow-Origin: * with \ Access-Control-Allow-Credentials: true. While browsers should block this \ combination, it indicates a serious CORS misconfiguration." ), Severity::High, url.to_string(), "GET".to_string(), ); finding.cwe = Some("CWE-942".to_string()); finding.evidence = vec![evidence]; finding.remediation = Some( "Never combine Access-Control-Allow-Origin: * with \ Access-Control-Allow-Credentials: true. Specify explicit allowed origins." .to_string(), ); findings.push(finding); } } } // Test with various Origin headers let mut test_origins = Self::test_origins(&target_host); for origin in &additional_origins { test_origins.push(("custom", origin.clone())); } for (test_name, origin) in &test_origins { let resp = match self .http .get(url) .header("Origin", origin.as_str()) .send() .await { Ok(r) => r, Err(_) => continue, }; let status = resp.status().as_u16(); let acao = resp .headers() .get("access-control-allow-origin") .and_then(|v| v.to_str().ok()) .map(String::from); let acac = resp .headers() .get("access-control-allow-credentials") .and_then(|v| v.to_str().ok()) .map(String::from); let acam = resp .headers() .get("access-control-allow-methods") .and_then(|v| v.to_str().ok()) .map(String::from); cors_data.push(json!({ "test": test_name, "origin": origin, "acao": acao, "acac": acac, "acam": acam, "status": status, })); // Check if the origin was reflected back if let Some(ref acao_val) = acao { let origin_reflected = acao_val == origin; let credentials_allowed = acac .as_ref() .map(|v| v.to_lowercase() == "true") .unwrap_or(false); if origin_reflected && *test_name != "http_downgrade" { let severity = if credentials_allowed { Severity::Critical } else { Severity::High }; let resp_headers: HashMap = resp .headers() .iter() .map(|(k, v)| (k.to_string(), v.to_str().unwrap_or("").to_string())) .collect(); let evidence = DastEvidence { request_method: "GET".to_string(), request_url: url.to_string(), request_headers: Some( [("Origin".to_string(), origin.clone())] .into_iter() .collect(), ), request_body: None, response_status: status, response_headers: Some(resp_headers), response_snippet: Some(format!( "Origin: {origin}\nAccess-Control-Allow-Origin: {acao_val}\n\ Access-Control-Allow-Credentials: {}", acac.as_deref().unwrap_or("not set") )), screenshot_path: None, payload: Some(origin.clone()), response_time_ms: None, }; let title = match *test_name { "null_origin" => { "CORS accepts null origin".to_string() } "evil_domain" => { "CORS reflects arbitrary origin".to_string() } "subdomain_spoof" => { "CORS vulnerable to subdomain spoofing".to_string() } "prefix_spoof" => { "CORS vulnerable to prefix spoofing".to_string() } _ => format!("CORS reflects untrusted origin ({test_name})"), }; let cred_note = if credentials_allowed { " Combined with Access-Control-Allow-Credentials: true, this allows \ the attacker to steal authenticated data." } else { "" }; let mut finding = DastFinding::new( String::new(), target_id.clone(), DastVulnType::CorsMisconfiguration, title, format!( "The endpoint {url} reflects the Origin header '{origin}' back in \ Access-Control-Allow-Origin, allowing cross-origin requests from \ untrusted domains.{cred_note}" ), severity, url.to_string(), "GET".to_string(), ); finding.cwe = Some("CWE-942".to_string()); finding.exploitable = credentials_allowed; finding.evidence = vec![evidence]; finding.remediation = Some( "Validate the Origin header against a whitelist of trusted origins. \ Do not reflect the Origin header value directly. Use specific allowed \ origins instead of wildcards." .to_string(), ); findings.push(finding); warn!( url, test_name, origin, credentials = credentials_allowed, "CORS misconfiguration detected" ); } // Special case: HTTP downgrade if *test_name == "http_downgrade" && origin_reflected && credentials_allowed { let evidence = DastEvidence { request_method: "GET".to_string(), request_url: url.to_string(), request_headers: Some( [("Origin".to_string(), origin.clone())] .into_iter() .collect(), ), request_body: None, response_status: status, response_headers: None, response_snippet: Some(format!( "HTTP origin accepted: {origin} -> ACAO: {acao_val}" )), screenshot_path: None, payload: Some(origin.clone()), response_time_ms: None, }; let mut finding = DastFinding::new( String::new(), target_id.clone(), DastVulnType::CorsMisconfiguration, "CORS allows HTTP origin with credentials".to_string(), format!( "The HTTPS endpoint {url} accepts the HTTP origin {origin} with \ credentials. An attacker performing a man-in-the-middle attack on \ the HTTP version could steal authenticated data." ), Severity::High, url.to_string(), "GET".to_string(), ); finding.cwe = Some("CWE-942".to_string()); finding.evidence = vec![evidence]; finding.remediation = Some( "Do not accept HTTP origins for HTTPS endpoints. Ensure CORS \ origin validation enforces the https:// scheme." .to_string(), ); findings.push(finding); } } } // Also send a preflight OPTIONS request if let Ok(resp) = self .http .request(reqwest::Method::OPTIONS, url) .header("Origin", "https://evil.com") .header("Access-Control-Request-Method", "POST") .header("Access-Control-Request-Headers", "Authorization, Content-Type") .send() .await { let acam = resp .headers() .get("access-control-allow-methods") .and_then(|v| v.to_str().ok()) .map(String::from); let acah = resp .headers() .get("access-control-allow-headers") .and_then(|v| v.to_str().ok()) .map(String::from); cors_data.push(json!({ "test": "preflight", "status": resp.status().as_u16(), "allow_methods": acam, "allow_headers": acah, })); } let count = findings.len(); info!(url, findings = count, "CORS check complete"); Ok(PentestToolResult { summary: if count > 0 { format!("Found {count} CORS misconfiguration issues for {url}.") } else { format!("CORS configuration appears secure for {url}.") }, findings, data: json!({ "tests": cors_data, }), }) }) } }