use compliance_core::error::CoreError; use compliance_core::traits::dast_agent::{ DastAgent, DastContext, DiscoveredEndpoint, EndpointParameter, }; use compliance_core::traits::pentest_tool::{PentestTool, PentestToolContext, PentestToolResult}; use serde_json::json; use crate::agents::injection::SqlInjectionAgent; /// PentestTool wrapper around the existing SqlInjectionAgent. pub struct SqlInjectionTool { _http: reqwest::Client, agent: SqlInjectionAgent, } impl SqlInjectionTool { pub fn new(http: reqwest::Client) -> Self { let agent = SqlInjectionAgent::new(http.clone()); Self { _http: http, agent } } fn parse_endpoints(input: &serde_json::Value) -> Vec { let mut endpoints = Vec::new(); if let Some(arr) = input.get("endpoints").and_then(|v| v.as_array()) { for ep in arr { let url = ep .get("url") .and_then(|v| v.as_str()) .unwrap_or_default() .to_string(); let method = ep .get("method") .and_then(|v| v.as_str()) .unwrap_or("GET") .to_string(); let mut parameters = Vec::new(); if let Some(params) = ep.get("parameters").and_then(|v| v.as_array()) { for p in params { let name = p .get("name") .and_then(|v| v.as_str()) .unwrap_or_default() .to_string(); let location = p .get("location") .and_then(|v| v.as_str()) .unwrap_or("query") .to_string(); let param_type = p .get("param_type") .and_then(|v| v.as_str()) .map(String::from); let example_value = p .get("example_value") .and_then(|v| v.as_str()) .map(String::from); parameters.push(EndpointParameter { name, location, param_type, example_value, }); } } endpoints.push(DiscoveredEndpoint { url, method, parameters, content_type: ep .get("content_type") .and_then(|v| v.as_str()) .map(String::from), requires_auth: ep .get("requires_auth") .and_then(|v| v.as_bool()) .unwrap_or(false), }); } } endpoints } } #[cfg(test)] mod tests { use super::*; use serde_json::json; #[test] fn parse_endpoints_basic() { let input = json!({ "endpoints": [{ "url": "https://example.com/api/users", "method": "POST", "parameters": [ { "name": "id", "location": "body", "param_type": "integer" } ] }] }); let endpoints = SqlInjectionTool::parse_endpoints(&input); assert_eq!(endpoints.len(), 1); assert_eq!(endpoints[0].url, "https://example.com/api/users"); assert_eq!(endpoints[0].method, "POST"); assert_eq!(endpoints[0].parameters[0].name, "id"); assert_eq!(endpoints[0].parameters[0].location, "body"); assert_eq!( endpoints[0].parameters[0].param_type.as_deref(), Some("integer") ); } #[test] fn parse_endpoints_empty_input() { assert!(SqlInjectionTool::parse_endpoints(&json!({})).is_empty()); assert!(SqlInjectionTool::parse_endpoints(&json!({ "endpoints": [] })).is_empty()); } #[test] fn parse_endpoints_multiple() { let input = json!({ "endpoints": [ { "url": "https://a.com/1", "method": "GET", "parameters": [] }, { "url": "https://b.com/2", "method": "DELETE", "parameters": [] } ] }); let endpoints = SqlInjectionTool::parse_endpoints(&input); assert_eq!(endpoints.len(), 2); assert_eq!(endpoints[0].url, "https://a.com/1"); assert_eq!(endpoints[1].method, "DELETE"); } #[test] fn parse_endpoints_default_method() { let input = json!({ "endpoints": [{ "url": "https://x.com", "parameters": [] }] }); let endpoints = SqlInjectionTool::parse_endpoints(&input); assert_eq!(endpoints[0].method, "GET"); } } impl PentestTool for SqlInjectionTool { fn name(&self) -> &str { "sql_injection_scanner" } fn description(&self) -> &str { "Tests endpoints for SQL injection vulnerabilities using error-based, boolean-based, \ time-based, and union-based techniques. Provide endpoints with their parameters to test." } fn input_schema(&self) -> serde_json::Value { json!({ "type": "object", "properties": { "endpoints": { "type": "array", "description": "Endpoints to test for SQL injection", "items": { "type": "object", "properties": { "url": { "type": "string", "description": "Full URL of the endpoint" }, "method": { "type": "string", "enum": ["GET", "POST", "PUT", "PATCH", "DELETE"] }, "parameters": { "type": "array", "items": { "type": "object", "properties": { "name": { "type": "string" }, "location": { "type": "string", "enum": ["query", "body", "header", "path", "cookie"] }, "param_type": { "type": "string" }, "example_value": { "type": "string" } }, "required": ["name"] } } }, "required": ["url", "method", "parameters"] } }, "custom_payloads": { "type": "array", "description": "Optional additional SQL injection payloads to test", "items": { "type": "string" } } }, "required": ["endpoints"] }) } fn execute<'a>( &'a self, input: serde_json::Value, context: &'a PentestToolContext, ) -> std::pin::Pin< Box> + Send + 'a>, > { Box::pin(async move { let endpoints = Self::parse_endpoints(&input); if endpoints.is_empty() { return Ok(PentestToolResult { summary: "No endpoints provided to test.".to_string(), findings: Vec::new(), data: json!({}), }); } let dast_context = DastContext { endpoints, technologies: Vec::new(), sast_hints: Vec::new(), }; let findings = self.agent.run(&context.target, &dast_context).await?; let count = findings.len(); Ok(PentestToolResult { summary: if count > 0 { format!("Found {count} SQL injection vulnerabilities.") } else { "No SQL injection vulnerabilities detected.".to_string() }, findings, data: json!({ "endpoints_tested": dast_context.endpoints.len() }), }) }) } }