refactor: modularize codebase and add 404 unit tests (#13)
All checks were successful
CI / Format (push) Successful in 4s
CI / Clippy (push) Successful in 4m19s
CI / Security Audit (push) Successful in 1m44s
CI / Tests (push) Successful in 5m15s
CI / Detect Changes (push) Successful in 5s
CI / Deploy Agent (push) Successful in 2s
CI / Deploy Dashboard (push) Successful in 2s
CI / Deploy Docs (push) Has been skipped
CI / Deploy MCP (push) Successful in 2s
All checks were successful
CI / Format (push) Successful in 4s
CI / Clippy (push) Successful in 4m19s
CI / Security Audit (push) Successful in 1m44s
CI / Tests (push) Successful in 5m15s
CI / Detect Changes (push) Successful in 5s
CI / Deploy Agent (push) Successful in 2s
CI / Deploy Dashboard (push) Successful in 2s
CI / Deploy Docs (push) Has been skipped
CI / Deploy MCP (push) Successful in 2s
This commit was merged in pull request #13.
This commit is contained in:
481
compliance-agent/src/api/handlers/dto.rs
Normal file
481
compliance-agent/src/api/handlers/dto.rs
Normal file
@@ -0,0 +1,481 @@
|
||||
use compliance_core::models::TrackerType;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use compliance_core::models::ScanRun;
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct PaginationParams {
|
||||
#[serde(default = "default_page")]
|
||||
pub page: u64,
|
||||
#[serde(default = "default_limit")]
|
||||
pub limit: i64,
|
||||
}
|
||||
|
||||
pub(crate) fn default_page() -> u64 {
|
||||
1
|
||||
}
|
||||
pub(crate) fn default_limit() -> i64 {
|
||||
50
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct FindingsFilter {
|
||||
#[serde(default)]
|
||||
pub repo_id: Option<String>,
|
||||
#[serde(default)]
|
||||
pub severity: Option<String>,
|
||||
#[serde(default)]
|
||||
pub scan_type: Option<String>,
|
||||
#[serde(default)]
|
||||
pub status: Option<String>,
|
||||
#[serde(default)]
|
||||
pub q: Option<String>,
|
||||
#[serde(default)]
|
||||
pub sort_by: Option<String>,
|
||||
#[serde(default)]
|
||||
pub sort_order: Option<String>,
|
||||
#[serde(default = "default_page")]
|
||||
pub page: u64,
|
||||
#[serde(default = "default_limit")]
|
||||
pub limit: i64,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub struct ApiResponse<T: Serialize> {
|
||||
pub data: T,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub total: Option<u64>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub page: Option<u64>,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub struct OverviewStats {
|
||||
pub total_repositories: u64,
|
||||
pub total_findings: u64,
|
||||
pub critical_findings: u64,
|
||||
pub high_findings: u64,
|
||||
pub medium_findings: u64,
|
||||
pub low_findings: u64,
|
||||
pub total_sbom_entries: u64,
|
||||
pub total_cve_alerts: u64,
|
||||
pub total_issues: u64,
|
||||
pub recent_scans: Vec<ScanRun>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct AddRepositoryRequest {
|
||||
pub name: String,
|
||||
pub git_url: String,
|
||||
#[serde(default = "default_branch")]
|
||||
pub default_branch: String,
|
||||
pub auth_token: Option<String>,
|
||||
pub auth_username: Option<String>,
|
||||
pub tracker_type: Option<TrackerType>,
|
||||
pub tracker_owner: Option<String>,
|
||||
pub tracker_repo: Option<String>,
|
||||
pub tracker_token: Option<String>,
|
||||
pub scan_schedule: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct UpdateRepositoryRequest {
|
||||
pub name: Option<String>,
|
||||
pub default_branch: Option<String>,
|
||||
pub auth_token: Option<String>,
|
||||
pub auth_username: Option<String>,
|
||||
pub tracker_type: Option<TrackerType>,
|
||||
pub tracker_owner: Option<String>,
|
||||
pub tracker_repo: Option<String>,
|
||||
pub tracker_token: Option<String>,
|
||||
pub scan_schedule: Option<String>,
|
||||
}
|
||||
|
||||
fn default_branch() -> String {
|
||||
"main".to_string()
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct UpdateStatusRequest {
|
||||
pub status: String,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct BulkUpdateStatusRequest {
|
||||
pub ids: Vec<String>,
|
||||
pub status: String,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct UpdateFeedbackRequest {
|
||||
pub feedback: String,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct SbomFilter {
|
||||
#[serde(default)]
|
||||
pub repo_id: Option<String>,
|
||||
#[serde(default)]
|
||||
pub package_manager: Option<String>,
|
||||
#[serde(default)]
|
||||
pub q: Option<String>,
|
||||
#[serde(default)]
|
||||
pub has_vulns: Option<bool>,
|
||||
#[serde(default)]
|
||||
pub license: Option<String>,
|
||||
#[serde(default = "default_page")]
|
||||
pub page: u64,
|
||||
#[serde(default = "default_limit")]
|
||||
pub limit: i64,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct SbomExportParams {
|
||||
pub repo_id: String,
|
||||
#[serde(default = "default_export_format")]
|
||||
pub format: String,
|
||||
}
|
||||
|
||||
fn default_export_format() -> String {
|
||||
"cyclonedx".to_string()
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct SbomDiffParams {
|
||||
pub repo_a: String,
|
||||
pub repo_b: String,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub struct LicenseSummary {
|
||||
pub license: String,
|
||||
pub count: u64,
|
||||
pub is_copyleft: bool,
|
||||
pub packages: Vec<String>,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub struct SbomDiffResult {
|
||||
pub only_in_a: Vec<SbomDiffEntry>,
|
||||
pub only_in_b: Vec<SbomDiffEntry>,
|
||||
pub version_changed: Vec<SbomVersionDiff>,
|
||||
pub common_count: u64,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub struct SbomDiffEntry {
|
||||
pub name: String,
|
||||
pub version: String,
|
||||
pub package_manager: String,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub struct SbomVersionDiff {
|
||||
pub name: String,
|
||||
pub package_manager: String,
|
||||
pub version_a: String,
|
||||
pub version_b: String,
|
||||
}
|
||||
|
||||
pub(crate) type AgentExt = axum::extract::Extension<std::sync::Arc<crate::agent::ComplianceAgent>>;
|
||||
pub(crate) type ApiResult<T> = Result<axum::Json<ApiResponse<T>>, axum::http::StatusCode>;
|
||||
|
||||
pub(crate) async fn collect_cursor_async<T: serde::de::DeserializeOwned + Unpin + Send>(
|
||||
mut cursor: mongodb::Cursor<T>,
|
||||
) -> Vec<T> {
|
||||
use futures_util::StreamExt;
|
||||
let mut items = Vec::new();
|
||||
while let Some(result) = cursor.next().await {
|
||||
match result {
|
||||
Ok(item) => items.push(item),
|
||||
Err(e) => tracing::warn!("Failed to deserialize document: {e}"),
|
||||
}
|
||||
}
|
||||
items
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use serde_json::json;
|
||||
|
||||
// ── PaginationParams ─────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn pagination_params_defaults() {
|
||||
let p: PaginationParams = serde_json::from_str("{}").unwrap();
|
||||
assert_eq!(p.page, 1);
|
||||
assert_eq!(p.limit, 50);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn pagination_params_custom_values() {
|
||||
let p: PaginationParams = serde_json::from_str(r#"{"page":3,"limit":10}"#).unwrap();
|
||||
assert_eq!(p.page, 3);
|
||||
assert_eq!(p.limit, 10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn pagination_params_partial_override() {
|
||||
let p: PaginationParams = serde_json::from_str(r#"{"page":5}"#).unwrap();
|
||||
assert_eq!(p.page, 5);
|
||||
assert_eq!(p.limit, 50);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn pagination_params_zero_page() {
|
||||
let p: PaginationParams = serde_json::from_str(r#"{"page":0}"#).unwrap();
|
||||
assert_eq!(p.page, 0);
|
||||
}
|
||||
|
||||
// ── FindingsFilter ───────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn findings_filter_all_defaults() {
|
||||
let f: FindingsFilter = serde_json::from_str("{}").unwrap();
|
||||
assert!(f.repo_id.is_none());
|
||||
assert!(f.severity.is_none());
|
||||
assert!(f.scan_type.is_none());
|
||||
assert!(f.status.is_none());
|
||||
assert!(f.q.is_none());
|
||||
assert!(f.sort_by.is_none());
|
||||
assert!(f.sort_order.is_none());
|
||||
assert_eq!(f.page, 1);
|
||||
assert_eq!(f.limit, 50);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn findings_filter_with_all_fields() {
|
||||
let f: FindingsFilter = serde_json::from_str(
|
||||
r#"{
|
||||
"repo_id": "abc",
|
||||
"severity": "high",
|
||||
"scan_type": "sast",
|
||||
"status": "open",
|
||||
"q": "sql injection",
|
||||
"sort_by": "severity",
|
||||
"sort_order": "desc",
|
||||
"page": 2,
|
||||
"limit": 25
|
||||
}"#,
|
||||
)
|
||||
.unwrap();
|
||||
assert_eq!(f.repo_id.as_deref(), Some("abc"));
|
||||
assert_eq!(f.severity.as_deref(), Some("high"));
|
||||
assert_eq!(f.scan_type.as_deref(), Some("sast"));
|
||||
assert_eq!(f.status.as_deref(), Some("open"));
|
||||
assert_eq!(f.q.as_deref(), Some("sql injection"));
|
||||
assert_eq!(f.sort_by.as_deref(), Some("severity"));
|
||||
assert_eq!(f.sort_order.as_deref(), Some("desc"));
|
||||
assert_eq!(f.page, 2);
|
||||
assert_eq!(f.limit, 25);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn findings_filter_empty_string_fields() {
|
||||
let f: FindingsFilter = serde_json::from_str(r#"{"repo_id":"","severity":""}"#).unwrap();
|
||||
assert_eq!(f.repo_id.as_deref(), Some(""));
|
||||
assert_eq!(f.severity.as_deref(), Some(""));
|
||||
}
|
||||
|
||||
// ── ApiResponse ──────────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn api_response_serializes_with_all_fields() {
|
||||
let resp = ApiResponse {
|
||||
data: vec!["a", "b"],
|
||||
total: Some(100),
|
||||
page: Some(1),
|
||||
};
|
||||
let v = serde_json::to_value(&resp).unwrap();
|
||||
assert_eq!(v["data"], json!(["a", "b"]));
|
||||
assert_eq!(v["total"], 100);
|
||||
assert_eq!(v["page"], 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn api_response_skips_none_fields() {
|
||||
let resp = ApiResponse {
|
||||
data: "hello",
|
||||
total: None,
|
||||
page: None,
|
||||
};
|
||||
let v = serde_json::to_value(&resp).unwrap();
|
||||
assert_eq!(v["data"], "hello");
|
||||
assert!(v.get("total").is_none());
|
||||
assert!(v.get("page").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn api_response_with_nested_struct() {
|
||||
#[derive(Serialize)]
|
||||
struct Item {
|
||||
id: u32,
|
||||
}
|
||||
let resp = ApiResponse {
|
||||
data: Item { id: 42 },
|
||||
total: Some(1),
|
||||
page: None,
|
||||
};
|
||||
let v = serde_json::to_value(&resp).unwrap();
|
||||
assert_eq!(v["data"]["id"], 42);
|
||||
assert_eq!(v["total"], 1);
|
||||
assert!(v.get("page").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn api_response_empty_vec() {
|
||||
let resp: ApiResponse<Vec<String>> = ApiResponse {
|
||||
data: vec![],
|
||||
total: Some(0),
|
||||
page: Some(1),
|
||||
};
|
||||
let v = serde_json::to_value(&resp).unwrap();
|
||||
assert!(v["data"].as_array().unwrap().is_empty());
|
||||
}
|
||||
|
||||
// ── SbomFilter ───────────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn sbom_filter_defaults() {
|
||||
let f: SbomFilter = serde_json::from_str("{}").unwrap();
|
||||
assert!(f.repo_id.is_none());
|
||||
assert!(f.package_manager.is_none());
|
||||
assert!(f.q.is_none());
|
||||
assert!(f.has_vulns.is_none());
|
||||
assert!(f.license.is_none());
|
||||
assert_eq!(f.page, 1);
|
||||
assert_eq!(f.limit, 50);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sbom_filter_has_vulns_bool() {
|
||||
let f: SbomFilter = serde_json::from_str(r#"{"has_vulns": true}"#).unwrap();
|
||||
assert_eq!(f.has_vulns, Some(true));
|
||||
}
|
||||
|
||||
// ── SbomExportParams ─────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn sbom_export_params_default_format() {
|
||||
let p: SbomExportParams = serde_json::from_str(r#"{"repo_id":"r1"}"#).unwrap();
|
||||
assert_eq!(p.repo_id, "r1");
|
||||
assert_eq!(p.format, "cyclonedx");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sbom_export_params_custom_format() {
|
||||
let p: SbomExportParams =
|
||||
serde_json::from_str(r#"{"repo_id":"r1","format":"spdx"}"#).unwrap();
|
||||
assert_eq!(p.format, "spdx");
|
||||
}
|
||||
|
||||
// ── AddRepositoryRequest ─────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn add_repository_request_defaults() {
|
||||
let r: AddRepositoryRequest = serde_json::from_str(
|
||||
r#"{
|
||||
"name": "my-repo",
|
||||
"git_url": "https://github.com/x/y.git"
|
||||
}"#,
|
||||
)
|
||||
.unwrap();
|
||||
assert_eq!(r.name, "my-repo");
|
||||
assert_eq!(r.git_url, "https://github.com/x/y.git");
|
||||
assert_eq!(r.default_branch, "main");
|
||||
assert!(r.auth_token.is_none());
|
||||
assert!(r.tracker_type.is_none());
|
||||
assert!(r.scan_schedule.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn add_repository_request_custom_branch() {
|
||||
let r: AddRepositoryRequest = serde_json::from_str(
|
||||
r#"{
|
||||
"name": "repo",
|
||||
"git_url": "url",
|
||||
"default_branch": "develop"
|
||||
}"#,
|
||||
)
|
||||
.unwrap();
|
||||
assert_eq!(r.default_branch, "develop");
|
||||
}
|
||||
|
||||
// ── UpdateStatusRequest / BulkUpdateStatusRequest ────────────
|
||||
|
||||
#[test]
|
||||
fn update_status_request() {
|
||||
let r: UpdateStatusRequest = serde_json::from_str(r#"{"status":"resolved"}"#).unwrap();
|
||||
assert_eq!(r.status, "resolved");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn bulk_update_status_request() {
|
||||
let r: BulkUpdateStatusRequest =
|
||||
serde_json::from_str(r#"{"ids":["a","b"],"status":"dismissed"}"#).unwrap();
|
||||
assert_eq!(r.ids, vec!["a", "b"]);
|
||||
assert_eq!(r.status, "dismissed");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn bulk_update_status_empty_ids() {
|
||||
let r: BulkUpdateStatusRequest =
|
||||
serde_json::from_str(r#"{"ids":[],"status":"x"}"#).unwrap();
|
||||
assert!(r.ids.is_empty());
|
||||
}
|
||||
|
||||
// ── SbomDiffResult serialization ─────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn sbom_diff_result_serializes() {
|
||||
let r = SbomDiffResult {
|
||||
only_in_a: vec![SbomDiffEntry {
|
||||
name: "pkg-a".to_string(),
|
||||
version: "1.0".to_string(),
|
||||
package_manager: "npm".to_string(),
|
||||
}],
|
||||
only_in_b: vec![],
|
||||
version_changed: vec![SbomVersionDiff {
|
||||
name: "shared".to_string(),
|
||||
package_manager: "cargo".to_string(),
|
||||
version_a: "0.1".to_string(),
|
||||
version_b: "0.2".to_string(),
|
||||
}],
|
||||
common_count: 10,
|
||||
};
|
||||
let v = serde_json::to_value(&r).unwrap();
|
||||
assert_eq!(v["only_in_a"].as_array().unwrap().len(), 1);
|
||||
assert_eq!(v["only_in_b"].as_array().unwrap().len(), 0);
|
||||
assert_eq!(v["version_changed"][0]["version_a"], "0.1");
|
||||
assert_eq!(v["common_count"], 10);
|
||||
}
|
||||
|
||||
// ── LicenseSummary ───────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn license_summary_serializes() {
|
||||
let ls = LicenseSummary {
|
||||
license: "MIT".to_string(),
|
||||
count: 42,
|
||||
is_copyleft: false,
|
||||
packages: vec!["serde".to_string()],
|
||||
};
|
||||
let v = serde_json::to_value(&ls).unwrap();
|
||||
assert_eq!(v["license"], "MIT");
|
||||
assert_eq!(v["is_copyleft"], false);
|
||||
assert_eq!(v["count"], 42);
|
||||
}
|
||||
|
||||
// ── Default helper functions ─────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn default_page_returns_1() {
|
||||
assert_eq!(default_page(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn default_limit_returns_50() {
|
||||
assert_eq!(default_limit(), 50);
|
||||
}
|
||||
}
|
||||
172
compliance-agent/src/api/handlers/findings.rs
Normal file
172
compliance-agent/src/api/handlers/findings.rs
Normal file
@@ -0,0 +1,172 @@
|
||||
use axum::extract::{Extension, Path, Query};
|
||||
use axum::http::StatusCode;
|
||||
use axum::Json;
|
||||
use mongodb::bson::doc;
|
||||
|
||||
use super::dto::*;
|
||||
use compliance_core::models::Finding;
|
||||
|
||||
#[tracing::instrument(skip_all, fields(repo_id = ?filter.repo_id, severity = ?filter.severity, scan_type = ?filter.scan_type))]
|
||||
pub async fn list_findings(
|
||||
Extension(agent): AgentExt,
|
||||
Query(filter): Query<FindingsFilter>,
|
||||
) -> ApiResult<Vec<Finding>> {
|
||||
let db = &agent.db;
|
||||
let mut query = doc! {};
|
||||
if let Some(repo_id) = &filter.repo_id {
|
||||
query.insert("repo_id", repo_id);
|
||||
}
|
||||
if let Some(severity) = &filter.severity {
|
||||
query.insert("severity", severity);
|
||||
}
|
||||
if let Some(scan_type) = &filter.scan_type {
|
||||
query.insert("scan_type", scan_type);
|
||||
}
|
||||
if let Some(status) = &filter.status {
|
||||
query.insert("status", status);
|
||||
}
|
||||
// Text search across title, description, file_path, rule_id
|
||||
if let Some(q) = &filter.q {
|
||||
if !q.is_empty() {
|
||||
let regex = doc! { "$regex": q, "$options": "i" };
|
||||
query.insert(
|
||||
"$or",
|
||||
mongodb::bson::bson!([
|
||||
{ "title": regex.clone() },
|
||||
{ "description": regex.clone() },
|
||||
{ "file_path": regex.clone() },
|
||||
{ "rule_id": regex },
|
||||
]),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// Dynamic sort
|
||||
let sort_field = filter.sort_by.as_deref().unwrap_or("created_at");
|
||||
let sort_dir: i32 = match filter.sort_order.as_deref() {
|
||||
Some("asc") => 1,
|
||||
_ => -1,
|
||||
};
|
||||
let sort_doc = doc! { sort_field: sort_dir };
|
||||
|
||||
let skip = (filter.page.saturating_sub(1)) * filter.limit as u64;
|
||||
let total = db
|
||||
.findings()
|
||||
.count_documents(query.clone())
|
||||
.await
|
||||
.unwrap_or(0);
|
||||
|
||||
let findings = match db
|
||||
.findings()
|
||||
.find(query)
|
||||
.sort(sort_doc)
|
||||
.skip(skip)
|
||||
.limit(filter.limit)
|
||||
.await
|
||||
{
|
||||
Ok(cursor) => collect_cursor_async(cursor).await,
|
||||
Err(e) => {
|
||||
tracing::warn!("Failed to fetch findings: {e}");
|
||||
Vec::new()
|
||||
}
|
||||
};
|
||||
|
||||
Ok(Json(ApiResponse {
|
||||
data: findings,
|
||||
total: Some(total),
|
||||
page: Some(filter.page),
|
||||
}))
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip_all, fields(finding_id = %id))]
|
||||
pub async fn get_finding(
|
||||
Extension(agent): AgentExt,
|
||||
Path(id): Path<String>,
|
||||
) -> Result<Json<ApiResponse<Finding>>, StatusCode> {
|
||||
let oid = mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?;
|
||||
let finding = agent
|
||||
.db
|
||||
.findings()
|
||||
.find_one(doc! { "_id": oid })
|
||||
.await
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
|
||||
.ok_or(StatusCode::NOT_FOUND)?;
|
||||
|
||||
Ok(Json(ApiResponse {
|
||||
data: finding,
|
||||
total: None,
|
||||
page: None,
|
||||
}))
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip_all, fields(finding_id = %id))]
|
||||
pub async fn update_finding_status(
|
||||
Extension(agent): AgentExt,
|
||||
Path(id): Path<String>,
|
||||
Json(req): Json<UpdateStatusRequest>,
|
||||
) -> Result<Json<serde_json::Value>, StatusCode> {
|
||||
let oid = mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?;
|
||||
|
||||
agent
|
||||
.db
|
||||
.findings()
|
||||
.update_one(
|
||||
doc! { "_id": oid },
|
||||
doc! { "$set": { "status": &req.status, "updated_at": mongodb::bson::DateTime::now() } },
|
||||
)
|
||||
.await
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
|
||||
Ok(Json(serde_json::json!({ "status": "updated" })))
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip_all)]
|
||||
pub async fn bulk_update_finding_status(
|
||||
Extension(agent): AgentExt,
|
||||
Json(req): Json<BulkUpdateStatusRequest>,
|
||||
) -> Result<Json<serde_json::Value>, StatusCode> {
|
||||
let oids: Vec<mongodb::bson::oid::ObjectId> = req
|
||||
.ids
|
||||
.iter()
|
||||
.filter_map(|id| mongodb::bson::oid::ObjectId::parse_str(id).ok())
|
||||
.collect();
|
||||
|
||||
if oids.is_empty() {
|
||||
return Err(StatusCode::BAD_REQUEST);
|
||||
}
|
||||
|
||||
let result = agent
|
||||
.db
|
||||
.findings()
|
||||
.update_many(
|
||||
doc! { "_id": { "$in": oids } },
|
||||
doc! { "$set": { "status": &req.status, "updated_at": mongodb::bson::DateTime::now() } },
|
||||
)
|
||||
.await
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
|
||||
Ok(Json(
|
||||
serde_json::json!({ "status": "updated", "modified_count": result.modified_count }),
|
||||
))
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip_all)]
|
||||
pub async fn update_finding_feedback(
|
||||
Extension(agent): AgentExt,
|
||||
Path(id): Path<String>,
|
||||
Json(req): Json<UpdateFeedbackRequest>,
|
||||
) -> Result<Json<serde_json::Value>, StatusCode> {
|
||||
let oid = mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?;
|
||||
|
||||
agent
|
||||
.db
|
||||
.findings()
|
||||
.update_one(
|
||||
doc! { "_id": oid },
|
||||
doc! { "$set": { "developer_feedback": &req.feedback, "updated_at": mongodb::bson::DateTime::now() } },
|
||||
)
|
||||
.await
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
|
||||
Ok(Json(serde_json::json!({ "status": "updated" })))
|
||||
}
|
||||
84
compliance-agent/src/api/handlers/health.rs
Normal file
84
compliance-agent/src/api/handlers/health.rs
Normal file
@@ -0,0 +1,84 @@
|
||||
use axum::Json;
|
||||
use mongodb::bson::doc;
|
||||
|
||||
use super::dto::*;
|
||||
use compliance_core::models::ScanRun;
|
||||
|
||||
#[tracing::instrument(skip_all)]
|
||||
pub async fn health() -> Json<serde_json::Value> {
|
||||
Json(serde_json::json!({ "status": "ok" }))
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip_all)]
|
||||
pub async fn stats_overview(axum::extract::Extension(agent): AgentExt) -> ApiResult<OverviewStats> {
|
||||
let db = &agent.db;
|
||||
|
||||
let total_repositories = db
|
||||
.repositories()
|
||||
.count_documents(doc! {})
|
||||
.await
|
||||
.unwrap_or(0);
|
||||
let total_findings = db.findings().count_documents(doc! {}).await.unwrap_or(0);
|
||||
let critical_findings = db
|
||||
.findings()
|
||||
.count_documents(doc! { "severity": "critical" })
|
||||
.await
|
||||
.unwrap_or(0);
|
||||
let high_findings = db
|
||||
.findings()
|
||||
.count_documents(doc! { "severity": "high" })
|
||||
.await
|
||||
.unwrap_or(0);
|
||||
let medium_findings = db
|
||||
.findings()
|
||||
.count_documents(doc! { "severity": "medium" })
|
||||
.await
|
||||
.unwrap_or(0);
|
||||
let low_findings = db
|
||||
.findings()
|
||||
.count_documents(doc! { "severity": "low" })
|
||||
.await
|
||||
.unwrap_or(0);
|
||||
let total_sbom_entries = db
|
||||
.sbom_entries()
|
||||
.count_documents(doc! {})
|
||||
.await
|
||||
.unwrap_or(0);
|
||||
let total_cve_alerts = db.cve_alerts().count_documents(doc! {}).await.unwrap_or(0);
|
||||
let total_issues = db
|
||||
.tracker_issues()
|
||||
.count_documents(doc! {})
|
||||
.await
|
||||
.unwrap_or(0);
|
||||
|
||||
let recent_scans: Vec<ScanRun> = match db
|
||||
.scan_runs()
|
||||
.find(doc! {})
|
||||
.sort(doc! { "started_at": -1 })
|
||||
.limit(10)
|
||||
.await
|
||||
{
|
||||
Ok(cursor) => collect_cursor_async(cursor).await,
|
||||
Err(e) => {
|
||||
tracing::warn!("Failed to fetch recent scans: {e}");
|
||||
Vec::new()
|
||||
}
|
||||
};
|
||||
|
||||
Ok(Json(ApiResponse {
|
||||
data: OverviewStats {
|
||||
total_repositories,
|
||||
total_findings,
|
||||
critical_findings,
|
||||
high_findings,
|
||||
medium_findings,
|
||||
low_findings,
|
||||
total_sbom_entries,
|
||||
total_cve_alerts,
|
||||
total_issues,
|
||||
recent_scans,
|
||||
},
|
||||
total: None,
|
||||
page: None,
|
||||
}))
|
||||
}
|
||||
41
compliance-agent/src/api/handlers/issues.rs
Normal file
41
compliance-agent/src/api/handlers/issues.rs
Normal file
@@ -0,0 +1,41 @@
|
||||
use axum::extract::{Extension, Query};
|
||||
use axum::Json;
|
||||
use mongodb::bson::doc;
|
||||
|
||||
use super::dto::*;
|
||||
use compliance_core::models::TrackerIssue;
|
||||
|
||||
#[tracing::instrument(skip_all)]
|
||||
pub async fn list_issues(
|
||||
Extension(agent): AgentExt,
|
||||
Query(params): Query<PaginationParams>,
|
||||
) -> ApiResult<Vec<TrackerIssue>> {
|
||||
let db = &agent.db;
|
||||
let skip = (params.page.saturating_sub(1)) * params.limit as u64;
|
||||
let total = db
|
||||
.tracker_issues()
|
||||
.count_documents(doc! {})
|
||||
.await
|
||||
.unwrap_or(0);
|
||||
|
||||
let issues = match db
|
||||
.tracker_issues()
|
||||
.find(doc! {})
|
||||
.sort(doc! { "created_at": -1 })
|
||||
.skip(skip)
|
||||
.limit(params.limit)
|
||||
.await
|
||||
{
|
||||
Ok(cursor) => collect_cursor_async(cursor).await,
|
||||
Err(e) => {
|
||||
tracing::warn!("Failed to fetch tracker issues: {e}");
|
||||
Vec::new()
|
||||
}
|
||||
};
|
||||
|
||||
Ok(Json(ApiResponse {
|
||||
data: issues,
|
||||
total: Some(total),
|
||||
page: Some(params.page),
|
||||
}))
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
131
compliance-agent/src/api/handlers/pentest_handlers/export.rs
Normal file
131
compliance-agent/src/api/handlers/pentest_handlers/export.rs
Normal file
@@ -0,0 +1,131 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use axum::extract::{Extension, Path};
|
||||
use axum::http::StatusCode;
|
||||
use axum::response::IntoResponse;
|
||||
use axum::Json;
|
||||
use mongodb::bson::doc;
|
||||
use serde::Deserialize;
|
||||
|
||||
use compliance_core::models::dast::DastFinding;
|
||||
use compliance_core::models::pentest::*;
|
||||
|
||||
use crate::agent::ComplianceAgent;
|
||||
|
||||
use super::super::dto::collect_cursor_async;
|
||||
|
||||
type AgentExt = Extension<Arc<ComplianceAgent>>;
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct ExportBody {
|
||||
pub password: String,
|
||||
/// Requester display name (from auth)
|
||||
#[serde(default)]
|
||||
pub requester_name: String,
|
||||
/// Requester email (from auth)
|
||||
#[serde(default)]
|
||||
pub requester_email: String,
|
||||
}
|
||||
|
||||
/// POST /api/v1/pentest/sessions/:id/export — Export an encrypted pentest report archive
|
||||
#[tracing::instrument(skip_all, fields(session_id = %id))]
|
||||
pub async fn export_session_report(
|
||||
Extension(agent): AgentExt,
|
||||
Path(id): Path<String>,
|
||||
Json(body): Json<ExportBody>,
|
||||
) -> Result<axum::response::Response, (StatusCode, String)> {
|
||||
let oid = mongodb::bson::oid::ObjectId::parse_str(&id)
|
||||
.map_err(|_| (StatusCode::BAD_REQUEST, "Invalid session ID".to_string()))?;
|
||||
|
||||
if body.password.len() < 8 {
|
||||
return Err((
|
||||
StatusCode::BAD_REQUEST,
|
||||
"Password must be at least 8 characters".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
// Fetch session
|
||||
let session = agent
|
||||
.db
|
||||
.pentest_sessions()
|
||||
.find_one(doc! { "_id": oid })
|
||||
.await
|
||||
.map_err(|e| {
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Database error: {e}"),
|
||||
)
|
||||
})?
|
||||
.ok_or_else(|| (StatusCode::NOT_FOUND, "Session not found".to_string()))?;
|
||||
|
||||
// Resolve target name
|
||||
let target = if let Ok(tid) = mongodb::bson::oid::ObjectId::parse_str(&session.target_id) {
|
||||
agent
|
||||
.db
|
||||
.dast_targets()
|
||||
.find_one(doc! { "_id": tid })
|
||||
.await
|
||||
.ok()
|
||||
.flatten()
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let target_name = target
|
||||
.as_ref()
|
||||
.map(|t| t.name.clone())
|
||||
.unwrap_or_else(|| "Unknown Target".to_string());
|
||||
let target_url = target
|
||||
.as_ref()
|
||||
.map(|t| t.base_url.clone())
|
||||
.unwrap_or_default();
|
||||
|
||||
// Fetch attack chain nodes
|
||||
let nodes: Vec<AttackChainNode> = match agent
|
||||
.db
|
||||
.attack_chain_nodes()
|
||||
.find(doc! { "session_id": &id })
|
||||
.sort(doc! { "started_at": 1 })
|
||||
.await
|
||||
{
|
||||
Ok(cursor) => collect_cursor_async(cursor).await,
|
||||
Err(_) => Vec::new(),
|
||||
};
|
||||
|
||||
// Fetch DAST findings for this session
|
||||
let findings: Vec<DastFinding> = match agent
|
||||
.db
|
||||
.dast_findings()
|
||||
.find(doc! { "session_id": &id })
|
||||
.sort(doc! { "severity": -1, "created_at": -1 })
|
||||
.await
|
||||
{
|
||||
Ok(cursor) => collect_cursor_async(cursor).await,
|
||||
Err(_) => Vec::new(),
|
||||
};
|
||||
|
||||
let ctx = crate::pentest::report::ReportContext {
|
||||
session,
|
||||
target_name,
|
||||
target_url,
|
||||
findings,
|
||||
attack_chain: nodes,
|
||||
requester_name: if body.requester_name.is_empty() {
|
||||
"Unknown".to_string()
|
||||
} else {
|
||||
body.requester_name
|
||||
},
|
||||
requester_email: body.requester_email,
|
||||
};
|
||||
|
||||
let report = crate::pentest::generate_encrypted_report(&ctx, &body.password)
|
||||
.await
|
||||
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e))?;
|
||||
|
||||
let response = serde_json::json!({
|
||||
"archive_base64": base64::Engine::encode(&base64::engine::general_purpose::STANDARD, &report.archive),
|
||||
"sha256": report.sha256,
|
||||
"filename": format!("pentest-report-{id}.zip"),
|
||||
});
|
||||
|
||||
Ok(Json(response).into_response())
|
||||
}
|
||||
@@ -0,0 +1,9 @@
|
||||
mod export;
|
||||
mod session;
|
||||
mod stats;
|
||||
mod stream;
|
||||
|
||||
pub use export::*;
|
||||
pub use session::*;
|
||||
pub use stats::*;
|
||||
pub use stream::*;
|
||||
@@ -2,20 +2,16 @@ use std::sync::Arc;
|
||||
|
||||
use axum::extract::{Extension, Path, Query};
|
||||
use axum::http::StatusCode;
|
||||
use axum::response::sse::{Event, Sse};
|
||||
use axum::response::IntoResponse;
|
||||
use axum::Json;
|
||||
use futures_util::stream;
|
||||
use mongodb::bson::doc;
|
||||
use serde::Deserialize;
|
||||
|
||||
use compliance_core::models::dast::DastFinding;
|
||||
use compliance_core::models::pentest::*;
|
||||
|
||||
use crate::agent::ComplianceAgent;
|
||||
use crate::pentest::PentestOrchestrator;
|
||||
|
||||
use super::{collect_cursor_async, ApiResponse, PaginationParams};
|
||||
use super::super::dto::{collect_cursor_async, ApiResponse, PaginationParams};
|
||||
|
||||
type AgentExt = Extension<Arc<ComplianceAgent>>;
|
||||
|
||||
@@ -160,8 +156,7 @@ pub async fn get_session(
|
||||
Extension(agent): AgentExt,
|
||||
Path(id): Path<String>,
|
||||
) -> Result<Json<ApiResponse<PentestSession>>, StatusCode> {
|
||||
let oid =
|
||||
mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?;
|
||||
let oid = mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?;
|
||||
|
||||
let session = agent
|
||||
.db
|
||||
@@ -210,13 +205,12 @@ pub async fn send_message(
|
||||
}
|
||||
|
||||
// Look up the target
|
||||
let target_oid =
|
||||
mongodb::bson::oid::ObjectId::parse_str(&session.target_id).map_err(|_| {
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
"Invalid target_id in session".to_string(),
|
||||
)
|
||||
})?;
|
||||
let target_oid = mongodb::bson::oid::ObjectId::parse_str(&session.target_id).map_err(|_| {
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
"Invalid target_id in session".to_string(),
|
||||
)
|
||||
})?;
|
||||
|
||||
let target = agent
|
||||
.db
|
||||
@@ -261,106 +255,6 @@ pub async fn send_message(
|
||||
}))
|
||||
}
|
||||
|
||||
/// GET /api/v1/pentest/sessions/:id/stream — SSE endpoint for real-time events
|
||||
///
|
||||
/// Returns recent messages as SSE events (polling approach).
|
||||
/// True real-time streaming with broadcast channels will be added in a future iteration.
|
||||
#[tracing::instrument(skip_all, fields(session_id = %id))]
|
||||
pub async fn session_stream(
|
||||
Extension(agent): AgentExt,
|
||||
Path(id): Path<String>,
|
||||
) -> Result<Sse<impl futures_util::Stream<Item = Result<Event, std::convert::Infallible>>>, StatusCode>
|
||||
{
|
||||
let oid =
|
||||
mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?;
|
||||
|
||||
// Verify session exists
|
||||
let _session = agent
|
||||
.db
|
||||
.pentest_sessions()
|
||||
.find_one(doc! { "_id": oid })
|
||||
.await
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
|
||||
.ok_or(StatusCode::NOT_FOUND)?;
|
||||
|
||||
// Fetch recent messages for this session
|
||||
let messages: Vec<PentestMessage> = match agent
|
||||
.db
|
||||
.pentest_messages()
|
||||
.find(doc! { "session_id": &id })
|
||||
.sort(doc! { "created_at": 1 })
|
||||
.limit(100)
|
||||
.await
|
||||
{
|
||||
Ok(cursor) => collect_cursor_async(cursor).await,
|
||||
Err(_) => Vec::new(),
|
||||
};
|
||||
|
||||
// Fetch recent attack chain nodes
|
||||
let nodes: Vec<AttackChainNode> = match agent
|
||||
.db
|
||||
.attack_chain_nodes()
|
||||
.find(doc! { "session_id": &id })
|
||||
.sort(doc! { "started_at": 1 })
|
||||
.limit(100)
|
||||
.await
|
||||
{
|
||||
Ok(cursor) => collect_cursor_async(cursor).await,
|
||||
Err(_) => Vec::new(),
|
||||
};
|
||||
|
||||
// Build SSE events from stored data
|
||||
let mut events: Vec<Result<Event, std::convert::Infallible>> = Vec::new();
|
||||
|
||||
for msg in &messages {
|
||||
let event_data = serde_json::json!({
|
||||
"type": "message",
|
||||
"role": msg.role,
|
||||
"content": msg.content,
|
||||
"created_at": msg.created_at.to_rfc3339(),
|
||||
});
|
||||
if let Ok(data) = serde_json::to_string(&event_data) {
|
||||
events.push(Ok(Event::default().event("message").data(data)));
|
||||
}
|
||||
}
|
||||
|
||||
for node in &nodes {
|
||||
let event_data = serde_json::json!({
|
||||
"type": "tool_execution",
|
||||
"node_id": node.node_id,
|
||||
"tool_name": node.tool_name,
|
||||
"status": node.status,
|
||||
"findings_produced": node.findings_produced,
|
||||
});
|
||||
if let Ok(data) = serde_json::to_string(&event_data) {
|
||||
events.push(Ok(Event::default().event("tool").data(data)));
|
||||
}
|
||||
}
|
||||
|
||||
// Add session status event
|
||||
let session = agent
|
||||
.db
|
||||
.pentest_sessions()
|
||||
.find_one(doc! { "_id": oid })
|
||||
.await
|
||||
.ok()
|
||||
.flatten();
|
||||
|
||||
if let Some(s) = session {
|
||||
let status_data = serde_json::json!({
|
||||
"type": "status",
|
||||
"status": s.status,
|
||||
"findings_count": s.findings_count,
|
||||
"tool_invocations": s.tool_invocations,
|
||||
});
|
||||
if let Ok(data) = serde_json::to_string(&status_data) {
|
||||
events.push(Ok(Event::default().event("status").data(data)));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(Sse::new(stream::iter(events)))
|
||||
}
|
||||
|
||||
/// POST /api/v1/pentest/sessions/:id/stop — Stop a running pentest session
|
||||
#[tracing::instrument(skip_all, fields(session_id = %id))]
|
||||
pub async fn stop_session(
|
||||
@@ -375,7 +269,12 @@ pub async fn stop_session(
|
||||
.pentest_sessions()
|
||||
.find_one(doc! { "_id": oid })
|
||||
.await
|
||||
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, format!("Database error: {e}")))?
|
||||
.map_err(|e| {
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Database error: {e}"),
|
||||
)
|
||||
})?
|
||||
.ok_or_else(|| (StatusCode::NOT_FOUND, "Session not found".to_string()))?;
|
||||
|
||||
if session.status != PentestStatus::Running {
|
||||
@@ -397,15 +296,30 @@ pub async fn stop_session(
|
||||
}},
|
||||
)
|
||||
.await
|
||||
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, format!("Database error: {e}")))?;
|
||||
.map_err(|e| {
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Database error: {e}"),
|
||||
)
|
||||
})?;
|
||||
|
||||
let updated = agent
|
||||
.db
|
||||
.pentest_sessions()
|
||||
.find_one(doc! { "_id": oid })
|
||||
.await
|
||||
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, format!("Database error: {e}")))?
|
||||
.ok_or_else(|| (StatusCode::NOT_FOUND, "Session not found after update".to_string()))?;
|
||||
.map_err(|e| {
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Database error: {e}"),
|
||||
)
|
||||
})?
|
||||
.ok_or_else(|| {
|
||||
(
|
||||
StatusCode::NOT_FOUND,
|
||||
"Session not found after update".to_string(),
|
||||
)
|
||||
})?;
|
||||
|
||||
Ok(Json(ApiResponse {
|
||||
data: updated,
|
||||
@@ -420,9 +334,7 @@ pub async fn get_attack_chain(
|
||||
Extension(agent): AgentExt,
|
||||
Path(id): Path<String>,
|
||||
) -> Result<Json<ApiResponse<Vec<AttackChainNode>>>, StatusCode> {
|
||||
// Verify the session ID is valid
|
||||
let _oid =
|
||||
mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?;
|
||||
let _oid = mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?;
|
||||
|
||||
let nodes = match agent
|
||||
.db
|
||||
@@ -453,8 +365,7 @@ pub async fn get_messages(
|
||||
Path(id): Path<String>,
|
||||
Query(params): Query<PaginationParams>,
|
||||
) -> Result<Json<ApiResponse<Vec<PentestMessage>>>, StatusCode> {
|
||||
let _oid =
|
||||
mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?;
|
||||
let _oid = mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?;
|
||||
|
||||
let skip = (params.page.saturating_sub(1)) * params.limit as u64;
|
||||
let total = agent
|
||||
@@ -487,95 +398,14 @@ pub async fn get_messages(
|
||||
}))
|
||||
}
|
||||
|
||||
/// GET /api/v1/pentest/stats — Aggregated pentest statistics
|
||||
#[tracing::instrument(skip_all)]
|
||||
pub async fn pentest_stats(
|
||||
Extension(agent): AgentExt,
|
||||
) -> Result<Json<ApiResponse<PentestStats>>, StatusCode> {
|
||||
let db = &agent.db;
|
||||
|
||||
let running_sessions = db
|
||||
.pentest_sessions()
|
||||
.count_documents(doc! { "status": "running" })
|
||||
.await
|
||||
.unwrap_or(0) as u32;
|
||||
|
||||
// Count DAST findings from pentest sessions
|
||||
let total_vulnerabilities = db
|
||||
.dast_findings()
|
||||
.count_documents(doc! { "session_id": { "$exists": true, "$ne": null } })
|
||||
.await
|
||||
.unwrap_or(0) as u32;
|
||||
|
||||
// Aggregate tool invocations from all sessions
|
||||
let sessions: Vec<PentestSession> = match db.pentest_sessions().find(doc! {}).await {
|
||||
Ok(cursor) => collect_cursor_async(cursor).await,
|
||||
Err(_) => Vec::new(),
|
||||
};
|
||||
|
||||
let total_tool_invocations: u32 = sessions.iter().map(|s| s.tool_invocations).sum();
|
||||
let total_successes: u32 = sessions.iter().map(|s| s.tool_successes).sum();
|
||||
let tool_success_rate = if total_tool_invocations == 0 {
|
||||
100.0
|
||||
} else {
|
||||
(total_successes as f64 / total_tool_invocations as f64) * 100.0
|
||||
};
|
||||
|
||||
// Severity distribution from pentest-related DAST findings
|
||||
let critical = db
|
||||
.dast_findings()
|
||||
.count_documents(doc! { "session_id": { "$exists": true, "$ne": null }, "severity": "critical" })
|
||||
.await
|
||||
.unwrap_or(0) as u32;
|
||||
let high = db
|
||||
.dast_findings()
|
||||
.count_documents(doc! { "session_id": { "$exists": true, "$ne": null }, "severity": "high" })
|
||||
.await
|
||||
.unwrap_or(0) as u32;
|
||||
let medium = db
|
||||
.dast_findings()
|
||||
.count_documents(doc! { "session_id": { "$exists": true, "$ne": null }, "severity": "medium" })
|
||||
.await
|
||||
.unwrap_or(0) as u32;
|
||||
let low = db
|
||||
.dast_findings()
|
||||
.count_documents(doc! { "session_id": { "$exists": true, "$ne": null }, "severity": "low" })
|
||||
.await
|
||||
.unwrap_or(0) as u32;
|
||||
let info = db
|
||||
.dast_findings()
|
||||
.count_documents(doc! { "session_id": { "$exists": true, "$ne": null }, "severity": "info" })
|
||||
.await
|
||||
.unwrap_or(0) as u32;
|
||||
|
||||
Ok(Json(ApiResponse {
|
||||
data: PentestStats {
|
||||
running_sessions,
|
||||
total_vulnerabilities,
|
||||
total_tool_invocations,
|
||||
tool_success_rate,
|
||||
severity_distribution: SeverityDistribution {
|
||||
critical,
|
||||
high,
|
||||
medium,
|
||||
low,
|
||||
info,
|
||||
},
|
||||
},
|
||||
total: None,
|
||||
page: None,
|
||||
}))
|
||||
}
|
||||
|
||||
/// GET /api/v1/pentest/sessions/:id/findings — Get DAST findings for a pentest session
|
||||
#[tracing::instrument(skip_all, fields(session_id = %id))]
|
||||
pub async fn get_session_findings(
|
||||
Extension(agent): AgentExt,
|
||||
Path(id): Path<String>,
|
||||
Query(params): Query<PaginationParams>,
|
||||
) -> Result<Json<ApiResponse<Vec<DastFinding>>>, StatusCode> {
|
||||
let _oid =
|
||||
mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?;
|
||||
) -> Result<Json<ApiResponse<Vec<compliance_core::models::dast::DastFinding>>>, StatusCode> {
|
||||
let _oid = mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?;
|
||||
|
||||
let skip = (params.page.saturating_sub(1)) * params.limit as u64;
|
||||
let total = agent
|
||||
@@ -607,112 +437,3 @@ pub async fn get_session_findings(
|
||||
page: Some(params.page),
|
||||
}))
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct ExportBody {
|
||||
pub password: String,
|
||||
/// Requester display name (from auth)
|
||||
#[serde(default)]
|
||||
pub requester_name: String,
|
||||
/// Requester email (from auth)
|
||||
#[serde(default)]
|
||||
pub requester_email: String,
|
||||
}
|
||||
|
||||
/// POST /api/v1/pentest/sessions/:id/export — Export an encrypted pentest report archive
|
||||
#[tracing::instrument(skip_all, fields(session_id = %id))]
|
||||
pub async fn export_session_report(
|
||||
Extension(agent): AgentExt,
|
||||
Path(id): Path<String>,
|
||||
Json(body): Json<ExportBody>,
|
||||
) -> Result<axum::response::Response, (StatusCode, String)> {
|
||||
let oid = mongodb::bson::oid::ObjectId::parse_str(&id)
|
||||
.map_err(|_| (StatusCode::BAD_REQUEST, "Invalid session ID".to_string()))?;
|
||||
|
||||
if body.password.len() < 8 {
|
||||
return Err((
|
||||
StatusCode::BAD_REQUEST,
|
||||
"Password must be at least 8 characters".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
// Fetch session
|
||||
let session = agent
|
||||
.db
|
||||
.pentest_sessions()
|
||||
.find_one(doc! { "_id": oid })
|
||||
.await
|
||||
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, format!("Database error: {e}")))?
|
||||
.ok_or_else(|| (StatusCode::NOT_FOUND, "Session not found".to_string()))?;
|
||||
|
||||
// Resolve target name
|
||||
let target = if let Ok(tid) = mongodb::bson::oid::ObjectId::parse_str(&session.target_id) {
|
||||
agent
|
||||
.db
|
||||
.dast_targets()
|
||||
.find_one(doc! { "_id": tid })
|
||||
.await
|
||||
.ok()
|
||||
.flatten()
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let target_name = target
|
||||
.as_ref()
|
||||
.map(|t| t.name.clone())
|
||||
.unwrap_or_else(|| "Unknown Target".to_string());
|
||||
let target_url = target
|
||||
.as_ref()
|
||||
.map(|t| t.base_url.clone())
|
||||
.unwrap_or_default();
|
||||
|
||||
// Fetch attack chain nodes
|
||||
let nodes: Vec<AttackChainNode> = match agent
|
||||
.db
|
||||
.attack_chain_nodes()
|
||||
.find(doc! { "session_id": &id })
|
||||
.sort(doc! { "started_at": 1 })
|
||||
.await
|
||||
{
|
||||
Ok(cursor) => collect_cursor_async(cursor).await,
|
||||
Err(_) => Vec::new(),
|
||||
};
|
||||
|
||||
// Fetch DAST findings for this session
|
||||
let findings: Vec<DastFinding> = match agent
|
||||
.db
|
||||
.dast_findings()
|
||||
.find(doc! { "session_id": &id })
|
||||
.sort(doc! { "severity": -1, "created_at": -1 })
|
||||
.await
|
||||
{
|
||||
Ok(cursor) => collect_cursor_async(cursor).await,
|
||||
Err(_) => Vec::new(),
|
||||
};
|
||||
|
||||
let ctx = crate::pentest::report::ReportContext {
|
||||
session,
|
||||
target_name,
|
||||
target_url,
|
||||
findings,
|
||||
attack_chain: nodes,
|
||||
requester_name: if body.requester_name.is_empty() {
|
||||
"Unknown".to_string()
|
||||
} else {
|
||||
body.requester_name
|
||||
},
|
||||
requester_email: body.requester_email,
|
||||
};
|
||||
|
||||
let report = crate::pentest::generate_encrypted_report(&ctx, &body.password)
|
||||
.await
|
||||
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e))?;
|
||||
|
||||
let response = serde_json::json!({
|
||||
"archive_base64": base64::Engine::encode(&base64::engine::general_purpose::STANDARD, &report.archive),
|
||||
"sha256": report.sha256,
|
||||
"filename": format!("pentest-report-{id}.zip"),
|
||||
});
|
||||
|
||||
Ok(Json(response).into_response())
|
||||
}
|
||||
102
compliance-agent/src/api/handlers/pentest_handlers/stats.rs
Normal file
102
compliance-agent/src/api/handlers/pentest_handlers/stats.rs
Normal file
@@ -0,0 +1,102 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use axum::extract::Extension;
|
||||
use axum::http::StatusCode;
|
||||
use axum::Json;
|
||||
use mongodb::bson::doc;
|
||||
|
||||
use compliance_core::models::pentest::*;
|
||||
|
||||
use crate::agent::ComplianceAgent;
|
||||
|
||||
use super::super::dto::{collect_cursor_async, ApiResponse};
|
||||
|
||||
type AgentExt = Extension<Arc<ComplianceAgent>>;
|
||||
|
||||
/// GET /api/v1/pentest/stats — Aggregated pentest statistics
|
||||
#[tracing::instrument(skip_all)]
|
||||
pub async fn pentest_stats(
|
||||
Extension(agent): AgentExt,
|
||||
) -> Result<Json<ApiResponse<PentestStats>>, StatusCode> {
|
||||
let db = &agent.db;
|
||||
|
||||
let running_sessions = db
|
||||
.pentest_sessions()
|
||||
.count_documents(doc! { "status": "running" })
|
||||
.await
|
||||
.unwrap_or(0) as u32;
|
||||
|
||||
// Count DAST findings from pentest sessions
|
||||
let total_vulnerabilities = db
|
||||
.dast_findings()
|
||||
.count_documents(doc! { "session_id": { "$exists": true, "$ne": null } })
|
||||
.await
|
||||
.unwrap_or(0) as u32;
|
||||
|
||||
// Aggregate tool invocations from all sessions
|
||||
let sessions: Vec<PentestSession> = match db.pentest_sessions().find(doc! {}).await {
|
||||
Ok(cursor) => collect_cursor_async(cursor).await,
|
||||
Err(_) => Vec::new(),
|
||||
};
|
||||
|
||||
let total_tool_invocations: u32 = sessions.iter().map(|s| s.tool_invocations).sum();
|
||||
let total_successes: u32 = sessions.iter().map(|s| s.tool_successes).sum();
|
||||
let tool_success_rate = if total_tool_invocations == 0 {
|
||||
100.0
|
||||
} else {
|
||||
(total_successes as f64 / total_tool_invocations as f64) * 100.0
|
||||
};
|
||||
|
||||
// Severity distribution from pentest-related DAST findings
|
||||
let critical = db
|
||||
.dast_findings()
|
||||
.count_documents(
|
||||
doc! { "session_id": { "$exists": true, "$ne": null }, "severity": "critical" },
|
||||
)
|
||||
.await
|
||||
.unwrap_or(0) as u32;
|
||||
let high = db
|
||||
.dast_findings()
|
||||
.count_documents(
|
||||
doc! { "session_id": { "$exists": true, "$ne": null }, "severity": "high" },
|
||||
)
|
||||
.await
|
||||
.unwrap_or(0) as u32;
|
||||
let medium = db
|
||||
.dast_findings()
|
||||
.count_documents(
|
||||
doc! { "session_id": { "$exists": true, "$ne": null }, "severity": "medium" },
|
||||
)
|
||||
.await
|
||||
.unwrap_or(0) as u32;
|
||||
let low = db
|
||||
.dast_findings()
|
||||
.count_documents(doc! { "session_id": { "$exists": true, "$ne": null }, "severity": "low" })
|
||||
.await
|
||||
.unwrap_or(0) as u32;
|
||||
let info = db
|
||||
.dast_findings()
|
||||
.count_documents(
|
||||
doc! { "session_id": { "$exists": true, "$ne": null }, "severity": "info" },
|
||||
)
|
||||
.await
|
||||
.unwrap_or(0) as u32;
|
||||
|
||||
Ok(Json(ApiResponse {
|
||||
data: PentestStats {
|
||||
running_sessions,
|
||||
total_vulnerabilities,
|
||||
total_tool_invocations,
|
||||
tool_success_rate,
|
||||
severity_distribution: SeverityDistribution {
|
||||
critical,
|
||||
high,
|
||||
medium,
|
||||
low,
|
||||
info,
|
||||
},
|
||||
},
|
||||
total: None,
|
||||
page: None,
|
||||
}))
|
||||
}
|
||||
116
compliance-agent/src/api/handlers/pentest_handlers/stream.rs
Normal file
116
compliance-agent/src/api/handlers/pentest_handlers/stream.rs
Normal file
@@ -0,0 +1,116 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use axum::extract::{Extension, Path};
|
||||
use axum::http::StatusCode;
|
||||
use axum::response::sse::{Event, Sse};
|
||||
use futures_util::stream;
|
||||
use mongodb::bson::doc;
|
||||
|
||||
use compliance_core::models::pentest::*;
|
||||
|
||||
use crate::agent::ComplianceAgent;
|
||||
|
||||
use super::super::dto::collect_cursor_async;
|
||||
|
||||
type AgentExt = Extension<Arc<ComplianceAgent>>;
|
||||
|
||||
/// GET /api/v1/pentest/sessions/:id/stream — SSE endpoint for real-time events
|
||||
///
|
||||
/// Returns recent messages as SSE events (polling approach).
|
||||
/// True real-time streaming with broadcast channels will be added in a future iteration.
|
||||
#[tracing::instrument(skip_all, fields(session_id = %id))]
|
||||
pub async fn session_stream(
|
||||
Extension(agent): AgentExt,
|
||||
Path(id): Path<String>,
|
||||
) -> Result<
|
||||
Sse<impl futures_util::Stream<Item = Result<Event, std::convert::Infallible>>>,
|
||||
StatusCode,
|
||||
> {
|
||||
let oid = mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?;
|
||||
|
||||
// Verify session exists
|
||||
let _session = agent
|
||||
.db
|
||||
.pentest_sessions()
|
||||
.find_one(doc! { "_id": oid })
|
||||
.await
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
|
||||
.ok_or(StatusCode::NOT_FOUND)?;
|
||||
|
||||
// Fetch recent messages for this session
|
||||
let messages: Vec<PentestMessage> = match agent
|
||||
.db
|
||||
.pentest_messages()
|
||||
.find(doc! { "session_id": &id })
|
||||
.sort(doc! { "created_at": 1 })
|
||||
.limit(100)
|
||||
.await
|
||||
{
|
||||
Ok(cursor) => collect_cursor_async(cursor).await,
|
||||
Err(_) => Vec::new(),
|
||||
};
|
||||
|
||||
// Fetch recent attack chain nodes
|
||||
let nodes: Vec<AttackChainNode> = match agent
|
||||
.db
|
||||
.attack_chain_nodes()
|
||||
.find(doc! { "session_id": &id })
|
||||
.sort(doc! { "started_at": 1 })
|
||||
.limit(100)
|
||||
.await
|
||||
{
|
||||
Ok(cursor) => collect_cursor_async(cursor).await,
|
||||
Err(_) => Vec::new(),
|
||||
};
|
||||
|
||||
// Build SSE events from stored data
|
||||
let mut events: Vec<Result<Event, std::convert::Infallible>> = Vec::new();
|
||||
|
||||
for msg in &messages {
|
||||
let event_data = serde_json::json!({
|
||||
"type": "message",
|
||||
"role": msg.role,
|
||||
"content": msg.content,
|
||||
"created_at": msg.created_at.to_rfc3339(),
|
||||
});
|
||||
if let Ok(data) = serde_json::to_string(&event_data) {
|
||||
events.push(Ok(Event::default().event("message").data(data)));
|
||||
}
|
||||
}
|
||||
|
||||
for node in &nodes {
|
||||
let event_data = serde_json::json!({
|
||||
"type": "tool_execution",
|
||||
"node_id": node.node_id,
|
||||
"tool_name": node.tool_name,
|
||||
"status": node.status,
|
||||
"findings_produced": node.findings_produced,
|
||||
});
|
||||
if let Ok(data) = serde_json::to_string(&event_data) {
|
||||
events.push(Ok(Event::default().event("tool").data(data)));
|
||||
}
|
||||
}
|
||||
|
||||
// Add session status event
|
||||
let session = agent
|
||||
.db
|
||||
.pentest_sessions()
|
||||
.find_one(doc! { "_id": oid })
|
||||
.await
|
||||
.ok()
|
||||
.flatten();
|
||||
|
||||
if let Some(s) = session {
|
||||
let status_data = serde_json::json!({
|
||||
"type": "status",
|
||||
"status": s.status,
|
||||
"findings_count": s.findings_count,
|
||||
"tool_invocations": s.tool_invocations,
|
||||
});
|
||||
if let Ok(data) = serde_json::to_string(&status_data) {
|
||||
events.push(Ok(Event::default().event("status").data(data)));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(Sse::new(stream::iter(events)))
|
||||
}
|
||||
241
compliance-agent/src/api/handlers/repos.rs
Normal file
241
compliance-agent/src/api/handlers/repos.rs
Normal file
@@ -0,0 +1,241 @@
|
||||
use axum::extract::{Extension, Path, Query};
|
||||
use axum::http::StatusCode;
|
||||
use axum::Json;
|
||||
use mongodb::bson::doc;
|
||||
|
||||
use super::dto::*;
|
||||
use compliance_core::models::*;
|
||||
|
||||
#[tracing::instrument(skip_all)]
|
||||
pub async fn list_repositories(
|
||||
Extension(agent): AgentExt,
|
||||
Query(params): Query<PaginationParams>,
|
||||
) -> ApiResult<Vec<TrackedRepository>> {
|
||||
let db = &agent.db;
|
||||
let skip = (params.page.saturating_sub(1)) * params.limit as u64;
|
||||
let total = db
|
||||
.repositories()
|
||||
.count_documents(doc! {})
|
||||
.await
|
||||
.unwrap_or(0);
|
||||
|
||||
let repos = match db
|
||||
.repositories()
|
||||
.find(doc! {})
|
||||
.skip(skip)
|
||||
.limit(params.limit)
|
||||
.await
|
||||
{
|
||||
Ok(cursor) => collect_cursor_async(cursor).await,
|
||||
Err(e) => {
|
||||
tracing::warn!("Failed to fetch repositories: {e}");
|
||||
Vec::new()
|
||||
}
|
||||
};
|
||||
|
||||
Ok(Json(ApiResponse {
|
||||
data: repos,
|
||||
total: Some(total),
|
||||
page: Some(params.page),
|
||||
}))
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip_all)]
|
||||
pub async fn add_repository(
|
||||
Extension(agent): AgentExt,
|
||||
Json(req): Json<AddRepositoryRequest>,
|
||||
) -> Result<Json<ApiResponse<TrackedRepository>>, (StatusCode, String)> {
|
||||
// Validate repository access before saving
|
||||
let creds = crate::pipeline::git::RepoCredentials {
|
||||
ssh_key_path: Some(agent.config.ssh_key_path.clone()),
|
||||
auth_token: req.auth_token.clone(),
|
||||
auth_username: req.auth_username.clone(),
|
||||
};
|
||||
|
||||
if let Err(e) = crate::pipeline::git::GitOps::test_access(&req.git_url, &creds) {
|
||||
return Err((
|
||||
StatusCode::BAD_REQUEST,
|
||||
format!("Cannot access repository: {e}"),
|
||||
));
|
||||
}
|
||||
|
||||
let mut repo = TrackedRepository::new(req.name, req.git_url);
|
||||
repo.default_branch = req.default_branch;
|
||||
repo.auth_token = req.auth_token;
|
||||
repo.auth_username = req.auth_username;
|
||||
repo.tracker_type = req.tracker_type;
|
||||
repo.tracker_owner = req.tracker_owner;
|
||||
repo.tracker_repo = req.tracker_repo;
|
||||
repo.tracker_token = req.tracker_token;
|
||||
repo.scan_schedule = req.scan_schedule;
|
||||
|
||||
agent
|
||||
.db
|
||||
.repositories()
|
||||
.insert_one(&repo)
|
||||
.await
|
||||
.map_err(|_| {
|
||||
(
|
||||
StatusCode::CONFLICT,
|
||||
"Repository already exists".to_string(),
|
||||
)
|
||||
})?;
|
||||
|
||||
Ok(Json(ApiResponse {
|
||||
data: repo,
|
||||
total: None,
|
||||
page: None,
|
||||
}))
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip_all, fields(repo_id = %id))]
|
||||
pub async fn update_repository(
|
||||
Extension(agent): AgentExt,
|
||||
Path(id): Path<String>,
|
||||
Json(req): Json<UpdateRepositoryRequest>,
|
||||
) -> Result<Json<serde_json::Value>, StatusCode> {
|
||||
let oid = mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?;
|
||||
|
||||
let mut set_doc = doc! { "updated_at": mongodb::bson::DateTime::now() };
|
||||
|
||||
if let Some(name) = &req.name {
|
||||
set_doc.insert("name", name);
|
||||
}
|
||||
if let Some(branch) = &req.default_branch {
|
||||
set_doc.insert("default_branch", branch);
|
||||
}
|
||||
if let Some(token) = &req.auth_token {
|
||||
set_doc.insert("auth_token", token);
|
||||
}
|
||||
if let Some(username) = &req.auth_username {
|
||||
set_doc.insert("auth_username", username);
|
||||
}
|
||||
if let Some(tracker_type) = &req.tracker_type {
|
||||
set_doc.insert("tracker_type", tracker_type.to_string());
|
||||
}
|
||||
if let Some(owner) = &req.tracker_owner {
|
||||
set_doc.insert("tracker_owner", owner);
|
||||
}
|
||||
if let Some(repo) = &req.tracker_repo {
|
||||
set_doc.insert("tracker_repo", repo);
|
||||
}
|
||||
if let Some(token) = &req.tracker_token {
|
||||
set_doc.insert("tracker_token", token);
|
||||
}
|
||||
if let Some(schedule) = &req.scan_schedule {
|
||||
set_doc.insert("scan_schedule", schedule);
|
||||
}
|
||||
|
||||
let result = agent
|
||||
.db
|
||||
.repositories()
|
||||
.update_one(doc! { "_id": oid }, doc! { "$set": set_doc })
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::warn!("Failed to update repository: {e}");
|
||||
StatusCode::INTERNAL_SERVER_ERROR
|
||||
})?;
|
||||
|
||||
if result.matched_count == 0 {
|
||||
return Err(StatusCode::NOT_FOUND);
|
||||
}
|
||||
|
||||
Ok(Json(serde_json::json!({ "status": "updated" })))
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip_all)]
|
||||
pub async fn get_ssh_public_key(
|
||||
Extension(agent): AgentExt,
|
||||
) -> Result<Json<serde_json::Value>, StatusCode> {
|
||||
let public_path = format!("{}.pub", agent.config.ssh_key_path);
|
||||
let public_key = std::fs::read_to_string(&public_path).map_err(|_| StatusCode::NOT_FOUND)?;
|
||||
Ok(Json(serde_json::json!({ "public_key": public_key.trim() })))
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip_all, fields(repo_id = %id))]
|
||||
pub async fn trigger_scan(
|
||||
Extension(agent): AgentExt,
|
||||
Path(id): Path<String>,
|
||||
) -> Result<Json<serde_json::Value>, StatusCode> {
|
||||
let agent_clone = (*agent).clone();
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = agent_clone.run_scan(&id, ScanTrigger::Manual).await {
|
||||
tracing::error!("Manual scan failed for {id}: {e}");
|
||||
}
|
||||
});
|
||||
|
||||
Ok(Json(serde_json::json!({ "status": "scan_triggered" })))
|
||||
}
|
||||
|
||||
/// Return the webhook secret for a repository (used by dashboard to display it)
|
||||
pub async fn get_webhook_config(
|
||||
Extension(agent): AgentExt,
|
||||
Path(id): Path<String>,
|
||||
) -> Result<Json<serde_json::Value>, StatusCode> {
|
||||
let oid = mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?;
|
||||
let repo = agent
|
||||
.db
|
||||
.repositories()
|
||||
.find_one(doc! { "_id": oid })
|
||||
.await
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
|
||||
.ok_or(StatusCode::NOT_FOUND)?;
|
||||
|
||||
let tracker_type = repo
|
||||
.tracker_type
|
||||
.as_ref()
|
||||
.map(|t| t.to_string())
|
||||
.unwrap_or_else(|| "gitea".to_string());
|
||||
|
||||
Ok(Json(serde_json::json!({
|
||||
"webhook_secret": repo.webhook_secret,
|
||||
"tracker_type": tracker_type,
|
||||
})))
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip_all, fields(repo_id = %id))]
|
||||
pub async fn delete_repository(
|
||||
Extension(agent): AgentExt,
|
||||
Path(id): Path<String>,
|
||||
) -> Result<Json<serde_json::Value>, StatusCode> {
|
||||
let oid = mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?;
|
||||
let db = &agent.db;
|
||||
|
||||
// Delete the repository
|
||||
let result = db
|
||||
.repositories()
|
||||
.delete_one(doc! { "_id": oid })
|
||||
.await
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
|
||||
if result.deleted_count == 0 {
|
||||
return Err(StatusCode::NOT_FOUND);
|
||||
}
|
||||
|
||||
// Cascade delete all related data
|
||||
let _ = db.findings().delete_many(doc! { "repo_id": &id }).await;
|
||||
let _ = db.sbom_entries().delete_many(doc! { "repo_id": &id }).await;
|
||||
let _ = db.scan_runs().delete_many(doc! { "repo_id": &id }).await;
|
||||
let _ = db.cve_alerts().delete_many(doc! { "repo_id": &id }).await;
|
||||
let _ = db
|
||||
.tracker_issues()
|
||||
.delete_many(doc! { "repo_id": &id })
|
||||
.await;
|
||||
let _ = db.graph_nodes().delete_many(doc! { "repo_id": &id }).await;
|
||||
let _ = db.graph_edges().delete_many(doc! { "repo_id": &id }).await;
|
||||
let _ = db.graph_builds().delete_many(doc! { "repo_id": &id }).await;
|
||||
let _ = db
|
||||
.impact_analyses()
|
||||
.delete_many(doc! { "repo_id": &id })
|
||||
.await;
|
||||
let _ = db
|
||||
.code_embeddings()
|
||||
.delete_many(doc! { "repo_id": &id })
|
||||
.await;
|
||||
let _ = db
|
||||
.embedding_builds()
|
||||
.delete_many(doc! { "repo_id": &id })
|
||||
.await;
|
||||
|
||||
Ok(Json(serde_json::json!({ "status": "deleted" })))
|
||||
}
|
||||
379
compliance-agent/src/api/handlers/sbom.rs
Normal file
379
compliance-agent/src/api/handlers/sbom.rs
Normal file
@@ -0,0 +1,379 @@
|
||||
use axum::extract::{Extension, Query};
|
||||
use axum::http::{header, StatusCode};
|
||||
use axum::response::IntoResponse;
|
||||
use axum::Json;
|
||||
use mongodb::bson::doc;
|
||||
|
||||
use super::dto::*;
|
||||
use compliance_core::models::SbomEntry;
|
||||
|
||||
const COPYLEFT_LICENSES: &[&str] = &[
|
||||
"GPL-2.0",
|
||||
"GPL-2.0-only",
|
||||
"GPL-2.0-or-later",
|
||||
"GPL-3.0",
|
||||
"GPL-3.0-only",
|
||||
"GPL-3.0-or-later",
|
||||
"AGPL-3.0",
|
||||
"AGPL-3.0-only",
|
||||
"AGPL-3.0-or-later",
|
||||
"LGPL-2.1",
|
||||
"LGPL-2.1-only",
|
||||
"LGPL-2.1-or-later",
|
||||
"LGPL-3.0",
|
||||
"LGPL-3.0-only",
|
||||
"LGPL-3.0-or-later",
|
||||
"MPL-2.0",
|
||||
];
|
||||
|
||||
#[tracing::instrument(skip_all)]
|
||||
pub async fn sbom_filters(
|
||||
Extension(agent): AgentExt,
|
||||
) -> Result<Json<serde_json::Value>, StatusCode> {
|
||||
let db = &agent.db;
|
||||
|
||||
let managers: Vec<String> = db
|
||||
.sbom_entries()
|
||||
.distinct("package_manager", doc! {})
|
||||
.await
|
||||
.unwrap_or_default()
|
||||
.into_iter()
|
||||
.filter_map(|v| v.as_str().map(|s| s.to_string()))
|
||||
.filter(|s| !s.is_empty() && s != "unknown" && s != "file")
|
||||
.collect();
|
||||
|
||||
let licenses: Vec<String> = db
|
||||
.sbom_entries()
|
||||
.distinct("license", doc! {})
|
||||
.await
|
||||
.unwrap_or_default()
|
||||
.into_iter()
|
||||
.filter_map(|v| v.as_str().map(|s| s.to_string()))
|
||||
.filter(|s| !s.is_empty())
|
||||
.collect();
|
||||
|
||||
Ok(Json(serde_json::json!({
|
||||
"package_managers": managers,
|
||||
"licenses": licenses,
|
||||
})))
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip_all, fields(repo_id = ?filter.repo_id, package_manager = ?filter.package_manager))]
|
||||
pub async fn list_sbom(
|
||||
Extension(agent): AgentExt,
|
||||
Query(filter): Query<SbomFilter>,
|
||||
) -> ApiResult<Vec<SbomEntry>> {
|
||||
let db = &agent.db;
|
||||
let mut query = doc! {};
|
||||
|
||||
if let Some(repo_id) = &filter.repo_id {
|
||||
query.insert("repo_id", repo_id);
|
||||
}
|
||||
if let Some(pm) = &filter.package_manager {
|
||||
query.insert("package_manager", pm);
|
||||
}
|
||||
if let Some(q) = &filter.q {
|
||||
if !q.is_empty() {
|
||||
query.insert("name", doc! { "$regex": q, "$options": "i" });
|
||||
}
|
||||
}
|
||||
if let Some(has_vulns) = filter.has_vulns {
|
||||
if has_vulns {
|
||||
query.insert("known_vulnerabilities", doc! { "$exists": true, "$ne": [] });
|
||||
} else {
|
||||
query.insert("known_vulnerabilities", doc! { "$size": 0 });
|
||||
}
|
||||
}
|
||||
if let Some(license) = &filter.license {
|
||||
query.insert("license", license);
|
||||
}
|
||||
|
||||
let skip = (filter.page.saturating_sub(1)) * filter.limit as u64;
|
||||
let total = db
|
||||
.sbom_entries()
|
||||
.count_documents(query.clone())
|
||||
.await
|
||||
.unwrap_or(0);
|
||||
|
||||
let entries = match db
|
||||
.sbom_entries()
|
||||
.find(query)
|
||||
.sort(doc! { "name": 1 })
|
||||
.skip(skip)
|
||||
.limit(filter.limit)
|
||||
.await
|
||||
{
|
||||
Ok(cursor) => collect_cursor_async(cursor).await,
|
||||
Err(e) => {
|
||||
tracing::warn!("Failed to fetch SBOM entries: {e}");
|
||||
Vec::new()
|
||||
}
|
||||
};
|
||||
|
||||
Ok(Json(ApiResponse {
|
||||
data: entries,
|
||||
total: Some(total),
|
||||
page: Some(filter.page),
|
||||
}))
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip_all)]
|
||||
pub async fn export_sbom(
|
||||
Extension(agent): AgentExt,
|
||||
Query(params): Query<SbomExportParams>,
|
||||
) -> Result<impl IntoResponse, StatusCode> {
|
||||
let db = &agent.db;
|
||||
let entries: Vec<SbomEntry> = match db
|
||||
.sbom_entries()
|
||||
.find(doc! { "repo_id": ¶ms.repo_id })
|
||||
.await
|
||||
{
|
||||
Ok(cursor) => collect_cursor_async(cursor).await,
|
||||
Err(e) => {
|
||||
tracing::warn!("Failed to fetch SBOM entries for export: {e}");
|
||||
Vec::new()
|
||||
}
|
||||
};
|
||||
|
||||
let body = if params.format == "spdx" {
|
||||
// SPDX 2.3 format
|
||||
let packages: Vec<serde_json::Value> = entries
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, e)| {
|
||||
serde_json::json!({
|
||||
"SPDXID": format!("SPDXRef-Package-{i}"),
|
||||
"name": e.name,
|
||||
"versionInfo": e.version,
|
||||
"downloadLocation": "NOASSERTION",
|
||||
"licenseConcluded": e.license.as_deref().unwrap_or("NOASSERTION"),
|
||||
"externalRefs": e.purl.as_ref().map(|p| vec![serde_json::json!({
|
||||
"referenceCategory": "PACKAGE-MANAGER",
|
||||
"referenceType": "purl",
|
||||
"referenceLocator": p,
|
||||
})]).unwrap_or_default(),
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
||||
serde_json::json!({
|
||||
"spdxVersion": "SPDX-2.3",
|
||||
"dataLicense": "CC0-1.0",
|
||||
"SPDXID": "SPDXRef-DOCUMENT",
|
||||
"name": format!("sbom-{}", params.repo_id),
|
||||
"documentNamespace": format!("https://compliance-scanner/sbom/{}", params.repo_id),
|
||||
"packages": packages,
|
||||
})
|
||||
} else {
|
||||
// CycloneDX 1.5 format
|
||||
let components: Vec<serde_json::Value> = entries
|
||||
.iter()
|
||||
.map(|e| {
|
||||
let mut comp = serde_json::json!({
|
||||
"type": "library",
|
||||
"name": e.name,
|
||||
"version": e.version,
|
||||
"group": e.package_manager,
|
||||
});
|
||||
if let Some(purl) = &e.purl {
|
||||
comp["purl"] = serde_json::Value::String(purl.clone());
|
||||
}
|
||||
if let Some(license) = &e.license {
|
||||
comp["licenses"] = serde_json::json!([{ "license": { "id": license } }]);
|
||||
}
|
||||
if !e.known_vulnerabilities.is_empty() {
|
||||
comp["vulnerabilities"] = serde_json::json!(
|
||||
e.known_vulnerabilities.iter().map(|v| serde_json::json!({
|
||||
"id": v.id,
|
||||
"source": { "name": v.source },
|
||||
"ratings": v.severity.as_ref().map(|s| vec![serde_json::json!({"severity": s})]).unwrap_or_default(),
|
||||
})).collect::<Vec<_>>()
|
||||
);
|
||||
}
|
||||
comp
|
||||
})
|
||||
.collect();
|
||||
|
||||
serde_json::json!({
|
||||
"bomFormat": "CycloneDX",
|
||||
"specVersion": "1.5",
|
||||
"version": 1,
|
||||
"metadata": {
|
||||
"component": {
|
||||
"type": "application",
|
||||
"name": format!("repo-{}", params.repo_id),
|
||||
}
|
||||
},
|
||||
"components": components,
|
||||
})
|
||||
};
|
||||
|
||||
let json_str =
|
||||
serde_json::to_string_pretty(&body).map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
let filename = if params.format == "spdx" {
|
||||
format!("sbom-{}-spdx.json", params.repo_id)
|
||||
} else {
|
||||
format!("sbom-{}-cyclonedx.json", params.repo_id)
|
||||
};
|
||||
|
||||
let disposition = format!("attachment; filename=\"{filename}\"");
|
||||
Ok((
|
||||
[
|
||||
(
|
||||
header::CONTENT_TYPE,
|
||||
header::HeaderValue::from_static("application/json"),
|
||||
),
|
||||
(
|
||||
header::CONTENT_DISPOSITION,
|
||||
header::HeaderValue::from_str(&disposition)
|
||||
.unwrap_or_else(|_| header::HeaderValue::from_static("attachment")),
|
||||
),
|
||||
],
|
||||
json_str,
|
||||
))
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip_all)]
|
||||
pub async fn license_summary(
|
||||
Extension(agent): AgentExt,
|
||||
Query(params): Query<SbomFilter>,
|
||||
) -> ApiResult<Vec<LicenseSummary>> {
|
||||
let db = &agent.db;
|
||||
let mut query = doc! {};
|
||||
if let Some(repo_id) = ¶ms.repo_id {
|
||||
query.insert("repo_id", repo_id);
|
||||
}
|
||||
|
||||
let entries: Vec<SbomEntry> = match db.sbom_entries().find(query).await {
|
||||
Ok(cursor) => collect_cursor_async(cursor).await,
|
||||
Err(e) => {
|
||||
tracing::warn!("Failed to fetch SBOM entries for license summary: {e}");
|
||||
Vec::new()
|
||||
}
|
||||
};
|
||||
|
||||
let mut license_map: std::collections::HashMap<String, Vec<String>> =
|
||||
std::collections::HashMap::new();
|
||||
for entry in &entries {
|
||||
let lic = entry.license.as_deref().unwrap_or("Unknown").to_string();
|
||||
license_map.entry(lic).or_default().push(entry.name.clone());
|
||||
}
|
||||
|
||||
let mut summaries: Vec<LicenseSummary> = license_map
|
||||
.into_iter()
|
||||
.map(|(license, packages)| {
|
||||
let is_copyleft = COPYLEFT_LICENSES
|
||||
.iter()
|
||||
.any(|c| license.to_uppercase().contains(&c.to_uppercase()));
|
||||
LicenseSummary {
|
||||
license,
|
||||
count: packages.len() as u64,
|
||||
is_copyleft,
|
||||
packages,
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
summaries.sort_by(|a, b| b.count.cmp(&a.count));
|
||||
|
||||
Ok(Json(ApiResponse {
|
||||
data: summaries,
|
||||
total: None,
|
||||
page: None,
|
||||
}))
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip_all)]
|
||||
pub async fn sbom_diff(
|
||||
Extension(agent): AgentExt,
|
||||
Query(params): Query<SbomDiffParams>,
|
||||
) -> ApiResult<SbomDiffResult> {
|
||||
let db = &agent.db;
|
||||
|
||||
let entries_a: Vec<SbomEntry> = match db
|
||||
.sbom_entries()
|
||||
.find(doc! { "repo_id": ¶ms.repo_a })
|
||||
.await
|
||||
{
|
||||
Ok(cursor) => collect_cursor_async(cursor).await,
|
||||
Err(e) => {
|
||||
tracing::warn!("Failed to fetch SBOM entries for repo_a: {e}");
|
||||
Vec::new()
|
||||
}
|
||||
};
|
||||
|
||||
let entries_b: Vec<SbomEntry> = match db
|
||||
.sbom_entries()
|
||||
.find(doc! { "repo_id": ¶ms.repo_b })
|
||||
.await
|
||||
{
|
||||
Ok(cursor) => collect_cursor_async(cursor).await,
|
||||
Err(e) => {
|
||||
tracing::warn!("Failed to fetch SBOM entries for repo_b: {e}");
|
||||
Vec::new()
|
||||
}
|
||||
};
|
||||
|
||||
// Build maps by (name, package_manager) -> version
|
||||
let map_a: std::collections::HashMap<(String, String), String> = entries_a
|
||||
.iter()
|
||||
.map(|e| {
|
||||
(
|
||||
(e.name.clone(), e.package_manager.clone()),
|
||||
e.version.clone(),
|
||||
)
|
||||
})
|
||||
.collect();
|
||||
let map_b: std::collections::HashMap<(String, String), String> = entries_b
|
||||
.iter()
|
||||
.map(|e| {
|
||||
(
|
||||
(e.name.clone(), e.package_manager.clone()),
|
||||
e.version.clone(),
|
||||
)
|
||||
})
|
||||
.collect();
|
||||
|
||||
let mut only_in_a = Vec::new();
|
||||
let mut version_changed = Vec::new();
|
||||
let mut common_count: u64 = 0;
|
||||
|
||||
for (key, ver_a) in &map_a {
|
||||
match map_b.get(key) {
|
||||
None => only_in_a.push(SbomDiffEntry {
|
||||
name: key.0.clone(),
|
||||
version: ver_a.clone(),
|
||||
package_manager: key.1.clone(),
|
||||
}),
|
||||
Some(ver_b) if ver_a != ver_b => {
|
||||
version_changed.push(SbomVersionDiff {
|
||||
name: key.0.clone(),
|
||||
package_manager: key.1.clone(),
|
||||
version_a: ver_a.clone(),
|
||||
version_b: ver_b.clone(),
|
||||
});
|
||||
}
|
||||
Some(_) => common_count += 1,
|
||||
}
|
||||
}
|
||||
|
||||
let only_in_b: Vec<SbomDiffEntry> = map_b
|
||||
.iter()
|
||||
.filter(|(key, _)| !map_a.contains_key(key))
|
||||
.map(|(key, ver)| SbomDiffEntry {
|
||||
name: key.0.clone(),
|
||||
version: ver.clone(),
|
||||
package_manager: key.1.clone(),
|
||||
})
|
||||
.collect();
|
||||
|
||||
Ok(Json(ApiResponse {
|
||||
data: SbomDiffResult {
|
||||
only_in_a,
|
||||
only_in_b,
|
||||
version_changed,
|
||||
common_count,
|
||||
},
|
||||
total: None,
|
||||
page: None,
|
||||
}))
|
||||
}
|
||||
37
compliance-agent/src/api/handlers/scans.rs
Normal file
37
compliance-agent/src/api/handlers/scans.rs
Normal file
@@ -0,0 +1,37 @@
|
||||
use axum::extract::{Extension, Query};
|
||||
use axum::Json;
|
||||
use mongodb::bson::doc;
|
||||
|
||||
use super::dto::*;
|
||||
use compliance_core::models::ScanRun;
|
||||
|
||||
#[tracing::instrument(skip_all)]
|
||||
pub async fn list_scan_runs(
|
||||
Extension(agent): AgentExt,
|
||||
Query(params): Query<PaginationParams>,
|
||||
) -> ApiResult<Vec<ScanRun>> {
|
||||
let db = &agent.db;
|
||||
let skip = (params.page.saturating_sub(1)) * params.limit as u64;
|
||||
let total = db.scan_runs().count_documents(doc! {}).await.unwrap_or(0);
|
||||
|
||||
let scans = match db
|
||||
.scan_runs()
|
||||
.find(doc! {})
|
||||
.sort(doc! { "started_at": -1 })
|
||||
.skip(skip)
|
||||
.limit(params.limit)
|
||||
.await
|
||||
{
|
||||
Ok(cursor) => collect_cursor_async(cursor).await,
|
||||
Err(e) => {
|
||||
tracing::warn!("Failed to fetch scan runs: {e}");
|
||||
Vec::new()
|
||||
}
|
||||
};
|
||||
|
||||
Ok(Json(ApiResponse {
|
||||
data: scans,
|
||||
total: Some(total),
|
||||
page: Some(params.page),
|
||||
}))
|
||||
}
|
||||
@@ -136,7 +136,10 @@ pub fn build_router() -> Router {
|
||||
"/api/v1/pentest/sessions/{id}/export",
|
||||
post(handlers::pentest::export_session_report),
|
||||
)
|
||||
.route("/api/v1/pentest/stats", get(handlers::pentest::pentest_stats))
|
||||
.route(
|
||||
"/api/v1/pentest/stats",
|
||||
get(handlers::pentest::pentest_stats),
|
||||
)
|
||||
// Webhook endpoints (proxied through dashboard)
|
||||
.route(
|
||||
"/webhook/github/{repo_id}",
|
||||
|
||||
@@ -1,147 +1,17 @@
|
||||
use secrecy::{ExposeSecret, SecretString};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use super::types::*;
|
||||
use crate::error::AgentError;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct LlmClient {
|
||||
base_url: String,
|
||||
api_key: SecretString,
|
||||
model: String,
|
||||
embed_model: String,
|
||||
http: reqwest::Client,
|
||||
pub(crate) base_url: String,
|
||||
pub(crate) api_key: SecretString,
|
||||
pub(crate) model: String,
|
||||
pub(crate) embed_model: String,
|
||||
pub(crate) http: reqwest::Client,
|
||||
}
|
||||
|
||||
// ── Request types ──────────────────────────────────────────────
|
||||
|
||||
#[derive(Serialize, Clone, Debug)]
|
||||
pub struct ChatMessage {
|
||||
pub role: String,
|
||||
pub content: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tool_calls: Option<Vec<ToolCallRequest>>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tool_call_id: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct ChatCompletionRequest {
|
||||
model: String,
|
||||
messages: Vec<ChatMessage>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
temperature: Option<f64>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
max_tokens: Option<u32>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
tools: Option<Vec<ToolDefinitionPayload>>,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct ToolDefinitionPayload {
|
||||
r#type: String,
|
||||
function: ToolFunctionPayload,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct ToolFunctionPayload {
|
||||
name: String,
|
||||
description: String,
|
||||
parameters: serde_json::Value,
|
||||
}
|
||||
|
||||
// ── Response types ─────────────────────────────────────────────
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct ChatCompletionResponse {
|
||||
choices: Vec<ChatChoice>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct ChatChoice {
|
||||
message: ChatResponseMessage,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct ChatResponseMessage {
|
||||
#[serde(default)]
|
||||
content: Option<String>,
|
||||
#[serde(default)]
|
||||
tool_calls: Option<Vec<ToolCallResponse>>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct ToolCallResponse {
|
||||
id: String,
|
||||
function: ToolCallFunction,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct ToolCallFunction {
|
||||
name: String,
|
||||
arguments: String,
|
||||
}
|
||||
|
||||
// ── Public types for tool calling ──────────────────────────────
|
||||
|
||||
/// Definition of a tool that the LLM can invoke
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
pub struct ToolDefinition {
|
||||
pub name: String,
|
||||
pub description: String,
|
||||
pub parameters: serde_json::Value,
|
||||
}
|
||||
|
||||
/// A tool call request from the LLM
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct LlmToolCall {
|
||||
pub id: String,
|
||||
pub name: String,
|
||||
pub arguments: serde_json::Value,
|
||||
}
|
||||
|
||||
/// A tool call in the request message format (for sending back tool_calls in assistant messages)
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ToolCallRequest {
|
||||
pub id: String,
|
||||
pub r#type: String,
|
||||
pub function: ToolCallRequestFunction,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ToolCallRequestFunction {
|
||||
pub name: String,
|
||||
pub arguments: String,
|
||||
}
|
||||
|
||||
/// Response from the LLM — either content or tool calls
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum LlmResponse {
|
||||
Content(String),
|
||||
/// Tool calls with optional reasoning text from the LLM
|
||||
ToolCalls { calls: Vec<LlmToolCall>, reasoning: String },
|
||||
}
|
||||
|
||||
// ── Embedding types ────────────────────────────────────────────
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct EmbeddingRequest {
|
||||
model: String,
|
||||
input: Vec<String>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct EmbeddingResponse {
|
||||
data: Vec<EmbeddingData>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct EmbeddingData {
|
||||
embedding: Vec<f64>,
|
||||
index: usize,
|
||||
}
|
||||
|
||||
// ── Implementation ─────────────────────────────────────────────
|
||||
|
||||
impl LlmClient {
|
||||
pub fn new(
|
||||
base_url: String,
|
||||
@@ -158,18 +28,14 @@ impl LlmClient {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn embed_model(&self) -> &str {
|
||||
&self.embed_model
|
||||
}
|
||||
|
||||
fn chat_url(&self) -> String {
|
||||
pub(crate) fn chat_url(&self) -> String {
|
||||
format!(
|
||||
"{}/v1/chat/completions",
|
||||
self.base_url.trim_end_matches('/')
|
||||
)
|
||||
}
|
||||
|
||||
fn auth_header(&self) -> Option<String> {
|
||||
pub(crate) fn auth_header(&self) -> Option<String> {
|
||||
let key = self.api_key.expose_secret();
|
||||
if key.is_empty() {
|
||||
None
|
||||
@@ -241,12 +107,12 @@ impl LlmClient {
|
||||
tools: None,
|
||||
};
|
||||
|
||||
self.send_chat_request(&request_body).await.map(|resp| {
|
||||
match resp {
|
||||
self.send_chat_request(&request_body)
|
||||
.await
|
||||
.map(|resp| match resp {
|
||||
LlmResponse::Content(c) => c,
|
||||
LlmResponse::ToolCalls { .. } => String::new(),
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
/// Chat with tool definitions — returns either content or tool calls.
|
||||
@@ -292,7 +158,7 @@ impl LlmClient {
|
||||
) -> Result<LlmResponse, AgentError> {
|
||||
let mut req = self
|
||||
.http
|
||||
.post(&self.chat_url())
|
||||
.post(self.chat_url())
|
||||
.header("content-type", "application/json")
|
||||
.json(request_body);
|
||||
|
||||
@@ -345,54 +211,7 @@ impl LlmClient {
|
||||
}
|
||||
|
||||
// Otherwise return content
|
||||
let content = choice
|
||||
.message
|
||||
.content
|
||||
.clone()
|
||||
.unwrap_or_default();
|
||||
let content = choice.message.content.clone().unwrap_or_default();
|
||||
Ok(LlmResponse::Content(content))
|
||||
}
|
||||
|
||||
/// Generate embeddings for a batch of texts
|
||||
pub async fn embed(&self, texts: Vec<String>) -> Result<Vec<Vec<f64>>, AgentError> {
|
||||
let url = format!("{}/v1/embeddings", self.base_url.trim_end_matches('/'));
|
||||
|
||||
let request_body = EmbeddingRequest {
|
||||
model: self.embed_model.clone(),
|
||||
input: texts,
|
||||
};
|
||||
|
||||
let mut req = self
|
||||
.http
|
||||
.post(&url)
|
||||
.header("content-type", "application/json")
|
||||
.json(&request_body);
|
||||
|
||||
if let Some(auth) = self.auth_header() {
|
||||
req = req.header("Authorization", auth);
|
||||
}
|
||||
|
||||
let resp = req
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| AgentError::Other(format!("Embedding request failed: {e}")))?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
let status = resp.status();
|
||||
let body = resp.text().await.unwrap_or_default();
|
||||
return Err(AgentError::Other(format!(
|
||||
"Embedding API returned {status}: {body}"
|
||||
)));
|
||||
}
|
||||
|
||||
let body: EmbeddingResponse = resp
|
||||
.json()
|
||||
.await
|
||||
.map_err(|e| AgentError::Other(format!("Failed to parse embedding response: {e}")))?;
|
||||
|
||||
let mut data = body.data;
|
||||
data.sort_by_key(|d| d.index);
|
||||
|
||||
Ok(data.into_iter().map(|d| d.embedding).collect())
|
||||
}
|
||||
}
|
||||
|
||||
74
compliance-agent/src/llm/embedding.rs
Normal file
74
compliance-agent/src/llm/embedding.rs
Normal file
@@ -0,0 +1,74 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use super::client::LlmClient;
|
||||
use crate::error::AgentError;
|
||||
|
||||
// ── Embedding types ────────────────────────────────────────────
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct EmbeddingRequest {
|
||||
model: String,
|
||||
input: Vec<String>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct EmbeddingResponse {
|
||||
data: Vec<EmbeddingData>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct EmbeddingData {
|
||||
embedding: Vec<f64>,
|
||||
index: usize,
|
||||
}
|
||||
|
||||
// ── Embedding implementation ───────────────────────────────────
|
||||
|
||||
impl LlmClient {
|
||||
pub fn embed_model(&self) -> &str {
|
||||
&self.embed_model
|
||||
}
|
||||
|
||||
/// Generate embeddings for a batch of texts
|
||||
pub async fn embed(&self, texts: Vec<String>) -> Result<Vec<Vec<f64>>, AgentError> {
|
||||
let url = format!("{}/v1/embeddings", self.base_url.trim_end_matches('/'));
|
||||
|
||||
let request_body = EmbeddingRequest {
|
||||
model: self.embed_model.clone(),
|
||||
input: texts,
|
||||
};
|
||||
|
||||
let mut req = self
|
||||
.http
|
||||
.post(&url)
|
||||
.header("content-type", "application/json")
|
||||
.json(&request_body);
|
||||
|
||||
if let Some(auth) = self.auth_header() {
|
||||
req = req.header("Authorization", auth);
|
||||
}
|
||||
|
||||
let resp = req
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| AgentError::Other(format!("Embedding request failed: {e}")))?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
let status = resp.status();
|
||||
let body = resp.text().await.unwrap_or_default();
|
||||
return Err(AgentError::Other(format!(
|
||||
"Embedding API returned {status}: {body}"
|
||||
)));
|
||||
}
|
||||
|
||||
let body: EmbeddingResponse = resp
|
||||
.json()
|
||||
.await
|
||||
.map_err(|e| AgentError::Other(format!("Failed to parse embedding response: {e}")))?;
|
||||
|
||||
let mut data = body.data;
|
||||
data.sort_by_key(|d| d.index);
|
||||
|
||||
Ok(data.into_iter().map(|d| d.embedding).collect())
|
||||
}
|
||||
}
|
||||
@@ -1,11 +1,16 @@
|
||||
pub mod client;
|
||||
#[allow(dead_code)]
|
||||
pub mod descriptions;
|
||||
pub mod embedding;
|
||||
#[allow(dead_code)]
|
||||
pub mod fixes;
|
||||
#[allow(dead_code)]
|
||||
pub mod pr_review;
|
||||
pub mod review_prompts;
|
||||
pub mod triage;
|
||||
pub mod types;
|
||||
|
||||
pub use client::LlmClient;
|
||||
pub use types::{
|
||||
ChatMessage, LlmResponse, ToolCallRequest, ToolCallRequestFunction, ToolDefinition,
|
||||
};
|
||||
|
||||
@@ -278,3 +278,220 @@ struct TriageResult {
|
||||
fn default_action() -> String {
|
||||
"confirm".to_string()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use compliance_core::models::Severity;
|
||||
|
||||
// ── classify_file_path ───────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn classify_none_path() {
|
||||
assert_eq!(classify_file_path(None), "unknown");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn classify_production_path() {
|
||||
assert_eq!(classify_file_path(Some("src/main.rs")), "production");
|
||||
assert_eq!(classify_file_path(Some("lib/core/engine.py")), "production");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn classify_test_paths() {
|
||||
assert_eq!(classify_file_path(Some("src/test/helper.rs")), "test");
|
||||
assert_eq!(classify_file_path(Some("src/tests/unit.rs")), "test");
|
||||
assert_eq!(classify_file_path(Some("foo_test.go")), "test");
|
||||
assert_eq!(classify_file_path(Some("bar.test.js")), "test");
|
||||
assert_eq!(classify_file_path(Some("baz.spec.ts")), "test");
|
||||
assert_eq!(
|
||||
classify_file_path(Some("data/fixtures/sample.json")),
|
||||
"test"
|
||||
);
|
||||
assert_eq!(classify_file_path(Some("src/testdata/input.txt")), "test");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn classify_example_paths() {
|
||||
assert_eq!(
|
||||
classify_file_path(Some("docs/examples/basic.rs")),
|
||||
"example"
|
||||
);
|
||||
// /example matches because contains("/example")
|
||||
assert_eq!(classify_file_path(Some("src/example/main.py")), "example");
|
||||
assert_eq!(classify_file_path(Some("src/demo/run.sh")), "example");
|
||||
assert_eq!(classify_file_path(Some("src/sample/lib.rs")), "example");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn classify_generated_paths() {
|
||||
assert_eq!(
|
||||
classify_file_path(Some("src/generated/api.rs")),
|
||||
"generated"
|
||||
);
|
||||
assert_eq!(
|
||||
classify_file_path(Some("proto/gen/service.go")),
|
||||
"generated"
|
||||
);
|
||||
assert_eq!(classify_file_path(Some("api.generated.ts")), "generated");
|
||||
assert_eq!(classify_file_path(Some("service.pb.go")), "generated");
|
||||
assert_eq!(classify_file_path(Some("model_generated.rs")), "generated");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn classify_vendored_paths() {
|
||||
// Implementation checks for /vendor/, /node_modules/, /third_party/ (with slashes)
|
||||
assert_eq!(
|
||||
classify_file_path(Some("src/vendor/lib/foo.go")),
|
||||
"vendored"
|
||||
);
|
||||
assert_eq!(
|
||||
classify_file_path(Some("src/node_modules/pkg/index.js")),
|
||||
"vendored"
|
||||
);
|
||||
assert_eq!(
|
||||
classify_file_path(Some("src/third_party/lib.c")),
|
||||
"vendored"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn classify_is_case_insensitive() {
|
||||
assert_eq!(classify_file_path(Some("src/TEST/Helper.rs")), "test");
|
||||
assert_eq!(classify_file_path(Some("src/VENDOR/lib.go")), "vendored");
|
||||
assert_eq!(
|
||||
classify_file_path(Some("src/GENERATED/foo.ts")),
|
||||
"generated"
|
||||
);
|
||||
}
|
||||
|
||||
// ── adjust_confidence ────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn adjust_confidence_production() {
|
||||
assert_eq!(adjust_confidence(8.0, "production"), 8.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn adjust_confidence_test() {
|
||||
assert_eq!(adjust_confidence(10.0, "test"), 5.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn adjust_confidence_example() {
|
||||
assert_eq!(adjust_confidence(10.0, "example"), 6.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn adjust_confidence_generated() {
|
||||
assert_eq!(adjust_confidence(10.0, "generated"), 3.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn adjust_confidence_vendored() {
|
||||
assert_eq!(adjust_confidence(10.0, "vendored"), 4.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn adjust_confidence_unknown_classification() {
|
||||
assert_eq!(adjust_confidence(7.0, "unknown"), 7.0);
|
||||
assert_eq!(adjust_confidence(7.0, "something_else"), 7.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn adjust_confidence_zero() {
|
||||
assert_eq!(adjust_confidence(0.0, "test"), 0.0);
|
||||
assert_eq!(adjust_confidence(0.0, "production"), 0.0);
|
||||
}
|
||||
|
||||
// ── downgrade_severity ───────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn downgrade_severity_all_levels() {
|
||||
assert_eq!(downgrade_severity(&Severity::Critical), Severity::High);
|
||||
assert_eq!(downgrade_severity(&Severity::High), Severity::Medium);
|
||||
assert_eq!(downgrade_severity(&Severity::Medium), Severity::Low);
|
||||
assert_eq!(downgrade_severity(&Severity::Low), Severity::Info);
|
||||
assert_eq!(downgrade_severity(&Severity::Info), Severity::Info);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn downgrade_severity_info_is_floor() {
|
||||
// Downgrading Info twice should still be Info
|
||||
let s = downgrade_severity(&Severity::Info);
|
||||
assert_eq!(downgrade_severity(&s), Severity::Info);
|
||||
}
|
||||
|
||||
// ── upgrade_severity ─────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn upgrade_severity_all_levels() {
|
||||
assert_eq!(upgrade_severity(&Severity::Info), Severity::Low);
|
||||
assert_eq!(upgrade_severity(&Severity::Low), Severity::Medium);
|
||||
assert_eq!(upgrade_severity(&Severity::Medium), Severity::High);
|
||||
assert_eq!(upgrade_severity(&Severity::High), Severity::Critical);
|
||||
assert_eq!(upgrade_severity(&Severity::Critical), Severity::Critical);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn upgrade_severity_critical_is_ceiling() {
|
||||
let s = upgrade_severity(&Severity::Critical);
|
||||
assert_eq!(upgrade_severity(&s), Severity::Critical);
|
||||
}
|
||||
|
||||
// ── upgrade/downgrade roundtrip ──────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn upgrade_then_downgrade_is_identity_for_middle_values() {
|
||||
for sev in [Severity::Low, Severity::Medium, Severity::High] {
|
||||
assert_eq!(downgrade_severity(&upgrade_severity(&sev)), sev);
|
||||
}
|
||||
}
|
||||
|
||||
// ── TriageResult deserialization ─────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn triage_result_full() {
|
||||
let json = r#"{"action":"dismiss","confidence":8.5,"rationale":"false positive","remediation":"remove code"}"#;
|
||||
let r: TriageResult = serde_json::from_str(json).unwrap();
|
||||
assert_eq!(r.action, "dismiss");
|
||||
assert_eq!(r.confidence, 8.5);
|
||||
assert_eq!(r.rationale, "false positive");
|
||||
assert_eq!(r.remediation.as_deref(), Some("remove code"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn triage_result_defaults() {
|
||||
let json = r#"{}"#;
|
||||
let r: TriageResult = serde_json::from_str(json).unwrap();
|
||||
assert_eq!(r.action, "confirm");
|
||||
assert_eq!(r.confidence, 0.0);
|
||||
assert_eq!(r.rationale, "");
|
||||
assert!(r.remediation.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn triage_result_partial() {
|
||||
let json = r#"{"action":"downgrade","confidence":6.0}"#;
|
||||
let r: TriageResult = serde_json::from_str(json).unwrap();
|
||||
assert_eq!(r.action, "downgrade");
|
||||
assert_eq!(r.confidence, 6.0);
|
||||
assert_eq!(r.rationale, "");
|
||||
assert!(r.remediation.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn triage_result_with_markdown_fences() {
|
||||
// Simulate LLM wrapping response in markdown code fences
|
||||
let raw = "```json\n{\"action\":\"upgrade\",\"confidence\":9,\"rationale\":\"critical\",\"remediation\":null}\n```";
|
||||
let cleaned = raw
|
||||
.trim()
|
||||
.trim_start_matches("```json")
|
||||
.trim_start_matches("```")
|
||||
.trim_end_matches("```")
|
||||
.trim();
|
||||
let r: TriageResult = serde_json::from_str(cleaned).unwrap();
|
||||
assert_eq!(r.action, "upgrade");
|
||||
assert_eq!(r.confidence, 9.0);
|
||||
}
|
||||
}
|
||||
|
||||
369
compliance-agent/src/llm/types.rs
Normal file
369
compliance-agent/src/llm/types.rs
Normal file
@@ -0,0 +1,369 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
// ── Request types ──────────────────────────────────────────────
|
||||
|
||||
#[derive(Serialize, Clone, Debug)]
|
||||
pub struct ChatMessage {
|
||||
pub role: String,
|
||||
pub content: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tool_calls: Option<Vec<ToolCallRequest>>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tool_call_id: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub(crate) struct ChatCompletionRequest {
|
||||
pub(crate) model: String,
|
||||
pub(crate) messages: Vec<ChatMessage>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub(crate) temperature: Option<f64>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub(crate) max_tokens: Option<u32>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub(crate) tools: Option<Vec<ToolDefinitionPayload>>,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub(crate) struct ToolDefinitionPayload {
|
||||
pub(crate) r#type: String,
|
||||
pub(crate) function: ToolFunctionPayload,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub(crate) struct ToolFunctionPayload {
|
||||
pub(crate) name: String,
|
||||
pub(crate) description: String,
|
||||
pub(crate) parameters: serde_json::Value,
|
||||
}
|
||||
|
||||
// ── Response types ─────────────────────────────────────────────
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub(crate) struct ChatCompletionResponse {
|
||||
pub(crate) choices: Vec<ChatChoice>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub(crate) struct ChatChoice {
|
||||
pub(crate) message: ChatResponseMessage,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub(crate) struct ChatResponseMessage {
|
||||
#[serde(default)]
|
||||
pub(crate) content: Option<String>,
|
||||
#[serde(default)]
|
||||
pub(crate) tool_calls: Option<Vec<ToolCallResponse>>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub(crate) struct ToolCallResponse {
|
||||
pub(crate) id: String,
|
||||
pub(crate) function: ToolCallFunction,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub(crate) struct ToolCallFunction {
|
||||
pub(crate) name: String,
|
||||
pub(crate) arguments: String,
|
||||
}
|
||||
|
||||
// ── Public types for tool calling ──────────────────────────────
|
||||
|
||||
/// Definition of a tool that the LLM can invoke
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
pub struct ToolDefinition {
|
||||
pub name: String,
|
||||
pub description: String,
|
||||
pub parameters: serde_json::Value,
|
||||
}
|
||||
|
||||
/// A tool call request from the LLM
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct LlmToolCall {
|
||||
pub id: String,
|
||||
pub name: String,
|
||||
pub arguments: serde_json::Value,
|
||||
}
|
||||
|
||||
/// A tool call in the request message format (for sending back tool_calls in assistant messages)
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ToolCallRequest {
|
||||
pub id: String,
|
||||
pub r#type: String,
|
||||
pub function: ToolCallRequestFunction,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ToolCallRequestFunction {
|
||||
pub name: String,
|
||||
pub arguments: String,
|
||||
}
|
||||
|
||||
/// Response from the LLM — either content or tool calls
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum LlmResponse {
|
||||
Content(String),
|
||||
/// Tool calls with optional reasoning text from the LLM
|
||||
ToolCalls {
|
||||
calls: Vec<LlmToolCall>,
|
||||
reasoning: String,
|
||||
},
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use serde_json::json;
|
||||
|
||||
// ── ChatMessage ──────────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn chat_message_serializes_minimal() {
|
||||
let msg = ChatMessage {
|
||||
role: "user".to_string(),
|
||||
content: Some("hello".to_string()),
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
};
|
||||
let v = serde_json::to_value(&msg).unwrap();
|
||||
assert_eq!(v["role"], "user");
|
||||
assert_eq!(v["content"], "hello");
|
||||
// None fields with skip_serializing_if should be absent
|
||||
assert!(v.get("tool_calls").is_none());
|
||||
assert!(v.get("tool_call_id").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn chat_message_serializes_with_tool_calls() {
|
||||
let msg = ChatMessage {
|
||||
role: "assistant".to_string(),
|
||||
content: None,
|
||||
tool_calls: Some(vec![ToolCallRequest {
|
||||
id: "call_1".to_string(),
|
||||
r#type: "function".to_string(),
|
||||
function: ToolCallRequestFunction {
|
||||
name: "get_weather".to_string(),
|
||||
arguments: r#"{"city":"NYC"}"#.to_string(),
|
||||
},
|
||||
}]),
|
||||
tool_call_id: None,
|
||||
};
|
||||
let v = serde_json::to_value(&msg).unwrap();
|
||||
assert!(v["tool_calls"].is_array());
|
||||
assert_eq!(v["tool_calls"][0]["function"]["name"], "get_weather");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn chat_message_content_null_when_none() {
|
||||
let msg = ChatMessage {
|
||||
role: "assistant".to_string(),
|
||||
content: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
};
|
||||
let v = serde_json::to_value(&msg).unwrap();
|
||||
assert!(v["content"].is_null());
|
||||
}
|
||||
|
||||
// ── ToolDefinition ───────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn tool_definition_serializes() {
|
||||
let td = ToolDefinition {
|
||||
name: "search".to_string(),
|
||||
description: "Search the web".to_string(),
|
||||
parameters: json!({"type": "object", "properties": {"q": {"type": "string"}}}),
|
||||
};
|
||||
let v = serde_json::to_value(&td).unwrap();
|
||||
assert_eq!(v["name"], "search");
|
||||
assert_eq!(v["parameters"]["type"], "object");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tool_definition_empty_parameters() {
|
||||
let td = ToolDefinition {
|
||||
name: "noop".to_string(),
|
||||
description: "".to_string(),
|
||||
parameters: json!({}),
|
||||
};
|
||||
let v = serde_json::to_value(&td).unwrap();
|
||||
assert_eq!(v["parameters"], json!({}));
|
||||
}
|
||||
|
||||
// ── LlmToolCall ──────────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn llm_tool_call_roundtrip() {
|
||||
let call = LlmToolCall {
|
||||
id: "tc_42".to_string(),
|
||||
name: "run_scan".to_string(),
|
||||
arguments: json!({"path": "/tmp", "verbose": true}),
|
||||
};
|
||||
let serialized = serde_json::to_string(&call).unwrap();
|
||||
let deserialized: LlmToolCall = serde_json::from_str(&serialized).unwrap();
|
||||
assert_eq!(deserialized.id, "tc_42");
|
||||
assert_eq!(deserialized.name, "run_scan");
|
||||
assert_eq!(deserialized.arguments["path"], "/tmp");
|
||||
assert_eq!(deserialized.arguments["verbose"], true);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn llm_tool_call_empty_arguments() {
|
||||
let call = LlmToolCall {
|
||||
id: "tc_0".to_string(),
|
||||
name: "noop".to_string(),
|
||||
arguments: json!({}),
|
||||
};
|
||||
let rt: LlmToolCall = serde_json::from_str(&serde_json::to_string(&call).unwrap()).unwrap();
|
||||
assert!(rt.arguments.as_object().unwrap().is_empty());
|
||||
}
|
||||
|
||||
// ── ToolCallRequest / ToolCallRequestFunction ────────────────
|
||||
|
||||
#[test]
|
||||
fn tool_call_request_roundtrip() {
|
||||
let req = ToolCallRequest {
|
||||
id: "call_abc".to_string(),
|
||||
r#type: "function".to_string(),
|
||||
function: ToolCallRequestFunction {
|
||||
name: "my_func".to_string(),
|
||||
arguments: r#"{"x":1}"#.to_string(),
|
||||
},
|
||||
};
|
||||
let json_str = serde_json::to_string(&req).unwrap();
|
||||
let back: ToolCallRequest = serde_json::from_str(&json_str).unwrap();
|
||||
assert_eq!(back.id, "call_abc");
|
||||
assert_eq!(back.r#type, "function");
|
||||
assert_eq!(back.function.name, "my_func");
|
||||
assert_eq!(back.function.arguments, r#"{"x":1}"#);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tool_call_request_type_field_serializes_as_type() {
|
||||
let req = ToolCallRequest {
|
||||
id: "id".to_string(),
|
||||
r#type: "function".to_string(),
|
||||
function: ToolCallRequestFunction {
|
||||
name: "f".to_string(),
|
||||
arguments: "{}".to_string(),
|
||||
},
|
||||
};
|
||||
let v = serde_json::to_value(&req).unwrap();
|
||||
// The field should be "type" in JSON, not "r#type"
|
||||
assert!(v.get("type").is_some());
|
||||
assert!(v.get("r#type").is_none());
|
||||
}
|
||||
|
||||
// ── ChatCompletionRequest ────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn chat_completion_request_skips_none_fields() {
|
||||
let req = ChatCompletionRequest {
|
||||
model: "gpt-4".to_string(),
|
||||
messages: vec![],
|
||||
temperature: None,
|
||||
max_tokens: None,
|
||||
tools: None,
|
||||
};
|
||||
let v = serde_json::to_value(&req).unwrap();
|
||||
assert_eq!(v["model"], "gpt-4");
|
||||
assert!(v.get("temperature").is_none());
|
||||
assert!(v.get("max_tokens").is_none());
|
||||
assert!(v.get("tools").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn chat_completion_request_includes_set_fields() {
|
||||
let req = ChatCompletionRequest {
|
||||
model: "gpt-4".to_string(),
|
||||
messages: vec![],
|
||||
temperature: Some(0.7),
|
||||
max_tokens: Some(1024),
|
||||
tools: Some(vec![]),
|
||||
};
|
||||
let v = serde_json::to_value(&req).unwrap();
|
||||
assert_eq!(v["temperature"], 0.7);
|
||||
assert_eq!(v["max_tokens"], 1024);
|
||||
assert!(v["tools"].is_array());
|
||||
}
|
||||
|
||||
// ── ChatCompletionResponse deserialization ───────────────────
|
||||
|
||||
#[test]
|
||||
fn chat_completion_response_deserializes_content() {
|
||||
let json_str = r#"{"choices":[{"message":{"content":"Hello!"}}]}"#;
|
||||
let resp: ChatCompletionResponse = serde_json::from_str(json_str).unwrap();
|
||||
assert_eq!(resp.choices.len(), 1);
|
||||
assert_eq!(resp.choices[0].message.content.as_deref(), Some("Hello!"));
|
||||
assert!(resp.choices[0].message.tool_calls.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn chat_completion_response_deserializes_tool_calls() {
|
||||
let json_str = r#"{
|
||||
"choices": [{
|
||||
"message": {
|
||||
"tool_calls": [{
|
||||
"id": "call_1",
|
||||
"function": {"name": "search", "arguments": "{\"q\":\"rust\"}"}
|
||||
}]
|
||||
}
|
||||
}]
|
||||
}"#;
|
||||
let resp: ChatCompletionResponse = serde_json::from_str(json_str).unwrap();
|
||||
let tc = resp.choices[0].message.tool_calls.as_ref().unwrap();
|
||||
assert_eq!(tc.len(), 1);
|
||||
assert_eq!(tc[0].id, "call_1");
|
||||
assert_eq!(tc[0].function.name, "search");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn chat_completion_response_defaults_missing_fields() {
|
||||
// content and tool_calls are both missing — should default to None
|
||||
let json_str = r#"{"choices":[{"message":{}}]}"#;
|
||||
let resp: ChatCompletionResponse = serde_json::from_str(json_str).unwrap();
|
||||
assert!(resp.choices[0].message.content.is_none());
|
||||
assert!(resp.choices[0].message.tool_calls.is_none());
|
||||
}
|
||||
|
||||
// ── LlmResponse ─────────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn llm_response_content_variant() {
|
||||
let resp = LlmResponse::Content("answer".to_string());
|
||||
match resp {
|
||||
LlmResponse::Content(s) => assert_eq!(s, "answer"),
|
||||
_ => panic!("expected Content variant"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn llm_response_tool_calls_variant() {
|
||||
let resp = LlmResponse::ToolCalls {
|
||||
calls: vec![LlmToolCall {
|
||||
id: "1".to_string(),
|
||||
name: "f".to_string(),
|
||||
arguments: json!({}),
|
||||
}],
|
||||
reasoning: "because".to_string(),
|
||||
};
|
||||
match resp {
|
||||
LlmResponse::ToolCalls { calls, reasoning } => {
|
||||
assert_eq!(calls.len(), 1);
|
||||
assert_eq!(reasoning, "because");
|
||||
}
|
||||
_ => panic!("expected ToolCalls variant"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn llm_response_empty_content() {
|
||||
let resp = LlmResponse::Content(String::new());
|
||||
match resp {
|
||||
LlmResponse::Content(s) => assert!(s.is_empty()),
|
||||
_ => panic!("expected Content variant"),
|
||||
}
|
||||
}
|
||||
}
|
||||
150
compliance-agent/src/pentest/context.rs
Normal file
150
compliance-agent/src/pentest/context.rs
Normal file
@@ -0,0 +1,150 @@
|
||||
use futures_util::StreamExt;
|
||||
use mongodb::bson::doc;
|
||||
|
||||
use compliance_core::models::dast::DastTarget;
|
||||
use compliance_core::models::finding::Finding;
|
||||
use compliance_core::models::pentest::CodeContextHint;
|
||||
use compliance_core::models::sbom::SbomEntry;
|
||||
|
||||
use super::orchestrator::PentestOrchestrator;
|
||||
|
||||
impl PentestOrchestrator {
|
||||
/// Fetch SAST findings, SBOM entries (with CVEs), and code graph entry points
|
||||
/// for the repo linked to this DAST target.
|
||||
pub(crate) async fn gather_repo_context(
|
||||
&self,
|
||||
target: &DastTarget,
|
||||
) -> (Vec<Finding>, Vec<SbomEntry>, Vec<CodeContextHint>) {
|
||||
let Some(repo_id) = &target.repo_id else {
|
||||
return (Vec::new(), Vec::new(), Vec::new());
|
||||
};
|
||||
|
||||
let sast_findings = self.fetch_sast_findings(repo_id).await;
|
||||
let sbom_entries = self.fetch_vulnerable_sbom(repo_id).await;
|
||||
let code_context = self.fetch_code_context(repo_id, &sast_findings).await;
|
||||
|
||||
tracing::info!(
|
||||
repo_id,
|
||||
sast_findings = sast_findings.len(),
|
||||
vulnerable_deps = sbom_entries.len(),
|
||||
code_hints = code_context.len(),
|
||||
"Gathered code-awareness context for pentest"
|
||||
);
|
||||
|
||||
(sast_findings, sbom_entries, code_context)
|
||||
}
|
||||
|
||||
/// Fetch open/triaged SAST findings for the repo (not false positives or resolved)
|
||||
async fn fetch_sast_findings(&self, repo_id: &str) -> Vec<Finding> {
|
||||
let cursor = self
|
||||
.db
|
||||
.findings()
|
||||
.find(doc! {
|
||||
"repo_id": repo_id,
|
||||
"status": { "$in": ["open", "triaged"] },
|
||||
})
|
||||
.sort(doc! { "severity": -1 })
|
||||
.limit(100)
|
||||
.await;
|
||||
|
||||
match cursor {
|
||||
Ok(mut c) => {
|
||||
let mut results = Vec::new();
|
||||
while let Some(Ok(f)) = c.next().await {
|
||||
results.push(f);
|
||||
}
|
||||
results
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!("Failed to fetch SAST findings for pentest: {e}");
|
||||
Vec::new()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Fetch SBOM entries that have known vulnerabilities
|
||||
async fn fetch_vulnerable_sbom(&self, repo_id: &str) -> Vec<SbomEntry> {
|
||||
let cursor = self
|
||||
.db
|
||||
.sbom_entries()
|
||||
.find(doc! {
|
||||
"repo_id": repo_id,
|
||||
"known_vulnerabilities": { "$exists": true, "$ne": [] },
|
||||
})
|
||||
.limit(50)
|
||||
.await;
|
||||
|
||||
match cursor {
|
||||
Ok(mut c) => {
|
||||
let mut results = Vec::new();
|
||||
while let Some(Ok(e)) = c.next().await {
|
||||
results.push(e);
|
||||
}
|
||||
results
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!("Failed to fetch vulnerable SBOM entries: {e}");
|
||||
Vec::new()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Build CodeContextHint objects from the code knowledge graph.
|
||||
/// Maps entry points to their source files and links SAST findings.
|
||||
async fn fetch_code_context(
|
||||
&self,
|
||||
repo_id: &str,
|
||||
sast_findings: &[Finding],
|
||||
) -> Vec<CodeContextHint> {
|
||||
// Get entry point nodes from the code graph
|
||||
let cursor = self
|
||||
.db
|
||||
.graph_nodes()
|
||||
.find(doc! {
|
||||
"repo_id": repo_id,
|
||||
"is_entry_point": true,
|
||||
})
|
||||
.limit(50)
|
||||
.await;
|
||||
|
||||
let nodes = match cursor {
|
||||
Ok(mut c) => {
|
||||
let mut results = Vec::new();
|
||||
while let Some(Ok(n)) = c.next().await {
|
||||
results.push(n);
|
||||
}
|
||||
results
|
||||
}
|
||||
Err(_) => return Vec::new(),
|
||||
};
|
||||
|
||||
// Build hints by matching graph nodes to SAST findings by file path
|
||||
nodes
|
||||
.into_iter()
|
||||
.map(|node| {
|
||||
// Find SAST findings in the same file
|
||||
let linked_vulns: Vec<String> = sast_findings
|
||||
.iter()
|
||||
.filter(|f| f.file_path.as_deref() == Some(&node.file_path))
|
||||
.map(|f| {
|
||||
format!(
|
||||
"[{}] {}: {} (line {})",
|
||||
f.severity,
|
||||
f.scanner,
|
||||
f.title,
|
||||
f.line_number.unwrap_or(0)
|
||||
)
|
||||
})
|
||||
.collect();
|
||||
|
||||
CodeContextHint {
|
||||
endpoint_pattern: node.qualified_name.clone(),
|
||||
handler_function: node.name.clone(),
|
||||
file_path: node.file_path.clone(),
|
||||
code_snippet: String::new(), // Could fetch from embeddings
|
||||
known_vulnerabilities: linked_vulns,
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
@@ -1,4 +1,6 @@
|
||||
mod context;
|
||||
pub mod orchestrator;
|
||||
mod prompt_builder;
|
||||
pub mod report;
|
||||
|
||||
pub use orchestrator::PentestOrchestrator;
|
||||
|
||||
@@ -1,31 +1,27 @@
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use futures_util::StreamExt;
|
||||
use mongodb::bson::doc;
|
||||
use tokio::sync::broadcast;
|
||||
|
||||
use compliance_core::models::dast::DastTarget;
|
||||
use compliance_core::models::finding::{Finding, FindingStatus, Severity};
|
||||
use compliance_core::models::pentest::*;
|
||||
use compliance_core::models::sbom::SbomEntry;
|
||||
use compliance_core::traits::pentest_tool::PentestToolContext;
|
||||
use compliance_dast::ToolRegistry;
|
||||
|
||||
use crate::database::Database;
|
||||
use crate::llm::client::{
|
||||
ChatMessage, LlmResponse, ToolCallRequest, ToolCallRequestFunction, ToolDefinition,
|
||||
use crate::llm::{
|
||||
ChatMessage, LlmClient, LlmResponse, ToolCallRequest, ToolCallRequestFunction, ToolDefinition,
|
||||
};
|
||||
use crate::llm::LlmClient;
|
||||
|
||||
/// Maximum duration for a single pentest session before timeout
|
||||
const SESSION_TIMEOUT: Duration = Duration::from_secs(30 * 60); // 30 minutes
|
||||
|
||||
pub struct PentestOrchestrator {
|
||||
tool_registry: ToolRegistry,
|
||||
llm: Arc<LlmClient>,
|
||||
db: Database,
|
||||
event_tx: broadcast::Sender<PentestEvent>,
|
||||
pub(crate) tool_registry: ToolRegistry,
|
||||
pub(crate) llm: Arc<LlmClient>,
|
||||
pub(crate) db: Database,
|
||||
pub(crate) event_tx: broadcast::Sender<PentestEvent>,
|
||||
}
|
||||
|
||||
impl PentestOrchestrator {
|
||||
@@ -39,10 +35,12 @@ impl PentestOrchestrator {
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub fn subscribe(&self) -> broadcast::Receiver<PentestEvent> {
|
||||
self.event_tx.subscribe()
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub fn event_sender(&self) -> broadcast::Sender<PentestEvent> {
|
||||
self.event_tx.clone()
|
||||
}
|
||||
@@ -111,18 +109,20 @@ impl PentestOrchestrator {
|
||||
target: &DastTarget,
|
||||
initial_message: &str,
|
||||
) -> Result<(), crate::error::AgentError> {
|
||||
let session_id = session
|
||||
.id
|
||||
.map(|oid| oid.to_hex())
|
||||
.unwrap_or_default();
|
||||
let session_id = session.id.map(|oid| oid.to_hex()).unwrap_or_default();
|
||||
|
||||
// Gather code-awareness context from linked repo
|
||||
let (sast_findings, sbom_entries, code_context) =
|
||||
self.gather_repo_context(target).await;
|
||||
let (sast_findings, sbom_entries, code_context) = self.gather_repo_context(target).await;
|
||||
|
||||
// Build system prompt with code context
|
||||
let system_prompt = self
|
||||
.build_system_prompt(session, target, &sast_findings, &sbom_entries, &code_context)
|
||||
.build_system_prompt(
|
||||
session,
|
||||
target,
|
||||
&sast_findings,
|
||||
&sbom_entries,
|
||||
&code_context,
|
||||
)
|
||||
.await;
|
||||
|
||||
// Build tool definitions for LLM
|
||||
@@ -182,8 +182,7 @@ impl PentestOrchestrator {
|
||||
|
||||
match response {
|
||||
LlmResponse::Content(content) => {
|
||||
let msg =
|
||||
PentestMessage::assistant(session_id.clone(), content.clone());
|
||||
let msg = PentestMessage::assistant(session_id.clone(), content.clone());
|
||||
let _ = self.db.pentest_messages().insert_one(&msg).await;
|
||||
let _ = self.event_tx.send(PentestEvent::Message {
|
||||
content: content.clone(),
|
||||
@@ -213,7 +212,10 @@ impl PentestOrchestrator {
|
||||
}
|
||||
break;
|
||||
}
|
||||
LlmResponse::ToolCalls { calls: tool_calls, reasoning } => {
|
||||
LlmResponse::ToolCalls {
|
||||
calls: tool_calls,
|
||||
reasoning,
|
||||
} => {
|
||||
let tc_requests: Vec<ToolCallRequest> = tool_calls
|
||||
.iter()
|
||||
.map(|tc| ToolCallRequest {
|
||||
@@ -221,15 +223,18 @@ impl PentestOrchestrator {
|
||||
r#type: "function".to_string(),
|
||||
function: ToolCallRequestFunction {
|
||||
name: tc.name.clone(),
|
||||
arguments: serde_json::to_string(&tc.arguments)
|
||||
.unwrap_or_default(),
|
||||
arguments: serde_json::to_string(&tc.arguments).unwrap_or_default(),
|
||||
},
|
||||
})
|
||||
.collect();
|
||||
|
||||
messages.push(ChatMessage {
|
||||
role: "assistant".to_string(),
|
||||
content: if reasoning.is_empty() { None } else { Some(reasoning.clone()) },
|
||||
content: if reasoning.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(reasoning.clone())
|
||||
},
|
||||
tool_calls: Some(tc_requests),
|
||||
tool_call_id: None,
|
||||
});
|
||||
@@ -274,24 +279,30 @@ impl PentestOrchestrator {
|
||||
let insert_result =
|
||||
self.db.dast_findings().insert_one(&finding).await;
|
||||
if let Ok(res) = &insert_result {
|
||||
finding_ids.push(res.inserted_id.as_object_id().map(|oid| oid.to_hex()).unwrap_or_default());
|
||||
}
|
||||
let _ =
|
||||
self.event_tx.send(PentestEvent::Finding {
|
||||
finding_id: finding
|
||||
.id
|
||||
finding_ids.push(
|
||||
res.inserted_id
|
||||
.as_object_id()
|
||||
.map(|oid| oid.to_hex())
|
||||
.unwrap_or_default(),
|
||||
title: finding.title.clone(),
|
||||
severity: finding.severity.to_string(),
|
||||
});
|
||||
);
|
||||
}
|
||||
let _ = self.event_tx.send(PentestEvent::Finding {
|
||||
finding_id: finding
|
||||
.id
|
||||
.map(|oid| oid.to_hex())
|
||||
.unwrap_or_default(),
|
||||
title: finding.title.clone(),
|
||||
severity: finding.severity.to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
// Compute risk score based on findings severity
|
||||
let risk_score: Option<u8> = if findings_count > 0 {
|
||||
Some(std::cmp::min(
|
||||
100,
|
||||
(findings_count as u8).saturating_mul(15).saturating_add(20),
|
||||
(findings_count as u8)
|
||||
.saturating_mul(15)
|
||||
.saturating_add(20),
|
||||
))
|
||||
} else {
|
||||
None
|
||||
@@ -415,347 +426,4 @@ impl PentestOrchestrator {
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// ── Code-Awareness: Gather context from linked repo ─────────
|
||||
|
||||
/// Fetch SAST findings, SBOM entries (with CVEs), and code graph entry points
|
||||
/// for the repo linked to this DAST target.
|
||||
async fn gather_repo_context(
|
||||
&self,
|
||||
target: &DastTarget,
|
||||
) -> (Vec<Finding>, Vec<SbomEntry>, Vec<CodeContextHint>) {
|
||||
let Some(repo_id) = &target.repo_id else {
|
||||
return (Vec::new(), Vec::new(), Vec::new());
|
||||
};
|
||||
|
||||
let sast_findings = self.fetch_sast_findings(repo_id).await;
|
||||
let sbom_entries = self.fetch_vulnerable_sbom(repo_id).await;
|
||||
let code_context = self.fetch_code_context(repo_id, &sast_findings).await;
|
||||
|
||||
tracing::info!(
|
||||
repo_id,
|
||||
sast_findings = sast_findings.len(),
|
||||
vulnerable_deps = sbom_entries.len(),
|
||||
code_hints = code_context.len(),
|
||||
"Gathered code-awareness context for pentest"
|
||||
);
|
||||
|
||||
(sast_findings, sbom_entries, code_context)
|
||||
}
|
||||
|
||||
/// Fetch open/triaged SAST findings for the repo (not false positives or resolved)
|
||||
async fn fetch_sast_findings(&self, repo_id: &str) -> Vec<Finding> {
|
||||
let cursor = self
|
||||
.db
|
||||
.findings()
|
||||
.find(doc! {
|
||||
"repo_id": repo_id,
|
||||
"status": { "$in": ["open", "triaged"] },
|
||||
})
|
||||
.sort(doc! { "severity": -1 })
|
||||
.limit(100)
|
||||
.await;
|
||||
|
||||
match cursor {
|
||||
Ok(mut c) => {
|
||||
let mut results = Vec::new();
|
||||
while let Some(Ok(f)) = c.next().await {
|
||||
results.push(f);
|
||||
}
|
||||
results
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!("Failed to fetch SAST findings for pentest: {e}");
|
||||
Vec::new()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Fetch SBOM entries that have known vulnerabilities
|
||||
async fn fetch_vulnerable_sbom(&self, repo_id: &str) -> Vec<SbomEntry> {
|
||||
let cursor = self
|
||||
.db
|
||||
.sbom_entries()
|
||||
.find(doc! {
|
||||
"repo_id": repo_id,
|
||||
"known_vulnerabilities": { "$exists": true, "$ne": [] },
|
||||
})
|
||||
.limit(50)
|
||||
.await;
|
||||
|
||||
match cursor {
|
||||
Ok(mut c) => {
|
||||
let mut results = Vec::new();
|
||||
while let Some(Ok(e)) = c.next().await {
|
||||
results.push(e);
|
||||
}
|
||||
results
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!("Failed to fetch vulnerable SBOM entries: {e}");
|
||||
Vec::new()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Build CodeContextHint objects from the code knowledge graph.
|
||||
/// Maps entry points to their source files and links SAST findings.
|
||||
async fn fetch_code_context(
|
||||
&self,
|
||||
repo_id: &str,
|
||||
sast_findings: &[Finding],
|
||||
) -> Vec<CodeContextHint> {
|
||||
// Get entry point nodes from the code graph
|
||||
let cursor = self
|
||||
.db
|
||||
.graph_nodes()
|
||||
.find(doc! {
|
||||
"repo_id": repo_id,
|
||||
"is_entry_point": true,
|
||||
})
|
||||
.limit(50)
|
||||
.await;
|
||||
|
||||
let nodes = match cursor {
|
||||
Ok(mut c) => {
|
||||
let mut results = Vec::new();
|
||||
while let Some(Ok(n)) = c.next().await {
|
||||
results.push(n);
|
||||
}
|
||||
results
|
||||
}
|
||||
Err(_) => return Vec::new(),
|
||||
};
|
||||
|
||||
// Build hints by matching graph nodes to SAST findings by file path
|
||||
nodes
|
||||
.into_iter()
|
||||
.map(|node| {
|
||||
// Find SAST findings in the same file
|
||||
let linked_vulns: Vec<String> = sast_findings
|
||||
.iter()
|
||||
.filter(|f| {
|
||||
f.file_path.as_deref() == Some(&node.file_path)
|
||||
})
|
||||
.map(|f| {
|
||||
format!(
|
||||
"[{}] {}: {} (line {})",
|
||||
f.severity,
|
||||
f.scanner,
|
||||
f.title,
|
||||
f.line_number.unwrap_or(0)
|
||||
)
|
||||
})
|
||||
.collect();
|
||||
|
||||
CodeContextHint {
|
||||
endpoint_pattern: node.qualified_name.clone(),
|
||||
handler_function: node.name.clone(),
|
||||
file_path: node.file_path.clone(),
|
||||
code_snippet: String::new(), // Could fetch from embeddings
|
||||
known_vulnerabilities: linked_vulns,
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
// ── System Prompt Builder ───────────────────────────────────
|
||||
|
||||
async fn build_system_prompt(
|
||||
&self,
|
||||
session: &PentestSession,
|
||||
target: &DastTarget,
|
||||
sast_findings: &[Finding],
|
||||
sbom_entries: &[SbomEntry],
|
||||
code_context: &[CodeContextHint],
|
||||
) -> String {
|
||||
let tool_names = self.tool_registry.list_names().join(", ");
|
||||
let strategy_guidance = match session.strategy {
|
||||
PentestStrategy::Quick => {
|
||||
"Focus on the most common and impactful vulnerabilities. Run a quick recon, then target the highest-risk areas."
|
||||
}
|
||||
PentestStrategy::Comprehensive => {
|
||||
"Perform a thorough assessment covering all vulnerability types. Start with recon, then systematically test each attack surface."
|
||||
}
|
||||
PentestStrategy::Targeted => {
|
||||
"Focus specifically on areas highlighted by SAST findings and known CVEs. Prioritize exploiting known weaknesses."
|
||||
}
|
||||
PentestStrategy::Aggressive => {
|
||||
"Use all available tools aggressively. Test with maximum payloads and attempt full exploitation."
|
||||
}
|
||||
PentestStrategy::Stealth => {
|
||||
"Minimize noise. Use fewer requests, avoid aggressive payloads. Focus on passive analysis and targeted probes."
|
||||
}
|
||||
};
|
||||
|
||||
// Build SAST findings section
|
||||
let sast_section = if sast_findings.is_empty() {
|
||||
String::from("No SAST findings available for this target.")
|
||||
} else {
|
||||
let critical = sast_findings
|
||||
.iter()
|
||||
.filter(|f| f.severity == Severity::Critical)
|
||||
.count();
|
||||
let high = sast_findings
|
||||
.iter()
|
||||
.filter(|f| f.severity == Severity::High)
|
||||
.count();
|
||||
|
||||
let mut section = format!(
|
||||
"{} open findings ({} critical, {} high):\n",
|
||||
sast_findings.len(),
|
||||
critical,
|
||||
high
|
||||
);
|
||||
|
||||
// List the most important findings (critical/high first, up to 20)
|
||||
for f in sast_findings.iter().take(20) {
|
||||
let file_info = f
|
||||
.file_path
|
||||
.as_ref()
|
||||
.map(|p| {
|
||||
format!(
|
||||
" in {}:{}",
|
||||
p,
|
||||
f.line_number.unwrap_or(0)
|
||||
)
|
||||
})
|
||||
.unwrap_or_default();
|
||||
let status_note = match f.status {
|
||||
FindingStatus::Triaged => " [TRIAGED]",
|
||||
_ => "",
|
||||
};
|
||||
section.push_str(&format!(
|
||||
"- [{sev}] {title}{file}{status}\n",
|
||||
sev = f.severity,
|
||||
title = f.title,
|
||||
file = file_info,
|
||||
status = status_note,
|
||||
));
|
||||
if let Some(cwe) = &f.cwe {
|
||||
section.push_str(&format!(" CWE: {cwe}\n"));
|
||||
}
|
||||
}
|
||||
if sast_findings.len() > 20 {
|
||||
section.push_str(&format!(
|
||||
"... and {} more findings\n",
|
||||
sast_findings.len() - 20
|
||||
));
|
||||
}
|
||||
section
|
||||
};
|
||||
|
||||
// Build SBOM/CVE section
|
||||
let sbom_section = if sbom_entries.is_empty() {
|
||||
String::from("No vulnerable dependencies identified.")
|
||||
} else {
|
||||
let mut section = format!(
|
||||
"{} dependencies with known vulnerabilities:\n",
|
||||
sbom_entries.len()
|
||||
);
|
||||
for entry in sbom_entries.iter().take(15) {
|
||||
let cve_ids: Vec<&str> = entry
|
||||
.known_vulnerabilities
|
||||
.iter()
|
||||
.map(|v| v.id.as_str())
|
||||
.collect();
|
||||
section.push_str(&format!(
|
||||
"- {} {} ({}): {}\n",
|
||||
entry.name,
|
||||
entry.version,
|
||||
entry.package_manager,
|
||||
cve_ids.join(", ")
|
||||
));
|
||||
}
|
||||
if sbom_entries.len() > 15 {
|
||||
section.push_str(&format!(
|
||||
"... and {} more vulnerable dependencies\n",
|
||||
sbom_entries.len() - 15
|
||||
));
|
||||
}
|
||||
section
|
||||
};
|
||||
|
||||
// Build code context section
|
||||
let code_section = if code_context.is_empty() {
|
||||
String::from("No code knowledge graph available for this target.")
|
||||
} else {
|
||||
let with_vulns = code_context
|
||||
.iter()
|
||||
.filter(|c| !c.known_vulnerabilities.is_empty())
|
||||
.count();
|
||||
|
||||
let mut section = format!(
|
||||
"{} entry points identified ({} with linked SAST findings):\n",
|
||||
code_context.len(),
|
||||
with_vulns
|
||||
);
|
||||
|
||||
for hint in code_context.iter().take(20) {
|
||||
section.push_str(&format!(
|
||||
"- {} ({})\n",
|
||||
hint.endpoint_pattern, hint.file_path
|
||||
));
|
||||
for vuln in &hint.known_vulnerabilities {
|
||||
section.push_str(&format!(" SAST: {vuln}\n"));
|
||||
}
|
||||
}
|
||||
section
|
||||
};
|
||||
|
||||
format!(
|
||||
r#"You are an expert penetration tester conducting an authorized security assessment.
|
||||
|
||||
## Target
|
||||
- **Name**: {target_name}
|
||||
- **URL**: {base_url}
|
||||
- **Type**: {target_type}
|
||||
- **Rate Limit**: {rate_limit} req/s
|
||||
- **Destructive Tests Allowed**: {allow_destructive}
|
||||
- **Linked Repository**: {repo_linked}
|
||||
|
||||
## Strategy
|
||||
{strategy_guidance}
|
||||
|
||||
## SAST Findings (Static Analysis)
|
||||
{sast_section}
|
||||
|
||||
## Vulnerable Dependencies (SBOM)
|
||||
{sbom_section}
|
||||
|
||||
## Code Entry Points (Knowledge Graph)
|
||||
{code_section}
|
||||
|
||||
## Available Tools
|
||||
{tool_names}
|
||||
|
||||
## Instructions
|
||||
1. Start by running reconnaissance (recon tool) to fingerprint the target and discover technologies.
|
||||
2. Run the OpenAPI parser to discover API endpoints from specs.
|
||||
3. Check infrastructure: DNS, DMARC, TLS, security headers, cookies, CSP, CORS.
|
||||
4. Based on SAST findings, prioritize testing endpoints where vulnerabilities were found in code.
|
||||
5. For each vulnerability type found in SAST, use the corresponding DAST tool to verify exploitability.
|
||||
6. If vulnerable dependencies are listed, try to trigger known CVE conditions against the running application.
|
||||
7. Test rate limiting on critical endpoints (login, API).
|
||||
8. Check for console.log leakage in frontend JavaScript.
|
||||
9. Analyze tool results and chain findings — if one vulnerability enables others, explore the chain.
|
||||
10. When testing is complete, provide a structured summary with severity and remediation.
|
||||
11. Always explain your reasoning before invoking each tool.
|
||||
12. When done, say "Testing complete" followed by a final summary.
|
||||
|
||||
## Important
|
||||
- This is an authorized penetration test. All testing is permitted within the target scope.
|
||||
- Respect the rate limit of {rate_limit} requests per second.
|
||||
- Only use destructive tests if explicitly allowed ({allow_destructive}).
|
||||
- Use SAST findings to guide your testing — they tell you WHERE in the code vulnerabilities exist.
|
||||
- Use SBOM data to understand what technologies and versions the target runs.
|
||||
"#,
|
||||
target_name = target.name,
|
||||
base_url = target.base_url,
|
||||
target_type = target.target_type,
|
||||
rate_limit = target.rate_limit,
|
||||
allow_destructive = target.allow_destructive,
|
||||
repo_linked = target.repo_id.as_deref().unwrap_or("None"),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
504
compliance-agent/src/pentest/prompt_builder.rs
Normal file
504
compliance-agent/src/pentest/prompt_builder.rs
Normal file
@@ -0,0 +1,504 @@
|
||||
use compliance_core::models::dast::DastTarget;
|
||||
use compliance_core::models::finding::{Finding, FindingStatus, Severity};
|
||||
use compliance_core::models::pentest::*;
|
||||
use compliance_core::models::sbom::SbomEntry;
|
||||
|
||||
use super::orchestrator::PentestOrchestrator;
|
||||
|
||||
/// Return strategy guidance text for the given strategy.
|
||||
fn strategy_guidance(strategy: &PentestStrategy) -> &'static str {
|
||||
match strategy {
|
||||
PentestStrategy::Quick => {
|
||||
"Focus on the most common and impactful vulnerabilities. Run a quick recon, then target the highest-risk areas."
|
||||
}
|
||||
PentestStrategy::Comprehensive => {
|
||||
"Perform a thorough assessment covering all vulnerability types. Start with recon, then systematically test each attack surface."
|
||||
}
|
||||
PentestStrategy::Targeted => {
|
||||
"Focus specifically on areas highlighted by SAST findings and known CVEs. Prioritize exploiting known weaknesses."
|
||||
}
|
||||
PentestStrategy::Aggressive => {
|
||||
"Use all available tools aggressively. Test with maximum payloads and attempt full exploitation."
|
||||
}
|
||||
PentestStrategy::Stealth => {
|
||||
"Minimize noise. Use fewer requests, avoid aggressive payloads. Focus on passive analysis and targeted probes."
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Build the SAST findings section for the system prompt.
|
||||
fn build_sast_section(sast_findings: &[Finding]) -> String {
|
||||
if sast_findings.is_empty() {
|
||||
return String::from("No SAST findings available for this target.");
|
||||
}
|
||||
|
||||
let critical = sast_findings
|
||||
.iter()
|
||||
.filter(|f| f.severity == Severity::Critical)
|
||||
.count();
|
||||
let high = sast_findings
|
||||
.iter()
|
||||
.filter(|f| f.severity == Severity::High)
|
||||
.count();
|
||||
|
||||
let mut section = format!(
|
||||
"{} open findings ({} critical, {} high):\n",
|
||||
sast_findings.len(),
|
||||
critical,
|
||||
high
|
||||
);
|
||||
|
||||
// List the most important findings (critical/high first, up to 20)
|
||||
for f in sast_findings.iter().take(20) {
|
||||
let file_info = f
|
||||
.file_path
|
||||
.as_ref()
|
||||
.map(|p| format!(" in {}:{}", p, f.line_number.unwrap_or(0)))
|
||||
.unwrap_or_default();
|
||||
let status_note = match f.status {
|
||||
FindingStatus::Triaged => " [TRIAGED]",
|
||||
_ => "",
|
||||
};
|
||||
section.push_str(&format!(
|
||||
"- [{sev}] {title}{file}{status}\n",
|
||||
sev = f.severity,
|
||||
title = f.title,
|
||||
file = file_info,
|
||||
status = status_note,
|
||||
));
|
||||
if let Some(cwe) = &f.cwe {
|
||||
section.push_str(&format!(" CWE: {cwe}\n"));
|
||||
}
|
||||
}
|
||||
if sast_findings.len() > 20 {
|
||||
section.push_str(&format!(
|
||||
"... and {} more findings\n",
|
||||
sast_findings.len() - 20
|
||||
));
|
||||
}
|
||||
section
|
||||
}
|
||||
|
||||
/// Build the SBOM/CVE section for the system prompt.
|
||||
fn build_sbom_section(sbom_entries: &[SbomEntry]) -> String {
|
||||
if sbom_entries.is_empty() {
|
||||
return String::from("No vulnerable dependencies identified.");
|
||||
}
|
||||
|
||||
let mut section = format!(
|
||||
"{} dependencies with known vulnerabilities:\n",
|
||||
sbom_entries.len()
|
||||
);
|
||||
for entry in sbom_entries.iter().take(15) {
|
||||
let cve_ids: Vec<&str> = entry
|
||||
.known_vulnerabilities
|
||||
.iter()
|
||||
.map(|v| v.id.as_str())
|
||||
.collect();
|
||||
section.push_str(&format!(
|
||||
"- {} {} ({}): {}\n",
|
||||
entry.name,
|
||||
entry.version,
|
||||
entry.package_manager,
|
||||
cve_ids.join(", ")
|
||||
));
|
||||
}
|
||||
if sbom_entries.len() > 15 {
|
||||
section.push_str(&format!(
|
||||
"... and {} more vulnerable dependencies\n",
|
||||
sbom_entries.len() - 15
|
||||
));
|
||||
}
|
||||
section
|
||||
}
|
||||
|
||||
/// Build the code context section for the system prompt.
|
||||
fn build_code_section(code_context: &[CodeContextHint]) -> String {
|
||||
if code_context.is_empty() {
|
||||
return String::from("No code knowledge graph available for this target.");
|
||||
}
|
||||
|
||||
let with_vulns = code_context
|
||||
.iter()
|
||||
.filter(|c| !c.known_vulnerabilities.is_empty())
|
||||
.count();
|
||||
|
||||
let mut section = format!(
|
||||
"{} entry points identified ({} with linked SAST findings):\n",
|
||||
code_context.len(),
|
||||
with_vulns
|
||||
);
|
||||
|
||||
for hint in code_context.iter().take(20) {
|
||||
section.push_str(&format!(
|
||||
"- {} ({})\n",
|
||||
hint.endpoint_pattern, hint.file_path
|
||||
));
|
||||
for vuln in &hint.known_vulnerabilities {
|
||||
section.push_str(&format!(" SAST: {vuln}\n"));
|
||||
}
|
||||
}
|
||||
section
|
||||
}
|
||||
|
||||
impl PentestOrchestrator {
|
||||
pub(crate) async fn build_system_prompt(
|
||||
&self,
|
||||
session: &PentestSession,
|
||||
target: &DastTarget,
|
||||
sast_findings: &[Finding],
|
||||
sbom_entries: &[SbomEntry],
|
||||
code_context: &[CodeContextHint],
|
||||
) -> String {
|
||||
let tool_names = self.tool_registry.list_names().join(", ");
|
||||
let guidance = strategy_guidance(&session.strategy);
|
||||
let sast_section = build_sast_section(sast_findings);
|
||||
let sbom_section = build_sbom_section(sbom_entries);
|
||||
let code_section = build_code_section(code_context);
|
||||
|
||||
format!(
|
||||
r#"You are an expert penetration tester conducting an authorized security assessment.
|
||||
|
||||
## Target
|
||||
- **Name**: {target_name}
|
||||
- **URL**: {base_url}
|
||||
- **Type**: {target_type}
|
||||
- **Rate Limit**: {rate_limit} req/s
|
||||
- **Destructive Tests Allowed**: {allow_destructive}
|
||||
- **Linked Repository**: {repo_linked}
|
||||
|
||||
## Strategy
|
||||
{strategy_guidance}
|
||||
|
||||
## SAST Findings (Static Analysis)
|
||||
{sast_section}
|
||||
|
||||
## Vulnerable Dependencies (SBOM)
|
||||
{sbom_section}
|
||||
|
||||
## Code Entry Points (Knowledge Graph)
|
||||
{code_section}
|
||||
|
||||
## Available Tools
|
||||
{tool_names}
|
||||
|
||||
## Instructions
|
||||
1. Start by running reconnaissance (recon tool) to fingerprint the target and discover technologies.
|
||||
2. Run the OpenAPI parser to discover API endpoints from specs.
|
||||
3. Check infrastructure: DNS, DMARC, TLS, security headers, cookies, CSP, CORS.
|
||||
4. Based on SAST findings, prioritize testing endpoints where vulnerabilities were found in code.
|
||||
5. For each vulnerability type found in SAST, use the corresponding DAST tool to verify exploitability.
|
||||
6. If vulnerable dependencies are listed, try to trigger known CVE conditions against the running application.
|
||||
7. Test rate limiting on critical endpoints (login, API).
|
||||
8. Check for console.log leakage in frontend JavaScript.
|
||||
9. Analyze tool results and chain findings — if one vulnerability enables others, explore the chain.
|
||||
10. When testing is complete, provide a structured summary with severity and remediation.
|
||||
11. Always explain your reasoning before invoking each tool.
|
||||
12. When done, say "Testing complete" followed by a final summary.
|
||||
|
||||
## Important
|
||||
- This is an authorized penetration test. All testing is permitted within the target scope.
|
||||
- Respect the rate limit of {rate_limit} requests per second.
|
||||
- Only use destructive tests if explicitly allowed ({allow_destructive}).
|
||||
- Use SAST findings to guide your testing — they tell you WHERE in the code vulnerabilities exist.
|
||||
- Use SBOM data to understand what technologies and versions the target runs.
|
||||
"#,
|
||||
target_name = target.name,
|
||||
base_url = target.base_url,
|
||||
target_type = target.target_type,
|
||||
rate_limit = target.rate_limit,
|
||||
allow_destructive = target.allow_destructive,
|
||||
repo_linked = target.repo_id.as_deref().unwrap_or("None"),
|
||||
strategy_guidance = guidance,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use compliance_core::models::finding::Severity;
|
||||
use compliance_core::models::sbom::VulnRef;
|
||||
use compliance_core::models::scan::ScanType;
|
||||
|
||||
fn make_finding(
|
||||
severity: Severity,
|
||||
title: &str,
|
||||
file_path: Option<&str>,
|
||||
line: Option<u32>,
|
||||
status: FindingStatus,
|
||||
cwe: Option<&str>,
|
||||
) -> Finding {
|
||||
let mut f = Finding::new(
|
||||
"repo-1".into(),
|
||||
format!("fp-{title}"),
|
||||
"semgrep".into(),
|
||||
ScanType::Sast,
|
||||
title.into(),
|
||||
"desc".into(),
|
||||
severity,
|
||||
);
|
||||
f.file_path = file_path.map(|s| s.to_string());
|
||||
f.line_number = line;
|
||||
f.status = status;
|
||||
f.cwe = cwe.map(|s| s.to_string());
|
||||
f
|
||||
}
|
||||
|
||||
fn make_sbom_entry(name: &str, version: &str, cves: &[&str]) -> SbomEntry {
|
||||
let mut entry = SbomEntry::new("repo-1".into(), name.into(), version.into(), "npm".into());
|
||||
entry.known_vulnerabilities = cves
|
||||
.iter()
|
||||
.map(|id| VulnRef {
|
||||
id: id.to_string(),
|
||||
source: "nvd".into(),
|
||||
severity: None,
|
||||
url: None,
|
||||
})
|
||||
.collect();
|
||||
entry
|
||||
}
|
||||
|
||||
fn make_code_hint(endpoint: &str, file: &str, vulns: Vec<String>) -> CodeContextHint {
|
||||
CodeContextHint {
|
||||
endpoint_pattern: endpoint.into(),
|
||||
handler_function: "handler".into(),
|
||||
file_path: file.into(),
|
||||
code_snippet: String::new(),
|
||||
known_vulnerabilities: vulns,
|
||||
}
|
||||
}
|
||||
|
||||
// ── strategy_guidance ────────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn strategy_guidance_quick() {
|
||||
let g = strategy_guidance(&PentestStrategy::Quick);
|
||||
assert!(g.contains("most common"));
|
||||
assert!(g.contains("quick recon"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn strategy_guidance_comprehensive() {
|
||||
let g = strategy_guidance(&PentestStrategy::Comprehensive);
|
||||
assert!(g.contains("thorough assessment"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn strategy_guidance_targeted() {
|
||||
let g = strategy_guidance(&PentestStrategy::Targeted);
|
||||
assert!(g.contains("SAST findings"));
|
||||
assert!(g.contains("known CVEs"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn strategy_guidance_aggressive() {
|
||||
let g = strategy_guidance(&PentestStrategy::Aggressive);
|
||||
assert!(g.contains("aggressively"));
|
||||
assert!(g.contains("full exploitation"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn strategy_guidance_stealth() {
|
||||
let g = strategy_guidance(&PentestStrategy::Stealth);
|
||||
assert!(g.contains("Minimize noise"));
|
||||
assert!(g.contains("passive analysis"));
|
||||
}
|
||||
|
||||
// ── build_sast_section ───────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn sast_section_empty() {
|
||||
let section = build_sast_section(&[]);
|
||||
assert_eq!(section, "No SAST findings available for this target.");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sast_section_single_critical() {
|
||||
let findings = vec![make_finding(
|
||||
Severity::Critical,
|
||||
"SQL Injection",
|
||||
Some("src/db.rs"),
|
||||
Some(42),
|
||||
FindingStatus::Open,
|
||||
Some("CWE-89"),
|
||||
)];
|
||||
let section = build_sast_section(&findings);
|
||||
assert!(section.contains("1 open findings (1 critical, 0 high)"));
|
||||
assert!(section.contains("[critical] SQL Injection in src/db.rs:42"));
|
||||
assert!(section.contains("CWE: CWE-89"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sast_section_triaged_finding_shows_marker() {
|
||||
let findings = vec![make_finding(
|
||||
Severity::High,
|
||||
"XSS",
|
||||
None,
|
||||
None,
|
||||
FindingStatus::Triaged,
|
||||
None,
|
||||
)];
|
||||
let section = build_sast_section(&findings);
|
||||
assert!(section.contains("[TRIAGED]"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sast_section_no_file_path_omits_location() {
|
||||
let findings = vec![make_finding(
|
||||
Severity::Medium,
|
||||
"Open Redirect",
|
||||
None,
|
||||
None,
|
||||
FindingStatus::Open,
|
||||
None,
|
||||
)];
|
||||
let section = build_sast_section(&findings);
|
||||
assert!(section.contains("- [medium] Open Redirect\n"));
|
||||
assert!(!section.contains(" in "));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sast_section_counts_critical_and_high() {
|
||||
let findings = vec![
|
||||
make_finding(
|
||||
Severity::Critical,
|
||||
"F1",
|
||||
None,
|
||||
None,
|
||||
FindingStatus::Open,
|
||||
None,
|
||||
),
|
||||
make_finding(
|
||||
Severity::Critical,
|
||||
"F2",
|
||||
None,
|
||||
None,
|
||||
FindingStatus::Open,
|
||||
None,
|
||||
),
|
||||
make_finding(Severity::High, "F3", None, None, FindingStatus::Open, None),
|
||||
make_finding(
|
||||
Severity::Medium,
|
||||
"F4",
|
||||
None,
|
||||
None,
|
||||
FindingStatus::Open,
|
||||
None,
|
||||
),
|
||||
];
|
||||
let section = build_sast_section(&findings);
|
||||
assert!(section.contains("4 open findings (2 critical, 1 high)"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sast_section_truncates_at_20() {
|
||||
let findings: Vec<Finding> = (0..25)
|
||||
.map(|i| {
|
||||
make_finding(
|
||||
Severity::Low,
|
||||
&format!("Finding {i}"),
|
||||
None,
|
||||
None,
|
||||
FindingStatus::Open,
|
||||
None,
|
||||
)
|
||||
})
|
||||
.collect();
|
||||
let section = build_sast_section(&findings);
|
||||
assert!(section.contains("... and 5 more findings"));
|
||||
// Should contain Finding 19 (the 20th) but not Finding 20 (the 21st)
|
||||
assert!(section.contains("Finding 19"));
|
||||
assert!(!section.contains("Finding 20"));
|
||||
}
|
||||
|
||||
// ── build_sbom_section ───────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn sbom_section_empty() {
|
||||
let section = build_sbom_section(&[]);
|
||||
assert_eq!(section, "No vulnerable dependencies identified.");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sbom_section_single_entry() {
|
||||
let entries = vec![make_sbom_entry("lodash", "4.17.20", &["CVE-2021-23337"])];
|
||||
let section = build_sbom_section(&entries);
|
||||
assert!(section.contains("1 dependencies with known vulnerabilities"));
|
||||
assert!(section.contains("- lodash 4.17.20 (npm): CVE-2021-23337"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sbom_section_multiple_cves() {
|
||||
let entries = vec![make_sbom_entry(
|
||||
"openssl",
|
||||
"1.1.1",
|
||||
&["CVE-2022-0001", "CVE-2022-0002"],
|
||||
)];
|
||||
let section = build_sbom_section(&entries);
|
||||
assert!(section.contains("CVE-2022-0001, CVE-2022-0002"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sbom_section_truncates_at_15() {
|
||||
let entries: Vec<SbomEntry> = (0..18)
|
||||
.map(|i| make_sbom_entry(&format!("pkg-{i}"), "1.0.0", &["CVE-2024-0001"]))
|
||||
.collect();
|
||||
let section = build_sbom_section(&entries);
|
||||
assert!(section.contains("... and 3 more vulnerable dependencies"));
|
||||
assert!(section.contains("pkg-14"));
|
||||
assert!(!section.contains("pkg-15"));
|
||||
}
|
||||
|
||||
// ── build_code_section ───────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn code_section_empty() {
|
||||
let section = build_code_section(&[]);
|
||||
assert_eq!(
|
||||
section,
|
||||
"No code knowledge graph available for this target."
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn code_section_single_entry_no_vulns() {
|
||||
let hints = vec![make_code_hint("GET /api/users", "src/routes.rs", vec![])];
|
||||
let section = build_code_section(&hints);
|
||||
assert!(section.contains("1 entry points identified (0 with linked SAST findings)"));
|
||||
assert!(section.contains("- GET /api/users (src/routes.rs)"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn code_section_with_linked_vulns() {
|
||||
let hints = vec![make_code_hint(
|
||||
"POST /login",
|
||||
"src/auth.rs",
|
||||
vec!["[critical] semgrep: SQL Injection (line 15)".into()],
|
||||
)];
|
||||
let section = build_code_section(&hints);
|
||||
assert!(section.contains("1 entry points identified (1 with linked SAST findings)"));
|
||||
assert!(section.contains("SAST: [critical] semgrep: SQL Injection (line 15)"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn code_section_counts_entries_with_vulns() {
|
||||
let hints = vec![
|
||||
make_code_hint("GET /a", "a.rs", vec!["vuln1".into()]),
|
||||
make_code_hint("GET /b", "b.rs", vec![]),
|
||||
make_code_hint("GET /c", "c.rs", vec!["vuln2".into(), "vuln3".into()]),
|
||||
];
|
||||
let section = build_code_section(&hints);
|
||||
assert!(section.contains("3 entry points identified (2 with linked SAST findings)"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn code_section_truncates_at_20() {
|
||||
let hints: Vec<CodeContextHint> = (0..25)
|
||||
.map(|i| make_code_hint(&format!("GET /ep{i}"), &format!("f{i}.rs"), vec![]))
|
||||
.collect();
|
||||
let section = build_code_section(&hints);
|
||||
assert!(section.contains("GET /ep19"));
|
||||
assert!(!section.contains("GET /ep20"));
|
||||
}
|
||||
}
|
||||
43
compliance-agent/src/pentest/report/archive.rs
Normal file
43
compliance-agent/src/pentest/report/archive.rs
Normal file
@@ -0,0 +1,43 @@
|
||||
use std::io::{Cursor, Write};
|
||||
|
||||
use zip::write::SimpleFileOptions;
|
||||
use zip::AesMode;
|
||||
|
||||
use super::ReportContext;
|
||||
|
||||
pub(super) fn build_zip(
|
||||
ctx: &ReportContext,
|
||||
password: &str,
|
||||
html: &str,
|
||||
pdf: &[u8],
|
||||
) -> Result<Vec<u8>, zip::result::ZipError> {
|
||||
let buf = Cursor::new(Vec::new());
|
||||
let mut zip = zip::ZipWriter::new(buf);
|
||||
|
||||
let options = SimpleFileOptions::default()
|
||||
.compression_method(zip::CompressionMethod::Deflated)
|
||||
.with_aes_encryption(AesMode::Aes256, password);
|
||||
|
||||
// report.pdf (primary)
|
||||
zip.start_file("report.pdf", options)?;
|
||||
zip.write_all(pdf)?;
|
||||
|
||||
// report.html (fallback)
|
||||
zip.start_file("report.html", options)?;
|
||||
zip.write_all(html.as_bytes())?;
|
||||
|
||||
// findings.json
|
||||
let findings_json =
|
||||
serde_json::to_string_pretty(&ctx.findings).unwrap_or_else(|_| "[]".to_string());
|
||||
zip.start_file("findings.json", options)?;
|
||||
zip.write_all(findings_json.as_bytes())?;
|
||||
|
||||
// attack-chain.json
|
||||
let chain_json =
|
||||
serde_json::to_string_pretty(&ctx.attack_chain).unwrap_or_else(|_| "[]".to_string());
|
||||
zip.start_file("attack-chain.json", options)?;
|
||||
zip.write_all(chain_json.as_bytes())?;
|
||||
|
||||
let cursor = zip.finish()?;
|
||||
Ok(cursor.into_inner())
|
||||
}
|
||||
@@ -1,193 +1,50 @@
|
||||
use std::io::{Cursor, Write};
|
||||
|
||||
use compliance_core::models::dast::DastFinding;
|
||||
use compliance_core::models::pentest::{AttackChainNode, PentestSession};
|
||||
use sha2::{Digest, Sha256};
|
||||
use zip::write::SimpleFileOptions;
|
||||
use zip::AesMode;
|
||||
use compliance_core::models::pentest::AttackChainNode;
|
||||
|
||||
/// Report archive with metadata
|
||||
pub struct ReportArchive {
|
||||
/// The password-protected ZIP bytes
|
||||
pub archive: Vec<u8>,
|
||||
/// SHA-256 hex digest of the archive
|
||||
pub sha256: String,
|
||||
}
|
||||
use super::ReportContext;
|
||||
|
||||
/// Report context gathered from the database
|
||||
pub struct ReportContext {
|
||||
pub session: PentestSession,
|
||||
pub target_name: String,
|
||||
pub target_url: String,
|
||||
pub findings: Vec<DastFinding>,
|
||||
pub attack_chain: Vec<AttackChainNode>,
|
||||
pub requester_name: String,
|
||||
pub requester_email: String,
|
||||
}
|
||||
|
||||
/// Generate a password-protected ZIP archive containing the pentest report.
|
||||
///
|
||||
/// The archive contains:
|
||||
/// - `report.pdf` — Professional pentest report (PDF)
|
||||
/// - `report.html` — HTML source (fallback)
|
||||
/// - `findings.json` — Raw findings data
|
||||
/// - `attack-chain.json` — Attack chain timeline
|
||||
///
|
||||
/// Files are encrypted with AES-256 inside the ZIP (standard WinZip AES format,
|
||||
/// supported by 7-Zip, WinRAR, macOS Archive Utility, etc.).
|
||||
pub async fn generate_encrypted_report(
|
||||
ctx: &ReportContext,
|
||||
password: &str,
|
||||
) -> Result<ReportArchive, String> {
|
||||
let html = build_html_report(ctx);
|
||||
|
||||
// Convert HTML to PDF via headless Chrome
|
||||
let pdf_bytes = html_to_pdf(&html).await?;
|
||||
|
||||
let zip_bytes = build_zip(ctx, password, &html, &pdf_bytes)
|
||||
.map_err(|e| format!("Failed to create archive: {e}"))?;
|
||||
|
||||
let mut hasher = Sha256::new();
|
||||
hasher.update(&zip_bytes);
|
||||
let sha256 = hex::encode(hasher.finalize());
|
||||
|
||||
Ok(ReportArchive { archive: zip_bytes, sha256 })
|
||||
}
|
||||
|
||||
/// Convert HTML string to PDF bytes using headless Chrome/Chromium.
|
||||
async fn html_to_pdf(html: &str) -> Result<Vec<u8>, String> {
|
||||
let tmp_dir = std::env::temp_dir();
|
||||
let run_id = uuid::Uuid::new_v4().to_string();
|
||||
let html_path = tmp_dir.join(format!("pentest-report-{run_id}.html"));
|
||||
let pdf_path = tmp_dir.join(format!("pentest-report-{run_id}.pdf"));
|
||||
|
||||
// Write HTML to temp file
|
||||
std::fs::write(&html_path, html)
|
||||
.map_err(|e| format!("Failed to write temp HTML: {e}"))?;
|
||||
|
||||
// Find Chrome/Chromium binary
|
||||
let chrome_bin = find_chrome_binary()
|
||||
.ok_or_else(|| "Chrome/Chromium not found. Install google-chrome or chromium to generate PDF reports.".to_string())?;
|
||||
|
||||
tracing::info!(chrome = %chrome_bin, "Generating PDF report via headless Chrome");
|
||||
|
||||
let html_url = format!("file://{}", html_path.display());
|
||||
|
||||
let output = tokio::process::Command::new(&chrome_bin)
|
||||
.args([
|
||||
"--headless",
|
||||
"--disable-gpu",
|
||||
"--no-sandbox",
|
||||
"--disable-software-rasterizer",
|
||||
"--run-all-compositor-stages-before-draw",
|
||||
"--disable-dev-shm-usage",
|
||||
&format!("--print-to-pdf={}", pdf_path.display()),
|
||||
"--no-pdf-header-footer",
|
||||
&html_url,
|
||||
])
|
||||
.output()
|
||||
.await
|
||||
.map_err(|e| format!("Failed to run Chrome: {e}"))?;
|
||||
|
||||
if !output.status.success() {
|
||||
let stderr = String::from_utf8_lossy(&output.stderr);
|
||||
// Clean up temp files
|
||||
let _ = std::fs::remove_file(&html_path);
|
||||
let _ = std::fs::remove_file(&pdf_path);
|
||||
return Err(format!("Chrome PDF generation failed: {stderr}"));
|
||||
}
|
||||
|
||||
let pdf_bytes = std::fs::read(&pdf_path)
|
||||
.map_err(|e| format!("Failed to read generated PDF: {e}"))?;
|
||||
|
||||
// Clean up temp files
|
||||
let _ = std::fs::remove_file(&html_path);
|
||||
let _ = std::fs::remove_file(&pdf_path);
|
||||
|
||||
if pdf_bytes.is_empty() {
|
||||
return Err("Chrome produced an empty PDF".to_string());
|
||||
}
|
||||
|
||||
tracing::info!(size_kb = pdf_bytes.len() / 1024, "PDF report generated");
|
||||
Ok(pdf_bytes)
|
||||
}
|
||||
|
||||
/// Search for Chrome/Chromium binary on the system.
|
||||
fn find_chrome_binary() -> Option<String> {
|
||||
let candidates = [
|
||||
"google-chrome-stable",
|
||||
"google-chrome",
|
||||
"chromium-browser",
|
||||
"chromium",
|
||||
];
|
||||
for name in &candidates {
|
||||
if let Ok(output) = std::process::Command::new("which").arg(name).output() {
|
||||
if output.status.success() {
|
||||
let path = String::from_utf8_lossy(&output.stdout).trim().to_string();
|
||||
if !path.is_empty() {
|
||||
return Some(path);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
fn build_zip(
|
||||
ctx: &ReportContext,
|
||||
password: &str,
|
||||
html: &str,
|
||||
pdf: &[u8],
|
||||
) -> Result<Vec<u8>, zip::result::ZipError> {
|
||||
let buf = Cursor::new(Vec::new());
|
||||
let mut zip = zip::ZipWriter::new(buf);
|
||||
|
||||
let options = SimpleFileOptions::default()
|
||||
.compression_method(zip::CompressionMethod::Deflated)
|
||||
.with_aes_encryption(AesMode::Aes256, password);
|
||||
|
||||
// report.pdf (primary)
|
||||
zip.start_file("report.pdf", options.clone())?;
|
||||
zip.write_all(pdf)?;
|
||||
|
||||
// report.html (fallback)
|
||||
zip.start_file("report.html", options.clone())?;
|
||||
zip.write_all(html.as_bytes())?;
|
||||
|
||||
// findings.json
|
||||
let findings_json =
|
||||
serde_json::to_string_pretty(&ctx.findings).unwrap_or_else(|_| "[]".to_string());
|
||||
zip.start_file("findings.json", options.clone())?;
|
||||
zip.write_all(findings_json.as_bytes())?;
|
||||
|
||||
// attack-chain.json
|
||||
let chain_json =
|
||||
serde_json::to_string_pretty(&ctx.attack_chain).unwrap_or_else(|_| "[]".to_string());
|
||||
zip.start_file("attack-chain.json", options)?;
|
||||
zip.write_all(chain_json.as_bytes())?;
|
||||
|
||||
let cursor = zip.finish()?;
|
||||
Ok(cursor.into_inner())
|
||||
}
|
||||
|
||||
fn build_html_report(ctx: &ReportContext) -> String {
|
||||
#[allow(clippy::format_in_format_args)]
|
||||
pub(super) fn build_html_report(ctx: &ReportContext) -> String {
|
||||
let session = &ctx.session;
|
||||
let session_id = session
|
||||
.id
|
||||
.map(|oid| oid.to_hex())
|
||||
.unwrap_or_else(|| "-".to_string());
|
||||
let date_str = session.started_at.format("%B %d, %Y at %H:%M UTC").to_string();
|
||||
let date_str = session
|
||||
.started_at
|
||||
.format("%B %d, %Y at %H:%M UTC")
|
||||
.to_string();
|
||||
let date_short = session.started_at.format("%B %d, %Y").to_string();
|
||||
let completed_str = session
|
||||
.completed_at
|
||||
.map(|d| d.format("%B %d, %Y at %H:%M UTC").to_string())
|
||||
.unwrap_or_else(|| "In Progress".to_string());
|
||||
|
||||
let critical = ctx.findings.iter().filter(|f| f.severity.to_string() == "critical").count();
|
||||
let high = ctx.findings.iter().filter(|f| f.severity.to_string() == "high").count();
|
||||
let medium = ctx.findings.iter().filter(|f| f.severity.to_string() == "medium").count();
|
||||
let low = ctx.findings.iter().filter(|f| f.severity.to_string() == "low").count();
|
||||
let info = ctx.findings.iter().filter(|f| f.severity.to_string() == "info").count();
|
||||
let critical = ctx
|
||||
.findings
|
||||
.iter()
|
||||
.filter(|f| f.severity.to_string() == "critical")
|
||||
.count();
|
||||
let high = ctx
|
||||
.findings
|
||||
.iter()
|
||||
.filter(|f| f.severity.to_string() == "high")
|
||||
.count();
|
||||
let medium = ctx
|
||||
.findings
|
||||
.iter()
|
||||
.filter(|f| f.severity.to_string() == "medium")
|
||||
.count();
|
||||
let low = ctx
|
||||
.findings
|
||||
.iter()
|
||||
.filter(|f| f.severity.to_string() == "low")
|
||||
.count();
|
||||
let info = ctx
|
||||
.findings
|
||||
.iter()
|
||||
.filter(|f| f.severity.to_string() == "info")
|
||||
.count();
|
||||
let exploitable = ctx.findings.iter().filter(|f| f.exploitable).count();
|
||||
let total = ctx.findings.len();
|
||||
|
||||
@@ -212,10 +69,8 @@ fn build_html_report(ctx: &ReportContext) -> String {
|
||||
};
|
||||
|
||||
// Risk score 0-100
|
||||
let risk_score: usize = std::cmp::min(
|
||||
100,
|
||||
critical * 25 + high * 15 + medium * 8 + low * 3 + info * 1,
|
||||
);
|
||||
let risk_score: usize =
|
||||
std::cmp::min(100, critical * 25 + high * 15 + medium * 8 + low * 3 + info);
|
||||
|
||||
// Collect unique tool names used
|
||||
let tool_names: Vec<String> = {
|
||||
@@ -247,7 +102,8 @@ fn build_html_report(ctx: &ReportContext) -> String {
|
||||
if high > 0 {
|
||||
bar.push_str(&format!(
|
||||
r#"<div class="sev-bar-seg sev-bar-high" style="width:{}%"><span>{}</span></div>"#,
|
||||
std::cmp::max(high_pct, 4), high
|
||||
std::cmp::max(high_pct, 4),
|
||||
high
|
||||
));
|
||||
}
|
||||
if medium > 0 {
|
||||
@@ -259,22 +115,38 @@ fn build_html_report(ctx: &ReportContext) -> String {
|
||||
if low > 0 {
|
||||
bar.push_str(&format!(
|
||||
r#"<div class="sev-bar-seg sev-bar-low" style="width:{}%"><span>{}</span></div>"#,
|
||||
std::cmp::max(low_pct, 4), low
|
||||
std::cmp::max(low_pct, 4),
|
||||
low
|
||||
));
|
||||
}
|
||||
if info > 0 {
|
||||
bar.push_str(&format!(
|
||||
r#"<div class="sev-bar-seg sev-bar-info" style="width:{}%"><span>{}</span></div>"#,
|
||||
std::cmp::max(info_pct, 4), info
|
||||
std::cmp::max(info_pct, 4),
|
||||
info
|
||||
));
|
||||
}
|
||||
bar.push_str("</div>");
|
||||
bar.push_str(r#"<div class="sev-bar-legend">"#);
|
||||
if critical > 0 { bar.push_str(r#"<span><i class="sev-dot" style="background:#991b1b"></i> Critical</span>"#); }
|
||||
if high > 0 { bar.push_str(r#"<span><i class="sev-dot" style="background:#c2410c"></i> High</span>"#); }
|
||||
if medium > 0 { bar.push_str(r#"<span><i class="sev-dot" style="background:#a16207"></i> Medium</span>"#); }
|
||||
if low > 0 { bar.push_str(r#"<span><i class="sev-dot" style="background:#1d4ed8"></i> Low</span>"#); }
|
||||
if info > 0 { bar.push_str(r#"<span><i class="sev-dot" style="background:#4b5563"></i> Info</span>"#); }
|
||||
if critical > 0 {
|
||||
bar.push_str(
|
||||
r#"<span><i class="sev-dot" style="background:#991b1b"></i> Critical</span>"#,
|
||||
);
|
||||
}
|
||||
if high > 0 {
|
||||
bar.push_str(r#"<span><i class="sev-dot" style="background:#c2410c"></i> High</span>"#);
|
||||
}
|
||||
if medium > 0 {
|
||||
bar.push_str(
|
||||
r#"<span><i class="sev-dot" style="background:#a16207"></i> Medium</span>"#,
|
||||
);
|
||||
}
|
||||
if low > 0 {
|
||||
bar.push_str(r#"<span><i class="sev-dot" style="background:#1d4ed8"></i> Low</span>"#);
|
||||
}
|
||||
if info > 0 {
|
||||
bar.push_str(r#"<span><i class="sev-dot" style="background:#4b5563"></i> Info</span>"#);
|
||||
}
|
||||
bar.push_str("</div>");
|
||||
bar
|
||||
} else {
|
||||
@@ -322,7 +194,12 @@ fn build_html_report(ctx: &ReportContext) -> String {
|
||||
let param_row = f
|
||||
.parameter
|
||||
.as_deref()
|
||||
.map(|p| format!("<tr><td>Parameter</td><td><code>{}</code></td></tr>", html_escape(p)))
|
||||
.map(|p| {
|
||||
format!(
|
||||
"<tr><td>Parameter</td><td><code>{}</code></td></tr>",
|
||||
html_escape(p)
|
||||
)
|
||||
})
|
||||
.unwrap_or_default();
|
||||
let remediation = f
|
||||
.remediation
|
||||
@@ -332,7 +209,9 @@ fn build_html_report(ctx: &ReportContext) -> String {
|
||||
let evidence_html = if f.evidence.is_empty() {
|
||||
String::new()
|
||||
} else {
|
||||
let mut eh = String::from(r#"<div class="evidence-block"><div class="evidence-title">Evidence</div><table class="evidence-table"><thead><tr><th>Request</th><th>Status</th><th>Details</th></tr></thead><tbody>"#);
|
||||
let mut eh = String::from(
|
||||
r#"<div class="evidence-block"><div class="evidence-title">Evidence</div><table class="evidence-table"><thead><tr><th>Request</th><th>Status</th><th>Details</th></tr></thead><tbody>"#,
|
||||
);
|
||||
for ev in &f.evidence {
|
||||
let payload_info = ev
|
||||
.payload
|
||||
@@ -346,7 +225,7 @@ fn build_html_report(ctx: &ReportContext) -> String {
|
||||
ev.response_status,
|
||||
ev.response_snippet
|
||||
.as_deref()
|
||||
.map(|s| html_escape(s))
|
||||
.map(html_escape)
|
||||
.unwrap_or_default(),
|
||||
payload_info,
|
||||
));
|
||||
@@ -402,7 +281,8 @@ fn build_html_report(ctx: &ReportContext) -> String {
|
||||
let mut chain_html = String::new();
|
||||
if !ctx.attack_chain.is_empty() {
|
||||
// Compute phases via BFS from root nodes
|
||||
let mut phase_map: std::collections::HashMap<String, usize> = std::collections::HashMap::new();
|
||||
let mut phase_map: std::collections::HashMap<String, usize> =
|
||||
std::collections::HashMap::new();
|
||||
let mut queue: std::collections::VecDeque<String> = std::collections::VecDeque::new();
|
||||
|
||||
for node in &ctx.attack_chain {
|
||||
@@ -438,7 +318,13 @@ fn build_html_report(ctx: &ReportContext) -> String {
|
||||
|
||||
// Group nodes by phase
|
||||
let max_phase = phase_map.values().copied().max().unwrap_or(0);
|
||||
let phase_labels = ["Reconnaissance", "Enumeration", "Exploitation", "Validation", "Post-Exploitation"];
|
||||
let phase_labels = [
|
||||
"Reconnaissance",
|
||||
"Enumeration",
|
||||
"Exploitation",
|
||||
"Validation",
|
||||
"Post-Exploitation",
|
||||
];
|
||||
|
||||
for phase_idx in 0..=max_phase {
|
||||
let phase_nodes: Vec<&AttackChainNode> = ctx
|
||||
@@ -485,15 +371,28 @@ fn build_html_report(ctx: &ReportContext) -> String {
|
||||
format!(
|
||||
r#"<span class="step-findings">{} finding{}</span>"#,
|
||||
node.findings_produced.len(),
|
||||
if node.findings_produced.len() == 1 { "" } else { "s" },
|
||||
if node.findings_produced.len() == 1 {
|
||||
""
|
||||
} else {
|
||||
"s"
|
||||
},
|
||||
)
|
||||
} else {
|
||||
String::new()
|
||||
};
|
||||
let risk_badge = node.risk_score.map(|r| {
|
||||
let risk_class = if r >= 70 { "risk-high" } else if r >= 40 { "risk-med" } else { "risk-low" };
|
||||
format!(r#"<span class="step-risk {risk_class}">Risk: {r}</span>"#)
|
||||
}).unwrap_or_default();
|
||||
let risk_badge = node
|
||||
.risk_score
|
||||
.map(|r| {
|
||||
let risk_class = if r >= 70 {
|
||||
"risk-high"
|
||||
} else if r >= 40 {
|
||||
"risk-med"
|
||||
} else {
|
||||
"risk-low"
|
||||
};
|
||||
format!(r#"<span class="step-risk {risk_class}">Risk: {r}</span>"#)
|
||||
})
|
||||
.unwrap_or_default();
|
||||
|
||||
let reasoning_html = if node.llm_reasoning.is_empty() {
|
||||
String::new()
|
||||
@@ -547,10 +446,20 @@ fn build_html_report(ctx: &ReportContext) -> String {
|
||||
let toc_findings_sub = if !ctx.findings.is_empty() {
|
||||
let mut sub = String::new();
|
||||
let mut fnum = 0usize;
|
||||
for (si, &sev_key) in severity_order.iter().enumerate() {
|
||||
let count = ctx.findings.iter().filter(|f| f.severity.to_string() == sev_key).count();
|
||||
if count == 0 { continue; }
|
||||
for f in ctx.findings.iter().filter(|f| f.severity.to_string() == sev_key) {
|
||||
for &sev_key in severity_order.iter() {
|
||||
let count = ctx
|
||||
.findings
|
||||
.iter()
|
||||
.filter(|f| f.severity.to_string() == sev_key)
|
||||
.count();
|
||||
if count == 0 {
|
||||
continue;
|
||||
}
|
||||
for f in ctx
|
||||
.findings
|
||||
.iter()
|
||||
.filter(|f| f.severity.to_string() == sev_key)
|
||||
{
|
||||
fnum += 1;
|
||||
sub.push_str(&format!(
|
||||
r#"<div class="toc-sub">F-{:03} — {}</div>"#,
|
||||
@@ -1577,19 +1486,49 @@ table.tools-table td:first-child {{
|
||||
|
||||
fn tool_category(tool_name: &str) -> &'static str {
|
||||
let name = tool_name.to_lowercase();
|
||||
if name.contains("nmap") || name.contains("port") { return "Network Reconnaissance"; }
|
||||
if name.contains("nikto") || name.contains("header") { return "Web Server Analysis"; }
|
||||
if name.contains("zap") || name.contains("spider") || name.contains("crawl") { return "Web Application Scanning"; }
|
||||
if name.contains("sqlmap") || name.contains("sqli") || name.contains("sql") { return "SQL Injection Testing"; }
|
||||
if name.contains("xss") || name.contains("cross-site") { return "Cross-Site Scripting Testing"; }
|
||||
if name.contains("dir") || name.contains("brute") || name.contains("fuzz") || name.contains("gobuster") { return "Directory Enumeration"; }
|
||||
if name.contains("ssl") || name.contains("tls") || name.contains("cert") { return "SSL/TLS Analysis"; }
|
||||
if name.contains("api") || name.contains("endpoint") { return "API Security Testing"; }
|
||||
if name.contains("auth") || name.contains("login") || name.contains("credential") { return "Authentication Testing"; }
|
||||
if name.contains("cors") { return "CORS Testing"; }
|
||||
if name.contains("csrf") { return "CSRF Testing"; }
|
||||
if name.contains("nuclei") || name.contains("template") { return "Vulnerability Scanning"; }
|
||||
if name.contains("whatweb") || name.contains("tech") || name.contains("wappalyzer") { return "Technology Fingerprinting"; }
|
||||
if name.contains("nmap") || name.contains("port") {
|
||||
return "Network Reconnaissance";
|
||||
}
|
||||
if name.contains("nikto") || name.contains("header") {
|
||||
return "Web Server Analysis";
|
||||
}
|
||||
if name.contains("zap") || name.contains("spider") || name.contains("crawl") {
|
||||
return "Web Application Scanning";
|
||||
}
|
||||
if name.contains("sqlmap") || name.contains("sqli") || name.contains("sql") {
|
||||
return "SQL Injection Testing";
|
||||
}
|
||||
if name.contains("xss") || name.contains("cross-site") {
|
||||
return "Cross-Site Scripting Testing";
|
||||
}
|
||||
if name.contains("dir")
|
||||
|| name.contains("brute")
|
||||
|| name.contains("fuzz")
|
||||
|| name.contains("gobuster")
|
||||
{
|
||||
return "Directory Enumeration";
|
||||
}
|
||||
if name.contains("ssl") || name.contains("tls") || name.contains("cert") {
|
||||
return "SSL/TLS Analysis";
|
||||
}
|
||||
if name.contains("api") || name.contains("endpoint") {
|
||||
return "API Security Testing";
|
||||
}
|
||||
if name.contains("auth") || name.contains("login") || name.contains("credential") {
|
||||
return "Authentication Testing";
|
||||
}
|
||||
if name.contains("cors") {
|
||||
return "CORS Testing";
|
||||
}
|
||||
if name.contains("csrf") {
|
||||
return "CSRF Testing";
|
||||
}
|
||||
if name.contains("nuclei") || name.contains("template") {
|
||||
return "Vulnerability Scanning";
|
||||
}
|
||||
if name.contains("whatweb") || name.contains("tech") || name.contains("wappalyzer") {
|
||||
return "Technology Fingerprinting";
|
||||
}
|
||||
"Security Testing"
|
||||
}
|
||||
|
||||
@@ -1599,3 +1538,314 @@ fn html_escape(s: &str) -> String {
|
||||
.replace('>', ">")
|
||||
.replace('"', """)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use compliance_core::models::dast::{DastFinding, DastVulnType};
|
||||
use compliance_core::models::finding::Severity;
|
||||
use compliance_core::models::pentest::{
|
||||
AttackChainNode, AttackNodeStatus, PentestSession, PentestStrategy,
|
||||
};
|
||||
|
||||
// ── html_escape ──────────────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn html_escape_handles_ampersand() {
|
||||
assert_eq!(html_escape("a & b"), "a & b");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn html_escape_handles_angle_brackets() {
|
||||
assert_eq!(html_escape("<script>"), "<script>");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn html_escape_handles_quotes() {
|
||||
assert_eq!(html_escape(r#"key="val""#), "key="val"");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn html_escape_handles_all_special_chars() {
|
||||
assert_eq!(
|
||||
html_escape(r#"<a href="x">&y</a>"#),
|
||||
"<a href="x">&y</a>"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn html_escape_no_change_for_plain_text() {
|
||||
assert_eq!(html_escape("hello world"), "hello world");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn html_escape_empty_string() {
|
||||
assert_eq!(html_escape(""), "");
|
||||
}
|
||||
|
||||
// ── tool_category ────────────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn tool_category_nmap() {
|
||||
assert_eq!(tool_category("nmap_scan"), "Network Reconnaissance");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tool_category_port_scanner() {
|
||||
assert_eq!(tool_category("port_scanner"), "Network Reconnaissance");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tool_category_nikto() {
|
||||
assert_eq!(tool_category("nikto"), "Web Server Analysis");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tool_category_header_check() {
|
||||
assert_eq!(
|
||||
tool_category("security_header_check"),
|
||||
"Web Server Analysis"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tool_category_zap_spider() {
|
||||
assert_eq!(tool_category("zap_spider"), "Web Application Scanning");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tool_category_sqlmap() {
|
||||
assert_eq!(tool_category("sqlmap"), "SQL Injection Testing");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tool_category_xss_scanner() {
|
||||
assert_eq!(tool_category("xss_scanner"), "Cross-Site Scripting Testing");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tool_category_dir_bruteforce() {
|
||||
assert_eq!(tool_category("dir_bruteforce"), "Directory Enumeration");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tool_category_gobuster() {
|
||||
assert_eq!(tool_category("gobuster"), "Directory Enumeration");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tool_category_ssl_check() {
|
||||
assert_eq!(tool_category("ssl_check"), "SSL/TLS Analysis");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tool_category_tls_scan() {
|
||||
assert_eq!(tool_category("tls_scan"), "SSL/TLS Analysis");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tool_category_api_test() {
|
||||
assert_eq!(tool_category("api_endpoint_test"), "API Security Testing");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tool_category_auth_bypass() {
|
||||
assert_eq!(tool_category("auth_bypass_check"), "Authentication Testing");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tool_category_cors() {
|
||||
assert_eq!(tool_category("cors_check"), "CORS Testing");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tool_category_csrf() {
|
||||
assert_eq!(tool_category("csrf_scanner"), "CSRF Testing");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tool_category_nuclei() {
|
||||
assert_eq!(tool_category("nuclei"), "Vulnerability Scanning");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tool_category_whatweb() {
|
||||
assert_eq!(tool_category("whatweb"), "Technology Fingerprinting");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tool_category_unknown_defaults_to_security_testing() {
|
||||
assert_eq!(tool_category("custom_tool"), "Security Testing");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tool_category_is_case_insensitive() {
|
||||
assert_eq!(tool_category("NMAP_Scanner"), "Network Reconnaissance");
|
||||
assert_eq!(tool_category("SQLMap"), "SQL Injection Testing");
|
||||
}
|
||||
|
||||
// ── build_html_report ────────────────────────────────────────────
|
||||
|
||||
fn make_session(strategy: PentestStrategy) -> PentestSession {
|
||||
let mut s = PentestSession::new("target-1".into(), strategy);
|
||||
s.tool_invocations = 5;
|
||||
s.tool_successes = 4;
|
||||
s.findings_count = 2;
|
||||
s.exploitable_count = 1;
|
||||
s
|
||||
}
|
||||
|
||||
fn make_finding(severity: Severity, title: &str, exploitable: bool) -> DastFinding {
|
||||
let mut f = DastFinding::new(
|
||||
"run-1".into(),
|
||||
"target-1".into(),
|
||||
DastVulnType::Xss,
|
||||
title.into(),
|
||||
"description".into(),
|
||||
severity,
|
||||
"https://example.com/test".into(),
|
||||
"GET".into(),
|
||||
);
|
||||
f.exploitable = exploitable;
|
||||
f
|
||||
}
|
||||
|
||||
fn make_attack_node(tool_name: &str) -> AttackChainNode {
|
||||
let mut node = AttackChainNode::new(
|
||||
"session-1".into(),
|
||||
"node-1".into(),
|
||||
tool_name.into(),
|
||||
serde_json::json!({}),
|
||||
"Testing this tool".into(),
|
||||
);
|
||||
node.status = AttackNodeStatus::Completed;
|
||||
node
|
||||
}
|
||||
|
||||
fn make_report_context(
|
||||
findings: Vec<DastFinding>,
|
||||
chain: Vec<AttackChainNode>,
|
||||
) -> ReportContext {
|
||||
ReportContext {
|
||||
session: make_session(PentestStrategy::Comprehensive),
|
||||
target_name: "Test App".into(),
|
||||
target_url: "https://example.com".into(),
|
||||
findings,
|
||||
attack_chain: chain,
|
||||
requester_name: "Alice".into(),
|
||||
requester_email: "alice@example.com".into(),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn report_contains_target_info() {
|
||||
let ctx = make_report_context(vec![], vec![]);
|
||||
let html = build_html_report(&ctx);
|
||||
assert!(html.contains("Test App"));
|
||||
assert!(html.contains("https://example.com"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn report_contains_requester_info() {
|
||||
let ctx = make_report_context(vec![], vec![]);
|
||||
let html = build_html_report(&ctx);
|
||||
assert!(html.contains("Alice"));
|
||||
assert!(html.contains("alice@example.com"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn report_shows_informational_risk_when_no_findings() {
|
||||
let ctx = make_report_context(vec![], vec![]);
|
||||
let html = build_html_report(&ctx);
|
||||
assert!(html.contains("INFORMATIONAL"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn report_shows_critical_risk_with_critical_finding() {
|
||||
let findings = vec![make_finding(Severity::Critical, "Critical XSS", true)];
|
||||
let ctx = make_report_context(findings, vec![]);
|
||||
let html = build_html_report(&ctx);
|
||||
assert!(html.contains("CRITICAL"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn report_shows_high_risk_without_critical() {
|
||||
let findings = vec![make_finding(Severity::High, "High SQLi", false)];
|
||||
let ctx = make_report_context(findings, vec![]);
|
||||
let html = build_html_report(&ctx);
|
||||
// Should show HIGH, not CRITICAL
|
||||
assert!(html.contains("HIGH"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn report_shows_medium_risk_level() {
|
||||
let findings = vec![make_finding(Severity::Medium, "Medium Issue", false)];
|
||||
let ctx = make_report_context(findings, vec![]);
|
||||
let html = build_html_report(&ctx);
|
||||
assert!(html.contains("MEDIUM"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn report_includes_finding_title() {
|
||||
let findings = vec![make_finding(
|
||||
Severity::High,
|
||||
"Reflected XSS in /search",
|
||||
true,
|
||||
)];
|
||||
let ctx = make_report_context(findings, vec![]);
|
||||
let html = build_html_report(&ctx);
|
||||
assert!(html.contains("Reflected XSS in /search"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn report_shows_exploitable_badge() {
|
||||
let findings = vec![make_finding(Severity::Critical, "SQLi", true)];
|
||||
let ctx = make_report_context(findings, vec![]);
|
||||
let html = build_html_report(&ctx);
|
||||
// The report should mark exploitable findings
|
||||
assert!(html.contains("EXPLOITABLE"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn report_includes_attack_chain_tool_names() {
|
||||
let chain = vec![make_attack_node("nmap_scan"), make_attack_node("sqlmap")];
|
||||
let ctx = make_report_context(vec![], chain);
|
||||
let html = build_html_report(&ctx);
|
||||
assert!(html.contains("nmap_scan"));
|
||||
assert!(html.contains("sqlmap"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn report_is_valid_html_structure() {
|
||||
let ctx = make_report_context(vec![], vec![]);
|
||||
let html = build_html_report(&ctx);
|
||||
assert!(html.contains("<!DOCTYPE html>") || html.contains("<html"));
|
||||
assert!(html.contains("</html>"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn report_strategy_appears() {
|
||||
let ctx = make_report_context(vec![], vec![]);
|
||||
let html = build_html_report(&ctx);
|
||||
// PentestStrategy::Comprehensive => "comprehensive"
|
||||
assert!(html.contains("comprehensive") || html.contains("Comprehensive"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn report_finding_count_is_correct() {
|
||||
let findings = vec![
|
||||
make_finding(Severity::Critical, "F1", true),
|
||||
make_finding(Severity::High, "F2", false),
|
||||
make_finding(Severity::Low, "F3", false),
|
||||
];
|
||||
let ctx = make_report_context(findings, vec![]);
|
||||
let html = build_html_report(&ctx);
|
||||
// The total count "3" should appear somewhere
|
||||
assert!(
|
||||
html.contains(">3<")
|
||||
|| html.contains(">3 ")
|
||||
|| html.contains("3 findings")
|
||||
|| html.contains("3 Total")
|
||||
);
|
||||
}
|
||||
}
|
||||
58
compliance-agent/src/pentest/report/mod.rs
Normal file
58
compliance-agent/src/pentest/report/mod.rs
Normal file
@@ -0,0 +1,58 @@
|
||||
mod archive;
|
||||
mod html;
|
||||
mod pdf;
|
||||
|
||||
use compliance_core::models::dast::DastFinding;
|
||||
use compliance_core::models::pentest::{AttackChainNode, PentestSession};
|
||||
use sha2::{Digest, Sha256};
|
||||
|
||||
/// Report archive with metadata
|
||||
pub struct ReportArchive {
|
||||
/// The password-protected ZIP bytes
|
||||
pub archive: Vec<u8>,
|
||||
/// SHA-256 hex digest of the archive
|
||||
pub sha256: String,
|
||||
}
|
||||
|
||||
/// Report context gathered from the database
|
||||
pub struct ReportContext {
|
||||
pub session: PentestSession,
|
||||
pub target_name: String,
|
||||
pub target_url: String,
|
||||
pub findings: Vec<DastFinding>,
|
||||
pub attack_chain: Vec<AttackChainNode>,
|
||||
pub requester_name: String,
|
||||
pub requester_email: String,
|
||||
}
|
||||
|
||||
/// Generate a password-protected ZIP archive containing the pentest report.
|
||||
///
|
||||
/// The archive contains:
|
||||
/// - `report.pdf` — Professional pentest report (PDF)
|
||||
/// - `report.html` — HTML source (fallback)
|
||||
/// - `findings.json` — Raw findings data
|
||||
/// - `attack-chain.json` — Attack chain timeline
|
||||
///
|
||||
/// Files are encrypted with AES-256 inside the ZIP (standard WinZip AES format,
|
||||
/// supported by 7-Zip, WinRAR, macOS Archive Utility, etc.).
|
||||
pub async fn generate_encrypted_report(
|
||||
ctx: &ReportContext,
|
||||
password: &str,
|
||||
) -> Result<ReportArchive, String> {
|
||||
let html = html::build_html_report(ctx);
|
||||
|
||||
// Convert HTML to PDF via headless Chrome
|
||||
let pdf_bytes = pdf::html_to_pdf(&html).await?;
|
||||
|
||||
let zip_bytes = archive::build_zip(ctx, password, &html, &pdf_bytes)
|
||||
.map_err(|e| format!("Failed to create archive: {e}"))?;
|
||||
|
||||
let mut hasher = Sha256::new();
|
||||
hasher.update(&zip_bytes);
|
||||
let sha256 = hex::encode(hasher.finalize());
|
||||
|
||||
Ok(ReportArchive {
|
||||
archive: zip_bytes,
|
||||
sha256,
|
||||
})
|
||||
}
|
||||
79
compliance-agent/src/pentest/report/pdf.rs
Normal file
79
compliance-agent/src/pentest/report/pdf.rs
Normal file
@@ -0,0 +1,79 @@
|
||||
/// Convert HTML string to PDF bytes using headless Chrome/Chromium.
|
||||
pub(super) async fn html_to_pdf(html: &str) -> Result<Vec<u8>, String> {
|
||||
let tmp_dir = std::env::temp_dir();
|
||||
let run_id = uuid::Uuid::new_v4().to_string();
|
||||
let html_path = tmp_dir.join(format!("pentest-report-{run_id}.html"));
|
||||
let pdf_path = tmp_dir.join(format!("pentest-report-{run_id}.pdf"));
|
||||
|
||||
// Write HTML to temp file
|
||||
std::fs::write(&html_path, html).map_err(|e| format!("Failed to write temp HTML: {e}"))?;
|
||||
|
||||
// Find Chrome/Chromium binary
|
||||
let chrome_bin = find_chrome_binary().ok_or_else(|| {
|
||||
"Chrome/Chromium not found. Install google-chrome or chromium to generate PDF reports."
|
||||
.to_string()
|
||||
})?;
|
||||
|
||||
tracing::info!(chrome = %chrome_bin, "Generating PDF report via headless Chrome");
|
||||
|
||||
let html_url = format!("file://{}", html_path.display());
|
||||
|
||||
let output = tokio::process::Command::new(&chrome_bin)
|
||||
.args([
|
||||
"--headless",
|
||||
"--disable-gpu",
|
||||
"--no-sandbox",
|
||||
"--disable-software-rasterizer",
|
||||
"--run-all-compositor-stages-before-draw",
|
||||
"--disable-dev-shm-usage",
|
||||
&format!("--print-to-pdf={}", pdf_path.display()),
|
||||
"--no-pdf-header-footer",
|
||||
&html_url,
|
||||
])
|
||||
.output()
|
||||
.await
|
||||
.map_err(|e| format!("Failed to run Chrome: {e}"))?;
|
||||
|
||||
if !output.status.success() {
|
||||
let stderr = String::from_utf8_lossy(&output.stderr);
|
||||
// Clean up temp files
|
||||
let _ = std::fs::remove_file(&html_path);
|
||||
let _ = std::fs::remove_file(&pdf_path);
|
||||
return Err(format!("Chrome PDF generation failed: {stderr}"));
|
||||
}
|
||||
|
||||
let pdf_bytes =
|
||||
std::fs::read(&pdf_path).map_err(|e| format!("Failed to read generated PDF: {e}"))?;
|
||||
|
||||
// Clean up temp files
|
||||
let _ = std::fs::remove_file(&html_path);
|
||||
let _ = std::fs::remove_file(&pdf_path);
|
||||
|
||||
if pdf_bytes.is_empty() {
|
||||
return Err("Chrome produced an empty PDF".to_string());
|
||||
}
|
||||
|
||||
tracing::info!(size_kb = pdf_bytes.len() / 1024, "PDF report generated");
|
||||
Ok(pdf_bytes)
|
||||
}
|
||||
|
||||
/// Search for Chrome/Chromium binary on the system.
|
||||
fn find_chrome_binary() -> Option<String> {
|
||||
let candidates = [
|
||||
"google-chrome-stable",
|
||||
"google-chrome",
|
||||
"chromium-browser",
|
||||
"chromium",
|
||||
];
|
||||
for name in &candidates {
|
||||
if let Ok(output) = std::process::Command::new("which").arg(name).output() {
|
||||
if output.status.success() {
|
||||
let path = String::from_utf8_lossy(&output.stdout).trim().to_string();
|
||||
if !path.is_empty() {
|
||||
return Some(path);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
@@ -8,3 +8,51 @@ pub fn compute_fingerprint(parts: &[&str]) -> String {
|
||||
}
|
||||
hex::encode(hasher.finalize())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn fingerprint_is_deterministic() {
|
||||
let a = compute_fingerprint(&["repo1", "rule-x", "src/main.rs", "42"]);
|
||||
let b = compute_fingerprint(&["repo1", "rule-x", "src/main.rs", "42"]);
|
||||
assert_eq!(a, b);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn fingerprint_changes_with_different_input() {
|
||||
let a = compute_fingerprint(&["repo1", "rule-x", "src/main.rs", "42"]);
|
||||
let b = compute_fingerprint(&["repo1", "rule-x", "src/main.rs", "43"]);
|
||||
assert_ne!(a, b);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn fingerprint_is_valid_hex_sha256() {
|
||||
let fp = compute_fingerprint(&["hello"]);
|
||||
assert_eq!(fp.len(), 64, "SHA-256 hex should be 64 chars");
|
||||
assert!(fp.chars().all(|c| c.is_ascii_hexdigit()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn fingerprint_empty_parts() {
|
||||
let fp = compute_fingerprint(&[]);
|
||||
// Should still produce a valid hash (of empty input)
|
||||
assert_eq!(fp.len(), 64);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn fingerprint_order_matters() {
|
||||
let a = compute_fingerprint(&["a", "b"]);
|
||||
let b = compute_fingerprint(&["b", "a"]);
|
||||
assert_ne!(a, b);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn fingerprint_separator_prevents_collision() {
|
||||
// "ab" + "c" vs "a" + "bc" should differ because of the "|" separator
|
||||
let a = compute_fingerprint(&["ab", "c"]);
|
||||
let b = compute_fingerprint(&["a", "bc"]);
|
||||
assert_ne!(a, b);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -129,3 +129,110 @@ struct GitleaksResult {
|
||||
#[serde(rename = "Match")]
|
||||
r#match: String,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
// --- is_allowlisted tests ---
|
||||
|
||||
#[test]
|
||||
fn allowlisted_env_example_files() {
|
||||
assert!(is_allowlisted(".env.example"));
|
||||
assert!(is_allowlisted("config/.env.sample"));
|
||||
assert!(is_allowlisted("deploy/.ENV.TEMPLATE"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn allowlisted_test_directories() {
|
||||
assert!(is_allowlisted("src/test/config.json"));
|
||||
assert!(is_allowlisted("src/tests/fixtures.rs"));
|
||||
assert!(is_allowlisted("data/fixtures/secret.txt"));
|
||||
assert!(is_allowlisted("pkg/testdata/key.pem"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn allowlisted_mock_files() {
|
||||
assert!(is_allowlisted("src/mock_service.py"));
|
||||
assert!(is_allowlisted("lib/MockAuth.java"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn allowlisted_test_suffixes() {
|
||||
assert!(is_allowlisted("auth_test.go"));
|
||||
assert!(is_allowlisted("auth.test.ts"));
|
||||
assert!(is_allowlisted("auth.test.js"));
|
||||
assert!(is_allowlisted("auth.spec.ts"));
|
||||
assert!(is_allowlisted("auth.spec.js"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn not_allowlisted_regular_files() {
|
||||
assert!(!is_allowlisted("src/main.rs"));
|
||||
assert!(!is_allowlisted("config/.env"));
|
||||
assert!(!is_allowlisted("lib/auth.ts"));
|
||||
assert!(!is_allowlisted("deploy/secrets.yaml"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn not_allowlisted_partial_matches() {
|
||||
// "test" as substring in a non-directory context should not match
|
||||
assert!(!is_allowlisted("src/attestation.rs"));
|
||||
assert!(!is_allowlisted("src/contest/data.json"));
|
||||
}
|
||||
|
||||
// --- GitleaksResult deserialization tests ---
|
||||
|
||||
#[test]
|
||||
fn deserialize_gitleaks_result() {
|
||||
let json = r#"{
|
||||
"Description": "AWS Access Key",
|
||||
"RuleID": "aws-access-key",
|
||||
"File": "src/config.rs",
|
||||
"StartLine": 10,
|
||||
"Match": "AKIAIOSFODNN7EXAMPLE"
|
||||
}"#;
|
||||
let result: GitleaksResult = serde_json::from_str(json).unwrap();
|
||||
assert_eq!(result.description, "AWS Access Key");
|
||||
assert_eq!(result.rule_id, "aws-access-key");
|
||||
assert_eq!(result.file, "src/config.rs");
|
||||
assert_eq!(result.start_line, 10);
|
||||
assert_eq!(result.r#match, "AKIAIOSFODNN7EXAMPLE");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn deserialize_gitleaks_result_array() {
|
||||
let json = r#"[
|
||||
{
|
||||
"Description": "Generic Secret",
|
||||
"RuleID": "generic-secret",
|
||||
"File": "app.py",
|
||||
"StartLine": 5,
|
||||
"Match": "password=hunter2"
|
||||
}
|
||||
]"#;
|
||||
let results: Vec<GitleaksResult> = serde_json::from_str(json).unwrap();
|
||||
assert_eq!(results.len(), 1);
|
||||
assert_eq!(results[0].rule_id, "generic-secret");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn severity_mapping_private_key() {
|
||||
// Verify the severity logic from the scan method
|
||||
let rule_id = "some-private-key-rule";
|
||||
assert!(rule_id.contains("private-key"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn severity_mapping_token_password_secret() {
|
||||
for keyword in &["token", "password", "secret"] {
|
||||
let rule_id = format!("some-{}-rule", keyword);
|
||||
assert!(
|
||||
rule_id.contains("token")
|
||||
|| rule_id.contains("password")
|
||||
|| rule_id.contains("secret"),
|
||||
"Expected '{rule_id}' to match token/password/secret"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
106
compliance-agent/src/pipeline/graph_build.rs
Normal file
106
compliance-agent/src/pipeline/graph_build.rs
Normal file
@@ -0,0 +1,106 @@
|
||||
use compliance_core::models::Finding;
|
||||
|
||||
use super::orchestrator::{GraphContext, PipelineOrchestrator};
|
||||
use crate::error::AgentError;
|
||||
|
||||
impl PipelineOrchestrator {
|
||||
/// Build the code knowledge graph for a repo and compute impact analyses
|
||||
pub(super) async fn build_code_graph(
|
||||
&self,
|
||||
repo_path: &std::path::Path,
|
||||
repo_id: &str,
|
||||
findings: &[Finding],
|
||||
) -> Result<GraphContext, AgentError> {
|
||||
let graph_build_id = uuid::Uuid::new_v4().to_string();
|
||||
let engine = compliance_graph::GraphEngine::new(50_000);
|
||||
|
||||
let (mut code_graph, build_run) =
|
||||
engine
|
||||
.build_graph(repo_path, repo_id, &graph_build_id)
|
||||
.map_err(|e| AgentError::Other(format!("Graph build error: {e}")))?;
|
||||
|
||||
// Apply community detection
|
||||
compliance_graph::graph::community::apply_communities(&mut code_graph);
|
||||
|
||||
// Store graph in MongoDB
|
||||
let store = compliance_graph::graph::persistence::GraphStore::new(self.db.inner());
|
||||
store
|
||||
.delete_repo_graph(repo_id)
|
||||
.await
|
||||
.map_err(|e| AgentError::Other(format!("Graph cleanup error: {e}")))?;
|
||||
store
|
||||
.store_graph(&build_run, &code_graph.nodes, &code_graph.edges)
|
||||
.await
|
||||
.map_err(|e| AgentError::Other(format!("Graph store error: {e}")))?;
|
||||
|
||||
// Compute impact analysis for each finding
|
||||
let analyzer = compliance_graph::GraphEngine::impact_analyzer(&code_graph);
|
||||
let mut impacts = Vec::new();
|
||||
|
||||
for finding in findings {
|
||||
if let Some(file_path) = &finding.file_path {
|
||||
let impact = analyzer.analyze(
|
||||
repo_id,
|
||||
&finding.fingerprint,
|
||||
&graph_build_id,
|
||||
file_path,
|
||||
finding.line_number,
|
||||
);
|
||||
store
|
||||
.store_impact(&impact)
|
||||
.await
|
||||
.map_err(|e| AgentError::Other(format!("Impact store error: {e}")))?;
|
||||
impacts.push(impact);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(GraphContext {
|
||||
node_count: build_run.node_count,
|
||||
edge_count: build_run.edge_count,
|
||||
community_count: build_run.community_count,
|
||||
impacts,
|
||||
})
|
||||
}
|
||||
|
||||
/// Trigger DAST scan if a target is configured for this repo
|
||||
pub(super) async fn maybe_trigger_dast(&self, repo_id: &str, scan_run_id: &str) {
|
||||
use futures_util::TryStreamExt;
|
||||
|
||||
let filter = mongodb::bson::doc! { "repo_id": repo_id };
|
||||
let targets: Vec<compliance_core::models::DastTarget> =
|
||||
match self.db.dast_targets().find(filter).await {
|
||||
Ok(cursor) => cursor.try_collect().await.unwrap_or_default(),
|
||||
Err(_) => return,
|
||||
};
|
||||
|
||||
if targets.is_empty() {
|
||||
tracing::info!("[{repo_id}] No DAST targets configured, skipping");
|
||||
return;
|
||||
}
|
||||
|
||||
for target in targets {
|
||||
let db = self.db.clone();
|
||||
let scan_run_id = scan_run_id.to_string();
|
||||
tokio::spawn(async move {
|
||||
let orchestrator = compliance_dast::DastOrchestrator::new(100);
|
||||
match orchestrator.run_scan(&target, Vec::new()).await {
|
||||
Ok((mut scan_run, findings)) => {
|
||||
scan_run.sast_scan_run_id = Some(scan_run_id);
|
||||
if let Err(e) = db.dast_scan_runs().insert_one(&scan_run).await {
|
||||
tracing::error!("Failed to store DAST scan run: {e}");
|
||||
}
|
||||
for finding in &findings {
|
||||
if let Err(e) = db.dast_findings().insert_one(finding).await {
|
||||
tracing::error!("Failed to store DAST finding: {e}");
|
||||
}
|
||||
}
|
||||
tracing::info!("DAST scan complete: {} findings", findings.len());
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!("DAST scan failed: {e}");
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
259
compliance-agent/src/pipeline/issue_creation.rs
Normal file
259
compliance-agent/src/pipeline/issue_creation.rs
Normal file
@@ -0,0 +1,259 @@
|
||||
use mongodb::bson::doc;
|
||||
|
||||
use compliance_core::models::*;
|
||||
|
||||
use super::orchestrator::{extract_base_url, PipelineOrchestrator};
|
||||
use super::tracker_dispatch::TrackerDispatch;
|
||||
use crate::error::AgentError;
|
||||
use crate::trackers;
|
||||
|
||||
impl PipelineOrchestrator {
|
||||
/// Build an issue tracker client from a repository's tracker configuration.
|
||||
/// Returns `None` if the repo has no tracker configured.
|
||||
pub(super) fn build_tracker(&self, repo: &TrackedRepository) -> Option<TrackerDispatch> {
|
||||
let tracker_type = repo.tracker_type.as_ref()?;
|
||||
// Per-repo token takes precedence, fall back to global config
|
||||
match tracker_type {
|
||||
TrackerType::GitHub => {
|
||||
let token = repo.tracker_token.clone().or_else(|| {
|
||||
self.config.github_token.as_ref().map(|t| {
|
||||
use secrecy::ExposeSecret;
|
||||
t.expose_secret().to_string()
|
||||
})
|
||||
})?;
|
||||
let secret = secrecy::SecretString::from(token);
|
||||
match trackers::github::GitHubTracker::new(&secret) {
|
||||
Ok(t) => Some(TrackerDispatch::GitHub(t)),
|
||||
Err(e) => {
|
||||
tracing::warn!("Failed to build GitHub tracker: {e}");
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
TrackerType::GitLab => {
|
||||
let base_url = self
|
||||
.config
|
||||
.gitlab_url
|
||||
.clone()
|
||||
.unwrap_or_else(|| "https://gitlab.com".to_string());
|
||||
let token = repo.tracker_token.clone().or_else(|| {
|
||||
self.config.gitlab_token.as_ref().map(|t| {
|
||||
use secrecy::ExposeSecret;
|
||||
t.expose_secret().to_string()
|
||||
})
|
||||
})?;
|
||||
let secret = secrecy::SecretString::from(token);
|
||||
Some(TrackerDispatch::GitLab(
|
||||
trackers::gitlab::GitLabTracker::new(base_url, secret),
|
||||
))
|
||||
}
|
||||
TrackerType::Gitea => {
|
||||
let token = repo.tracker_token.clone()?;
|
||||
let base_url = extract_base_url(&repo.git_url)?;
|
||||
let secret = secrecy::SecretString::from(token);
|
||||
Some(TrackerDispatch::Gitea(trackers::gitea::GiteaTracker::new(
|
||||
base_url, secret,
|
||||
)))
|
||||
}
|
||||
TrackerType::Jira => {
|
||||
let base_url = self.config.jira_url.clone()?;
|
||||
let email = self.config.jira_email.clone()?;
|
||||
let project_key = self.config.jira_project_key.clone()?;
|
||||
let token = repo.tracker_token.clone().or_else(|| {
|
||||
self.config.jira_api_token.as_ref().map(|t| {
|
||||
use secrecy::ExposeSecret;
|
||||
t.expose_secret().to_string()
|
||||
})
|
||||
})?;
|
||||
let secret = secrecy::SecretString::from(token);
|
||||
Some(TrackerDispatch::Jira(trackers::jira::JiraTracker::new(
|
||||
base_url,
|
||||
email,
|
||||
secret,
|
||||
project_key,
|
||||
)))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Create tracker issues for new findings (severity >= Medium).
|
||||
/// Checks for duplicates via fingerprint search before creating.
|
||||
#[tracing::instrument(skip_all, fields(repo_id = %repo_id))]
|
||||
pub(super) async fn create_tracker_issues(
|
||||
&self,
|
||||
repo: &TrackedRepository,
|
||||
repo_id: &str,
|
||||
new_findings: &[Finding],
|
||||
) -> Result<(), AgentError> {
|
||||
let tracker = match self.build_tracker(repo) {
|
||||
Some(t) => t,
|
||||
None => {
|
||||
tracing::info!("[{repo_id}] No issue tracker configured, skipping");
|
||||
return Ok(());
|
||||
}
|
||||
};
|
||||
|
||||
let owner = match repo.tracker_owner.as_deref() {
|
||||
Some(o) => o,
|
||||
None => {
|
||||
tracing::warn!("[{repo_id}] tracker_owner not set, skipping issue creation");
|
||||
return Ok(());
|
||||
}
|
||||
};
|
||||
let tracker_repo_name = match repo.tracker_repo.as_deref() {
|
||||
Some(r) => r,
|
||||
None => {
|
||||
tracing::warn!("[{repo_id}] tracker_repo not set, skipping issue creation");
|
||||
return Ok(());
|
||||
}
|
||||
};
|
||||
|
||||
// Only create issues for medium+ severity findings
|
||||
let actionable: Vec<&Finding> = new_findings
|
||||
.iter()
|
||||
.filter(|f| {
|
||||
matches!(
|
||||
f.severity,
|
||||
Severity::Medium | Severity::High | Severity::Critical
|
||||
)
|
||||
})
|
||||
.collect();
|
||||
|
||||
if actionable.is_empty() {
|
||||
tracing::info!("[{repo_id}] No medium+ findings, skipping issue creation");
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
tracing::info!(
|
||||
"[{repo_id}] Creating issues for {} findings via {}",
|
||||
actionable.len(),
|
||||
tracker.name()
|
||||
);
|
||||
|
||||
let mut created = 0u32;
|
||||
for finding in actionable {
|
||||
let title = format!(
|
||||
"[{}] {}: {}",
|
||||
finding.severity, finding.scanner, finding.title
|
||||
);
|
||||
|
||||
// Check if an issue already exists by fingerprint first, then by title
|
||||
let mut found_existing = false;
|
||||
for search_term in [&finding.fingerprint, &title] {
|
||||
match tracker
|
||||
.find_existing_issue(owner, tracker_repo_name, search_term)
|
||||
.await
|
||||
{
|
||||
Ok(Some(existing)) => {
|
||||
tracing::debug!(
|
||||
"[{repo_id}] Issue already exists for '{}': {}",
|
||||
search_term,
|
||||
existing.external_url
|
||||
);
|
||||
found_existing = true;
|
||||
break;
|
||||
}
|
||||
Ok(None) => {}
|
||||
Err(e) => {
|
||||
tracing::warn!("[{repo_id}] Failed to search for existing issue: {e}");
|
||||
}
|
||||
}
|
||||
}
|
||||
if found_existing {
|
||||
continue;
|
||||
}
|
||||
let body = format_issue_body(finding);
|
||||
let labels = vec![
|
||||
format!("severity:{}", finding.severity),
|
||||
format!("scanner:{}", finding.scanner),
|
||||
"compliance-scanner".to_string(),
|
||||
];
|
||||
|
||||
match tracker
|
||||
.create_issue(owner, tracker_repo_name, &title, &body, &labels)
|
||||
.await
|
||||
{
|
||||
Ok(mut issue) => {
|
||||
issue.finding_id = finding
|
||||
.id
|
||||
.as_ref()
|
||||
.map(|id| id.to_hex())
|
||||
.unwrap_or_default();
|
||||
|
||||
// Update the finding with the issue URL
|
||||
if let Some(finding_id) = &finding.id {
|
||||
let _ = self
|
||||
.db
|
||||
.findings()
|
||||
.update_one(
|
||||
doc! { "_id": finding_id },
|
||||
doc! { "$set": { "tracker_issue_url": &issue.external_url } },
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
// Store the tracker issue record
|
||||
if let Err(e) = self.db.tracker_issues().insert_one(&issue).await {
|
||||
tracing::warn!("[{repo_id}] Failed to store tracker issue: {e}");
|
||||
}
|
||||
|
||||
created += 1;
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!(
|
||||
"[{repo_id}] Failed to create issue for {}: {e}",
|
||||
finding.fingerprint
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
tracing::info!("[{repo_id}] Created {created} tracker issues");
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// Format a finding into a markdown issue body for the tracker.
|
||||
pub(super) fn format_issue_body(finding: &Finding) -> String {
|
||||
let mut body = String::new();
|
||||
|
||||
body.push_str(&format!("## {} Finding\n\n", finding.severity));
|
||||
body.push_str(&format!("**Scanner:** {}\n", finding.scanner));
|
||||
body.push_str(&format!("**Severity:** {}\n", finding.severity));
|
||||
|
||||
if let Some(rule) = &finding.rule_id {
|
||||
body.push_str(&format!("**Rule:** {}\n", rule));
|
||||
}
|
||||
if let Some(cwe) = &finding.cwe {
|
||||
body.push_str(&format!("**CWE:** {}\n", cwe));
|
||||
}
|
||||
|
||||
body.push_str(&format!("\n### Description\n\n{}\n", finding.description));
|
||||
|
||||
if let Some(file_path) = &finding.file_path {
|
||||
body.push_str(&format!("\n### Location\n\n**File:** `{}`", file_path));
|
||||
if let Some(line) = finding.line_number {
|
||||
body.push_str(&format!(" (line {})", line));
|
||||
}
|
||||
body.push('\n');
|
||||
}
|
||||
|
||||
if let Some(snippet) = &finding.code_snippet {
|
||||
body.push_str(&format!("\n### Code\n\n```\n{}\n```\n", snippet));
|
||||
}
|
||||
|
||||
if let Some(remediation) = &finding.remediation {
|
||||
body.push_str(&format!("\n### Remediation\n\n{}\n", remediation));
|
||||
}
|
||||
|
||||
if let Some(fix) = &finding.suggested_fix {
|
||||
body.push_str(&format!("\n### Suggested Fix\n\n```\n{}\n```\n", fix));
|
||||
}
|
||||
|
||||
body.push_str(&format!(
|
||||
"\n---\n*Fingerprint:* `{}`\n*Generated by compliance-scanner*",
|
||||
finding.fingerprint
|
||||
));
|
||||
|
||||
body
|
||||
}
|
||||
@@ -1,366 +0,0 @@
|
||||
use std::path::Path;
|
||||
use std::time::Duration;
|
||||
|
||||
use compliance_core::models::{Finding, ScanType, Severity};
|
||||
use compliance_core::traits::{ScanOutput, Scanner};
|
||||
use compliance_core::CoreError;
|
||||
use tokio::process::Command;
|
||||
|
||||
use crate::pipeline::dedup;
|
||||
|
||||
/// Timeout for each individual lint command
|
||||
const LINT_TIMEOUT: Duration = Duration::from_secs(120);
|
||||
|
||||
pub struct LintScanner;
|
||||
|
||||
impl Scanner for LintScanner {
|
||||
fn name(&self) -> &str {
|
||||
"lint"
|
||||
}
|
||||
|
||||
fn scan_type(&self) -> ScanType {
|
||||
ScanType::Lint
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip_all)]
|
||||
async fn scan(&self, repo_path: &Path, repo_id: &str) -> Result<ScanOutput, CoreError> {
|
||||
let mut all_findings = Vec::new();
|
||||
|
||||
// Detect which languages are present and run appropriate linters
|
||||
if has_rust_project(repo_path) {
|
||||
match run_clippy(repo_path, repo_id).await {
|
||||
Ok(findings) => all_findings.extend(findings),
|
||||
Err(e) => tracing::warn!("Clippy failed: {e}"),
|
||||
}
|
||||
}
|
||||
|
||||
if has_js_project(repo_path) {
|
||||
match run_eslint(repo_path, repo_id).await {
|
||||
Ok(findings) => all_findings.extend(findings),
|
||||
Err(e) => tracing::warn!("ESLint failed: {e}"),
|
||||
}
|
||||
}
|
||||
|
||||
if has_python_project(repo_path) {
|
||||
match run_ruff(repo_path, repo_id).await {
|
||||
Ok(findings) => all_findings.extend(findings),
|
||||
Err(e) => tracing::warn!("Ruff failed: {e}"),
|
||||
}
|
||||
}
|
||||
|
||||
Ok(ScanOutput {
|
||||
findings: all_findings,
|
||||
sbom_entries: Vec::new(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
fn has_rust_project(repo_path: &Path) -> bool {
|
||||
repo_path.join("Cargo.toml").exists()
|
||||
}
|
||||
|
||||
fn has_js_project(repo_path: &Path) -> bool {
|
||||
// Only run if eslint is actually installed in the project
|
||||
repo_path.join("package.json").exists() && repo_path.join("node_modules/.bin/eslint").exists()
|
||||
}
|
||||
|
||||
fn has_python_project(repo_path: &Path) -> bool {
|
||||
repo_path.join("pyproject.toml").exists()
|
||||
|| repo_path.join("setup.py").exists()
|
||||
|| repo_path.join("requirements.txt").exists()
|
||||
}
|
||||
|
||||
/// Run a command with a timeout, returning its output or an error
|
||||
async fn run_with_timeout(
|
||||
child: tokio::process::Child,
|
||||
scanner_name: &str,
|
||||
) -> Result<std::process::Output, CoreError> {
|
||||
let result = tokio::time::timeout(LINT_TIMEOUT, child.wait_with_output()).await;
|
||||
match result {
|
||||
Ok(Ok(output)) => Ok(output),
|
||||
Ok(Err(e)) => Err(CoreError::Scanner {
|
||||
scanner: scanner_name.to_string(),
|
||||
source: Box::new(e),
|
||||
}),
|
||||
Err(_) => {
|
||||
// Process is dropped here which sends SIGKILL on Unix
|
||||
Err(CoreError::Scanner {
|
||||
scanner: scanner_name.to_string(),
|
||||
source: Box::new(std::io::Error::new(
|
||||
std::io::ErrorKind::TimedOut,
|
||||
format!("{scanner_name} timed out after {}s", LINT_TIMEOUT.as_secs()),
|
||||
)),
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ── Clippy ──────────────────────────────────────────────
|
||||
|
||||
async fn run_clippy(repo_path: &Path, repo_id: &str) -> Result<Vec<Finding>, CoreError> {
|
||||
let child = Command::new("cargo")
|
||||
.args([
|
||||
"clippy",
|
||||
"--message-format=json",
|
||||
"--quiet",
|
||||
"--",
|
||||
"-W",
|
||||
"clippy::all",
|
||||
])
|
||||
.current_dir(repo_path)
|
||||
.env("RUSTC_WRAPPER", "")
|
||||
.stdout(std::process::Stdio::piped())
|
||||
.stderr(std::process::Stdio::piped())
|
||||
.spawn()
|
||||
.map_err(|e| CoreError::Scanner {
|
||||
scanner: "clippy".to_string(),
|
||||
source: Box::new(e),
|
||||
})?;
|
||||
|
||||
let output = run_with_timeout(child, "clippy").await?;
|
||||
let stdout = String::from_utf8_lossy(&output.stdout);
|
||||
let mut findings = Vec::new();
|
||||
|
||||
for line in stdout.lines() {
|
||||
let msg: serde_json::Value = match serde_json::from_str(line) {
|
||||
Ok(v) => v,
|
||||
Err(_) => continue,
|
||||
};
|
||||
|
||||
if msg.get("reason").and_then(|v| v.as_str()) != Some("compiler-message") {
|
||||
continue;
|
||||
}
|
||||
|
||||
let message = match msg.get("message") {
|
||||
Some(m) => m,
|
||||
None => continue,
|
||||
};
|
||||
|
||||
let level = message.get("level").and_then(|v| v.as_str()).unwrap_or("");
|
||||
|
||||
if level != "warning" && level != "error" {
|
||||
continue;
|
||||
}
|
||||
|
||||
let text = message
|
||||
.get("message")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("")
|
||||
.to_string();
|
||||
|
||||
let code = message
|
||||
.get("code")
|
||||
.and_then(|v| v.get("code"))
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("")
|
||||
.to_string();
|
||||
|
||||
if text.starts_with("aborting due to") || code.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
let (file_path, line_number) = extract_primary_span(message);
|
||||
|
||||
let severity = if level == "error" {
|
||||
Severity::High
|
||||
} else {
|
||||
Severity::Low
|
||||
};
|
||||
|
||||
let fingerprint = dedup::compute_fingerprint(&[
|
||||
repo_id,
|
||||
"clippy",
|
||||
&code,
|
||||
&file_path,
|
||||
&line_number.to_string(),
|
||||
]);
|
||||
|
||||
let mut finding = Finding::new(
|
||||
repo_id.to_string(),
|
||||
fingerprint,
|
||||
"clippy".to_string(),
|
||||
ScanType::Lint,
|
||||
format!("[clippy] {text}"),
|
||||
text,
|
||||
severity,
|
||||
);
|
||||
finding.rule_id = Some(code);
|
||||
if !file_path.is_empty() {
|
||||
finding.file_path = Some(file_path);
|
||||
}
|
||||
if line_number > 0 {
|
||||
finding.line_number = Some(line_number);
|
||||
}
|
||||
findings.push(finding);
|
||||
}
|
||||
|
||||
Ok(findings)
|
||||
}
|
||||
|
||||
fn extract_primary_span(message: &serde_json::Value) -> (String, u32) {
|
||||
let spans = match message.get("spans").and_then(|v| v.as_array()) {
|
||||
Some(s) => s,
|
||||
None => return (String::new(), 0),
|
||||
};
|
||||
|
||||
for span in spans {
|
||||
if span.get("is_primary").and_then(|v| v.as_bool()) == Some(true) {
|
||||
let file = span
|
||||
.get("file_name")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("")
|
||||
.to_string();
|
||||
let line = span.get("line_start").and_then(|v| v.as_u64()).unwrap_or(0) as u32;
|
||||
return (file, line);
|
||||
}
|
||||
}
|
||||
|
||||
(String::new(), 0)
|
||||
}
|
||||
|
||||
// ── ESLint ──────────────────────────────────────────────
|
||||
|
||||
async fn run_eslint(repo_path: &Path, repo_id: &str) -> Result<Vec<Finding>, CoreError> {
|
||||
// Use the project-local eslint binary directly, not npx (which can hang downloading)
|
||||
let eslint_bin = repo_path.join("node_modules/.bin/eslint");
|
||||
let child = Command::new(eslint_bin)
|
||||
.args([".", "--format", "json", "--no-error-on-unmatched-pattern"])
|
||||
.current_dir(repo_path)
|
||||
.stdout(std::process::Stdio::piped())
|
||||
.stderr(std::process::Stdio::piped())
|
||||
.spawn()
|
||||
.map_err(|e| CoreError::Scanner {
|
||||
scanner: "eslint".to_string(),
|
||||
source: Box::new(e),
|
||||
})?;
|
||||
|
||||
let output = run_with_timeout(child, "eslint").await?;
|
||||
|
||||
if output.stdout.is_empty() {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
|
||||
let results: Vec<EslintFileResult> = serde_json::from_slice(&output.stdout).unwrap_or_default();
|
||||
|
||||
let mut findings = Vec::new();
|
||||
for file_result in results {
|
||||
for msg in file_result.messages {
|
||||
let severity = match msg.severity {
|
||||
2 => Severity::Medium,
|
||||
_ => Severity::Low,
|
||||
};
|
||||
|
||||
let rule_id = msg.rule_id.unwrap_or_default();
|
||||
let fingerprint = dedup::compute_fingerprint(&[
|
||||
repo_id,
|
||||
"eslint",
|
||||
&rule_id,
|
||||
&file_result.file_path,
|
||||
&msg.line.to_string(),
|
||||
]);
|
||||
|
||||
let mut finding = Finding::new(
|
||||
repo_id.to_string(),
|
||||
fingerprint,
|
||||
"eslint".to_string(),
|
||||
ScanType::Lint,
|
||||
format!("[eslint] {}", msg.message),
|
||||
msg.message,
|
||||
severity,
|
||||
);
|
||||
finding.rule_id = Some(rule_id);
|
||||
finding.file_path = Some(file_result.file_path.clone());
|
||||
finding.line_number = Some(msg.line);
|
||||
findings.push(finding);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(findings)
|
||||
}
|
||||
|
||||
#[derive(serde::Deserialize)]
|
||||
struct EslintFileResult {
|
||||
#[serde(rename = "filePath")]
|
||||
file_path: String,
|
||||
messages: Vec<EslintMessage>,
|
||||
}
|
||||
|
||||
#[derive(serde::Deserialize)]
|
||||
struct EslintMessage {
|
||||
#[serde(rename = "ruleId")]
|
||||
rule_id: Option<String>,
|
||||
severity: u8,
|
||||
message: String,
|
||||
line: u32,
|
||||
}
|
||||
|
||||
// ── Ruff ────────────────────────────────────────────────
|
||||
|
||||
async fn run_ruff(repo_path: &Path, repo_id: &str) -> Result<Vec<Finding>, CoreError> {
|
||||
let child = Command::new("ruff")
|
||||
.args(["check", ".", "--output-format", "json", "--exit-zero"])
|
||||
.current_dir(repo_path)
|
||||
.stdout(std::process::Stdio::piped())
|
||||
.stderr(std::process::Stdio::piped())
|
||||
.spawn()
|
||||
.map_err(|e| CoreError::Scanner {
|
||||
scanner: "ruff".to_string(),
|
||||
source: Box::new(e),
|
||||
})?;
|
||||
|
||||
let output = run_with_timeout(child, "ruff").await?;
|
||||
|
||||
if output.stdout.is_empty() {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
|
||||
let results: Vec<RuffResult> = serde_json::from_slice(&output.stdout).unwrap_or_default();
|
||||
|
||||
let findings = results
|
||||
.into_iter()
|
||||
.map(|r| {
|
||||
let severity = if r.code.starts_with('E') || r.code.starts_with('F') {
|
||||
Severity::Medium
|
||||
} else {
|
||||
Severity::Low
|
||||
};
|
||||
|
||||
let fingerprint = dedup::compute_fingerprint(&[
|
||||
repo_id,
|
||||
"ruff",
|
||||
&r.code,
|
||||
&r.filename,
|
||||
&r.location.row.to_string(),
|
||||
]);
|
||||
|
||||
let mut finding = Finding::new(
|
||||
repo_id.to_string(),
|
||||
fingerprint,
|
||||
"ruff".to_string(),
|
||||
ScanType::Lint,
|
||||
format!("[ruff] {}: {}", r.code, r.message),
|
||||
r.message,
|
||||
severity,
|
||||
);
|
||||
finding.rule_id = Some(r.code);
|
||||
finding.file_path = Some(r.filename);
|
||||
finding.line_number = Some(r.location.row);
|
||||
finding
|
||||
})
|
||||
.collect();
|
||||
|
||||
Ok(findings)
|
||||
}
|
||||
|
||||
#[derive(serde::Deserialize)]
|
||||
struct RuffResult {
|
||||
code: String,
|
||||
message: String,
|
||||
filename: String,
|
||||
location: RuffLocation,
|
||||
}
|
||||
|
||||
#[derive(serde::Deserialize)]
|
||||
struct RuffLocation {
|
||||
row: u32,
|
||||
}
|
||||
251
compliance-agent/src/pipeline/lint/clippy.rs
Normal file
251
compliance-agent/src/pipeline/lint/clippy.rs
Normal file
@@ -0,0 +1,251 @@
|
||||
use std::path::Path;
|
||||
|
||||
use compliance_core::models::{Finding, ScanType, Severity};
|
||||
use compliance_core::CoreError;
|
||||
use tokio::process::Command;
|
||||
|
||||
use crate::pipeline::dedup;
|
||||
|
||||
use super::run_with_timeout;
|
||||
|
||||
pub(super) async fn run_clippy(repo_path: &Path, repo_id: &str) -> Result<Vec<Finding>, CoreError> {
|
||||
let child = Command::new("cargo")
|
||||
.args([
|
||||
"clippy",
|
||||
"--message-format=json",
|
||||
"--quiet",
|
||||
"--",
|
||||
"-W",
|
||||
"clippy::all",
|
||||
])
|
||||
.current_dir(repo_path)
|
||||
.env("RUSTC_WRAPPER", "")
|
||||
.stdout(std::process::Stdio::piped())
|
||||
.stderr(std::process::Stdio::piped())
|
||||
.spawn()
|
||||
.map_err(|e| CoreError::Scanner {
|
||||
scanner: "clippy".to_string(),
|
||||
source: Box::new(e),
|
||||
})?;
|
||||
|
||||
let output = run_with_timeout(child, "clippy").await?;
|
||||
let stdout = String::from_utf8_lossy(&output.stdout);
|
||||
let mut findings = Vec::new();
|
||||
|
||||
for line in stdout.lines() {
|
||||
let msg: serde_json::Value = match serde_json::from_str(line) {
|
||||
Ok(v) => v,
|
||||
Err(_) => continue,
|
||||
};
|
||||
|
||||
if msg.get("reason").and_then(|v| v.as_str()) != Some("compiler-message") {
|
||||
continue;
|
||||
}
|
||||
|
||||
let message = match msg.get("message") {
|
||||
Some(m) => m,
|
||||
None => continue,
|
||||
};
|
||||
|
||||
let level = message.get("level").and_then(|v| v.as_str()).unwrap_or("");
|
||||
|
||||
if level != "warning" && level != "error" {
|
||||
continue;
|
||||
}
|
||||
|
||||
let text = message
|
||||
.get("message")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("")
|
||||
.to_string();
|
||||
|
||||
let code = message
|
||||
.get("code")
|
||||
.and_then(|v| v.get("code"))
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("")
|
||||
.to_string();
|
||||
|
||||
if text.starts_with("aborting due to") || code.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
let (file_path, line_number) = extract_primary_span(message);
|
||||
|
||||
let severity = if level == "error" {
|
||||
Severity::High
|
||||
} else {
|
||||
Severity::Low
|
||||
};
|
||||
|
||||
let fingerprint = dedup::compute_fingerprint(&[
|
||||
repo_id,
|
||||
"clippy",
|
||||
&code,
|
||||
&file_path,
|
||||
&line_number.to_string(),
|
||||
]);
|
||||
|
||||
let mut finding = Finding::new(
|
||||
repo_id.to_string(),
|
||||
fingerprint,
|
||||
"clippy".to_string(),
|
||||
ScanType::Lint,
|
||||
format!("[clippy] {text}"),
|
||||
text,
|
||||
severity,
|
||||
);
|
||||
finding.rule_id = Some(code);
|
||||
if !file_path.is_empty() {
|
||||
finding.file_path = Some(file_path);
|
||||
}
|
||||
if line_number > 0 {
|
||||
finding.line_number = Some(line_number);
|
||||
}
|
||||
findings.push(finding);
|
||||
}
|
||||
|
||||
Ok(findings)
|
||||
}
|
||||
|
||||
fn extract_primary_span(message: &serde_json::Value) -> (String, u32) {
|
||||
let spans = match message.get("spans").and_then(|v| v.as_array()) {
|
||||
Some(s) => s,
|
||||
None => return (String::new(), 0),
|
||||
};
|
||||
|
||||
for span in spans {
|
||||
if span.get("is_primary").and_then(|v| v.as_bool()) == Some(true) {
|
||||
let file = span
|
||||
.get("file_name")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("")
|
||||
.to_string();
|
||||
let line = span.get("line_start").and_then(|v| v.as_u64()).unwrap_or(0) as u32;
|
||||
return (file, line);
|
||||
}
|
||||
}
|
||||
|
||||
(String::new(), 0)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn extract_primary_span_with_primary() {
|
||||
let msg = serde_json::json!({
|
||||
"spans": [
|
||||
{
|
||||
"file_name": "src/lib.rs",
|
||||
"line_start": 42,
|
||||
"is_primary": true
|
||||
}
|
||||
]
|
||||
});
|
||||
let (file, line) = extract_primary_span(&msg);
|
||||
assert_eq!(file, "src/lib.rs");
|
||||
assert_eq!(line, 42);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extract_primary_span_no_primary() {
|
||||
let msg = serde_json::json!({
|
||||
"spans": [
|
||||
{
|
||||
"file_name": "src/lib.rs",
|
||||
"line_start": 42,
|
||||
"is_primary": false
|
||||
}
|
||||
]
|
||||
});
|
||||
let (file, line) = extract_primary_span(&msg);
|
||||
assert_eq!(file, "");
|
||||
assert_eq!(line, 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extract_primary_span_multiple_spans() {
|
||||
let msg = serde_json::json!({
|
||||
"spans": [
|
||||
{
|
||||
"file_name": "src/other.rs",
|
||||
"line_start": 10,
|
||||
"is_primary": false
|
||||
},
|
||||
{
|
||||
"file_name": "src/main.rs",
|
||||
"line_start": 99,
|
||||
"is_primary": true
|
||||
}
|
||||
]
|
||||
});
|
||||
let (file, line) = extract_primary_span(&msg);
|
||||
assert_eq!(file, "src/main.rs");
|
||||
assert_eq!(line, 99);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extract_primary_span_no_spans() {
|
||||
let msg = serde_json::json!({});
|
||||
let (file, line) = extract_primary_span(&msg);
|
||||
assert_eq!(file, "");
|
||||
assert_eq!(line, 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extract_primary_span_empty_spans() {
|
||||
let msg = serde_json::json!({ "spans": [] });
|
||||
let (file, line) = extract_primary_span(&msg);
|
||||
assert_eq!(file, "");
|
||||
assert_eq!(line, 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_clippy_compiler_message_line() {
|
||||
let line = r#"{"reason":"compiler-message","message":{"level":"warning","message":"unused variable","code":{"code":"unused_variables"},"spans":[{"file_name":"src/main.rs","line_start":5,"is_primary":true}]}}"#;
|
||||
let msg: serde_json::Value = serde_json::from_str(line).unwrap();
|
||||
|
||||
assert_eq!(
|
||||
msg.get("reason").and_then(|v| v.as_str()),
|
||||
Some("compiler-message")
|
||||
);
|
||||
let message = msg.get("message").unwrap();
|
||||
assert_eq!(
|
||||
message.get("level").and_then(|v| v.as_str()),
|
||||
Some("warning")
|
||||
);
|
||||
assert_eq!(
|
||||
message.get("message").and_then(|v| v.as_str()),
|
||||
Some("unused variable")
|
||||
);
|
||||
assert_eq!(
|
||||
message
|
||||
.get("code")
|
||||
.and_then(|v| v.get("code"))
|
||||
.and_then(|v| v.as_str()),
|
||||
Some("unused_variables")
|
||||
);
|
||||
|
||||
let (file, line_num) = extract_primary_span(message);
|
||||
assert_eq!(file, "src/main.rs");
|
||||
assert_eq!(line_num, 5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn skip_non_compiler_message() {
|
||||
let line = r#"{"reason":"build-script-executed","package_id":"foo 0.1.0"}"#;
|
||||
let msg: serde_json::Value = serde_json::from_str(line).unwrap();
|
||||
assert_ne!(
|
||||
msg.get("reason").and_then(|v| v.as_str()),
|
||||
Some("compiler-message")
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn skip_aborting_message() {
|
||||
let text = "aborting due to 3 previous errors";
|
||||
assert!(text.starts_with("aborting due to"));
|
||||
}
|
||||
}
|
||||
183
compliance-agent/src/pipeline/lint/eslint.rs
Normal file
183
compliance-agent/src/pipeline/lint/eslint.rs
Normal file
@@ -0,0 +1,183 @@
|
||||
use std::path::Path;
|
||||
|
||||
use compliance_core::models::{Finding, ScanType, Severity};
|
||||
use compliance_core::CoreError;
|
||||
use tokio::process::Command;
|
||||
|
||||
use crate::pipeline::dedup;
|
||||
|
||||
use super::run_with_timeout;
|
||||
|
||||
pub(super) async fn run_eslint(repo_path: &Path, repo_id: &str) -> Result<Vec<Finding>, CoreError> {
|
||||
// Use the project-local eslint binary directly, not npx (which can hang downloading)
|
||||
let eslint_bin = repo_path.join("node_modules/.bin/eslint");
|
||||
let child = Command::new(eslint_bin)
|
||||
.args([".", "--format", "json", "--no-error-on-unmatched-pattern"])
|
||||
.current_dir(repo_path)
|
||||
.stdout(std::process::Stdio::piped())
|
||||
.stderr(std::process::Stdio::piped())
|
||||
.spawn()
|
||||
.map_err(|e| CoreError::Scanner {
|
||||
scanner: "eslint".to_string(),
|
||||
source: Box::new(e),
|
||||
})?;
|
||||
|
||||
let output = run_with_timeout(child, "eslint").await?;
|
||||
|
||||
if output.stdout.is_empty() {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
|
||||
let results: Vec<EslintFileResult> = serde_json::from_slice(&output.stdout).unwrap_or_default();
|
||||
|
||||
let mut findings = Vec::new();
|
||||
for file_result in results {
|
||||
for msg in file_result.messages {
|
||||
let severity = match msg.severity {
|
||||
2 => Severity::Medium,
|
||||
_ => Severity::Low,
|
||||
};
|
||||
|
||||
let rule_id = msg.rule_id.unwrap_or_default();
|
||||
let fingerprint = dedup::compute_fingerprint(&[
|
||||
repo_id,
|
||||
"eslint",
|
||||
&rule_id,
|
||||
&file_result.file_path,
|
||||
&msg.line.to_string(),
|
||||
]);
|
||||
|
||||
let mut finding = Finding::new(
|
||||
repo_id.to_string(),
|
||||
fingerprint,
|
||||
"eslint".to_string(),
|
||||
ScanType::Lint,
|
||||
format!("[eslint] {}", msg.message),
|
||||
msg.message,
|
||||
severity,
|
||||
);
|
||||
finding.rule_id = Some(rule_id);
|
||||
finding.file_path = Some(file_result.file_path.clone());
|
||||
finding.line_number = Some(msg.line);
|
||||
findings.push(finding);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(findings)
|
||||
}
|
||||
|
||||
#[derive(serde::Deserialize)]
|
||||
struct EslintFileResult {
|
||||
#[serde(rename = "filePath")]
|
||||
file_path: String,
|
||||
messages: Vec<EslintMessage>,
|
||||
}
|
||||
|
||||
#[derive(serde::Deserialize)]
|
||||
struct EslintMessage {
|
||||
#[serde(rename = "ruleId")]
|
||||
rule_id: Option<String>,
|
||||
severity: u8,
|
||||
message: String,
|
||||
line: u32,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn deserialize_eslint_output() {
|
||||
let json = r#"[
|
||||
{
|
||||
"filePath": "/home/user/project/src/app.js",
|
||||
"messages": [
|
||||
{
|
||||
"ruleId": "no-unused-vars",
|
||||
"severity": 2,
|
||||
"message": "'x' is defined but never used.",
|
||||
"line": 10
|
||||
},
|
||||
{
|
||||
"ruleId": "semi",
|
||||
"severity": 1,
|
||||
"message": "Missing semicolon.",
|
||||
"line": 15
|
||||
}
|
||||
]
|
||||
}
|
||||
]"#;
|
||||
let results: Vec<EslintFileResult> = serde_json::from_str(json).unwrap();
|
||||
assert_eq!(results.len(), 1);
|
||||
assert_eq!(results[0].file_path, "/home/user/project/src/app.js");
|
||||
assert_eq!(results[0].messages.len(), 2);
|
||||
|
||||
assert_eq!(
|
||||
results[0].messages[0].rule_id,
|
||||
Some("no-unused-vars".to_string())
|
||||
);
|
||||
assert_eq!(results[0].messages[0].severity, 2);
|
||||
assert_eq!(results[0].messages[0].line, 10);
|
||||
|
||||
assert_eq!(results[0].messages[1].severity, 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn deserialize_eslint_null_rule_id() {
|
||||
let json = r#"[
|
||||
{
|
||||
"filePath": "src/index.js",
|
||||
"messages": [
|
||||
{
|
||||
"ruleId": null,
|
||||
"severity": 2,
|
||||
"message": "Parsing error: Unexpected token",
|
||||
"line": 1
|
||||
}
|
||||
]
|
||||
}
|
||||
]"#;
|
||||
let results: Vec<EslintFileResult> = serde_json::from_str(json).unwrap();
|
||||
assert_eq!(results[0].messages[0].rule_id, None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn deserialize_eslint_empty_messages() {
|
||||
let json = r#"[{"filePath": "src/clean.js", "messages": []}]"#;
|
||||
let results: Vec<EslintFileResult> = serde_json::from_str(json).unwrap();
|
||||
assert_eq!(results[0].messages.len(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn deserialize_eslint_empty_array() {
|
||||
let json = "[]";
|
||||
let results: Vec<EslintFileResult> = serde_json::from_str(json).unwrap();
|
||||
assert!(results.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn eslint_severity_mapping() {
|
||||
// severity 2 = error -> Medium, anything else -> Low
|
||||
assert_eq!(
|
||||
match 2u8 {
|
||||
2 => "Medium",
|
||||
_ => "Low",
|
||||
},
|
||||
"Medium"
|
||||
);
|
||||
assert_eq!(
|
||||
match 1u8 {
|
||||
2 => "Medium",
|
||||
_ => "Low",
|
||||
},
|
||||
"Low"
|
||||
);
|
||||
assert_eq!(
|
||||
match 0u8 {
|
||||
2 => "Medium",
|
||||
_ => "Low",
|
||||
},
|
||||
"Low"
|
||||
);
|
||||
}
|
||||
}
|
||||
97
compliance-agent/src/pipeline/lint/mod.rs
Normal file
97
compliance-agent/src/pipeline/lint/mod.rs
Normal file
@@ -0,0 +1,97 @@
|
||||
mod clippy;
|
||||
mod eslint;
|
||||
mod ruff;
|
||||
|
||||
use std::path::Path;
|
||||
use std::time::Duration;
|
||||
|
||||
use compliance_core::models::ScanType;
|
||||
use compliance_core::traits::{ScanOutput, Scanner};
|
||||
use compliance_core::CoreError;
|
||||
|
||||
/// Timeout for each individual lint command
|
||||
pub(crate) const LINT_TIMEOUT: Duration = Duration::from_secs(120);
|
||||
|
||||
pub struct LintScanner;
|
||||
|
||||
impl Scanner for LintScanner {
|
||||
fn name(&self) -> &str {
|
||||
"lint"
|
||||
}
|
||||
|
||||
fn scan_type(&self) -> ScanType {
|
||||
ScanType::Lint
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip_all)]
|
||||
async fn scan(&self, repo_path: &Path, repo_id: &str) -> Result<ScanOutput, CoreError> {
|
||||
let mut all_findings = Vec::new();
|
||||
|
||||
// Detect which languages are present and run appropriate linters
|
||||
if has_rust_project(repo_path) {
|
||||
match clippy::run_clippy(repo_path, repo_id).await {
|
||||
Ok(findings) => all_findings.extend(findings),
|
||||
Err(e) => tracing::warn!("Clippy failed: {e}"),
|
||||
}
|
||||
}
|
||||
|
||||
if has_js_project(repo_path) {
|
||||
match eslint::run_eslint(repo_path, repo_id).await {
|
||||
Ok(findings) => all_findings.extend(findings),
|
||||
Err(e) => tracing::warn!("ESLint failed: {e}"),
|
||||
}
|
||||
}
|
||||
|
||||
if has_python_project(repo_path) {
|
||||
match ruff::run_ruff(repo_path, repo_id).await {
|
||||
Ok(findings) => all_findings.extend(findings),
|
||||
Err(e) => tracing::warn!("Ruff failed: {e}"),
|
||||
}
|
||||
}
|
||||
|
||||
Ok(ScanOutput {
|
||||
findings: all_findings,
|
||||
sbom_entries: Vec::new(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
fn has_rust_project(repo_path: &Path) -> bool {
|
||||
repo_path.join("Cargo.toml").exists()
|
||||
}
|
||||
|
||||
fn has_js_project(repo_path: &Path) -> bool {
|
||||
// Only run if eslint is actually installed in the project
|
||||
repo_path.join("package.json").exists() && repo_path.join("node_modules/.bin/eslint").exists()
|
||||
}
|
||||
|
||||
fn has_python_project(repo_path: &Path) -> bool {
|
||||
repo_path.join("pyproject.toml").exists()
|
||||
|| repo_path.join("setup.py").exists()
|
||||
|| repo_path.join("requirements.txt").exists()
|
||||
}
|
||||
|
||||
/// Run a command with a timeout, returning its output or an error
|
||||
pub(crate) async fn run_with_timeout(
|
||||
child: tokio::process::Child,
|
||||
scanner_name: &str,
|
||||
) -> Result<std::process::Output, CoreError> {
|
||||
let result = tokio::time::timeout(LINT_TIMEOUT, child.wait_with_output()).await;
|
||||
match result {
|
||||
Ok(Ok(output)) => Ok(output),
|
||||
Ok(Err(e)) => Err(CoreError::Scanner {
|
||||
scanner: scanner_name.to_string(),
|
||||
source: Box::new(e),
|
||||
}),
|
||||
Err(_) => {
|
||||
// Process is dropped here which sends SIGKILL on Unix
|
||||
Err(CoreError::Scanner {
|
||||
scanner: scanner_name.to_string(),
|
||||
source: Box::new(std::io::Error::new(
|
||||
std::io::ErrorKind::TimedOut,
|
||||
format!("{scanner_name} timed out after {}s", LINT_TIMEOUT.as_secs()),
|
||||
)),
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
150
compliance-agent/src/pipeline/lint/ruff.rs
Normal file
150
compliance-agent/src/pipeline/lint/ruff.rs
Normal file
@@ -0,0 +1,150 @@
|
||||
use std::path::Path;
|
||||
|
||||
use compliance_core::models::{Finding, ScanType, Severity};
|
||||
use compliance_core::CoreError;
|
||||
use tokio::process::Command;
|
||||
|
||||
use crate::pipeline::dedup;
|
||||
|
||||
use super::run_with_timeout;
|
||||
|
||||
pub(super) async fn run_ruff(repo_path: &Path, repo_id: &str) -> Result<Vec<Finding>, CoreError> {
|
||||
let child = Command::new("ruff")
|
||||
.args(["check", ".", "--output-format", "json", "--exit-zero"])
|
||||
.current_dir(repo_path)
|
||||
.stdout(std::process::Stdio::piped())
|
||||
.stderr(std::process::Stdio::piped())
|
||||
.spawn()
|
||||
.map_err(|e| CoreError::Scanner {
|
||||
scanner: "ruff".to_string(),
|
||||
source: Box::new(e),
|
||||
})?;
|
||||
|
||||
let output = run_with_timeout(child, "ruff").await?;
|
||||
|
||||
if output.stdout.is_empty() {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
|
||||
let results: Vec<RuffResult> = serde_json::from_slice(&output.stdout).unwrap_or_default();
|
||||
|
||||
let findings = results
|
||||
.into_iter()
|
||||
.map(|r| {
|
||||
let severity = if r.code.starts_with('E') || r.code.starts_with('F') {
|
||||
Severity::Medium
|
||||
} else {
|
||||
Severity::Low
|
||||
};
|
||||
|
||||
let fingerprint = dedup::compute_fingerprint(&[
|
||||
repo_id,
|
||||
"ruff",
|
||||
&r.code,
|
||||
&r.filename,
|
||||
&r.location.row.to_string(),
|
||||
]);
|
||||
|
||||
let mut finding = Finding::new(
|
||||
repo_id.to_string(),
|
||||
fingerprint,
|
||||
"ruff".to_string(),
|
||||
ScanType::Lint,
|
||||
format!("[ruff] {}: {}", r.code, r.message),
|
||||
r.message,
|
||||
severity,
|
||||
);
|
||||
finding.rule_id = Some(r.code);
|
||||
finding.file_path = Some(r.filename);
|
||||
finding.line_number = Some(r.location.row);
|
||||
finding
|
||||
})
|
||||
.collect();
|
||||
|
||||
Ok(findings)
|
||||
}
|
||||
|
||||
#[derive(serde::Deserialize)]
|
||||
struct RuffResult {
|
||||
code: String,
|
||||
message: String,
|
||||
filename: String,
|
||||
location: RuffLocation,
|
||||
}
|
||||
|
||||
#[derive(serde::Deserialize)]
|
||||
struct RuffLocation {
|
||||
row: u32,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn deserialize_ruff_output() {
|
||||
let json = r#"[
|
||||
{
|
||||
"code": "E501",
|
||||
"message": "Line too long (120 > 79 characters)",
|
||||
"filename": "src/main.py",
|
||||
"location": {"row": 42}
|
||||
},
|
||||
{
|
||||
"code": "F401",
|
||||
"message": "`os` imported but unused",
|
||||
"filename": "src/utils.py",
|
||||
"location": {"row": 1}
|
||||
}
|
||||
]"#;
|
||||
let results: Vec<RuffResult> = serde_json::from_str(json).unwrap();
|
||||
assert_eq!(results.len(), 2);
|
||||
|
||||
assert_eq!(results[0].code, "E501");
|
||||
assert_eq!(results[0].filename, "src/main.py");
|
||||
assert_eq!(results[0].location.row, 42);
|
||||
|
||||
assert_eq!(results[1].code, "F401");
|
||||
assert_eq!(results[1].location.row, 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn deserialize_ruff_empty() {
|
||||
let json = "[]";
|
||||
let results: Vec<RuffResult> = serde_json::from_str(json).unwrap();
|
||||
assert!(results.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ruff_severity_e_and_f_are_medium() {
|
||||
for code in &["E501", "E302", "F401", "F811"] {
|
||||
let is_medium = code.starts_with('E') || code.starts_with('F');
|
||||
assert!(is_medium, "Expected {code} to be Medium severity");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ruff_severity_others_are_low() {
|
||||
for code in &["W291", "I001", "D100", "C901", "N801"] {
|
||||
let is_medium = code.starts_with('E') || code.starts_with('F');
|
||||
assert!(!is_medium, "Expected {code} to be Low severity");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn deserialize_ruff_with_extra_fields() {
|
||||
// Ruff output may contain additional fields we don't use
|
||||
let json = r#"[{
|
||||
"code": "W291",
|
||||
"message": "Trailing whitespace",
|
||||
"filename": "app.py",
|
||||
"location": {"row": 3, "column": 10},
|
||||
"end_location": {"row": 3, "column": 11},
|
||||
"fix": null,
|
||||
"noqa_row": 3
|
||||
}]"#;
|
||||
let results: Vec<RuffResult> = serde_json::from_str(json).unwrap();
|
||||
assert_eq!(results.len(), 1);
|
||||
assert_eq!(results[0].code, "W291");
|
||||
}
|
||||
}
|
||||
@@ -3,8 +3,12 @@ pub mod cve;
|
||||
pub mod dedup;
|
||||
pub mod git;
|
||||
pub mod gitleaks;
|
||||
mod graph_build;
|
||||
mod issue_creation;
|
||||
pub mod lint;
|
||||
pub mod orchestrator;
|
||||
pub mod patterns;
|
||||
mod pr_review;
|
||||
pub mod sbom;
|
||||
pub mod semgrep;
|
||||
mod tracker_dispatch;
|
||||
|
||||
@@ -4,7 +4,6 @@ use mongodb::bson::doc;
|
||||
use tracing::Instrument;
|
||||
|
||||
use compliance_core::models::*;
|
||||
use compliance_core::traits::issue_tracker::IssueTracker;
|
||||
use compliance_core::traits::Scanner;
|
||||
use compliance_core::AgentConfig;
|
||||
|
||||
@@ -19,84 +18,6 @@ use crate::pipeline::lint::LintScanner;
|
||||
use crate::pipeline::patterns::{GdprPatternScanner, OAuthPatternScanner};
|
||||
use crate::pipeline::sbom::SbomScanner;
|
||||
use crate::pipeline::semgrep::SemgrepScanner;
|
||||
use crate::trackers;
|
||||
|
||||
/// Enum dispatch for issue trackers (async traits aren't dyn-compatible).
|
||||
enum TrackerDispatch {
|
||||
GitHub(trackers::github::GitHubTracker),
|
||||
GitLab(trackers::gitlab::GitLabTracker),
|
||||
Gitea(trackers::gitea::GiteaTracker),
|
||||
Jira(trackers::jira::JiraTracker),
|
||||
}
|
||||
|
||||
impl TrackerDispatch {
|
||||
fn name(&self) -> &str {
|
||||
match self {
|
||||
Self::GitHub(t) => t.name(),
|
||||
Self::GitLab(t) => t.name(),
|
||||
Self::Gitea(t) => t.name(),
|
||||
Self::Jira(t) => t.name(),
|
||||
}
|
||||
}
|
||||
|
||||
async fn create_issue(
|
||||
&self,
|
||||
owner: &str,
|
||||
repo: &str,
|
||||
title: &str,
|
||||
body: &str,
|
||||
labels: &[String],
|
||||
) -> Result<TrackerIssue, compliance_core::error::CoreError> {
|
||||
match self {
|
||||
Self::GitHub(t) => t.create_issue(owner, repo, title, body, labels).await,
|
||||
Self::GitLab(t) => t.create_issue(owner, repo, title, body, labels).await,
|
||||
Self::Gitea(t) => t.create_issue(owner, repo, title, body, labels).await,
|
||||
Self::Jira(t) => t.create_issue(owner, repo, title, body, labels).await,
|
||||
}
|
||||
}
|
||||
|
||||
async fn find_existing_issue(
|
||||
&self,
|
||||
owner: &str,
|
||||
repo: &str,
|
||||
fingerprint: &str,
|
||||
) -> Result<Option<TrackerIssue>, compliance_core::error::CoreError> {
|
||||
match self {
|
||||
Self::GitHub(t) => t.find_existing_issue(owner, repo, fingerprint).await,
|
||||
Self::GitLab(t) => t.find_existing_issue(owner, repo, fingerprint).await,
|
||||
Self::Gitea(t) => t.find_existing_issue(owner, repo, fingerprint).await,
|
||||
Self::Jira(t) => t.find_existing_issue(owner, repo, fingerprint).await,
|
||||
}
|
||||
}
|
||||
|
||||
async fn create_pr_review(
|
||||
&self,
|
||||
owner: &str,
|
||||
repo: &str,
|
||||
pr_number: u64,
|
||||
body: &str,
|
||||
comments: Vec<compliance_core::traits::issue_tracker::ReviewComment>,
|
||||
) -> Result<(), compliance_core::error::CoreError> {
|
||||
match self {
|
||||
Self::GitHub(t) => {
|
||||
t.create_pr_review(owner, repo, pr_number, body, comments)
|
||||
.await
|
||||
}
|
||||
Self::GitLab(t) => {
|
||||
t.create_pr_review(owner, repo, pr_number, body, comments)
|
||||
.await
|
||||
}
|
||||
Self::Gitea(t) => {
|
||||
t.create_pr_review(owner, repo, pr_number, body, comments)
|
||||
.await
|
||||
}
|
||||
Self::Jira(t) => {
|
||||
t.create_pr_review(owner, repo, pr_number, body, comments)
|
||||
.await
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Context from graph analysis passed to LLM triage for enhanced filtering
|
||||
#[derive(Debug)]
|
||||
@@ -109,10 +30,10 @@ pub struct GraphContext {
|
||||
}
|
||||
|
||||
pub struct PipelineOrchestrator {
|
||||
config: AgentConfig,
|
||||
db: Database,
|
||||
llm: Arc<LlmClient>,
|
||||
http: reqwest::Client,
|
||||
pub(super) config: AgentConfig,
|
||||
pub(super) db: Database,
|
||||
pub(super) llm: Arc<LlmClient>,
|
||||
pub(super) http: reqwest::Client,
|
||||
}
|
||||
|
||||
impl PipelineOrchestrator {
|
||||
@@ -460,446 +381,7 @@ impl PipelineOrchestrator {
|
||||
Ok(new_count)
|
||||
}
|
||||
|
||||
/// Build the code knowledge graph for a repo and compute impact analyses
|
||||
async fn build_code_graph(
|
||||
&self,
|
||||
repo_path: &std::path::Path,
|
||||
repo_id: &str,
|
||||
findings: &[Finding],
|
||||
) -> Result<GraphContext, AgentError> {
|
||||
let graph_build_id = uuid::Uuid::new_v4().to_string();
|
||||
let engine = compliance_graph::GraphEngine::new(50_000);
|
||||
|
||||
let (mut code_graph, build_run) =
|
||||
engine
|
||||
.build_graph(repo_path, repo_id, &graph_build_id)
|
||||
.map_err(|e| AgentError::Other(format!("Graph build error: {e}")))?;
|
||||
|
||||
// Apply community detection
|
||||
compliance_graph::graph::community::apply_communities(&mut code_graph);
|
||||
|
||||
// Store graph in MongoDB
|
||||
let store = compliance_graph::graph::persistence::GraphStore::new(self.db.inner());
|
||||
store
|
||||
.delete_repo_graph(repo_id)
|
||||
.await
|
||||
.map_err(|e| AgentError::Other(format!("Graph cleanup error: {e}")))?;
|
||||
store
|
||||
.store_graph(&build_run, &code_graph.nodes, &code_graph.edges)
|
||||
.await
|
||||
.map_err(|e| AgentError::Other(format!("Graph store error: {e}")))?;
|
||||
|
||||
// Compute impact analysis for each finding
|
||||
let analyzer = compliance_graph::GraphEngine::impact_analyzer(&code_graph);
|
||||
let mut impacts = Vec::new();
|
||||
|
||||
for finding in findings {
|
||||
if let Some(file_path) = &finding.file_path {
|
||||
let impact = analyzer.analyze(
|
||||
repo_id,
|
||||
&finding.fingerprint,
|
||||
&graph_build_id,
|
||||
file_path,
|
||||
finding.line_number,
|
||||
);
|
||||
store
|
||||
.store_impact(&impact)
|
||||
.await
|
||||
.map_err(|e| AgentError::Other(format!("Impact store error: {e}")))?;
|
||||
impacts.push(impact);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(GraphContext {
|
||||
node_count: build_run.node_count,
|
||||
edge_count: build_run.edge_count,
|
||||
community_count: build_run.community_count,
|
||||
impacts,
|
||||
})
|
||||
}
|
||||
|
||||
/// Trigger DAST scan if a target is configured for this repo
|
||||
async fn maybe_trigger_dast(&self, repo_id: &str, scan_run_id: &str) {
|
||||
use futures_util::TryStreamExt;
|
||||
|
||||
let filter = mongodb::bson::doc! { "repo_id": repo_id };
|
||||
let targets: Vec<compliance_core::models::DastTarget> =
|
||||
match self.db.dast_targets().find(filter).await {
|
||||
Ok(cursor) => cursor.try_collect().await.unwrap_or_default(),
|
||||
Err(_) => return,
|
||||
};
|
||||
|
||||
if targets.is_empty() {
|
||||
tracing::info!("[{repo_id}] No DAST targets configured, skipping");
|
||||
return;
|
||||
}
|
||||
|
||||
for target in targets {
|
||||
let db = self.db.clone();
|
||||
let scan_run_id = scan_run_id.to_string();
|
||||
tokio::spawn(async move {
|
||||
let orchestrator = compliance_dast::DastOrchestrator::new(100);
|
||||
match orchestrator.run_scan(&target, Vec::new()).await {
|
||||
Ok((mut scan_run, findings)) => {
|
||||
scan_run.sast_scan_run_id = Some(scan_run_id);
|
||||
if let Err(e) = db.dast_scan_runs().insert_one(&scan_run).await {
|
||||
tracing::error!("Failed to store DAST scan run: {e}");
|
||||
}
|
||||
for finding in &findings {
|
||||
if let Err(e) = db.dast_findings().insert_one(finding).await {
|
||||
tracing::error!("Failed to store DAST finding: {e}");
|
||||
}
|
||||
}
|
||||
tracing::info!("DAST scan complete: {} findings", findings.len());
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!("DAST scan failed: {e}");
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
/// Build an issue tracker client from a repository's tracker configuration.
|
||||
/// Returns `None` if the repo has no tracker configured.
|
||||
fn build_tracker(&self, repo: &TrackedRepository) -> Option<TrackerDispatch> {
|
||||
let tracker_type = repo.tracker_type.as_ref()?;
|
||||
// Per-repo token takes precedence, fall back to global config
|
||||
match tracker_type {
|
||||
TrackerType::GitHub => {
|
||||
let token = repo.tracker_token.clone().or_else(|| {
|
||||
self.config.github_token.as_ref().map(|t| {
|
||||
use secrecy::ExposeSecret;
|
||||
t.expose_secret().to_string()
|
||||
})
|
||||
})?;
|
||||
let secret = secrecy::SecretString::from(token);
|
||||
match trackers::github::GitHubTracker::new(&secret) {
|
||||
Ok(t) => Some(TrackerDispatch::GitHub(t)),
|
||||
Err(e) => {
|
||||
tracing::warn!("Failed to build GitHub tracker: {e}");
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
TrackerType::GitLab => {
|
||||
let base_url = self
|
||||
.config
|
||||
.gitlab_url
|
||||
.clone()
|
||||
.unwrap_or_else(|| "https://gitlab.com".to_string());
|
||||
let token = repo.tracker_token.clone().or_else(|| {
|
||||
self.config.gitlab_token.as_ref().map(|t| {
|
||||
use secrecy::ExposeSecret;
|
||||
t.expose_secret().to_string()
|
||||
})
|
||||
})?;
|
||||
let secret = secrecy::SecretString::from(token);
|
||||
Some(TrackerDispatch::GitLab(
|
||||
trackers::gitlab::GitLabTracker::new(base_url, secret),
|
||||
))
|
||||
}
|
||||
TrackerType::Gitea => {
|
||||
let token = repo.tracker_token.clone()?;
|
||||
let base_url = extract_base_url(&repo.git_url)?;
|
||||
let secret = secrecy::SecretString::from(token);
|
||||
Some(TrackerDispatch::Gitea(trackers::gitea::GiteaTracker::new(
|
||||
base_url, secret,
|
||||
)))
|
||||
}
|
||||
TrackerType::Jira => {
|
||||
let base_url = self.config.jira_url.clone()?;
|
||||
let email = self.config.jira_email.clone()?;
|
||||
let project_key = self.config.jira_project_key.clone()?;
|
||||
let token = repo.tracker_token.clone().or_else(|| {
|
||||
self.config.jira_api_token.as_ref().map(|t| {
|
||||
use secrecy::ExposeSecret;
|
||||
t.expose_secret().to_string()
|
||||
})
|
||||
})?;
|
||||
let secret = secrecy::SecretString::from(token);
|
||||
Some(TrackerDispatch::Jira(trackers::jira::JiraTracker::new(
|
||||
base_url,
|
||||
email,
|
||||
secret,
|
||||
project_key,
|
||||
)))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Create tracker issues for new findings (severity >= Medium).
|
||||
/// Checks for duplicates via fingerprint search before creating.
|
||||
#[tracing::instrument(skip_all, fields(repo_id = %repo_id))]
|
||||
async fn create_tracker_issues(
|
||||
&self,
|
||||
repo: &TrackedRepository,
|
||||
repo_id: &str,
|
||||
new_findings: &[Finding],
|
||||
) -> Result<(), AgentError> {
|
||||
let tracker = match self.build_tracker(repo) {
|
||||
Some(t) => t,
|
||||
None => {
|
||||
tracing::info!("[{repo_id}] No issue tracker configured, skipping");
|
||||
return Ok(());
|
||||
}
|
||||
};
|
||||
|
||||
let owner = match repo.tracker_owner.as_deref() {
|
||||
Some(o) => o,
|
||||
None => {
|
||||
tracing::warn!("[{repo_id}] tracker_owner not set, skipping issue creation");
|
||||
return Ok(());
|
||||
}
|
||||
};
|
||||
let tracker_repo_name = match repo.tracker_repo.as_deref() {
|
||||
Some(r) => r,
|
||||
None => {
|
||||
tracing::warn!("[{repo_id}] tracker_repo not set, skipping issue creation");
|
||||
return Ok(());
|
||||
}
|
||||
};
|
||||
|
||||
// Only create issues for medium+ severity findings
|
||||
let actionable: Vec<&Finding> = new_findings
|
||||
.iter()
|
||||
.filter(|f| {
|
||||
matches!(
|
||||
f.severity,
|
||||
Severity::Medium | Severity::High | Severity::Critical
|
||||
)
|
||||
})
|
||||
.collect();
|
||||
|
||||
if actionable.is_empty() {
|
||||
tracing::info!("[{repo_id}] No medium+ findings, skipping issue creation");
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
tracing::info!(
|
||||
"[{repo_id}] Creating issues for {} findings via {}",
|
||||
actionable.len(),
|
||||
tracker.name()
|
||||
);
|
||||
|
||||
let mut created = 0u32;
|
||||
for finding in actionable {
|
||||
let title = format!(
|
||||
"[{}] {}: {}",
|
||||
finding.severity, finding.scanner, finding.title
|
||||
);
|
||||
|
||||
// Check if an issue already exists by fingerprint first, then by title
|
||||
let mut found_existing = false;
|
||||
for search_term in [&finding.fingerprint, &title] {
|
||||
match tracker
|
||||
.find_existing_issue(owner, tracker_repo_name, search_term)
|
||||
.await
|
||||
{
|
||||
Ok(Some(existing)) => {
|
||||
tracing::debug!(
|
||||
"[{repo_id}] Issue already exists for '{}': {}",
|
||||
search_term,
|
||||
existing.external_url
|
||||
);
|
||||
found_existing = true;
|
||||
break;
|
||||
}
|
||||
Ok(None) => {}
|
||||
Err(e) => {
|
||||
tracing::warn!("[{repo_id}] Failed to search for existing issue: {e}");
|
||||
}
|
||||
}
|
||||
}
|
||||
if found_existing {
|
||||
continue;
|
||||
}
|
||||
let body = format_issue_body(finding);
|
||||
let labels = vec![
|
||||
format!("severity:{}", finding.severity),
|
||||
format!("scanner:{}", finding.scanner),
|
||||
"compliance-scanner".to_string(),
|
||||
];
|
||||
|
||||
match tracker
|
||||
.create_issue(owner, tracker_repo_name, &title, &body, &labels)
|
||||
.await
|
||||
{
|
||||
Ok(mut issue) => {
|
||||
issue.finding_id = finding
|
||||
.id
|
||||
.as_ref()
|
||||
.map(|id| id.to_hex())
|
||||
.unwrap_or_default();
|
||||
|
||||
// Update the finding with the issue URL
|
||||
if let Some(finding_id) = &finding.id {
|
||||
let _ = self
|
||||
.db
|
||||
.findings()
|
||||
.update_one(
|
||||
doc! { "_id": finding_id },
|
||||
doc! { "$set": { "tracker_issue_url": &issue.external_url } },
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
// Store the tracker issue record
|
||||
if let Err(e) = self.db.tracker_issues().insert_one(&issue).await {
|
||||
tracing::warn!("[{repo_id}] Failed to store tracker issue: {e}");
|
||||
}
|
||||
|
||||
created += 1;
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!(
|
||||
"[{repo_id}] Failed to create issue for {}: {e}",
|
||||
finding.fingerprint
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
tracing::info!("[{repo_id}] Created {created} tracker issues");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Run an incremental scan on a PR diff and post review comments.
|
||||
#[tracing::instrument(skip_all, fields(repo_id = %repo_id, pr_number))]
|
||||
pub async fn run_pr_review(
|
||||
&self,
|
||||
repo: &TrackedRepository,
|
||||
repo_id: &str,
|
||||
pr_number: u64,
|
||||
base_sha: &str,
|
||||
head_sha: &str,
|
||||
) -> Result<(), AgentError> {
|
||||
let tracker = match self.build_tracker(repo) {
|
||||
Some(t) => t,
|
||||
None => {
|
||||
tracing::warn!("[{repo_id}] No tracker configured, cannot post PR review");
|
||||
return Ok(());
|
||||
}
|
||||
};
|
||||
let owner = repo.tracker_owner.as_deref().unwrap_or("");
|
||||
let tracker_repo_name = repo.tracker_repo.as_deref().unwrap_or("");
|
||||
if owner.is_empty() || tracker_repo_name.is_empty() {
|
||||
tracing::warn!("[{repo_id}] tracker_owner or tracker_repo not set");
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Clone/fetch the repo
|
||||
let creds = GitOps::make_repo_credentials(&self.config, repo);
|
||||
let git_ops = GitOps::new(&self.config.git_clone_base_path, creds);
|
||||
let repo_path = git_ops.clone_or_fetch(&repo.git_url, &repo.name)?;
|
||||
|
||||
// Get diff between base and head
|
||||
let diff_files = GitOps::get_diff_content(&repo_path, base_sha, head_sha)?;
|
||||
if diff_files.is_empty() {
|
||||
tracing::info!("[{repo_id}] PR #{pr_number}: no diff files, skipping review");
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Run semgrep on the full repo but we'll filter findings to changed files
|
||||
let changed_paths: std::collections::HashSet<String> =
|
||||
diff_files.iter().map(|f| f.path.clone()).collect();
|
||||
|
||||
let mut pr_findings: Vec<Finding> = Vec::new();
|
||||
|
||||
// SAST scan (semgrep)
|
||||
match SemgrepScanner.scan(&repo_path, repo_id).await {
|
||||
Ok(output) => {
|
||||
for f in output.findings {
|
||||
if let Some(fp) = &f.file_path {
|
||||
if changed_paths.contains(fp.as_str()) {
|
||||
pr_findings.push(f);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => tracing::warn!("[{repo_id}] PR semgrep failed: {e}"),
|
||||
}
|
||||
|
||||
// LLM code review on the diff
|
||||
let reviewer = CodeReviewScanner::new(self.llm.clone());
|
||||
let review_output = reviewer
|
||||
.review_diff(&repo_path, repo_id, base_sha, head_sha)
|
||||
.await;
|
||||
pr_findings.extend(review_output.findings);
|
||||
|
||||
if pr_findings.is_empty() {
|
||||
// Post a clean review
|
||||
if let Err(e) = tracker
|
||||
.create_pr_review(
|
||||
owner,
|
||||
tracker_repo_name,
|
||||
pr_number,
|
||||
"Compliance scan: no issues found in this PR.",
|
||||
Vec::new(),
|
||||
)
|
||||
.await
|
||||
{
|
||||
tracing::warn!("[{repo_id}] Failed to post clean PR review: {e}");
|
||||
}
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Build review comments from findings
|
||||
let mut review_comments = Vec::new();
|
||||
for finding in &pr_findings {
|
||||
if let (Some(path), Some(line)) = (&finding.file_path, finding.line_number) {
|
||||
let comment_body = format!(
|
||||
"**[{}] {}**\n\n{}\n\n*Scanner: {} | {}*",
|
||||
finding.severity,
|
||||
finding.title,
|
||||
finding.description,
|
||||
finding.scanner,
|
||||
finding
|
||||
.cwe
|
||||
.as_deref()
|
||||
.map(|c| format!("CWE: {c}"))
|
||||
.unwrap_or_default(),
|
||||
);
|
||||
review_comments.push(compliance_core::traits::issue_tracker::ReviewComment {
|
||||
path: path.clone(),
|
||||
line,
|
||||
body: comment_body,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
let summary = format!(
|
||||
"Compliance scan found **{}** issue(s) in this PR:\n\n{}",
|
||||
pr_findings.len(),
|
||||
pr_findings
|
||||
.iter()
|
||||
.map(|f| format!("- **[{}]** {}: {}", f.severity, f.scanner, f.title))
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n"),
|
||||
);
|
||||
|
||||
if let Err(e) = tracker
|
||||
.create_pr_review(
|
||||
owner,
|
||||
tracker_repo_name,
|
||||
pr_number,
|
||||
&summary,
|
||||
review_comments,
|
||||
)
|
||||
.await
|
||||
{
|
||||
tracing::warn!("[{repo_id}] Failed to post PR review: {e}");
|
||||
} else {
|
||||
tracing::info!(
|
||||
"[{repo_id}] Posted PR review on #{pr_number} with {} findings",
|
||||
pr_findings.len()
|
||||
);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn update_phase(&self, scan_run_id: &str, phase: &str) {
|
||||
pub(super) async fn update_phase(&self, scan_run_id: &str, phase: &str) {
|
||||
if let Ok(oid) = mongodb::bson::oid::ObjectId::parse_str(scan_run_id) {
|
||||
let _ = self
|
||||
.db
|
||||
@@ -917,9 +399,9 @@ impl PipelineOrchestrator {
|
||||
}
|
||||
|
||||
/// Extract the scheme + host from a git URL.
|
||||
/// e.g. "https://gitea.example.com/owner/repo.git" → "https://gitea.example.com"
|
||||
/// e.g. "ssh://git@gitea.example.com:22/owner/repo.git" → "https://gitea.example.com"
|
||||
fn extract_base_url(git_url: &str) -> Option<String> {
|
||||
/// e.g. "https://gitea.example.com/owner/repo.git" -> "https://gitea.example.com"
|
||||
/// e.g. "ssh://git@gitea.example.com:22/owner/repo.git" -> "https://gitea.example.com"
|
||||
pub(super) fn extract_base_url(git_url: &str) -> Option<String> {
|
||||
if let Some(rest) = git_url.strip_prefix("https://") {
|
||||
let host = rest.split('/').next()?;
|
||||
Some(format!("https://{host}"))
|
||||
@@ -927,7 +409,7 @@ fn extract_base_url(git_url: &str) -> Option<String> {
|
||||
let host = rest.split('/').next()?;
|
||||
Some(format!("http://{host}"))
|
||||
} else if let Some(rest) = git_url.strip_prefix("ssh://") {
|
||||
// ssh://git@host:port/path → extract host
|
||||
// ssh://git@host:port/path -> extract host
|
||||
let after_at = rest.find('@').map(|i| &rest[i + 1..]).unwrap_or(rest);
|
||||
let host = after_at.split(&[':', '/'][..]).next()?;
|
||||
Some(format!("https://{host}"))
|
||||
@@ -940,48 +422,3 @@ fn extract_base_url(git_url: &str) -> Option<String> {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
/// Format a finding into a markdown issue body for the tracker.
|
||||
fn format_issue_body(finding: &Finding) -> String {
|
||||
let mut body = String::new();
|
||||
|
||||
body.push_str(&format!("## {} Finding\n\n", finding.severity));
|
||||
body.push_str(&format!("**Scanner:** {}\n", finding.scanner));
|
||||
body.push_str(&format!("**Severity:** {}\n", finding.severity));
|
||||
|
||||
if let Some(rule) = &finding.rule_id {
|
||||
body.push_str(&format!("**Rule:** {}\n", rule));
|
||||
}
|
||||
if let Some(cwe) = &finding.cwe {
|
||||
body.push_str(&format!("**CWE:** {}\n", cwe));
|
||||
}
|
||||
|
||||
body.push_str(&format!("\n### Description\n\n{}\n", finding.description));
|
||||
|
||||
if let Some(file_path) = &finding.file_path {
|
||||
body.push_str(&format!("\n### Location\n\n**File:** `{}`", file_path));
|
||||
if let Some(line) = finding.line_number {
|
||||
body.push_str(&format!(" (line {})", line));
|
||||
}
|
||||
body.push('\n');
|
||||
}
|
||||
|
||||
if let Some(snippet) = &finding.code_snippet {
|
||||
body.push_str(&format!("\n### Code\n\n```\n{}\n```\n", snippet));
|
||||
}
|
||||
|
||||
if let Some(remediation) = &finding.remediation {
|
||||
body.push_str(&format!("\n### Remediation\n\n{}\n", remediation));
|
||||
}
|
||||
|
||||
if let Some(fix) = &finding.suggested_fix {
|
||||
body.push_str(&format!("\n### Suggested Fix\n\n```\n{}\n```\n", fix));
|
||||
}
|
||||
|
||||
body.push_str(&format!(
|
||||
"\n---\n*Fingerprint:* `{}`\n*Generated by compliance-scanner*",
|
||||
finding.fingerprint
|
||||
));
|
||||
|
||||
body
|
||||
}
|
||||
|
||||
@@ -256,3 +256,159 @@ fn walkdir(path: &Path) -> Result<Vec<walkdir::DirEntry>, CoreError> {
|
||||
|
||||
Ok(entries)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
// --- compile_regex tests ---
|
||||
|
||||
#[test]
|
||||
fn compile_regex_valid_pattern() {
|
||||
let re = compile_regex(r"\bfoo\b");
|
||||
assert!(re.is_match("hello foo bar"));
|
||||
assert!(!re.is_match("foobar"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn compile_regex_invalid_pattern_returns_fallback() {
|
||||
// An invalid regex should return the fallback "^$" that only matches empty strings
|
||||
let re = compile_regex(r"[invalid");
|
||||
assert!(re.is_match(""));
|
||||
assert!(!re.is_match("anything"));
|
||||
}
|
||||
|
||||
// --- GDPR pattern tests ---
|
||||
|
||||
#[test]
|
||||
fn gdpr_pii_logging_matches() {
|
||||
let scanner = GdprPatternScanner::new();
|
||||
let pattern = &scanner.patterns[0]; // gdpr-pii-logging
|
||||
// Regex: (log|print|console\.|logger\.|tracing::)\s*[\.(].*\b(pii_keyword)\b
|
||||
assert!(pattern.pattern.is_match("console.log(email)"));
|
||||
assert!(pattern.pattern.is_match("console.log(user.ssn)"));
|
||||
assert!(pattern.pattern.is_match("print(phone_number)"));
|
||||
assert!(pattern.pattern.is_match("tracing::(ip_addr)"));
|
||||
assert!(pattern.pattern.is_match("log.debug(credit_card)"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn gdpr_pii_logging_no_false_positive() {
|
||||
let scanner = GdprPatternScanner::new();
|
||||
let pattern = &scanner.patterns[0];
|
||||
// Regular logging without PII fields should not match
|
||||
assert!(!pattern
|
||||
.pattern
|
||||
.is_match("logger.info(\"request completed\")"));
|
||||
assert!(!pattern.pattern.is_match("let email = user.email;"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn gdpr_no_consent_matches() {
|
||||
let scanner = GdprPatternScanner::new();
|
||||
let pattern = &scanner.patterns[1]; // gdpr-no-consent
|
||||
assert!(pattern.pattern.is_match("collect personal data"));
|
||||
assert!(pattern.pattern.is_match("store user_data in db"));
|
||||
assert!(pattern.pattern.is_match("save pii to disk"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn gdpr_user_model_matches() {
|
||||
let scanner = GdprPatternScanner::new();
|
||||
let pattern = &scanner.patterns[2]; // gdpr-no-delete-endpoint
|
||||
assert!(pattern.pattern.is_match("struct User {"));
|
||||
assert!(pattern.pattern.is_match("class User(Model):"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn gdpr_hardcoded_retention_matches() {
|
||||
let scanner = GdprPatternScanner::new();
|
||||
let pattern = &scanner.patterns[3]; // gdpr-hardcoded-retention
|
||||
assert!(pattern.pattern.is_match("retention = 30"));
|
||||
assert!(pattern.pattern.is_match("ttl: 3600"));
|
||||
assert!(pattern.pattern.is_match("expire = 86400"));
|
||||
}
|
||||
|
||||
// --- OAuth pattern tests ---
|
||||
|
||||
#[test]
|
||||
fn oauth_implicit_grant_matches() {
|
||||
let scanner = OAuthPatternScanner::new();
|
||||
let pattern = &scanner.patterns[0]; // oauth-implicit-grant
|
||||
assert!(pattern.pattern.is_match("response_type = \"token\""));
|
||||
assert!(pattern.pattern.is_match("grant_type: implicit"));
|
||||
assert!(pattern.pattern.is_match("response_type='token'"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn oauth_implicit_grant_no_false_positive() {
|
||||
let scanner = OAuthPatternScanner::new();
|
||||
let pattern = &scanner.patterns[0];
|
||||
assert!(!pattern.pattern.is_match("response_type = \"code\""));
|
||||
assert!(!pattern.pattern.is_match("grant_type: authorization_code"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn oauth_authorization_code_matches() {
|
||||
let scanner = OAuthPatternScanner::new();
|
||||
let pattern = &scanner.patterns[1]; // oauth-missing-pkce
|
||||
assert!(pattern.pattern.is_match("uses authorization_code flow"));
|
||||
assert!(pattern.pattern.is_match("authorization code grant"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn oauth_token_localstorage_matches() {
|
||||
let scanner = OAuthPatternScanner::new();
|
||||
let pattern = &scanner.patterns[2]; // oauth-token-localstorage
|
||||
assert!(pattern
|
||||
.pattern
|
||||
.is_match("localStorage.setItem('access_token', tok)"));
|
||||
assert!(pattern
|
||||
.pattern
|
||||
.is_match("localStorage.getItem(\"refresh_token\")"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn oauth_token_localstorage_no_false_positive() {
|
||||
let scanner = OAuthPatternScanner::new();
|
||||
let pattern = &scanner.patterns[2];
|
||||
assert!(!pattern
|
||||
.pattern
|
||||
.is_match("localStorage.setItem('theme', 'dark')"));
|
||||
assert!(!pattern
|
||||
.pattern
|
||||
.is_match("sessionStorage.setItem('token', t)"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn oauth_token_url_matches() {
|
||||
let scanner = OAuthPatternScanner::new();
|
||||
let pattern = &scanner.patterns[3]; // oauth-token-url
|
||||
assert!(pattern.pattern.is_match("access_token = build_url(query)"));
|
||||
assert!(pattern.pattern.is_match("bearer = url.param"));
|
||||
}
|
||||
|
||||
// --- Pattern rule file extension filtering ---
|
||||
|
||||
#[test]
|
||||
fn gdpr_patterns_cover_common_languages() {
|
||||
let scanner = GdprPatternScanner::new();
|
||||
for pattern in &scanner.patterns {
|
||||
assert!(
|
||||
pattern.file_extensions.contains(&"rs".to_string()),
|
||||
"Pattern {} should cover .rs files",
|
||||
pattern.id
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn oauth_localstorage_only_js_ts() {
|
||||
let scanner = OAuthPatternScanner::new();
|
||||
let pattern = &scanner.patterns[2]; // oauth-token-localstorage
|
||||
assert!(pattern.file_extensions.contains(&"js".to_string()));
|
||||
assert!(pattern.file_extensions.contains(&"ts".to_string()));
|
||||
assert!(!pattern.file_extensions.contains(&"rs".to_string()));
|
||||
assert!(!pattern.file_extensions.contains(&"py".to_string()));
|
||||
}
|
||||
}
|
||||
|
||||
146
compliance-agent/src/pipeline/pr_review.rs
Normal file
146
compliance-agent/src/pipeline/pr_review.rs
Normal file
@@ -0,0 +1,146 @@
|
||||
use compliance_core::models::*;
|
||||
|
||||
use super::orchestrator::PipelineOrchestrator;
|
||||
use crate::error::AgentError;
|
||||
use crate::pipeline::code_review::CodeReviewScanner;
|
||||
use crate::pipeline::git::GitOps;
|
||||
use crate::pipeline::semgrep::SemgrepScanner;
|
||||
|
||||
use compliance_core::traits::Scanner;
|
||||
|
||||
impl PipelineOrchestrator {
|
||||
/// Run an incremental scan on a PR diff and post review comments.
|
||||
#[tracing::instrument(skip_all, fields(repo_id = %repo_id, pr_number))]
|
||||
pub async fn run_pr_review(
|
||||
&self,
|
||||
repo: &TrackedRepository,
|
||||
repo_id: &str,
|
||||
pr_number: u64,
|
||||
base_sha: &str,
|
||||
head_sha: &str,
|
||||
) -> Result<(), AgentError> {
|
||||
let tracker = match self.build_tracker(repo) {
|
||||
Some(t) => t,
|
||||
None => {
|
||||
tracing::warn!("[{repo_id}] No tracker configured, cannot post PR review");
|
||||
return Ok(());
|
||||
}
|
||||
};
|
||||
let owner = repo.tracker_owner.as_deref().unwrap_or("");
|
||||
let tracker_repo_name = repo.tracker_repo.as_deref().unwrap_or("");
|
||||
if owner.is_empty() || tracker_repo_name.is_empty() {
|
||||
tracing::warn!("[{repo_id}] tracker_owner or tracker_repo not set");
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Clone/fetch the repo
|
||||
let creds = GitOps::make_repo_credentials(&self.config, repo);
|
||||
let git_ops = GitOps::new(&self.config.git_clone_base_path, creds);
|
||||
let repo_path = git_ops.clone_or_fetch(&repo.git_url, &repo.name)?;
|
||||
|
||||
// Get diff between base and head
|
||||
let diff_files = GitOps::get_diff_content(&repo_path, base_sha, head_sha)?;
|
||||
if diff_files.is_empty() {
|
||||
tracing::info!("[{repo_id}] PR #{pr_number}: no diff files, skipping review");
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Run semgrep on the full repo but we'll filter findings to changed files
|
||||
let changed_paths: std::collections::HashSet<String> =
|
||||
diff_files.iter().map(|f| f.path.clone()).collect();
|
||||
|
||||
let mut pr_findings: Vec<Finding> = Vec::new();
|
||||
|
||||
// SAST scan (semgrep)
|
||||
match SemgrepScanner.scan(&repo_path, repo_id).await {
|
||||
Ok(output) => {
|
||||
for f in output.findings {
|
||||
if let Some(fp) = &f.file_path {
|
||||
if changed_paths.contains(fp.as_str()) {
|
||||
pr_findings.push(f);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => tracing::warn!("[{repo_id}] PR semgrep failed: {e}"),
|
||||
}
|
||||
|
||||
// LLM code review on the diff
|
||||
let reviewer = CodeReviewScanner::new(self.llm.clone());
|
||||
let review_output = reviewer
|
||||
.review_diff(&repo_path, repo_id, base_sha, head_sha)
|
||||
.await;
|
||||
pr_findings.extend(review_output.findings);
|
||||
|
||||
if pr_findings.is_empty() {
|
||||
// Post a clean review
|
||||
if let Err(e) = tracker
|
||||
.create_pr_review(
|
||||
owner,
|
||||
tracker_repo_name,
|
||||
pr_number,
|
||||
"Compliance scan: no issues found in this PR.",
|
||||
Vec::new(),
|
||||
)
|
||||
.await
|
||||
{
|
||||
tracing::warn!("[{repo_id}] Failed to post clean PR review: {e}");
|
||||
}
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Build review comments from findings
|
||||
let mut review_comments = Vec::new();
|
||||
for finding in &pr_findings {
|
||||
if let (Some(path), Some(line)) = (&finding.file_path, finding.line_number) {
|
||||
let comment_body = format!(
|
||||
"**[{}] {}**\n\n{}\n\n*Scanner: {} | {}*",
|
||||
finding.severity,
|
||||
finding.title,
|
||||
finding.description,
|
||||
finding.scanner,
|
||||
finding
|
||||
.cwe
|
||||
.as_deref()
|
||||
.map(|c| format!("CWE: {c}"))
|
||||
.unwrap_or_default(),
|
||||
);
|
||||
review_comments.push(compliance_core::traits::issue_tracker::ReviewComment {
|
||||
path: path.clone(),
|
||||
line,
|
||||
body: comment_body,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
let summary = format!(
|
||||
"Compliance scan found **{}** issue(s) in this PR:\n\n{}",
|
||||
pr_findings.len(),
|
||||
pr_findings
|
||||
.iter()
|
||||
.map(|f| format!("- **[{}]** {}: {}", f.severity, f.scanner, f.title))
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n"),
|
||||
);
|
||||
|
||||
if let Err(e) = tracker
|
||||
.create_pr_review(
|
||||
owner,
|
||||
tracker_repo_name,
|
||||
pr_number,
|
||||
&summary,
|
||||
review_comments,
|
||||
)
|
||||
.await
|
||||
{
|
||||
tracing::warn!("[{repo_id}] Failed to post PR review: {e}");
|
||||
} else {
|
||||
tracing::info!(
|
||||
"[{repo_id}] Posted PR review on #{pr_number} with {} findings",
|
||||
pr_findings.len()
|
||||
);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
72
compliance-agent/src/pipeline/sbom/cargo_audit.rs
Normal file
72
compliance-agent/src/pipeline/sbom/cargo_audit.rs
Normal file
@@ -0,0 +1,72 @@
|
||||
use std::path::Path;
|
||||
|
||||
use compliance_core::CoreError;
|
||||
|
||||
pub(super) struct AuditVuln {
|
||||
pub package: String,
|
||||
pub id: String,
|
||||
pub url: String,
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip_all)]
|
||||
pub(super) async fn run_cargo_audit(
|
||||
repo_path: &Path,
|
||||
_repo_id: &str,
|
||||
) -> Result<Vec<AuditVuln>, CoreError> {
|
||||
let cargo_lock = repo_path.join("Cargo.lock");
|
||||
if !cargo_lock.exists() {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
|
||||
let output = tokio::process::Command::new("cargo")
|
||||
.args(["audit", "--json"])
|
||||
.current_dir(repo_path)
|
||||
.env("RUSTC_WRAPPER", "")
|
||||
.output()
|
||||
.await
|
||||
.map_err(|e| CoreError::Scanner {
|
||||
scanner: "cargo-audit".to_string(),
|
||||
source: Box::new(e),
|
||||
})?;
|
||||
|
||||
let result: CargoAuditOutput =
|
||||
serde_json::from_slice(&output.stdout).unwrap_or_else(|_| CargoAuditOutput {
|
||||
vulnerabilities: CargoAuditVulns { list: Vec::new() },
|
||||
});
|
||||
|
||||
let vulns = result
|
||||
.vulnerabilities
|
||||
.list
|
||||
.into_iter()
|
||||
.map(|v| AuditVuln {
|
||||
package: v.advisory.package,
|
||||
id: v.advisory.id,
|
||||
url: v.advisory.url,
|
||||
})
|
||||
.collect();
|
||||
|
||||
Ok(vulns)
|
||||
}
|
||||
|
||||
// Cargo audit types
|
||||
#[derive(serde::Deserialize)]
|
||||
struct CargoAuditOutput {
|
||||
vulnerabilities: CargoAuditVulns,
|
||||
}
|
||||
|
||||
#[derive(serde::Deserialize)]
|
||||
struct CargoAuditVulns {
|
||||
list: Vec<CargoAuditEntry>,
|
||||
}
|
||||
|
||||
#[derive(serde::Deserialize)]
|
||||
struct CargoAuditEntry {
|
||||
advisory: CargoAuditAdvisory,
|
||||
}
|
||||
|
||||
#[derive(serde::Deserialize)]
|
||||
struct CargoAuditAdvisory {
|
||||
id: String,
|
||||
package: String,
|
||||
url: String,
|
||||
}
|
||||
@@ -1,3 +1,6 @@
|
||||
mod cargo_audit;
|
||||
mod syft;
|
||||
|
||||
use std::path::Path;
|
||||
|
||||
use compliance_core::models::{SbomEntry, ScanType, VulnRef};
|
||||
@@ -23,7 +26,7 @@ impl Scanner for SbomScanner {
|
||||
generate_lockfiles(repo_path).await;
|
||||
|
||||
// Run syft for SBOM generation
|
||||
match run_syft(repo_path, repo_id).await {
|
||||
match syft::run_syft(repo_path, repo_id).await {
|
||||
Ok(syft_entries) => entries.extend(syft_entries),
|
||||
Err(e) => tracing::warn!("syft failed: {e}"),
|
||||
}
|
||||
@@ -32,7 +35,7 @@ impl Scanner for SbomScanner {
|
||||
enrich_cargo_licenses(repo_path, &mut entries).await;
|
||||
|
||||
// Run cargo-audit for Rust-specific vulns
|
||||
match run_cargo_audit(repo_path, repo_id).await {
|
||||
match cargo_audit::run_cargo_audit(repo_path, repo_id).await {
|
||||
Ok(vulns) => merge_audit_vulns(&mut entries, vulns),
|
||||
Err(e) => tracing::warn!("cargo-audit skipped: {e}"),
|
||||
}
|
||||
@@ -186,95 +189,7 @@ async fn enrich_cargo_licenses(repo_path: &Path, entries: &mut [SbomEntry]) {
|
||||
}
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip_all, fields(repo_id = %repo_id))]
|
||||
async fn run_syft(repo_path: &Path, repo_id: &str) -> Result<Vec<SbomEntry>, CoreError> {
|
||||
let output = tokio::process::Command::new("syft")
|
||||
.arg(repo_path)
|
||||
.args(["-o", "cyclonedx-json"])
|
||||
// Enable remote license lookups for all ecosystems
|
||||
.env("SYFT_GOLANG_SEARCH_REMOTE_LICENSES", "true")
|
||||
.env("SYFT_JAVASCRIPT_SEARCH_REMOTE_LICENSES", "true")
|
||||
.env("SYFT_PYTHON_SEARCH_REMOTE_LICENSES", "true")
|
||||
.env("SYFT_JAVA_USE_NETWORK", "true")
|
||||
.output()
|
||||
.await
|
||||
.map_err(|e| CoreError::Scanner {
|
||||
scanner: "syft".to_string(),
|
||||
source: Box::new(e),
|
||||
})?;
|
||||
|
||||
if !output.status.success() {
|
||||
let stderr = String::from_utf8_lossy(&output.stderr);
|
||||
return Err(CoreError::Scanner {
|
||||
scanner: "syft".to_string(),
|
||||
source: format!("syft exited with {}: {stderr}", output.status).into(),
|
||||
});
|
||||
}
|
||||
|
||||
let cdx: CycloneDxBom = serde_json::from_slice(&output.stdout)?;
|
||||
let entries = cdx
|
||||
.components
|
||||
.unwrap_or_default()
|
||||
.into_iter()
|
||||
.map(|c| {
|
||||
let package_manager = c
|
||||
.purl
|
||||
.as_deref()
|
||||
.and_then(extract_ecosystem_from_purl)
|
||||
.unwrap_or_else(|| "unknown".to_string());
|
||||
let mut entry = SbomEntry::new(
|
||||
repo_id.to_string(),
|
||||
c.name,
|
||||
c.version.unwrap_or_else(|| "unknown".to_string()),
|
||||
package_manager,
|
||||
);
|
||||
entry.purl = c.purl;
|
||||
entry.license = c.licenses.and_then(|ls| extract_license(&ls));
|
||||
entry
|
||||
})
|
||||
.collect();
|
||||
|
||||
Ok(entries)
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip_all)]
|
||||
async fn run_cargo_audit(repo_path: &Path, _repo_id: &str) -> Result<Vec<AuditVuln>, CoreError> {
|
||||
let cargo_lock = repo_path.join("Cargo.lock");
|
||||
if !cargo_lock.exists() {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
|
||||
let output = tokio::process::Command::new("cargo")
|
||||
.args(["audit", "--json"])
|
||||
.current_dir(repo_path)
|
||||
.env("RUSTC_WRAPPER", "")
|
||||
.output()
|
||||
.await
|
||||
.map_err(|e| CoreError::Scanner {
|
||||
scanner: "cargo-audit".to_string(),
|
||||
source: Box::new(e),
|
||||
})?;
|
||||
|
||||
let result: CargoAuditOutput =
|
||||
serde_json::from_slice(&output.stdout).unwrap_or_else(|_| CargoAuditOutput {
|
||||
vulnerabilities: CargoAuditVulns { list: Vec::new() },
|
||||
});
|
||||
|
||||
let vulns = result
|
||||
.vulnerabilities
|
||||
.list
|
||||
.into_iter()
|
||||
.map(|v| AuditVuln {
|
||||
package: v.advisory.package,
|
||||
id: v.advisory.id,
|
||||
url: v.advisory.url,
|
||||
})
|
||||
.collect();
|
||||
|
||||
Ok(vulns)
|
||||
}
|
||||
|
||||
fn merge_audit_vulns(entries: &mut [SbomEntry], vulns: Vec<AuditVuln>) {
|
||||
fn merge_audit_vulns(entries: &mut [SbomEntry], vulns: Vec<cargo_audit::AuditVuln>) {
|
||||
for vuln in vulns {
|
||||
if let Some(entry) = entries.iter_mut().find(|e| e.name == vuln.package) {
|
||||
entry.known_vulnerabilities.push(VulnRef {
|
||||
@@ -287,65 +202,6 @@ fn merge_audit_vulns(entries: &mut [SbomEntry], vulns: Vec<AuditVuln>) {
|
||||
}
|
||||
}
|
||||
|
||||
// CycloneDX JSON types
|
||||
#[derive(serde::Deserialize)]
|
||||
struct CycloneDxBom {
|
||||
components: Option<Vec<CdxComponent>>,
|
||||
}
|
||||
|
||||
#[derive(serde::Deserialize)]
|
||||
struct CdxComponent {
|
||||
name: String,
|
||||
version: Option<String>,
|
||||
#[serde(rename = "type")]
|
||||
#[allow(dead_code)]
|
||||
component_type: Option<String>,
|
||||
purl: Option<String>,
|
||||
licenses: Option<Vec<CdxLicenseWrapper>>,
|
||||
}
|
||||
|
||||
#[derive(serde::Deserialize)]
|
||||
struct CdxLicenseWrapper {
|
||||
license: Option<CdxLicense>,
|
||||
/// SPDX license expression (e.g. "MIT OR Apache-2.0")
|
||||
expression: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(serde::Deserialize)]
|
||||
struct CdxLicense {
|
||||
id: Option<String>,
|
||||
name: Option<String>,
|
||||
}
|
||||
|
||||
// Cargo audit types
|
||||
#[derive(serde::Deserialize)]
|
||||
struct CargoAuditOutput {
|
||||
vulnerabilities: CargoAuditVulns,
|
||||
}
|
||||
|
||||
#[derive(serde::Deserialize)]
|
||||
struct CargoAuditVulns {
|
||||
list: Vec<CargoAuditEntry>,
|
||||
}
|
||||
|
||||
#[derive(serde::Deserialize)]
|
||||
struct CargoAuditEntry {
|
||||
advisory: CargoAuditAdvisory,
|
||||
}
|
||||
|
||||
#[derive(serde::Deserialize)]
|
||||
struct CargoAuditAdvisory {
|
||||
id: String,
|
||||
package: String,
|
||||
url: String,
|
||||
}
|
||||
|
||||
struct AuditVuln {
|
||||
package: String,
|
||||
id: String,
|
||||
url: String,
|
||||
}
|
||||
|
||||
// Cargo metadata types
|
||||
#[derive(serde::Deserialize)]
|
||||
struct CargoMetadata {
|
||||
@@ -358,49 +214,3 @@ struct CargoPackage {
|
||||
version: String,
|
||||
license: Option<String>,
|
||||
}
|
||||
|
||||
/// Extract the best license string from CycloneDX license entries.
|
||||
/// Handles three formats: expression ("MIT OR Apache-2.0"), license.id ("MIT"), license.name ("MIT License").
|
||||
fn extract_license(entries: &[CdxLicenseWrapper]) -> Option<String> {
|
||||
// First pass: look for SPDX expressions (most precise for dual-licensed packages)
|
||||
for entry in entries {
|
||||
if let Some(ref expr) = entry.expression {
|
||||
if !expr.is_empty() {
|
||||
return Some(expr.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
// Second pass: collect license.id or license.name from all entries
|
||||
let parts: Vec<String> = entries
|
||||
.iter()
|
||||
.filter_map(|e| {
|
||||
e.license.as_ref().and_then(|lic| {
|
||||
lic.id
|
||||
.clone()
|
||||
.or_else(|| lic.name.clone())
|
||||
.filter(|s| !s.is_empty())
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
if parts.is_empty() {
|
||||
return None;
|
||||
}
|
||||
Some(parts.join(" OR "))
|
||||
}
|
||||
|
||||
/// Extract the ecosystem/package-manager from a PURL string.
|
||||
/// e.g. "pkg:npm/lodash@4.17.21" → "npm", "pkg:cargo/serde@1.0" → "cargo"
|
||||
fn extract_ecosystem_from_purl(purl: &str) -> Option<String> {
|
||||
let rest = purl.strip_prefix("pkg:")?;
|
||||
let ecosystem = rest.split('/').next()?;
|
||||
if ecosystem.is_empty() {
|
||||
return None;
|
||||
}
|
||||
// Normalise common PURL types to user-friendly names
|
||||
let normalised = match ecosystem {
|
||||
"golang" => "go",
|
||||
"pypi" => "pip",
|
||||
_ => ecosystem,
|
||||
};
|
||||
Some(normalised.to_string())
|
||||
}
|
||||
355
compliance-agent/src/pipeline/sbom/syft.rs
Normal file
355
compliance-agent/src/pipeline/sbom/syft.rs
Normal file
@@ -0,0 +1,355 @@
|
||||
use std::path::Path;
|
||||
|
||||
use compliance_core::models::SbomEntry;
|
||||
use compliance_core::CoreError;
|
||||
|
||||
#[tracing::instrument(skip_all, fields(repo_id = %repo_id))]
|
||||
pub(super) async fn run_syft(repo_path: &Path, repo_id: &str) -> Result<Vec<SbomEntry>, CoreError> {
|
||||
let output = tokio::process::Command::new("syft")
|
||||
.arg(repo_path)
|
||||
.args(["-o", "cyclonedx-json"])
|
||||
// Enable remote license lookups for all ecosystems
|
||||
.env("SYFT_GOLANG_SEARCH_REMOTE_LICENSES", "true")
|
||||
.env("SYFT_JAVASCRIPT_SEARCH_REMOTE_LICENSES", "true")
|
||||
.env("SYFT_PYTHON_SEARCH_REMOTE_LICENSES", "true")
|
||||
.env("SYFT_JAVA_USE_NETWORK", "true")
|
||||
.output()
|
||||
.await
|
||||
.map_err(|e| CoreError::Scanner {
|
||||
scanner: "syft".to_string(),
|
||||
source: Box::new(e),
|
||||
})?;
|
||||
|
||||
if !output.status.success() {
|
||||
let stderr = String::from_utf8_lossy(&output.stderr);
|
||||
return Err(CoreError::Scanner {
|
||||
scanner: "syft".to_string(),
|
||||
source: format!("syft exited with {}: {stderr}", output.status).into(),
|
||||
});
|
||||
}
|
||||
|
||||
let cdx: CycloneDxBom = serde_json::from_slice(&output.stdout)?;
|
||||
let entries = cdx
|
||||
.components
|
||||
.unwrap_or_default()
|
||||
.into_iter()
|
||||
.map(|c| {
|
||||
let package_manager = c
|
||||
.purl
|
||||
.as_deref()
|
||||
.and_then(extract_ecosystem_from_purl)
|
||||
.unwrap_or_else(|| "unknown".to_string());
|
||||
let mut entry = SbomEntry::new(
|
||||
repo_id.to_string(),
|
||||
c.name,
|
||||
c.version.unwrap_or_else(|| "unknown".to_string()),
|
||||
package_manager,
|
||||
);
|
||||
entry.purl = c.purl;
|
||||
entry.license = c.licenses.and_then(|ls| extract_license(&ls));
|
||||
entry
|
||||
})
|
||||
.collect();
|
||||
|
||||
Ok(entries)
|
||||
}
|
||||
|
||||
// CycloneDX JSON types
|
||||
#[derive(serde::Deserialize)]
|
||||
struct CycloneDxBom {
|
||||
components: Option<Vec<CdxComponent>>,
|
||||
}
|
||||
|
||||
#[derive(serde::Deserialize)]
|
||||
struct CdxComponent {
|
||||
name: String,
|
||||
version: Option<String>,
|
||||
#[serde(rename = "type")]
|
||||
#[allow(dead_code)]
|
||||
component_type: Option<String>,
|
||||
purl: Option<String>,
|
||||
licenses: Option<Vec<CdxLicenseWrapper>>,
|
||||
}
|
||||
|
||||
#[derive(serde::Deserialize)]
|
||||
struct CdxLicenseWrapper {
|
||||
license: Option<CdxLicense>,
|
||||
/// SPDX license expression (e.g. "MIT OR Apache-2.0")
|
||||
expression: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(serde::Deserialize)]
|
||||
struct CdxLicense {
|
||||
id: Option<String>,
|
||||
name: Option<String>,
|
||||
}
|
||||
|
||||
/// Extract the best license string from CycloneDX license entries.
|
||||
/// Handles three formats: expression ("MIT OR Apache-2.0"), license.id ("MIT"), license.name ("MIT License").
|
||||
fn extract_license(entries: &[CdxLicenseWrapper]) -> Option<String> {
|
||||
// First pass: look for SPDX expressions (most precise for dual-licensed packages)
|
||||
for entry in entries {
|
||||
if let Some(ref expr) = entry.expression {
|
||||
if !expr.is_empty() {
|
||||
return Some(expr.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
// Second pass: collect license.id or license.name from all entries
|
||||
let parts: Vec<String> = entries
|
||||
.iter()
|
||||
.filter_map(|e| {
|
||||
e.license.as_ref().and_then(|lic| {
|
||||
lic.id
|
||||
.clone()
|
||||
.or_else(|| lic.name.clone())
|
||||
.filter(|s| !s.is_empty())
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
if parts.is_empty() {
|
||||
return None;
|
||||
}
|
||||
Some(parts.join(" OR "))
|
||||
}
|
||||
|
||||
/// Extract the ecosystem/package-manager from a PURL string.
|
||||
/// e.g. "pkg:npm/lodash@4.17.21" -> "npm", "pkg:cargo/serde@1.0" -> "cargo"
|
||||
fn extract_ecosystem_from_purl(purl: &str) -> Option<String> {
|
||||
let rest = purl.strip_prefix("pkg:")?;
|
||||
let ecosystem = rest.split('/').next()?;
|
||||
if ecosystem.is_empty() {
|
||||
return None;
|
||||
}
|
||||
// Normalise common PURL types to user-friendly names
|
||||
let normalised = match ecosystem {
|
||||
"golang" => "go",
|
||||
"pypi" => "pip",
|
||||
_ => ecosystem,
|
||||
};
|
||||
Some(normalised.to_string())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
// --- extract_ecosystem_from_purl tests ---
|
||||
|
||||
#[test]
|
||||
fn purl_npm() {
|
||||
assert_eq!(
|
||||
extract_ecosystem_from_purl("pkg:npm/lodash@4.17.21"),
|
||||
Some("npm".to_string())
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn purl_cargo() {
|
||||
assert_eq!(
|
||||
extract_ecosystem_from_purl("pkg:cargo/serde@1.0.197"),
|
||||
Some("cargo".to_string())
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn purl_golang_normalised() {
|
||||
assert_eq!(
|
||||
extract_ecosystem_from_purl("pkg:golang/github.com/gin-gonic/gin@1.9.1"),
|
||||
Some("go".to_string())
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn purl_pypi_normalised() {
|
||||
assert_eq!(
|
||||
extract_ecosystem_from_purl("pkg:pypi/requests@2.31.0"),
|
||||
Some("pip".to_string())
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn purl_maven() {
|
||||
assert_eq!(
|
||||
extract_ecosystem_from_purl("pkg:maven/org.apache.commons/commons-lang3@3.14.0"),
|
||||
Some("maven".to_string())
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn purl_missing_prefix() {
|
||||
assert_eq!(extract_ecosystem_from_purl("npm/lodash@4.17.21"), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn purl_empty_ecosystem() {
|
||||
assert_eq!(extract_ecosystem_from_purl("pkg:/lodash@4.17.21"), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn purl_empty_string() {
|
||||
assert_eq!(extract_ecosystem_from_purl(""), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn purl_just_prefix() {
|
||||
assert_eq!(extract_ecosystem_from_purl("pkg:"), None);
|
||||
}
|
||||
|
||||
// --- extract_license tests ---
|
||||
|
||||
#[test]
|
||||
fn license_from_expression() {
|
||||
let entries = vec![CdxLicenseWrapper {
|
||||
license: None,
|
||||
expression: Some("MIT OR Apache-2.0".to_string()),
|
||||
}];
|
||||
assert_eq!(
|
||||
extract_license(&entries),
|
||||
Some("MIT OR Apache-2.0".to_string())
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn license_from_id() {
|
||||
let entries = vec![CdxLicenseWrapper {
|
||||
license: Some(CdxLicense {
|
||||
id: Some("MIT".to_string()),
|
||||
name: None,
|
||||
}),
|
||||
expression: None,
|
||||
}];
|
||||
assert_eq!(extract_license(&entries), Some("MIT".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn license_from_name_fallback() {
|
||||
let entries = vec![CdxLicenseWrapper {
|
||||
license: Some(CdxLicense {
|
||||
id: None,
|
||||
name: Some("MIT License".to_string()),
|
||||
}),
|
||||
expression: None,
|
||||
}];
|
||||
assert_eq!(extract_license(&entries), Some("MIT License".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn license_expression_preferred_over_id() {
|
||||
let entries = vec![
|
||||
CdxLicenseWrapper {
|
||||
license: Some(CdxLicense {
|
||||
id: Some("MIT".to_string()),
|
||||
name: None,
|
||||
}),
|
||||
expression: None,
|
||||
},
|
||||
CdxLicenseWrapper {
|
||||
license: None,
|
||||
expression: Some("MIT AND Apache-2.0".to_string()),
|
||||
},
|
||||
];
|
||||
// Expression should be preferred (first pass finds it)
|
||||
assert_eq!(
|
||||
extract_license(&entries),
|
||||
Some("MIT AND Apache-2.0".to_string())
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn license_multiple_ids_joined() {
|
||||
let entries = vec![
|
||||
CdxLicenseWrapper {
|
||||
license: Some(CdxLicense {
|
||||
id: Some("MIT".to_string()),
|
||||
name: None,
|
||||
}),
|
||||
expression: None,
|
||||
},
|
||||
CdxLicenseWrapper {
|
||||
license: Some(CdxLicense {
|
||||
id: Some("Apache-2.0".to_string()),
|
||||
name: None,
|
||||
}),
|
||||
expression: None,
|
||||
},
|
||||
];
|
||||
assert_eq!(
|
||||
extract_license(&entries),
|
||||
Some("MIT OR Apache-2.0".to_string())
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn license_empty_entries() {
|
||||
let entries: Vec<CdxLicenseWrapper> = vec![];
|
||||
assert_eq!(extract_license(&entries), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn license_all_empty_strings() {
|
||||
let entries = vec![CdxLicenseWrapper {
|
||||
license: Some(CdxLicense {
|
||||
id: Some(String::new()),
|
||||
name: Some(String::new()),
|
||||
}),
|
||||
expression: Some(String::new()),
|
||||
}];
|
||||
assert_eq!(extract_license(&entries), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn license_none_fields() {
|
||||
let entries = vec![CdxLicenseWrapper {
|
||||
license: None,
|
||||
expression: None,
|
||||
}];
|
||||
assert_eq!(extract_license(&entries), None);
|
||||
}
|
||||
|
||||
// --- CycloneDX deserialization tests ---
|
||||
|
||||
#[test]
|
||||
fn deserialize_cyclonedx_bom() {
|
||||
let json = r#"{
|
||||
"components": [
|
||||
{
|
||||
"name": "serde",
|
||||
"version": "1.0.197",
|
||||
"type": "library",
|
||||
"purl": "pkg:cargo/serde@1.0.197",
|
||||
"licenses": [
|
||||
{"expression": "MIT OR Apache-2.0"}
|
||||
]
|
||||
}
|
||||
]
|
||||
}"#;
|
||||
let bom: CycloneDxBom = serde_json::from_str(json).unwrap();
|
||||
let components = bom.components.unwrap();
|
||||
assert_eq!(components.len(), 1);
|
||||
assert_eq!(components[0].name, "serde");
|
||||
assert_eq!(components[0].version, Some("1.0.197".to_string()));
|
||||
assert_eq!(
|
||||
components[0].purl,
|
||||
Some("pkg:cargo/serde@1.0.197".to_string())
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn deserialize_cyclonedx_no_components() {
|
||||
let json = r#"{}"#;
|
||||
let bom: CycloneDxBom = serde_json::from_str(json).unwrap();
|
||||
assert!(bom.components.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn deserialize_cyclonedx_minimal_component() {
|
||||
let json = r#"{"components": [{"name": "foo"}]}"#;
|
||||
let bom: CycloneDxBom = serde_json::from_str(json).unwrap();
|
||||
let c = &bom.components.unwrap()[0];
|
||||
assert_eq!(c.name, "foo");
|
||||
assert!(c.version.is_none());
|
||||
assert!(c.purl.is_none());
|
||||
assert!(c.licenses.is_none());
|
||||
}
|
||||
}
|
||||
@@ -108,3 +108,124 @@ struct SemgrepExtra {
|
||||
#[serde(default)]
|
||||
metadata: Option<serde_json::Value>,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn deserialize_semgrep_output() {
|
||||
let json = r#"{
|
||||
"results": [
|
||||
{
|
||||
"check_id": "python.lang.security.audit.exec-detected",
|
||||
"path": "src/main.py",
|
||||
"start": {"line": 15},
|
||||
"extra": {
|
||||
"message": "Detected use of exec()",
|
||||
"severity": "ERROR",
|
||||
"lines": "exec(user_input)",
|
||||
"metadata": {"cwe": "CWE-78"}
|
||||
}
|
||||
}
|
||||
]
|
||||
}"#;
|
||||
let output: SemgrepOutput = serde_json::from_str(json).unwrap();
|
||||
assert_eq!(output.results.len(), 1);
|
||||
|
||||
let r = &output.results[0];
|
||||
assert_eq!(r.check_id, "python.lang.security.audit.exec-detected");
|
||||
assert_eq!(r.path, "src/main.py");
|
||||
assert_eq!(r.start.line, 15);
|
||||
assert_eq!(r.extra.message, "Detected use of exec()");
|
||||
assert_eq!(r.extra.severity, "ERROR");
|
||||
assert_eq!(r.extra.lines, "exec(user_input)");
|
||||
assert_eq!(
|
||||
r.extra
|
||||
.metadata
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.get("cwe")
|
||||
.unwrap()
|
||||
.as_str(),
|
||||
Some("CWE-78")
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn deserialize_semgrep_empty_results() {
|
||||
let json = r#"{"results": []}"#;
|
||||
let output: SemgrepOutput = serde_json::from_str(json).unwrap();
|
||||
assert!(output.results.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn deserialize_semgrep_no_metadata() {
|
||||
let json = r#"{
|
||||
"results": [
|
||||
{
|
||||
"check_id": "rule-1",
|
||||
"path": "app.py",
|
||||
"start": {"line": 1},
|
||||
"extra": {
|
||||
"message": "found something",
|
||||
"severity": "WARNING",
|
||||
"lines": "import os"
|
||||
}
|
||||
}
|
||||
]
|
||||
}"#;
|
||||
let output: SemgrepOutput = serde_json::from_str(json).unwrap();
|
||||
assert!(output.results[0].extra.metadata.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn semgrep_severity_mapping() {
|
||||
let cases = vec![
|
||||
("ERROR", "High"),
|
||||
("WARNING", "Medium"),
|
||||
("INFO", "Low"),
|
||||
("UNKNOWN", "Info"),
|
||||
];
|
||||
for (input, expected) in cases {
|
||||
let result = match input {
|
||||
"ERROR" => "High",
|
||||
"WARNING" => "Medium",
|
||||
"INFO" => "Low",
|
||||
_ => "Info",
|
||||
};
|
||||
assert_eq!(result, expected, "Severity for '{input}'");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn deserialize_semgrep_multiple_results() {
|
||||
let json = r#"{
|
||||
"results": [
|
||||
{
|
||||
"check_id": "rule-a",
|
||||
"path": "a.py",
|
||||
"start": {"line": 1},
|
||||
"extra": {
|
||||
"message": "msg a",
|
||||
"severity": "ERROR",
|
||||
"lines": "line a"
|
||||
}
|
||||
},
|
||||
{
|
||||
"check_id": "rule-b",
|
||||
"path": "b.py",
|
||||
"start": {"line": 99},
|
||||
"extra": {
|
||||
"message": "msg b",
|
||||
"severity": "INFO",
|
||||
"lines": "line b"
|
||||
}
|
||||
}
|
||||
]
|
||||
}"#;
|
||||
let output: SemgrepOutput = serde_json::from_str(json).unwrap();
|
||||
assert_eq!(output.results.len(), 2);
|
||||
assert_eq!(output.results[1].start.line, 99);
|
||||
}
|
||||
}
|
||||
|
||||
81
compliance-agent/src/pipeline/tracker_dispatch.rs
Normal file
81
compliance-agent/src/pipeline/tracker_dispatch.rs
Normal file
@@ -0,0 +1,81 @@
|
||||
use compliance_core::models::TrackerIssue;
|
||||
use compliance_core::traits::issue_tracker::IssueTracker;
|
||||
|
||||
use crate::trackers;
|
||||
|
||||
/// Enum dispatch for issue trackers (async traits aren't dyn-compatible).
|
||||
pub(crate) enum TrackerDispatch {
|
||||
GitHub(trackers::github::GitHubTracker),
|
||||
GitLab(trackers::gitlab::GitLabTracker),
|
||||
Gitea(trackers::gitea::GiteaTracker),
|
||||
Jira(trackers::jira::JiraTracker),
|
||||
}
|
||||
|
||||
impl TrackerDispatch {
|
||||
pub(crate) fn name(&self) -> &str {
|
||||
match self {
|
||||
Self::GitHub(t) => t.name(),
|
||||
Self::GitLab(t) => t.name(),
|
||||
Self::Gitea(t) => t.name(),
|
||||
Self::Jira(t) => t.name(),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn create_issue(
|
||||
&self,
|
||||
owner: &str,
|
||||
repo: &str,
|
||||
title: &str,
|
||||
body: &str,
|
||||
labels: &[String],
|
||||
) -> Result<TrackerIssue, compliance_core::error::CoreError> {
|
||||
match self {
|
||||
Self::GitHub(t) => t.create_issue(owner, repo, title, body, labels).await,
|
||||
Self::GitLab(t) => t.create_issue(owner, repo, title, body, labels).await,
|
||||
Self::Gitea(t) => t.create_issue(owner, repo, title, body, labels).await,
|
||||
Self::Jira(t) => t.create_issue(owner, repo, title, body, labels).await,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn find_existing_issue(
|
||||
&self,
|
||||
owner: &str,
|
||||
repo: &str,
|
||||
fingerprint: &str,
|
||||
) -> Result<Option<TrackerIssue>, compliance_core::error::CoreError> {
|
||||
match self {
|
||||
Self::GitHub(t) => t.find_existing_issue(owner, repo, fingerprint).await,
|
||||
Self::GitLab(t) => t.find_existing_issue(owner, repo, fingerprint).await,
|
||||
Self::Gitea(t) => t.find_existing_issue(owner, repo, fingerprint).await,
|
||||
Self::Jira(t) => t.find_existing_issue(owner, repo, fingerprint).await,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn create_pr_review(
|
||||
&self,
|
||||
owner: &str,
|
||||
repo: &str,
|
||||
pr_number: u64,
|
||||
body: &str,
|
||||
comments: Vec<compliance_core::traits::issue_tracker::ReviewComment>,
|
||||
) -> Result<(), compliance_core::error::CoreError> {
|
||||
match self {
|
||||
Self::GitHub(t) => {
|
||||
t.create_pr_review(owner, repo, pr_number, body, comments)
|
||||
.await
|
||||
}
|
||||
Self::GitLab(t) => {
|
||||
t.create_pr_review(owner, repo, pr_number, body, comments)
|
||||
.await
|
||||
}
|
||||
Self::Gitea(t) => {
|
||||
t.create_pr_review(owner, repo, pr_number, body, comments)
|
||||
.await
|
||||
}
|
||||
Self::Jira(t) => {
|
||||
t.create_pr_review(owner, repo, pr_number, body, comments)
|
||||
.await
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
3
compliance-agent/tests/common/mod.rs
Normal file
3
compliance-agent/tests/common/mod.rs
Normal file
@@ -0,0 +1,3 @@
|
||||
// Shared test helpers for compliance-agent integration tests.
|
||||
//
|
||||
// Add database mocks, fixtures, and test utilities here.
|
||||
4
compliance-agent/tests/integration/mod.rs
Normal file
4
compliance-agent/tests/integration/mod.rs
Normal file
@@ -0,0 +1,4 @@
|
||||
// Integration tests for the compliance-agent crate.
|
||||
//
|
||||
// Add tests that exercise the full pipeline, API handlers,
|
||||
// and cross-module interactions here.
|
||||
@@ -250,7 +250,11 @@ pub enum PentestEvent {
|
||||
findings_count: u32,
|
||||
},
|
||||
/// A new finding was discovered
|
||||
Finding { finding_id: String, title: String, severity: String },
|
||||
Finding {
|
||||
finding_id: String,
|
||||
title: String,
|
||||
severity: String,
|
||||
},
|
||||
/// Assistant message (streaming text)
|
||||
Message { content: String },
|
||||
/// Session completed
|
||||
|
||||
475
compliance-core/tests/models.rs
Normal file
475
compliance-core/tests/models.rs
Normal file
@@ -0,0 +1,475 @@
|
||||
use compliance_core::models::*;
|
||||
|
||||
// ─── Severity ───
|
||||
|
||||
#[test]
|
||||
fn severity_display_all_variants() {
|
||||
assert_eq!(Severity::Info.to_string(), "info");
|
||||
assert_eq!(Severity::Low.to_string(), "low");
|
||||
assert_eq!(Severity::Medium.to_string(), "medium");
|
||||
assert_eq!(Severity::High.to_string(), "high");
|
||||
assert_eq!(Severity::Critical.to_string(), "critical");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn severity_ordering() {
|
||||
assert!(Severity::Info < Severity::Low);
|
||||
assert!(Severity::Low < Severity::Medium);
|
||||
assert!(Severity::Medium < Severity::High);
|
||||
assert!(Severity::High < Severity::Critical);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn severity_serde_roundtrip() {
|
||||
for sev in [
|
||||
Severity::Info,
|
||||
Severity::Low,
|
||||
Severity::Medium,
|
||||
Severity::High,
|
||||
Severity::Critical,
|
||||
] {
|
||||
let json = serde_json::to_string(&sev).unwrap();
|
||||
let back: Severity = serde_json::from_str(&json).unwrap();
|
||||
assert_eq!(sev, back);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn severity_deserialize_lowercase() {
|
||||
let s: Severity = serde_json::from_str(r#""critical""#).unwrap();
|
||||
assert_eq!(s, Severity::Critical);
|
||||
}
|
||||
|
||||
// ─── FindingStatus ───
|
||||
|
||||
#[test]
|
||||
fn finding_status_display_all_variants() {
|
||||
assert_eq!(FindingStatus::Open.to_string(), "open");
|
||||
assert_eq!(FindingStatus::Triaged.to_string(), "triaged");
|
||||
assert_eq!(FindingStatus::FalsePositive.to_string(), "false_positive");
|
||||
assert_eq!(FindingStatus::Resolved.to_string(), "resolved");
|
||||
assert_eq!(FindingStatus::Ignored.to_string(), "ignored");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn finding_status_serde_roundtrip() {
|
||||
for status in [
|
||||
FindingStatus::Open,
|
||||
FindingStatus::Triaged,
|
||||
FindingStatus::FalsePositive,
|
||||
FindingStatus::Resolved,
|
||||
FindingStatus::Ignored,
|
||||
] {
|
||||
let json = serde_json::to_string(&status).unwrap();
|
||||
let back: FindingStatus = serde_json::from_str(&json).unwrap();
|
||||
assert_eq!(status, back);
|
||||
}
|
||||
}
|
||||
|
||||
// ─── Finding ───
|
||||
|
||||
#[test]
|
||||
fn finding_new_defaults() {
|
||||
let f = Finding::new(
|
||||
"repo1".into(),
|
||||
"fp123".into(),
|
||||
"semgrep".into(),
|
||||
ScanType::Sast,
|
||||
"Test title".into(),
|
||||
"Test desc".into(),
|
||||
Severity::High,
|
||||
);
|
||||
assert_eq!(f.repo_id, "repo1");
|
||||
assert_eq!(f.fingerprint, "fp123");
|
||||
assert_eq!(f.scanner, "semgrep");
|
||||
assert_eq!(f.scan_type, ScanType::Sast);
|
||||
assert_eq!(f.severity, Severity::High);
|
||||
assert_eq!(f.status, FindingStatus::Open);
|
||||
assert!(f.id.is_none());
|
||||
assert!(f.rule_id.is_none());
|
||||
assert!(f.confidence.is_none());
|
||||
assert!(f.file_path.is_none());
|
||||
assert!(f.remediation.is_none());
|
||||
assert!(f.suggested_fix.is_none());
|
||||
assert!(f.triage_action.is_none());
|
||||
assert!(f.developer_feedback.is_none());
|
||||
}
|
||||
|
||||
// ─── ScanType ───
|
||||
|
||||
#[test]
|
||||
fn scan_type_display_all_variants() {
|
||||
let cases = vec![
|
||||
(ScanType::Sast, "sast"),
|
||||
(ScanType::Sbom, "sbom"),
|
||||
(ScanType::Cve, "cve"),
|
||||
(ScanType::Gdpr, "gdpr"),
|
||||
(ScanType::OAuth, "oauth"),
|
||||
(ScanType::Graph, "graph"),
|
||||
(ScanType::Dast, "dast"),
|
||||
(ScanType::SecretDetection, "secret_detection"),
|
||||
(ScanType::Lint, "lint"),
|
||||
(ScanType::CodeReview, "code_review"),
|
||||
];
|
||||
for (variant, expected) in cases {
|
||||
assert_eq!(variant.to_string(), expected);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn scan_type_serde_roundtrip() {
|
||||
for st in [
|
||||
ScanType::Sast,
|
||||
ScanType::SecretDetection,
|
||||
ScanType::CodeReview,
|
||||
] {
|
||||
let json = serde_json::to_string(&st).unwrap();
|
||||
let back: ScanType = serde_json::from_str(&json).unwrap();
|
||||
assert_eq!(st, back);
|
||||
}
|
||||
}
|
||||
|
||||
// ─── ScanRun ───
|
||||
|
||||
#[test]
|
||||
fn scan_run_new_defaults() {
|
||||
let sr = ScanRun::new("repo1".into(), ScanTrigger::Manual);
|
||||
assert_eq!(sr.repo_id, "repo1");
|
||||
assert_eq!(sr.trigger, ScanTrigger::Manual);
|
||||
assert_eq!(sr.status, ScanRunStatus::Running);
|
||||
assert_eq!(sr.current_phase, ScanPhase::ChangeDetection);
|
||||
assert!(sr.phases_completed.is_empty());
|
||||
assert_eq!(sr.new_findings_count, 0);
|
||||
assert!(sr.error_message.is_none());
|
||||
assert!(sr.completed_at.is_none());
|
||||
}
|
||||
|
||||
// ─── PentestStatus ───
|
||||
|
||||
#[test]
|
||||
fn pentest_status_display() {
|
||||
assert_eq!(pentest::PentestStatus::Running.to_string(), "running");
|
||||
assert_eq!(pentest::PentestStatus::Paused.to_string(), "paused");
|
||||
assert_eq!(pentest::PentestStatus::Completed.to_string(), "completed");
|
||||
assert_eq!(pentest::PentestStatus::Failed.to_string(), "failed");
|
||||
}
|
||||
|
||||
// ─── PentestStrategy ───
|
||||
|
||||
#[test]
|
||||
fn pentest_strategy_display() {
|
||||
assert_eq!(pentest::PentestStrategy::Quick.to_string(), "quick");
|
||||
assert_eq!(
|
||||
pentest::PentestStrategy::Comprehensive.to_string(),
|
||||
"comprehensive"
|
||||
);
|
||||
assert_eq!(pentest::PentestStrategy::Targeted.to_string(), "targeted");
|
||||
assert_eq!(
|
||||
pentest::PentestStrategy::Aggressive.to_string(),
|
||||
"aggressive"
|
||||
);
|
||||
assert_eq!(pentest::PentestStrategy::Stealth.to_string(), "stealth");
|
||||
}
|
||||
|
||||
// ─── PentestSession ───
|
||||
|
||||
#[test]
|
||||
fn pentest_session_new_defaults() {
|
||||
let s = pentest::PentestSession::new("target1".into(), pentest::PentestStrategy::Quick);
|
||||
assert_eq!(s.target_id, "target1");
|
||||
assert_eq!(s.status, pentest::PentestStatus::Running);
|
||||
assert_eq!(s.strategy, pentest::PentestStrategy::Quick);
|
||||
assert_eq!(s.tool_invocations, 0);
|
||||
assert_eq!(s.tool_successes, 0);
|
||||
assert_eq!(s.findings_count, 0);
|
||||
assert!(s.completed_at.is_none());
|
||||
assert!(s.repo_id.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn pentest_session_success_rate_zero_invocations() {
|
||||
let s = pentest::PentestSession::new("t".into(), pentest::PentestStrategy::Comprehensive);
|
||||
assert_eq!(s.success_rate(), 100.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn pentest_session_success_rate_calculation() {
|
||||
let mut s = pentest::PentestSession::new("t".into(), pentest::PentestStrategy::Comprehensive);
|
||||
s.tool_invocations = 10;
|
||||
s.tool_successes = 7;
|
||||
assert!((s.success_rate() - 70.0).abs() < f64::EPSILON);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn pentest_session_success_rate_all_success() {
|
||||
let mut s = pentest::PentestSession::new("t".into(), pentest::PentestStrategy::Quick);
|
||||
s.tool_invocations = 5;
|
||||
s.tool_successes = 5;
|
||||
assert_eq!(s.success_rate(), 100.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn pentest_session_success_rate_none_success() {
|
||||
let mut s = pentest::PentestSession::new("t".into(), pentest::PentestStrategy::Quick);
|
||||
s.tool_invocations = 3;
|
||||
s.tool_successes = 0;
|
||||
assert_eq!(s.success_rate(), 0.0);
|
||||
}
|
||||
|
||||
// ─── PentestMessage factories ───
|
||||
|
||||
#[test]
|
||||
fn pentest_message_user() {
|
||||
let m = pentest::PentestMessage::user("sess1".into(), "hello".into());
|
||||
assert_eq!(m.role, "user");
|
||||
assert_eq!(m.session_id, "sess1");
|
||||
assert_eq!(m.content, "hello");
|
||||
assert!(m.attack_node_id.is_none());
|
||||
assert!(m.tool_calls.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn pentest_message_assistant() {
|
||||
let m = pentest::PentestMessage::assistant("sess1".into(), "response".into());
|
||||
assert_eq!(m.role, "assistant");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn pentest_message_tool_result() {
|
||||
let m = pentest::PentestMessage::tool_result("sess1".into(), "output".into(), "node1".into());
|
||||
assert_eq!(m.role, "tool_result");
|
||||
assert_eq!(m.attack_node_id, Some("node1".to_string()));
|
||||
}
|
||||
|
||||
// ─── AttackChainNode ───
|
||||
|
||||
#[test]
|
||||
fn attack_chain_node_new_defaults() {
|
||||
let n = pentest::AttackChainNode::new(
|
||||
"sess1".into(),
|
||||
"node1".into(),
|
||||
"recon".into(),
|
||||
serde_json::json!({"target": "example.com"}),
|
||||
"Starting recon".into(),
|
||||
);
|
||||
assert_eq!(n.session_id, "sess1");
|
||||
assert_eq!(n.node_id, "node1");
|
||||
assert_eq!(n.tool_name, "recon");
|
||||
assert_eq!(n.status, pentest::AttackNodeStatus::Pending);
|
||||
assert!(n.parent_node_ids.is_empty());
|
||||
assert!(n.findings_produced.is_empty());
|
||||
assert!(n.risk_score.is_none());
|
||||
assert!(n.started_at.is_none());
|
||||
}
|
||||
|
||||
// ─── DastTarget ───
|
||||
|
||||
#[test]
|
||||
fn dast_target_new_defaults() {
|
||||
let t = dast::DastTarget::new(
|
||||
"My App".into(),
|
||||
"https://example.com".into(),
|
||||
dast::DastTargetType::WebApp,
|
||||
);
|
||||
assert_eq!(t.name, "My App");
|
||||
assert_eq!(t.base_url, "https://example.com");
|
||||
assert_eq!(t.target_type, dast::DastTargetType::WebApp);
|
||||
assert_eq!(t.max_crawl_depth, 5);
|
||||
assert_eq!(t.rate_limit, 10);
|
||||
assert!(!t.allow_destructive);
|
||||
assert!(t.excluded_paths.is_empty());
|
||||
assert!(t.auth_config.is_none());
|
||||
assert!(t.repo_id.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn dast_target_type_display() {
|
||||
assert_eq!(dast::DastTargetType::WebApp.to_string(), "webapp");
|
||||
assert_eq!(dast::DastTargetType::RestApi.to_string(), "rest_api");
|
||||
assert_eq!(dast::DastTargetType::GraphQl.to_string(), "graphql");
|
||||
}
|
||||
|
||||
// ─── DastScanRun ───
|
||||
|
||||
#[test]
|
||||
fn dast_scan_run_new_defaults() {
|
||||
let sr = dast::DastScanRun::new("target1".into());
|
||||
assert_eq!(sr.status, dast::DastScanStatus::Running);
|
||||
assert_eq!(sr.current_phase, dast::DastScanPhase::Reconnaissance);
|
||||
assert!(sr.phases_completed.is_empty());
|
||||
assert_eq!(sr.endpoints_discovered, 0);
|
||||
assert_eq!(sr.findings_count, 0);
|
||||
assert!(!sr.exploitable_count > 0);
|
||||
assert!(sr.completed_at.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn dast_scan_phase_display() {
|
||||
assert_eq!(
|
||||
dast::DastScanPhase::Reconnaissance.to_string(),
|
||||
"reconnaissance"
|
||||
);
|
||||
assert_eq!(dast::DastScanPhase::Crawling.to_string(), "crawling");
|
||||
assert_eq!(dast::DastScanPhase::Completed.to_string(), "completed");
|
||||
}
|
||||
|
||||
// ─── DastVulnType ───
|
||||
|
||||
#[test]
|
||||
fn dast_vuln_type_display_all_variants() {
|
||||
let cases = vec![
|
||||
(dast::DastVulnType::SqlInjection, "sql_injection"),
|
||||
(dast::DastVulnType::Xss, "xss"),
|
||||
(dast::DastVulnType::AuthBypass, "auth_bypass"),
|
||||
(dast::DastVulnType::Ssrf, "ssrf"),
|
||||
(dast::DastVulnType::Idor, "idor"),
|
||||
(dast::DastVulnType::Other, "other"),
|
||||
];
|
||||
for (variant, expected) in cases {
|
||||
assert_eq!(variant.to_string(), expected);
|
||||
}
|
||||
}
|
||||
|
||||
// ─── DastFinding ───
|
||||
|
||||
#[test]
|
||||
fn dast_finding_new_defaults() {
|
||||
let f = dast::DastFinding::new(
|
||||
"run1".into(),
|
||||
"target1".into(),
|
||||
dast::DastVulnType::Xss,
|
||||
"XSS in search".into(),
|
||||
"Reflected XSS".into(),
|
||||
Severity::High,
|
||||
"https://example.com/search".into(),
|
||||
"GET".into(),
|
||||
);
|
||||
assert_eq!(f.vuln_type, dast::DastVulnType::Xss);
|
||||
assert_eq!(f.severity, Severity::High);
|
||||
assert!(!f.exploitable);
|
||||
assert!(f.evidence.is_empty());
|
||||
assert!(f.session_id.is_none());
|
||||
assert!(f.linked_sast_finding_id.is_none());
|
||||
}
|
||||
|
||||
// ─── SbomEntry ───
|
||||
|
||||
#[test]
|
||||
fn sbom_entry_new_defaults() {
|
||||
let e = SbomEntry::new(
|
||||
"repo1".into(),
|
||||
"lodash".into(),
|
||||
"4.17.21".into(),
|
||||
"npm".into(),
|
||||
);
|
||||
assert_eq!(e.name, "lodash");
|
||||
assert_eq!(e.version, "4.17.21");
|
||||
assert_eq!(e.package_manager, "npm");
|
||||
assert!(e.license.is_none());
|
||||
assert!(e.purl.is_none());
|
||||
assert!(e.known_vulnerabilities.is_empty());
|
||||
}
|
||||
|
||||
// ─── TrackedRepository ───
|
||||
|
||||
#[test]
|
||||
fn tracked_repository_new_defaults() {
|
||||
let r = TrackedRepository::new("My Repo".into(), "https://github.com/org/repo.git".into());
|
||||
assert_eq!(r.name, "My Repo");
|
||||
assert_eq!(r.git_url, "https://github.com/org/repo.git");
|
||||
assert_eq!(r.default_branch, "main");
|
||||
assert!(!r.webhook_enabled);
|
||||
assert!(r.webhook_secret.is_some());
|
||||
// Webhook secret should be 32 hex chars (UUID without dashes)
|
||||
assert_eq!(r.webhook_secret.as_ref().unwrap().len(), 32);
|
||||
assert!(r.tracker_type.is_none());
|
||||
assert_eq!(r.findings_count, 0);
|
||||
}
|
||||
|
||||
// ─── ScanTrigger ───
|
||||
|
||||
#[test]
|
||||
fn scan_trigger_serde_roundtrip() {
|
||||
for trigger in [
|
||||
ScanTrigger::Scheduled,
|
||||
ScanTrigger::Webhook,
|
||||
ScanTrigger::Manual,
|
||||
] {
|
||||
let json = serde_json::to_string(&trigger).unwrap();
|
||||
let back: ScanTrigger = serde_json::from_str(&json).unwrap();
|
||||
assert_eq!(trigger, back);
|
||||
}
|
||||
}
|
||||
|
||||
// ─── PentestEvent serde (tagged enum) ───
|
||||
|
||||
#[test]
|
||||
fn pentest_event_serde_thinking() {
|
||||
let event = pentest::PentestEvent::Thinking {
|
||||
reasoning: "analyzing target".into(),
|
||||
};
|
||||
let json = serde_json::to_string(&event).unwrap();
|
||||
assert!(json.contains(r#""type":"thinking""#));
|
||||
assert!(json.contains("analyzing target"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn pentest_event_serde_finding() {
|
||||
let event = pentest::PentestEvent::Finding {
|
||||
finding_id: "f1".into(),
|
||||
title: "XSS".into(),
|
||||
severity: "high".into(),
|
||||
};
|
||||
let json = serde_json::to_string(&event).unwrap();
|
||||
let back: pentest::PentestEvent = serde_json::from_str(&json).unwrap();
|
||||
match back {
|
||||
pentest::PentestEvent::Finding {
|
||||
finding_id,
|
||||
title,
|
||||
severity,
|
||||
} => {
|
||||
assert_eq!(finding_id, "f1");
|
||||
assert_eq!(title, "XSS");
|
||||
assert_eq!(severity, "high");
|
||||
}
|
||||
_ => panic!("wrong variant"),
|
||||
}
|
||||
}
|
||||
|
||||
// ─── Serde helpers (BSON datetime) ───
|
||||
|
||||
#[test]
|
||||
fn bson_datetime_roundtrip_via_finding() {
|
||||
let f = Finding::new(
|
||||
"repo1".into(),
|
||||
"fp".into(),
|
||||
"test".into(),
|
||||
ScanType::Sast,
|
||||
"t".into(),
|
||||
"d".into(),
|
||||
Severity::Low,
|
||||
);
|
||||
// Serialize to BSON and back
|
||||
let bson_doc = bson::to_document(&f).unwrap();
|
||||
let back: Finding = bson::from_document(bson_doc).unwrap();
|
||||
// Timestamps should survive (within 1 second tolerance due to ms precision)
|
||||
assert!((back.created_at - f.created_at).num_milliseconds().abs() < 1000);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn opt_bson_datetime_roundtrip_with_none() {
|
||||
let s = pentest::PentestSession::new("t".into(), pentest::PentestStrategy::Quick);
|
||||
assert!(s.completed_at.is_none());
|
||||
let bson_doc = bson::to_document(&s).unwrap();
|
||||
let back: pentest::PentestSession = bson::from_document(bson_doc).unwrap();
|
||||
assert!(back.completed_at.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn opt_bson_datetime_roundtrip_with_some() {
|
||||
let mut s = pentest::PentestSession::new("t".into(), pentest::PentestStrategy::Quick);
|
||||
s.completed_at = Some(chrono::Utc::now());
|
||||
let bson_doc = bson::to_document(&s).unwrap();
|
||||
let back: pentest::PentestSession = bson::from_document(bson_doc).unwrap();
|
||||
assert!(back.completed_at.is_some());
|
||||
}
|
||||
283
compliance-dashboard/src/components/attack_chain/helpers.rs
Normal file
283
compliance-dashboard/src/components/attack_chain/helpers.rs
Normal file
@@ -0,0 +1,283 @@
|
||||
use std::collections::{HashMap, VecDeque};
|
||||
|
||||
/// Get category CSS class from tool name
|
||||
pub(crate) fn tool_category(name: &str) -> &'static str {
|
||||
let lower = name.to_lowercase();
|
||||
if lower.contains("recon") {
|
||||
return "recon";
|
||||
}
|
||||
if lower.contains("openapi") || lower.contains("api") || lower.contains("swagger") {
|
||||
return "api";
|
||||
}
|
||||
if lower.contains("header") {
|
||||
return "headers";
|
||||
}
|
||||
if lower.contains("csp") {
|
||||
return "csp";
|
||||
}
|
||||
if lower.contains("cookie") {
|
||||
return "cookies";
|
||||
}
|
||||
if lower.contains("log") || lower.contains("console") {
|
||||
return "logs";
|
||||
}
|
||||
if lower.contains("rate") || lower.contains("limit") {
|
||||
return "ratelimit";
|
||||
}
|
||||
if lower.contains("cors") {
|
||||
return "cors";
|
||||
}
|
||||
if lower.contains("tls") || lower.contains("ssl") {
|
||||
return "tls";
|
||||
}
|
||||
if lower.contains("redirect") {
|
||||
return "redirect";
|
||||
}
|
||||
if lower.contains("dns")
|
||||
|| lower.contains("dmarc")
|
||||
|| lower.contains("email")
|
||||
|| lower.contains("spf")
|
||||
{
|
||||
return "email";
|
||||
}
|
||||
if lower.contains("auth")
|
||||
|| lower.contains("jwt")
|
||||
|| lower.contains("token")
|
||||
|| lower.contains("session")
|
||||
{
|
||||
return "auth";
|
||||
}
|
||||
if lower.contains("xss") {
|
||||
return "xss";
|
||||
}
|
||||
if lower.contains("sql") || lower.contains("sqli") {
|
||||
return "sqli";
|
||||
}
|
||||
if lower.contains("ssrf") {
|
||||
return "ssrf";
|
||||
}
|
||||
if lower.contains("idor") {
|
||||
return "idor";
|
||||
}
|
||||
if lower.contains("fuzz") {
|
||||
return "fuzzer";
|
||||
}
|
||||
if lower.contains("cve") || lower.contains("exploit") {
|
||||
return "cve";
|
||||
}
|
||||
"default"
|
||||
}
|
||||
|
||||
/// Get emoji icon from tool category
|
||||
pub(crate) fn tool_emoji(cat: &str) -> &'static str {
|
||||
match cat {
|
||||
"recon" => "\u{1F50D}",
|
||||
"api" => "\u{1F517}",
|
||||
"headers" => "\u{1F6E1}",
|
||||
"csp" => "\u{1F6A7}",
|
||||
"cookies" => "\u{1F36A}",
|
||||
"logs" => "\u{1F4DD}",
|
||||
"ratelimit" => "\u{23F1}",
|
||||
"cors" => "\u{1F30D}",
|
||||
"tls" => "\u{1F510}",
|
||||
"redirect" => "\u{21AA}",
|
||||
"email" => "\u{1F4E7}",
|
||||
"auth" => "\u{1F512}",
|
||||
"xss" => "\u{26A1}",
|
||||
"sqli" => "\u{1F489}",
|
||||
"ssrf" => "\u{1F310}",
|
||||
"idor" => "\u{1F511}",
|
||||
"fuzzer" => "\u{1F9EA}",
|
||||
"cve" => "\u{1F4A3}",
|
||||
_ => "\u{1F527}",
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute display label for category
|
||||
pub(crate) fn cat_label(cat: &str) -> &'static str {
|
||||
match cat {
|
||||
"recon" => "Recon",
|
||||
"api" => "API",
|
||||
"headers" => "Headers",
|
||||
"csp" => "CSP",
|
||||
"cookies" => "Cookies",
|
||||
"logs" => "Logs",
|
||||
"ratelimit" => "Rate Limit",
|
||||
"cors" => "CORS",
|
||||
"tls" => "TLS",
|
||||
"redirect" => "Redirect",
|
||||
"email" => "Email/DNS",
|
||||
"auth" => "Auth",
|
||||
"xss" => "XSS",
|
||||
"sqli" => "SQLi",
|
||||
"ssrf" => "SSRF",
|
||||
"idor" => "IDOR",
|
||||
"fuzzer" => "Fuzzer",
|
||||
"cve" => "CVE",
|
||||
_ => "Other",
|
||||
}
|
||||
}
|
||||
|
||||
/// Phase name heuristic based on depth
|
||||
pub(crate) fn phase_name(depth: usize) -> &'static str {
|
||||
match depth {
|
||||
0 => "Reconnaissance",
|
||||
1 => "Analysis",
|
||||
2 => "Boundary Testing",
|
||||
3 => "Injection & Exploitation",
|
||||
4 => "Authentication Testing",
|
||||
5 => "Validation",
|
||||
6 => "Deep Scan",
|
||||
_ => "Final",
|
||||
}
|
||||
}
|
||||
|
||||
/// Short label for phase rail
|
||||
pub(crate) fn phase_short_name(depth: usize) -> &'static str {
|
||||
match depth {
|
||||
0 => "Recon",
|
||||
1 => "Analysis",
|
||||
2 => "Boundary",
|
||||
3 => "Exploit",
|
||||
4 => "Auth",
|
||||
5 => "Validate",
|
||||
6 => "Deep",
|
||||
_ => "Final",
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute BFS phases from attack chain nodes
|
||||
pub(crate) fn compute_phases(steps: &[serde_json::Value]) -> Vec<Vec<usize>> {
|
||||
let node_ids: Vec<String> = steps
|
||||
.iter()
|
||||
.map(|s| {
|
||||
s.get("node_id")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("")
|
||||
.to_string()
|
||||
})
|
||||
.collect();
|
||||
|
||||
let id_to_idx: HashMap<String, usize> = node_ids
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, id)| (id.clone(), i))
|
||||
.collect();
|
||||
|
||||
// Compute depth via BFS
|
||||
let mut depths = vec![usize::MAX; steps.len()];
|
||||
let mut queue = VecDeque::new();
|
||||
|
||||
// Root nodes: those with no parents or parents not in the set
|
||||
for (i, step) in steps.iter().enumerate() {
|
||||
let parents = step
|
||||
.get("parent_node_ids")
|
||||
.and_then(|v| v.as_array())
|
||||
.map(|arr| {
|
||||
arr.iter()
|
||||
.filter_map(|p| p.as_str())
|
||||
.filter(|p| id_to_idx.contains_key(*p))
|
||||
.count()
|
||||
})
|
||||
.unwrap_or(0);
|
||||
if parents == 0 {
|
||||
depths[i] = 0;
|
||||
queue.push_back(i);
|
||||
}
|
||||
}
|
||||
|
||||
// BFS to compute min depth
|
||||
while let Some(idx) = queue.pop_front() {
|
||||
let current_depth = depths[idx];
|
||||
let node_id = &node_ids[idx];
|
||||
// Find children: nodes that list this node as a parent
|
||||
for (j, step) in steps.iter().enumerate() {
|
||||
if depths[j] <= current_depth + 1 {
|
||||
continue;
|
||||
}
|
||||
let is_child = step
|
||||
.get("parent_node_ids")
|
||||
.and_then(|v| v.as_array())
|
||||
.map(|arr| arr.iter().any(|p| p.as_str() == Some(node_id.as_str())))
|
||||
.unwrap_or(false);
|
||||
if is_child {
|
||||
depths[j] = current_depth + 1;
|
||||
queue.push_back(j);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Handle unreachable nodes
|
||||
for d in depths.iter_mut() {
|
||||
if *d == usize::MAX {
|
||||
*d = 0;
|
||||
}
|
||||
}
|
||||
|
||||
// Group by depth
|
||||
let max_depth = depths.iter().copied().max().unwrap_or(0);
|
||||
let mut phases: Vec<Vec<usize>> = Vec::new();
|
||||
for d in 0..=max_depth {
|
||||
let indices: Vec<usize> = depths
|
||||
.iter()
|
||||
.enumerate()
|
||||
.filter(|(_, &dep)| dep == d)
|
||||
.map(|(i, _)| i)
|
||||
.collect();
|
||||
if !indices.is_empty() {
|
||||
phases.push(indices);
|
||||
}
|
||||
}
|
||||
phases
|
||||
}
|
||||
|
||||
/// Format BSON datetime to readable string
|
||||
pub(crate) fn format_bson_time(val: &serde_json::Value) -> String {
|
||||
// Handle BSON {"$date":{"$numberLong":"..."}}
|
||||
if let Some(date_obj) = val.get("$date") {
|
||||
if let Some(ms_str) = date_obj.get("$numberLong").and_then(|v| v.as_str()) {
|
||||
if let Ok(ms) = ms_str.parse::<i64>() {
|
||||
let secs = ms / 1000;
|
||||
let h = (secs / 3600) % 24;
|
||||
let m = (secs / 60) % 60;
|
||||
let s = secs % 60;
|
||||
return format!("{h:02}:{m:02}:{s:02}");
|
||||
}
|
||||
}
|
||||
// Handle {"$date": "2025-..."}
|
||||
if let Some(s) = date_obj.as_str() {
|
||||
return s.to_string();
|
||||
}
|
||||
}
|
||||
// Handle plain string
|
||||
if let Some(s) = val.as_str() {
|
||||
return s.to_string();
|
||||
}
|
||||
String::new()
|
||||
}
|
||||
|
||||
/// Compute duration string from started_at and completed_at
|
||||
pub(crate) fn compute_duration(step: &serde_json::Value) -> String {
|
||||
let extract_ms = |val: &serde_json::Value| -> Option<i64> {
|
||||
val.get("$date")?
|
||||
.get("$numberLong")?
|
||||
.as_str()?
|
||||
.parse::<i64>()
|
||||
.ok()
|
||||
};
|
||||
|
||||
let started = step.get("started_at").and_then(extract_ms);
|
||||
let completed = step.get("completed_at").and_then(extract_ms);
|
||||
|
||||
match (started, completed) {
|
||||
(Some(s), Some(c)) => {
|
||||
let diff_ms = c - s;
|
||||
if diff_ms < 1000 {
|
||||
format!("{}ms", diff_ms)
|
||||
} else {
|
||||
format!("{:.1}s", diff_ms as f64 / 1000.0)
|
||||
}
|
||||
}
|
||||
_ => String::new(),
|
||||
}
|
||||
}
|
||||
4
compliance-dashboard/src/components/attack_chain/mod.rs
Normal file
4
compliance-dashboard/src/components/attack_chain/mod.rs
Normal file
@@ -0,0 +1,4 @@
|
||||
pub mod helpers;
|
||||
mod view;
|
||||
|
||||
pub use view::AttackChainView;
|
||||
363
compliance-dashboard/src/components/attack_chain/view.rs
Normal file
363
compliance-dashboard/src/components/attack_chain/view.rs
Normal file
@@ -0,0 +1,363 @@
|
||||
use dioxus::prelude::*;
|
||||
|
||||
use super::helpers::*;
|
||||
|
||||
/// (phase_index, steps, findings_count, has_failed, has_running, all_done)
|
||||
type PhaseData<'a> = (usize, Vec<&'a serde_json::Value>, usize, bool, bool, bool);
|
||||
|
||||
#[component]
|
||||
pub fn AttackChainView(
|
||||
steps: Vec<serde_json::Value>,
|
||||
is_running: bool,
|
||||
session_findings: usize,
|
||||
session_tool_invocations: usize,
|
||||
session_success_rate: f64,
|
||||
) -> Element {
|
||||
let phases = compute_phases(&steps);
|
||||
|
||||
// Compute KPIs — prefer session-level stats, fall back to node-level
|
||||
let total_tools = steps.len();
|
||||
let node_findings: usize = steps
|
||||
.iter()
|
||||
.map(|s| {
|
||||
s.get("findings_produced")
|
||||
.and_then(|v| v.as_array())
|
||||
.map(|a| a.len())
|
||||
.unwrap_or(0)
|
||||
})
|
||||
.sum();
|
||||
// Use session-level findings count if nodes don't have findings linked
|
||||
let total_findings = if node_findings > 0 {
|
||||
node_findings
|
||||
} else {
|
||||
session_findings
|
||||
};
|
||||
|
||||
let completed_count = steps
|
||||
.iter()
|
||||
.filter(|s| s.get("status").and_then(|v| v.as_str()) == Some("completed"))
|
||||
.count();
|
||||
let failed_count = steps
|
||||
.iter()
|
||||
.filter(|s| s.get("status").and_then(|v| v.as_str()) == Some("failed"))
|
||||
.count();
|
||||
let finished = completed_count + failed_count;
|
||||
let success_pct = if finished == 0 {
|
||||
100
|
||||
} else {
|
||||
(completed_count * 100) / finished
|
||||
};
|
||||
let max_risk: u8 = steps
|
||||
.iter()
|
||||
.filter_map(|s| s.get("risk_score").and_then(|v| v.as_u64()))
|
||||
.map(|v| v as u8)
|
||||
.max()
|
||||
.unwrap_or(0);
|
||||
|
||||
let progress_pct = if total_tools == 0 {
|
||||
0
|
||||
} else {
|
||||
((completed_count + failed_count) * 100) / total_tools
|
||||
};
|
||||
|
||||
// Build phase data for rail and accordion
|
||||
let phase_data: Vec<PhaseData<'_>> = phases
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(pi, indices)| {
|
||||
let phase_steps: Vec<&serde_json::Value> = indices.iter().map(|&i| &steps[i]).collect();
|
||||
let phase_findings: usize = phase_steps
|
||||
.iter()
|
||||
.map(|s| {
|
||||
s.get("findings_produced")
|
||||
.and_then(|v| v.as_array())
|
||||
.map(|a| a.len())
|
||||
.unwrap_or(0)
|
||||
})
|
||||
.sum();
|
||||
let has_failed = phase_steps
|
||||
.iter()
|
||||
.any(|s| s.get("status").and_then(|v| v.as_str()) == Some("failed"));
|
||||
let has_running = phase_steps
|
||||
.iter()
|
||||
.any(|s| s.get("status").and_then(|v| v.as_str()) == Some("running"));
|
||||
let all_done = phase_steps.iter().all(|s| {
|
||||
let st = s.get("status").and_then(|v| v.as_str()).unwrap_or("");
|
||||
st == "completed" || st == "failed" || st == "skipped"
|
||||
});
|
||||
(
|
||||
pi,
|
||||
phase_steps,
|
||||
phase_findings,
|
||||
has_failed,
|
||||
has_running,
|
||||
all_done,
|
||||
)
|
||||
})
|
||||
.collect();
|
||||
|
||||
let mut active_rail = use_signal(|| 0usize);
|
||||
|
||||
rsx! {
|
||||
// KPI bar
|
||||
div { class: "ac-kpi-bar",
|
||||
div { class: "ac-kpi-card",
|
||||
div { class: "ac-kpi-value", style: "color: var(--text-primary);", "{total_tools}" }
|
||||
div { class: "ac-kpi-label", "Tools Run" }
|
||||
}
|
||||
div { class: "ac-kpi-card",
|
||||
div { class: "ac-kpi-value", style: "color: var(--danger, #dc2626);", "{total_findings}" }
|
||||
div { class: "ac-kpi-label", "Findings" }
|
||||
}
|
||||
div { class: "ac-kpi-card",
|
||||
div { class: "ac-kpi-value", style: "color: var(--success, #16a34a);", "{success_pct}%" }
|
||||
div { class: "ac-kpi-label", "Success Rate" }
|
||||
}
|
||||
div { class: "ac-kpi-card",
|
||||
div { class: "ac-kpi-value", style: "color: var(--warning, #d97706);", "{max_risk}" }
|
||||
div { class: "ac-kpi-label", "Max Risk" }
|
||||
}
|
||||
}
|
||||
|
||||
// Phase rail
|
||||
div { class: "ac-phase-rail",
|
||||
for (pi, (_phase_idx, phase_steps, phase_findings, has_failed, has_running, all_done)) in phase_data.iter().enumerate() {
|
||||
{
|
||||
if pi > 0 {
|
||||
let prev_done = phase_data.get(pi - 1).map(|p| p.5).unwrap_or(false);
|
||||
let bar_class = if prev_done && *all_done {
|
||||
"done"
|
||||
} else if prev_done {
|
||||
"running"
|
||||
} else {
|
||||
""
|
||||
};
|
||||
rsx! {
|
||||
div { class: "ac-rail-bar",
|
||||
div { class: "ac-rail-bar-inner {bar_class}" }
|
||||
}
|
||||
}
|
||||
} else {
|
||||
rsx! {}
|
||||
}
|
||||
}
|
||||
{
|
||||
let dot_class = if *has_running {
|
||||
"running"
|
||||
} else if *has_failed && *all_done {
|
||||
"mixed"
|
||||
} else if *all_done {
|
||||
"done"
|
||||
} else {
|
||||
"pending"
|
||||
};
|
||||
let is_active = *active_rail.read() == pi;
|
||||
let active_cls = if is_active { " active" } else { "" };
|
||||
let findings_cls = if *phase_findings > 0 { "has" } else { "none" };
|
||||
let findings_text = if *phase_findings > 0 {
|
||||
format!("{phase_findings}")
|
||||
} else {
|
||||
"\u{2014}".to_string()
|
||||
};
|
||||
let short = phase_short_name(pi);
|
||||
|
||||
rsx! {
|
||||
div {
|
||||
class: "ac-rail-node{active_cls}",
|
||||
onclick: move |_| {
|
||||
active_rail.set(pi);
|
||||
let js = format!(
|
||||
"document.getElementById('ac-phase-{pi}')?.scrollIntoView({{behavior:'smooth',block:'nearest'}});document.getElementById('ac-phase-{pi}')?.classList.add('open');"
|
||||
);
|
||||
document::eval(&js);
|
||||
},
|
||||
div { class: "ac-rail-dot {dot_class}" }
|
||||
div { class: "ac-rail-label", "{short}" }
|
||||
div { class: "ac-rail-findings {findings_cls}", "{findings_text}" }
|
||||
div { class: "ac-rail-heatmap",
|
||||
for step in phase_steps.iter() {
|
||||
{
|
||||
let st = step.get("status").and_then(|v| v.as_str()).unwrap_or("pending");
|
||||
let hm_cls = match st {
|
||||
"completed" => "ok",
|
||||
"failed" => "fail",
|
||||
"running" => "run",
|
||||
_ => "wait",
|
||||
};
|
||||
rsx! { div { class: "ac-hm-cell {hm_cls}" } }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Progress bar
|
||||
div { class: "ac-progress-track",
|
||||
div { class: "ac-progress-fill", style: "width: {progress_pct}%;" }
|
||||
}
|
||||
|
||||
// Expand all
|
||||
div { class: "ac-controls",
|
||||
button {
|
||||
class: "ac-btn-toggle",
|
||||
onclick: move |_| {
|
||||
document::eval(
|
||||
"document.querySelectorAll('.ac-phase').forEach(p => p.classList.toggle('open', !document.querySelector('.ac-phase.open') || !document.querySelectorAll('.ac-phase:not(.open)').length === 0));(function(){var ps=document.querySelectorAll('.ac-phase');var allOpen=Array.from(ps).every(p=>p.classList.contains('open'));ps.forEach(p=>{if(allOpen)p.classList.remove('open');else p.classList.add('open');});})();"
|
||||
);
|
||||
},
|
||||
"Expand all"
|
||||
}
|
||||
}
|
||||
|
||||
// Phase accordion
|
||||
div { class: "ac-phases",
|
||||
for (pi, (_, phase_steps, phase_findings, _has_failed, has_running, _all_done)) in phase_data.iter().enumerate() {
|
||||
{
|
||||
let open_cls = if pi == 0 { " open" } else { "" };
|
||||
let phase_label = phase_name(pi);
|
||||
let tool_count = phase_steps.len();
|
||||
let meta_text = if *has_running {
|
||||
"in progress".to_string()
|
||||
} else {
|
||||
format!("{phase_findings} findings")
|
||||
};
|
||||
let meta_cls = if *has_running { "running-ct" } else { "findings-ct" };
|
||||
let phase_num_label = format!("PHASE {}", pi + 1);
|
||||
let phase_el_id = format!("ac-phase-{pi}");
|
||||
let phase_el_id2 = phase_el_id.clone();
|
||||
|
||||
rsx! {
|
||||
div {
|
||||
class: "ac-phase{open_cls}",
|
||||
id: "{phase_el_id}",
|
||||
div {
|
||||
class: "ac-phase-header",
|
||||
onclick: move |_| {
|
||||
let js = format!("document.getElementById('{phase_el_id2}').classList.toggle('open');");
|
||||
document::eval(&js);
|
||||
},
|
||||
span { class: "ac-phase-num", "{phase_num_label}" }
|
||||
span { class: "ac-phase-title", "{phase_label}" }
|
||||
div { class: "ac-phase-dots",
|
||||
for step in phase_steps.iter() {
|
||||
{
|
||||
let st = step.get("status").and_then(|v| v.as_str()).unwrap_or("pending");
|
||||
rsx! { div { class: "ac-phase-dot {st}" } }
|
||||
}
|
||||
}
|
||||
}
|
||||
div { class: "ac-phase-meta",
|
||||
span { "{tool_count} tools" }
|
||||
span { class: "{meta_cls}", "{meta_text}" }
|
||||
}
|
||||
span { class: "ac-phase-chevron", "\u{25B8}" }
|
||||
}
|
||||
div { class: "ac-phase-body",
|
||||
div { class: "ac-phase-body-inner",
|
||||
for step in phase_steps.iter() {
|
||||
{
|
||||
let tool_name_val = step.get("tool_name").and_then(|v| v.as_str()).unwrap_or("Unknown").to_string();
|
||||
let status = step.get("status").and_then(|v| v.as_str()).unwrap_or("pending").to_string();
|
||||
let cat = tool_category(&tool_name_val);
|
||||
let emoji = tool_emoji(cat);
|
||||
let label = cat_label(cat);
|
||||
let findings_n = step.get("findings_produced").and_then(|v| v.as_array()).map(|a| a.len()).unwrap_or(0);
|
||||
let risk = step.get("risk_score").and_then(|v| v.as_u64()).map(|v| v as u8);
|
||||
let reasoning = step.get("llm_reasoning").and_then(|v| v.as_str()).unwrap_or("").to_string();
|
||||
let duration = compute_duration(step);
|
||||
let started = step.get("started_at").map(format_bson_time).unwrap_or_default();
|
||||
|
||||
let is_pending = status == "pending";
|
||||
let pending_cls = if is_pending { " is-pending" } else { "" };
|
||||
|
||||
let duration_cls = if status == "running" { "ac-tool-duration running-text" } else { "ac-tool-duration" };
|
||||
let duration_text = if status == "running" {
|
||||
"running\u{2026}".to_string()
|
||||
} else if duration.is_empty() {
|
||||
"\u{2014}".to_string()
|
||||
} else {
|
||||
duration
|
||||
};
|
||||
|
||||
let pill_cls = if findings_n > 0 { "ac-findings-pill has" } else { "ac-findings-pill zero" };
|
||||
let pill_text = if findings_n > 0 { format!("{findings_n}") } else { "\u{2014}".to_string() };
|
||||
|
||||
let (risk_cls, risk_text) = match risk {
|
||||
Some(r) if r >= 75 => ("ac-risk-val high", format!("{r}")),
|
||||
Some(r) if r >= 40 => ("ac-risk-val medium", format!("{r}")),
|
||||
Some(r) => ("ac-risk-val low", format!("{r}")),
|
||||
None => ("ac-risk-val none", "\u{2014}".to_string()),
|
||||
};
|
||||
|
||||
let node_id = step.get("node_id").and_then(|v| v.as_str()).unwrap_or("").to_string();
|
||||
let detail_id = format!("ac-detail-{node_id}");
|
||||
let row_id = format!("ac-row-{node_id}");
|
||||
let detail_id_clone = detail_id.clone();
|
||||
|
||||
rsx! {
|
||||
div {
|
||||
class: "ac-tool-row{pending_cls}",
|
||||
id: "{row_id}",
|
||||
onclick: move |_| {
|
||||
if is_pending { return; }
|
||||
let js = format!(
|
||||
"(function(){{var r=document.getElementById('{row_id}');var d=document.getElementById('{detail_id}');if(r.classList.contains('expanded')){{r.classList.remove('expanded');d.classList.remove('open');}}else{{r.classList.add('expanded');d.classList.add('open');}}}})()"
|
||||
);
|
||||
document::eval(&js);
|
||||
},
|
||||
div { class: "ac-status-bar {status}" }
|
||||
div { class: "ac-tool-icon", "{emoji}" }
|
||||
div { class: "ac-tool-info",
|
||||
div { class: "ac-tool-name", "{tool_name_val}" }
|
||||
span { class: "ac-cat-chip {cat}", "{label}" }
|
||||
}
|
||||
div { class: "{duration_cls}", "{duration_text}" }
|
||||
div { span { class: "{pill_cls}", "{pill_text}" } }
|
||||
div { class: "{risk_cls}", "{risk_text}" }
|
||||
}
|
||||
div {
|
||||
class: "ac-tool-detail",
|
||||
id: "{detail_id_clone}",
|
||||
if !reasoning.is_empty() || !started.is_empty() {
|
||||
div { class: "ac-tool-detail-inner",
|
||||
if !reasoning.is_empty() {
|
||||
div { class: "ac-reasoning-block", "{reasoning}" }
|
||||
}
|
||||
if !started.is_empty() {
|
||||
div { class: "ac-detail-grid",
|
||||
span { class: "ac-detail-label", "Started" }
|
||||
span { class: "ac-detail-value", "{started}" }
|
||||
if !duration_text.is_empty() && status != "running" && duration_text != "\u{2014}" {
|
||||
span { class: "ac-detail-label", "Duration" }
|
||||
span { class: "ac-detail-value", "{duration_text}" }
|
||||
}
|
||||
span { class: "ac-detail-label", "Status" }
|
||||
if status == "completed" {
|
||||
span { class: "ac-detail-value", style: "color: var(--success, #16a34a);", "Completed" }
|
||||
} else if status == "failed" {
|
||||
span { class: "ac-detail-value", style: "color: var(--danger, #dc2626);", "Failed" }
|
||||
} else if status == "running" {
|
||||
span { class: "ac-detail-value", style: "color: var(--warning, #d97706);", "Running" }
|
||||
} else {
|
||||
span { class: "ac-detail-value", "{status}" }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,4 +1,5 @@
|
||||
pub mod app_shell;
|
||||
pub mod attack_chain;
|
||||
pub mod code_inspector;
|
||||
pub mod code_snippet;
|
||||
pub mod file_tree;
|
||||
|
||||
@@ -101,11 +101,18 @@ pub async fn fetch_pentest_session(id: String) -> Result<PentestSessionResponse,
|
||||
if let Ok(tbody) = tresp.json::<serde_json::Value>().await {
|
||||
if let Some(targets) = tbody.get("data").and_then(|v| v.as_array()) {
|
||||
for t in targets {
|
||||
let t_id = t.get("_id").and_then(|v| v.get("$oid")).and_then(|v| v.as_str()).unwrap_or("");
|
||||
let t_id = t
|
||||
.get("_id")
|
||||
.and_then(|v| v.get("$oid"))
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("");
|
||||
if t_id == tid {
|
||||
if let Some(name) = t.get("name").and_then(|v| v.as_str()) {
|
||||
body.data.as_object_mut().map(|obj| {
|
||||
obj.insert("target_name".to_string(), serde_json::Value::String(name.to_string()))
|
||||
obj.insert(
|
||||
"target_name".to_string(),
|
||||
serde_json::Value::String(name.to_string()),
|
||||
)
|
||||
});
|
||||
}
|
||||
break;
|
||||
@@ -155,9 +162,7 @@ pub async fn fetch_pentest_stats() -> Result<PentestStatsResponse, ServerFnError
|
||||
}
|
||||
|
||||
#[server]
|
||||
pub async fn fetch_attack_chain(
|
||||
session_id: String,
|
||||
) -> Result<AttackChainResponse, ServerFnError> {
|
||||
pub async fn fetch_attack_chain(session_id: String) -> Result<AttackChainResponse, ServerFnError> {
|
||||
let state: super::server_state::ServerState =
|
||||
dioxus_fullstack::FullstackContext::extract().await?;
|
||||
let url = format!(
|
||||
|
||||
@@ -116,7 +116,10 @@ pub fn PentestDashboardPage() -> Element {
|
||||
_ => serde_json::Value::Null,
|
||||
}
|
||||
};
|
||||
let severity_critical = sev_dist.get("critical").and_then(|v| v.as_u64()).unwrap_or(0);
|
||||
let severity_critical = sev_dist
|
||||
.get("critical")
|
||||
.and_then(|v| v.as_u64())
|
||||
.unwrap_or(0);
|
||||
let severity_high = sev_dist.get("high").and_then(|v| v.as_u64()).unwrap_or(0);
|
||||
let severity_medium = sev_dist.get("medium").and_then(|v| v.as_u64()).unwrap_or(0);
|
||||
let severity_low = sev_dist.get("low").and_then(|v| v.as_u64()).unwrap_or(0);
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
use std::collections::{HashMap, VecDeque};
|
||||
|
||||
use dioxus::prelude::*;
|
||||
use dioxus_free_icons::icons::bs_icons::*;
|
||||
use dioxus_free_icons::Icon;
|
||||
|
||||
use crate::app::Route;
|
||||
use crate::components::attack_chain::AttackChainView;
|
||||
use crate::components::severity_badge::SeverityBadge;
|
||||
use crate::infrastructure::pentest::{
|
||||
export_pentest_report, fetch_attack_chain, fetch_pentest_findings, fetch_pentest_session,
|
||||
@@ -115,9 +114,7 @@ pub fn PentestSessionPage(session_id: String) -> Element {
|
||||
let list = &data.data;
|
||||
let c = list
|
||||
.iter()
|
||||
.filter(|f| {
|
||||
f.get("severity").and_then(|v| v.as_str()) == Some("critical")
|
||||
})
|
||||
.filter(|f| f.get("severity").and_then(|v| v.as_str()) == Some("critical"))
|
||||
.count();
|
||||
let h = list
|
||||
.iter()
|
||||
@@ -125,9 +122,7 @@ pub fn PentestSessionPage(session_id: String) -> Element {
|
||||
.count();
|
||||
let m = list
|
||||
.iter()
|
||||
.filter(|f| {
|
||||
f.get("severity").and_then(|v| v.as_str()) == Some("medium")
|
||||
})
|
||||
.filter(|f| f.get("severity").and_then(|v| v.as_str()) == Some("medium"))
|
||||
.count();
|
||||
let l = list
|
||||
.iter()
|
||||
@@ -140,7 +135,9 @@ pub fn PentestSessionPage(session_id: String) -> Element {
|
||||
let e = list
|
||||
.iter()
|
||||
.filter(|f| {
|
||||
f.get("exploitable").and_then(|v| v.as_bool()).unwrap_or(false)
|
||||
f.get("exploitable")
|
||||
.and_then(|v| v.as_bool())
|
||||
.unwrap_or(false)
|
||||
})
|
||||
.count();
|
||||
(c, h, m, l, i, e)
|
||||
@@ -171,14 +168,7 @@ pub fn PentestSessionPage(session_id: String) -> Element {
|
||||
let sid = sid_for_export.clone();
|
||||
spawn(async move {
|
||||
// TODO: get real user info from auth context
|
||||
match export_pentest_report(
|
||||
sid.clone(),
|
||||
pw,
|
||||
String::new(),
|
||||
String::new(),
|
||||
)
|
||||
.await
|
||||
{
|
||||
match export_pentest_report(sid.clone(), pw, String::new(), String::new()).await {
|
||||
Ok(resp) => {
|
||||
export_sha256.set(Some(resp.sha256.clone()));
|
||||
// Trigger download via JS
|
||||
@@ -556,586 +546,3 @@ pub fn PentestSessionPage(session_id: String) -> Element {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ═══════════════════════════════════════
|
||||
// Attack Chain Visualization Component
|
||||
// ═══════════════════════════════════════
|
||||
|
||||
/// Get category CSS class from tool name
|
||||
fn tool_category(name: &str) -> &'static str {
|
||||
let lower = name.to_lowercase();
|
||||
if lower.contains("recon") { return "recon"; }
|
||||
if lower.contains("openapi") || lower.contains("api") || lower.contains("swagger") { return "api"; }
|
||||
if lower.contains("header") { return "headers"; }
|
||||
if lower.contains("csp") { return "csp"; }
|
||||
if lower.contains("cookie") { return "cookies"; }
|
||||
if lower.contains("log") || lower.contains("console") { return "logs"; }
|
||||
if lower.contains("rate") || lower.contains("limit") { return "ratelimit"; }
|
||||
if lower.contains("cors") { return "cors"; }
|
||||
if lower.contains("tls") || lower.contains("ssl") { return "tls"; }
|
||||
if lower.contains("redirect") { return "redirect"; }
|
||||
if lower.contains("dns") || lower.contains("dmarc") || lower.contains("email") || lower.contains("spf") { return "email"; }
|
||||
if lower.contains("auth") || lower.contains("jwt") || lower.contains("token") || lower.contains("session") { return "auth"; }
|
||||
if lower.contains("xss") { return "xss"; }
|
||||
if lower.contains("sql") || lower.contains("sqli") { return "sqli"; }
|
||||
if lower.contains("ssrf") { return "ssrf"; }
|
||||
if lower.contains("idor") { return "idor"; }
|
||||
if lower.contains("fuzz") { return "fuzzer"; }
|
||||
if lower.contains("cve") || lower.contains("exploit") { return "cve"; }
|
||||
"default"
|
||||
}
|
||||
|
||||
/// Get emoji icon from tool category
|
||||
fn tool_emoji(cat: &str) -> &'static str {
|
||||
match cat {
|
||||
"recon" => "\u{1F50D}",
|
||||
"api" => "\u{1F517}",
|
||||
"headers" => "\u{1F6E1}",
|
||||
"csp" => "\u{1F6A7}",
|
||||
"cookies" => "\u{1F36A}",
|
||||
"logs" => "\u{1F4DD}",
|
||||
"ratelimit" => "\u{23F1}",
|
||||
"cors" => "\u{1F30D}",
|
||||
"tls" => "\u{1F510}",
|
||||
"redirect" => "\u{21AA}",
|
||||
"email" => "\u{1F4E7}",
|
||||
"auth" => "\u{1F512}",
|
||||
"xss" => "\u{26A1}",
|
||||
"sqli" => "\u{1F489}",
|
||||
"ssrf" => "\u{1F310}",
|
||||
"idor" => "\u{1F511}",
|
||||
"fuzzer" => "\u{1F9EA}",
|
||||
"cve" => "\u{1F4A3}",
|
||||
_ => "\u{1F527}",
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute display label for category
|
||||
fn cat_label(cat: &str) -> &'static str {
|
||||
match cat {
|
||||
"recon" => "Recon",
|
||||
"api" => "API",
|
||||
"headers" => "Headers",
|
||||
"csp" => "CSP",
|
||||
"cookies" => "Cookies",
|
||||
"logs" => "Logs",
|
||||
"ratelimit" => "Rate Limit",
|
||||
"cors" => "CORS",
|
||||
"tls" => "TLS",
|
||||
"redirect" => "Redirect",
|
||||
"email" => "Email/DNS",
|
||||
"auth" => "Auth",
|
||||
"xss" => "XSS",
|
||||
"sqli" => "SQLi",
|
||||
"ssrf" => "SSRF",
|
||||
"idor" => "IDOR",
|
||||
"fuzzer" => "Fuzzer",
|
||||
"cve" => "CVE",
|
||||
_ => "Other",
|
||||
}
|
||||
}
|
||||
|
||||
/// Phase name heuristic based on depth
|
||||
fn phase_name(depth: usize) -> &'static str {
|
||||
match depth {
|
||||
0 => "Reconnaissance",
|
||||
1 => "Analysis",
|
||||
2 => "Boundary Testing",
|
||||
3 => "Injection & Exploitation",
|
||||
4 => "Authentication Testing",
|
||||
5 => "Validation",
|
||||
6 => "Deep Scan",
|
||||
_ => "Final",
|
||||
}
|
||||
}
|
||||
|
||||
/// Short label for phase rail
|
||||
fn phase_short_name(depth: usize) -> &'static str {
|
||||
match depth {
|
||||
0 => "Recon",
|
||||
1 => "Analysis",
|
||||
2 => "Boundary",
|
||||
3 => "Exploit",
|
||||
4 => "Auth",
|
||||
5 => "Validate",
|
||||
6 => "Deep",
|
||||
_ => "Final",
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute BFS phases from attack chain nodes
|
||||
fn compute_phases(steps: &[serde_json::Value]) -> Vec<Vec<usize>> {
|
||||
let node_ids: Vec<String> = steps
|
||||
.iter()
|
||||
.map(|s| s.get("node_id").and_then(|v| v.as_str()).unwrap_or("").to_string())
|
||||
.collect();
|
||||
|
||||
let id_to_idx: HashMap<String, usize> = node_ids
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, id)| (id.clone(), i))
|
||||
.collect();
|
||||
|
||||
// Compute depth via BFS
|
||||
let mut depths = vec![usize::MAX; steps.len()];
|
||||
let mut queue = VecDeque::new();
|
||||
|
||||
// Root nodes: those with no parents or parents not in the set
|
||||
for (i, step) in steps.iter().enumerate() {
|
||||
let parents = step
|
||||
.get("parent_node_ids")
|
||||
.and_then(|v| v.as_array())
|
||||
.map(|arr| {
|
||||
arr.iter()
|
||||
.filter_map(|p| p.as_str())
|
||||
.filter(|p| id_to_idx.contains_key(*p))
|
||||
.count()
|
||||
})
|
||||
.unwrap_or(0);
|
||||
if parents == 0 {
|
||||
depths[i] = 0;
|
||||
queue.push_back(i);
|
||||
}
|
||||
}
|
||||
|
||||
// BFS to compute min depth
|
||||
while let Some(idx) = queue.pop_front() {
|
||||
let current_depth = depths[idx];
|
||||
let node_id = &node_ids[idx];
|
||||
// Find children: nodes that list this node as a parent
|
||||
for (j, step) in steps.iter().enumerate() {
|
||||
if depths[j] <= current_depth + 1 {
|
||||
continue;
|
||||
}
|
||||
let is_child = step
|
||||
.get("parent_node_ids")
|
||||
.and_then(|v| v.as_array())
|
||||
.map(|arr| arr.iter().any(|p| p.as_str() == Some(node_id.as_str())))
|
||||
.unwrap_or(false);
|
||||
if is_child {
|
||||
depths[j] = current_depth + 1;
|
||||
queue.push_back(j);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Handle unreachable nodes
|
||||
for d in depths.iter_mut() {
|
||||
if *d == usize::MAX {
|
||||
*d = 0;
|
||||
}
|
||||
}
|
||||
|
||||
// Group by depth
|
||||
let max_depth = depths.iter().copied().max().unwrap_or(0);
|
||||
let mut phases: Vec<Vec<usize>> = Vec::new();
|
||||
for d in 0..=max_depth {
|
||||
let indices: Vec<usize> = depths
|
||||
.iter()
|
||||
.enumerate()
|
||||
.filter(|(_, &dep)| dep == d)
|
||||
.map(|(i, _)| i)
|
||||
.collect();
|
||||
if !indices.is_empty() {
|
||||
phases.push(indices);
|
||||
}
|
||||
}
|
||||
phases
|
||||
}
|
||||
|
||||
/// Format BSON datetime to readable string
|
||||
fn format_bson_time(val: &serde_json::Value) -> String {
|
||||
// Handle BSON {"$date":{"$numberLong":"..."}}
|
||||
if let Some(date_obj) = val.get("$date") {
|
||||
if let Some(ms_str) = date_obj.get("$numberLong").and_then(|v| v.as_str()) {
|
||||
if let Ok(ms) = ms_str.parse::<i64>() {
|
||||
let secs = ms / 1000;
|
||||
let h = (secs / 3600) % 24;
|
||||
let m = (secs / 60) % 60;
|
||||
let s = secs % 60;
|
||||
return format!("{h:02}:{m:02}:{s:02}");
|
||||
}
|
||||
}
|
||||
// Handle {"$date": "2025-..."}
|
||||
if let Some(s) = date_obj.as_str() {
|
||||
return s.to_string();
|
||||
}
|
||||
}
|
||||
// Handle plain string
|
||||
if let Some(s) = val.as_str() {
|
||||
return s.to_string();
|
||||
}
|
||||
String::new()
|
||||
}
|
||||
|
||||
/// Compute duration string from started_at and completed_at
|
||||
fn compute_duration(step: &serde_json::Value) -> String {
|
||||
let extract_ms = |val: &serde_json::Value| -> Option<i64> {
|
||||
val.get("$date")?
|
||||
.get("$numberLong")?
|
||||
.as_str()?
|
||||
.parse::<i64>()
|
||||
.ok()
|
||||
};
|
||||
|
||||
let started = step.get("started_at").and_then(extract_ms);
|
||||
let completed = step.get("completed_at").and_then(extract_ms);
|
||||
|
||||
match (started, completed) {
|
||||
(Some(s), Some(c)) => {
|
||||
let diff_ms = c - s;
|
||||
if diff_ms < 1000 {
|
||||
format!("{}ms", diff_ms)
|
||||
} else {
|
||||
format!("{:.1}s", diff_ms as f64 / 1000.0)
|
||||
}
|
||||
}
|
||||
_ => String::new(),
|
||||
}
|
||||
}
|
||||
|
||||
#[component]
|
||||
fn AttackChainView(
|
||||
steps: Vec<serde_json::Value>,
|
||||
is_running: bool,
|
||||
session_findings: usize,
|
||||
session_tool_invocations: usize,
|
||||
session_success_rate: f64,
|
||||
) -> Element {
|
||||
let phases = compute_phases(&steps);
|
||||
|
||||
// Compute KPIs — prefer session-level stats, fall back to node-level
|
||||
let total_tools = steps.len();
|
||||
let node_findings: usize = steps
|
||||
.iter()
|
||||
.map(|s| {
|
||||
s.get("findings_produced")
|
||||
.and_then(|v| v.as_array())
|
||||
.map(|a| a.len())
|
||||
.unwrap_or(0)
|
||||
})
|
||||
.sum();
|
||||
// Use session-level findings count if nodes don't have findings linked
|
||||
let total_findings = if node_findings > 0 { node_findings } else { session_findings };
|
||||
|
||||
let completed_count = steps
|
||||
.iter()
|
||||
.filter(|s| s.get("status").and_then(|v| v.as_str()) == Some("completed"))
|
||||
.count();
|
||||
let failed_count = steps
|
||||
.iter()
|
||||
.filter(|s| s.get("status").and_then(|v| v.as_str()) == Some("failed"))
|
||||
.count();
|
||||
let finished = completed_count + failed_count;
|
||||
let success_pct = if finished == 0 {
|
||||
100
|
||||
} else {
|
||||
(completed_count * 100) / finished
|
||||
};
|
||||
let max_risk: u8 = steps
|
||||
.iter()
|
||||
.filter_map(|s| s.get("risk_score").and_then(|v| v.as_u64()))
|
||||
.map(|v| v as u8)
|
||||
.max()
|
||||
.unwrap_or(0);
|
||||
|
||||
let progress_pct = if total_tools == 0 {
|
||||
0
|
||||
} else {
|
||||
((completed_count + failed_count) * 100) / total_tools
|
||||
};
|
||||
|
||||
// Build phase data for rail and accordion
|
||||
let phase_data: Vec<(usize, Vec<&serde_json::Value>, usize, bool, bool, bool)> = phases
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(pi, indices)| {
|
||||
let phase_steps: Vec<&serde_json::Value> = indices.iter().map(|&i| &steps[i]).collect();
|
||||
let phase_findings: usize = phase_steps
|
||||
.iter()
|
||||
.map(|s| {
|
||||
s.get("findings_produced")
|
||||
.and_then(|v| v.as_array())
|
||||
.map(|a| a.len())
|
||||
.unwrap_or(0)
|
||||
})
|
||||
.sum();
|
||||
let has_failed = phase_steps
|
||||
.iter()
|
||||
.any(|s| s.get("status").and_then(|v| v.as_str()) == Some("failed"));
|
||||
let has_running = phase_steps
|
||||
.iter()
|
||||
.any(|s| s.get("status").and_then(|v| v.as_str()) == Some("running"));
|
||||
let all_done = phase_steps.iter().all(|s| {
|
||||
let st = s.get("status").and_then(|v| v.as_str()).unwrap_or("");
|
||||
st == "completed" || st == "failed" || st == "skipped"
|
||||
});
|
||||
(pi, phase_steps, phase_findings, has_failed, has_running, all_done)
|
||||
})
|
||||
.collect();
|
||||
|
||||
let mut active_rail = use_signal(|| 0usize);
|
||||
|
||||
rsx! {
|
||||
// KPI bar
|
||||
div { class: "ac-kpi-bar",
|
||||
div { class: "ac-kpi-card",
|
||||
div { class: "ac-kpi-value", style: "color: var(--text-primary);", "{total_tools}" }
|
||||
div { class: "ac-kpi-label", "Tools Run" }
|
||||
}
|
||||
div { class: "ac-kpi-card",
|
||||
div { class: "ac-kpi-value", style: "color: var(--danger, #dc2626);", "{total_findings}" }
|
||||
div { class: "ac-kpi-label", "Findings" }
|
||||
}
|
||||
div { class: "ac-kpi-card",
|
||||
div { class: "ac-kpi-value", style: "color: var(--success, #16a34a);", "{success_pct}%" }
|
||||
div { class: "ac-kpi-label", "Success Rate" }
|
||||
}
|
||||
div { class: "ac-kpi-card",
|
||||
div { class: "ac-kpi-value", style: "color: var(--warning, #d97706);", "{max_risk}" }
|
||||
div { class: "ac-kpi-label", "Max Risk" }
|
||||
}
|
||||
}
|
||||
|
||||
// Phase rail
|
||||
div { class: "ac-phase-rail",
|
||||
for (pi, (_phase_idx, phase_steps, phase_findings, has_failed, has_running, all_done)) in phase_data.iter().enumerate() {
|
||||
{
|
||||
if pi > 0 {
|
||||
let prev_done = phase_data.get(pi - 1).map(|p| p.5).unwrap_or(false);
|
||||
let bar_class = if prev_done && *all_done {
|
||||
"done"
|
||||
} else if prev_done {
|
||||
"running"
|
||||
} else {
|
||||
""
|
||||
};
|
||||
rsx! {
|
||||
div { class: "ac-rail-bar",
|
||||
div { class: "ac-rail-bar-inner {bar_class}" }
|
||||
}
|
||||
}
|
||||
} else {
|
||||
rsx! {}
|
||||
}
|
||||
}
|
||||
{
|
||||
let dot_class = if *has_running {
|
||||
"running"
|
||||
} else if *has_failed && *all_done {
|
||||
"mixed"
|
||||
} else if *all_done {
|
||||
"done"
|
||||
} else {
|
||||
"pending"
|
||||
};
|
||||
let is_active = *active_rail.read() == pi;
|
||||
let active_cls = if is_active { " active" } else { "" };
|
||||
let findings_cls = if *phase_findings > 0 { "has" } else { "none" };
|
||||
let findings_text = if *phase_findings > 0 {
|
||||
format!("{phase_findings}")
|
||||
} else {
|
||||
"\u{2014}".to_string()
|
||||
};
|
||||
let short = phase_short_name(pi);
|
||||
|
||||
rsx! {
|
||||
div {
|
||||
class: "ac-rail-node{active_cls}",
|
||||
onclick: move |_| {
|
||||
active_rail.set(pi);
|
||||
let js = format!(
|
||||
"document.getElementById('ac-phase-{pi}')?.scrollIntoView({{behavior:'smooth',block:'nearest'}});document.getElementById('ac-phase-{pi}')?.classList.add('open');"
|
||||
);
|
||||
document::eval(&js);
|
||||
},
|
||||
div { class: "ac-rail-dot {dot_class}" }
|
||||
div { class: "ac-rail-label", "{short}" }
|
||||
div { class: "ac-rail-findings {findings_cls}", "{findings_text}" }
|
||||
div { class: "ac-rail-heatmap",
|
||||
for step in phase_steps.iter() {
|
||||
{
|
||||
let st = step.get("status").and_then(|v| v.as_str()).unwrap_or("pending");
|
||||
let hm_cls = match st {
|
||||
"completed" => "ok",
|
||||
"failed" => "fail",
|
||||
"running" => "run",
|
||||
_ => "wait",
|
||||
};
|
||||
rsx! { div { class: "ac-hm-cell {hm_cls}" } }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Progress bar
|
||||
div { class: "ac-progress-track",
|
||||
div { class: "ac-progress-fill", style: "width: {progress_pct}%;" }
|
||||
}
|
||||
|
||||
// Expand all
|
||||
div { class: "ac-controls",
|
||||
button {
|
||||
class: "ac-btn-toggle",
|
||||
onclick: move |_| {
|
||||
document::eval(
|
||||
"document.querySelectorAll('.ac-phase').forEach(p => p.classList.toggle('open', !document.querySelector('.ac-phase.open') || !document.querySelectorAll('.ac-phase:not(.open)').length === 0));(function(){var ps=document.querySelectorAll('.ac-phase');var allOpen=Array.from(ps).every(p=>p.classList.contains('open'));ps.forEach(p=>{if(allOpen)p.classList.remove('open');else p.classList.add('open');});})();"
|
||||
);
|
||||
},
|
||||
"Expand all"
|
||||
}
|
||||
}
|
||||
|
||||
// Phase accordion
|
||||
div { class: "ac-phases",
|
||||
for (pi, (_, phase_steps, phase_findings, has_failed, has_running, all_done)) in phase_data.iter().enumerate() {
|
||||
{
|
||||
let open_cls = if pi == 0 { " open" } else { "" };
|
||||
let phase_label = phase_name(pi);
|
||||
let tool_count = phase_steps.len();
|
||||
let meta_text = if *has_running {
|
||||
"in progress".to_string()
|
||||
} else {
|
||||
format!("{phase_findings} findings")
|
||||
};
|
||||
let meta_cls = if *has_running { "running-ct" } else { "findings-ct" };
|
||||
let phase_num_label = format!("PHASE {}", pi + 1);
|
||||
let phase_el_id = format!("ac-phase-{pi}");
|
||||
let phase_el_id2 = phase_el_id.clone();
|
||||
|
||||
rsx! {
|
||||
div {
|
||||
class: "ac-phase{open_cls}",
|
||||
id: "{phase_el_id}",
|
||||
div {
|
||||
class: "ac-phase-header",
|
||||
onclick: move |_| {
|
||||
let js = format!("document.getElementById('{phase_el_id2}').classList.toggle('open');");
|
||||
document::eval(&js);
|
||||
},
|
||||
span { class: "ac-phase-num", "{phase_num_label}" }
|
||||
span { class: "ac-phase-title", "{phase_label}" }
|
||||
div { class: "ac-phase-dots",
|
||||
for step in phase_steps.iter() {
|
||||
{
|
||||
let st = step.get("status").and_then(|v| v.as_str()).unwrap_or("pending");
|
||||
rsx! { div { class: "ac-phase-dot {st}" } }
|
||||
}
|
||||
}
|
||||
}
|
||||
div { class: "ac-phase-meta",
|
||||
span { "{tool_count} tools" }
|
||||
span { class: "{meta_cls}", "{meta_text}" }
|
||||
}
|
||||
span { class: "ac-phase-chevron", "\u{25B8}" }
|
||||
}
|
||||
div { class: "ac-phase-body",
|
||||
div { class: "ac-phase-body-inner",
|
||||
for step in phase_steps.iter() {
|
||||
{
|
||||
let tool_name_val = step.get("tool_name").and_then(|v| v.as_str()).unwrap_or("Unknown").to_string();
|
||||
let status = step.get("status").and_then(|v| v.as_str()).unwrap_or("pending").to_string();
|
||||
let cat = tool_category(&tool_name_val);
|
||||
let emoji = tool_emoji(cat);
|
||||
let label = cat_label(cat);
|
||||
let findings_n = step.get("findings_produced").and_then(|v| v.as_array()).map(|a| a.len()).unwrap_or(0);
|
||||
let risk = step.get("risk_score").and_then(|v| v.as_u64()).map(|v| v as u8);
|
||||
let reasoning = step.get("llm_reasoning").and_then(|v| v.as_str()).unwrap_or("").to_string();
|
||||
let duration = compute_duration(step);
|
||||
let started = step.get("started_at").map(format_bson_time).unwrap_or_default();
|
||||
|
||||
let is_pending = status == "pending";
|
||||
let pending_cls = if is_pending { " is-pending" } else { "" };
|
||||
|
||||
let duration_cls = if status == "running" { "ac-tool-duration running-text" } else { "ac-tool-duration" };
|
||||
let duration_text = if status == "running" {
|
||||
"running\u{2026}".to_string()
|
||||
} else if duration.is_empty() {
|
||||
"\u{2014}".to_string()
|
||||
} else {
|
||||
duration
|
||||
};
|
||||
|
||||
let pill_cls = if findings_n > 0 { "ac-findings-pill has" } else { "ac-findings-pill zero" };
|
||||
let pill_text = if findings_n > 0 { format!("{findings_n}") } else { "\u{2014}".to_string() };
|
||||
|
||||
let (risk_cls, risk_text) = match risk {
|
||||
Some(r) if r >= 75 => ("ac-risk-val high", format!("{r}")),
|
||||
Some(r) if r >= 40 => ("ac-risk-val medium", format!("{r}")),
|
||||
Some(r) => ("ac-risk-val low", format!("{r}")),
|
||||
None => ("ac-risk-val none", "\u{2014}".to_string()),
|
||||
};
|
||||
|
||||
let node_id = step.get("node_id").and_then(|v| v.as_str()).unwrap_or("").to_string();
|
||||
let detail_id = format!("ac-detail-{node_id}");
|
||||
let row_id = format!("ac-row-{node_id}");
|
||||
let detail_id_clone = detail_id.clone();
|
||||
|
||||
rsx! {
|
||||
div {
|
||||
class: "ac-tool-row{pending_cls}",
|
||||
id: "{row_id}",
|
||||
onclick: move |_| {
|
||||
if is_pending { return; }
|
||||
let js = format!(
|
||||
"(function(){{var r=document.getElementById('{row_id}');var d=document.getElementById('{detail_id}');if(r.classList.contains('expanded')){{r.classList.remove('expanded');d.classList.remove('open');}}else{{r.classList.add('expanded');d.classList.add('open');}}}})()"
|
||||
);
|
||||
document::eval(&js);
|
||||
},
|
||||
div { class: "ac-status-bar {status}" }
|
||||
div { class: "ac-tool-icon", "{emoji}" }
|
||||
div { class: "ac-tool-info",
|
||||
div { class: "ac-tool-name", "{tool_name_val}" }
|
||||
span { class: "ac-cat-chip {cat}", "{label}" }
|
||||
}
|
||||
div { class: "{duration_cls}", "{duration_text}" }
|
||||
div { span { class: "{pill_cls}", "{pill_text}" } }
|
||||
div { class: "{risk_cls}", "{risk_text}" }
|
||||
}
|
||||
div {
|
||||
class: "ac-tool-detail",
|
||||
id: "{detail_id_clone}",
|
||||
if !reasoning.is_empty() || !started.is_empty() {
|
||||
div { class: "ac-tool-detail-inner",
|
||||
if !reasoning.is_empty() {
|
||||
div { class: "ac-reasoning-block", "{reasoning}" }
|
||||
}
|
||||
if !started.is_empty() {
|
||||
div { class: "ac-detail-grid",
|
||||
span { class: "ac-detail-label", "Started" }
|
||||
span { class: "ac-detail-value", "{started}" }
|
||||
if !duration_text.is_empty() && status != "running" && duration_text != "\u{2014}" {
|
||||
span { class: "ac-detail-label", "Duration" }
|
||||
span { class: "ac-detail-value", "{duration_text}" }
|
||||
}
|
||||
span { class: "ac-detail-label", "Status" }
|
||||
if status == "completed" {
|
||||
span { class: "ac-detail-value", style: "color: var(--success, #16a34a);", "Completed" }
|
||||
} else if status == "failed" {
|
||||
span { class: "ac-detail-value", style: "color: var(--danger, #dc2626);", "Failed" }
|
||||
} else if status == "running" {
|
||||
span { class: "ac-detail-value", style: "color: var(--warning, #d97706);", "Running" }
|
||||
} else {
|
||||
span { class: "ac-detail-value", "{status}" }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
use compliance_core::error::CoreError;
|
||||
use compliance_core::traits::dast_agent::{DastAgent, DastContext, DiscoveredEndpoint, EndpointParameter};
|
||||
use compliance_core::traits::dast_agent::{
|
||||
DastAgent, DastContext, DiscoveredEndpoint, EndpointParameter,
|
||||
};
|
||||
use compliance_core::traits::pentest_tool::{PentestTool, PentestToolContext, PentestToolResult};
|
||||
use serde_json::json;
|
||||
|
||||
@@ -7,30 +9,52 @@ use crate::agents::api_fuzzer::ApiFuzzerAgent;
|
||||
|
||||
/// PentestTool wrapper around the existing ApiFuzzerAgent.
|
||||
pub struct ApiFuzzerTool {
|
||||
http: reqwest::Client,
|
||||
_http: reqwest::Client,
|
||||
agent: ApiFuzzerAgent,
|
||||
}
|
||||
|
||||
impl ApiFuzzerTool {
|
||||
pub fn new(http: reqwest::Client) -> Self {
|
||||
let agent = ApiFuzzerAgent::new(http.clone());
|
||||
Self { http, agent }
|
||||
Self { _http: http, agent }
|
||||
}
|
||||
|
||||
fn parse_endpoints(input: &serde_json::Value) -> Vec<DiscoveredEndpoint> {
|
||||
let mut endpoints = Vec::new();
|
||||
if let Some(arr) = input.get("endpoints").and_then(|v| v.as_array()) {
|
||||
for ep in arr {
|
||||
let url = ep.get("url").and_then(|v| v.as_str()).unwrap_or_default().to_string();
|
||||
let method = ep.get("method").and_then(|v| v.as_str()).unwrap_or("GET").to_string();
|
||||
let url = ep
|
||||
.get("url")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or_default()
|
||||
.to_string();
|
||||
let method = ep
|
||||
.get("method")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("GET")
|
||||
.to_string();
|
||||
let mut parameters = Vec::new();
|
||||
if let Some(params) = ep.get("parameters").and_then(|v| v.as_array()) {
|
||||
for p in params {
|
||||
parameters.push(EndpointParameter {
|
||||
name: p.get("name").and_then(|v| v.as_str()).unwrap_or_default().to_string(),
|
||||
location: p.get("location").and_then(|v| v.as_str()).unwrap_or("query").to_string(),
|
||||
param_type: p.get("param_type").and_then(|v| v.as_str()).map(String::from),
|
||||
example_value: p.get("example_value").and_then(|v| v.as_str()).map(String::from),
|
||||
name: p
|
||||
.get("name")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or_default()
|
||||
.to_string(),
|
||||
location: p
|
||||
.get("location")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("query")
|
||||
.to_string(),
|
||||
param_type: p
|
||||
.get("param_type")
|
||||
.and_then(|v| v.as_str())
|
||||
.map(String::from),
|
||||
example_value: p
|
||||
.get("example_value")
|
||||
.and_then(|v| v.as_str())
|
||||
.map(String::from),
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -38,8 +62,14 @@ impl ApiFuzzerTool {
|
||||
url,
|
||||
method,
|
||||
parameters,
|
||||
content_type: ep.get("content_type").and_then(|v| v.as_str()).map(String::from),
|
||||
requires_auth: ep.get("requires_auth").and_then(|v| v.as_bool()).unwrap_or(false),
|
||||
content_type: ep
|
||||
.get("content_type")
|
||||
.and_then(|v| v.as_str())
|
||||
.map(String::from),
|
||||
requires_auth: ep
|
||||
.get("requires_auth")
|
||||
.and_then(|v| v.as_bool())
|
||||
.unwrap_or(false),
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -98,49 +128,51 @@ impl PentestTool for ApiFuzzerTool {
|
||||
&'a self,
|
||||
input: serde_json::Value,
|
||||
context: &'a PentestToolContext,
|
||||
) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<PentestToolResult, CoreError>> + Send + 'a>> {
|
||||
) -> std::pin::Pin<
|
||||
Box<dyn std::future::Future<Output = Result<PentestToolResult, CoreError>> + Send + 'a>,
|
||||
> {
|
||||
Box::pin(async move {
|
||||
let mut endpoints = Self::parse_endpoints(&input);
|
||||
let mut endpoints = Self::parse_endpoints(&input);
|
||||
|
||||
// If a base_url is provided but no endpoints, create a default endpoint
|
||||
if endpoints.is_empty() {
|
||||
if let Some(base) = input.get("base_url").and_then(|v| v.as_str()) {
|
||||
endpoints.push(DiscoveredEndpoint {
|
||||
url: base.to_string(),
|
||||
method: "GET".to_string(),
|
||||
parameters: Vec::new(),
|
||||
content_type: None,
|
||||
requires_auth: false,
|
||||
// If a base_url is provided but no endpoints, create a default endpoint
|
||||
if endpoints.is_empty() {
|
||||
if let Some(base) = input.get("base_url").and_then(|v| v.as_str()) {
|
||||
endpoints.push(DiscoveredEndpoint {
|
||||
url: base.to_string(),
|
||||
method: "GET".to_string(),
|
||||
parameters: Vec::new(),
|
||||
content_type: None,
|
||||
requires_auth: false,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
if endpoints.is_empty() {
|
||||
return Ok(PentestToolResult {
|
||||
summary: "No endpoints or base_url provided to fuzz.".to_string(),
|
||||
findings: Vec::new(),
|
||||
data: json!({}),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
if endpoints.is_empty() {
|
||||
return Ok(PentestToolResult {
|
||||
summary: "No endpoints or base_url provided to fuzz.".to_string(),
|
||||
findings: Vec::new(),
|
||||
data: json!({}),
|
||||
});
|
||||
}
|
||||
let dast_context = DastContext {
|
||||
endpoints,
|
||||
technologies: Vec::new(),
|
||||
sast_hints: Vec::new(),
|
||||
};
|
||||
|
||||
let dast_context = DastContext {
|
||||
endpoints,
|
||||
technologies: Vec::new(),
|
||||
sast_hints: Vec::new(),
|
||||
};
|
||||
let findings = self.agent.run(&context.target, &dast_context).await?;
|
||||
let count = findings.len();
|
||||
|
||||
let findings = self.agent.run(&context.target, &dast_context).await?;
|
||||
let count = findings.len();
|
||||
|
||||
Ok(PentestToolResult {
|
||||
summary: if count > 0 {
|
||||
format!("Found {count} API misconfigurations or information disclosures.")
|
||||
} else {
|
||||
"No API misconfigurations detected.".to_string()
|
||||
},
|
||||
findings,
|
||||
data: json!({ "endpoints_tested": dast_context.endpoints.len() }),
|
||||
})
|
||||
Ok(PentestToolResult {
|
||||
summary: if count > 0 {
|
||||
format!("Found {count} API misconfigurations or information disclosures.")
|
||||
} else {
|
||||
"No API misconfigurations detected.".to_string()
|
||||
},
|
||||
findings,
|
||||
data: json!({ "endpoints_tested": dast_context.endpoints.len() }),
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
use compliance_core::error::CoreError;
|
||||
use compliance_core::traits::dast_agent::{DastAgent, DastContext, DiscoveredEndpoint, EndpointParameter};
|
||||
use compliance_core::traits::dast_agent::{
|
||||
DastAgent, DastContext, DiscoveredEndpoint, EndpointParameter,
|
||||
};
|
||||
use compliance_core::traits::pentest_tool::{PentestTool, PentestToolContext, PentestToolResult};
|
||||
use serde_json::json;
|
||||
|
||||
@@ -7,30 +9,52 @@ use crate::agents::auth_bypass::AuthBypassAgent;
|
||||
|
||||
/// PentestTool wrapper around the existing AuthBypassAgent.
|
||||
pub struct AuthBypassTool {
|
||||
http: reqwest::Client,
|
||||
_http: reqwest::Client,
|
||||
agent: AuthBypassAgent,
|
||||
}
|
||||
|
||||
impl AuthBypassTool {
|
||||
pub fn new(http: reqwest::Client) -> Self {
|
||||
let agent = AuthBypassAgent::new(http.clone());
|
||||
Self { http, agent }
|
||||
Self { _http: http, agent }
|
||||
}
|
||||
|
||||
fn parse_endpoints(input: &serde_json::Value) -> Vec<DiscoveredEndpoint> {
|
||||
let mut endpoints = Vec::new();
|
||||
if let Some(arr) = input.get("endpoints").and_then(|v| v.as_array()) {
|
||||
for ep in arr {
|
||||
let url = ep.get("url").and_then(|v| v.as_str()).unwrap_or_default().to_string();
|
||||
let method = ep.get("method").and_then(|v| v.as_str()).unwrap_or("GET").to_string();
|
||||
let url = ep
|
||||
.get("url")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or_default()
|
||||
.to_string();
|
||||
let method = ep
|
||||
.get("method")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("GET")
|
||||
.to_string();
|
||||
let mut parameters = Vec::new();
|
||||
if let Some(params) = ep.get("parameters").and_then(|v| v.as_array()) {
|
||||
for p in params {
|
||||
parameters.push(EndpointParameter {
|
||||
name: p.get("name").and_then(|v| v.as_str()).unwrap_or_default().to_string(),
|
||||
location: p.get("location").and_then(|v| v.as_str()).unwrap_or("query").to_string(),
|
||||
param_type: p.get("param_type").and_then(|v| v.as_str()).map(String::from),
|
||||
example_value: p.get("example_value").and_then(|v| v.as_str()).map(String::from),
|
||||
name: p
|
||||
.get("name")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or_default()
|
||||
.to_string(),
|
||||
location: p
|
||||
.get("location")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("query")
|
||||
.to_string(),
|
||||
param_type: p
|
||||
.get("param_type")
|
||||
.and_then(|v| v.as_str())
|
||||
.map(String::from),
|
||||
example_value: p
|
||||
.get("example_value")
|
||||
.and_then(|v| v.as_str())
|
||||
.map(String::from),
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -38,8 +62,14 @@ impl AuthBypassTool {
|
||||
url,
|
||||
method,
|
||||
parameters,
|
||||
content_type: ep.get("content_type").and_then(|v| v.as_str()).map(String::from),
|
||||
requires_auth: ep.get("requires_auth").and_then(|v| v.as_bool()).unwrap_or(false),
|
||||
content_type: ep
|
||||
.get("content_type")
|
||||
.and_then(|v| v.as_str())
|
||||
.map(String::from),
|
||||
requires_auth: ep
|
||||
.get("requires_auth")
|
||||
.and_then(|v| v.as_bool())
|
||||
.unwrap_or(false),
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -96,35 +126,37 @@ impl PentestTool for AuthBypassTool {
|
||||
&'a self,
|
||||
input: serde_json::Value,
|
||||
context: &'a PentestToolContext,
|
||||
) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<PentestToolResult, CoreError>> + Send + 'a>> {
|
||||
) -> std::pin::Pin<
|
||||
Box<dyn std::future::Future<Output = Result<PentestToolResult, CoreError>> + Send + 'a>,
|
||||
> {
|
||||
Box::pin(async move {
|
||||
let endpoints = Self::parse_endpoints(&input);
|
||||
if endpoints.is_empty() {
|
||||
return Ok(PentestToolResult {
|
||||
summary: "No endpoints provided to test.".to_string(),
|
||||
findings: Vec::new(),
|
||||
data: json!({}),
|
||||
});
|
||||
}
|
||||
let endpoints = Self::parse_endpoints(&input);
|
||||
if endpoints.is_empty() {
|
||||
return Ok(PentestToolResult {
|
||||
summary: "No endpoints provided to test.".to_string(),
|
||||
findings: Vec::new(),
|
||||
data: json!({}),
|
||||
});
|
||||
}
|
||||
|
||||
let dast_context = DastContext {
|
||||
endpoints,
|
||||
technologies: Vec::new(),
|
||||
sast_hints: Vec::new(),
|
||||
};
|
||||
let dast_context = DastContext {
|
||||
endpoints,
|
||||
technologies: Vec::new(),
|
||||
sast_hints: Vec::new(),
|
||||
};
|
||||
|
||||
let findings = self.agent.run(&context.target, &dast_context).await?;
|
||||
let count = findings.len();
|
||||
let findings = self.agent.run(&context.target, &dast_context).await?;
|
||||
let count = findings.len();
|
||||
|
||||
Ok(PentestToolResult {
|
||||
summary: if count > 0 {
|
||||
format!("Found {count} authentication bypass vulnerabilities.")
|
||||
} else {
|
||||
"No authentication bypass vulnerabilities detected.".to_string()
|
||||
},
|
||||
findings,
|
||||
data: json!({ "endpoints_tested": dast_context.endpoints.len() }),
|
||||
})
|
||||
Ok(PentestToolResult {
|
||||
summary: if count > 0 {
|
||||
format!("Found {count} authentication bypass vulnerabilities.")
|
||||
} else {
|
||||
"No authentication bypass vulnerabilities detected.".to_string()
|
||||
},
|
||||
findings,
|
||||
data: json!({ "endpoints_tested": dast_context.endpoints.len() }),
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -54,7 +54,7 @@ impl ConsoleLogDetectorTool {
|
||||
}
|
||||
|
||||
let quote = html.as_bytes().get(abs_start).copied();
|
||||
let (open, close) = match quote {
|
||||
let (_open, close) = match quote {
|
||||
Some(b'"') => ('"', '"'),
|
||||
Some(b'\'') => ('\'', '\''),
|
||||
_ => {
|
||||
@@ -122,6 +122,96 @@ impl ConsoleLogDetectorTool {
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn extract_js_urls_from_html() {
|
||||
let html = r#"
|
||||
<html>
|
||||
<head>
|
||||
<script src="/static/app.js"></script>
|
||||
<script src="https://cdn.example.com/lib.js"></script>
|
||||
<script src='//cdn2.example.com/vendor.js'></script>
|
||||
</head>
|
||||
</html>
|
||||
"#;
|
||||
let urls = ConsoleLogDetectorTool::extract_js_urls(html, "https://example.com");
|
||||
assert_eq!(urls.len(), 3);
|
||||
assert!(urls.contains(&"https://example.com/static/app.js".to_string()));
|
||||
assert!(urls.contains(&"https://cdn.example.com/lib.js".to_string()));
|
||||
assert!(urls.contains(&"https://cdn2.example.com/vendor.js".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extract_js_urls_no_scripts() {
|
||||
let html = "<html><body><p>Hello</p></body></html>";
|
||||
let urls = ConsoleLogDetectorTool::extract_js_urls(html, "https://example.com");
|
||||
assert!(urls.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extract_js_urls_filters_non_js() {
|
||||
let html = r#"<link src="/style.css"><script src="/app.js"></script>"#;
|
||||
let urls = ConsoleLogDetectorTool::extract_js_urls(html, "https://example.com");
|
||||
// Only .js files should be extracted
|
||||
assert_eq!(urls.len(), 1);
|
||||
assert!(urls[0].ends_with("/app.js"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn scan_js_content_finds_console_log() {
|
||||
let js = r#"
|
||||
function init() {
|
||||
console.log("debug info");
|
||||
doStuff();
|
||||
}
|
||||
"#;
|
||||
let matches = ConsoleLogDetectorTool::scan_js_content(js, "https://example.com/app.js");
|
||||
assert_eq!(matches.len(), 1);
|
||||
assert_eq!(matches[0].pattern, "console.log");
|
||||
assert_eq!(matches[0].line_number, Some(3));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn scan_js_content_finds_multiple_patterns() {
|
||||
let js =
|
||||
"console.log('a');\nconsole.debug('b');\nconsole.error('c');\ndebugger;\nalert('x');";
|
||||
let matches = ConsoleLogDetectorTool::scan_js_content(js, "test.js");
|
||||
assert_eq!(matches.len(), 5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn scan_js_content_skips_comments() {
|
||||
let js = "// console.log('commented out');\n* console.log('also comment');\n/* console.log('block comment') */";
|
||||
let matches = ConsoleLogDetectorTool::scan_js_content(js, "test.js");
|
||||
assert!(matches.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn scan_js_content_one_match_per_line() {
|
||||
let js = "console.log('a'); console.debug('b');";
|
||||
let matches = ConsoleLogDetectorTool::scan_js_content(js, "test.js");
|
||||
// Only one match per line
|
||||
assert_eq!(matches.len(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn scan_js_content_empty_input() {
|
||||
let matches = ConsoleLogDetectorTool::scan_js_content("", "test.js");
|
||||
assert!(matches.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn patterns_list_is_not_empty() {
|
||||
let patterns = ConsoleLogDetectorTool::patterns();
|
||||
assert!(patterns.len() >= 8);
|
||||
assert!(patterns.contains(&"console.log("));
|
||||
assert!(patterns.contains(&"debugger;"));
|
||||
}
|
||||
}
|
||||
|
||||
impl PentestTool for ConsoleLogDetectorTool {
|
||||
fn name(&self) -> &str {
|
||||
"console_log_detector"
|
||||
@@ -154,173 +244,180 @@ impl PentestTool for ConsoleLogDetectorTool {
|
||||
&'a self,
|
||||
input: serde_json::Value,
|
||||
context: &'a PentestToolContext,
|
||||
) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<PentestToolResult, CoreError>> + Send + 'a>> {
|
||||
) -> std::pin::Pin<
|
||||
Box<dyn std::future::Future<Output = Result<PentestToolResult, CoreError>> + Send + 'a>,
|
||||
> {
|
||||
Box::pin(async move {
|
||||
let url = input
|
||||
.get("url")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| CoreError::Dast("Missing required 'url' parameter".to_string()))?;
|
||||
let url = input
|
||||
.get("url")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| CoreError::Dast("Missing required 'url' parameter".to_string()))?;
|
||||
|
||||
let additional_js: Vec<String> = input
|
||||
.get("additional_js_urls")
|
||||
.and_then(|v| v.as_array())
|
||||
.map(|arr| {
|
||||
arr.iter()
|
||||
.filter_map(|v| v.as_str().map(String::from))
|
||||
.collect()
|
||||
})
|
||||
.unwrap_or_default();
|
||||
|
||||
let target_id = context
|
||||
.target
|
||||
.id
|
||||
.map(|oid| oid.to_hex())
|
||||
.unwrap_or_else(|| "unknown".to_string());
|
||||
|
||||
// Fetch the main page
|
||||
let response = self
|
||||
.http
|
||||
.get(url)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| CoreError::Dast(format!("Failed to fetch {url}: {e}")))?;
|
||||
|
||||
let html = response.text().await.unwrap_or_default();
|
||||
|
||||
// Scan inline scripts in the HTML
|
||||
let mut all_matches = Vec::new();
|
||||
let inline_matches = Self::scan_js_content(&html, url);
|
||||
all_matches.extend(inline_matches);
|
||||
|
||||
// Extract JS file URLs from the HTML
|
||||
let mut js_urls = Self::extract_js_urls(&html, url);
|
||||
js_urls.extend(additional_js);
|
||||
js_urls.dedup();
|
||||
|
||||
// Fetch and scan each JS file
|
||||
for js_url in &js_urls {
|
||||
match self.http.get(js_url).send().await {
|
||||
Ok(resp) => {
|
||||
if resp.status().is_success() {
|
||||
let js_content = resp.text().await.unwrap_or_default();
|
||||
// Only scan non-minified-looking files or files where we can still
|
||||
// find patterns (minifiers typically strip console calls, but not always)
|
||||
let file_matches = Self::scan_js_content(&js_content, js_url);
|
||||
all_matches.extend(file_matches);
|
||||
}
|
||||
}
|
||||
Err(_) => continue,
|
||||
}
|
||||
}
|
||||
|
||||
let mut findings = Vec::new();
|
||||
let match_data: Vec<serde_json::Value> = all_matches
|
||||
.iter()
|
||||
.map(|m| {
|
||||
json!({
|
||||
"pattern": m.pattern,
|
||||
"file": m.file_url,
|
||||
"line": m.line_number,
|
||||
"snippet": m.line_snippet,
|
||||
let additional_js: Vec<String> = input
|
||||
.get("additional_js_urls")
|
||||
.and_then(|v| v.as_array())
|
||||
.map(|arr| {
|
||||
arr.iter()
|
||||
.filter_map(|v| v.as_str().map(String::from))
|
||||
.collect()
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
.unwrap_or_default();
|
||||
|
||||
if !all_matches.is_empty() {
|
||||
// Group by file for the finding
|
||||
let mut by_file: std::collections::HashMap<&str, Vec<&ConsoleMatch>> =
|
||||
std::collections::HashMap::new();
|
||||
for m in &all_matches {
|
||||
by_file.entry(&m.file_url).or_default().push(m);
|
||||
let target_id = context
|
||||
.target
|
||||
.id
|
||||
.map(|oid| oid.to_hex())
|
||||
.unwrap_or_else(|| "unknown".to_string());
|
||||
|
||||
// Fetch the main page
|
||||
let response = self
|
||||
.http
|
||||
.get(url)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| CoreError::Dast(format!("Failed to fetch {url}: {e}")))?;
|
||||
|
||||
let html = response.text().await.unwrap_or_default();
|
||||
|
||||
// Scan inline scripts in the HTML
|
||||
let mut all_matches = Vec::new();
|
||||
let inline_matches = Self::scan_js_content(&html, url);
|
||||
all_matches.extend(inline_matches);
|
||||
|
||||
// Extract JS file URLs from the HTML
|
||||
let mut js_urls = Self::extract_js_urls(&html, url);
|
||||
js_urls.extend(additional_js);
|
||||
js_urls.dedup();
|
||||
|
||||
// Fetch and scan each JS file
|
||||
for js_url in &js_urls {
|
||||
match self.http.get(js_url).send().await {
|
||||
Ok(resp) => {
|
||||
if resp.status().is_success() {
|
||||
let js_content = resp.text().await.unwrap_or_default();
|
||||
// Only scan non-minified-looking files or files where we can still
|
||||
// find patterns (minifiers typically strip console calls, but not always)
|
||||
let file_matches = Self::scan_js_content(&js_content, js_url);
|
||||
all_matches.extend(file_matches);
|
||||
}
|
||||
}
|
||||
Err(_) => continue,
|
||||
}
|
||||
}
|
||||
|
||||
for (file_url, matches) in &by_file {
|
||||
let pattern_summary: Vec<String> = matches
|
||||
.iter()
|
||||
.take(5)
|
||||
.map(|m| {
|
||||
format!(
|
||||
" Line {}: {} - {}",
|
||||
m.line_number.unwrap_or(0),
|
||||
m.pattern,
|
||||
if m.line_snippet.len() > 80 {
|
||||
format!("{}...", &m.line_snippet[..80])
|
||||
} else {
|
||||
m.line_snippet.clone()
|
||||
}
|
||||
)
|
||||
let mut findings = Vec::new();
|
||||
let match_data: Vec<serde_json::Value> = all_matches
|
||||
.iter()
|
||||
.map(|m| {
|
||||
json!({
|
||||
"pattern": m.pattern,
|
||||
"file": m.file_url,
|
||||
"line": m.line_number,
|
||||
"snippet": m.line_snippet,
|
||||
})
|
||||
.collect();
|
||||
})
|
||||
.collect();
|
||||
|
||||
let evidence = DastEvidence {
|
||||
request_method: "GET".to_string(),
|
||||
request_url: file_url.to_string(),
|
||||
request_headers: None,
|
||||
request_body: None,
|
||||
response_status: 200,
|
||||
response_headers: None,
|
||||
response_snippet: Some(pattern_summary.join("\n")),
|
||||
screenshot_path: None,
|
||||
payload: None,
|
||||
response_time_ms: None,
|
||||
};
|
||||
if !all_matches.is_empty() {
|
||||
// Group by file for the finding
|
||||
let mut by_file: std::collections::HashMap<&str, Vec<&ConsoleMatch>> =
|
||||
std::collections::HashMap::new();
|
||||
for m in &all_matches {
|
||||
by_file.entry(&m.file_url).or_default().push(m);
|
||||
}
|
||||
|
||||
let total = matches.len();
|
||||
let extra = if total > 5 {
|
||||
format!(" (and {} more)", total - 5)
|
||||
} else {
|
||||
String::new()
|
||||
};
|
||||
for (file_url, matches) in &by_file {
|
||||
let pattern_summary: Vec<String> = matches
|
||||
.iter()
|
||||
.take(5)
|
||||
.map(|m| {
|
||||
format!(
|
||||
" Line {}: {} - {}",
|
||||
m.line_number.unwrap_or(0),
|
||||
m.pattern,
|
||||
if m.line_snippet.len() > 80 {
|
||||
format!("{}...", &m.line_snippet[..80])
|
||||
} else {
|
||||
m.line_snippet.clone()
|
||||
}
|
||||
)
|
||||
})
|
||||
.collect();
|
||||
|
||||
let mut finding = DastFinding::new(
|
||||
String::new(),
|
||||
target_id.clone(),
|
||||
DastVulnType::ConsoleLogLeakage,
|
||||
format!("Console/debug statements in {}", file_url),
|
||||
format!(
|
||||
"Found {total} console/debug statements in {file_url}{extra}. \
|
||||
let evidence = DastEvidence {
|
||||
request_method: "GET".to_string(),
|
||||
request_url: file_url.to_string(),
|
||||
request_headers: None,
|
||||
request_body: None,
|
||||
response_status: 200,
|
||||
response_headers: None,
|
||||
response_snippet: Some(pattern_summary.join("\n")),
|
||||
screenshot_path: None,
|
||||
payload: None,
|
||||
response_time_ms: None,
|
||||
};
|
||||
|
||||
let total = matches.len();
|
||||
let extra = if total > 5 {
|
||||
format!(" (and {} more)", total - 5)
|
||||
} else {
|
||||
String::new()
|
||||
};
|
||||
|
||||
let mut finding = DastFinding::new(
|
||||
String::new(),
|
||||
target_id.clone(),
|
||||
DastVulnType::ConsoleLogLeakage,
|
||||
format!("Console/debug statements in {}", file_url),
|
||||
format!(
|
||||
"Found {total} console/debug statements in {file_url}{extra}. \
|
||||
These can leak sensitive information such as API responses, user data, \
|
||||
or internal state to anyone with browser developer tools open."
|
||||
),
|
||||
Severity::Low,
|
||||
file_url.to_string(),
|
||||
"GET".to_string(),
|
||||
);
|
||||
finding.cwe = Some("CWE-532".to_string());
|
||||
finding.evidence = vec![evidence];
|
||||
finding.remediation = Some(
|
||||
"Remove console.log/debug/error statements from production code. \
|
||||
),
|
||||
Severity::Low,
|
||||
file_url.to_string(),
|
||||
"GET".to_string(),
|
||||
);
|
||||
finding.cwe = Some("CWE-532".to_string());
|
||||
finding.evidence = vec![evidence];
|
||||
finding.remediation = Some(
|
||||
"Remove console.log/debug/error statements from production code. \
|
||||
Use a build step (e.g., babel plugin, terser) to strip console calls \
|
||||
during the production build."
|
||||
.to_string(),
|
||||
);
|
||||
findings.push(finding);
|
||||
.to_string(),
|
||||
);
|
||||
findings.push(finding);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let total_matches = all_matches.len();
|
||||
let count = findings.len();
|
||||
info!(url, js_files = js_urls.len(), total_matches, "Console log detection complete");
|
||||
let total_matches = all_matches.len();
|
||||
let count = findings.len();
|
||||
info!(
|
||||
url,
|
||||
js_files = js_urls.len(),
|
||||
total_matches,
|
||||
"Console log detection complete"
|
||||
);
|
||||
|
||||
Ok(PentestToolResult {
|
||||
summary: if total_matches > 0 {
|
||||
format!(
|
||||
"Found {total_matches} console/debug statements across {} files.",
|
||||
count
|
||||
)
|
||||
} else {
|
||||
format!(
|
||||
"No console/debug statements found in HTML or {} JS files.",
|
||||
js_urls.len()
|
||||
)
|
||||
},
|
||||
findings,
|
||||
data: json!({
|
||||
"total_matches": total_matches,
|
||||
"js_files_scanned": js_urls.len(),
|
||||
"matches": match_data,
|
||||
}),
|
||||
})
|
||||
Ok(PentestToolResult {
|
||||
summary: if total_matches > 0 {
|
||||
format!(
|
||||
"Found {total_matches} console/debug statements across {} files.",
|
||||
count
|
||||
)
|
||||
} else {
|
||||
format!(
|
||||
"No console/debug statements found in HTML or {} JS files.",
|
||||
js_urls.len()
|
||||
)
|
||||
},
|
||||
findings,
|
||||
data: json!({
|
||||
"total_matches": total_matches,
|
||||
"js_files_scanned": js_urls.len(),
|
||||
"matches": match_data,
|
||||
}),
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -14,6 +14,7 @@ pub struct CookieAnalyzerTool {
|
||||
#[derive(Debug)]
|
||||
struct ParsedCookie {
|
||||
name: String,
|
||||
#[allow(dead_code)]
|
||||
value: String,
|
||||
secure: bool,
|
||||
http_only: bool,
|
||||
@@ -92,6 +93,81 @@ impl CookieAnalyzerTool {
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn parse_simple_cookie() {
|
||||
let cookie = CookieAnalyzerTool::parse_set_cookie("session_id=abc123");
|
||||
assert_eq!(cookie.name, "session_id");
|
||||
assert_eq!(cookie.value, "abc123");
|
||||
assert!(!cookie.secure);
|
||||
assert!(!cookie.http_only);
|
||||
assert!(cookie.same_site.is_none());
|
||||
assert!(cookie.domain.is_none());
|
||||
assert!(cookie.path.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_cookie_with_all_attributes() {
|
||||
let raw = "token=xyz; Secure; HttpOnly; SameSite=Strict; Domain=.example.com; Path=/api";
|
||||
let cookie = CookieAnalyzerTool::parse_set_cookie(raw);
|
||||
assert_eq!(cookie.name, "token");
|
||||
assert_eq!(cookie.value, "xyz");
|
||||
assert!(cookie.secure);
|
||||
assert!(cookie.http_only);
|
||||
assert_eq!(cookie.same_site.as_deref(), Some("strict"));
|
||||
assert_eq!(cookie.domain.as_deref(), Some(".example.com"));
|
||||
assert_eq!(cookie.path.as_deref(), Some("/api"));
|
||||
assert_eq!(cookie.raw, raw);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_cookie_samesite_none() {
|
||||
let cookie = CookieAnalyzerTool::parse_set_cookie("id=1; SameSite=None; Secure");
|
||||
assert_eq!(cookie.same_site.as_deref(), Some("none"));
|
||||
assert!(cookie.secure);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_cookie_with_equals_in_value() {
|
||||
let cookie = CookieAnalyzerTool::parse_set_cookie("data=a=b=c; HttpOnly");
|
||||
assert_eq!(cookie.name, "data");
|
||||
assert_eq!(cookie.value, "a=b=c");
|
||||
assert!(cookie.http_only);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn is_sensitive_cookie_known_names() {
|
||||
assert!(CookieAnalyzerTool::is_sensitive_cookie("session_id"));
|
||||
assert!(CookieAnalyzerTool::is_sensitive_cookie("PHPSESSID"));
|
||||
assert!(CookieAnalyzerTool::is_sensitive_cookie("JSESSIONID"));
|
||||
assert!(CookieAnalyzerTool::is_sensitive_cookie("connect.sid"));
|
||||
assert!(CookieAnalyzerTool::is_sensitive_cookie("asp.net_sessionid"));
|
||||
assert!(CookieAnalyzerTool::is_sensitive_cookie("auth_token"));
|
||||
assert!(CookieAnalyzerTool::is_sensitive_cookie("jwt_access"));
|
||||
assert!(CookieAnalyzerTool::is_sensitive_cookie("csrf_token"));
|
||||
assert!(CookieAnalyzerTool::is_sensitive_cookie("my_sess_cookie"));
|
||||
assert!(CookieAnalyzerTool::is_sensitive_cookie("SID"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn is_sensitive_cookie_non_sensitive() {
|
||||
assert!(!CookieAnalyzerTool::is_sensitive_cookie("theme"));
|
||||
assert!(!CookieAnalyzerTool::is_sensitive_cookie("language"));
|
||||
assert!(!CookieAnalyzerTool::is_sensitive_cookie("_ga"));
|
||||
assert!(!CookieAnalyzerTool::is_sensitive_cookie("tracking"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_empty_cookie_header() {
|
||||
let cookie = CookieAnalyzerTool::parse_set_cookie("");
|
||||
assert_eq!(cookie.name, "");
|
||||
assert_eq!(cookie.value, "");
|
||||
}
|
||||
}
|
||||
|
||||
impl PentestTool for CookieAnalyzerTool {
|
||||
fn name(&self) -> &str {
|
||||
"cookie_analyzer"
|
||||
@@ -123,96 +199,96 @@ impl PentestTool for CookieAnalyzerTool {
|
||||
&'a self,
|
||||
input: serde_json::Value,
|
||||
context: &'a PentestToolContext,
|
||||
) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<PentestToolResult, CoreError>> + Send + 'a>> {
|
||||
) -> std::pin::Pin<
|
||||
Box<dyn std::future::Future<Output = Result<PentestToolResult, CoreError>> + Send + 'a>,
|
||||
> {
|
||||
Box::pin(async move {
|
||||
let url = input
|
||||
.get("url")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| CoreError::Dast("Missing required 'url' parameter".to_string()))?;
|
||||
let url = input
|
||||
.get("url")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| CoreError::Dast("Missing required 'url' parameter".to_string()))?;
|
||||
|
||||
let login_url = input.get("login_url").and_then(|v| v.as_str());
|
||||
let login_url = input.get("login_url").and_then(|v| v.as_str());
|
||||
|
||||
let target_id = context
|
||||
.target
|
||||
.id
|
||||
.map(|oid| oid.to_hex())
|
||||
.unwrap_or_else(|| "unknown".to_string());
|
||||
let target_id = context
|
||||
.target
|
||||
.id
|
||||
.map(|oid| oid.to_hex())
|
||||
.unwrap_or_else(|| "unknown".to_string());
|
||||
|
||||
let mut findings = Vec::new();
|
||||
let mut cookie_data = Vec::new();
|
||||
let mut findings = Vec::new();
|
||||
let mut cookie_data = Vec::new();
|
||||
|
||||
// Collect Set-Cookie headers from the main URL and optional login URL
|
||||
let urls_to_check: Vec<&str> = std::iter::once(url)
|
||||
.chain(login_url.into_iter())
|
||||
.collect();
|
||||
// Collect Set-Cookie headers from the main URL and optional login URL
|
||||
let urls_to_check: Vec<&str> = std::iter::once(url).chain(login_url).collect();
|
||||
|
||||
for check_url in &urls_to_check {
|
||||
// Use a client that does NOT follow redirects so we catch cookies on redirect responses
|
||||
let no_redirect_client = reqwest::Client::builder()
|
||||
.danger_accept_invalid_certs(true)
|
||||
.redirect(reqwest::redirect::Policy::none())
|
||||
.timeout(std::time::Duration::from_secs(15))
|
||||
.build()
|
||||
.map_err(|e| CoreError::Dast(format!("Client build error: {e}")))?;
|
||||
for check_url in &urls_to_check {
|
||||
// Use a client that does NOT follow redirects so we catch cookies on redirect responses
|
||||
let no_redirect_client = reqwest::Client::builder()
|
||||
.danger_accept_invalid_certs(true)
|
||||
.redirect(reqwest::redirect::Policy::none())
|
||||
.timeout(std::time::Duration::from_secs(15))
|
||||
.build()
|
||||
.map_err(|e| CoreError::Dast(format!("Client build error: {e}")))?;
|
||||
|
||||
let response = match no_redirect_client.get(*check_url).send().await {
|
||||
Ok(r) => r,
|
||||
Err(e) => {
|
||||
// Try with the main client that follows redirects
|
||||
match self.http.get(*check_url).send().await {
|
||||
Ok(r) => r,
|
||||
Err(_) => continue,
|
||||
let response = match no_redirect_client.get(*check_url).send().await {
|
||||
Ok(r) => r,
|
||||
Err(_e) => {
|
||||
// Try with the main client that follows redirects
|
||||
match self.http.get(*check_url).send().await {
|
||||
Ok(r) => r,
|
||||
Err(_) => continue,
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
let status = response.status().as_u16();
|
||||
let set_cookie_headers: Vec<String> = response
|
||||
.headers()
|
||||
.get_all("set-cookie")
|
||||
.iter()
|
||||
.filter_map(|v| v.to_str().ok().map(String::from))
|
||||
.collect();
|
||||
let status = response.status().as_u16();
|
||||
let set_cookie_headers: Vec<String> = response
|
||||
.headers()
|
||||
.get_all("set-cookie")
|
||||
.iter()
|
||||
.filter_map(|v| v.to_str().ok().map(String::from))
|
||||
.collect();
|
||||
|
||||
for raw_cookie in &set_cookie_headers {
|
||||
let cookie = Self::parse_set_cookie(raw_cookie);
|
||||
let is_sensitive = Self::is_sensitive_cookie(&cookie.name);
|
||||
let is_https = check_url.starts_with("https://");
|
||||
for raw_cookie in &set_cookie_headers {
|
||||
let cookie = Self::parse_set_cookie(raw_cookie);
|
||||
let is_sensitive = Self::is_sensitive_cookie(&cookie.name);
|
||||
let is_https = check_url.starts_with("https://");
|
||||
|
||||
let cookie_info = json!({
|
||||
"name": cookie.name,
|
||||
"secure": cookie.secure,
|
||||
"http_only": cookie.http_only,
|
||||
"same_site": cookie.same_site,
|
||||
"domain": cookie.domain,
|
||||
"path": cookie.path,
|
||||
"is_sensitive": is_sensitive,
|
||||
"url": check_url,
|
||||
});
|
||||
cookie_data.push(cookie_info);
|
||||
let cookie_info = json!({
|
||||
"name": cookie.name,
|
||||
"secure": cookie.secure,
|
||||
"http_only": cookie.http_only,
|
||||
"same_site": cookie.same_site,
|
||||
"domain": cookie.domain,
|
||||
"path": cookie.path,
|
||||
"is_sensitive": is_sensitive,
|
||||
"url": check_url,
|
||||
});
|
||||
cookie_data.push(cookie_info);
|
||||
|
||||
// Check: missing Secure flag
|
||||
if !cookie.secure && (is_https || is_sensitive) {
|
||||
let severity = if is_sensitive {
|
||||
Severity::High
|
||||
} else {
|
||||
Severity::Medium
|
||||
};
|
||||
// Check: missing Secure flag
|
||||
if !cookie.secure && (is_https || is_sensitive) {
|
||||
let severity = if is_sensitive {
|
||||
Severity::High
|
||||
} else {
|
||||
Severity::Medium
|
||||
};
|
||||
|
||||
let evidence = DastEvidence {
|
||||
request_method: "GET".to_string(),
|
||||
request_url: check_url.to_string(),
|
||||
request_headers: None,
|
||||
request_body: None,
|
||||
response_status: status,
|
||||
response_headers: None,
|
||||
response_snippet: Some(cookie.raw.clone()),
|
||||
screenshot_path: None,
|
||||
payload: None,
|
||||
response_time_ms: None,
|
||||
};
|
||||
let evidence = DastEvidence {
|
||||
request_method: "GET".to_string(),
|
||||
request_url: check_url.to_string(),
|
||||
request_headers: None,
|
||||
request_body: None,
|
||||
response_status: status,
|
||||
response_headers: None,
|
||||
response_snippet: Some(cookie.raw.clone()),
|
||||
screenshot_path: None,
|
||||
payload: None,
|
||||
response_time_ms: None,
|
||||
};
|
||||
|
||||
let mut finding = DastFinding::new(
|
||||
let mut finding = DastFinding::new(
|
||||
String::new(),
|
||||
target_id.clone(),
|
||||
DastVulnType::CookieSecurity,
|
||||
@@ -226,32 +302,32 @@ impl PentestTool for CookieAnalyzerTool {
|
||||
check_url.to_string(),
|
||||
"GET".to_string(),
|
||||
);
|
||||
finding.cwe = Some("CWE-614".to_string());
|
||||
finding.evidence = vec![evidence];
|
||||
finding.remediation = Some(
|
||||
"Add the 'Secure' attribute to the Set-Cookie header to ensure the \
|
||||
finding.cwe = Some("CWE-614".to_string());
|
||||
finding.evidence = vec![evidence];
|
||||
finding.remediation = Some(
|
||||
"Add the 'Secure' attribute to the Set-Cookie header to ensure the \
|
||||
cookie is only sent over HTTPS connections."
|
||||
.to_string(),
|
||||
);
|
||||
findings.push(finding);
|
||||
}
|
||||
.to_string(),
|
||||
);
|
||||
findings.push(finding);
|
||||
}
|
||||
|
||||
// Check: missing HttpOnly flag on sensitive cookies
|
||||
if !cookie.http_only && is_sensitive {
|
||||
let evidence = DastEvidence {
|
||||
request_method: "GET".to_string(),
|
||||
request_url: check_url.to_string(),
|
||||
request_headers: None,
|
||||
request_body: None,
|
||||
response_status: status,
|
||||
response_headers: None,
|
||||
response_snippet: Some(cookie.raw.clone()),
|
||||
screenshot_path: None,
|
||||
payload: None,
|
||||
response_time_ms: None,
|
||||
};
|
||||
// Check: missing HttpOnly flag on sensitive cookies
|
||||
if !cookie.http_only && is_sensitive {
|
||||
let evidence = DastEvidence {
|
||||
request_method: "GET".to_string(),
|
||||
request_url: check_url.to_string(),
|
||||
request_headers: None,
|
||||
request_body: None,
|
||||
response_status: status,
|
||||
response_headers: None,
|
||||
response_snippet: Some(cookie.raw.clone()),
|
||||
screenshot_path: None,
|
||||
payload: None,
|
||||
response_time_ms: None,
|
||||
};
|
||||
|
||||
let mut finding = DastFinding::new(
|
||||
let mut finding = DastFinding::new(
|
||||
String::new(),
|
||||
target_id.clone(),
|
||||
DastVulnType::CookieSecurity,
|
||||
@@ -265,137 +341,137 @@ impl PentestTool for CookieAnalyzerTool {
|
||||
check_url.to_string(),
|
||||
"GET".to_string(),
|
||||
);
|
||||
finding.cwe = Some("CWE-1004".to_string());
|
||||
finding.evidence = vec![evidence];
|
||||
finding.remediation = Some(
|
||||
"Add the 'HttpOnly' attribute to the Set-Cookie header to prevent \
|
||||
finding.cwe = Some("CWE-1004".to_string());
|
||||
finding.evidence = vec![evidence];
|
||||
finding.remediation = Some(
|
||||
"Add the 'HttpOnly' attribute to the Set-Cookie header to prevent \
|
||||
JavaScript access to the cookie."
|
||||
.to_string(),
|
||||
);
|
||||
findings.push(finding);
|
||||
}
|
||||
.to_string(),
|
||||
);
|
||||
findings.push(finding);
|
||||
}
|
||||
|
||||
// Check: missing or weak SameSite
|
||||
if is_sensitive {
|
||||
let weak_same_site = match &cookie.same_site {
|
||||
None => true,
|
||||
Some(ss) => ss == "none",
|
||||
};
|
||||
|
||||
if weak_same_site {
|
||||
let evidence = DastEvidence {
|
||||
request_method: "GET".to_string(),
|
||||
request_url: check_url.to_string(),
|
||||
request_headers: None,
|
||||
request_body: None,
|
||||
response_status: status,
|
||||
response_headers: None,
|
||||
response_snippet: Some(cookie.raw.clone()),
|
||||
screenshot_path: None,
|
||||
payload: None,
|
||||
response_time_ms: None,
|
||||
// Check: missing or weak SameSite
|
||||
if is_sensitive {
|
||||
let weak_same_site = match &cookie.same_site {
|
||||
None => true,
|
||||
Some(ss) => ss == "none",
|
||||
};
|
||||
|
||||
let desc = if cookie.same_site.is_none() {
|
||||
format!(
|
||||
if weak_same_site {
|
||||
let evidence = DastEvidence {
|
||||
request_method: "GET".to_string(),
|
||||
request_url: check_url.to_string(),
|
||||
request_headers: None,
|
||||
request_body: None,
|
||||
response_status: status,
|
||||
response_headers: None,
|
||||
response_snippet: Some(cookie.raw.clone()),
|
||||
screenshot_path: None,
|
||||
payload: None,
|
||||
response_time_ms: None,
|
||||
};
|
||||
|
||||
let desc = if cookie.same_site.is_none() {
|
||||
format!(
|
||||
"The session/auth cookie '{}' does not have a SameSite attribute. \
|
||||
This may allow cross-site request forgery (CSRF) attacks.",
|
||||
cookie.name
|
||||
)
|
||||
} else {
|
||||
format!(
|
||||
} else {
|
||||
format!(
|
||||
"The session/auth cookie '{}' has SameSite=None, which allows it \
|
||||
to be sent in cross-site requests, enabling CSRF attacks.",
|
||||
cookie.name
|
||||
)
|
||||
};
|
||||
};
|
||||
|
||||
let mut finding = DastFinding::new(
|
||||
String::new(),
|
||||
target_id.clone(),
|
||||
DastVulnType::CookieSecurity,
|
||||
format!("Cookie '{}' missing or weak SameSite", cookie.name),
|
||||
desc,
|
||||
Severity::Medium,
|
||||
check_url.to_string(),
|
||||
"GET".to_string(),
|
||||
);
|
||||
finding.cwe = Some("CWE-1275".to_string());
|
||||
finding.evidence = vec![evidence];
|
||||
finding.remediation = Some(
|
||||
"Set 'SameSite=Strict' or 'SameSite=Lax' on session/auth cookies \
|
||||
let mut finding = DastFinding::new(
|
||||
String::new(),
|
||||
target_id.clone(),
|
||||
DastVulnType::CookieSecurity,
|
||||
format!("Cookie '{}' missing or weak SameSite", cookie.name),
|
||||
desc,
|
||||
Severity::Medium,
|
||||
check_url.to_string(),
|
||||
"GET".to_string(),
|
||||
);
|
||||
finding.cwe = Some("CWE-1275".to_string());
|
||||
finding.evidence = vec![evidence];
|
||||
finding.remediation = Some(
|
||||
"Set 'SameSite=Strict' or 'SameSite=Lax' on session/auth cookies \
|
||||
to prevent cross-site request inclusion."
|
||||
.to_string(),
|
||||
);
|
||||
findings.push(finding);
|
||||
.to_string(),
|
||||
);
|
||||
findings.push(finding);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check: overly broad domain
|
||||
if let Some(ref domain) = cookie.domain {
|
||||
// A domain starting with a dot applies to all subdomains
|
||||
let dot_domain = domain.starts_with('.');
|
||||
// Count domain parts - if only 2 parts (e.g., .example.com), it's broad
|
||||
let parts: Vec<&str> = domain.trim_start_matches('.').split('.').collect();
|
||||
if dot_domain && parts.len() <= 2 && is_sensitive {
|
||||
let evidence = DastEvidence {
|
||||
request_method: "GET".to_string(),
|
||||
request_url: check_url.to_string(),
|
||||
request_headers: None,
|
||||
request_body: None,
|
||||
response_status: status,
|
||||
response_headers: None,
|
||||
response_snippet: Some(cookie.raw.clone()),
|
||||
screenshot_path: None,
|
||||
payload: None,
|
||||
response_time_ms: None,
|
||||
};
|
||||
// Check: overly broad domain
|
||||
if let Some(ref domain) = cookie.domain {
|
||||
// A domain starting with a dot applies to all subdomains
|
||||
let dot_domain = domain.starts_with('.');
|
||||
// Count domain parts - if only 2 parts (e.g., .example.com), it's broad
|
||||
let parts: Vec<&str> = domain.trim_start_matches('.').split('.').collect();
|
||||
if dot_domain && parts.len() <= 2 && is_sensitive {
|
||||
let evidence = DastEvidence {
|
||||
request_method: "GET".to_string(),
|
||||
request_url: check_url.to_string(),
|
||||
request_headers: None,
|
||||
request_body: None,
|
||||
response_status: status,
|
||||
response_headers: None,
|
||||
response_snippet: Some(cookie.raw.clone()),
|
||||
screenshot_path: None,
|
||||
payload: None,
|
||||
response_time_ms: None,
|
||||
};
|
||||
|
||||
let mut finding = DastFinding::new(
|
||||
String::new(),
|
||||
target_id.clone(),
|
||||
DastVulnType::CookieSecurity,
|
||||
format!("Cookie '{}' has overly broad domain", cookie.name),
|
||||
format!(
|
||||
"The cookie '{}' is scoped to domain '{}' which includes all \
|
||||
let mut finding = DastFinding::new(
|
||||
String::new(),
|
||||
target_id.clone(),
|
||||
DastVulnType::CookieSecurity,
|
||||
format!("Cookie '{}' has overly broad domain", cookie.name),
|
||||
format!(
|
||||
"The cookie '{}' is scoped to domain '{}' which includes all \
|
||||
subdomains. If any subdomain is compromised, the attacker can \
|
||||
access this cookie.",
|
||||
cookie.name, domain
|
||||
),
|
||||
Severity::Low,
|
||||
check_url.to_string(),
|
||||
"GET".to_string(),
|
||||
);
|
||||
finding.cwe = Some("CWE-1004".to_string());
|
||||
finding.evidence = vec![evidence];
|
||||
finding.remediation = Some(
|
||||
cookie.name, domain
|
||||
),
|
||||
Severity::Low,
|
||||
check_url.to_string(),
|
||||
"GET".to_string(),
|
||||
);
|
||||
finding.cwe = Some("CWE-1004".to_string());
|
||||
finding.evidence = vec![evidence];
|
||||
finding.remediation = Some(
|
||||
"Restrict the cookie domain to the specific subdomain that needs it \
|
||||
rather than the entire parent domain."
|
||||
.to_string(),
|
||||
);
|
||||
findings.push(finding);
|
||||
findings.push(finding);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let count = findings.len();
|
||||
info!(url, findings = count, "Cookie analysis complete");
|
||||
let count = findings.len();
|
||||
info!(url, findings = count, "Cookie analysis complete");
|
||||
|
||||
Ok(PentestToolResult {
|
||||
summary: if count > 0 {
|
||||
format!("Found {count} cookie security issues.")
|
||||
} else if cookie_data.is_empty() {
|
||||
"No cookies were set by the target.".to_string()
|
||||
} else {
|
||||
"All cookies have proper security attributes.".to_string()
|
||||
},
|
||||
findings,
|
||||
data: json!({
|
||||
"cookies": cookie_data,
|
||||
"total_cookies": cookie_data.len(),
|
||||
}),
|
||||
})
|
||||
Ok(PentestToolResult {
|
||||
summary: if count > 0 {
|
||||
format!("Found {count} cookie security issues.")
|
||||
} else if cookie_data.is_empty() {
|
||||
"No cookies were set by the target.".to_string()
|
||||
} else {
|
||||
"All cookies have proper security attributes.".to_string()
|
||||
},
|
||||
findings,
|
||||
data: json!({
|
||||
"cookies": cookie_data,
|
||||
"total_cookies": cookie_data.len(),
|
||||
}),
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -22,19 +22,60 @@ impl CorsCheckerTool {
|
||||
vec![
|
||||
("null_origin", "null".to_string()),
|
||||
("evil_domain", "https://evil.com".to_string()),
|
||||
(
|
||||
"subdomain_spoof",
|
||||
format!("https://{target_host}.evil.com"),
|
||||
),
|
||||
(
|
||||
"prefix_spoof",
|
||||
format!("https://evil-{target_host}"),
|
||||
),
|
||||
("subdomain_spoof", format!("https://{target_host}.evil.com")),
|
||||
("prefix_spoof", format!("https://evil-{target_host}")),
|
||||
("http_downgrade", format!("http://{target_host}")),
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_origins_contains_expected_variants() {
|
||||
let origins = CorsCheckerTool::test_origins("example.com");
|
||||
assert_eq!(origins.len(), 5);
|
||||
|
||||
let names: Vec<&str> = origins.iter().map(|(name, _)| *name).collect();
|
||||
assert!(names.contains(&"null_origin"));
|
||||
assert!(names.contains(&"evil_domain"));
|
||||
assert!(names.contains(&"subdomain_spoof"));
|
||||
assert!(names.contains(&"prefix_spoof"));
|
||||
assert!(names.contains(&"http_downgrade"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_origins_uses_target_host() {
|
||||
let origins = CorsCheckerTool::test_origins("myapp.io");
|
||||
|
||||
let subdomain = origins
|
||||
.iter()
|
||||
.find(|(n, _)| *n == "subdomain_spoof")
|
||||
.unwrap();
|
||||
assert_eq!(subdomain.1, "https://myapp.io.evil.com");
|
||||
|
||||
let prefix = origins.iter().find(|(n, _)| *n == "prefix_spoof").unwrap();
|
||||
assert_eq!(prefix.1, "https://evil-myapp.io");
|
||||
|
||||
let http_downgrade = origins
|
||||
.iter()
|
||||
.find(|(n, _)| *n == "http_downgrade")
|
||||
.unwrap();
|
||||
assert_eq!(http_downgrade.1, "http://myapp.io");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_origins_null_and_evil_are_static() {
|
||||
let origins = CorsCheckerTool::test_origins("anything.com");
|
||||
let null_origin = origins.iter().find(|(n, _)| *n == "null_origin").unwrap();
|
||||
assert_eq!(null_origin.1, "null");
|
||||
let evil = origins.iter().find(|(n, _)| *n == "evil_domain").unwrap();
|
||||
assert_eq!(evil.1, "https://evil.com");
|
||||
}
|
||||
}
|
||||
|
||||
impl PentestTool for CorsCheckerTool {
|
||||
fn name(&self) -> &str {
|
||||
"cors_checker"
|
||||
@@ -68,82 +109,82 @@ impl PentestTool for CorsCheckerTool {
|
||||
&'a self,
|
||||
input: serde_json::Value,
|
||||
context: &'a PentestToolContext,
|
||||
) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<PentestToolResult, CoreError>> + Send + 'a>> {
|
||||
) -> std::pin::Pin<
|
||||
Box<dyn std::future::Future<Output = Result<PentestToolResult, CoreError>> + Send + 'a>,
|
||||
> {
|
||||
Box::pin(async move {
|
||||
let url = input
|
||||
.get("url")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| CoreError::Dast("Missing required 'url' parameter".to_string()))?;
|
||||
let url = input
|
||||
.get("url")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| CoreError::Dast("Missing required 'url' parameter".to_string()))?;
|
||||
|
||||
let additional_origins: Vec<String> = input
|
||||
.get("additional_origins")
|
||||
.and_then(|v| v.as_array())
|
||||
.map(|arr| {
|
||||
arr.iter()
|
||||
.filter_map(|v| v.as_str().map(String::from))
|
||||
.collect()
|
||||
})
|
||||
.unwrap_or_default();
|
||||
let additional_origins: Vec<String> = input
|
||||
.get("additional_origins")
|
||||
.and_then(|v| v.as_array())
|
||||
.map(|arr| {
|
||||
arr.iter()
|
||||
.filter_map(|v| v.as_str().map(String::from))
|
||||
.collect()
|
||||
})
|
||||
.unwrap_or_default();
|
||||
|
||||
let target_id = context
|
||||
.target
|
||||
.id
|
||||
.map(|oid| oid.to_hex())
|
||||
.unwrap_or_else(|| "unknown".to_string());
|
||||
let target_id = context
|
||||
.target
|
||||
.id
|
||||
.map(|oid| oid.to_hex())
|
||||
.unwrap_or_else(|| "unknown".to_string());
|
||||
|
||||
let target_host = url::Url::parse(url)
|
||||
.ok()
|
||||
.and_then(|u| u.host_str().map(String::from))
|
||||
.unwrap_or_else(|| url.to_string());
|
||||
let target_host = url::Url::parse(url)
|
||||
.ok()
|
||||
.and_then(|u| u.host_str().map(String::from))
|
||||
.unwrap_or_else(|| url.to_string());
|
||||
|
||||
let mut findings = Vec::new();
|
||||
let mut cors_data: Vec<serde_json::Value> = Vec::new();
|
||||
let mut findings = Vec::new();
|
||||
let mut cors_data: Vec<serde_json::Value> = Vec::new();
|
||||
|
||||
// First, send a request without Origin to get baseline
|
||||
let baseline = self
|
||||
.http
|
||||
.get(url)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| CoreError::Dast(format!("Failed to fetch {url}: {e}")))?;
|
||||
// First, send a request without Origin to get baseline
|
||||
let baseline = self
|
||||
.http
|
||||
.get(url)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| CoreError::Dast(format!("Failed to fetch {url}: {e}")))?;
|
||||
|
||||
let baseline_acao = baseline
|
||||
.headers()
|
||||
.get("access-control-allow-origin")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.map(String::from);
|
||||
let baseline_acao = baseline
|
||||
.headers()
|
||||
.get("access-control-allow-origin")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.map(String::from);
|
||||
|
||||
cors_data.push(json!({
|
||||
"origin": null,
|
||||
"acao": baseline_acao,
|
||||
}));
|
||||
cors_data.push(json!({
|
||||
"origin": null,
|
||||
"acao": baseline_acao,
|
||||
}));
|
||||
|
||||
// Check for wildcard + credentials (dangerous combo)
|
||||
if let Some(ref acao) = baseline_acao {
|
||||
if acao == "*" {
|
||||
let acac = baseline
|
||||
.headers()
|
||||
.get("access-control-allow-credentials")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.unwrap_or("");
|
||||
// Check for wildcard + credentials (dangerous combo)
|
||||
if let Some(ref acao) = baseline_acao {
|
||||
if acao == "*" {
|
||||
let acac = baseline
|
||||
.headers()
|
||||
.get("access-control-allow-credentials")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.unwrap_or("");
|
||||
|
||||
if acac.to_lowercase() == "true" {
|
||||
let evidence = DastEvidence {
|
||||
request_method: "GET".to_string(),
|
||||
request_url: url.to_string(),
|
||||
request_headers: None,
|
||||
request_body: None,
|
||||
response_status: baseline.status().as_u16(),
|
||||
response_headers: None,
|
||||
response_snippet: Some(format!(
|
||||
"Access-Control-Allow-Origin: *\nAccess-Control-Allow-Credentials: true"
|
||||
)),
|
||||
screenshot_path: None,
|
||||
payload: None,
|
||||
response_time_ms: None,
|
||||
};
|
||||
if acac.to_lowercase() == "true" {
|
||||
let evidence = DastEvidence {
|
||||
request_method: "GET".to_string(),
|
||||
request_url: url.to_string(),
|
||||
request_headers: None,
|
||||
request_body: None,
|
||||
response_status: baseline.status().as_u16(),
|
||||
response_headers: None,
|
||||
response_snippet: Some("Access-Control-Allow-Origin: *\nAccess-Control-Allow-Credentials: true".to_string()),
|
||||
screenshot_path: None,
|
||||
payload: None,
|
||||
response_time_ms: None,
|
||||
};
|
||||
|
||||
let mut finding = DastFinding::new(
|
||||
let mut finding = DastFinding::new(
|
||||
String::new(),
|
||||
target_id.clone(),
|
||||
DastVulnType::CorsMisconfiguration,
|
||||
@@ -157,254 +198,251 @@ impl PentestTool for CorsCheckerTool {
|
||||
url.to_string(),
|
||||
"GET".to_string(),
|
||||
);
|
||||
finding.cwe = Some("CWE-942".to_string());
|
||||
finding.evidence = vec![evidence];
|
||||
finding.remediation = Some(
|
||||
"Never combine Access-Control-Allow-Origin: * with \
|
||||
finding.cwe = Some("CWE-942".to_string());
|
||||
finding.evidence = vec![evidence];
|
||||
finding.remediation = Some(
|
||||
"Never combine Access-Control-Allow-Origin: * with \
|
||||
Access-Control-Allow-Credentials: true. Specify explicit allowed origins."
|
||||
.to_string(),
|
||||
);
|
||||
findings.push(finding);
|
||||
.to_string(),
|
||||
);
|
||||
findings.push(finding);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Test with various Origin headers
|
||||
let mut test_origins = Self::test_origins(&target_host);
|
||||
for origin in &additional_origins {
|
||||
test_origins.push(("custom", origin.clone()));
|
||||
}
|
||||
// Test with various Origin headers
|
||||
let mut test_origins = Self::test_origins(&target_host);
|
||||
for origin in &additional_origins {
|
||||
test_origins.push(("custom", origin.clone()));
|
||||
}
|
||||
|
||||
for (test_name, origin) in &test_origins {
|
||||
let resp = match self
|
||||
for (test_name, origin) in &test_origins {
|
||||
let resp = match self
|
||||
.http
|
||||
.get(url)
|
||||
.header("Origin", origin.as_str())
|
||||
.send()
|
||||
.await
|
||||
{
|
||||
Ok(r) => r,
|
||||
Err(_) => continue,
|
||||
};
|
||||
|
||||
let status = resp.status().as_u16();
|
||||
let acao = resp
|
||||
.headers()
|
||||
.get("access-control-allow-origin")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.map(String::from);
|
||||
|
||||
let acac = resp
|
||||
.headers()
|
||||
.get("access-control-allow-credentials")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.map(String::from);
|
||||
|
||||
let acam = resp
|
||||
.headers()
|
||||
.get("access-control-allow-methods")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.map(String::from);
|
||||
|
||||
cors_data.push(json!({
|
||||
"test": test_name,
|
||||
"origin": origin,
|
||||
"acao": acao,
|
||||
"acac": acac,
|
||||
"acam": acam,
|
||||
"status": status,
|
||||
}));
|
||||
|
||||
// Check if the origin was reflected back
|
||||
if let Some(ref acao_val) = acao {
|
||||
let origin_reflected = acao_val == origin;
|
||||
let credentials_allowed = acac
|
||||
.as_ref()
|
||||
.map(|v| v.to_lowercase() == "true")
|
||||
.unwrap_or(false);
|
||||
|
||||
if origin_reflected && *test_name != "http_downgrade" {
|
||||
let severity = if credentials_allowed {
|
||||
Severity::Critical
|
||||
} else {
|
||||
Severity::High
|
||||
};
|
||||
|
||||
let resp_headers: HashMap<String, String> = resp
|
||||
.headers()
|
||||
.iter()
|
||||
.map(|(k, v)| (k.to_string(), v.to_str().unwrap_or("").to_string()))
|
||||
.collect();
|
||||
|
||||
let evidence = DastEvidence {
|
||||
request_method: "GET".to_string(),
|
||||
request_url: url.to_string(),
|
||||
request_headers: Some(
|
||||
[("Origin".to_string(), origin.clone())]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
),
|
||||
request_body: None,
|
||||
response_status: status,
|
||||
response_headers: Some(resp_headers),
|
||||
response_snippet: Some(format!(
|
||||
"Origin: {origin}\nAccess-Control-Allow-Origin: {acao_val}\n\
|
||||
Access-Control-Allow-Credentials: {}",
|
||||
acac.as_deref().unwrap_or("not set")
|
||||
)),
|
||||
screenshot_path: None,
|
||||
payload: Some(origin.clone()),
|
||||
response_time_ms: None,
|
||||
};
|
||||
|
||||
let title = match *test_name {
|
||||
"null_origin" => "CORS accepts null origin".to_string(),
|
||||
"evil_domain" => "CORS reflects arbitrary origin".to_string(),
|
||||
"subdomain_spoof" => {
|
||||
"CORS vulnerable to subdomain spoofing".to_string()
|
||||
}
|
||||
"prefix_spoof" => "CORS vulnerable to prefix spoofing".to_string(),
|
||||
_ => format!("CORS reflects untrusted origin ({test_name})"),
|
||||
};
|
||||
|
||||
let cred_note = if credentials_allowed {
|
||||
" Combined with Access-Control-Allow-Credentials: true, this allows \
|
||||
the attacker to steal authenticated data."
|
||||
} else {
|
||||
""
|
||||
};
|
||||
|
||||
let mut finding = DastFinding::new(
|
||||
String::new(),
|
||||
target_id.clone(),
|
||||
DastVulnType::CorsMisconfiguration,
|
||||
title,
|
||||
format!(
|
||||
"The endpoint {url} reflects the Origin header '{origin}' back in \
|
||||
Access-Control-Allow-Origin, allowing cross-origin requests from \
|
||||
untrusted domains.{cred_note}"
|
||||
),
|
||||
severity,
|
||||
url.to_string(),
|
||||
"GET".to_string(),
|
||||
);
|
||||
finding.cwe = Some("CWE-942".to_string());
|
||||
finding.exploitable = credentials_allowed;
|
||||
finding.evidence = vec![evidence];
|
||||
finding.remediation = Some(
|
||||
"Validate the Origin header against a whitelist of trusted origins. \
|
||||
Do not reflect the Origin header value directly. Use specific allowed \
|
||||
origins instead of wildcards."
|
||||
.to_string(),
|
||||
);
|
||||
findings.push(finding);
|
||||
|
||||
warn!(
|
||||
url,
|
||||
test_name,
|
||||
origin,
|
||||
credentials = credentials_allowed,
|
||||
"CORS misconfiguration detected"
|
||||
);
|
||||
}
|
||||
|
||||
// Special case: HTTP downgrade
|
||||
if *test_name == "http_downgrade" && origin_reflected && credentials_allowed {
|
||||
let evidence = DastEvidence {
|
||||
request_method: "GET".to_string(),
|
||||
request_url: url.to_string(),
|
||||
request_headers: Some(
|
||||
[("Origin".to_string(), origin.clone())]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
),
|
||||
request_body: None,
|
||||
response_status: status,
|
||||
response_headers: None,
|
||||
response_snippet: Some(format!(
|
||||
"HTTP origin accepted: {origin} -> ACAO: {acao_val}"
|
||||
)),
|
||||
screenshot_path: None,
|
||||
payload: Some(origin.clone()),
|
||||
response_time_ms: None,
|
||||
};
|
||||
|
||||
let mut finding = DastFinding::new(
|
||||
String::new(),
|
||||
target_id.clone(),
|
||||
DastVulnType::CorsMisconfiguration,
|
||||
"CORS allows HTTP origin with credentials".to_string(),
|
||||
format!(
|
||||
"The HTTPS endpoint {url} accepts the HTTP origin {origin} with \
|
||||
credentials. An attacker performing a man-in-the-middle attack on \
|
||||
the HTTP version could steal authenticated data."
|
||||
),
|
||||
Severity::High,
|
||||
url.to_string(),
|
||||
"GET".to_string(),
|
||||
);
|
||||
finding.cwe = Some("CWE-942".to_string());
|
||||
finding.evidence = vec![evidence];
|
||||
finding.remediation = Some(
|
||||
"Do not accept HTTP origins for HTTPS endpoints. Ensure CORS \
|
||||
origin validation enforces the https:// scheme."
|
||||
.to_string(),
|
||||
);
|
||||
findings.push(finding);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Also send a preflight OPTIONS request
|
||||
if let Ok(resp) = self
|
||||
.http
|
||||
.get(url)
|
||||
.header("Origin", origin.as_str())
|
||||
.request(reqwest::Method::OPTIONS, url)
|
||||
.header("Origin", "https://evil.com")
|
||||
.header("Access-Control-Request-Method", "POST")
|
||||
.header(
|
||||
"Access-Control-Request-Headers",
|
||||
"Authorization, Content-Type",
|
||||
)
|
||||
.send()
|
||||
.await
|
||||
{
|
||||
Ok(r) => r,
|
||||
Err(_) => continue,
|
||||
};
|
||||
let acam = resp
|
||||
.headers()
|
||||
.get("access-control-allow-methods")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.map(String::from);
|
||||
|
||||
let status = resp.status().as_u16();
|
||||
let acao = resp
|
||||
.headers()
|
||||
.get("access-control-allow-origin")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.map(String::from);
|
||||
let acah = resp
|
||||
.headers()
|
||||
.get("access-control-allow-headers")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.map(String::from);
|
||||
|
||||
let acac = resp
|
||||
.headers()
|
||||
.get("access-control-allow-credentials")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.map(String::from);
|
||||
|
||||
let acam = resp
|
||||
.headers()
|
||||
.get("access-control-allow-methods")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.map(String::from);
|
||||
|
||||
cors_data.push(json!({
|
||||
"test": test_name,
|
||||
"origin": origin,
|
||||
"acao": acao,
|
||||
"acac": acac,
|
||||
"acam": acam,
|
||||
"status": status,
|
||||
}));
|
||||
|
||||
// Check if the origin was reflected back
|
||||
if let Some(ref acao_val) = acao {
|
||||
let origin_reflected = acao_val == origin;
|
||||
let credentials_allowed = acac
|
||||
.as_ref()
|
||||
.map(|v| v.to_lowercase() == "true")
|
||||
.unwrap_or(false);
|
||||
|
||||
if origin_reflected && *test_name != "http_downgrade" {
|
||||
let severity = if credentials_allowed {
|
||||
Severity::Critical
|
||||
} else {
|
||||
Severity::High
|
||||
};
|
||||
|
||||
let resp_headers: HashMap<String, String> = resp
|
||||
.headers()
|
||||
.iter()
|
||||
.map(|(k, v)| (k.to_string(), v.to_str().unwrap_or("").to_string()))
|
||||
.collect();
|
||||
|
||||
let evidence = DastEvidence {
|
||||
request_method: "GET".to_string(),
|
||||
request_url: url.to_string(),
|
||||
request_headers: Some(
|
||||
[("Origin".to_string(), origin.clone())]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
),
|
||||
request_body: None,
|
||||
response_status: status,
|
||||
response_headers: Some(resp_headers),
|
||||
response_snippet: Some(format!(
|
||||
"Origin: {origin}\nAccess-Control-Allow-Origin: {acao_val}\n\
|
||||
Access-Control-Allow-Credentials: {}",
|
||||
acac.as_deref().unwrap_or("not set")
|
||||
)),
|
||||
screenshot_path: None,
|
||||
payload: Some(origin.clone()),
|
||||
response_time_ms: None,
|
||||
};
|
||||
|
||||
let title = match *test_name {
|
||||
"null_origin" => {
|
||||
"CORS accepts null origin".to_string()
|
||||
}
|
||||
"evil_domain" => {
|
||||
"CORS reflects arbitrary origin".to_string()
|
||||
}
|
||||
"subdomain_spoof" => {
|
||||
"CORS vulnerable to subdomain spoofing".to_string()
|
||||
}
|
||||
"prefix_spoof" => {
|
||||
"CORS vulnerable to prefix spoofing".to_string()
|
||||
}
|
||||
_ => format!("CORS reflects untrusted origin ({test_name})"),
|
||||
};
|
||||
|
||||
let cred_note = if credentials_allowed {
|
||||
" Combined with Access-Control-Allow-Credentials: true, this allows \
|
||||
the attacker to steal authenticated data."
|
||||
} else {
|
||||
""
|
||||
};
|
||||
|
||||
let mut finding = DastFinding::new(
|
||||
String::new(),
|
||||
target_id.clone(),
|
||||
DastVulnType::CorsMisconfiguration,
|
||||
title,
|
||||
format!(
|
||||
"The endpoint {url} reflects the Origin header '{origin}' back in \
|
||||
Access-Control-Allow-Origin, allowing cross-origin requests from \
|
||||
untrusted domains.{cred_note}"
|
||||
),
|
||||
severity,
|
||||
url.to_string(),
|
||||
"GET".to_string(),
|
||||
);
|
||||
finding.cwe = Some("CWE-942".to_string());
|
||||
finding.exploitable = credentials_allowed;
|
||||
finding.evidence = vec![evidence];
|
||||
finding.remediation = Some(
|
||||
"Validate the Origin header against a whitelist of trusted origins. \
|
||||
Do not reflect the Origin header value directly. Use specific allowed \
|
||||
origins instead of wildcards."
|
||||
.to_string(),
|
||||
);
|
||||
findings.push(finding);
|
||||
|
||||
warn!(
|
||||
url,
|
||||
test_name,
|
||||
origin,
|
||||
credentials = credentials_allowed,
|
||||
"CORS misconfiguration detected"
|
||||
);
|
||||
}
|
||||
|
||||
// Special case: HTTP downgrade
|
||||
if *test_name == "http_downgrade" && origin_reflected && credentials_allowed {
|
||||
let evidence = DastEvidence {
|
||||
request_method: "GET".to_string(),
|
||||
request_url: url.to_string(),
|
||||
request_headers: Some(
|
||||
[("Origin".to_string(), origin.clone())]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
),
|
||||
request_body: None,
|
||||
response_status: status,
|
||||
response_headers: None,
|
||||
response_snippet: Some(format!(
|
||||
"HTTP origin accepted: {origin} -> ACAO: {acao_val}"
|
||||
)),
|
||||
screenshot_path: None,
|
||||
payload: Some(origin.clone()),
|
||||
response_time_ms: None,
|
||||
};
|
||||
|
||||
let mut finding = DastFinding::new(
|
||||
String::new(),
|
||||
target_id.clone(),
|
||||
DastVulnType::CorsMisconfiguration,
|
||||
"CORS allows HTTP origin with credentials".to_string(),
|
||||
format!(
|
||||
"The HTTPS endpoint {url} accepts the HTTP origin {origin} with \
|
||||
credentials. An attacker performing a man-in-the-middle attack on \
|
||||
the HTTP version could steal authenticated data."
|
||||
),
|
||||
Severity::High,
|
||||
url.to_string(),
|
||||
"GET".to_string(),
|
||||
);
|
||||
finding.cwe = Some("CWE-942".to_string());
|
||||
finding.evidence = vec![evidence];
|
||||
finding.remediation = Some(
|
||||
"Do not accept HTTP origins for HTTPS endpoints. Ensure CORS \
|
||||
origin validation enforces the https:// scheme."
|
||||
.to_string(),
|
||||
);
|
||||
findings.push(finding);
|
||||
}
|
||||
cors_data.push(json!({
|
||||
"test": "preflight",
|
||||
"status": resp.status().as_u16(),
|
||||
"allow_methods": acam,
|
||||
"allow_headers": acah,
|
||||
}));
|
||||
}
|
||||
}
|
||||
|
||||
// Also send a preflight OPTIONS request
|
||||
if let Ok(resp) = self
|
||||
.http
|
||||
.request(reqwest::Method::OPTIONS, url)
|
||||
.header("Origin", "https://evil.com")
|
||||
.header("Access-Control-Request-Method", "POST")
|
||||
.header("Access-Control-Request-Headers", "Authorization, Content-Type")
|
||||
.send()
|
||||
.await
|
||||
{
|
||||
let acam = resp
|
||||
.headers()
|
||||
.get("access-control-allow-methods")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.map(String::from);
|
||||
let count = findings.len();
|
||||
info!(url, findings = count, "CORS check complete");
|
||||
|
||||
let acah = resp
|
||||
.headers()
|
||||
.get("access-control-allow-headers")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.map(String::from);
|
||||
|
||||
cors_data.push(json!({
|
||||
"test": "preflight",
|
||||
"status": resp.status().as_u16(),
|
||||
"allow_methods": acam,
|
||||
"allow_headers": acah,
|
||||
}));
|
||||
}
|
||||
|
||||
let count = findings.len();
|
||||
info!(url, findings = count, "CORS check complete");
|
||||
|
||||
Ok(PentestToolResult {
|
||||
summary: if count > 0 {
|
||||
format!("Found {count} CORS misconfiguration issues for {url}.")
|
||||
} else {
|
||||
format!("CORS configuration appears secure for {url}.")
|
||||
},
|
||||
findings,
|
||||
data: json!({
|
||||
"tests": cors_data,
|
||||
}),
|
||||
})
|
||||
Ok(PentestToolResult {
|
||||
summary: if count > 0 {
|
||||
format!("Found {count} CORS misconfiguration issues for {url}.")
|
||||
} else {
|
||||
format!("CORS configuration appears secure for {url}.")
|
||||
},
|
||||
findings,
|
||||
data: json!({
|
||||
"tests": cors_data,
|
||||
}),
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -47,7 +47,7 @@ impl CspAnalyzerTool {
|
||||
url: &str,
|
||||
target_id: &str,
|
||||
status: u16,
|
||||
csp_raw: &str,
|
||||
_csp_raw: &str,
|
||||
) -> Vec<DastFinding> {
|
||||
let mut findings = Vec::new();
|
||||
|
||||
@@ -216,12 +216,18 @@ impl CspAnalyzerTool {
|
||||
("object-src", "Controls plugins like Flash"),
|
||||
("base-uri", "Controls the base URL for relative URLs"),
|
||||
("form-action", "Controls where forms can submit to"),
|
||||
("frame-ancestors", "Controls who can embed this page in iframes"),
|
||||
(
|
||||
"frame-ancestors",
|
||||
"Controls who can embed this page in iframes",
|
||||
),
|
||||
];
|
||||
|
||||
for (dir_name, desc) in &important_directives {
|
||||
if !directive_names.contains(dir_name)
|
||||
&& !(has_default_src && *dir_name != "frame-ancestors" && *dir_name != "base-uri" && *dir_name != "form-action")
|
||||
&& (!has_default_src
|
||||
|| *dir_name == "frame-ancestors"
|
||||
|| *dir_name == "base-uri"
|
||||
|| *dir_name == "form-action")
|
||||
{
|
||||
let evidence = make_evidence(format!("CSP missing directive: {dir_name}"));
|
||||
let mut finding = DastFinding::new(
|
||||
@@ -258,6 +264,125 @@ impl CspAnalyzerTool {
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn parse_csp_basic() {
|
||||
let directives = CspAnalyzerTool::parse_csp(
|
||||
"default-src 'self'; script-src 'self' https://cdn.example.com",
|
||||
);
|
||||
assert_eq!(directives.len(), 2);
|
||||
assert_eq!(directives[0].name, "default-src");
|
||||
assert_eq!(directives[0].values, vec!["'self'"]);
|
||||
assert_eq!(directives[1].name, "script-src");
|
||||
assert_eq!(
|
||||
directives[1].values,
|
||||
vec!["'self'", "https://cdn.example.com"]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_csp_empty() {
|
||||
let directives = CspAnalyzerTool::parse_csp("");
|
||||
assert!(directives.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_csp_trailing_semicolons() {
|
||||
let directives = CspAnalyzerTool::parse_csp("default-src 'none';;;");
|
||||
assert_eq!(directives.len(), 1);
|
||||
assert_eq!(directives[0].name, "default-src");
|
||||
assert_eq!(directives[0].values, vec!["'none'"]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_csp_directive_without_value() {
|
||||
let directives = CspAnalyzerTool::parse_csp("upgrade-insecure-requests");
|
||||
assert_eq!(directives.len(), 1);
|
||||
assert_eq!(directives[0].name, "upgrade-insecure-requests");
|
||||
assert!(directives[0].values.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_csp_names_are_lowercased() {
|
||||
let directives = CspAnalyzerTool::parse_csp("Script-Src 'self'");
|
||||
assert_eq!(directives[0].name, "script-src");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn analyze_detects_unsafe_inline() {
|
||||
let directives = CspAnalyzerTool::parse_csp("script-src 'self' 'unsafe-inline'");
|
||||
let findings =
|
||||
CspAnalyzerTool::analyze_directives(&directives, "https://example.com", "t1", 200, "");
|
||||
assert!(findings.iter().any(|f| f.title.contains("unsafe-inline")));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn analyze_detects_unsafe_eval() {
|
||||
let directives = CspAnalyzerTool::parse_csp("script-src 'self' 'unsafe-eval'");
|
||||
let findings =
|
||||
CspAnalyzerTool::analyze_directives(&directives, "https://example.com", "t1", 200, "");
|
||||
assert!(findings.iter().any(|f| f.title.contains("unsafe-eval")));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn analyze_detects_wildcard() {
|
||||
let directives = CspAnalyzerTool::parse_csp("img-src *");
|
||||
let findings =
|
||||
CspAnalyzerTool::analyze_directives(&directives, "https://example.com", "t1", 200, "");
|
||||
assert!(findings.iter().any(|f| f.title.contains("wildcard")));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn analyze_detects_data_uri_in_script_src() {
|
||||
let directives = CspAnalyzerTool::parse_csp("script-src 'self' data:");
|
||||
let findings =
|
||||
CspAnalyzerTool::analyze_directives(&directives, "https://example.com", "t1", 200, "");
|
||||
assert!(findings.iter().any(|f| f.title.contains("data:")));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn analyze_detects_http_sources() {
|
||||
let directives = CspAnalyzerTool::parse_csp("script-src http:");
|
||||
let findings =
|
||||
CspAnalyzerTool::analyze_directives(&directives, "https://example.com", "t1", 200, "");
|
||||
assert!(findings.iter().any(|f| f.title.contains("HTTP sources")));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn analyze_detects_missing_directives_without_default_src() {
|
||||
let directives = CspAnalyzerTool::parse_csp("img-src 'self'");
|
||||
let findings =
|
||||
CspAnalyzerTool::analyze_directives(&directives, "https://example.com", "t1", 200, "");
|
||||
let missing_names: Vec<&str> = findings
|
||||
.iter()
|
||||
.filter(|f| f.title.contains("missing"))
|
||||
.map(|f| f.title.as_str())
|
||||
.collect();
|
||||
// Should flag script-src, object-src, base-uri, form-action, frame-ancestors
|
||||
assert!(missing_names.len() >= 4);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn analyze_good_csp_no_unsafe_findings() {
|
||||
let directives = CspAnalyzerTool::parse_csp(
|
||||
"default-src 'none'; script-src 'self'; style-src 'self'; \
|
||||
img-src 'self'; object-src 'none'; base-uri 'self'; \
|
||||
form-action 'self'; frame-ancestors 'none'",
|
||||
);
|
||||
let findings =
|
||||
CspAnalyzerTool::analyze_directives(&directives, "https://example.com", "t1", 200, "");
|
||||
// A well-configured CSP should not produce unsafe-inline/eval/wildcard findings
|
||||
assert!(findings.iter().all(|f| {
|
||||
!f.title.contains("unsafe-inline")
|
||||
&& !f.title.contains("unsafe-eval")
|
||||
&& !f.title.contains("wildcard")
|
||||
}));
|
||||
}
|
||||
}
|
||||
|
||||
impl PentestTool for CspAnalyzerTool {
|
||||
fn name(&self) -> &str {
|
||||
"csp_analyzer"
|
||||
@@ -285,163 +410,167 @@ impl PentestTool for CspAnalyzerTool {
|
||||
&'a self,
|
||||
input: serde_json::Value,
|
||||
context: &'a PentestToolContext,
|
||||
) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<PentestToolResult, CoreError>> + Send + 'a>> {
|
||||
) -> std::pin::Pin<
|
||||
Box<dyn std::future::Future<Output = Result<PentestToolResult, CoreError>> + Send + 'a>,
|
||||
> {
|
||||
Box::pin(async move {
|
||||
let url = input
|
||||
.get("url")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| CoreError::Dast("Missing required 'url' parameter".to_string()))?;
|
||||
let url = input
|
||||
.get("url")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| CoreError::Dast("Missing required 'url' parameter".to_string()))?;
|
||||
|
||||
let target_id = context
|
||||
.target
|
||||
.id
|
||||
.map(|oid| oid.to_hex())
|
||||
.unwrap_or_else(|| "unknown".to_string());
|
||||
let target_id = context
|
||||
.target
|
||||
.id
|
||||
.map(|oid| oid.to_hex())
|
||||
.unwrap_or_else(|| "unknown".to_string());
|
||||
|
||||
let response = self
|
||||
.http
|
||||
.get(url)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| CoreError::Dast(format!("Failed to fetch {url}: {e}")))?;
|
||||
let response = self
|
||||
.http
|
||||
.get(url)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| CoreError::Dast(format!("Failed to fetch {url}: {e}")))?;
|
||||
|
||||
let status = response.status().as_u16();
|
||||
let status = response.status().as_u16();
|
||||
|
||||
// Check for CSP header
|
||||
let csp_header = response
|
||||
.headers()
|
||||
.get("content-security-policy")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.map(String::from);
|
||||
// Check for CSP header
|
||||
let csp_header = response
|
||||
.headers()
|
||||
.get("content-security-policy")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.map(String::from);
|
||||
|
||||
// Also check for report-only variant
|
||||
let csp_report_only = response
|
||||
.headers()
|
||||
.get("content-security-policy-report-only")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.map(String::from);
|
||||
// Also check for report-only variant
|
||||
let csp_report_only = response
|
||||
.headers()
|
||||
.get("content-security-policy-report-only")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.map(String::from);
|
||||
|
||||
let mut findings = Vec::new();
|
||||
let mut csp_data = json!({});
|
||||
let mut findings = Vec::new();
|
||||
let mut csp_data = json!({});
|
||||
|
||||
match &csp_header {
|
||||
Some(csp) => {
|
||||
let directives = Self::parse_csp(csp);
|
||||
let directive_map: serde_json::Value = directives
|
||||
.iter()
|
||||
.map(|d| (d.name.clone(), json!(d.values)))
|
||||
.collect::<serde_json::Map<String, serde_json::Value>>()
|
||||
.into();
|
||||
match &csp_header {
|
||||
Some(csp) => {
|
||||
let directives = Self::parse_csp(csp);
|
||||
let directive_map: serde_json::Value = directives
|
||||
.iter()
|
||||
.map(|d| (d.name.clone(), json!(d.values)))
|
||||
.collect::<serde_json::Map<String, serde_json::Value>>()
|
||||
.into();
|
||||
|
||||
csp_data["csp_header"] = json!(csp);
|
||||
csp_data["directives"] = directive_map;
|
||||
csp_data["csp_header"] = json!(csp);
|
||||
csp_data["directives"] = directive_map;
|
||||
|
||||
findings.extend(Self::analyze_directives(
|
||||
&directives,
|
||||
url,
|
||||
&target_id,
|
||||
status,
|
||||
csp,
|
||||
));
|
||||
}
|
||||
None => {
|
||||
csp_data["csp_header"] = json!(null);
|
||||
findings.extend(Self::analyze_directives(
|
||||
&directives,
|
||||
url,
|
||||
&target_id,
|
||||
status,
|
||||
csp,
|
||||
));
|
||||
}
|
||||
None => {
|
||||
csp_data["csp_header"] = json!(null);
|
||||
|
||||
let evidence = DastEvidence {
|
||||
request_method: "GET".to_string(),
|
||||
request_url: url.to_string(),
|
||||
request_headers: None,
|
||||
request_body: None,
|
||||
response_status: status,
|
||||
response_headers: None,
|
||||
response_snippet: Some("Content-Security-Policy header is missing".to_string()),
|
||||
screenshot_path: None,
|
||||
payload: None,
|
||||
response_time_ms: None,
|
||||
};
|
||||
let evidence = DastEvidence {
|
||||
request_method: "GET".to_string(),
|
||||
request_url: url.to_string(),
|
||||
request_headers: None,
|
||||
request_body: None,
|
||||
response_status: status,
|
||||
response_headers: None,
|
||||
response_snippet: Some(
|
||||
"Content-Security-Policy header is missing".to_string(),
|
||||
),
|
||||
screenshot_path: None,
|
||||
payload: None,
|
||||
response_time_ms: None,
|
||||
};
|
||||
|
||||
let mut finding = DastFinding::new(
|
||||
String::new(),
|
||||
target_id.clone(),
|
||||
DastVulnType::CspIssue,
|
||||
"Missing Content-Security-Policy header".to_string(),
|
||||
format!(
|
||||
"No Content-Security-Policy header is present on {url}. \
|
||||
let mut finding = DastFinding::new(
|
||||
String::new(),
|
||||
target_id.clone(),
|
||||
DastVulnType::CspIssue,
|
||||
"Missing Content-Security-Policy header".to_string(),
|
||||
format!(
|
||||
"No Content-Security-Policy header is present on {url}. \
|
||||
Without CSP, the browser has no instructions on which sources are \
|
||||
trusted, making XSS exploitation much easier."
|
||||
),
|
||||
Severity::Medium,
|
||||
url.to_string(),
|
||||
"GET".to_string(),
|
||||
);
|
||||
finding.cwe = Some("CWE-16".to_string());
|
||||
finding.evidence = vec![evidence];
|
||||
finding.remediation = Some(
|
||||
),
|
||||
Severity::Medium,
|
||||
url.to_string(),
|
||||
"GET".to_string(),
|
||||
);
|
||||
finding.cwe = Some("CWE-16".to_string());
|
||||
finding.evidence = vec![evidence];
|
||||
finding.remediation = Some(
|
||||
"Add a Content-Security-Policy header. Start with a restrictive policy like \
|
||||
\"default-src 'self'; script-src 'self'; style-src 'self'; img-src 'self' data:; \
|
||||
object-src 'none'; frame-ancestors 'none'; base-uri 'self'\"."
|
||||
.to_string(),
|
||||
);
|
||||
findings.push(finding);
|
||||
findings.push(finding);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(ref report_only) = csp_report_only {
|
||||
csp_data["csp_report_only"] = json!(report_only);
|
||||
// If ONLY report-only exists (no enforcing CSP), warn
|
||||
if csp_header.is_none() {
|
||||
let evidence = DastEvidence {
|
||||
request_method: "GET".to_string(),
|
||||
request_url: url.to_string(),
|
||||
request_headers: None,
|
||||
request_body: None,
|
||||
response_status: status,
|
||||
response_headers: None,
|
||||
response_snippet: Some(format!(
|
||||
"Content-Security-Policy-Report-Only: {}",
|
||||
report_only
|
||||
)),
|
||||
screenshot_path: None,
|
||||
payload: None,
|
||||
response_time_ms: None,
|
||||
};
|
||||
if let Some(ref report_only) = csp_report_only {
|
||||
csp_data["csp_report_only"] = json!(report_only);
|
||||
// If ONLY report-only exists (no enforcing CSP), warn
|
||||
if csp_header.is_none() {
|
||||
let evidence = DastEvidence {
|
||||
request_method: "GET".to_string(),
|
||||
request_url: url.to_string(),
|
||||
request_headers: None,
|
||||
request_body: None,
|
||||
response_status: status,
|
||||
response_headers: None,
|
||||
response_snippet: Some(format!(
|
||||
"Content-Security-Policy-Report-Only: {}",
|
||||
report_only
|
||||
)),
|
||||
screenshot_path: None,
|
||||
payload: None,
|
||||
response_time_ms: None,
|
||||
};
|
||||
|
||||
let mut finding = DastFinding::new(
|
||||
String::new(),
|
||||
target_id.clone(),
|
||||
DastVulnType::CspIssue,
|
||||
"CSP is report-only, not enforcing".to_string(),
|
||||
"A Content-Security-Policy-Report-Only header is present but no enforcing \
|
||||
let mut finding = DastFinding::new(
|
||||
String::new(),
|
||||
target_id.clone(),
|
||||
DastVulnType::CspIssue,
|
||||
"CSP is report-only, not enforcing".to_string(),
|
||||
"A Content-Security-Policy-Report-Only header is present but no enforcing \
|
||||
Content-Security-Policy header exists. Report-only mode only logs violations \
|
||||
but does not block them."
|
||||
.to_string(),
|
||||
Severity::Low,
|
||||
url.to_string(),
|
||||
"GET".to_string(),
|
||||
);
|
||||
finding.cwe = Some("CWE-16".to_string());
|
||||
finding.evidence = vec![evidence];
|
||||
finding.remediation = Some(
|
||||
.to_string(),
|
||||
Severity::Low,
|
||||
url.to_string(),
|
||||
"GET".to_string(),
|
||||
);
|
||||
finding.cwe = Some("CWE-16".to_string());
|
||||
finding.evidence = vec![evidence];
|
||||
finding.remediation = Some(
|
||||
"Once you have verified the CSP policy works correctly in report-only mode, \
|
||||
deploy it as an enforcing Content-Security-Policy header."
|
||||
.to_string(),
|
||||
);
|
||||
findings.push(finding);
|
||||
findings.push(finding);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let count = findings.len();
|
||||
info!(url, findings = count, "CSP analysis complete");
|
||||
let count = findings.len();
|
||||
info!(url, findings = count, "CSP analysis complete");
|
||||
|
||||
Ok(PentestToolResult {
|
||||
summary: if count > 0 {
|
||||
format!("Found {count} CSP issues for {url}.")
|
||||
} else {
|
||||
format!("Content-Security-Policy looks good for {url}.")
|
||||
},
|
||||
findings,
|
||||
data: csp_data,
|
||||
})
|
||||
Ok(PentestToolResult {
|
||||
summary: if count > 0 {
|
||||
format!("Found {count} CSP issues for {url}.")
|
||||
} else {
|
||||
format!("Content-Security-Policy looks good for {url}.")
|
||||
},
|
||||
findings,
|
||||
data: csp_data,
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8,6 +8,12 @@ use tracing::{info, warn};
|
||||
/// Tool that checks email security configuration (DMARC and SPF records).
|
||||
pub struct DmarcCheckerTool;
|
||||
|
||||
impl Default for DmarcCheckerTool {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl DmarcCheckerTool {
|
||||
pub fn new() -> Self {
|
||||
Self
|
||||
@@ -78,6 +84,105 @@ impl DmarcCheckerTool {
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn parse_dmarc_policy_reject() {
|
||||
let record = "v=DMARC1; p=reject; rua=mailto:dmarc@example.com";
|
||||
assert_eq!(
|
||||
DmarcCheckerTool::parse_dmarc_policy(record),
|
||||
Some("reject".to_string())
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_dmarc_policy_none() {
|
||||
let record = "v=DMARC1; p=none";
|
||||
assert_eq!(
|
||||
DmarcCheckerTool::parse_dmarc_policy(record),
|
||||
Some("none".to_string())
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_dmarc_policy_quarantine() {
|
||||
let record = "v=DMARC1; p=quarantine; sp=none";
|
||||
assert_eq!(
|
||||
DmarcCheckerTool::parse_dmarc_policy(record),
|
||||
Some("quarantine".to_string())
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_dmarc_policy_missing() {
|
||||
let record = "v=DMARC1; rua=mailto:test@example.com";
|
||||
assert_eq!(DmarcCheckerTool::parse_dmarc_policy(record), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_dmarc_subdomain_policy() {
|
||||
let record = "v=DMARC1; p=reject; sp=quarantine";
|
||||
assert_eq!(
|
||||
DmarcCheckerTool::parse_dmarc_subdomain_policy(record),
|
||||
Some("quarantine".to_string())
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_dmarc_subdomain_policy_missing() {
|
||||
let record = "v=DMARC1; p=reject";
|
||||
assert_eq!(DmarcCheckerTool::parse_dmarc_subdomain_policy(record), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_dmarc_rua_present() {
|
||||
let record = "v=DMARC1; p=reject; rua=mailto:dmarc@example.com";
|
||||
assert_eq!(
|
||||
DmarcCheckerTool::parse_dmarc_rua(record),
|
||||
Some("mailto:dmarc@example.com".to_string())
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_dmarc_rua_missing() {
|
||||
let record = "v=DMARC1; p=none";
|
||||
assert_eq!(DmarcCheckerTool::parse_dmarc_rua(record), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn is_spf_record_valid() {
|
||||
assert!(DmarcCheckerTool::is_spf_record(
|
||||
"v=spf1 include:_spf.google.com -all"
|
||||
));
|
||||
assert!(DmarcCheckerTool::is_spf_record("v=spf1 -all"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn is_spf_record_invalid() {
|
||||
assert!(!DmarcCheckerTool::is_spf_record("v=DMARC1; p=reject"));
|
||||
assert!(!DmarcCheckerTool::is_spf_record("some random txt record"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn spf_soft_fail_detection() {
|
||||
assert!(DmarcCheckerTool::spf_uses_soft_fail(
|
||||
"v=spf1 include:_spf.google.com ~all"
|
||||
));
|
||||
assert!(!DmarcCheckerTool::spf_uses_soft_fail(
|
||||
"v=spf1 include:_spf.google.com -all"
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn spf_allows_all_detection() {
|
||||
assert!(DmarcCheckerTool::spf_allows_all("v=spf1 +all"));
|
||||
assert!(!DmarcCheckerTool::spf_allows_all("v=spf1 -all"));
|
||||
assert!(!DmarcCheckerTool::spf_allows_all("v=spf1 ~all"));
|
||||
}
|
||||
}
|
||||
|
||||
impl PentestTool for DmarcCheckerTool {
|
||||
fn name(&self) -> &str {
|
||||
"dmarc_checker"
|
||||
@@ -105,43 +210,89 @@ impl PentestTool for DmarcCheckerTool {
|
||||
&'a self,
|
||||
input: serde_json::Value,
|
||||
context: &'a PentestToolContext,
|
||||
) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<PentestToolResult, CoreError>> + Send + 'a>> {
|
||||
) -> std::pin::Pin<
|
||||
Box<dyn std::future::Future<Output = Result<PentestToolResult, CoreError>> + Send + 'a>,
|
||||
> {
|
||||
Box::pin(async move {
|
||||
let domain = input
|
||||
.get("domain")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| CoreError::Dast("Missing required 'domain' parameter".to_string()))?;
|
||||
let domain = input
|
||||
.get("domain")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| {
|
||||
CoreError::Dast("Missing required 'domain' parameter".to_string())
|
||||
})?;
|
||||
|
||||
let target_id = context
|
||||
.target
|
||||
.id
|
||||
.map(|oid| oid.to_hex())
|
||||
.unwrap_or_else(|| "unknown".to_string());
|
||||
let target_id = context
|
||||
.target
|
||||
.id
|
||||
.map(|oid| oid.to_hex())
|
||||
.unwrap_or_else(|| "unknown".to_string());
|
||||
|
||||
let mut findings = Vec::new();
|
||||
let mut email_data = json!({});
|
||||
let mut findings = Vec::new();
|
||||
let mut email_data = json!({});
|
||||
|
||||
// ---- DMARC check ----
|
||||
let dmarc_domain = format!("_dmarc.{domain}");
|
||||
let dmarc_records = Self::query_txt(&dmarc_domain).await.unwrap_or_default();
|
||||
// ---- DMARC check ----
|
||||
let dmarc_domain = format!("_dmarc.{domain}");
|
||||
let dmarc_records = Self::query_txt(&dmarc_domain).await.unwrap_or_default();
|
||||
|
||||
let dmarc_record = dmarc_records.iter().find(|r| r.starts_with("v=DMARC1"));
|
||||
let dmarc_record = dmarc_records.iter().find(|r| r.starts_with("v=DMARC1"));
|
||||
|
||||
match dmarc_record {
|
||||
Some(record) => {
|
||||
email_data["dmarc_record"] = json!(record);
|
||||
match dmarc_record {
|
||||
Some(record) => {
|
||||
email_data["dmarc_record"] = json!(record);
|
||||
|
||||
let policy = Self::parse_dmarc_policy(record);
|
||||
let sp = Self::parse_dmarc_subdomain_policy(record);
|
||||
let rua = Self::parse_dmarc_rua(record);
|
||||
let policy = Self::parse_dmarc_policy(record);
|
||||
let sp = Self::parse_dmarc_subdomain_policy(record);
|
||||
let rua = Self::parse_dmarc_rua(record);
|
||||
|
||||
email_data["dmarc_policy"] = json!(policy);
|
||||
email_data["dmarc_subdomain_policy"] = json!(sp);
|
||||
email_data["dmarc_rua"] = json!(rua);
|
||||
email_data["dmarc_policy"] = json!(policy);
|
||||
email_data["dmarc_subdomain_policy"] = json!(sp);
|
||||
email_data["dmarc_rua"] = json!(rua);
|
||||
|
||||
// Warn on weak policy
|
||||
if let Some(ref p) = policy {
|
||||
if p == "none" {
|
||||
// Warn on weak policy
|
||||
if let Some(ref p) = policy {
|
||||
if p == "none" {
|
||||
let evidence = DastEvidence {
|
||||
request_method: "DNS".to_string(),
|
||||
request_url: dmarc_domain.clone(),
|
||||
request_headers: None,
|
||||
request_body: None,
|
||||
response_status: 0,
|
||||
response_headers: None,
|
||||
response_snippet: Some(record.clone()),
|
||||
screenshot_path: None,
|
||||
payload: None,
|
||||
response_time_ms: None,
|
||||
};
|
||||
|
||||
let mut finding = DastFinding::new(
|
||||
String::new(),
|
||||
target_id.clone(),
|
||||
DastVulnType::EmailSecurity,
|
||||
format!("Weak DMARC policy for {domain}"),
|
||||
format!(
|
||||
"The DMARC policy for {domain} is set to 'none', which only monitors \
|
||||
but does not enforce email authentication. Attackers can spoof emails \
|
||||
from this domain."
|
||||
),
|
||||
Severity::Medium,
|
||||
domain.to_string(),
|
||||
"DNS".to_string(),
|
||||
);
|
||||
finding.cwe = Some("CWE-290".to_string());
|
||||
finding.evidence = vec![evidence];
|
||||
finding.remediation = Some(
|
||||
"Upgrade the DMARC policy from 'p=none' to 'p=quarantine' or \
|
||||
'p=reject' after verifying legitimate email flows are properly \
|
||||
authenticated."
|
||||
.to_string(),
|
||||
);
|
||||
findings.push(finding);
|
||||
warn!(domain, "DMARC policy is 'none'");
|
||||
}
|
||||
}
|
||||
|
||||
// Warn if no reporting URI
|
||||
if rua.is_none() {
|
||||
let evidence = DastEvidence {
|
||||
request_method: "DNS".to_string(),
|
||||
request_url: dmarc_domain.clone(),
|
||||
@@ -156,48 +307,6 @@ impl PentestTool for DmarcCheckerTool {
|
||||
};
|
||||
|
||||
let mut finding = DastFinding::new(
|
||||
String::new(),
|
||||
target_id.clone(),
|
||||
DastVulnType::EmailSecurity,
|
||||
format!("Weak DMARC policy for {domain}"),
|
||||
format!(
|
||||
"The DMARC policy for {domain} is set to 'none', which only monitors \
|
||||
but does not enforce email authentication. Attackers can spoof emails \
|
||||
from this domain."
|
||||
),
|
||||
Severity::Medium,
|
||||
domain.to_string(),
|
||||
"DNS".to_string(),
|
||||
);
|
||||
finding.cwe = Some("CWE-290".to_string());
|
||||
finding.evidence = vec![evidence];
|
||||
finding.remediation = Some(
|
||||
"Upgrade the DMARC policy from 'p=none' to 'p=quarantine' or \
|
||||
'p=reject' after verifying legitimate email flows are properly \
|
||||
authenticated."
|
||||
.to_string(),
|
||||
);
|
||||
findings.push(finding);
|
||||
warn!(domain, "DMARC policy is 'none'");
|
||||
}
|
||||
}
|
||||
|
||||
// Warn if no reporting URI
|
||||
if rua.is_none() {
|
||||
let evidence = DastEvidence {
|
||||
request_method: "DNS".to_string(),
|
||||
request_url: dmarc_domain.clone(),
|
||||
request_headers: None,
|
||||
request_body: None,
|
||||
response_status: 0,
|
||||
response_headers: None,
|
||||
response_snippet: Some(record.clone()),
|
||||
screenshot_path: None,
|
||||
payload: None,
|
||||
response_time_ms: None,
|
||||
};
|
||||
|
||||
let mut finding = DastFinding::new(
|
||||
String::new(),
|
||||
target_id.clone(),
|
||||
DastVulnType::EmailSecurity,
|
||||
@@ -211,80 +320,80 @@ impl PentestTool for DmarcCheckerTool {
|
||||
domain.to_string(),
|
||||
"DNS".to_string(),
|
||||
);
|
||||
finding.cwe = Some("CWE-778".to_string());
|
||||
finding.evidence = vec![evidence];
|
||||
finding.remediation = Some(
|
||||
"Add a 'rua=' tag to your DMARC record to receive aggregate reports. \
|
||||
finding.cwe = Some("CWE-778".to_string());
|
||||
finding.evidence = vec![evidence];
|
||||
finding.remediation = Some(
|
||||
"Add a 'rua=' tag to your DMARC record to receive aggregate reports. \
|
||||
Example: 'rua=mailto:dmarc-reports@example.com'."
|
||||
.to_string(),
|
||||
);
|
||||
findings.push(finding);
|
||||
.to_string(),
|
||||
);
|
||||
findings.push(finding);
|
||||
}
|
||||
}
|
||||
}
|
||||
None => {
|
||||
email_data["dmarc_record"] = json!(null);
|
||||
None => {
|
||||
email_data["dmarc_record"] = json!(null);
|
||||
|
||||
let evidence = DastEvidence {
|
||||
request_method: "DNS".to_string(),
|
||||
request_url: dmarc_domain.clone(),
|
||||
request_headers: None,
|
||||
request_body: None,
|
||||
response_status: 0,
|
||||
response_headers: None,
|
||||
response_snippet: Some("No DMARC record found".to_string()),
|
||||
screenshot_path: None,
|
||||
payload: None,
|
||||
response_time_ms: None,
|
||||
};
|
||||
|
||||
let mut finding = DastFinding::new(
|
||||
String::new(),
|
||||
target_id.clone(),
|
||||
DastVulnType::EmailSecurity,
|
||||
format!("Missing DMARC record for {domain}"),
|
||||
format!(
|
||||
"No DMARC record was found for {domain}. Without DMARC, there is no \
|
||||
policy to prevent email spoofing and phishing attacks using this domain."
|
||||
),
|
||||
Severity::High,
|
||||
domain.to_string(),
|
||||
"DNS".to_string(),
|
||||
);
|
||||
finding.cwe = Some("CWE-290".to_string());
|
||||
finding.evidence = vec![evidence];
|
||||
finding.remediation = Some(
|
||||
"Create a DMARC TXT record at _dmarc.<domain>. Start with 'v=DMARC1; p=none; \
|
||||
rua=mailto:dmarc@example.com' and gradually move to 'p=reject'."
|
||||
.to_string(),
|
||||
);
|
||||
findings.push(finding);
|
||||
warn!(domain, "No DMARC record found");
|
||||
}
|
||||
}
|
||||
|
||||
// ---- SPF check ----
|
||||
let txt_records = Self::query_txt(domain).await.unwrap_or_default();
|
||||
let spf_record = txt_records.iter().find(|r| Self::is_spf_record(r));
|
||||
|
||||
match spf_record {
|
||||
Some(record) => {
|
||||
email_data["spf_record"] = json!(record);
|
||||
|
||||
if Self::spf_allows_all(record) {
|
||||
let evidence = DastEvidence {
|
||||
request_method: "DNS".to_string(),
|
||||
request_url: domain.to_string(),
|
||||
request_url: dmarc_domain.clone(),
|
||||
request_headers: None,
|
||||
request_body: None,
|
||||
response_status: 0,
|
||||
response_headers: None,
|
||||
response_snippet: Some(record.clone()),
|
||||
response_snippet: Some("No DMARC record found".to_string()),
|
||||
screenshot_path: None,
|
||||
payload: None,
|
||||
response_time_ms: None,
|
||||
};
|
||||
|
||||
let mut finding = DastFinding::new(
|
||||
String::new(),
|
||||
target_id.clone(),
|
||||
DastVulnType::EmailSecurity,
|
||||
format!("Missing DMARC record for {domain}"),
|
||||
format!(
|
||||
"No DMARC record was found for {domain}. Without DMARC, there is no \
|
||||
policy to prevent email spoofing and phishing attacks using this domain."
|
||||
),
|
||||
Severity::High,
|
||||
domain.to_string(),
|
||||
"DNS".to_string(),
|
||||
);
|
||||
finding.cwe = Some("CWE-290".to_string());
|
||||
finding.evidence = vec![evidence];
|
||||
finding.remediation = Some(
|
||||
"Create a DMARC TXT record at _dmarc.<domain>. Start with 'v=DMARC1; p=none; \
|
||||
rua=mailto:dmarc@example.com' and gradually move to 'p=reject'."
|
||||
.to_string(),
|
||||
);
|
||||
findings.push(finding);
|
||||
warn!(domain, "No DMARC record found");
|
||||
}
|
||||
}
|
||||
|
||||
// ---- SPF check ----
|
||||
let txt_records = Self::query_txt(domain).await.unwrap_or_default();
|
||||
let spf_record = txt_records.iter().find(|r| Self::is_spf_record(r));
|
||||
|
||||
match spf_record {
|
||||
Some(record) => {
|
||||
email_data["spf_record"] = json!(record);
|
||||
|
||||
if Self::spf_allows_all(record) {
|
||||
let evidence = DastEvidence {
|
||||
request_method: "DNS".to_string(),
|
||||
request_url: domain.to_string(),
|
||||
request_headers: None,
|
||||
request_body: None,
|
||||
response_status: 0,
|
||||
response_headers: None,
|
||||
response_snippet: Some(record.clone()),
|
||||
screenshot_path: None,
|
||||
payload: None,
|
||||
response_time_ms: None,
|
||||
};
|
||||
|
||||
let mut finding = DastFinding::new(
|
||||
String::new(),
|
||||
target_id.clone(),
|
||||
DastVulnType::EmailSecurity,
|
||||
@@ -297,15 +406,55 @@ impl PentestTool for DmarcCheckerTool {
|
||||
domain.to_string(),
|
||||
"DNS".to_string(),
|
||||
);
|
||||
finding.cwe = Some("CWE-290".to_string());
|
||||
finding.evidence = vec![evidence];
|
||||
finding.remediation = Some(
|
||||
finding.cwe = Some("CWE-290".to_string());
|
||||
finding.evidence = vec![evidence];
|
||||
finding.remediation = Some(
|
||||
"Change '+all' to '-all' (hard fail) or '~all' (soft fail) in your SPF record. \
|
||||
Only list authorized mail servers."
|
||||
.to_string(),
|
||||
);
|
||||
findings.push(finding);
|
||||
} else if Self::spf_uses_soft_fail(record) {
|
||||
findings.push(finding);
|
||||
} else if Self::spf_uses_soft_fail(record) {
|
||||
let evidence = DastEvidence {
|
||||
request_method: "DNS".to_string(),
|
||||
request_url: domain.to_string(),
|
||||
request_headers: None,
|
||||
request_body: None,
|
||||
response_status: 0,
|
||||
response_headers: None,
|
||||
response_snippet: Some(record.clone()),
|
||||
screenshot_path: None,
|
||||
payload: None,
|
||||
response_time_ms: None,
|
||||
};
|
||||
|
||||
let mut finding = DastFinding::new(
|
||||
String::new(),
|
||||
target_id.clone(),
|
||||
DastVulnType::EmailSecurity,
|
||||
format!("SPF soft fail for {domain}"),
|
||||
format!(
|
||||
"The SPF record for {domain} uses '~all' (soft fail) instead of \
|
||||
'-all' (hard fail). Soft fail marks unauthorized emails as suspicious \
|
||||
but does not reject them."
|
||||
),
|
||||
Severity::Low,
|
||||
domain.to_string(),
|
||||
"DNS".to_string(),
|
||||
);
|
||||
finding.cwe = Some("CWE-290".to_string());
|
||||
finding.evidence = vec![evidence];
|
||||
finding.remediation = Some(
|
||||
"Consider changing '~all' to '-all' in your SPF record once you have \
|
||||
confirmed all legitimate mail sources are listed."
|
||||
.to_string(),
|
||||
);
|
||||
findings.push(finding);
|
||||
}
|
||||
}
|
||||
None => {
|
||||
email_data["spf_record"] = json!(null);
|
||||
|
||||
let evidence = DastEvidence {
|
||||
request_method: "DNS".to_string(),
|
||||
request_url: domain.to_string(),
|
||||
@@ -313,7 +462,7 @@ impl PentestTool for DmarcCheckerTool {
|
||||
request_body: None,
|
||||
response_status: 0,
|
||||
response_headers: None,
|
||||
response_snippet: Some(record.clone()),
|
||||
response_snippet: Some("No SPF record found".to_string()),
|
||||
screenshot_path: None,
|
||||
payload: None,
|
||||
response_time_ms: None,
|
||||
@@ -323,79 +472,39 @@ impl PentestTool for DmarcCheckerTool {
|
||||
String::new(),
|
||||
target_id.clone(),
|
||||
DastVulnType::EmailSecurity,
|
||||
format!("SPF soft fail for {domain}"),
|
||||
format!("Missing SPF record for {domain}"),
|
||||
format!(
|
||||
"The SPF record for {domain} uses '~all' (soft fail) instead of \
|
||||
'-all' (hard fail). Soft fail marks unauthorized emails as suspicious \
|
||||
but does not reject them."
|
||||
),
|
||||
Severity::Low,
|
||||
"No SPF record was found for {domain}. Without SPF, any server can claim \
|
||||
to send email on behalf of this domain."
|
||||
),
|
||||
Severity::High,
|
||||
domain.to_string(),
|
||||
"DNS".to_string(),
|
||||
);
|
||||
finding.cwe = Some("CWE-290".to_string());
|
||||
finding.evidence = vec![evidence];
|
||||
finding.remediation = Some(
|
||||
"Consider changing '~all' to '-all' in your SPF record once you have \
|
||||
confirmed all legitimate mail sources are listed."
|
||||
"Create an SPF TXT record for your domain. Example: \
|
||||
'v=spf1 include:_spf.google.com -all'."
|
||||
.to_string(),
|
||||
);
|
||||
findings.push(finding);
|
||||
warn!(domain, "No SPF record found");
|
||||
}
|
||||
}
|
||||
None => {
|
||||
email_data["spf_record"] = json!(null);
|
||||
|
||||
let evidence = DastEvidence {
|
||||
request_method: "DNS".to_string(),
|
||||
request_url: domain.to_string(),
|
||||
request_headers: None,
|
||||
request_body: None,
|
||||
response_status: 0,
|
||||
response_headers: None,
|
||||
response_snippet: Some("No SPF record found".to_string()),
|
||||
screenshot_path: None,
|
||||
payload: None,
|
||||
response_time_ms: None,
|
||||
};
|
||||
let count = findings.len();
|
||||
info!(domain, findings = count, "DMARC/SPF check complete");
|
||||
|
||||
let mut finding = DastFinding::new(
|
||||
String::new(),
|
||||
target_id.clone(),
|
||||
DastVulnType::EmailSecurity,
|
||||
format!("Missing SPF record for {domain}"),
|
||||
format!(
|
||||
"No SPF record was found for {domain}. Without SPF, any server can claim \
|
||||
to send email on behalf of this domain."
|
||||
),
|
||||
Severity::High,
|
||||
domain.to_string(),
|
||||
"DNS".to_string(),
|
||||
);
|
||||
finding.cwe = Some("CWE-290".to_string());
|
||||
finding.evidence = vec![evidence];
|
||||
finding.remediation = Some(
|
||||
"Create an SPF TXT record for your domain. Example: \
|
||||
'v=spf1 include:_spf.google.com -all'."
|
||||
.to_string(),
|
||||
);
|
||||
findings.push(finding);
|
||||
warn!(domain, "No SPF record found");
|
||||
}
|
||||
}
|
||||
|
||||
let count = findings.len();
|
||||
info!(domain, findings = count, "DMARC/SPF check complete");
|
||||
|
||||
Ok(PentestToolResult {
|
||||
summary: if count > 0 {
|
||||
format!("Found {count} email security issues for {domain}.")
|
||||
} else {
|
||||
format!("Email security configuration looks good for {domain}.")
|
||||
},
|
||||
findings,
|
||||
data: email_data,
|
||||
})
|
||||
Ok(PentestToolResult {
|
||||
summary: if count > 0 {
|
||||
format!("Found {count} email security issues for {domain}.")
|
||||
} else {
|
||||
format!("Email security configuration looks good for {domain}.")
|
||||
},
|
||||
findings,
|
||||
data: email_data,
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -16,6 +16,12 @@ use tracing::{info, warn};
|
||||
/// `tokio::process::Command` wrapper around `dig` where available.
|
||||
pub struct DnsCheckerTool;
|
||||
|
||||
impl Default for DnsCheckerTool {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl DnsCheckerTool {
|
||||
pub fn new() -> Self {
|
||||
Self
|
||||
@@ -54,7 +60,9 @@ impl DnsCheckerTool {
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
return Err(CoreError::Dast(format!("DNS resolution failed for {domain}: {e}")));
|
||||
return Err(CoreError::Dast(format!(
|
||||
"DNS resolution failed for {domain}: {e}"
|
||||
)));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -94,107 +102,111 @@ impl PentestTool for DnsCheckerTool {
|
||||
&'a self,
|
||||
input: serde_json::Value,
|
||||
context: &'a PentestToolContext,
|
||||
) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<PentestToolResult, CoreError>> + Send + 'a>> {
|
||||
) -> std::pin::Pin<
|
||||
Box<dyn std::future::Future<Output = Result<PentestToolResult, CoreError>> + Send + 'a>,
|
||||
> {
|
||||
Box::pin(async move {
|
||||
let domain = input
|
||||
.get("domain")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| CoreError::Dast("Missing required 'domain' parameter".to_string()))?;
|
||||
let domain = input
|
||||
.get("domain")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| {
|
||||
CoreError::Dast("Missing required 'domain' parameter".to_string())
|
||||
})?;
|
||||
|
||||
let subdomains: Vec<String> = input
|
||||
.get("subdomains")
|
||||
.and_then(|v| v.as_array())
|
||||
.map(|arr| {
|
||||
arr.iter()
|
||||
.filter_map(|v| v.as_str().map(String::from))
|
||||
.collect()
|
||||
})
|
||||
.unwrap_or_default();
|
||||
let subdomains: Vec<String> = input
|
||||
.get("subdomains")
|
||||
.and_then(|v| v.as_array())
|
||||
.map(|arr| {
|
||||
arr.iter()
|
||||
.filter_map(|v| v.as_str().map(String::from))
|
||||
.collect()
|
||||
})
|
||||
.unwrap_or_default();
|
||||
|
||||
let mut findings = Vec::new();
|
||||
let mut dns_data: HashMap<String, serde_json::Value> = HashMap::new();
|
||||
let mut findings = Vec::new();
|
||||
let mut dns_data: HashMap<String, serde_json::Value> = HashMap::new();
|
||||
|
||||
let target_id = context
|
||||
.target
|
||||
.id
|
||||
.map(|oid| oid.to_hex())
|
||||
.unwrap_or_else(|| "unknown".to_string());
|
||||
let target_id = context
|
||||
.target
|
||||
.id
|
||||
.map(|oid| oid.to_hex())
|
||||
.unwrap_or_else(|| "unknown".to_string());
|
||||
|
||||
// --- A / AAAA records ---
|
||||
match Self::resolve_addresses(domain).await {
|
||||
Ok((ipv4, ipv6)) => {
|
||||
dns_data.insert("a_records".to_string(), json!(ipv4));
|
||||
dns_data.insert("aaaa_records".to_string(), json!(ipv6));
|
||||
// --- A / AAAA records ---
|
||||
match Self::resolve_addresses(domain).await {
|
||||
Ok((ipv4, ipv6)) => {
|
||||
dns_data.insert("a_records".to_string(), json!(ipv4));
|
||||
dns_data.insert("aaaa_records".to_string(), json!(ipv6));
|
||||
}
|
||||
Err(e) => {
|
||||
dns_data.insert("a_records_error".to_string(), json!(e.to_string()));
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
dns_data.insert("a_records_error".to_string(), json!(e.to_string()));
|
||||
}
|
||||
}
|
||||
|
||||
// --- MX records ---
|
||||
match Self::dig_query(domain, "MX").await {
|
||||
Ok(mx) => {
|
||||
dns_data.insert("mx_records".to_string(), json!(mx));
|
||||
// --- MX records ---
|
||||
match Self::dig_query(domain, "MX").await {
|
||||
Ok(mx) => {
|
||||
dns_data.insert("mx_records".to_string(), json!(mx));
|
||||
}
|
||||
Err(e) => {
|
||||
dns_data.insert("mx_records_error".to_string(), json!(e.to_string()));
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
dns_data.insert("mx_records_error".to_string(), json!(e.to_string()));
|
||||
}
|
||||
}
|
||||
|
||||
// --- NS records ---
|
||||
let ns_records = match Self::dig_query(domain, "NS").await {
|
||||
Ok(ns) => {
|
||||
dns_data.insert("ns_records".to_string(), json!(ns));
|
||||
ns
|
||||
}
|
||||
Err(e) => {
|
||||
dns_data.insert("ns_records_error".to_string(), json!(e.to_string()));
|
||||
Vec::new()
|
||||
}
|
||||
};
|
||||
// --- NS records ---
|
||||
let ns_records = match Self::dig_query(domain, "NS").await {
|
||||
Ok(ns) => {
|
||||
dns_data.insert("ns_records".to_string(), json!(ns));
|
||||
ns
|
||||
}
|
||||
Err(e) => {
|
||||
dns_data.insert("ns_records_error".to_string(), json!(e.to_string()));
|
||||
Vec::new()
|
||||
}
|
||||
};
|
||||
|
||||
// --- TXT records ---
|
||||
match Self::dig_query(domain, "TXT").await {
|
||||
Ok(txt) => {
|
||||
dns_data.insert("txt_records".to_string(), json!(txt));
|
||||
// --- TXT records ---
|
||||
match Self::dig_query(domain, "TXT").await {
|
||||
Ok(txt) => {
|
||||
dns_data.insert("txt_records".to_string(), json!(txt));
|
||||
}
|
||||
Err(e) => {
|
||||
dns_data.insert("txt_records_error".to_string(), json!(e.to_string()));
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
dns_data.insert("txt_records_error".to_string(), json!(e.to_string()));
|
||||
|
||||
// --- CNAME records (for subdomains) ---
|
||||
let mut cname_data: HashMap<String, Vec<String>> = HashMap::new();
|
||||
let mut domains_to_check = vec![domain.to_string()];
|
||||
for sub in &subdomains {
|
||||
domains_to_check.push(format!("{sub}.{domain}"));
|
||||
}
|
||||
}
|
||||
|
||||
// --- CNAME records (for subdomains) ---
|
||||
let mut cname_data: HashMap<String, Vec<String>> = HashMap::new();
|
||||
let mut domains_to_check = vec![domain.to_string()];
|
||||
for sub in &subdomains {
|
||||
domains_to_check.push(format!("{sub}.{domain}"));
|
||||
}
|
||||
for fqdn in &domains_to_check {
|
||||
match Self::dig_query(fqdn, "CNAME").await {
|
||||
Ok(cnames) if !cnames.is_empty() => {
|
||||
// Check for dangling CNAME
|
||||
for cname in &cnames {
|
||||
let cname_clean = cname.trim_end_matches('.');
|
||||
let check_addr = format!("{cname_clean}:443");
|
||||
let is_dangling = lookup_host(&check_addr).await.is_err();
|
||||
if is_dangling {
|
||||
let evidence = DastEvidence {
|
||||
request_method: "DNS".to_string(),
|
||||
request_url: fqdn.clone(),
|
||||
request_headers: None,
|
||||
request_body: None,
|
||||
response_status: 0,
|
||||
response_headers: None,
|
||||
response_snippet: Some(format!(
|
||||
"CNAME {fqdn} -> {cname} (target does not resolve)"
|
||||
)),
|
||||
screenshot_path: None,
|
||||
payload: None,
|
||||
response_time_ms: None,
|
||||
};
|
||||
|
||||
for fqdn in &domains_to_check {
|
||||
match Self::dig_query(fqdn, "CNAME").await {
|
||||
Ok(cnames) if !cnames.is_empty() => {
|
||||
// Check for dangling CNAME
|
||||
for cname in &cnames {
|
||||
let cname_clean = cname.trim_end_matches('.');
|
||||
let check_addr = format!("{cname_clean}:443");
|
||||
let is_dangling = lookup_host(&check_addr).await.is_err();
|
||||
if is_dangling {
|
||||
let evidence = DastEvidence {
|
||||
request_method: "DNS".to_string(),
|
||||
request_url: fqdn.clone(),
|
||||
request_headers: None,
|
||||
request_body: None,
|
||||
response_status: 0,
|
||||
response_headers: None,
|
||||
response_snippet: Some(format!(
|
||||
"CNAME {fqdn} -> {cname} (target does not resolve)"
|
||||
)),
|
||||
screenshot_path: None,
|
||||
payload: None,
|
||||
response_time_ms: None,
|
||||
};
|
||||
|
||||
let mut finding = DastFinding::new(
|
||||
let mut finding = DastFinding::new(
|
||||
String::new(),
|
||||
target_id.clone(),
|
||||
DastVulnType::DnsMisconfiguration,
|
||||
@@ -207,44 +219,47 @@ impl PentestTool for DnsCheckerTool {
|
||||
fqdn.clone(),
|
||||
"DNS".to_string(),
|
||||
);
|
||||
finding.cwe = Some("CWE-923".to_string());
|
||||
finding.evidence = vec![evidence];
|
||||
finding.remediation = Some(
|
||||
finding.cwe = Some("CWE-923".to_string());
|
||||
finding.evidence = vec![evidence];
|
||||
finding.remediation = Some(
|
||||
"Remove dangling CNAME records or ensure the target hostname is \
|
||||
properly configured and resolvable."
|
||||
.to_string(),
|
||||
);
|
||||
findings.push(finding);
|
||||
warn!(fqdn, cname, "Dangling CNAME detected - potential subdomain takeover");
|
||||
findings.push(finding);
|
||||
warn!(
|
||||
fqdn,
|
||||
cname, "Dangling CNAME detected - potential subdomain takeover"
|
||||
);
|
||||
}
|
||||
}
|
||||
cname_data.insert(fqdn.clone(), cnames);
|
||||
}
|
||||
cname_data.insert(fqdn.clone(), cnames);
|
||||
_ => {}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
if !cname_data.is_empty() {
|
||||
dns_data.insert("cname_records".to_string(), json!(cname_data));
|
||||
}
|
||||
if !cname_data.is_empty() {
|
||||
dns_data.insert("cname_records".to_string(), json!(cname_data));
|
||||
}
|
||||
|
||||
// --- CAA records ---
|
||||
match Self::dig_query(domain, "CAA").await {
|
||||
Ok(caa) => {
|
||||
if caa.is_empty() {
|
||||
let evidence = DastEvidence {
|
||||
request_method: "DNS".to_string(),
|
||||
request_url: domain.to_string(),
|
||||
request_headers: None,
|
||||
request_body: None,
|
||||
response_status: 0,
|
||||
response_headers: None,
|
||||
response_snippet: Some("No CAA records found".to_string()),
|
||||
screenshot_path: None,
|
||||
payload: None,
|
||||
response_time_ms: None,
|
||||
};
|
||||
// --- CAA records ---
|
||||
match Self::dig_query(domain, "CAA").await {
|
||||
Ok(caa) => {
|
||||
if caa.is_empty() {
|
||||
let evidence = DastEvidence {
|
||||
request_method: "DNS".to_string(),
|
||||
request_url: domain.to_string(),
|
||||
request_headers: None,
|
||||
request_body: None,
|
||||
response_status: 0,
|
||||
response_headers: None,
|
||||
response_snippet: Some("No CAA records found".to_string()),
|
||||
screenshot_path: None,
|
||||
payload: None,
|
||||
response_time_ms: None,
|
||||
};
|
||||
|
||||
let mut finding = DastFinding::new(
|
||||
let mut finding = DastFinding::new(
|
||||
String::new(),
|
||||
target_id.clone(),
|
||||
DastVulnType::DnsMisconfiguration,
|
||||
@@ -257,35 +272,83 @@ impl PentestTool for DnsCheckerTool {
|
||||
domain.to_string(),
|
||||
"DNS".to_string(),
|
||||
);
|
||||
finding.cwe = Some("CWE-295".to_string());
|
||||
finding.evidence = vec![evidence];
|
||||
finding.remediation = Some(
|
||||
finding.cwe = Some("CWE-295".to_string());
|
||||
finding.evidence = vec![evidence];
|
||||
finding.remediation = Some(
|
||||
"Add CAA DNS records to restrict which certificate authorities can issue \
|
||||
certificates for your domain. Example: '0 issue \"letsencrypt.org\"'."
|
||||
.to_string(),
|
||||
);
|
||||
findings.push(finding);
|
||||
findings.push(finding);
|
||||
}
|
||||
dns_data.insert("caa_records".to_string(), json!(caa));
|
||||
}
|
||||
Err(e) => {
|
||||
dns_data.insert("caa_records_error".to_string(), json!(e.to_string()));
|
||||
}
|
||||
dns_data.insert("caa_records".to_string(), json!(caa));
|
||||
}
|
||||
Err(e) => {
|
||||
dns_data.insert("caa_records_error".to_string(), json!(e.to_string()));
|
||||
|
||||
// --- DNSSEC check ---
|
||||
let dnssec_output = tokio::process::Command::new("dig")
|
||||
.args(["+dnssec", "+short", "DNSKEY", domain])
|
||||
.output()
|
||||
.await;
|
||||
|
||||
match dnssec_output {
|
||||
Ok(output) => {
|
||||
let stdout = String::from_utf8_lossy(&output.stdout);
|
||||
let has_dnssec = !stdout.trim().is_empty();
|
||||
dns_data.insert("dnssec_enabled".to_string(), json!(has_dnssec));
|
||||
|
||||
if !has_dnssec {
|
||||
let evidence = DastEvidence {
|
||||
request_method: "DNS".to_string(),
|
||||
request_url: domain.to_string(),
|
||||
request_headers: None,
|
||||
request_body: None,
|
||||
response_status: 0,
|
||||
response_headers: None,
|
||||
response_snippet: Some(
|
||||
"No DNSKEY records found - DNSSEC not enabled".to_string(),
|
||||
),
|
||||
screenshot_path: None,
|
||||
payload: None,
|
||||
response_time_ms: None,
|
||||
};
|
||||
|
||||
let mut finding = DastFinding::new(
|
||||
String::new(),
|
||||
target_id.clone(),
|
||||
DastVulnType::DnsMisconfiguration,
|
||||
format!("DNSSEC not enabled for {domain}"),
|
||||
format!(
|
||||
"DNSSEC is not enabled for {domain}. Without DNSSEC, DNS responses \
|
||||
can be spoofed, allowing man-in-the-middle attacks."
|
||||
),
|
||||
Severity::Medium,
|
||||
domain.to_string(),
|
||||
"DNS".to_string(),
|
||||
);
|
||||
finding.cwe = Some("CWE-350".to_string());
|
||||
finding.evidence = vec![evidence];
|
||||
finding.remediation = Some(
|
||||
"Enable DNSSEC for your domain by configuring DNSKEY and DS records \
|
||||
with your DNS provider and domain registrar."
|
||||
.to_string(),
|
||||
);
|
||||
findings.push(finding);
|
||||
}
|
||||
}
|
||||
Err(_) => {
|
||||
dns_data.insert("dnssec_check_error".to_string(), json!("dig not available"));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// --- DNSSEC check ---
|
||||
let dnssec_output = tokio::process::Command::new("dig")
|
||||
.args(["+dnssec", "+short", "DNSKEY", domain])
|
||||
.output()
|
||||
.await;
|
||||
|
||||
match dnssec_output {
|
||||
Ok(output) => {
|
||||
let stdout = String::from_utf8_lossy(&output.stdout);
|
||||
let has_dnssec = !stdout.trim().is_empty();
|
||||
dns_data.insert("dnssec_enabled".to_string(), json!(has_dnssec));
|
||||
|
||||
if !has_dnssec {
|
||||
// --- Check NS records for dangling ---
|
||||
for ns in &ns_records {
|
||||
let ns_clean = ns.trim_end_matches('.');
|
||||
let check_addr = format!("{ns_clean}:53");
|
||||
if lookup_host(&check_addr).await.is_err() {
|
||||
let evidence = DastEvidence {
|
||||
request_method: "DNS".to_string(),
|
||||
request_url: domain.to_string(),
|
||||
@@ -293,61 +356,13 @@ impl PentestTool for DnsCheckerTool {
|
||||
request_body: None,
|
||||
response_status: 0,
|
||||
response_headers: None,
|
||||
response_snippet: Some("No DNSKEY records found - DNSSEC not enabled".to_string()),
|
||||
response_snippet: Some(format!("NS record {ns} does not resolve")),
|
||||
screenshot_path: None,
|
||||
payload: None,
|
||||
response_time_ms: None,
|
||||
};
|
||||
|
||||
let mut finding = DastFinding::new(
|
||||
String::new(),
|
||||
target_id.clone(),
|
||||
DastVulnType::DnsMisconfiguration,
|
||||
format!("DNSSEC not enabled for {domain}"),
|
||||
format!(
|
||||
"DNSSEC is not enabled for {domain}. Without DNSSEC, DNS responses \
|
||||
can be spoofed, allowing man-in-the-middle attacks."
|
||||
),
|
||||
Severity::Medium,
|
||||
domain.to_string(),
|
||||
"DNS".to_string(),
|
||||
);
|
||||
finding.cwe = Some("CWE-350".to_string());
|
||||
finding.evidence = vec![evidence];
|
||||
finding.remediation = Some(
|
||||
"Enable DNSSEC for your domain by configuring DNSKEY and DS records \
|
||||
with your DNS provider and domain registrar."
|
||||
.to_string(),
|
||||
);
|
||||
findings.push(finding);
|
||||
}
|
||||
}
|
||||
Err(_) => {
|
||||
dns_data.insert("dnssec_check_error".to_string(), json!("dig not available"));
|
||||
}
|
||||
}
|
||||
|
||||
// --- Check NS records for dangling ---
|
||||
for ns in &ns_records {
|
||||
let ns_clean = ns.trim_end_matches('.');
|
||||
let check_addr = format!("{ns_clean}:53");
|
||||
if lookup_host(&check_addr).await.is_err() {
|
||||
let evidence = DastEvidence {
|
||||
request_method: "DNS".to_string(),
|
||||
request_url: domain.to_string(),
|
||||
request_headers: None,
|
||||
request_body: None,
|
||||
response_status: 0,
|
||||
response_headers: None,
|
||||
response_snippet: Some(format!(
|
||||
"NS record {ns} does not resolve"
|
||||
)),
|
||||
screenshot_path: None,
|
||||
payload: None,
|
||||
response_time_ms: None,
|
||||
};
|
||||
|
||||
let mut finding = DastFinding::new(
|
||||
String::new(),
|
||||
target_id.clone(),
|
||||
DastVulnType::DnsMisconfiguration,
|
||||
@@ -360,30 +375,33 @@ impl PentestTool for DnsCheckerTool {
|
||||
domain.to_string(),
|
||||
"DNS".to_string(),
|
||||
);
|
||||
finding.cwe = Some("CWE-923".to_string());
|
||||
finding.evidence = vec![evidence];
|
||||
finding.remediation = Some(
|
||||
"Remove dangling NS records or ensure the nameserver hostname is properly \
|
||||
finding.cwe = Some("CWE-923".to_string());
|
||||
finding.evidence = vec![evidence];
|
||||
finding.remediation = Some(
|
||||
"Remove dangling NS records or ensure the nameserver hostname is properly \
|
||||
configured. Dangling NS records can lead to full domain takeover."
|
||||
.to_string(),
|
||||
);
|
||||
findings.push(finding);
|
||||
warn!(domain, ns, "Dangling NS record detected - potential domain takeover");
|
||||
.to_string(),
|
||||
);
|
||||
findings.push(finding);
|
||||
warn!(
|
||||
domain,
|
||||
ns, "Dangling NS record detected - potential domain takeover"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let count = findings.len();
|
||||
info!(domain, findings = count, "DNS check complete");
|
||||
let count = findings.len();
|
||||
info!(domain, findings = count, "DNS check complete");
|
||||
|
||||
Ok(PentestToolResult {
|
||||
summary: if count > 0 {
|
||||
format!("Found {count} DNS configuration issues for {domain}.")
|
||||
} else {
|
||||
format!("No DNS configuration issues found for {domain}.")
|
||||
},
|
||||
findings,
|
||||
data: json!(dns_data),
|
||||
})
|
||||
Ok(PentestToolResult {
|
||||
summary: if count > 0 {
|
||||
format!("Found {count} DNS configuration issues for {domain}.")
|
||||
} else {
|
||||
format!("No DNS configuration issues found for {domain}.")
|
||||
},
|
||||
findings,
|
||||
data: json!(dns_data),
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -33,8 +33,15 @@ pub struct ToolRegistry {
|
||||
tools: HashMap<String, Box<dyn PentestTool>>,
|
||||
}
|
||||
|
||||
impl Default for ToolRegistry {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl ToolRegistry {
|
||||
/// Create a new registry with all built-in tools pre-registered.
|
||||
#[allow(clippy::expect_used)]
|
||||
pub fn new() -> Self {
|
||||
let http = reqwest::Client::builder()
|
||||
.danger_accept_invalid_certs(true)
|
||||
@@ -67,13 +74,10 @@ impl ToolRegistry {
|
||||
);
|
||||
|
||||
// New infrastructure / analysis tools
|
||||
register(&mut tools, Box::<dns_checker::DnsCheckerTool>::default());
|
||||
register(
|
||||
&mut tools,
|
||||
Box::new(dns_checker::DnsCheckerTool::new()),
|
||||
);
|
||||
register(
|
||||
&mut tools,
|
||||
Box::new(dmarc_checker::DmarcCheckerTool::new()),
|
||||
Box::<dmarc_checker::DmarcCheckerTool>::default(),
|
||||
);
|
||||
register(
|
||||
&mut tools,
|
||||
@@ -109,10 +113,7 @@ impl ToolRegistry {
|
||||
&mut tools,
|
||||
Box::new(openapi_parser::OpenApiParserTool::new(http.clone())),
|
||||
);
|
||||
register(
|
||||
&mut tools,
|
||||
Box::new(recon::ReconTool::new(http)),
|
||||
);
|
||||
register(&mut tools, Box::new(recon::ReconTool::new(http)));
|
||||
|
||||
Self { tools }
|
||||
}
|
||||
|
||||
@@ -92,7 +92,10 @@ impl OpenApiParserTool {
|
||||
|
||||
// If content type suggests YAML, we can't easily parse without a YAML dep,
|
||||
// so just report the URL as found
|
||||
if content_type.contains("yaml") || body.starts_with("openapi:") || body.starts_with("swagger:") {
|
||||
if content_type.contains("yaml")
|
||||
|| body.starts_with("openapi:")
|
||||
|| body.starts_with("swagger:")
|
||||
{
|
||||
// Return a minimal JSON indicating YAML was found
|
||||
return Some((
|
||||
url.to_string(),
|
||||
@@ -107,7 +110,7 @@ impl OpenApiParserTool {
|
||||
}
|
||||
|
||||
/// Parse an OpenAPI 3.x or Swagger 2.x spec into structured endpoints.
|
||||
fn parse_spec(spec: &serde_json::Value, base_url: &str) -> Vec<ParsedEndpoint> {
|
||||
fn parse_spec(spec: &serde_json::Value, _base_url: &str) -> Vec<ParsedEndpoint> {
|
||||
let mut endpoints = Vec::new();
|
||||
|
||||
// Determine base path
|
||||
@@ -258,6 +261,166 @@ impl OpenApiParserTool {
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use serde_json::json;
|
||||
|
||||
#[test]
|
||||
fn common_spec_paths_not_empty() {
|
||||
let paths = OpenApiParserTool::common_spec_paths();
|
||||
assert!(paths.len() >= 5);
|
||||
assert!(paths.contains(&"/openapi.json"));
|
||||
assert!(paths.contains(&"/swagger.json"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_spec_openapi3_basic() {
|
||||
let spec = json!({
|
||||
"openapi": "3.0.0",
|
||||
"info": { "title": "Test API", "version": "1.0" },
|
||||
"paths": {
|
||||
"/users": {
|
||||
"get": {
|
||||
"operationId": "listUsers",
|
||||
"summary": "List all users",
|
||||
"parameters": [
|
||||
{
|
||||
"name": "limit",
|
||||
"in": "query",
|
||||
"required": false,
|
||||
"schema": { "type": "integer" }
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"200": { "description": "OK" },
|
||||
"401": { "description": "Unauthorized" }
|
||||
},
|
||||
"tags": ["users"]
|
||||
},
|
||||
"post": {
|
||||
"operationId": "createUser",
|
||||
"requestBody": {
|
||||
"content": {
|
||||
"application/json": {}
|
||||
}
|
||||
},
|
||||
"responses": { "201": {} },
|
||||
"security": [{ "bearerAuth": [] }]
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let endpoints = OpenApiParserTool::parse_spec(&spec, "https://api.example.com");
|
||||
assert_eq!(endpoints.len(), 2);
|
||||
|
||||
let get_ep = endpoints.iter().find(|e| e.method == "GET").unwrap();
|
||||
assert_eq!(get_ep.path, "/users");
|
||||
assert_eq!(get_ep.operation_id.as_deref(), Some("listUsers"));
|
||||
assert_eq!(get_ep.summary.as_deref(), Some("List all users"));
|
||||
assert_eq!(get_ep.parameters.len(), 1);
|
||||
assert_eq!(get_ep.parameters[0].name, "limit");
|
||||
assert_eq!(get_ep.parameters[0].location, "query");
|
||||
assert!(!get_ep.parameters[0].required);
|
||||
assert_eq!(get_ep.parameters[0].param_type.as_deref(), Some("integer"));
|
||||
assert_eq!(get_ep.response_codes.len(), 2);
|
||||
assert_eq!(get_ep.tags, vec!["users"]);
|
||||
|
||||
let post_ep = endpoints.iter().find(|e| e.method == "POST").unwrap();
|
||||
assert_eq!(
|
||||
post_ep.request_body_content_type.as_deref(),
|
||||
Some("application/json")
|
||||
);
|
||||
assert_eq!(post_ep.security, vec!["bearerAuth"]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_spec_swagger2_with_base_path() {
|
||||
let spec = json!({
|
||||
"swagger": "2.0",
|
||||
"basePath": "/api/v1",
|
||||
"paths": {
|
||||
"/items": {
|
||||
"get": {
|
||||
"parameters": [
|
||||
{ "name": "id", "in": "path", "required": true, "type": "string" }
|
||||
],
|
||||
"responses": { "200": {} }
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let endpoints = OpenApiParserTool::parse_spec(&spec, "https://api.example.com");
|
||||
assert_eq!(endpoints.len(), 1);
|
||||
assert_eq!(endpoints[0].path, "/api/v1/items");
|
||||
assert_eq!(
|
||||
endpoints[0].parameters[0].param_type.as_deref(),
|
||||
Some("string")
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_spec_empty_paths() {
|
||||
let spec = json!({ "openapi": "3.0.0", "paths": {} });
|
||||
let endpoints = OpenApiParserTool::parse_spec(&spec, "https://example.com");
|
||||
assert!(endpoints.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_spec_no_paths_key() {
|
||||
let spec = json!({ "openapi": "3.0.0" });
|
||||
let endpoints = OpenApiParserTool::parse_spec(&spec, "https://example.com");
|
||||
assert!(endpoints.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_spec_servers_base_url() {
|
||||
let spec = json!({
|
||||
"openapi": "3.0.0",
|
||||
"servers": [{ "url": "/api/v2" }],
|
||||
"paths": {
|
||||
"/health": {
|
||||
"get": { "responses": { "200": {} } }
|
||||
}
|
||||
}
|
||||
});
|
||||
let endpoints = OpenApiParserTool::parse_spec(&spec, "https://example.com");
|
||||
assert_eq!(endpoints[0].path, "/api/v2/health");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_spec_path_level_parameters_merged() {
|
||||
let spec = json!({
|
||||
"openapi": "3.0.0",
|
||||
"paths": {
|
||||
"/items/{id}": {
|
||||
"parameters": [
|
||||
{ "name": "id", "in": "path", "required": true, "schema": { "type": "string" } }
|
||||
],
|
||||
"get": {
|
||||
"parameters": [
|
||||
{ "name": "fields", "in": "query", "schema": { "type": "string" } }
|
||||
],
|
||||
"responses": { "200": {} }
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
let endpoints = OpenApiParserTool::parse_spec(&spec, "https://example.com");
|
||||
assert_eq!(endpoints[0].parameters.len(), 2);
|
||||
assert!(endpoints[0]
|
||||
.parameters
|
||||
.iter()
|
||||
.any(|p| p.name == "id" && p.location == "path"));
|
||||
assert!(endpoints[0]
|
||||
.parameters
|
||||
.iter()
|
||||
.any(|p| p.name == "fields" && p.location == "query"));
|
||||
}
|
||||
}
|
||||
|
||||
impl PentestTool for OpenApiParserTool {
|
||||
fn name(&self) -> &str {
|
||||
"openapi_parser"
|
||||
@@ -289,134 +452,138 @@ impl PentestTool for OpenApiParserTool {
|
||||
fn execute<'a>(
|
||||
&'a self,
|
||||
input: serde_json::Value,
|
||||
context: &'a PentestToolContext,
|
||||
) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<PentestToolResult, CoreError>> + Send + 'a>> {
|
||||
_context: &'a PentestToolContext,
|
||||
) -> std::pin::Pin<
|
||||
Box<dyn std::future::Future<Output = Result<PentestToolResult, CoreError>> + Send + 'a>,
|
||||
> {
|
||||
Box::pin(async move {
|
||||
let base_url = input
|
||||
.get("base_url")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| CoreError::Dast("Missing required 'base_url' parameter".to_string()))?;
|
||||
let base_url = input
|
||||
.get("base_url")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| {
|
||||
CoreError::Dast("Missing required 'base_url' parameter".to_string())
|
||||
})?;
|
||||
|
||||
let explicit_spec_url = input.get("spec_url").and_then(|v| v.as_str());
|
||||
let explicit_spec_url = input.get("spec_url").and_then(|v| v.as_str());
|
||||
|
||||
let base_url_trimmed = base_url.trim_end_matches('/');
|
||||
let base_url_trimmed = base_url.trim_end_matches('/');
|
||||
|
||||
// If an explicit spec URL is provided, try it first
|
||||
let mut spec_result: Option<(String, serde_json::Value)> = None;
|
||||
if let Some(spec_url) = explicit_spec_url {
|
||||
spec_result = Self::try_fetch_spec(&self.http, spec_url).await;
|
||||
}
|
||||
// If an explicit spec URL is provided, try it first
|
||||
let mut spec_result: Option<(String, serde_json::Value)> = None;
|
||||
if let Some(spec_url) = explicit_spec_url {
|
||||
spec_result = Self::try_fetch_spec(&self.http, spec_url).await;
|
||||
}
|
||||
|
||||
// If no explicit URL or it failed, try common paths
|
||||
if spec_result.is_none() {
|
||||
for path in Self::common_spec_paths() {
|
||||
let url = format!("{base_url_trimmed}{path}");
|
||||
if let Some(result) = Self::try_fetch_spec(&self.http, &url).await {
|
||||
spec_result = Some(result);
|
||||
break;
|
||||
// If no explicit URL or it failed, try common paths
|
||||
if spec_result.is_none() {
|
||||
for path in Self::common_spec_paths() {
|
||||
let url = format!("{base_url_trimmed}{path}");
|
||||
if let Some(result) = Self::try_fetch_spec(&self.http, &url).await {
|
||||
spec_result = Some(result);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
match spec_result {
|
||||
Some((spec_url, spec)) => {
|
||||
let spec_version = spec
|
||||
.get("openapi")
|
||||
.or_else(|| spec.get("swagger"))
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("unknown");
|
||||
match spec_result {
|
||||
Some((spec_url, spec)) => {
|
||||
let spec_version = spec
|
||||
.get("openapi")
|
||||
.or_else(|| spec.get("swagger"))
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("unknown");
|
||||
|
||||
let api_title = spec
|
||||
.get("info")
|
||||
.and_then(|i| i.get("title"))
|
||||
.and_then(|t| t.as_str())
|
||||
.unwrap_or("Unknown API");
|
||||
let api_title = spec
|
||||
.get("info")
|
||||
.and_then(|i| i.get("title"))
|
||||
.and_then(|t| t.as_str())
|
||||
.unwrap_or("Unknown API");
|
||||
|
||||
let api_version = spec
|
||||
.get("info")
|
||||
.and_then(|i| i.get("version"))
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("unknown");
|
||||
let api_version = spec
|
||||
.get("info")
|
||||
.and_then(|i| i.get("version"))
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("unknown");
|
||||
|
||||
let endpoints = Self::parse_spec(&spec, base_url_trimmed);
|
||||
let endpoints = Self::parse_spec(&spec, base_url_trimmed);
|
||||
|
||||
let endpoint_data: Vec<serde_json::Value> = endpoints
|
||||
.iter()
|
||||
.map(|ep| {
|
||||
let params: Vec<serde_json::Value> = ep
|
||||
.parameters
|
||||
.iter()
|
||||
.map(|p| {
|
||||
json!({
|
||||
"name": p.name,
|
||||
"in": p.location,
|
||||
"required": p.required,
|
||||
"type": p.param_type,
|
||||
"description": p.description,
|
||||
let endpoint_data: Vec<serde_json::Value> = endpoints
|
||||
.iter()
|
||||
.map(|ep| {
|
||||
let params: Vec<serde_json::Value> = ep
|
||||
.parameters
|
||||
.iter()
|
||||
.map(|p| {
|
||||
json!({
|
||||
"name": p.name,
|
||||
"in": p.location,
|
||||
"required": p.required,
|
||||
"type": p.param_type,
|
||||
"description": p.description,
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
||||
json!({
|
||||
"path": ep.path,
|
||||
"method": ep.method,
|
||||
"operation_id": ep.operation_id,
|
||||
"summary": ep.summary,
|
||||
"parameters": params,
|
||||
"request_body_content_type": ep.request_body_content_type,
|
||||
"response_codes": ep.response_codes,
|
||||
"security": ep.security,
|
||||
"tags": ep.tags,
|
||||
})
|
||||
.collect();
|
||||
|
||||
json!({
|
||||
"path": ep.path,
|
||||
"method": ep.method,
|
||||
"operation_id": ep.operation_id,
|
||||
"summary": ep.summary,
|
||||
"parameters": params,
|
||||
"request_body_content_type": ep.request_body_content_type,
|
||||
"response_codes": ep.response_codes,
|
||||
"security": ep.security,
|
||||
"tags": ep.tags,
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
.collect();
|
||||
|
||||
let endpoint_count = endpoints.len();
|
||||
info!(
|
||||
spec_url = %spec_url,
|
||||
spec_version,
|
||||
api_title,
|
||||
endpoints = endpoint_count,
|
||||
"OpenAPI spec parsed"
|
||||
);
|
||||
let endpoint_count = endpoints.len();
|
||||
info!(
|
||||
spec_url = %spec_url,
|
||||
spec_version,
|
||||
api_title,
|
||||
endpoints = endpoint_count,
|
||||
"OpenAPI spec parsed"
|
||||
);
|
||||
|
||||
Ok(PentestToolResult {
|
||||
summary: format!(
|
||||
"Found OpenAPI spec ({spec_version}) at {spec_url}. \
|
||||
Ok(PentestToolResult {
|
||||
summary: format!(
|
||||
"Found OpenAPI spec ({spec_version}) at {spec_url}. \
|
||||
API: {api_title} v{api_version}. \
|
||||
Parsed {endpoint_count} endpoints."
|
||||
),
|
||||
findings: Vec::new(), // This tool produces data, not findings
|
||||
data: json!({
|
||||
"spec_url": spec_url,
|
||||
"spec_version": spec_version,
|
||||
"api_title": api_title,
|
||||
"api_version": api_version,
|
||||
"endpoint_count": endpoint_count,
|
||||
"endpoints": endpoint_data,
|
||||
"security_schemes": spec.get("components")
|
||||
.and_then(|c| c.get("securitySchemes"))
|
||||
.or_else(|| spec.get("securityDefinitions")),
|
||||
}),
|
||||
})
|
||||
}
|
||||
None => {
|
||||
info!(base_url, "No OpenAPI spec found");
|
||||
),
|
||||
findings: Vec::new(), // This tool produces data, not findings
|
||||
data: json!({
|
||||
"spec_url": spec_url,
|
||||
"spec_version": spec_version,
|
||||
"api_title": api_title,
|
||||
"api_version": api_version,
|
||||
"endpoint_count": endpoint_count,
|
||||
"endpoints": endpoint_data,
|
||||
"security_schemes": spec.get("components")
|
||||
.and_then(|c| c.get("securitySchemes"))
|
||||
.or_else(|| spec.get("securityDefinitions")),
|
||||
}),
|
||||
})
|
||||
}
|
||||
None => {
|
||||
info!(base_url, "No OpenAPI spec found");
|
||||
|
||||
Ok(PentestToolResult {
|
||||
summary: format!(
|
||||
"No OpenAPI/Swagger specification found for {base_url}. \
|
||||
Ok(PentestToolResult {
|
||||
summary: format!(
|
||||
"No OpenAPI/Swagger specification found for {base_url}. \
|
||||
Tried {} common paths.",
|
||||
Self::common_spec_paths().len()
|
||||
),
|
||||
findings: Vec::new(),
|
||||
data: json!({
|
||||
"spec_found": false,
|
||||
"paths_tried": Self::common_spec_paths(),
|
||||
}),
|
||||
})
|
||||
Self::common_spec_paths().len()
|
||||
),
|
||||
findings: Vec::new(),
|
||||
data: json!({
|
||||
"spec_found": false,
|
||||
"paths_tried": Self::common_spec_paths(),
|
||||
}),
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -62,224 +62,229 @@ impl PentestTool for RateLimitTesterTool {
|
||||
&'a self,
|
||||
input: serde_json::Value,
|
||||
context: &'a PentestToolContext,
|
||||
) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<PentestToolResult, CoreError>> + Send + 'a>> {
|
||||
) -> std::pin::Pin<
|
||||
Box<dyn std::future::Future<Output = Result<PentestToolResult, CoreError>> + Send + 'a>,
|
||||
> {
|
||||
Box::pin(async move {
|
||||
let url = input
|
||||
.get("url")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| CoreError::Dast("Missing required 'url' parameter".to_string()))?;
|
||||
let url = input
|
||||
.get("url")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| CoreError::Dast("Missing required 'url' parameter".to_string()))?;
|
||||
|
||||
let method = input
|
||||
.get("method")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("GET");
|
||||
let method = input
|
||||
.get("method")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("GET");
|
||||
|
||||
let request_count = input
|
||||
.get("request_count")
|
||||
.and_then(|v| v.as_u64())
|
||||
.unwrap_or(50)
|
||||
.min(200) as usize;
|
||||
let request_count = input
|
||||
.get("request_count")
|
||||
.and_then(|v| v.as_u64())
|
||||
.unwrap_or(50)
|
||||
.min(200) as usize;
|
||||
|
||||
let body = input.get("body").and_then(|v| v.as_str());
|
||||
let body = input.get("body").and_then(|v| v.as_str());
|
||||
|
||||
let target_id = context
|
||||
.target
|
||||
.id
|
||||
.map(|oid| oid.to_hex())
|
||||
.unwrap_or_else(|| "unknown".to_string());
|
||||
let target_id = context
|
||||
.target
|
||||
.id
|
||||
.map(|oid| oid.to_hex())
|
||||
.unwrap_or_else(|| "unknown".to_string());
|
||||
|
||||
// Respect the context rate limit if set
|
||||
let max_requests = if context.rate_limit > 0 {
|
||||
request_count.min(context.rate_limit as usize * 5)
|
||||
} else {
|
||||
request_count
|
||||
};
|
||||
|
||||
let mut status_codes: Vec<u16> = Vec::with_capacity(max_requests);
|
||||
let mut response_times: Vec<u128> = Vec::with_capacity(max_requests);
|
||||
let mut got_429 = false;
|
||||
let mut rate_limit_at_request: Option<usize> = None;
|
||||
|
||||
for i in 0..max_requests {
|
||||
let start = Instant::now();
|
||||
|
||||
let request = match method {
|
||||
"POST" => {
|
||||
let mut req = self.http.post(url);
|
||||
if let Some(b) = body {
|
||||
req = req.body(b.to_string());
|
||||
}
|
||||
req
|
||||
}
|
||||
"PUT" => {
|
||||
let mut req = self.http.put(url);
|
||||
if let Some(b) = body {
|
||||
req = req.body(b.to_string());
|
||||
}
|
||||
req
|
||||
}
|
||||
"PATCH" => {
|
||||
let mut req = self.http.patch(url);
|
||||
if let Some(b) = body {
|
||||
req = req.body(b.to_string());
|
||||
}
|
||||
req
|
||||
}
|
||||
"DELETE" => self.http.delete(url),
|
||||
_ => self.http.get(url),
|
||||
// Respect the context rate limit if set
|
||||
let max_requests = if context.rate_limit > 0 {
|
||||
request_count.min(context.rate_limit as usize * 5)
|
||||
} else {
|
||||
request_count
|
||||
};
|
||||
|
||||
match request.send().await {
|
||||
Ok(resp) => {
|
||||
let elapsed = start.elapsed().as_millis();
|
||||
let status = resp.status().as_u16();
|
||||
status_codes.push(status);
|
||||
response_times.push(elapsed);
|
||||
let mut status_codes: Vec<u16> = Vec::with_capacity(max_requests);
|
||||
let mut response_times: Vec<u128> = Vec::with_capacity(max_requests);
|
||||
let mut got_429 = false;
|
||||
let mut rate_limit_at_request: Option<usize> = None;
|
||||
|
||||
if status == 429 && !got_429 {
|
||||
got_429 = true;
|
||||
rate_limit_at_request = Some(i + 1);
|
||||
info!(url, request_num = i + 1, "Rate limit triggered (429)");
|
||||
for i in 0..max_requests {
|
||||
let start = Instant::now();
|
||||
|
||||
let request = match method {
|
||||
"POST" => {
|
||||
let mut req = self.http.post(url);
|
||||
if let Some(b) = body {
|
||||
req = req.body(b.to_string());
|
||||
}
|
||||
req
|
||||
}
|
||||
"PUT" => {
|
||||
let mut req = self.http.put(url);
|
||||
if let Some(b) = body {
|
||||
req = req.body(b.to_string());
|
||||
}
|
||||
req
|
||||
}
|
||||
"PATCH" => {
|
||||
let mut req = self.http.patch(url);
|
||||
if let Some(b) = body {
|
||||
req = req.body(b.to_string());
|
||||
}
|
||||
req
|
||||
}
|
||||
"DELETE" => self.http.delete(url),
|
||||
_ => self.http.get(url),
|
||||
};
|
||||
|
||||
// Check for rate limit headers even on 200
|
||||
if !got_429 {
|
||||
let headers = resp.headers();
|
||||
let has_rate_headers = headers.contains_key("x-ratelimit-limit")
|
||||
|| headers.contains_key("x-ratelimit-remaining")
|
||||
|| headers.contains_key("ratelimit-limit")
|
||||
|| headers.contains_key("ratelimit-remaining")
|
||||
|| headers.contains_key("retry-after");
|
||||
match request.send().await {
|
||||
Ok(resp) => {
|
||||
let elapsed = start.elapsed().as_millis();
|
||||
let status = resp.status().as_u16();
|
||||
status_codes.push(status);
|
||||
response_times.push(elapsed);
|
||||
|
||||
if has_rate_headers && rate_limit_at_request.is_none() {
|
||||
// Server has rate limit headers but hasn't blocked yet
|
||||
if status == 429 && !got_429 {
|
||||
got_429 = true;
|
||||
rate_limit_at_request = Some(i + 1);
|
||||
info!(url, request_num = i + 1, "Rate limit triggered (429)");
|
||||
}
|
||||
|
||||
// Check for rate limit headers even on 200
|
||||
if !got_429 {
|
||||
let headers = resp.headers();
|
||||
let has_rate_headers = headers.contains_key("x-ratelimit-limit")
|
||||
|| headers.contains_key("x-ratelimit-remaining")
|
||||
|| headers.contains_key("ratelimit-limit")
|
||||
|| headers.contains_key("ratelimit-remaining")
|
||||
|| headers.contains_key("retry-after");
|
||||
|
||||
if has_rate_headers && rate_limit_at_request.is_none() {
|
||||
// Server has rate limit headers but hasn't blocked yet
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
let elapsed = start.elapsed().as_millis();
|
||||
status_codes.push(0);
|
||||
response_times.push(elapsed);
|
||||
Err(_e) => {
|
||||
let elapsed = start.elapsed().as_millis();
|
||||
status_codes.push(0);
|
||||
response_times.push(elapsed);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let mut findings = Vec::new();
|
||||
let total_sent = status_codes.len();
|
||||
let count_429 = status_codes.iter().filter(|&&s| s == 429).count();
|
||||
let count_success = status_codes.iter().filter(|&&s| (200..300).contains(&s)).count();
|
||||
let mut findings = Vec::new();
|
||||
let total_sent = status_codes.len();
|
||||
let count_429 = status_codes.iter().filter(|&&s| s == 429).count();
|
||||
let count_success = status_codes
|
||||
.iter()
|
||||
.filter(|&&s| (200..300).contains(&s))
|
||||
.count();
|
||||
|
||||
// Calculate response time statistics
|
||||
let avg_time = if !response_times.is_empty() {
|
||||
response_times.iter().sum::<u128>() / response_times.len() as u128
|
||||
} else {
|
||||
0
|
||||
};
|
||||
|
||||
let first_half_avg = if response_times.len() >= 4 {
|
||||
let half = response_times.len() / 2;
|
||||
response_times[..half].iter().sum::<u128>() / half as u128
|
||||
} else {
|
||||
avg_time
|
||||
};
|
||||
|
||||
let second_half_avg = if response_times.len() >= 4 {
|
||||
let half = response_times.len() / 2;
|
||||
response_times[half..].iter().sum::<u128>() / (response_times.len() - half) as u128
|
||||
} else {
|
||||
avg_time
|
||||
};
|
||||
|
||||
// Significant time degradation suggests possible (weak) rate limiting
|
||||
let time_degradation = if first_half_avg > 0 {
|
||||
(second_half_avg as f64 / first_half_avg as f64) - 1.0
|
||||
} else {
|
||||
0.0
|
||||
};
|
||||
|
||||
let rate_data = json!({
|
||||
"total_requests_sent": total_sent,
|
||||
"status_429_count": count_429,
|
||||
"success_count": count_success,
|
||||
"rate_limit_at_request": rate_limit_at_request,
|
||||
"avg_response_time_ms": avg_time,
|
||||
"first_half_avg_ms": first_half_avg,
|
||||
"second_half_avg_ms": second_half_avg,
|
||||
"time_degradation_pct": (time_degradation * 100.0).round(),
|
||||
});
|
||||
|
||||
if !got_429 && count_success == total_sent {
|
||||
// No rate limiting detected at all
|
||||
let evidence = DastEvidence {
|
||||
request_method: method.to_string(),
|
||||
request_url: url.to_string(),
|
||||
request_headers: None,
|
||||
request_body: body.map(String::from),
|
||||
response_status: 200,
|
||||
response_headers: None,
|
||||
response_snippet: Some(format!(
|
||||
"Sent {total_sent} rapid requests. All returned success (2xx). \
|
||||
No 429 responses received. Avg response time: {avg_time}ms."
|
||||
)),
|
||||
screenshot_path: None,
|
||||
payload: None,
|
||||
response_time_ms: Some(avg_time as u64),
|
||||
// Calculate response time statistics
|
||||
let avg_time = if !response_times.is_empty() {
|
||||
response_times.iter().sum::<u128>() / response_times.len() as u128
|
||||
} else {
|
||||
0
|
||||
};
|
||||
|
||||
let mut finding = DastFinding::new(
|
||||
String::new(),
|
||||
target_id.clone(),
|
||||
DastVulnType::RateLimitAbsent,
|
||||
format!("No rate limiting on {} {}", method, url),
|
||||
format!(
|
||||
"The endpoint {} {} does not enforce rate limiting. \
|
||||
let first_half_avg = if response_times.len() >= 4 {
|
||||
let half = response_times.len() / 2;
|
||||
response_times[..half].iter().sum::<u128>() / half as u128
|
||||
} else {
|
||||
avg_time
|
||||
};
|
||||
|
||||
let second_half_avg = if response_times.len() >= 4 {
|
||||
let half = response_times.len() / 2;
|
||||
response_times[half..].iter().sum::<u128>() / (response_times.len() - half) as u128
|
||||
} else {
|
||||
avg_time
|
||||
};
|
||||
|
||||
// Significant time degradation suggests possible (weak) rate limiting
|
||||
let time_degradation = if first_half_avg > 0 {
|
||||
(second_half_avg as f64 / first_half_avg as f64) - 1.0
|
||||
} else {
|
||||
0.0
|
||||
};
|
||||
|
||||
let rate_data = json!({
|
||||
"total_requests_sent": total_sent,
|
||||
"status_429_count": count_429,
|
||||
"success_count": count_success,
|
||||
"rate_limit_at_request": rate_limit_at_request,
|
||||
"avg_response_time_ms": avg_time,
|
||||
"first_half_avg_ms": first_half_avg,
|
||||
"second_half_avg_ms": second_half_avg,
|
||||
"time_degradation_pct": (time_degradation * 100.0).round(),
|
||||
});
|
||||
|
||||
if !got_429 && count_success == total_sent {
|
||||
// No rate limiting detected at all
|
||||
let evidence = DastEvidence {
|
||||
request_method: method.to_string(),
|
||||
request_url: url.to_string(),
|
||||
request_headers: None,
|
||||
request_body: body.map(String::from),
|
||||
response_status: 200,
|
||||
response_headers: None,
|
||||
response_snippet: Some(format!(
|
||||
"Sent {total_sent} rapid requests. All returned success (2xx). \
|
||||
No 429 responses received. Avg response time: {avg_time}ms."
|
||||
)),
|
||||
screenshot_path: None,
|
||||
payload: None,
|
||||
response_time_ms: Some(avg_time as u64),
|
||||
};
|
||||
|
||||
let mut finding = DastFinding::new(
|
||||
String::new(),
|
||||
target_id.clone(),
|
||||
DastVulnType::RateLimitAbsent,
|
||||
format!("No rate limiting on {} {}", method, url),
|
||||
format!(
|
||||
"The endpoint {} {} does not enforce rate limiting. \
|
||||
{total_sent} rapid requests were all accepted with no 429 responses \
|
||||
or noticeable degradation. This makes the endpoint vulnerable to \
|
||||
brute force attacks and abuse.",
|
||||
method, url
|
||||
),
|
||||
Severity::Medium,
|
||||
url.to_string(),
|
||||
method.to_string(),
|
||||
);
|
||||
finding.cwe = Some("CWE-770".to_string());
|
||||
finding.evidence = vec![evidence];
|
||||
finding.remediation = Some(
|
||||
"Implement rate limiting on this endpoint. Use token bucket or sliding window \
|
||||
method, url
|
||||
),
|
||||
Severity::Medium,
|
||||
url.to_string(),
|
||||
method.to_string(),
|
||||
);
|
||||
finding.cwe = Some("CWE-770".to_string());
|
||||
finding.evidence = vec![evidence];
|
||||
finding.remediation = Some(
|
||||
"Implement rate limiting on this endpoint. Use token bucket or sliding window \
|
||||
algorithms. Return 429 Too Many Requests with a Retry-After header when the \
|
||||
limit is exceeded."
|
||||
.to_string(),
|
||||
);
|
||||
findings.push(finding);
|
||||
warn!(url, method, total_sent, "No rate limiting detected");
|
||||
} else if got_429 {
|
||||
info!(
|
||||
url,
|
||||
method,
|
||||
rate_limit_at = ?rate_limit_at_request,
|
||||
"Rate limiting is enforced"
|
||||
);
|
||||
}
|
||||
.to_string(),
|
||||
);
|
||||
findings.push(finding);
|
||||
warn!(url, method, total_sent, "No rate limiting detected");
|
||||
} else if got_429 {
|
||||
info!(
|
||||
url,
|
||||
method,
|
||||
rate_limit_at = ?rate_limit_at_request,
|
||||
"Rate limiting is enforced"
|
||||
);
|
||||
}
|
||||
|
||||
let count = findings.len();
|
||||
let count = findings.len();
|
||||
|
||||
Ok(PentestToolResult {
|
||||
summary: if got_429 {
|
||||
format!(
|
||||
"Rate limiting is enforced on {method} {url}. \
|
||||
Ok(PentestToolResult {
|
||||
summary: if got_429 {
|
||||
format!(
|
||||
"Rate limiting is enforced on {method} {url}. \
|
||||
429 response received after {} requests.",
|
||||
rate_limit_at_request.unwrap_or(0)
|
||||
)
|
||||
} else if count > 0 {
|
||||
format!(
|
||||
"No rate limiting detected on {method} {url} after {total_sent} requests."
|
||||
)
|
||||
} else {
|
||||
format!("Rate limit testing complete for {method} {url}.")
|
||||
},
|
||||
findings,
|
||||
data: rate_data,
|
||||
})
|
||||
rate_limit_at_request.unwrap_or(0)
|
||||
)
|
||||
} else if count > 0 {
|
||||
format!(
|
||||
"No rate limiting detected on {method} {url} after {total_sent} requests."
|
||||
)
|
||||
} else {
|
||||
format!("Rate limit testing complete for {method} {url}.")
|
||||
},
|
||||
findings,
|
||||
data: rate_data,
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -54,72 +54,75 @@ impl PentestTool for ReconTool {
|
||||
fn execute<'a>(
|
||||
&'a self,
|
||||
input: serde_json::Value,
|
||||
context: &'a PentestToolContext,
|
||||
) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<PentestToolResult, CoreError>> + Send + 'a>> {
|
||||
_context: &'a PentestToolContext,
|
||||
) -> std::pin::Pin<
|
||||
Box<dyn std::future::Future<Output = Result<PentestToolResult, CoreError>> + Send + 'a>,
|
||||
> {
|
||||
Box::pin(async move {
|
||||
let url = input
|
||||
.get("url")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| CoreError::Dast("Missing required 'url' parameter".to_string()))?;
|
||||
let url = input
|
||||
.get("url")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| CoreError::Dast("Missing required 'url' parameter".to_string()))?;
|
||||
|
||||
let additional_paths: Vec<String> = input
|
||||
.get("additional_paths")
|
||||
.and_then(|v| v.as_array())
|
||||
.map(|arr| {
|
||||
arr.iter()
|
||||
.filter_map(|v| v.as_str().map(String::from))
|
||||
.collect()
|
||||
})
|
||||
.unwrap_or_default();
|
||||
let additional_paths: Vec<String> = input
|
||||
.get("additional_paths")
|
||||
.and_then(|v| v.as_array())
|
||||
.map(|arr| {
|
||||
arr.iter()
|
||||
.filter_map(|v| v.as_str().map(String::from))
|
||||
.collect()
|
||||
})
|
||||
.unwrap_or_default();
|
||||
|
||||
let result = self.agent.scan(url).await?;
|
||||
let result = self.agent.scan(url).await?;
|
||||
|
||||
// Scan additional paths for more technology signals
|
||||
let mut extra_technologies: Vec<String> = Vec::new();
|
||||
let mut extra_headers: HashMap<String, String> = HashMap::new();
|
||||
// Scan additional paths for more technology signals
|
||||
let mut extra_technologies: Vec<String> = Vec::new();
|
||||
let mut extra_headers: HashMap<String, String> = HashMap::new();
|
||||
|
||||
let base_url = url.trim_end_matches('/');
|
||||
for path in &additional_paths {
|
||||
let probe_url = format!("{base_url}/{}", path.trim_start_matches('/'));
|
||||
if let Ok(resp) = self.http.get(&probe_url).send().await {
|
||||
for (key, value) in resp.headers() {
|
||||
let k = key.to_string().to_lowercase();
|
||||
let v = value.to_str().unwrap_or("").to_string();
|
||||
let base_url = url.trim_end_matches('/');
|
||||
for path in &additional_paths {
|
||||
let probe_url = format!("{base_url}/{}", path.trim_start_matches('/'));
|
||||
if let Ok(resp) = self.http.get(&probe_url).send().await {
|
||||
for (key, value) in resp.headers() {
|
||||
let k = key.to_string().to_lowercase();
|
||||
let v = value.to_str().unwrap_or("").to_string();
|
||||
|
||||
// Look for technology indicators
|
||||
if k == "x-powered-by" || k == "server" || k == "x-generator" {
|
||||
if !result.technologies.contains(&v) && !extra_technologies.contains(&v) {
|
||||
// Look for technology indicators
|
||||
if (k == "x-powered-by" || k == "server" || k == "x-generator")
|
||||
&& !result.technologies.contains(&v)
|
||||
&& !extra_technologies.contains(&v)
|
||||
{
|
||||
extra_technologies.push(v.clone());
|
||||
}
|
||||
extra_headers.insert(format!("{probe_url} -> {k}"), v);
|
||||
}
|
||||
extra_headers.insert(format!("{probe_url} -> {k}"), v);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let mut all_technologies = result.technologies.clone();
|
||||
all_technologies.extend(extra_technologies);
|
||||
all_technologies.dedup();
|
||||
let mut all_technologies = result.technologies.clone();
|
||||
all_technologies.extend(extra_technologies);
|
||||
all_technologies.dedup();
|
||||
|
||||
let tech_count = all_technologies.len();
|
||||
info!(url, technologies = tech_count, "Recon complete");
|
||||
let tech_count = all_technologies.len();
|
||||
info!(url, technologies = tech_count, "Recon complete");
|
||||
|
||||
Ok(PentestToolResult {
|
||||
summary: format!(
|
||||
"Recon complete for {url}. Detected {} technologies. Server: {}.",
|
||||
tech_count,
|
||||
result.server.as_deref().unwrap_or("unknown")
|
||||
),
|
||||
findings: Vec::new(), // Recon produces data, not findings
|
||||
data: json!({
|
||||
"base_url": url,
|
||||
"server": result.server,
|
||||
"technologies": all_technologies,
|
||||
"interesting_headers": result.interesting_headers,
|
||||
"extra_headers": extra_headers,
|
||||
"open_ports": result.open_ports,
|
||||
}),
|
||||
})
|
||||
Ok(PentestToolResult {
|
||||
summary: format!(
|
||||
"Recon complete for {url}. Detected {} technologies. Server: {}.",
|
||||
tech_count,
|
||||
result.server.as_deref().unwrap_or("unknown")
|
||||
),
|
||||
findings: Vec::new(), // Recon produces data, not findings
|
||||
data: json!({
|
||||
"base_url": url,
|
||||
"server": result.server,
|
||||
"technologies": all_technologies,
|
||||
"interesting_headers": result.interesting_headers,
|
||||
"extra_headers": extra_headers,
|
||||
"open_ports": result.open_ports,
|
||||
}),
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -111,57 +111,107 @@ impl PentestTool for SecurityHeadersTool {
|
||||
&'a self,
|
||||
input: serde_json::Value,
|
||||
context: &'a PentestToolContext,
|
||||
) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<PentestToolResult, CoreError>> + Send + 'a>> {
|
||||
) -> std::pin::Pin<
|
||||
Box<dyn std::future::Future<Output = Result<PentestToolResult, CoreError>> + Send + 'a>,
|
||||
> {
|
||||
Box::pin(async move {
|
||||
let url = input
|
||||
.get("url")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| CoreError::Dast("Missing required 'url' parameter".to_string()))?;
|
||||
let url = input
|
||||
.get("url")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| CoreError::Dast("Missing required 'url' parameter".to_string()))?;
|
||||
|
||||
let target_id = context
|
||||
.target
|
||||
.id
|
||||
.map(|oid| oid.to_hex())
|
||||
.unwrap_or_else(|| "unknown".to_string());
|
||||
let target_id = context
|
||||
.target
|
||||
.id
|
||||
.map(|oid| oid.to_hex())
|
||||
.unwrap_or_else(|| "unknown".to_string());
|
||||
|
||||
let response = self
|
||||
.http
|
||||
.get(url)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| CoreError::Dast(format!("Failed to fetch {url}: {e}")))?;
|
||||
let response = self
|
||||
.http
|
||||
.get(url)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| CoreError::Dast(format!("Failed to fetch {url}: {e}")))?;
|
||||
|
||||
let status = response.status().as_u16();
|
||||
let response_headers: HashMap<String, String> = response
|
||||
.headers()
|
||||
.iter()
|
||||
.map(|(k, v)| (k.to_string().to_lowercase(), v.to_str().unwrap_or("").to_string()))
|
||||
.collect();
|
||||
let status = response.status().as_u16();
|
||||
let response_headers: HashMap<String, String> = response
|
||||
.headers()
|
||||
.iter()
|
||||
.map(|(k, v)| {
|
||||
(
|
||||
k.to_string().to_lowercase(),
|
||||
v.to_str().unwrap_or("").to_string(),
|
||||
)
|
||||
})
|
||||
.collect();
|
||||
|
||||
let mut findings = Vec::new();
|
||||
let mut header_results: HashMap<String, serde_json::Value> = HashMap::new();
|
||||
let mut findings = Vec::new();
|
||||
let mut header_results: HashMap<String, serde_json::Value> = HashMap::new();
|
||||
|
||||
for expected in Self::expected_headers() {
|
||||
let header_value = response_headers.get(expected.name);
|
||||
for expected in Self::expected_headers() {
|
||||
let header_value = response_headers.get(expected.name);
|
||||
|
||||
match header_value {
|
||||
Some(value) => {
|
||||
let mut is_valid = true;
|
||||
if let Some(ref valid) = expected.valid_values {
|
||||
let lower = value.to_lowercase();
|
||||
is_valid = valid.iter().any(|v| lower.contains(v));
|
||||
match header_value {
|
||||
Some(value) => {
|
||||
let mut is_valid = true;
|
||||
if let Some(ref valid) = expected.valid_values {
|
||||
let lower = value.to_lowercase();
|
||||
is_valid = valid.iter().any(|v| lower.contains(v));
|
||||
}
|
||||
|
||||
header_results.insert(
|
||||
expected.name.to_string(),
|
||||
json!({
|
||||
"present": true,
|
||||
"value": value,
|
||||
"valid": is_valid,
|
||||
}),
|
||||
);
|
||||
|
||||
if !is_valid {
|
||||
let evidence = DastEvidence {
|
||||
request_method: "GET".to_string(),
|
||||
request_url: url.to_string(),
|
||||
request_headers: None,
|
||||
request_body: None,
|
||||
response_status: status,
|
||||
response_headers: Some(response_headers.clone()),
|
||||
response_snippet: Some(format!("{}: {}", expected.name, value)),
|
||||
screenshot_path: None,
|
||||
payload: None,
|
||||
response_time_ms: None,
|
||||
};
|
||||
|
||||
let mut finding = DastFinding::new(
|
||||
String::new(),
|
||||
target_id.clone(),
|
||||
DastVulnType::SecurityHeaderMissing,
|
||||
format!("Invalid {} header value", expected.name),
|
||||
format!(
|
||||
"The {} header is present but has an invalid or weak value: '{}'. \
|
||||
{}",
|
||||
expected.name, value, expected.description
|
||||
),
|
||||
expected.severity.clone(),
|
||||
url.to_string(),
|
||||
"GET".to_string(),
|
||||
);
|
||||
finding.cwe = Some(expected.cwe.to_string());
|
||||
finding.evidence = vec![evidence];
|
||||
finding.remediation = Some(expected.remediation.to_string());
|
||||
findings.push(finding);
|
||||
}
|
||||
}
|
||||
None => {
|
||||
header_results.insert(
|
||||
expected.name.to_string(),
|
||||
json!({
|
||||
"present": false,
|
||||
"value": null,
|
||||
"valid": false,
|
||||
}),
|
||||
);
|
||||
|
||||
header_results.insert(
|
||||
expected.name.to_string(),
|
||||
json!({
|
||||
"present": true,
|
||||
"value": value,
|
||||
"valid": is_valid,
|
||||
}),
|
||||
);
|
||||
|
||||
if !is_valid {
|
||||
let evidence = DastEvidence {
|
||||
request_method: "GET".to_string(),
|
||||
request_url: url.to_string(),
|
||||
@@ -169,7 +219,7 @@ impl PentestTool for SecurityHeadersTool {
|
||||
request_body: None,
|
||||
response_status: status,
|
||||
response_headers: Some(response_headers.clone()),
|
||||
response_snippet: Some(format!("{}: {}", expected.name, value)),
|
||||
response_snippet: Some(format!("{} header is missing", expected.name)),
|
||||
screenshot_path: None,
|
||||
payload: None,
|
||||
response_time_ms: None,
|
||||
@@ -179,11 +229,10 @@ impl PentestTool for SecurityHeadersTool {
|
||||
String::new(),
|
||||
target_id.clone(),
|
||||
DastVulnType::SecurityHeaderMissing,
|
||||
format!("Invalid {} header value", expected.name),
|
||||
format!("Missing {} header", expected.name),
|
||||
format!(
|
||||
"The {} header is present but has an invalid or weak value: '{}'. \
|
||||
{}",
|
||||
expected.name, value, expected.description
|
||||
"The {} header is not present in the response. {}",
|
||||
expected.name, expected.description
|
||||
),
|
||||
expected.severity.clone(),
|
||||
url.to_string(),
|
||||
@@ -195,14 +244,20 @@ impl PentestTool for SecurityHeadersTool {
|
||||
findings.push(finding);
|
||||
}
|
||||
}
|
||||
None => {
|
||||
}
|
||||
|
||||
// Also check for information disclosure headers
|
||||
let disclosure_headers = [
|
||||
"server",
|
||||
"x-powered-by",
|
||||
"x-aspnet-version",
|
||||
"x-aspnetmvc-version",
|
||||
];
|
||||
for h in &disclosure_headers {
|
||||
if let Some(value) = response_headers.get(*h) {
|
||||
header_results.insert(
|
||||
expected.name.to_string(),
|
||||
json!({
|
||||
"present": false,
|
||||
"value": null,
|
||||
"valid": false,
|
||||
}),
|
||||
format!("{h}_disclosure"),
|
||||
json!({ "present": true, "value": value }),
|
||||
);
|
||||
|
||||
let evidence = DastEvidence {
|
||||
@@ -212,56 +267,13 @@ impl PentestTool for SecurityHeadersTool {
|
||||
request_body: None,
|
||||
response_status: status,
|
||||
response_headers: Some(response_headers.clone()),
|
||||
response_snippet: Some(format!("{} header is missing", expected.name)),
|
||||
response_snippet: Some(format!("{h}: {value}")),
|
||||
screenshot_path: None,
|
||||
payload: None,
|
||||
response_time_ms: None,
|
||||
};
|
||||
|
||||
let mut finding = DastFinding::new(
|
||||
String::new(),
|
||||
target_id.clone(),
|
||||
DastVulnType::SecurityHeaderMissing,
|
||||
format!("Missing {} header", expected.name),
|
||||
format!(
|
||||
"The {} header is not present in the response. {}",
|
||||
expected.name, expected.description
|
||||
),
|
||||
expected.severity.clone(),
|
||||
url.to_string(),
|
||||
"GET".to_string(),
|
||||
);
|
||||
finding.cwe = Some(expected.cwe.to_string());
|
||||
finding.evidence = vec![evidence];
|
||||
finding.remediation = Some(expected.remediation.to_string());
|
||||
findings.push(finding);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Also check for information disclosure headers
|
||||
let disclosure_headers = ["server", "x-powered-by", "x-aspnet-version", "x-aspnetmvc-version"];
|
||||
for h in &disclosure_headers {
|
||||
if let Some(value) = response_headers.get(*h) {
|
||||
header_results.insert(
|
||||
format!("{h}_disclosure"),
|
||||
json!({ "present": true, "value": value }),
|
||||
);
|
||||
|
||||
let evidence = DastEvidence {
|
||||
request_method: "GET".to_string(),
|
||||
request_url: url.to_string(),
|
||||
request_headers: None,
|
||||
request_body: None,
|
||||
response_status: status,
|
||||
response_headers: Some(response_headers.clone()),
|
||||
response_snippet: Some(format!("{h}: {value}")),
|
||||
screenshot_path: None,
|
||||
payload: None,
|
||||
response_time_ms: None,
|
||||
};
|
||||
|
||||
let mut finding = DastFinding::new(
|
||||
String::new(),
|
||||
target_id.clone(),
|
||||
DastVulnType::SecurityHeaderMissing,
|
||||
@@ -274,27 +286,27 @@ impl PentestTool for SecurityHeadersTool {
|
||||
url.to_string(),
|
||||
"GET".to_string(),
|
||||
);
|
||||
finding.cwe = Some("CWE-200".to_string());
|
||||
finding.evidence = vec![evidence];
|
||||
finding.remediation = Some(format!(
|
||||
"Remove or suppress the {h} header in your server configuration."
|
||||
));
|
||||
findings.push(finding);
|
||||
finding.cwe = Some("CWE-200".to_string());
|
||||
finding.evidence = vec![evidence];
|
||||
finding.remediation = Some(format!(
|
||||
"Remove or suppress the {h} header in your server configuration."
|
||||
));
|
||||
findings.push(finding);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let count = findings.len();
|
||||
info!(url, findings = count, "Security headers check complete");
|
||||
let count = findings.len();
|
||||
info!(url, findings = count, "Security headers check complete");
|
||||
|
||||
Ok(PentestToolResult {
|
||||
summary: if count > 0 {
|
||||
format!("Found {count} security header issues for {url}.")
|
||||
} else {
|
||||
format!("All checked security headers are present and valid for {url}.")
|
||||
},
|
||||
findings,
|
||||
data: json!(header_results),
|
||||
})
|
||||
Ok(PentestToolResult {
|
||||
summary: if count > 0 {
|
||||
format!("Found {count} security header issues for {url}.")
|
||||
} else {
|
||||
format!("All checked security headers are present and valid for {url}.")
|
||||
},
|
||||
findings,
|
||||
data: json!(header_results),
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
use compliance_core::error::CoreError;
|
||||
use compliance_core::traits::dast_agent::{DastAgent, DastContext, DiscoveredEndpoint, EndpointParameter};
|
||||
use compliance_core::traits::dast_agent::{
|
||||
DastAgent, DastContext, DiscoveredEndpoint, EndpointParameter,
|
||||
};
|
||||
use compliance_core::traits::pentest_tool::{PentestTool, PentestToolContext, PentestToolResult};
|
||||
use serde_json::json;
|
||||
|
||||
@@ -7,29 +9,51 @@ use crate::agents::injection::SqlInjectionAgent;
|
||||
|
||||
/// PentestTool wrapper around the existing SqlInjectionAgent.
|
||||
pub struct SqlInjectionTool {
|
||||
http: reqwest::Client,
|
||||
_http: reqwest::Client,
|
||||
agent: SqlInjectionAgent,
|
||||
}
|
||||
|
||||
impl SqlInjectionTool {
|
||||
pub fn new(http: reqwest::Client) -> Self {
|
||||
let agent = SqlInjectionAgent::new(http.clone());
|
||||
Self { http, agent }
|
||||
Self { _http: http, agent }
|
||||
}
|
||||
|
||||
fn parse_endpoints(input: &serde_json::Value) -> Vec<DiscoveredEndpoint> {
|
||||
let mut endpoints = Vec::new();
|
||||
if let Some(arr) = input.get("endpoints").and_then(|v| v.as_array()) {
|
||||
for ep in arr {
|
||||
let url = ep.get("url").and_then(|v| v.as_str()).unwrap_or_default().to_string();
|
||||
let method = ep.get("method").and_then(|v| v.as_str()).unwrap_or("GET").to_string();
|
||||
let url = ep
|
||||
.get("url")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or_default()
|
||||
.to_string();
|
||||
let method = ep
|
||||
.get("method")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("GET")
|
||||
.to_string();
|
||||
let mut parameters = Vec::new();
|
||||
if let Some(params) = ep.get("parameters").and_then(|v| v.as_array()) {
|
||||
for p in params {
|
||||
let name = p.get("name").and_then(|v| v.as_str()).unwrap_or_default().to_string();
|
||||
let location = p.get("location").and_then(|v| v.as_str()).unwrap_or("query").to_string();
|
||||
let param_type = p.get("param_type").and_then(|v| v.as_str()).map(String::from);
|
||||
let example_value = p.get("example_value").and_then(|v| v.as_str()).map(String::from);
|
||||
let name = p
|
||||
.get("name")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or_default()
|
||||
.to_string();
|
||||
let location = p
|
||||
.get("location")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("query")
|
||||
.to_string();
|
||||
let param_type = p
|
||||
.get("param_type")
|
||||
.and_then(|v| v.as_str())
|
||||
.map(String::from);
|
||||
let example_value = p
|
||||
.get("example_value")
|
||||
.and_then(|v| v.as_str())
|
||||
.map(String::from);
|
||||
parameters.push(EndpointParameter {
|
||||
name,
|
||||
location,
|
||||
@@ -42,8 +66,14 @@ impl SqlInjectionTool {
|
||||
url,
|
||||
method,
|
||||
parameters,
|
||||
content_type: ep.get("content_type").and_then(|v| v.as_str()).map(String::from),
|
||||
requires_auth: ep.get("requires_auth").and_then(|v| v.as_bool()).unwrap_or(false),
|
||||
content_type: ep
|
||||
.get("content_type")
|
||||
.and_then(|v| v.as_str())
|
||||
.map(String::from),
|
||||
requires_auth: ep
|
||||
.get("requires_auth")
|
||||
.and_then(|v| v.as_bool())
|
||||
.unwrap_or(false),
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -51,6 +81,62 @@ impl SqlInjectionTool {
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use serde_json::json;
|
||||
|
||||
#[test]
|
||||
fn parse_endpoints_basic() {
|
||||
let input = json!({
|
||||
"endpoints": [{
|
||||
"url": "https://example.com/api/users",
|
||||
"method": "POST",
|
||||
"parameters": [
|
||||
{ "name": "id", "location": "body", "param_type": "integer" }
|
||||
]
|
||||
}]
|
||||
});
|
||||
let endpoints = SqlInjectionTool::parse_endpoints(&input);
|
||||
assert_eq!(endpoints.len(), 1);
|
||||
assert_eq!(endpoints[0].url, "https://example.com/api/users");
|
||||
assert_eq!(endpoints[0].method, "POST");
|
||||
assert_eq!(endpoints[0].parameters[0].name, "id");
|
||||
assert_eq!(endpoints[0].parameters[0].location, "body");
|
||||
assert_eq!(
|
||||
endpoints[0].parameters[0].param_type.as_deref(),
|
||||
Some("integer")
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_endpoints_empty_input() {
|
||||
assert!(SqlInjectionTool::parse_endpoints(&json!({})).is_empty());
|
||||
assert!(SqlInjectionTool::parse_endpoints(&json!({ "endpoints": [] })).is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_endpoints_multiple() {
|
||||
let input = json!({
|
||||
"endpoints": [
|
||||
{ "url": "https://a.com/1", "method": "GET", "parameters": [] },
|
||||
{ "url": "https://b.com/2", "method": "DELETE", "parameters": [] }
|
||||
]
|
||||
});
|
||||
let endpoints = SqlInjectionTool::parse_endpoints(&input);
|
||||
assert_eq!(endpoints.len(), 2);
|
||||
assert_eq!(endpoints[0].url, "https://a.com/1");
|
||||
assert_eq!(endpoints[1].method, "DELETE");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_endpoints_default_method() {
|
||||
let input = json!({ "endpoints": [{ "url": "https://x.com", "parameters": [] }] });
|
||||
let endpoints = SqlInjectionTool::parse_endpoints(&input);
|
||||
assert_eq!(endpoints[0].method, "GET");
|
||||
}
|
||||
}
|
||||
|
||||
impl PentestTool for SqlInjectionTool {
|
||||
fn name(&self) -> &str {
|
||||
"sql_injection_scanner"
|
||||
@@ -104,35 +190,37 @@ impl PentestTool for SqlInjectionTool {
|
||||
&'a self,
|
||||
input: serde_json::Value,
|
||||
context: &'a PentestToolContext,
|
||||
) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<PentestToolResult, CoreError>> + Send + 'a>> {
|
||||
) -> std::pin::Pin<
|
||||
Box<dyn std::future::Future<Output = Result<PentestToolResult, CoreError>> + Send + 'a>,
|
||||
> {
|
||||
Box::pin(async move {
|
||||
let endpoints = Self::parse_endpoints(&input);
|
||||
if endpoints.is_empty() {
|
||||
return Ok(PentestToolResult {
|
||||
summary: "No endpoints provided to test.".to_string(),
|
||||
findings: Vec::new(),
|
||||
data: json!({}),
|
||||
});
|
||||
}
|
||||
let endpoints = Self::parse_endpoints(&input);
|
||||
if endpoints.is_empty() {
|
||||
return Ok(PentestToolResult {
|
||||
summary: "No endpoints provided to test.".to_string(),
|
||||
findings: Vec::new(),
|
||||
data: json!({}),
|
||||
});
|
||||
}
|
||||
|
||||
let dast_context = DastContext {
|
||||
endpoints,
|
||||
technologies: Vec::new(),
|
||||
sast_hints: Vec::new(),
|
||||
};
|
||||
let dast_context = DastContext {
|
||||
endpoints,
|
||||
technologies: Vec::new(),
|
||||
sast_hints: Vec::new(),
|
||||
};
|
||||
|
||||
let findings = self.agent.run(&context.target, &dast_context).await?;
|
||||
let count = findings.len();
|
||||
let findings = self.agent.run(&context.target, &dast_context).await?;
|
||||
let count = findings.len();
|
||||
|
||||
Ok(PentestToolResult {
|
||||
summary: if count > 0 {
|
||||
format!("Found {count} SQL injection vulnerabilities.")
|
||||
} else {
|
||||
"No SQL injection vulnerabilities detected.".to_string()
|
||||
},
|
||||
findings,
|
||||
data: json!({ "endpoints_tested": dast_context.endpoints.len() }),
|
||||
})
|
||||
Ok(PentestToolResult {
|
||||
summary: if count > 0 {
|
||||
format!("Found {count} SQL injection vulnerabilities.")
|
||||
} else {
|
||||
"No SQL injection vulnerabilities detected.".to_string()
|
||||
},
|
||||
findings,
|
||||
data: json!({ "endpoints_tested": dast_context.endpoints.len() }),
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
use compliance_core::error::CoreError;
|
||||
use compliance_core::traits::dast_agent::{DastAgent, DastContext, DiscoveredEndpoint, EndpointParameter};
|
||||
use compliance_core::traits::dast_agent::{
|
||||
DastAgent, DastContext, DiscoveredEndpoint, EndpointParameter,
|
||||
};
|
||||
use compliance_core::traits::pentest_tool::{PentestTool, PentestToolContext, PentestToolResult};
|
||||
use serde_json::json;
|
||||
|
||||
@@ -7,30 +9,52 @@ use crate::agents::ssrf::SsrfAgent;
|
||||
|
||||
/// PentestTool wrapper around the existing SsrfAgent.
|
||||
pub struct SsrfTool {
|
||||
http: reqwest::Client,
|
||||
_http: reqwest::Client,
|
||||
agent: SsrfAgent,
|
||||
}
|
||||
|
||||
impl SsrfTool {
|
||||
pub fn new(http: reqwest::Client) -> Self {
|
||||
let agent = SsrfAgent::new(http.clone());
|
||||
Self { http, agent }
|
||||
Self { _http: http, agent }
|
||||
}
|
||||
|
||||
fn parse_endpoints(input: &serde_json::Value) -> Vec<DiscoveredEndpoint> {
|
||||
let mut endpoints = Vec::new();
|
||||
if let Some(arr) = input.get("endpoints").and_then(|v| v.as_array()) {
|
||||
for ep in arr {
|
||||
let url = ep.get("url").and_then(|v| v.as_str()).unwrap_or_default().to_string();
|
||||
let method = ep.get("method").and_then(|v| v.as_str()).unwrap_or("GET").to_string();
|
||||
let url = ep
|
||||
.get("url")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or_default()
|
||||
.to_string();
|
||||
let method = ep
|
||||
.get("method")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("GET")
|
||||
.to_string();
|
||||
let mut parameters = Vec::new();
|
||||
if let Some(params) = ep.get("parameters").and_then(|v| v.as_array()) {
|
||||
for p in params {
|
||||
parameters.push(EndpointParameter {
|
||||
name: p.get("name").and_then(|v| v.as_str()).unwrap_or_default().to_string(),
|
||||
location: p.get("location").and_then(|v| v.as_str()).unwrap_or("query").to_string(),
|
||||
param_type: p.get("param_type").and_then(|v| v.as_str()).map(String::from),
|
||||
example_value: p.get("example_value").and_then(|v| v.as_str()).map(String::from),
|
||||
name: p
|
||||
.get("name")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or_default()
|
||||
.to_string(),
|
||||
location: p
|
||||
.get("location")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("query")
|
||||
.to_string(),
|
||||
param_type: p
|
||||
.get("param_type")
|
||||
.and_then(|v| v.as_str())
|
||||
.map(String::from),
|
||||
example_value: p
|
||||
.get("example_value")
|
||||
.and_then(|v| v.as_str())
|
||||
.map(String::from),
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -38,8 +62,14 @@ impl SsrfTool {
|
||||
url,
|
||||
method,
|
||||
parameters,
|
||||
content_type: ep.get("content_type").and_then(|v| v.as_str()).map(String::from),
|
||||
requires_auth: ep.get("requires_auth").and_then(|v| v.as_bool()).unwrap_or(false),
|
||||
content_type: ep
|
||||
.get("content_type")
|
||||
.and_then(|v| v.as_str())
|
||||
.map(String::from),
|
||||
requires_auth: ep
|
||||
.get("requires_auth")
|
||||
.and_then(|v| v.as_bool())
|
||||
.unwrap_or(false),
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -100,35 +130,37 @@ impl PentestTool for SsrfTool {
|
||||
&'a self,
|
||||
input: serde_json::Value,
|
||||
context: &'a PentestToolContext,
|
||||
) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<PentestToolResult, CoreError>> + Send + 'a>> {
|
||||
) -> std::pin::Pin<
|
||||
Box<dyn std::future::Future<Output = Result<PentestToolResult, CoreError>> + Send + 'a>,
|
||||
> {
|
||||
Box::pin(async move {
|
||||
let endpoints = Self::parse_endpoints(&input);
|
||||
if endpoints.is_empty() {
|
||||
return Ok(PentestToolResult {
|
||||
summary: "No endpoints provided to test.".to_string(),
|
||||
findings: Vec::new(),
|
||||
data: json!({}),
|
||||
});
|
||||
}
|
||||
let endpoints = Self::parse_endpoints(&input);
|
||||
if endpoints.is_empty() {
|
||||
return Ok(PentestToolResult {
|
||||
summary: "No endpoints provided to test.".to_string(),
|
||||
findings: Vec::new(),
|
||||
data: json!({}),
|
||||
});
|
||||
}
|
||||
|
||||
let dast_context = DastContext {
|
||||
endpoints,
|
||||
technologies: Vec::new(),
|
||||
sast_hints: Vec::new(),
|
||||
};
|
||||
let dast_context = DastContext {
|
||||
endpoints,
|
||||
technologies: Vec::new(),
|
||||
sast_hints: Vec::new(),
|
||||
};
|
||||
|
||||
let findings = self.agent.run(&context.target, &dast_context).await?;
|
||||
let count = findings.len();
|
||||
let findings = self.agent.run(&context.target, &dast_context).await?;
|
||||
let count = findings.len();
|
||||
|
||||
Ok(PentestToolResult {
|
||||
summary: if count > 0 {
|
||||
format!("Found {count} SSRF vulnerabilities.")
|
||||
} else {
|
||||
"No SSRF vulnerabilities detected.".to_string()
|
||||
},
|
||||
findings,
|
||||
data: json!({ "endpoints_tested": dast_context.endpoints.len() }),
|
||||
})
|
||||
Ok(PentestToolResult {
|
||||
summary: if count > 0 {
|
||||
format!("Found {count} SSRF vulnerabilities.")
|
||||
} else {
|
||||
"No SSRF vulnerabilities detected.".to_string()
|
||||
},
|
||||
findings,
|
||||
data: json!({ "endpoints_tested": dast_context.endpoints.len() }),
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -39,10 +39,7 @@ impl TlsAnalyzerTool {
|
||||
/// TLS client hello. We test SSLv3 / old protocol support by attempting
|
||||
/// connection with the system's native-tls which typically negotiates the
|
||||
/// best available, then inspect what was negotiated.
|
||||
async fn check_tls(
|
||||
host: &str,
|
||||
port: u16,
|
||||
) -> Result<TlsInfo, CoreError> {
|
||||
async fn check_tls(host: &str, port: u16) -> Result<TlsInfo, CoreError> {
|
||||
let addr = format!("{host}:{port}");
|
||||
|
||||
let tcp = TcpStream::connect(&addr)
|
||||
@@ -62,11 +59,13 @@ impl TlsAnalyzerTool {
|
||||
.await
|
||||
.map_err(|e| CoreError::Dast(format!("TLS handshake with {addr} failed: {e}")))?;
|
||||
|
||||
let peer_cert = tls_stream.get_ref().peer_certificate()
|
||||
let peer_cert = tls_stream
|
||||
.get_ref()
|
||||
.peer_certificate()
|
||||
.map_err(|e| CoreError::Dast(format!("Failed to get peer certificate: {e}")))?;
|
||||
|
||||
let mut tls_info = TlsInfo {
|
||||
protocol_version: String::new(),
|
||||
_protocol_version: String::new(),
|
||||
cert_subject: String::new(),
|
||||
cert_issuer: String::new(),
|
||||
cert_not_before: String::new(),
|
||||
@@ -78,7 +77,8 @@ impl TlsAnalyzerTool {
|
||||
};
|
||||
|
||||
if let Some(cert) = peer_cert {
|
||||
let der = cert.to_der()
|
||||
let der = cert
|
||||
.to_der()
|
||||
.map_err(|e| CoreError::Dast(format!("Certificate DER encoding failed: {e}")))?;
|
||||
|
||||
// native_tls doesn't give rich access, so we parse what we can
|
||||
@@ -93,7 +93,7 @@ impl TlsAnalyzerTool {
|
||||
|
||||
/// Best-effort parse of DER-encoded X.509 certificate for dates and subject.
|
||||
/// This is a simplified parser; in production you would use a proper x509 crate.
|
||||
fn parse_cert_der(der: &[u8], mut info: TlsInfo) -> TlsInfo {
|
||||
fn parse_cert_der(_der: &[u8], mut info: TlsInfo) -> TlsInfo {
|
||||
// We rely on the native_tls debug output stored in cert_subject
|
||||
// and just mark fields as "see certificate details"
|
||||
if info.cert_subject.contains("self signed") || info.cert_subject.contains("Self-Signed") {
|
||||
@@ -104,7 +104,7 @@ impl TlsAnalyzerTool {
|
||||
}
|
||||
|
||||
struct TlsInfo {
|
||||
protocol_version: String,
|
||||
_protocol_version: String,
|
||||
cert_subject: String,
|
||||
cert_issuer: String,
|
||||
cert_not_before: String,
|
||||
@@ -152,111 +152,194 @@ impl PentestTool for TlsAnalyzerTool {
|
||||
&'a self,
|
||||
input: serde_json::Value,
|
||||
context: &'a PentestToolContext,
|
||||
) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<PentestToolResult, CoreError>> + Send + 'a>> {
|
||||
) -> std::pin::Pin<
|
||||
Box<dyn std::future::Future<Output = Result<PentestToolResult, CoreError>> + Send + 'a>,
|
||||
> {
|
||||
Box::pin(async move {
|
||||
let url = input
|
||||
.get("url")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| CoreError::Dast("Missing required 'url' parameter".to_string()))?;
|
||||
let url = input
|
||||
.get("url")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| CoreError::Dast("Missing required 'url' parameter".to_string()))?;
|
||||
|
||||
let host = Self::extract_host(url)
|
||||
.unwrap_or_else(|| url.to_string());
|
||||
let host = Self::extract_host(url).unwrap_or_else(|| url.to_string());
|
||||
|
||||
let port = input
|
||||
.get("port")
|
||||
.and_then(|v| v.as_u64())
|
||||
.map(|p| p as u16)
|
||||
.unwrap_or_else(|| Self::extract_port(url));
|
||||
let port = input
|
||||
.get("port")
|
||||
.and_then(|v| v.as_u64())
|
||||
.map(|p| p as u16)
|
||||
.unwrap_or_else(|| Self::extract_port(url));
|
||||
|
||||
let target_id = context
|
||||
.target
|
||||
.id
|
||||
.map(|oid| oid.to_hex())
|
||||
.unwrap_or_else(|| "unknown".to_string());
|
||||
let target_id = context
|
||||
.target
|
||||
.id
|
||||
.map(|oid| oid.to_hex())
|
||||
.unwrap_or_else(|| "unknown".to_string());
|
||||
|
||||
let mut findings = Vec::new();
|
||||
let mut tls_data = json!({});
|
||||
let mut findings = Vec::new();
|
||||
let mut tls_data = json!({});
|
||||
|
||||
// First check: does the server even support HTTPS?
|
||||
let https_url = if url.starts_with("https://") {
|
||||
url.to_string()
|
||||
} else if url.starts_with("http://") {
|
||||
url.replace("http://", "https://")
|
||||
} else {
|
||||
format!("https://{url}")
|
||||
};
|
||||
// First check: does the server even support HTTPS?
|
||||
let https_url = if url.starts_with("https://") {
|
||||
url.to_string()
|
||||
} else if url.starts_with("http://") {
|
||||
url.replace("http://", "https://")
|
||||
} else {
|
||||
format!("https://{url}")
|
||||
};
|
||||
|
||||
// Check if HTTP redirects to HTTPS
|
||||
let http_url = if url.starts_with("http://") {
|
||||
url.to_string()
|
||||
} else if url.starts_with("https://") {
|
||||
url.replace("https://", "http://")
|
||||
} else {
|
||||
format!("http://{url}")
|
||||
};
|
||||
// Check if HTTP redirects to HTTPS
|
||||
let http_url = if url.starts_with("http://") {
|
||||
url.to_string()
|
||||
} else if url.starts_with("https://") {
|
||||
url.replace("https://", "http://")
|
||||
} else {
|
||||
format!("http://{url}")
|
||||
};
|
||||
|
||||
match self.http.get(&http_url).send().await {
|
||||
Ok(resp) => {
|
||||
let final_url = resp.url().to_string();
|
||||
let redirects_to_https = final_url.starts_with("https://");
|
||||
tls_data["http_redirects_to_https"] = json!(redirects_to_https);
|
||||
match self.http.get(&http_url).send().await {
|
||||
Ok(resp) => {
|
||||
let final_url = resp.url().to_string();
|
||||
let redirects_to_https = final_url.starts_with("https://");
|
||||
tls_data["http_redirects_to_https"] = json!(redirects_to_https);
|
||||
|
||||
if !redirects_to_https {
|
||||
let evidence = DastEvidence {
|
||||
request_method: "GET".to_string(),
|
||||
request_url: http_url.clone(),
|
||||
request_headers: None,
|
||||
request_body: None,
|
||||
response_status: resp.status().as_u16(),
|
||||
response_headers: None,
|
||||
response_snippet: Some(format!("Final URL: {final_url}")),
|
||||
screenshot_path: None,
|
||||
payload: None,
|
||||
response_time_ms: None,
|
||||
};
|
||||
if !redirects_to_https {
|
||||
let evidence = DastEvidence {
|
||||
request_method: "GET".to_string(),
|
||||
request_url: http_url.clone(),
|
||||
request_headers: None,
|
||||
request_body: None,
|
||||
response_status: resp.status().as_u16(),
|
||||
response_headers: None,
|
||||
response_snippet: Some(format!("Final URL: {final_url}")),
|
||||
screenshot_path: None,
|
||||
payload: None,
|
||||
response_time_ms: None,
|
||||
};
|
||||
|
||||
let mut finding = DastFinding::new(
|
||||
String::new(),
|
||||
target_id.clone(),
|
||||
DastVulnType::TlsMisconfiguration,
|
||||
format!("HTTP does not redirect to HTTPS for {host}"),
|
||||
format!(
|
||||
"HTTP requests to {host} are not redirected to HTTPS. \
|
||||
let mut finding = DastFinding::new(
|
||||
String::new(),
|
||||
target_id.clone(),
|
||||
DastVulnType::TlsMisconfiguration,
|
||||
format!("HTTP does not redirect to HTTPS for {host}"),
|
||||
format!(
|
||||
"HTTP requests to {host} are not redirected to HTTPS. \
|
||||
Users accessing the site via HTTP will have their traffic \
|
||||
transmitted in cleartext."
|
||||
),
|
||||
Severity::Medium,
|
||||
http_url.clone(),
|
||||
"GET".to_string(),
|
||||
);
|
||||
finding.cwe = Some("CWE-319".to_string());
|
||||
finding.evidence = vec![evidence];
|
||||
finding.remediation = Some(
|
||||
"Configure the web server to redirect all HTTP requests to HTTPS \
|
||||
),
|
||||
Severity::Medium,
|
||||
http_url.clone(),
|
||||
"GET".to_string(),
|
||||
);
|
||||
finding.cwe = Some("CWE-319".to_string());
|
||||
finding.evidence = vec![evidence];
|
||||
finding.remediation = Some(
|
||||
"Configure the web server to redirect all HTTP requests to HTTPS \
|
||||
using a 301 redirect."
|
||||
.to_string(),
|
||||
);
|
||||
findings.push(finding);
|
||||
.to_string(),
|
||||
);
|
||||
findings.push(finding);
|
||||
}
|
||||
}
|
||||
Err(_) => {
|
||||
tls_data["http_check_error"] = json!("Could not connect via HTTP");
|
||||
}
|
||||
}
|
||||
Err(_) => {
|
||||
tls_data["http_check_error"] = json!("Could not connect via HTTP");
|
||||
}
|
||||
}
|
||||
|
||||
// Perform TLS analysis
|
||||
match Self::check_tls(&host, port).await {
|
||||
Ok(tls_info) => {
|
||||
tls_data["host"] = json!(host);
|
||||
tls_data["port"] = json!(port);
|
||||
tls_data["cert_subject"] = json!(tls_info.cert_subject);
|
||||
tls_data["cert_issuer"] = json!(tls_info.cert_issuer);
|
||||
tls_data["cert_not_before"] = json!(tls_info.cert_not_before);
|
||||
tls_data["cert_not_after"] = json!(tls_info.cert_not_after);
|
||||
tls_data["alpn_protocol"] = json!(tls_info.alpn_protocol);
|
||||
tls_data["san_names"] = json!(tls_info.san_names);
|
||||
// Perform TLS analysis
|
||||
match Self::check_tls(&host, port).await {
|
||||
Ok(tls_info) => {
|
||||
tls_data["host"] = json!(host);
|
||||
tls_data["port"] = json!(port);
|
||||
tls_data["cert_subject"] = json!(tls_info.cert_subject);
|
||||
tls_data["cert_issuer"] = json!(tls_info.cert_issuer);
|
||||
tls_data["cert_not_before"] = json!(tls_info.cert_not_before);
|
||||
tls_data["cert_not_after"] = json!(tls_info.cert_not_after);
|
||||
tls_data["alpn_protocol"] = json!(tls_info.alpn_protocol);
|
||||
tls_data["san_names"] = json!(tls_info.san_names);
|
||||
|
||||
if tls_info.cert_expired {
|
||||
if tls_info.cert_expired {
|
||||
let evidence = DastEvidence {
|
||||
request_method: "TLS".to_string(),
|
||||
request_url: format!("{host}:{port}"),
|
||||
request_headers: None,
|
||||
request_body: None,
|
||||
response_status: 0,
|
||||
response_headers: None,
|
||||
response_snippet: Some(format!(
|
||||
"Certificate expired. Not After: {}",
|
||||
tls_info.cert_not_after
|
||||
)),
|
||||
screenshot_path: None,
|
||||
payload: None,
|
||||
response_time_ms: None,
|
||||
};
|
||||
|
||||
let mut finding = DastFinding::new(
|
||||
String::new(),
|
||||
target_id.clone(),
|
||||
DastVulnType::TlsMisconfiguration,
|
||||
format!("Expired TLS certificate for {host}"),
|
||||
format!(
|
||||
"The TLS certificate for {host} has expired. \
|
||||
Browsers will show security warnings to users."
|
||||
),
|
||||
Severity::High,
|
||||
format!("https://{host}:{port}"),
|
||||
"TLS".to_string(),
|
||||
);
|
||||
finding.cwe = Some("CWE-295".to_string());
|
||||
finding.evidence = vec![evidence];
|
||||
finding.remediation = Some(
|
||||
"Renew the TLS certificate. Consider using automated certificate \
|
||||
management with Let's Encrypt or a similar CA."
|
||||
.to_string(),
|
||||
);
|
||||
findings.push(finding);
|
||||
warn!(host, "Expired TLS certificate");
|
||||
}
|
||||
|
||||
if tls_info.cert_self_signed {
|
||||
let evidence = DastEvidence {
|
||||
request_method: "TLS".to_string(),
|
||||
request_url: format!("{host}:{port}"),
|
||||
request_headers: None,
|
||||
request_body: None,
|
||||
response_status: 0,
|
||||
response_headers: None,
|
||||
response_snippet: Some("Self-signed certificate detected".to_string()),
|
||||
screenshot_path: None,
|
||||
payload: None,
|
||||
response_time_ms: None,
|
||||
};
|
||||
|
||||
let mut finding = DastFinding::new(
|
||||
String::new(),
|
||||
target_id.clone(),
|
||||
DastVulnType::TlsMisconfiguration,
|
||||
format!("Self-signed TLS certificate for {host}"),
|
||||
format!(
|
||||
"The TLS certificate for {host} is self-signed and not issued by a \
|
||||
trusted certificate authority. Browsers will show security warnings."
|
||||
),
|
||||
Severity::Medium,
|
||||
format!("https://{host}:{port}"),
|
||||
"TLS".to_string(),
|
||||
);
|
||||
finding.cwe = Some("CWE-295".to_string());
|
||||
finding.evidence = vec![evidence];
|
||||
finding.remediation = Some(
|
||||
"Replace the self-signed certificate with one issued by a trusted \
|
||||
certificate authority."
|
||||
.to_string(),
|
||||
);
|
||||
findings.push(finding);
|
||||
warn!(host, "Self-signed certificate");
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
tls_data["tls_error"] = json!(e.to_string());
|
||||
|
||||
// TLS handshake failure itself is a finding
|
||||
let evidence = DastEvidence {
|
||||
request_method: "TLS".to_string(),
|
||||
request_url: format!("{host}:{port}"),
|
||||
@@ -264,10 +347,7 @@ impl PentestTool for TlsAnalyzerTool {
|
||||
request_body: None,
|
||||
response_status: 0,
|
||||
response_headers: None,
|
||||
response_snippet: Some(format!(
|
||||
"Certificate expired. Not After: {}",
|
||||
tls_info.cert_not_after
|
||||
)),
|
||||
response_snippet: Some(format!("TLS error: {e}")),
|
||||
screenshot_path: None,
|
||||
payload: None,
|
||||
response_time_ms: None,
|
||||
@@ -277,10 +357,9 @@ impl PentestTool for TlsAnalyzerTool {
|
||||
String::new(),
|
||||
target_id.clone(),
|
||||
DastVulnType::TlsMisconfiguration,
|
||||
format!("Expired TLS certificate for {host}"),
|
||||
format!("TLS handshake failure for {host}"),
|
||||
format!(
|
||||
"The TLS certificate for {host} has expired. \
|
||||
Browsers will show security warnings to users."
|
||||
"Could not establish a TLS connection to {host}:{port}. Error: {e}"
|
||||
),
|
||||
Severity::High,
|
||||
format!("https://{host}:{port}"),
|
||||
@@ -289,115 +368,37 @@ impl PentestTool for TlsAnalyzerTool {
|
||||
finding.cwe = Some("CWE-295".to_string());
|
||||
finding.evidence = vec![evidence];
|
||||
finding.remediation = Some(
|
||||
"Renew the TLS certificate. Consider using automated certificate \
|
||||
management with Let's Encrypt or a similar CA."
|
||||
.to_string(),
|
||||
);
|
||||
findings.push(finding);
|
||||
warn!(host, "Expired TLS certificate");
|
||||
}
|
||||
|
||||
if tls_info.cert_self_signed {
|
||||
let evidence = DastEvidence {
|
||||
request_method: "TLS".to_string(),
|
||||
request_url: format!("{host}:{port}"),
|
||||
request_headers: None,
|
||||
request_body: None,
|
||||
response_status: 0,
|
||||
response_headers: None,
|
||||
response_snippet: Some("Self-signed certificate detected".to_string()),
|
||||
screenshot_path: None,
|
||||
payload: None,
|
||||
response_time_ms: None,
|
||||
};
|
||||
|
||||
let mut finding = DastFinding::new(
|
||||
String::new(),
|
||||
target_id.clone(),
|
||||
DastVulnType::TlsMisconfiguration,
|
||||
format!("Self-signed TLS certificate for {host}"),
|
||||
format!(
|
||||
"The TLS certificate for {host} is self-signed and not issued by a \
|
||||
trusted certificate authority. Browsers will show security warnings."
|
||||
),
|
||||
Severity::Medium,
|
||||
format!("https://{host}:{port}"),
|
||||
"TLS".to_string(),
|
||||
);
|
||||
finding.cwe = Some("CWE-295".to_string());
|
||||
finding.evidence = vec![evidence];
|
||||
finding.remediation = Some(
|
||||
"Replace the self-signed certificate with one issued by a trusted \
|
||||
certificate authority."
|
||||
.to_string(),
|
||||
);
|
||||
findings.push(finding);
|
||||
warn!(host, "Self-signed certificate");
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
tls_data["tls_error"] = json!(e.to_string());
|
||||
|
||||
// TLS handshake failure itself is a finding
|
||||
let evidence = DastEvidence {
|
||||
request_method: "TLS".to_string(),
|
||||
request_url: format!("{host}:{port}"),
|
||||
request_headers: None,
|
||||
request_body: None,
|
||||
response_status: 0,
|
||||
response_headers: None,
|
||||
response_snippet: Some(format!("TLS error: {e}")),
|
||||
screenshot_path: None,
|
||||
payload: None,
|
||||
response_time_ms: None,
|
||||
};
|
||||
|
||||
let mut finding = DastFinding::new(
|
||||
String::new(),
|
||||
target_id.clone(),
|
||||
DastVulnType::TlsMisconfiguration,
|
||||
format!("TLS handshake failure for {host}"),
|
||||
format!(
|
||||
"Could not establish a TLS connection to {host}:{port}. Error: {e}"
|
||||
),
|
||||
Severity::High,
|
||||
format!("https://{host}:{port}"),
|
||||
"TLS".to_string(),
|
||||
);
|
||||
finding.cwe = Some("CWE-295".to_string());
|
||||
finding.evidence = vec![evidence];
|
||||
finding.remediation = Some(
|
||||
"Ensure TLS is properly configured on the server. Check that the \
|
||||
"Ensure TLS is properly configured on the server. Check that the \
|
||||
certificate is valid and the server supports modern TLS versions."
|
||||
.to_string(),
|
||||
);
|
||||
findings.push(finding);
|
||||
.to_string(),
|
||||
);
|
||||
findings.push(finding);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check strict transport security via an HTTPS request
|
||||
match self.http.get(&https_url).send().await {
|
||||
Ok(resp) => {
|
||||
let hsts = resp.headers().get("strict-transport-security");
|
||||
tls_data["hsts_header"] = json!(hsts.map(|v| v.to_str().unwrap_or("")));
|
||||
// Check strict transport security via an HTTPS request
|
||||
match self.http.get(&https_url).send().await {
|
||||
Ok(resp) => {
|
||||
let hsts = resp.headers().get("strict-transport-security");
|
||||
tls_data["hsts_header"] = json!(hsts.map(|v| v.to_str().unwrap_or("")));
|
||||
|
||||
if hsts.is_none() {
|
||||
let evidence = DastEvidence {
|
||||
request_method: "GET".to_string(),
|
||||
request_url: https_url.clone(),
|
||||
request_headers: None,
|
||||
request_body: None,
|
||||
response_status: resp.status().as_u16(),
|
||||
response_headers: None,
|
||||
response_snippet: Some(
|
||||
"Strict-Transport-Security header not present".to_string(),
|
||||
),
|
||||
screenshot_path: None,
|
||||
payload: None,
|
||||
response_time_ms: None,
|
||||
};
|
||||
if hsts.is_none() {
|
||||
let evidence = DastEvidence {
|
||||
request_method: "GET".to_string(),
|
||||
request_url: https_url.clone(),
|
||||
request_headers: None,
|
||||
request_body: None,
|
||||
response_status: resp.status().as_u16(),
|
||||
response_headers: None,
|
||||
response_snippet: Some(
|
||||
"Strict-Transport-Security header not present".to_string(),
|
||||
),
|
||||
screenshot_path: None,
|
||||
payload: None,
|
||||
response_time_ms: None,
|
||||
};
|
||||
|
||||
let mut finding = DastFinding::new(
|
||||
let mut finding = DastFinding::new(
|
||||
String::new(),
|
||||
target_id.clone(),
|
||||
DastVulnType::TlsMisconfiguration,
|
||||
@@ -410,33 +411,33 @@ impl PentestTool for TlsAnalyzerTool {
|
||||
https_url.clone(),
|
||||
"GET".to_string(),
|
||||
);
|
||||
finding.cwe = Some("CWE-319".to_string());
|
||||
finding.evidence = vec![evidence];
|
||||
finding.remediation = Some(
|
||||
finding.cwe = Some("CWE-319".to_string());
|
||||
finding.evidence = vec![evidence];
|
||||
finding.remediation = Some(
|
||||
"Add the Strict-Transport-Security header with an appropriate max-age. \
|
||||
Example: 'Strict-Transport-Security: max-age=31536000; includeSubDomains'."
|
||||
.to_string(),
|
||||
);
|
||||
findings.push(finding);
|
||||
findings.push(finding);
|
||||
}
|
||||
}
|
||||
Err(_) => {
|
||||
tls_data["https_check_error"] = json!("Could not connect via HTTPS");
|
||||
}
|
||||
}
|
||||
Err(_) => {
|
||||
tls_data["https_check_error"] = json!("Could not connect via HTTPS");
|
||||
}
|
||||
}
|
||||
|
||||
let count = findings.len();
|
||||
info!(host = %host, findings = count, "TLS analysis complete");
|
||||
let count = findings.len();
|
||||
info!(host = %host, findings = count, "TLS analysis complete");
|
||||
|
||||
Ok(PentestToolResult {
|
||||
summary: if count > 0 {
|
||||
format!("Found {count} TLS configuration issues for {host}.")
|
||||
} else {
|
||||
format!("TLS configuration looks good for {host}.")
|
||||
},
|
||||
findings,
|
||||
data: tls_data,
|
||||
})
|
||||
Ok(PentestToolResult {
|
||||
summary: if count > 0 {
|
||||
format!("Found {count} TLS configuration issues for {host}.")
|
||||
} else {
|
||||
format!("TLS configuration looks good for {host}.")
|
||||
},
|
||||
findings,
|
||||
data: tls_data,
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
use compliance_core::error::CoreError;
|
||||
use compliance_core::traits::dast_agent::{DastAgent, DastContext, DiscoveredEndpoint, EndpointParameter};
|
||||
use compliance_core::traits::dast_agent::{
|
||||
DastAgent, DastContext, DiscoveredEndpoint, EndpointParameter,
|
||||
};
|
||||
use compliance_core::traits::pentest_tool::{PentestTool, PentestToolContext, PentestToolResult};
|
||||
use serde_json::json;
|
||||
|
||||
@@ -7,30 +9,52 @@ use crate::agents::xss::XssAgent;
|
||||
|
||||
/// PentestTool wrapper around the existing XssAgent.
|
||||
pub struct XssTool {
|
||||
http: reqwest::Client,
|
||||
_http: reqwest::Client,
|
||||
agent: XssAgent,
|
||||
}
|
||||
|
||||
impl XssTool {
|
||||
pub fn new(http: reqwest::Client) -> Self {
|
||||
let agent = XssAgent::new(http.clone());
|
||||
Self { http, agent }
|
||||
Self { _http: http, agent }
|
||||
}
|
||||
|
||||
fn parse_endpoints(input: &serde_json::Value) -> Vec<DiscoveredEndpoint> {
|
||||
let mut endpoints = Vec::new();
|
||||
if let Some(arr) = input.get("endpoints").and_then(|v| v.as_array()) {
|
||||
for ep in arr {
|
||||
let url = ep.get("url").and_then(|v| v.as_str()).unwrap_or_default().to_string();
|
||||
let method = ep.get("method").and_then(|v| v.as_str()).unwrap_or("GET").to_string();
|
||||
let url = ep
|
||||
.get("url")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or_default()
|
||||
.to_string();
|
||||
let method = ep
|
||||
.get("method")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("GET")
|
||||
.to_string();
|
||||
let mut parameters = Vec::new();
|
||||
if let Some(params) = ep.get("parameters").and_then(|v| v.as_array()) {
|
||||
for p in params {
|
||||
parameters.push(EndpointParameter {
|
||||
name: p.get("name").and_then(|v| v.as_str()).unwrap_or_default().to_string(),
|
||||
location: p.get("location").and_then(|v| v.as_str()).unwrap_or("query").to_string(),
|
||||
param_type: p.get("param_type").and_then(|v| v.as_str()).map(String::from),
|
||||
example_value: p.get("example_value").and_then(|v| v.as_str()).map(String::from),
|
||||
name: p
|
||||
.get("name")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or_default()
|
||||
.to_string(),
|
||||
location: p
|
||||
.get("location")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("query")
|
||||
.to_string(),
|
||||
param_type: p
|
||||
.get("param_type")
|
||||
.and_then(|v| v.as_str())
|
||||
.map(String::from),
|
||||
example_value: p
|
||||
.get("example_value")
|
||||
.and_then(|v| v.as_str())
|
||||
.map(String::from),
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -38,8 +62,14 @@ impl XssTool {
|
||||
url,
|
||||
method,
|
||||
parameters,
|
||||
content_type: ep.get("content_type").and_then(|v| v.as_str()).map(String::from),
|
||||
requires_auth: ep.get("requires_auth").and_then(|v| v.as_bool()).unwrap_or(false),
|
||||
content_type: ep
|
||||
.get("content_type")
|
||||
.and_then(|v| v.as_str())
|
||||
.map(String::from),
|
||||
requires_auth: ep
|
||||
.get("requires_auth")
|
||||
.and_then(|v| v.as_bool())
|
||||
.unwrap_or(false),
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -47,6 +77,91 @@ impl XssTool {
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use serde_json::json;
|
||||
|
||||
#[test]
|
||||
fn parse_endpoints_basic() {
|
||||
let input = json!({
|
||||
"endpoints": [
|
||||
{
|
||||
"url": "https://example.com/search",
|
||||
"method": "GET",
|
||||
"parameters": [
|
||||
{ "name": "q", "location": "query" }
|
||||
]
|
||||
}
|
||||
]
|
||||
});
|
||||
let endpoints = XssTool::parse_endpoints(&input);
|
||||
assert_eq!(endpoints.len(), 1);
|
||||
assert_eq!(endpoints[0].url, "https://example.com/search");
|
||||
assert_eq!(endpoints[0].method, "GET");
|
||||
assert_eq!(endpoints[0].parameters.len(), 1);
|
||||
assert_eq!(endpoints[0].parameters[0].name, "q");
|
||||
assert_eq!(endpoints[0].parameters[0].location, "query");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_endpoints_empty() {
|
||||
let input = json!({ "endpoints": [] });
|
||||
assert!(XssTool::parse_endpoints(&input).is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_endpoints_missing_key() {
|
||||
let input = json!({});
|
||||
assert!(XssTool::parse_endpoints(&input).is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_endpoints_defaults() {
|
||||
let input = json!({
|
||||
"endpoints": [
|
||||
{ "url": "https://example.com/api", "parameters": [] }
|
||||
]
|
||||
});
|
||||
let endpoints = XssTool::parse_endpoints(&input);
|
||||
assert_eq!(endpoints[0].method, "GET"); // default
|
||||
assert!(!endpoints[0].requires_auth); // default false
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_endpoints_full_params() {
|
||||
let input = json!({
|
||||
"endpoints": [{
|
||||
"url": "https://example.com",
|
||||
"method": "POST",
|
||||
"content_type": "application/json",
|
||||
"requires_auth": true,
|
||||
"parameters": [{
|
||||
"name": "body",
|
||||
"location": "body",
|
||||
"param_type": "string",
|
||||
"example_value": "test"
|
||||
}]
|
||||
}]
|
||||
});
|
||||
let endpoints = XssTool::parse_endpoints(&input);
|
||||
assert_eq!(endpoints[0].method, "POST");
|
||||
assert_eq!(
|
||||
endpoints[0].content_type.as_deref(),
|
||||
Some("application/json")
|
||||
);
|
||||
assert!(endpoints[0].requires_auth);
|
||||
assert_eq!(
|
||||
endpoints[0].parameters[0].param_type.as_deref(),
|
||||
Some("string")
|
||||
);
|
||||
assert_eq!(
|
||||
endpoints[0].parameters[0].example_value.as_deref(),
|
||||
Some("test")
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
impl PentestTool for XssTool {
|
||||
fn name(&self) -> &str {
|
||||
"xss_scanner"
|
||||
@@ -100,35 +215,37 @@ impl PentestTool for XssTool {
|
||||
&'a self,
|
||||
input: serde_json::Value,
|
||||
context: &'a PentestToolContext,
|
||||
) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<PentestToolResult, CoreError>> + Send + 'a>> {
|
||||
) -> std::pin::Pin<
|
||||
Box<dyn std::future::Future<Output = Result<PentestToolResult, CoreError>> + Send + 'a>,
|
||||
> {
|
||||
Box::pin(async move {
|
||||
let endpoints = Self::parse_endpoints(&input);
|
||||
if endpoints.is_empty() {
|
||||
return Ok(PentestToolResult {
|
||||
summary: "No endpoints provided to test.".to_string(),
|
||||
findings: Vec::new(),
|
||||
data: json!({}),
|
||||
});
|
||||
}
|
||||
let endpoints = Self::parse_endpoints(&input);
|
||||
if endpoints.is_empty() {
|
||||
return Ok(PentestToolResult {
|
||||
summary: "No endpoints provided to test.".to_string(),
|
||||
findings: Vec::new(),
|
||||
data: json!({}),
|
||||
});
|
||||
}
|
||||
|
||||
let dast_context = DastContext {
|
||||
endpoints,
|
||||
technologies: Vec::new(),
|
||||
sast_hints: Vec::new(),
|
||||
};
|
||||
let dast_context = DastContext {
|
||||
endpoints,
|
||||
technologies: Vec::new(),
|
||||
sast_hints: Vec::new(),
|
||||
};
|
||||
|
||||
let findings = self.agent.run(&context.target, &dast_context).await?;
|
||||
let count = findings.len();
|
||||
let findings = self.agent.run(&context.target, &dast_context).await?;
|
||||
let count = findings.len();
|
||||
|
||||
Ok(PentestToolResult {
|
||||
summary: if count > 0 {
|
||||
format!("Found {count} XSS vulnerabilities.")
|
||||
} else {
|
||||
"No XSS vulnerabilities detected.".to_string()
|
||||
},
|
||||
findings,
|
||||
data: json!({ "endpoints_tested": dast_context.endpoints.len() }),
|
||||
})
|
||||
Ok(PentestToolResult {
|
||||
summary: if count > 0 {
|
||||
format!("Found {count} XSS vulnerabilities.")
|
||||
} else {
|
||||
"No XSS vulnerabilities detected.".to_string()
|
||||
},
|
||||
findings,
|
||||
data: json!({ "endpoints_tested": dast_context.endpoints.len() }),
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
4
compliance-dast/tests/agents.rs
Normal file
4
compliance-dast/tests/agents.rs
Normal file
@@ -0,0 +1,4 @@
|
||||
// Integration tests for DAST agents.
|
||||
//
|
||||
// Test individual security testing agents (XSS, SQLi, SSRF, etc.)
|
||||
// against controlled test targets.
|
||||
@@ -94,3 +94,64 @@ fn build_context_header(file_path: &str, qualified_name: &str, kind: &str) -> St
|
||||
format!("// {file_path} | {kind}")
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_build_context_header_with_parent() {
|
||||
let result =
|
||||
build_context_header("src/main.rs", "src/main.rs::MyStruct::my_method", "method");
|
||||
assert_eq!(result, "// src/main.rs | method in src/main.rs::MyStruct");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_build_context_header_top_level() {
|
||||
let result = build_context_header("src/lib.rs", "main", "function");
|
||||
assert_eq!(result, "// src/lib.rs | function");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_build_context_header_single_parent() {
|
||||
let result = build_context_header("src/lib.rs", "src/lib.rs::do_stuff", "function");
|
||||
assert_eq!(result, "// src/lib.rs | function in src/lib.rs");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_build_context_header_deep_nesting() {
|
||||
let result = build_context_header(
|
||||
"src/mod.rs",
|
||||
"src/mod.rs::Outer::Inner::deep_fn",
|
||||
"function",
|
||||
);
|
||||
assert_eq!(
|
||||
result,
|
||||
"// src/mod.rs | function in src/mod.rs::Outer::Inner"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_build_context_header_empty_strings() {
|
||||
let result = build_context_header("", "", "function");
|
||||
assert_eq!(result, "// | function");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_code_chunk_struct_fields() {
|
||||
let chunk = CodeChunk {
|
||||
qualified_name: "main".to_string(),
|
||||
kind: "function".to_string(),
|
||||
file_path: "src/main.rs".to_string(),
|
||||
start_line: 1,
|
||||
end_line: 10,
|
||||
language: "rust".to_string(),
|
||||
content: "fn main() {}".to_string(),
|
||||
context_header: "// src/main.rs | function".to_string(),
|
||||
token_estimate: 3,
|
||||
};
|
||||
assert_eq!(chunk.start_line, 1);
|
||||
assert_eq!(chunk.end_line, 10);
|
||||
assert_eq!(chunk.language, "rust");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -253,3 +253,215 @@ fn detect_communities_with_assignment(code_graph: &mut CodeGraph) -> u32 {
|
||||
|
||||
next_id
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use compliance_core::models::graph::{CodeEdgeKind, CodeNode, CodeNodeKind};
|
||||
use petgraph::graph::DiGraph;
|
||||
|
||||
fn make_node(qualified_name: &str, graph_index: u32) -> CodeNode {
|
||||
CodeNode {
|
||||
id: None,
|
||||
repo_id: "test".to_string(),
|
||||
graph_build_id: "build1".to_string(),
|
||||
qualified_name: qualified_name.to_string(),
|
||||
name: qualified_name.to_string(),
|
||||
kind: CodeNodeKind::Function,
|
||||
file_path: "test.rs".to_string(),
|
||||
start_line: 1,
|
||||
end_line: 10,
|
||||
language: "rust".to_string(),
|
||||
community_id: None,
|
||||
is_entry_point: false,
|
||||
graph_index: Some(graph_index),
|
||||
}
|
||||
}
|
||||
|
||||
fn make_empty_code_graph() -> CodeGraph {
|
||||
CodeGraph {
|
||||
graph: DiGraph::new(),
|
||||
node_map: HashMap::new(),
|
||||
nodes: Vec::new(),
|
||||
edges: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_detect_communities_empty_graph() {
|
||||
let cg = make_empty_code_graph();
|
||||
assert_eq!(detect_communities(&cg), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_detect_communities_single_node_no_edges() {
|
||||
let mut graph = DiGraph::new();
|
||||
let idx = graph.add_node("a".to_string());
|
||||
let mut node_map = HashMap::new();
|
||||
node_map.insert("a".to_string(), idx);
|
||||
|
||||
let cg = CodeGraph {
|
||||
graph,
|
||||
node_map,
|
||||
nodes: vec![make_node("a", 0)],
|
||||
edges: Vec::new(),
|
||||
};
|
||||
// Single node with no edges => 1 community (itself)
|
||||
assert_eq!(detect_communities(&cg), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_detect_communities_isolated_nodes() {
|
||||
let mut graph = DiGraph::new();
|
||||
let a = graph.add_node("a".to_string());
|
||||
let b = graph.add_node("b".to_string());
|
||||
let c = graph.add_node("c".to_string());
|
||||
let mut node_map = HashMap::new();
|
||||
node_map.insert("a".to_string(), a);
|
||||
node_map.insert("b".to_string(), b);
|
||||
node_map.insert("c".to_string(), c);
|
||||
|
||||
let cg = CodeGraph {
|
||||
graph,
|
||||
node_map,
|
||||
nodes: vec![make_node("a", 0), make_node("b", 1), make_node("c", 2)],
|
||||
edges: Vec::new(),
|
||||
};
|
||||
// 3 isolated nodes => 3 communities
|
||||
assert_eq!(detect_communities(&cg), 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_detect_communities_fully_connected() {
|
||||
let mut graph = DiGraph::new();
|
||||
let a = graph.add_node("a".to_string());
|
||||
let b = graph.add_node("b".to_string());
|
||||
let c = graph.add_node("c".to_string());
|
||||
graph.add_edge(a, b, CodeEdgeKind::Calls);
|
||||
graph.add_edge(b, c, CodeEdgeKind::Calls);
|
||||
graph.add_edge(c, a, CodeEdgeKind::Calls);
|
||||
|
||||
let mut node_map = HashMap::new();
|
||||
node_map.insert("a".to_string(), a);
|
||||
node_map.insert("b".to_string(), b);
|
||||
node_map.insert("c".to_string(), c);
|
||||
|
||||
let cg = CodeGraph {
|
||||
graph,
|
||||
node_map,
|
||||
nodes: vec![make_node("a", 0), make_node("b", 1), make_node("c", 2)],
|
||||
edges: Vec::new(),
|
||||
};
|
||||
let num = detect_communities(&cg);
|
||||
// Fully connected triangle should converge to 1 community
|
||||
assert!(num >= 1);
|
||||
assert!(num <= 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_detect_communities_two_clusters() {
|
||||
// Two separate triangles connected by a single weak edge
|
||||
let mut graph = DiGraph::new();
|
||||
let a = graph.add_node("a".to_string());
|
||||
let b = graph.add_node("b".to_string());
|
||||
let c = graph.add_node("c".to_string());
|
||||
let d = graph.add_node("d".to_string());
|
||||
let e = graph.add_node("e".to_string());
|
||||
let f = graph.add_node("f".to_string());
|
||||
|
||||
// Cluster 1: a-b-c fully connected
|
||||
graph.add_edge(a, b, CodeEdgeKind::Calls);
|
||||
graph.add_edge(b, a, CodeEdgeKind::Calls);
|
||||
graph.add_edge(b, c, CodeEdgeKind::Calls);
|
||||
graph.add_edge(c, b, CodeEdgeKind::Calls);
|
||||
graph.add_edge(a, c, CodeEdgeKind::Calls);
|
||||
graph.add_edge(c, a, CodeEdgeKind::Calls);
|
||||
|
||||
// Cluster 2: d-e-f fully connected
|
||||
graph.add_edge(d, e, CodeEdgeKind::Calls);
|
||||
graph.add_edge(e, d, CodeEdgeKind::Calls);
|
||||
graph.add_edge(e, f, CodeEdgeKind::Calls);
|
||||
graph.add_edge(f, e, CodeEdgeKind::Calls);
|
||||
graph.add_edge(d, f, CodeEdgeKind::Calls);
|
||||
graph.add_edge(f, d, CodeEdgeKind::Calls);
|
||||
|
||||
let mut node_map = HashMap::new();
|
||||
node_map.insert("a".to_string(), a);
|
||||
node_map.insert("b".to_string(), b);
|
||||
node_map.insert("c".to_string(), c);
|
||||
node_map.insert("d".to_string(), d);
|
||||
node_map.insert("e".to_string(), e);
|
||||
node_map.insert("f".to_string(), f);
|
||||
|
||||
let cg = CodeGraph {
|
||||
graph,
|
||||
node_map,
|
||||
nodes: vec![
|
||||
make_node("a", 0),
|
||||
make_node("b", 1),
|
||||
make_node("c", 2),
|
||||
make_node("d", 3),
|
||||
make_node("e", 4),
|
||||
make_node("f", 5),
|
||||
],
|
||||
edges: Vec::new(),
|
||||
};
|
||||
let num = detect_communities(&cg);
|
||||
// Two disconnected clusters should yield 2 communities
|
||||
assert_eq!(num, 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_apply_communities_assigns_ids() {
|
||||
let mut graph = DiGraph::new();
|
||||
let a = graph.add_node("a".to_string());
|
||||
let b = graph.add_node("b".to_string());
|
||||
graph.add_edge(a, b, CodeEdgeKind::Calls);
|
||||
|
||||
let mut node_map = HashMap::new();
|
||||
node_map.insert("a".to_string(), a);
|
||||
node_map.insert("b".to_string(), b);
|
||||
|
||||
let mut cg = CodeGraph {
|
||||
graph,
|
||||
node_map,
|
||||
nodes: vec![make_node("a", 0), make_node("b", 1)],
|
||||
edges: Vec::new(),
|
||||
};
|
||||
let count = apply_communities(&mut cg);
|
||||
assert!(count >= 1);
|
||||
// All nodes should have a community_id assigned
|
||||
for node in &cg.nodes {
|
||||
assert!(node.community_id.is_some());
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_apply_communities_empty() {
|
||||
let mut cg = make_empty_code_graph();
|
||||
assert_eq!(apply_communities(&mut cg), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_apply_communities_isolated_nodes_get_own_community() {
|
||||
let mut graph = DiGraph::new();
|
||||
let a = graph.add_node("a".to_string());
|
||||
let b = graph.add_node("b".to_string());
|
||||
let mut node_map = HashMap::new();
|
||||
node_map.insert("a".to_string(), a);
|
||||
node_map.insert("b".to_string(), b);
|
||||
|
||||
let mut cg = CodeGraph {
|
||||
graph,
|
||||
node_map,
|
||||
nodes: vec![make_node("a", 0), make_node("b", 1)],
|
||||
edges: Vec::new(),
|
||||
};
|
||||
let count = apply_communities(&mut cg);
|
||||
assert_eq!(count, 2);
|
||||
// Each isolated node should be in a different community
|
||||
let c0 = cg.nodes[0].community_id.unwrap();
|
||||
let c1 = cg.nodes[1].community_id.unwrap();
|
||||
assert_ne!(c0, c1);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -172,3 +172,185 @@ impl GraphEngine {
|
||||
ImpactAnalyzer::new(code_graph)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use compliance_core::models::graph::{CodeEdgeKind, CodeNode, CodeNodeKind};
|
||||
|
||||
fn make_node(qualified_name: &str) -> CodeNode {
|
||||
CodeNode {
|
||||
id: None,
|
||||
repo_id: "test".to_string(),
|
||||
graph_build_id: "build1".to_string(),
|
||||
qualified_name: qualified_name.to_string(),
|
||||
name: qualified_name
|
||||
.split("::")
|
||||
.last()
|
||||
.unwrap_or(qualified_name)
|
||||
.to_string(),
|
||||
kind: CodeNodeKind::Function,
|
||||
file_path: "src/main.rs".to_string(),
|
||||
start_line: 1,
|
||||
end_line: 10,
|
||||
language: "rust".to_string(),
|
||||
community_id: None,
|
||||
is_entry_point: false,
|
||||
graph_index: None,
|
||||
}
|
||||
}
|
||||
|
||||
fn build_test_node_map(names: &[&str]) -> HashMap<String, NodeIndex> {
|
||||
let mut graph: DiGraph<String, String> = DiGraph::new();
|
||||
let mut map = HashMap::new();
|
||||
for name in names {
|
||||
let idx = graph.add_node(name.to_string());
|
||||
map.insert(name.to_string(), idx);
|
||||
}
|
||||
map
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_resolve_edge_target_direct_match() {
|
||||
let engine = GraphEngine::new(1000);
|
||||
let node_map = build_test_node_map(&["src/main.rs::foo", "src/main.rs::bar"]);
|
||||
let result = engine.resolve_edge_target("src/main.rs::foo", &node_map);
|
||||
assert!(result.is_some());
|
||||
assert_eq!(result.unwrap(), node_map["src/main.rs::foo"]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_resolve_edge_target_short_name_match() {
|
||||
let engine = GraphEngine::new(1000);
|
||||
let node_map = build_test_node_map(&["src/main.rs::foo", "src/main.rs::bar"]);
|
||||
let result = engine.resolve_edge_target("foo", &node_map);
|
||||
assert!(result.is_some());
|
||||
assert_eq!(result.unwrap(), node_map["src/main.rs::foo"]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_resolve_edge_target_method_match() {
|
||||
let engine = GraphEngine::new(1000);
|
||||
let node_map = build_test_node_map(&["src/main.rs::MyStruct::do_thing"]);
|
||||
let result = engine.resolve_edge_target("do_thing", &node_map);
|
||||
assert!(result.is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_resolve_edge_target_self_method() {
|
||||
let engine = GraphEngine::new(1000);
|
||||
let node_map = build_test_node_map(&["src/main.rs::MyStruct::process"]);
|
||||
let result = engine.resolve_edge_target("self.process", &node_map);
|
||||
assert!(result.is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_resolve_edge_target_no_match() {
|
||||
let engine = GraphEngine::new(1000);
|
||||
let node_map = build_test_node_map(&["src/main.rs::foo"]);
|
||||
let result = engine.resolve_edge_target("nonexistent", &node_map);
|
||||
assert!(result.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_resolve_edge_target_empty_map() {
|
||||
let engine = GraphEngine::new(1000);
|
||||
let node_map = HashMap::new();
|
||||
let result = engine.resolve_edge_target("anything", &node_map);
|
||||
assert!(result.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_resolve_edge_target_dot_notation() {
|
||||
let engine = GraphEngine::new(1000);
|
||||
let node_map = build_test_node_map(&["src/app.js.handler"]);
|
||||
let result = engine.resolve_edge_target("handler", &node_map);
|
||||
assert!(result.is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_build_petgraph_empty() {
|
||||
let engine = GraphEngine::new(1000);
|
||||
let output = ParseOutput::default();
|
||||
let code_graph = engine.build_petgraph(output).unwrap();
|
||||
assert_eq!(code_graph.nodes.len(), 0);
|
||||
assert_eq!(code_graph.edges.len(), 0);
|
||||
assert_eq!(code_graph.graph.node_count(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_build_petgraph_nodes_get_graph_index() {
|
||||
let engine = GraphEngine::new(1000);
|
||||
let mut output = ParseOutput::default();
|
||||
output.nodes.push(make_node("src/main.rs::foo"));
|
||||
output.nodes.push(make_node("src/main.rs::bar"));
|
||||
|
||||
let code_graph = engine.build_petgraph(output).unwrap();
|
||||
assert_eq!(code_graph.nodes.len(), 2);
|
||||
assert_eq!(code_graph.graph.node_count(), 2);
|
||||
// All nodes should have a graph_index assigned
|
||||
for node in &code_graph.nodes {
|
||||
assert!(node.graph_index.is_some());
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_build_petgraph_resolves_edges() {
|
||||
let engine = GraphEngine::new(1000);
|
||||
let mut output = ParseOutput::default();
|
||||
output.nodes.push(make_node("src/main.rs::foo"));
|
||||
output.nodes.push(make_node("src/main.rs::bar"));
|
||||
output.edges.push(CodeEdge {
|
||||
id: None,
|
||||
repo_id: "test".to_string(),
|
||||
graph_build_id: "build1".to_string(),
|
||||
source: "src/main.rs::foo".to_string(),
|
||||
target: "bar".to_string(), // short name, should resolve
|
||||
kind: CodeEdgeKind::Calls,
|
||||
file_path: "src/main.rs".to_string(),
|
||||
line_number: Some(5),
|
||||
});
|
||||
|
||||
let code_graph = engine.build_petgraph(output).unwrap();
|
||||
assert_eq!(code_graph.edges.len(), 1);
|
||||
assert_eq!(code_graph.graph.edge_count(), 1);
|
||||
// The resolved edge target should be the full qualified name
|
||||
assert_eq!(code_graph.edges[0].target, "src/main.rs::bar");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_build_petgraph_skips_unresolved_edges() {
|
||||
let engine = GraphEngine::new(1000);
|
||||
let mut output = ParseOutput::default();
|
||||
output.nodes.push(make_node("src/main.rs::foo"));
|
||||
output.edges.push(CodeEdge {
|
||||
id: None,
|
||||
repo_id: "test".to_string(),
|
||||
graph_build_id: "build1".to_string(),
|
||||
source: "src/main.rs::foo".to_string(),
|
||||
target: "external_crate::something".to_string(),
|
||||
kind: CodeEdgeKind::Calls,
|
||||
file_path: "src/main.rs".to_string(),
|
||||
line_number: Some(5),
|
||||
});
|
||||
|
||||
let code_graph = engine.build_petgraph(output).unwrap();
|
||||
assert_eq!(code_graph.edges.len(), 0);
|
||||
assert_eq!(code_graph.graph.edge_count(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_code_graph_node_map_consistency() {
|
||||
let engine = GraphEngine::new(1000);
|
||||
let mut output = ParseOutput::default();
|
||||
output.nodes.push(make_node("a::b"));
|
||||
output.nodes.push(make_node("a::c"));
|
||||
output.nodes.push(make_node("a::d"));
|
||||
|
||||
let code_graph = engine.build_petgraph(output).unwrap();
|
||||
assert_eq!(code_graph.node_map.len(), 3);
|
||||
assert!(code_graph.node_map.contains_key("a::b"));
|
||||
assert!(code_graph.node_map.contains_key("a::c"));
|
||||
assert!(code_graph.node_map.contains_key("a::d"));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -222,3 +222,378 @@ impl<'a> ImpactAnalyzer<'a> {
|
||||
.find(|n| n.graph_index == Some(target_gi))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use compliance_core::models::graph::{CodeEdgeKind, CodeNode, CodeNodeKind};
|
||||
use petgraph::graph::DiGraph;
|
||||
use std::collections::HashMap;
|
||||
|
||||
fn make_node(
|
||||
qualified_name: &str,
|
||||
file_path: &str,
|
||||
start: u32,
|
||||
end: u32,
|
||||
graph_index: u32,
|
||||
is_entry: bool,
|
||||
kind: CodeNodeKind,
|
||||
) -> CodeNode {
|
||||
CodeNode {
|
||||
id: None,
|
||||
repo_id: "test".to_string(),
|
||||
graph_build_id: "build1".to_string(),
|
||||
qualified_name: qualified_name.to_string(),
|
||||
name: qualified_name
|
||||
.split("::")
|
||||
.last()
|
||||
.unwrap_or(qualified_name)
|
||||
.to_string(),
|
||||
kind,
|
||||
file_path: file_path.to_string(),
|
||||
start_line: start,
|
||||
end_line: end,
|
||||
language: "rust".to_string(),
|
||||
community_id: None,
|
||||
is_entry_point: is_entry,
|
||||
graph_index: Some(graph_index),
|
||||
}
|
||||
}
|
||||
|
||||
fn make_fn_node(
|
||||
qualified_name: &str,
|
||||
file_path: &str,
|
||||
start: u32,
|
||||
end: u32,
|
||||
gi: u32,
|
||||
) -> CodeNode {
|
||||
make_node(
|
||||
qualified_name,
|
||||
file_path,
|
||||
start,
|
||||
end,
|
||||
gi,
|
||||
false,
|
||||
CodeNodeKind::Function,
|
||||
)
|
||||
}
|
||||
|
||||
/// Build a simple linear graph: A -> B -> C
|
||||
fn build_linear_graph() -> CodeGraph {
|
||||
let mut graph = DiGraph::new();
|
||||
let a = graph.add_node("a".to_string());
|
||||
let b = graph.add_node("b".to_string());
|
||||
let c = graph.add_node("c".to_string());
|
||||
graph.add_edge(a, b, CodeEdgeKind::Calls);
|
||||
graph.add_edge(b, c, CodeEdgeKind::Calls);
|
||||
|
||||
let mut node_map = HashMap::new();
|
||||
node_map.insert("a".to_string(), a);
|
||||
node_map.insert("b".to_string(), b);
|
||||
node_map.insert("c".to_string(), c);
|
||||
|
||||
CodeGraph {
|
||||
graph,
|
||||
node_map,
|
||||
nodes: vec![
|
||||
make_fn_node("a", "src/main.rs", 1, 5, 0),
|
||||
make_fn_node("b", "src/main.rs", 7, 12, 1),
|
||||
make_fn_node("c", "src/main.rs", 14, 20, 2),
|
||||
],
|
||||
edges: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_bfs_reachable_outgoing_linear() {
|
||||
let cg = build_linear_graph();
|
||||
let analyzer = ImpactAnalyzer::new(&cg);
|
||||
let start = cg.node_map["a"];
|
||||
let reachable = analyzer.bfs_reachable(start, Direction::Outgoing);
|
||||
// From a, we can reach b and c
|
||||
assert_eq!(reachable.len(), 2);
|
||||
assert!(reachable.contains(&cg.node_map["b"]));
|
||||
assert!(reachable.contains(&cg.node_map["c"]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_bfs_reachable_incoming_linear() {
|
||||
let cg = build_linear_graph();
|
||||
let analyzer = ImpactAnalyzer::new(&cg);
|
||||
let start = cg.node_map["c"];
|
||||
let reachable = analyzer.bfs_reachable(start, Direction::Incoming);
|
||||
// c is reached by a and b
|
||||
assert_eq!(reachable.len(), 2);
|
||||
assert!(reachable.contains(&cg.node_map["a"]));
|
||||
assert!(reachable.contains(&cg.node_map["b"]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_bfs_reachable_no_neighbors() {
|
||||
let mut graph = DiGraph::new();
|
||||
let a = graph.add_node("a".to_string());
|
||||
let cg = CodeGraph {
|
||||
graph,
|
||||
node_map: [("a".to_string(), a)].into_iter().collect(),
|
||||
nodes: vec![make_fn_node("a", "src/main.rs", 1, 5, 0)],
|
||||
edges: Vec::new(),
|
||||
};
|
||||
let analyzer = ImpactAnalyzer::new(&cg);
|
||||
let reachable = analyzer.bfs_reachable(a, Direction::Outgoing);
|
||||
assert!(reachable.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_bfs_reachable_cycle() {
|
||||
let mut graph = DiGraph::new();
|
||||
let a = graph.add_node("a".to_string());
|
||||
let b = graph.add_node("b".to_string());
|
||||
graph.add_edge(a, b, CodeEdgeKind::Calls);
|
||||
graph.add_edge(b, a, CodeEdgeKind::Calls);
|
||||
|
||||
let cg = CodeGraph {
|
||||
graph,
|
||||
node_map: [("a".to_string(), a), ("b".to_string(), b)]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
nodes: vec![
|
||||
make_fn_node("a", "f.rs", 1, 5, 0),
|
||||
make_fn_node("b", "f.rs", 6, 10, 1),
|
||||
],
|
||||
edges: Vec::new(),
|
||||
};
|
||||
let analyzer = ImpactAnalyzer::new(&cg);
|
||||
let reachable = analyzer.bfs_reachable(a, Direction::Outgoing);
|
||||
// Should handle cycle without infinite loop
|
||||
assert_eq!(reachable.len(), 1);
|
||||
assert!(reachable.contains(&b));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_find_path_exists() {
|
||||
let cg = build_linear_graph();
|
||||
let analyzer = ImpactAnalyzer::new(&cg);
|
||||
let path = analyzer.find_path(cg.node_map["a"], cg.node_map["c"], 10);
|
||||
assert!(path.is_some());
|
||||
let names = path.unwrap();
|
||||
assert_eq!(names, vec!["a", "b", "c"]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_find_path_direct() {
|
||||
let cg = build_linear_graph();
|
||||
let analyzer = ImpactAnalyzer::new(&cg);
|
||||
let path = analyzer.find_path(cg.node_map["a"], cg.node_map["b"], 10);
|
||||
assert!(path.is_some());
|
||||
let names = path.unwrap();
|
||||
assert_eq!(names, vec!["a", "b"]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_find_path_same_node() {
|
||||
let cg = build_linear_graph();
|
||||
let analyzer = ImpactAnalyzer::new(&cg);
|
||||
let path = analyzer.find_path(cg.node_map["a"], cg.node_map["a"], 10);
|
||||
assert!(path.is_some());
|
||||
let names = path.unwrap();
|
||||
assert_eq!(names, vec!["a"]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_find_path_no_connection() {
|
||||
let mut graph = DiGraph::new();
|
||||
let a = graph.add_node("a".to_string());
|
||||
let b = graph.add_node("b".to_string());
|
||||
// No edge between a and b
|
||||
let cg = CodeGraph {
|
||||
graph,
|
||||
node_map: [("a".to_string(), a), ("b".to_string(), b)]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
nodes: vec![
|
||||
make_fn_node("a", "f.rs", 1, 5, 0),
|
||||
make_fn_node("b", "f.rs", 6, 10, 1),
|
||||
],
|
||||
edges: Vec::new(),
|
||||
};
|
||||
let analyzer = ImpactAnalyzer::new(&cg);
|
||||
let path = analyzer.find_path(a, b, 10);
|
||||
assert!(path.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_find_path_depth_limited() {
|
||||
// Build a long chain: a -> b -> c -> d -> e
|
||||
let mut graph = DiGraph::new();
|
||||
let a = graph.add_node("a".to_string());
|
||||
let b = graph.add_node("b".to_string());
|
||||
let c = graph.add_node("c".to_string());
|
||||
let d = graph.add_node("d".to_string());
|
||||
let e = graph.add_node("e".to_string());
|
||||
graph.add_edge(a, b, CodeEdgeKind::Calls);
|
||||
graph.add_edge(b, c, CodeEdgeKind::Calls);
|
||||
graph.add_edge(c, d, CodeEdgeKind::Calls);
|
||||
graph.add_edge(d, e, CodeEdgeKind::Calls);
|
||||
|
||||
let mut node_map = HashMap::new();
|
||||
node_map.insert("a".to_string(), a);
|
||||
node_map.insert("b".to_string(), b);
|
||||
node_map.insert("c".to_string(), c);
|
||||
node_map.insert("d".to_string(), d);
|
||||
node_map.insert("e".to_string(), e);
|
||||
|
||||
let cg = CodeGraph {
|
||||
graph,
|
||||
node_map,
|
||||
nodes: vec![
|
||||
make_fn_node("a", "f.rs", 1, 2, 0),
|
||||
make_fn_node("b", "f.rs", 3, 4, 1),
|
||||
make_fn_node("c", "f.rs", 5, 6, 2),
|
||||
make_fn_node("d", "f.rs", 7, 8, 3),
|
||||
make_fn_node("e", "f.rs", 9, 10, 4),
|
||||
],
|
||||
edges: Vec::new(),
|
||||
};
|
||||
let analyzer = ImpactAnalyzer::new(&cg);
|
||||
// Depth 3 won't reach e from a (path length 5)
|
||||
let path = analyzer.find_path(a, e, 3);
|
||||
assert!(path.is_none());
|
||||
// Depth 5 should reach
|
||||
let path = analyzer.find_path(a, e, 5);
|
||||
assert!(path.is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_find_node_at_location_exact_line() {
|
||||
let cg = build_linear_graph();
|
||||
let analyzer = ImpactAnalyzer::new(&cg);
|
||||
// Node "b" is at lines 7-12
|
||||
let result = analyzer.find_node_at_location("src/main.rs", Some(9));
|
||||
assert!(result.is_some());
|
||||
assert_eq!(result.unwrap(), cg.node_map["b"]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_find_node_at_location_narrowest_match() {
|
||||
// Outer function 1-20, inner nested 5-10
|
||||
let mut graph = DiGraph::new();
|
||||
let outer = graph.add_node("outer".to_string());
|
||||
let inner = graph.add_node("inner".to_string());
|
||||
|
||||
let cg = CodeGraph {
|
||||
graph,
|
||||
node_map: [("outer".to_string(), outer), ("inner".to_string(), inner)]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
nodes: vec![
|
||||
make_fn_node("outer", "src/main.rs", 1, 20, 0),
|
||||
make_fn_node("inner", "src/main.rs", 5, 10, 1),
|
||||
],
|
||||
edges: Vec::new(),
|
||||
};
|
||||
let analyzer = ImpactAnalyzer::new(&cg);
|
||||
// Line 7 is inside both, but inner is narrower
|
||||
let result = analyzer.find_node_at_location("src/main.rs", Some(7));
|
||||
assert!(result.is_some());
|
||||
assert_eq!(result.unwrap(), inner);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_find_node_at_location_no_line_returns_file_node() {
|
||||
let mut graph = DiGraph::new();
|
||||
let file_node = graph.add_node("src/main.rs".to_string());
|
||||
let fn_node = graph.add_node("src/main.rs::foo".to_string());
|
||||
|
||||
let cg = CodeGraph {
|
||||
graph,
|
||||
node_map: [
|
||||
("src/main.rs".to_string(), file_node),
|
||||
("src/main.rs::foo".to_string(), fn_node),
|
||||
]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
nodes: vec![
|
||||
make_node(
|
||||
"src/main.rs",
|
||||
"src/main.rs",
|
||||
1,
|
||||
100,
|
||||
0,
|
||||
false,
|
||||
CodeNodeKind::File,
|
||||
),
|
||||
make_fn_node("src/main.rs::foo", "src/main.rs", 5, 10, 1),
|
||||
],
|
||||
edges: Vec::new(),
|
||||
};
|
||||
let analyzer = ImpactAnalyzer::new(&cg);
|
||||
let result = analyzer.find_node_at_location("src/main.rs", None);
|
||||
assert!(result.is_some());
|
||||
assert_eq!(result.unwrap(), file_node);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_find_node_at_location_wrong_file() {
|
||||
let cg = build_linear_graph();
|
||||
let analyzer = ImpactAnalyzer::new(&cg);
|
||||
let result = analyzer.find_node_at_location("nonexistent.rs", Some(5));
|
||||
assert!(result.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_find_node_at_location_line_out_of_range() {
|
||||
let cg = build_linear_graph();
|
||||
let analyzer = ImpactAnalyzer::new(&cg);
|
||||
let result = analyzer.find_node_at_location("src/main.rs", Some(999));
|
||||
assert!(result.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_analyze_basic() {
|
||||
// A (entry) -> B -> C
|
||||
let mut graph = DiGraph::new();
|
||||
let a = graph.add_node("a".to_string());
|
||||
let b = graph.add_node("b".to_string());
|
||||
let c = graph.add_node("c".to_string());
|
||||
graph.add_edge(a, b, CodeEdgeKind::Calls);
|
||||
graph.add_edge(b, c, CodeEdgeKind::Calls);
|
||||
|
||||
let mut node_map = HashMap::new();
|
||||
node_map.insert("a".to_string(), a);
|
||||
node_map.insert("b".to_string(), b);
|
||||
node_map.insert("c".to_string(), c);
|
||||
|
||||
let cg = CodeGraph {
|
||||
graph,
|
||||
node_map,
|
||||
nodes: vec![
|
||||
make_node("a", "src/main.rs", 1, 5, 0, true, CodeNodeKind::Function),
|
||||
make_fn_node("b", "src/main.rs", 7, 12, 1),
|
||||
make_fn_node("c", "src/main.rs", 14, 20, 2),
|
||||
],
|
||||
edges: Vec::new(),
|
||||
};
|
||||
|
||||
let analyzer = ImpactAnalyzer::new(&cg);
|
||||
let result = analyzer.analyze("repo1", "finding1", "build1", "src/main.rs", Some(9));
|
||||
// B's blast radius: C is reachable forward
|
||||
assert_eq!(result.blast_radius, 1);
|
||||
// B has A as direct caller
|
||||
assert_eq!(result.direct_callers, vec!["a"]);
|
||||
// B calls C
|
||||
assert_eq!(result.direct_callees, vec!["c"]);
|
||||
// A is an entry point that reaches B
|
||||
assert_eq!(result.affected_entry_points, vec!["a"]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_analyze_no_matching_node() {
|
||||
let cg = build_linear_graph();
|
||||
let analyzer = ImpactAnalyzer::new(&cg);
|
||||
let result = analyzer.analyze("repo1", "f1", "b1", "nonexistent.rs", Some(1));
|
||||
assert_eq!(result.blast_radius, 0);
|
||||
assert!(result.affected_entry_points.is_empty());
|
||||
assert!(result.direct_callers.is_empty());
|
||||
assert!(result.direct_callees.is_empty());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -184,3 +184,115 @@ impl Default for ParserRegistry {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::path::PathBuf;
|
||||
|
||||
#[test]
|
||||
fn test_supports_rust_extension() {
|
||||
let registry = ParserRegistry::new();
|
||||
assert!(registry.supports_extension("rs"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_supports_python_extension() {
|
||||
let registry = ParserRegistry::new();
|
||||
assert!(registry.supports_extension("py"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_supports_javascript_extension() {
|
||||
let registry = ParserRegistry::new();
|
||||
assert!(registry.supports_extension("js"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_supports_typescript_extension() {
|
||||
let registry = ParserRegistry::new();
|
||||
assert!(registry.supports_extension("ts"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_does_not_support_unknown_extension() {
|
||||
let registry = ParserRegistry::new();
|
||||
assert!(!registry.supports_extension("go"));
|
||||
assert!(!registry.supports_extension("java"));
|
||||
assert!(!registry.supports_extension("cpp"));
|
||||
assert!(!registry.supports_extension(""));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_supported_extensions_includes_all() {
|
||||
let registry = ParserRegistry::new();
|
||||
let exts = registry.supported_extensions();
|
||||
assert!(exts.contains(&"rs"));
|
||||
assert!(exts.contains(&"py"));
|
||||
assert!(exts.contains(&"js"));
|
||||
assert!(exts.contains(&"ts"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_supported_extensions_count() {
|
||||
let registry = ParserRegistry::new();
|
||||
let exts = registry.supported_extensions();
|
||||
// At least 4 extensions (rs, py, js, ts); could be more if tsx, jsx etc.
|
||||
assert!(exts.len() >= 4);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_file_returns_none_for_unsupported() {
|
||||
let registry = ParserRegistry::new();
|
||||
let path = PathBuf::from("test.go");
|
||||
let result = registry.parse_file(&path, "package main", "repo1", "build1");
|
||||
assert!(result.is_ok());
|
||||
assert!(result.unwrap().is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_file_rust_source() {
|
||||
let registry = ParserRegistry::new();
|
||||
let path = PathBuf::from("src/main.rs");
|
||||
let source = "fn main() {\n println!(\"hello\");\n}\n";
|
||||
let result = registry.parse_file(&path, source, "repo1", "build1");
|
||||
assert!(result.is_ok());
|
||||
let output = result.unwrap();
|
||||
assert!(output.is_some());
|
||||
let output = output.unwrap();
|
||||
// Should have at least the file node and the main function node
|
||||
assert!(output.nodes.len() >= 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_file_python_source() {
|
||||
let registry = ParserRegistry::new();
|
||||
let path = PathBuf::from("app.py");
|
||||
let source = "def hello():\n print('hi')\n";
|
||||
let result = registry.parse_file(&path, source, "repo1", "build1");
|
||||
assert!(result.is_ok());
|
||||
let output = result.unwrap();
|
||||
assert!(output.is_some());
|
||||
let output = output.unwrap();
|
||||
assert!(!output.nodes.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_file_empty_source() {
|
||||
let registry = ParserRegistry::new();
|
||||
let path = PathBuf::from("empty.rs");
|
||||
let result = registry.parse_file(&path, "", "repo1", "build1");
|
||||
assert!(result.is_ok());
|
||||
let output = result.unwrap();
|
||||
assert!(output.is_some());
|
||||
// At minimum the file node
|
||||
let output = output.unwrap();
|
||||
assert!(!output.nodes.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_default_trait() {
|
||||
let registry = ParserRegistry::default();
|
||||
assert!(registry.supports_extension("rs"));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -363,6 +363,214 @@ impl RustParser {
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use compliance_core::traits::graph_builder::LanguageParser;
|
||||
use std::path::PathBuf;
|
||||
|
||||
fn parse_rust(source: &str) -> ParseOutput {
|
||||
let parser = RustParser::new();
|
||||
parser
|
||||
.parse_file(&PathBuf::from("test.rs"), source, "repo1", "build1")
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extract_use_path_simple() {
|
||||
let parser = RustParser::new();
|
||||
assert_eq!(
|
||||
parser.extract_use_path("use std::collections::HashMap;"),
|
||||
Some("std::collections::HashMap".to_string())
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extract_use_path_nested() {
|
||||
let parser = RustParser::new();
|
||||
assert_eq!(
|
||||
parser.extract_use_path("use crate::models::graph::CodeNode;"),
|
||||
Some("crate::models::graph::CodeNode".to_string())
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extract_use_path_no_prefix() {
|
||||
let parser = RustParser::new();
|
||||
assert_eq!(parser.extract_use_path("let x = 5;"), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extract_use_path_empty() {
|
||||
let parser = RustParser::new();
|
||||
assert_eq!(parser.extract_use_path(""), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_function() {
|
||||
let output = parse_rust("fn hello() {\n let x = 1;\n}\n");
|
||||
let fn_nodes: Vec<_> = output
|
||||
.nodes
|
||||
.iter()
|
||||
.filter(|n| n.kind == CodeNodeKind::Function)
|
||||
.collect();
|
||||
assert_eq!(fn_nodes.len(), 1);
|
||||
assert_eq!(fn_nodes[0].name, "hello");
|
||||
assert!(fn_nodes[0].qualified_name.contains("hello"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_struct() {
|
||||
let output = parse_rust("struct Foo {\n x: i32,\n}\n");
|
||||
let struct_nodes: Vec<_> = output
|
||||
.nodes
|
||||
.iter()
|
||||
.filter(|n| n.kind == CodeNodeKind::Struct)
|
||||
.collect();
|
||||
assert_eq!(struct_nodes.len(), 1);
|
||||
assert_eq!(struct_nodes[0].name, "Foo");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_enum() {
|
||||
let output = parse_rust("enum Color {\n Red,\n Blue,\n}\n");
|
||||
let enum_nodes: Vec<_> = output
|
||||
.nodes
|
||||
.iter()
|
||||
.filter(|n| n.kind == CodeNodeKind::Enum)
|
||||
.collect();
|
||||
assert_eq!(enum_nodes.len(), 1);
|
||||
assert_eq!(enum_nodes[0].name, "Color");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_trait() {
|
||||
let output = parse_rust("trait Drawable {\n fn draw(&self);\n}\n");
|
||||
let trait_nodes: Vec<_> = output
|
||||
.nodes
|
||||
.iter()
|
||||
.filter(|n| n.kind == CodeNodeKind::Trait)
|
||||
.collect();
|
||||
assert_eq!(trait_nodes.len(), 1);
|
||||
assert_eq!(trait_nodes[0].name, "Drawable");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_file_node_always_created() {
|
||||
let output = parse_rust("");
|
||||
let file_nodes: Vec<_> = output
|
||||
.nodes
|
||||
.iter()
|
||||
.filter(|n| n.kind == CodeNodeKind::File)
|
||||
.collect();
|
||||
assert_eq!(file_nodes.len(), 1);
|
||||
assert_eq!(file_nodes[0].language, "rust");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_multiple_functions() {
|
||||
let source = "fn foo() {}\nfn bar() {}\nfn baz() {}\n";
|
||||
let output = parse_rust(source);
|
||||
let fn_nodes: Vec<_> = output
|
||||
.nodes
|
||||
.iter()
|
||||
.filter(|n| n.kind == CodeNodeKind::Function)
|
||||
.collect();
|
||||
assert_eq!(fn_nodes.len(), 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_main_is_entry_point() {
|
||||
let output = parse_rust("fn main() {\n println!(\"hi\");\n}\n");
|
||||
let main_node = output.nodes.iter().find(|n| n.name == "main").unwrap();
|
||||
assert!(main_node.is_entry_point);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_pub_fn_is_entry_point() {
|
||||
let output = parse_rust("pub fn handler() {}\n");
|
||||
let node = output.nodes.iter().find(|n| n.name == "handler").unwrap();
|
||||
assert!(node.is_entry_point);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_private_fn_is_not_entry_point() {
|
||||
let output = parse_rust("fn helper() {}\n");
|
||||
let node = output.nodes.iter().find(|n| n.name == "helper").unwrap();
|
||||
assert!(!node.is_entry_point);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_function_calls_create_edges() {
|
||||
let source = "fn caller() {\n callee();\n}\nfn callee() {}\n";
|
||||
let output = parse_rust(source);
|
||||
let call_edges: Vec<_> = output
|
||||
.edges
|
||||
.iter()
|
||||
.filter(|e| e.kind == CodeEdgeKind::Calls)
|
||||
.collect();
|
||||
assert!(!call_edges.is_empty());
|
||||
assert!(call_edges.iter().any(|e| e.target.contains("callee")));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_use_declaration_creates_import_edge() {
|
||||
let source = "use std::collections::HashMap;\nfn foo() {}\n";
|
||||
let output = parse_rust(source);
|
||||
let import_edges: Vec<_> = output
|
||||
.edges
|
||||
.iter()
|
||||
.filter(|e| e.kind == CodeEdgeKind::Imports)
|
||||
.collect();
|
||||
assert!(!import_edges.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_impl_methods() {
|
||||
let source = "struct Foo {}\nimpl Foo {\n fn do_thing(&self) {}\n}\n";
|
||||
let output = parse_rust(source);
|
||||
let fn_nodes: Vec<_> = output
|
||||
.nodes
|
||||
.iter()
|
||||
.filter(|n| n.kind == CodeNodeKind::Function)
|
||||
.collect();
|
||||
assert_eq!(fn_nodes.len(), 1);
|
||||
assert_eq!(fn_nodes[0].name, "do_thing");
|
||||
// Method should be qualified under the impl type
|
||||
assert!(fn_nodes[0].qualified_name.contains("Foo"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_mod_item() {
|
||||
let source = "mod inner {\n fn nested() {}\n}\n";
|
||||
let output = parse_rust(source);
|
||||
let mod_nodes: Vec<_> = output
|
||||
.nodes
|
||||
.iter()
|
||||
.filter(|n| n.kind == CodeNodeKind::Module)
|
||||
.collect();
|
||||
assert_eq!(mod_nodes.len(), 1);
|
||||
assert_eq!(mod_nodes[0].name, "inner");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_line_numbers() {
|
||||
let source = "fn first() {}\n\n\nfn second() {}\n";
|
||||
let output = parse_rust(source);
|
||||
let first = output.nodes.iter().find(|n| n.name == "first").unwrap();
|
||||
let second = output.nodes.iter().find(|n| n.name == "second").unwrap();
|
||||
assert_eq!(first.start_line, 1);
|
||||
assert!(second.start_line > first.start_line);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_language_and_extensions() {
|
||||
let parser = RustParser::new();
|
||||
assert_eq!(parser.language(), "rust");
|
||||
assert_eq!(parser.extensions(), &["rs"]);
|
||||
}
|
||||
}
|
||||
|
||||
impl LanguageParser for RustParser {
|
||||
fn language(&self) -> &str {
|
||||
"rust"
|
||||
|
||||
@@ -128,3 +128,186 @@ impl SymbolIndex {
|
||||
Ok(results)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use compliance_core::models::graph::CodeNodeKind;
|
||||
|
||||
fn make_node(
|
||||
qualified_name: &str,
|
||||
name: &str,
|
||||
kind: CodeNodeKind,
|
||||
file_path: &str,
|
||||
language: &str,
|
||||
) -> CodeNode {
|
||||
CodeNode {
|
||||
id: None,
|
||||
repo_id: "test".to_string(),
|
||||
graph_build_id: "build1".to_string(),
|
||||
qualified_name: qualified_name.to_string(),
|
||||
name: name.to_string(),
|
||||
kind,
|
||||
file_path: file_path.to_string(),
|
||||
start_line: 1,
|
||||
end_line: 10,
|
||||
language: language.to_string(),
|
||||
community_id: None,
|
||||
is_entry_point: false,
|
||||
graph_index: None,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_new_creates_index() {
|
||||
let index = SymbolIndex::new();
|
||||
assert!(index.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_index_empty_nodes() {
|
||||
let index = SymbolIndex::new().unwrap();
|
||||
let result = index.index_nodes(&[]);
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_index_and_search_single_node() {
|
||||
let index = SymbolIndex::new().unwrap();
|
||||
let nodes = vec![make_node(
|
||||
"src/main.rs::main",
|
||||
"main",
|
||||
CodeNodeKind::Function,
|
||||
"src/main.rs",
|
||||
"rust",
|
||||
)];
|
||||
index.index_nodes(&nodes).unwrap();
|
||||
|
||||
let results = index.search("main", 10).unwrap();
|
||||
assert!(!results.is_empty());
|
||||
assert_eq!(results[0].name, "main");
|
||||
assert_eq!(results[0].qualified_name, "src/main.rs::main");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_search_no_results() {
|
||||
let index = SymbolIndex::new().unwrap();
|
||||
let nodes = vec![make_node(
|
||||
"src/main.rs::foo",
|
||||
"foo",
|
||||
CodeNodeKind::Function,
|
||||
"src/main.rs",
|
||||
"rust",
|
||||
)];
|
||||
index.index_nodes(&nodes).unwrap();
|
||||
|
||||
let results = index.search("zzzznonexistent", 10).unwrap();
|
||||
assert!(results.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_search_multiple_nodes() {
|
||||
let index = SymbolIndex::new().unwrap();
|
||||
let nodes = vec![
|
||||
make_node(
|
||||
"a.rs::handle_request",
|
||||
"handle_request",
|
||||
CodeNodeKind::Function,
|
||||
"a.rs",
|
||||
"rust",
|
||||
),
|
||||
make_node(
|
||||
"b.rs::handle_response",
|
||||
"handle_response",
|
||||
CodeNodeKind::Function,
|
||||
"b.rs",
|
||||
"rust",
|
||||
),
|
||||
make_node(
|
||||
"c.rs::process_data",
|
||||
"process_data",
|
||||
CodeNodeKind::Function,
|
||||
"c.rs",
|
||||
"rust",
|
||||
),
|
||||
];
|
||||
index.index_nodes(&nodes).unwrap();
|
||||
|
||||
let results = index.search("handle", 10).unwrap();
|
||||
assert!(results.len() >= 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_search_limit() {
|
||||
let index = SymbolIndex::new().unwrap();
|
||||
let mut nodes = Vec::new();
|
||||
for i in 0..20 {
|
||||
nodes.push(make_node(
|
||||
&format!("mod::func_{i}"),
|
||||
&format!("func_{i}"),
|
||||
CodeNodeKind::Function,
|
||||
"mod.rs",
|
||||
"rust",
|
||||
));
|
||||
}
|
||||
index.index_nodes(&nodes).unwrap();
|
||||
|
||||
let results = index.search("func", 5).unwrap();
|
||||
assert!(results.len() <= 5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_search_result_has_score() {
|
||||
let index = SymbolIndex::new().unwrap();
|
||||
let nodes = vec![make_node(
|
||||
"src/lib.rs::compute",
|
||||
"compute",
|
||||
CodeNodeKind::Function,
|
||||
"src/lib.rs",
|
||||
"rust",
|
||||
)];
|
||||
index.index_nodes(&nodes).unwrap();
|
||||
|
||||
let results = index.search("compute", 10).unwrap();
|
||||
assert!(!results.is_empty());
|
||||
assert!(results[0].score > 0.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_search_result_fields() {
|
||||
let index = SymbolIndex::new().unwrap();
|
||||
let nodes = vec![make_node(
|
||||
"src/app.py::MyClass",
|
||||
"MyClass",
|
||||
CodeNodeKind::Class,
|
||||
"src/app.py",
|
||||
"python",
|
||||
)];
|
||||
index.index_nodes(&nodes).unwrap();
|
||||
|
||||
let results = index.search("MyClass", 10).unwrap();
|
||||
assert_eq!(results.len(), 1);
|
||||
assert_eq!(results[0].name, "MyClass");
|
||||
assert_eq!(results[0].kind, "class");
|
||||
assert_eq!(results[0].file_path, "src/app.py");
|
||||
assert_eq!(results[0].language, "python");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_search_empty_query() {
|
||||
let index = SymbolIndex::new().unwrap();
|
||||
let nodes = vec![make_node(
|
||||
"src/lib.rs::foo",
|
||||
"foo",
|
||||
CodeNodeKind::Function,
|
||||
"src/lib.rs",
|
||||
"rust",
|
||||
)];
|
||||
index.index_nodes(&nodes).unwrap();
|
||||
|
||||
// Empty query may parse error or return empty - both acceptable
|
||||
let result = index.search("", 10);
|
||||
// Just verify it doesn't panic
|
||||
let _ = result;
|
||||
}
|
||||
}
|
||||
|
||||
4
compliance-graph/tests/parsers.rs
Normal file
4
compliance-graph/tests/parsers.rs
Normal file
@@ -0,0 +1,4 @@
|
||||
// Tests for language parsers (Rust, TypeScript, JavaScript, Python).
|
||||
//
|
||||
// Test AST parsing, symbol extraction, and dependency graph construction
|
||||
// using fixture source files.
|
||||
@@ -12,6 +12,66 @@ fn cap_limit(limit: Option<i64>) -> i64 {
|
||||
limit.unwrap_or(DEFAULT_LIMIT).clamp(1, MAX_LIMIT)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn cap_limit_default() {
|
||||
assert_eq!(cap_limit(None), DEFAULT_LIMIT);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cap_limit_clamps_high() {
|
||||
assert_eq!(cap_limit(Some(300)), MAX_LIMIT);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cap_limit_clamps_low() {
|
||||
assert_eq!(cap_limit(Some(0)), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn list_dast_findings_params_deserialize() {
|
||||
let json = serde_json::json!({
|
||||
"target_id": "t1",
|
||||
"scan_run_id": "sr1",
|
||||
"severity": "critical",
|
||||
"exploitable": true,
|
||||
"vuln_type": "sql_injection",
|
||||
"limit": 10
|
||||
});
|
||||
let params: ListDastFindingsParams = serde_json::from_value(json).unwrap();
|
||||
assert_eq!(params.target_id.as_deref(), Some("t1"));
|
||||
assert_eq!(params.scan_run_id.as_deref(), Some("sr1"));
|
||||
assert_eq!(params.severity.as_deref(), Some("critical"));
|
||||
assert_eq!(params.exploitable, Some(true));
|
||||
assert_eq!(params.vuln_type.as_deref(), Some("sql_injection"));
|
||||
assert_eq!(params.limit, Some(10));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn list_dast_findings_params_all_optional() {
|
||||
let params: ListDastFindingsParams = serde_json::from_value(serde_json::json!({})).unwrap();
|
||||
assert!(params.target_id.is_none());
|
||||
assert!(params.scan_run_id.is_none());
|
||||
assert!(params.severity.is_none());
|
||||
assert!(params.exploitable.is_none());
|
||||
assert!(params.vuln_type.is_none());
|
||||
assert!(params.limit.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn dast_scan_summary_params_deserialize() {
|
||||
let params: DastScanSummaryParams =
|
||||
serde_json::from_value(serde_json::json!({ "target_id": "abc" })).unwrap();
|
||||
assert_eq!(params.target_id.as_deref(), Some("abc"));
|
||||
|
||||
let params2: DastScanSummaryParams = serde_json::from_value(serde_json::json!({})).unwrap();
|
||||
assert!(params2.target_id.is_none());
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, JsonSchema)]
|
||||
pub struct ListDastFindingsParams {
|
||||
/// Filter by DAST target ID
|
||||
|
||||
@@ -12,6 +12,89 @@ fn cap_limit(limit: Option<i64>) -> i64 {
|
||||
limit.unwrap_or(DEFAULT_LIMIT).clamp(1, MAX_LIMIT)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn cap_limit_default() {
|
||||
assert_eq!(cap_limit(None), DEFAULT_LIMIT);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cap_limit_normal_value() {
|
||||
assert_eq!(cap_limit(Some(100)), 100);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cap_limit_exceeds_max() {
|
||||
assert_eq!(cap_limit(Some(500)), MAX_LIMIT);
|
||||
assert_eq!(cap_limit(Some(201)), MAX_LIMIT);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cap_limit_zero_clamped_to_one() {
|
||||
assert_eq!(cap_limit(Some(0)), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cap_limit_negative_clamped_to_one() {
|
||||
assert_eq!(cap_limit(Some(-10)), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cap_limit_boundary_values() {
|
||||
assert_eq!(cap_limit(Some(1)), 1);
|
||||
assert_eq!(cap_limit(Some(MAX_LIMIT)), MAX_LIMIT);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn list_findings_params_deserialize() {
|
||||
let json = serde_json::json!({
|
||||
"repo_id": "abc",
|
||||
"severity": "high",
|
||||
"status": "open",
|
||||
"scan_type": "sast",
|
||||
"limit": 25
|
||||
});
|
||||
let params: ListFindingsParams = serde_json::from_value(json).unwrap();
|
||||
assert_eq!(params.repo_id.as_deref(), Some("abc"));
|
||||
assert_eq!(params.severity.as_deref(), Some("high"));
|
||||
assert_eq!(params.status.as_deref(), Some("open"));
|
||||
assert_eq!(params.scan_type.as_deref(), Some("sast"));
|
||||
assert_eq!(params.limit, Some(25));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn list_findings_params_all_optional() {
|
||||
let json = serde_json::json!({});
|
||||
let params: ListFindingsParams = serde_json::from_value(json).unwrap();
|
||||
assert!(params.repo_id.is_none());
|
||||
assert!(params.severity.is_none());
|
||||
assert!(params.status.is_none());
|
||||
assert!(params.scan_type.is_none());
|
||||
assert!(params.limit.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn get_finding_params_deserialize() {
|
||||
let json = serde_json::json!({ "id": "507f1f77bcf86cd799439011" });
|
||||
let params: GetFindingParams = serde_json::from_value(json).unwrap();
|
||||
assert_eq!(params.id, "507f1f77bcf86cd799439011");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn findings_summary_params_deserialize() {
|
||||
let json = serde_json::json!({ "repo_id": "r1" });
|
||||
let params: FindingsSummaryParams = serde_json::from_value(json).unwrap();
|
||||
assert_eq!(params.repo_id.as_deref(), Some("r1"));
|
||||
|
||||
let json2 = serde_json::json!({});
|
||||
let params2: FindingsSummaryParams = serde_json::from_value(json2).unwrap();
|
||||
assert!(params2.repo_id.is_none());
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, JsonSchema)]
|
||||
pub struct ListFindingsParams {
|
||||
/// Filter by repository ID
|
||||
|
||||
@@ -12,6 +12,90 @@ fn cap_limit(limit: Option<i64>) -> i64 {
|
||||
limit.unwrap_or(DEFAULT_LIMIT).clamp(1, MAX_LIMIT)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn cap_limit_default() {
|
||||
assert_eq!(cap_limit(None), DEFAULT_LIMIT);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cap_limit_clamps_high() {
|
||||
assert_eq!(cap_limit(Some(1000)), MAX_LIMIT);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cap_limit_clamps_low() {
|
||||
assert_eq!(cap_limit(Some(-100)), 1);
|
||||
assert_eq!(cap_limit(Some(0)), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cap_limit_normal() {
|
||||
assert_eq!(cap_limit(Some(42)), 42);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn list_pentest_sessions_params_deserialize() {
|
||||
let json = serde_json::json!({
|
||||
"target_id": "tgt",
|
||||
"status": "running",
|
||||
"strategy": "aggressive",
|
||||
"limit": 5
|
||||
});
|
||||
let params: ListPentestSessionsParams = serde_json::from_value(json).unwrap();
|
||||
assert_eq!(params.target_id.as_deref(), Some("tgt"));
|
||||
assert_eq!(params.status.as_deref(), Some("running"));
|
||||
assert_eq!(params.strategy.as_deref(), Some("aggressive"));
|
||||
assert_eq!(params.limit, Some(5));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn list_pentest_sessions_params_all_optional() {
|
||||
let params: ListPentestSessionsParams =
|
||||
serde_json::from_value(serde_json::json!({})).unwrap();
|
||||
assert!(params.target_id.is_none());
|
||||
assert!(params.status.is_none());
|
||||
assert!(params.strategy.is_none());
|
||||
assert!(params.limit.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn get_pentest_session_params_deserialize() {
|
||||
let params: GetPentestSessionParams =
|
||||
serde_json::from_value(serde_json::json!({ "id": "abc123" })).unwrap();
|
||||
assert_eq!(params.id, "abc123");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn get_attack_chain_params_deserialize() {
|
||||
let params: GetAttackChainParams =
|
||||
serde_json::from_value(serde_json::json!({ "session_id": "s1", "limit": 20 })).unwrap();
|
||||
assert_eq!(params.session_id, "s1");
|
||||
assert_eq!(params.limit, Some(20));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn get_pentest_messages_params_deserialize() {
|
||||
let params: GetPentestMessagesParams =
|
||||
serde_json::from_value(serde_json::json!({ "session_id": "s2" })).unwrap();
|
||||
assert_eq!(params.session_id, "s2");
|
||||
assert!(params.limit.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn pentest_stats_params_deserialize() {
|
||||
let params: PentestStatsParams =
|
||||
serde_json::from_value(serde_json::json!({ "target_id": "t1" })).unwrap();
|
||||
assert_eq!(params.target_id.as_deref(), Some("t1"));
|
||||
|
||||
let params2: PentestStatsParams = serde_json::from_value(serde_json::json!({})).unwrap();
|
||||
assert!(params2.target_id.is_none());
|
||||
}
|
||||
}
|
||||
|
||||
// ── List Pentest Sessions ──────────────────────────────────────
|
||||
|
||||
#[derive(Debug, Deserialize, JsonSchema)]
|
||||
|
||||
@@ -12,6 +12,66 @@ fn cap_limit(limit: Option<i64>) -> i64 {
|
||||
limit.unwrap_or(DEFAULT_LIMIT).clamp(1, MAX_LIMIT)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn cap_limit_default() {
|
||||
assert_eq!(cap_limit(None), DEFAULT_LIMIT);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cap_limit_clamps_high() {
|
||||
assert_eq!(cap_limit(Some(999)), MAX_LIMIT);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cap_limit_clamps_low() {
|
||||
assert_eq!(cap_limit(Some(0)), 1);
|
||||
assert_eq!(cap_limit(Some(-5)), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cap_limit_normal() {
|
||||
assert_eq!(cap_limit(Some(75)), 75);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn list_sbom_params_deserialize() {
|
||||
let json = serde_json::json!({
|
||||
"repo_id": "repo1",
|
||||
"has_vulns": true,
|
||||
"package_manager": "npm",
|
||||
"license": "MIT",
|
||||
"limit": 30
|
||||
});
|
||||
let params: ListSbomPackagesParams = serde_json::from_value(json).unwrap();
|
||||
assert_eq!(params.repo_id.as_deref(), Some("repo1"));
|
||||
assert_eq!(params.has_vulns, Some(true));
|
||||
assert_eq!(params.package_manager.as_deref(), Some("npm"));
|
||||
assert_eq!(params.license.as_deref(), Some("MIT"));
|
||||
assert_eq!(params.limit, Some(30));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn list_sbom_params_all_optional() {
|
||||
let params: ListSbomPackagesParams = serde_json::from_value(serde_json::json!({})).unwrap();
|
||||
assert!(params.repo_id.is_none());
|
||||
assert!(params.has_vulns.is_none());
|
||||
assert!(params.package_manager.is_none());
|
||||
assert!(params.license.is_none());
|
||||
assert!(params.limit.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sbom_vuln_report_params_deserialize() {
|
||||
let json = serde_json::json!({ "repo_id": "my-repo" });
|
||||
let params: SbomVulnReportParams = serde_json::from_value(json).unwrap();
|
||||
assert_eq!(params.repo_id, "my-repo");
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, JsonSchema)]
|
||||
pub struct ListSbomPackagesParams {
|
||||
/// Filter by repository ID
|
||||
|
||||
4
compliance-mcp/tests/tools.rs
Normal file
4
compliance-mcp/tests/tools.rs
Normal file
@@ -0,0 +1,4 @@
|
||||
// Tests for MCP tool implementations.
|
||||
//
|
||||
// Test tool request/response formats, parameter validation,
|
||||
// and database query construction.
|
||||
16
fuzz/Cargo.toml
Normal file
16
fuzz/Cargo.toml
Normal file
@@ -0,0 +1,16 @@
|
||||
[package]
|
||||
name = "compliance-fuzz"
|
||||
version = "0.0.0"
|
||||
publish = false
|
||||
edition = "2021"
|
||||
|
||||
[dependencies]
|
||||
libfuzzer-sys = "0.4"
|
||||
compliance-core = { path = "../compliance-core" }
|
||||
|
||||
# Fuzz targets are defined below. Add new targets as [[bin]] entries.
|
||||
|
||||
[[bin]]
|
||||
name = "fuzz_finding_dedup"
|
||||
path = "fuzz_targets/fuzz_finding_dedup.rs"
|
||||
doc = false
|
||||
12
fuzz/fuzz_targets/fuzz_finding_dedup.rs
Normal file
12
fuzz/fuzz_targets/fuzz_finding_dedup.rs
Normal file
@@ -0,0 +1,12 @@
|
||||
#![no_main]
|
||||
use libfuzzer_sys::fuzz_target;
|
||||
|
||||
// Example fuzz target stub for finding deduplication logic.
|
||||
// Replace with actual dedup function calls once ready.
|
||||
|
||||
fuzz_target!(|data: &[u8]| {
|
||||
if let Ok(s) = std::str::from_utf8(data) {
|
||||
// TODO: Call dedup/fingerprint functions with fuzzed input
|
||||
let _ = s.len();
|
||||
}
|
||||
});
|
||||
Reference in New Issue
Block a user