refactor: modularize codebase and add 404 unit tests (#13)
All checks were successful
CI / Format (push) Successful in 4s
CI / Clippy (push) Successful in 4m19s
CI / Security Audit (push) Successful in 1m44s
CI / Tests (push) Successful in 5m15s
CI / Detect Changes (push) Successful in 5s
CI / Deploy Agent (push) Successful in 2s
CI / Deploy Dashboard (push) Successful in 2s
CI / Deploy Docs (push) Has been skipped
CI / Deploy MCP (push) Successful in 2s

This commit was merged in pull request #13.
This commit is contained in:
2026-03-13 08:03:45 +00:00
parent acc5b86aa4
commit 3bb690e5bb
89 changed files with 11884 additions and 6046 deletions

View File

@@ -0,0 +1,481 @@
use compliance_core::models::TrackerType;
use serde::{Deserialize, Serialize};
use compliance_core::models::ScanRun;
#[derive(Deserialize)]
pub struct PaginationParams {
#[serde(default = "default_page")]
pub page: u64,
#[serde(default = "default_limit")]
pub limit: i64,
}
pub(crate) fn default_page() -> u64 {
1
}
pub(crate) fn default_limit() -> i64 {
50
}
#[derive(Deserialize)]
pub struct FindingsFilter {
#[serde(default)]
pub repo_id: Option<String>,
#[serde(default)]
pub severity: Option<String>,
#[serde(default)]
pub scan_type: Option<String>,
#[serde(default)]
pub status: Option<String>,
#[serde(default)]
pub q: Option<String>,
#[serde(default)]
pub sort_by: Option<String>,
#[serde(default)]
pub sort_order: Option<String>,
#[serde(default = "default_page")]
pub page: u64,
#[serde(default = "default_limit")]
pub limit: i64,
}
#[derive(Serialize)]
pub struct ApiResponse<T: Serialize> {
pub data: T,
#[serde(skip_serializing_if = "Option::is_none")]
pub total: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub page: Option<u64>,
}
#[derive(Serialize)]
pub struct OverviewStats {
pub total_repositories: u64,
pub total_findings: u64,
pub critical_findings: u64,
pub high_findings: u64,
pub medium_findings: u64,
pub low_findings: u64,
pub total_sbom_entries: u64,
pub total_cve_alerts: u64,
pub total_issues: u64,
pub recent_scans: Vec<ScanRun>,
}
#[derive(Deserialize)]
pub struct AddRepositoryRequest {
pub name: String,
pub git_url: String,
#[serde(default = "default_branch")]
pub default_branch: String,
pub auth_token: Option<String>,
pub auth_username: Option<String>,
pub tracker_type: Option<TrackerType>,
pub tracker_owner: Option<String>,
pub tracker_repo: Option<String>,
pub tracker_token: Option<String>,
pub scan_schedule: Option<String>,
}
#[derive(Deserialize)]
pub struct UpdateRepositoryRequest {
pub name: Option<String>,
pub default_branch: Option<String>,
pub auth_token: Option<String>,
pub auth_username: Option<String>,
pub tracker_type: Option<TrackerType>,
pub tracker_owner: Option<String>,
pub tracker_repo: Option<String>,
pub tracker_token: Option<String>,
pub scan_schedule: Option<String>,
}
fn default_branch() -> String {
"main".to_string()
}
#[derive(Deserialize)]
pub struct UpdateStatusRequest {
pub status: String,
}
#[derive(Deserialize)]
pub struct BulkUpdateStatusRequest {
pub ids: Vec<String>,
pub status: String,
}
#[derive(Deserialize)]
pub struct UpdateFeedbackRequest {
pub feedback: String,
}
#[derive(Deserialize)]
pub struct SbomFilter {
#[serde(default)]
pub repo_id: Option<String>,
#[serde(default)]
pub package_manager: Option<String>,
#[serde(default)]
pub q: Option<String>,
#[serde(default)]
pub has_vulns: Option<bool>,
#[serde(default)]
pub license: Option<String>,
#[serde(default = "default_page")]
pub page: u64,
#[serde(default = "default_limit")]
pub limit: i64,
}
#[derive(Deserialize)]
pub struct SbomExportParams {
pub repo_id: String,
#[serde(default = "default_export_format")]
pub format: String,
}
fn default_export_format() -> String {
"cyclonedx".to_string()
}
#[derive(Deserialize)]
pub struct SbomDiffParams {
pub repo_a: String,
pub repo_b: String,
}
#[derive(Serialize)]
pub struct LicenseSummary {
pub license: String,
pub count: u64,
pub is_copyleft: bool,
pub packages: Vec<String>,
}
#[derive(Serialize)]
pub struct SbomDiffResult {
pub only_in_a: Vec<SbomDiffEntry>,
pub only_in_b: Vec<SbomDiffEntry>,
pub version_changed: Vec<SbomVersionDiff>,
pub common_count: u64,
}
#[derive(Serialize)]
pub struct SbomDiffEntry {
pub name: String,
pub version: String,
pub package_manager: String,
}
#[derive(Serialize)]
pub struct SbomVersionDiff {
pub name: String,
pub package_manager: String,
pub version_a: String,
pub version_b: String,
}
pub(crate) type AgentExt = axum::extract::Extension<std::sync::Arc<crate::agent::ComplianceAgent>>;
pub(crate) type ApiResult<T> = Result<axum::Json<ApiResponse<T>>, axum::http::StatusCode>;
pub(crate) async fn collect_cursor_async<T: serde::de::DeserializeOwned + Unpin + Send>(
mut cursor: mongodb::Cursor<T>,
) -> Vec<T> {
use futures_util::StreamExt;
let mut items = Vec::new();
while let Some(result) = cursor.next().await {
match result {
Ok(item) => items.push(item),
Err(e) => tracing::warn!("Failed to deserialize document: {e}"),
}
}
items
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
// ── PaginationParams ─────────────────────────────────────────
#[test]
fn pagination_params_defaults() {
let p: PaginationParams = serde_json::from_str("{}").unwrap();
assert_eq!(p.page, 1);
assert_eq!(p.limit, 50);
}
#[test]
fn pagination_params_custom_values() {
let p: PaginationParams = serde_json::from_str(r#"{"page":3,"limit":10}"#).unwrap();
assert_eq!(p.page, 3);
assert_eq!(p.limit, 10);
}
#[test]
fn pagination_params_partial_override() {
let p: PaginationParams = serde_json::from_str(r#"{"page":5}"#).unwrap();
assert_eq!(p.page, 5);
assert_eq!(p.limit, 50);
}
#[test]
fn pagination_params_zero_page() {
let p: PaginationParams = serde_json::from_str(r#"{"page":0}"#).unwrap();
assert_eq!(p.page, 0);
}
// ── FindingsFilter ───────────────────────────────────────────
#[test]
fn findings_filter_all_defaults() {
let f: FindingsFilter = serde_json::from_str("{}").unwrap();
assert!(f.repo_id.is_none());
assert!(f.severity.is_none());
assert!(f.scan_type.is_none());
assert!(f.status.is_none());
assert!(f.q.is_none());
assert!(f.sort_by.is_none());
assert!(f.sort_order.is_none());
assert_eq!(f.page, 1);
assert_eq!(f.limit, 50);
}
#[test]
fn findings_filter_with_all_fields() {
let f: FindingsFilter = serde_json::from_str(
r#"{
"repo_id": "abc",
"severity": "high",
"scan_type": "sast",
"status": "open",
"q": "sql injection",
"sort_by": "severity",
"sort_order": "desc",
"page": 2,
"limit": 25
}"#,
)
.unwrap();
assert_eq!(f.repo_id.as_deref(), Some("abc"));
assert_eq!(f.severity.as_deref(), Some("high"));
assert_eq!(f.scan_type.as_deref(), Some("sast"));
assert_eq!(f.status.as_deref(), Some("open"));
assert_eq!(f.q.as_deref(), Some("sql injection"));
assert_eq!(f.sort_by.as_deref(), Some("severity"));
assert_eq!(f.sort_order.as_deref(), Some("desc"));
assert_eq!(f.page, 2);
assert_eq!(f.limit, 25);
}
#[test]
fn findings_filter_empty_string_fields() {
let f: FindingsFilter = serde_json::from_str(r#"{"repo_id":"","severity":""}"#).unwrap();
assert_eq!(f.repo_id.as_deref(), Some(""));
assert_eq!(f.severity.as_deref(), Some(""));
}
// ── ApiResponse ──────────────────────────────────────────────
#[test]
fn api_response_serializes_with_all_fields() {
let resp = ApiResponse {
data: vec!["a", "b"],
total: Some(100),
page: Some(1),
};
let v = serde_json::to_value(&resp).unwrap();
assert_eq!(v["data"], json!(["a", "b"]));
assert_eq!(v["total"], 100);
assert_eq!(v["page"], 1);
}
#[test]
fn api_response_skips_none_fields() {
let resp = ApiResponse {
data: "hello",
total: None,
page: None,
};
let v = serde_json::to_value(&resp).unwrap();
assert_eq!(v["data"], "hello");
assert!(v.get("total").is_none());
assert!(v.get("page").is_none());
}
#[test]
fn api_response_with_nested_struct() {
#[derive(Serialize)]
struct Item {
id: u32,
}
let resp = ApiResponse {
data: Item { id: 42 },
total: Some(1),
page: None,
};
let v = serde_json::to_value(&resp).unwrap();
assert_eq!(v["data"]["id"], 42);
assert_eq!(v["total"], 1);
assert!(v.get("page").is_none());
}
#[test]
fn api_response_empty_vec() {
let resp: ApiResponse<Vec<String>> = ApiResponse {
data: vec![],
total: Some(0),
page: Some(1),
};
let v = serde_json::to_value(&resp).unwrap();
assert!(v["data"].as_array().unwrap().is_empty());
}
// ── SbomFilter ───────────────────────────────────────────────
#[test]
fn sbom_filter_defaults() {
let f: SbomFilter = serde_json::from_str("{}").unwrap();
assert!(f.repo_id.is_none());
assert!(f.package_manager.is_none());
assert!(f.q.is_none());
assert!(f.has_vulns.is_none());
assert!(f.license.is_none());
assert_eq!(f.page, 1);
assert_eq!(f.limit, 50);
}
#[test]
fn sbom_filter_has_vulns_bool() {
let f: SbomFilter = serde_json::from_str(r#"{"has_vulns": true}"#).unwrap();
assert_eq!(f.has_vulns, Some(true));
}
// ── SbomExportParams ─────────────────────────────────────────
#[test]
fn sbom_export_params_default_format() {
let p: SbomExportParams = serde_json::from_str(r#"{"repo_id":"r1"}"#).unwrap();
assert_eq!(p.repo_id, "r1");
assert_eq!(p.format, "cyclonedx");
}
#[test]
fn sbom_export_params_custom_format() {
let p: SbomExportParams =
serde_json::from_str(r#"{"repo_id":"r1","format":"spdx"}"#).unwrap();
assert_eq!(p.format, "spdx");
}
// ── AddRepositoryRequest ─────────────────────────────────────
#[test]
fn add_repository_request_defaults() {
let r: AddRepositoryRequest = serde_json::from_str(
r#"{
"name": "my-repo",
"git_url": "https://github.com/x/y.git"
}"#,
)
.unwrap();
assert_eq!(r.name, "my-repo");
assert_eq!(r.git_url, "https://github.com/x/y.git");
assert_eq!(r.default_branch, "main");
assert!(r.auth_token.is_none());
assert!(r.tracker_type.is_none());
assert!(r.scan_schedule.is_none());
}
#[test]
fn add_repository_request_custom_branch() {
let r: AddRepositoryRequest = serde_json::from_str(
r#"{
"name": "repo",
"git_url": "url",
"default_branch": "develop"
}"#,
)
.unwrap();
assert_eq!(r.default_branch, "develop");
}
// ── UpdateStatusRequest / BulkUpdateStatusRequest ────────────
#[test]
fn update_status_request() {
let r: UpdateStatusRequest = serde_json::from_str(r#"{"status":"resolved"}"#).unwrap();
assert_eq!(r.status, "resolved");
}
#[test]
fn bulk_update_status_request() {
let r: BulkUpdateStatusRequest =
serde_json::from_str(r#"{"ids":["a","b"],"status":"dismissed"}"#).unwrap();
assert_eq!(r.ids, vec!["a", "b"]);
assert_eq!(r.status, "dismissed");
}
#[test]
fn bulk_update_status_empty_ids() {
let r: BulkUpdateStatusRequest =
serde_json::from_str(r#"{"ids":[],"status":"x"}"#).unwrap();
assert!(r.ids.is_empty());
}
// ── SbomDiffResult serialization ─────────────────────────────
#[test]
fn sbom_diff_result_serializes() {
let r = SbomDiffResult {
only_in_a: vec![SbomDiffEntry {
name: "pkg-a".to_string(),
version: "1.0".to_string(),
package_manager: "npm".to_string(),
}],
only_in_b: vec![],
version_changed: vec![SbomVersionDiff {
name: "shared".to_string(),
package_manager: "cargo".to_string(),
version_a: "0.1".to_string(),
version_b: "0.2".to_string(),
}],
common_count: 10,
};
let v = serde_json::to_value(&r).unwrap();
assert_eq!(v["only_in_a"].as_array().unwrap().len(), 1);
assert_eq!(v["only_in_b"].as_array().unwrap().len(), 0);
assert_eq!(v["version_changed"][0]["version_a"], "0.1");
assert_eq!(v["common_count"], 10);
}
// ── LicenseSummary ───────────────────────────────────────────
#[test]
fn license_summary_serializes() {
let ls = LicenseSummary {
license: "MIT".to_string(),
count: 42,
is_copyleft: false,
packages: vec!["serde".to_string()],
};
let v = serde_json::to_value(&ls).unwrap();
assert_eq!(v["license"], "MIT");
assert_eq!(v["is_copyleft"], false);
assert_eq!(v["count"], 42);
}
// ── Default helper functions ─────────────────────────────────
#[test]
fn default_page_returns_1() {
assert_eq!(default_page(), 1);
}
#[test]
fn default_limit_returns_50() {
assert_eq!(default_limit(), 50);
}
}

View File

@@ -0,0 +1,172 @@
use axum::extract::{Extension, Path, Query};
use axum::http::StatusCode;
use axum::Json;
use mongodb::bson::doc;
use super::dto::*;
use compliance_core::models::Finding;
#[tracing::instrument(skip_all, fields(repo_id = ?filter.repo_id, severity = ?filter.severity, scan_type = ?filter.scan_type))]
pub async fn list_findings(
Extension(agent): AgentExt,
Query(filter): Query<FindingsFilter>,
) -> ApiResult<Vec<Finding>> {
let db = &agent.db;
let mut query = doc! {};
if let Some(repo_id) = &filter.repo_id {
query.insert("repo_id", repo_id);
}
if let Some(severity) = &filter.severity {
query.insert("severity", severity);
}
if let Some(scan_type) = &filter.scan_type {
query.insert("scan_type", scan_type);
}
if let Some(status) = &filter.status {
query.insert("status", status);
}
// Text search across title, description, file_path, rule_id
if let Some(q) = &filter.q {
if !q.is_empty() {
let regex = doc! { "$regex": q, "$options": "i" };
query.insert(
"$or",
mongodb::bson::bson!([
{ "title": regex.clone() },
{ "description": regex.clone() },
{ "file_path": regex.clone() },
{ "rule_id": regex },
]),
);
}
}
// Dynamic sort
let sort_field = filter.sort_by.as_deref().unwrap_or("created_at");
let sort_dir: i32 = match filter.sort_order.as_deref() {
Some("asc") => 1,
_ => -1,
};
let sort_doc = doc! { sort_field: sort_dir };
let skip = (filter.page.saturating_sub(1)) * filter.limit as u64;
let total = db
.findings()
.count_documents(query.clone())
.await
.unwrap_or(0);
let findings = match db
.findings()
.find(query)
.sort(sort_doc)
.skip(skip)
.limit(filter.limit)
.await
{
Ok(cursor) => collect_cursor_async(cursor).await,
Err(e) => {
tracing::warn!("Failed to fetch findings: {e}");
Vec::new()
}
};
Ok(Json(ApiResponse {
data: findings,
total: Some(total),
page: Some(filter.page),
}))
}
#[tracing::instrument(skip_all, fields(finding_id = %id))]
pub async fn get_finding(
Extension(agent): AgentExt,
Path(id): Path<String>,
) -> Result<Json<ApiResponse<Finding>>, StatusCode> {
let oid = mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?;
let finding = agent
.db
.findings()
.find_one(doc! { "_id": oid })
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
.ok_or(StatusCode::NOT_FOUND)?;
Ok(Json(ApiResponse {
data: finding,
total: None,
page: None,
}))
}
#[tracing::instrument(skip_all, fields(finding_id = %id))]
pub async fn update_finding_status(
Extension(agent): AgentExt,
Path(id): Path<String>,
Json(req): Json<UpdateStatusRequest>,
) -> Result<Json<serde_json::Value>, StatusCode> {
let oid = mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?;
agent
.db
.findings()
.update_one(
doc! { "_id": oid },
doc! { "$set": { "status": &req.status, "updated_at": mongodb::bson::DateTime::now() } },
)
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
Ok(Json(serde_json::json!({ "status": "updated" })))
}
#[tracing::instrument(skip_all)]
pub async fn bulk_update_finding_status(
Extension(agent): AgentExt,
Json(req): Json<BulkUpdateStatusRequest>,
) -> Result<Json<serde_json::Value>, StatusCode> {
let oids: Vec<mongodb::bson::oid::ObjectId> = req
.ids
.iter()
.filter_map(|id| mongodb::bson::oid::ObjectId::parse_str(id).ok())
.collect();
if oids.is_empty() {
return Err(StatusCode::BAD_REQUEST);
}
let result = agent
.db
.findings()
.update_many(
doc! { "_id": { "$in": oids } },
doc! { "$set": { "status": &req.status, "updated_at": mongodb::bson::DateTime::now() } },
)
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
Ok(Json(
serde_json::json!({ "status": "updated", "modified_count": result.modified_count }),
))
}
#[tracing::instrument(skip_all)]
pub async fn update_finding_feedback(
Extension(agent): AgentExt,
Path(id): Path<String>,
Json(req): Json<UpdateFeedbackRequest>,
) -> Result<Json<serde_json::Value>, StatusCode> {
let oid = mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?;
agent
.db
.findings()
.update_one(
doc! { "_id": oid },
doc! { "$set": { "developer_feedback": &req.feedback, "updated_at": mongodb::bson::DateTime::now() } },
)
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
Ok(Json(serde_json::json!({ "status": "updated" })))
}

View File

@@ -0,0 +1,84 @@
use axum::Json;
use mongodb::bson::doc;
use super::dto::*;
use compliance_core::models::ScanRun;
#[tracing::instrument(skip_all)]
pub async fn health() -> Json<serde_json::Value> {
Json(serde_json::json!({ "status": "ok" }))
}
#[tracing::instrument(skip_all)]
pub async fn stats_overview(axum::extract::Extension(agent): AgentExt) -> ApiResult<OverviewStats> {
let db = &agent.db;
let total_repositories = db
.repositories()
.count_documents(doc! {})
.await
.unwrap_or(0);
let total_findings = db.findings().count_documents(doc! {}).await.unwrap_or(0);
let critical_findings = db
.findings()
.count_documents(doc! { "severity": "critical" })
.await
.unwrap_or(0);
let high_findings = db
.findings()
.count_documents(doc! { "severity": "high" })
.await
.unwrap_or(0);
let medium_findings = db
.findings()
.count_documents(doc! { "severity": "medium" })
.await
.unwrap_or(0);
let low_findings = db
.findings()
.count_documents(doc! { "severity": "low" })
.await
.unwrap_or(0);
let total_sbom_entries = db
.sbom_entries()
.count_documents(doc! {})
.await
.unwrap_or(0);
let total_cve_alerts = db.cve_alerts().count_documents(doc! {}).await.unwrap_or(0);
let total_issues = db
.tracker_issues()
.count_documents(doc! {})
.await
.unwrap_or(0);
let recent_scans: Vec<ScanRun> = match db
.scan_runs()
.find(doc! {})
.sort(doc! { "started_at": -1 })
.limit(10)
.await
{
Ok(cursor) => collect_cursor_async(cursor).await,
Err(e) => {
tracing::warn!("Failed to fetch recent scans: {e}");
Vec::new()
}
};
Ok(Json(ApiResponse {
data: OverviewStats {
total_repositories,
total_findings,
critical_findings,
high_findings,
medium_findings,
low_findings,
total_sbom_entries,
total_cve_alerts,
total_issues,
recent_scans,
},
total: None,
page: None,
}))
}

View File

@@ -0,0 +1,41 @@
use axum::extract::{Extension, Query};
use axum::Json;
use mongodb::bson::doc;
use super::dto::*;
use compliance_core::models::TrackerIssue;
#[tracing::instrument(skip_all)]
pub async fn list_issues(
Extension(agent): AgentExt,
Query(params): Query<PaginationParams>,
) -> ApiResult<Vec<TrackerIssue>> {
let db = &agent.db;
let skip = (params.page.saturating_sub(1)) * params.limit as u64;
let total = db
.tracker_issues()
.count_documents(doc! {})
.await
.unwrap_or(0);
let issues = match db
.tracker_issues()
.find(doc! {})
.sort(doc! { "created_at": -1 })
.skip(skip)
.limit(params.limit)
.await
{
Ok(cursor) => collect_cursor_async(cursor).await,
Err(e) => {
tracing::warn!("Failed to fetch tracker issues: {e}");
Vec::new()
}
};
Ok(Json(ApiResponse {
data: issues,
total: Some(total),
page: Some(params.page),
}))
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,131 @@
use std::sync::Arc;
use axum::extract::{Extension, Path};
use axum::http::StatusCode;
use axum::response::IntoResponse;
use axum::Json;
use mongodb::bson::doc;
use serde::Deserialize;
use compliance_core::models::dast::DastFinding;
use compliance_core::models::pentest::*;
use crate::agent::ComplianceAgent;
use super::super::dto::collect_cursor_async;
type AgentExt = Extension<Arc<ComplianceAgent>>;
#[derive(Deserialize)]
pub struct ExportBody {
pub password: String,
/// Requester display name (from auth)
#[serde(default)]
pub requester_name: String,
/// Requester email (from auth)
#[serde(default)]
pub requester_email: String,
}
/// POST /api/v1/pentest/sessions/:id/export — Export an encrypted pentest report archive
#[tracing::instrument(skip_all, fields(session_id = %id))]
pub async fn export_session_report(
Extension(agent): AgentExt,
Path(id): Path<String>,
Json(body): Json<ExportBody>,
) -> Result<axum::response::Response, (StatusCode, String)> {
let oid = mongodb::bson::oid::ObjectId::parse_str(&id)
.map_err(|_| (StatusCode::BAD_REQUEST, "Invalid session ID".to_string()))?;
if body.password.len() < 8 {
return Err((
StatusCode::BAD_REQUEST,
"Password must be at least 8 characters".to_string(),
));
}
// Fetch session
let session = agent
.db
.pentest_sessions()
.find_one(doc! { "_id": oid })
.await
.map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("Database error: {e}"),
)
})?
.ok_or_else(|| (StatusCode::NOT_FOUND, "Session not found".to_string()))?;
// Resolve target name
let target = if let Ok(tid) = mongodb::bson::oid::ObjectId::parse_str(&session.target_id) {
agent
.db
.dast_targets()
.find_one(doc! { "_id": tid })
.await
.ok()
.flatten()
} else {
None
};
let target_name = target
.as_ref()
.map(|t| t.name.clone())
.unwrap_or_else(|| "Unknown Target".to_string());
let target_url = target
.as_ref()
.map(|t| t.base_url.clone())
.unwrap_or_default();
// Fetch attack chain nodes
let nodes: Vec<AttackChainNode> = match agent
.db
.attack_chain_nodes()
.find(doc! { "session_id": &id })
.sort(doc! { "started_at": 1 })
.await
{
Ok(cursor) => collect_cursor_async(cursor).await,
Err(_) => Vec::new(),
};
// Fetch DAST findings for this session
let findings: Vec<DastFinding> = match agent
.db
.dast_findings()
.find(doc! { "session_id": &id })
.sort(doc! { "severity": -1, "created_at": -1 })
.await
{
Ok(cursor) => collect_cursor_async(cursor).await,
Err(_) => Vec::new(),
};
let ctx = crate::pentest::report::ReportContext {
session,
target_name,
target_url,
findings,
attack_chain: nodes,
requester_name: if body.requester_name.is_empty() {
"Unknown".to_string()
} else {
body.requester_name
},
requester_email: body.requester_email,
};
let report = crate::pentest::generate_encrypted_report(&ctx, &body.password)
.await
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e))?;
let response = serde_json::json!({
"archive_base64": base64::Engine::encode(&base64::engine::general_purpose::STANDARD, &report.archive),
"sha256": report.sha256,
"filename": format!("pentest-report-{id}.zip"),
});
Ok(Json(response).into_response())
}

View File

@@ -0,0 +1,9 @@
mod export;
mod session;
mod stats;
mod stream;
pub use export::*;
pub use session::*;
pub use stats::*;
pub use stream::*;

View File

@@ -2,20 +2,16 @@ use std::sync::Arc;
use axum::extract::{Extension, Path, Query};
use axum::http::StatusCode;
use axum::response::sse::{Event, Sse};
use axum::response::IntoResponse;
use axum::Json;
use futures_util::stream;
use mongodb::bson::doc;
use serde::Deserialize;
use compliance_core::models::dast::DastFinding;
use compliance_core::models::pentest::*;
use crate::agent::ComplianceAgent;
use crate::pentest::PentestOrchestrator;
use super::{collect_cursor_async, ApiResponse, PaginationParams};
use super::super::dto::{collect_cursor_async, ApiResponse, PaginationParams};
type AgentExt = Extension<Arc<ComplianceAgent>>;
@@ -160,8 +156,7 @@ pub async fn get_session(
Extension(agent): AgentExt,
Path(id): Path<String>,
) -> Result<Json<ApiResponse<PentestSession>>, StatusCode> {
let oid =
mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?;
let oid = mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?;
let session = agent
.db
@@ -210,13 +205,12 @@ pub async fn send_message(
}
// Look up the target
let target_oid =
mongodb::bson::oid::ObjectId::parse_str(&session.target_id).map_err(|_| {
(
StatusCode::INTERNAL_SERVER_ERROR,
"Invalid target_id in session".to_string(),
)
})?;
let target_oid = mongodb::bson::oid::ObjectId::parse_str(&session.target_id).map_err(|_| {
(
StatusCode::INTERNAL_SERVER_ERROR,
"Invalid target_id in session".to_string(),
)
})?;
let target = agent
.db
@@ -261,106 +255,6 @@ pub async fn send_message(
}))
}
/// GET /api/v1/pentest/sessions/:id/stream — SSE endpoint for real-time events
///
/// Returns recent messages as SSE events (polling approach).
/// True real-time streaming with broadcast channels will be added in a future iteration.
#[tracing::instrument(skip_all, fields(session_id = %id))]
pub async fn session_stream(
Extension(agent): AgentExt,
Path(id): Path<String>,
) -> Result<Sse<impl futures_util::Stream<Item = Result<Event, std::convert::Infallible>>>, StatusCode>
{
let oid =
mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?;
// Verify session exists
let _session = agent
.db
.pentest_sessions()
.find_one(doc! { "_id": oid })
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
.ok_or(StatusCode::NOT_FOUND)?;
// Fetch recent messages for this session
let messages: Vec<PentestMessage> = match agent
.db
.pentest_messages()
.find(doc! { "session_id": &id })
.sort(doc! { "created_at": 1 })
.limit(100)
.await
{
Ok(cursor) => collect_cursor_async(cursor).await,
Err(_) => Vec::new(),
};
// Fetch recent attack chain nodes
let nodes: Vec<AttackChainNode> = match agent
.db
.attack_chain_nodes()
.find(doc! { "session_id": &id })
.sort(doc! { "started_at": 1 })
.limit(100)
.await
{
Ok(cursor) => collect_cursor_async(cursor).await,
Err(_) => Vec::new(),
};
// Build SSE events from stored data
let mut events: Vec<Result<Event, std::convert::Infallible>> = Vec::new();
for msg in &messages {
let event_data = serde_json::json!({
"type": "message",
"role": msg.role,
"content": msg.content,
"created_at": msg.created_at.to_rfc3339(),
});
if let Ok(data) = serde_json::to_string(&event_data) {
events.push(Ok(Event::default().event("message").data(data)));
}
}
for node in &nodes {
let event_data = serde_json::json!({
"type": "tool_execution",
"node_id": node.node_id,
"tool_name": node.tool_name,
"status": node.status,
"findings_produced": node.findings_produced,
});
if let Ok(data) = serde_json::to_string(&event_data) {
events.push(Ok(Event::default().event("tool").data(data)));
}
}
// Add session status event
let session = agent
.db
.pentest_sessions()
.find_one(doc! { "_id": oid })
.await
.ok()
.flatten();
if let Some(s) = session {
let status_data = serde_json::json!({
"type": "status",
"status": s.status,
"findings_count": s.findings_count,
"tool_invocations": s.tool_invocations,
});
if let Ok(data) = serde_json::to_string(&status_data) {
events.push(Ok(Event::default().event("status").data(data)));
}
}
Ok(Sse::new(stream::iter(events)))
}
/// POST /api/v1/pentest/sessions/:id/stop — Stop a running pentest session
#[tracing::instrument(skip_all, fields(session_id = %id))]
pub async fn stop_session(
@@ -375,7 +269,12 @@ pub async fn stop_session(
.pentest_sessions()
.find_one(doc! { "_id": oid })
.await
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, format!("Database error: {e}")))?
.map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("Database error: {e}"),
)
})?
.ok_or_else(|| (StatusCode::NOT_FOUND, "Session not found".to_string()))?;
if session.status != PentestStatus::Running {
@@ -397,15 +296,30 @@ pub async fn stop_session(
}},
)
.await
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, format!("Database error: {e}")))?;
.map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("Database error: {e}"),
)
})?;
let updated = agent
.db
.pentest_sessions()
.find_one(doc! { "_id": oid })
.await
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, format!("Database error: {e}")))?
.ok_or_else(|| (StatusCode::NOT_FOUND, "Session not found after update".to_string()))?;
.map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("Database error: {e}"),
)
})?
.ok_or_else(|| {
(
StatusCode::NOT_FOUND,
"Session not found after update".to_string(),
)
})?;
Ok(Json(ApiResponse {
data: updated,
@@ -420,9 +334,7 @@ pub async fn get_attack_chain(
Extension(agent): AgentExt,
Path(id): Path<String>,
) -> Result<Json<ApiResponse<Vec<AttackChainNode>>>, StatusCode> {
// Verify the session ID is valid
let _oid =
mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?;
let _oid = mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?;
let nodes = match agent
.db
@@ -453,8 +365,7 @@ pub async fn get_messages(
Path(id): Path<String>,
Query(params): Query<PaginationParams>,
) -> Result<Json<ApiResponse<Vec<PentestMessage>>>, StatusCode> {
let _oid =
mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?;
let _oid = mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?;
let skip = (params.page.saturating_sub(1)) * params.limit as u64;
let total = agent
@@ -487,95 +398,14 @@ pub async fn get_messages(
}))
}
/// GET /api/v1/pentest/stats — Aggregated pentest statistics
#[tracing::instrument(skip_all)]
pub async fn pentest_stats(
Extension(agent): AgentExt,
) -> Result<Json<ApiResponse<PentestStats>>, StatusCode> {
let db = &agent.db;
let running_sessions = db
.pentest_sessions()
.count_documents(doc! { "status": "running" })
.await
.unwrap_or(0) as u32;
// Count DAST findings from pentest sessions
let total_vulnerabilities = db
.dast_findings()
.count_documents(doc! { "session_id": { "$exists": true, "$ne": null } })
.await
.unwrap_or(0) as u32;
// Aggregate tool invocations from all sessions
let sessions: Vec<PentestSession> = match db.pentest_sessions().find(doc! {}).await {
Ok(cursor) => collect_cursor_async(cursor).await,
Err(_) => Vec::new(),
};
let total_tool_invocations: u32 = sessions.iter().map(|s| s.tool_invocations).sum();
let total_successes: u32 = sessions.iter().map(|s| s.tool_successes).sum();
let tool_success_rate = if total_tool_invocations == 0 {
100.0
} else {
(total_successes as f64 / total_tool_invocations as f64) * 100.0
};
// Severity distribution from pentest-related DAST findings
let critical = db
.dast_findings()
.count_documents(doc! { "session_id": { "$exists": true, "$ne": null }, "severity": "critical" })
.await
.unwrap_or(0) as u32;
let high = db
.dast_findings()
.count_documents(doc! { "session_id": { "$exists": true, "$ne": null }, "severity": "high" })
.await
.unwrap_or(0) as u32;
let medium = db
.dast_findings()
.count_documents(doc! { "session_id": { "$exists": true, "$ne": null }, "severity": "medium" })
.await
.unwrap_or(0) as u32;
let low = db
.dast_findings()
.count_documents(doc! { "session_id": { "$exists": true, "$ne": null }, "severity": "low" })
.await
.unwrap_or(0) as u32;
let info = db
.dast_findings()
.count_documents(doc! { "session_id": { "$exists": true, "$ne": null }, "severity": "info" })
.await
.unwrap_or(0) as u32;
Ok(Json(ApiResponse {
data: PentestStats {
running_sessions,
total_vulnerabilities,
total_tool_invocations,
tool_success_rate,
severity_distribution: SeverityDistribution {
critical,
high,
medium,
low,
info,
},
},
total: None,
page: None,
}))
}
/// GET /api/v1/pentest/sessions/:id/findings — Get DAST findings for a pentest session
#[tracing::instrument(skip_all, fields(session_id = %id))]
pub async fn get_session_findings(
Extension(agent): AgentExt,
Path(id): Path<String>,
Query(params): Query<PaginationParams>,
) -> Result<Json<ApiResponse<Vec<DastFinding>>>, StatusCode> {
let _oid =
mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?;
) -> Result<Json<ApiResponse<Vec<compliance_core::models::dast::DastFinding>>>, StatusCode> {
let _oid = mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?;
let skip = (params.page.saturating_sub(1)) * params.limit as u64;
let total = agent
@@ -607,112 +437,3 @@ pub async fn get_session_findings(
page: Some(params.page),
}))
}
#[derive(Deserialize)]
pub struct ExportBody {
pub password: String,
/// Requester display name (from auth)
#[serde(default)]
pub requester_name: String,
/// Requester email (from auth)
#[serde(default)]
pub requester_email: String,
}
/// POST /api/v1/pentest/sessions/:id/export — Export an encrypted pentest report archive
#[tracing::instrument(skip_all, fields(session_id = %id))]
pub async fn export_session_report(
Extension(agent): AgentExt,
Path(id): Path<String>,
Json(body): Json<ExportBody>,
) -> Result<axum::response::Response, (StatusCode, String)> {
let oid = mongodb::bson::oid::ObjectId::parse_str(&id)
.map_err(|_| (StatusCode::BAD_REQUEST, "Invalid session ID".to_string()))?;
if body.password.len() < 8 {
return Err((
StatusCode::BAD_REQUEST,
"Password must be at least 8 characters".to_string(),
));
}
// Fetch session
let session = agent
.db
.pentest_sessions()
.find_one(doc! { "_id": oid })
.await
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, format!("Database error: {e}")))?
.ok_or_else(|| (StatusCode::NOT_FOUND, "Session not found".to_string()))?;
// Resolve target name
let target = if let Ok(tid) = mongodb::bson::oid::ObjectId::parse_str(&session.target_id) {
agent
.db
.dast_targets()
.find_one(doc! { "_id": tid })
.await
.ok()
.flatten()
} else {
None
};
let target_name = target
.as_ref()
.map(|t| t.name.clone())
.unwrap_or_else(|| "Unknown Target".to_string());
let target_url = target
.as_ref()
.map(|t| t.base_url.clone())
.unwrap_or_default();
// Fetch attack chain nodes
let nodes: Vec<AttackChainNode> = match agent
.db
.attack_chain_nodes()
.find(doc! { "session_id": &id })
.sort(doc! { "started_at": 1 })
.await
{
Ok(cursor) => collect_cursor_async(cursor).await,
Err(_) => Vec::new(),
};
// Fetch DAST findings for this session
let findings: Vec<DastFinding> = match agent
.db
.dast_findings()
.find(doc! { "session_id": &id })
.sort(doc! { "severity": -1, "created_at": -1 })
.await
{
Ok(cursor) => collect_cursor_async(cursor).await,
Err(_) => Vec::new(),
};
let ctx = crate::pentest::report::ReportContext {
session,
target_name,
target_url,
findings,
attack_chain: nodes,
requester_name: if body.requester_name.is_empty() {
"Unknown".to_string()
} else {
body.requester_name
},
requester_email: body.requester_email,
};
let report = crate::pentest::generate_encrypted_report(&ctx, &body.password)
.await
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e))?;
let response = serde_json::json!({
"archive_base64": base64::Engine::encode(&base64::engine::general_purpose::STANDARD, &report.archive),
"sha256": report.sha256,
"filename": format!("pentest-report-{id}.zip"),
});
Ok(Json(response).into_response())
}

View File

@@ -0,0 +1,102 @@
use std::sync::Arc;
use axum::extract::Extension;
use axum::http::StatusCode;
use axum::Json;
use mongodb::bson::doc;
use compliance_core::models::pentest::*;
use crate::agent::ComplianceAgent;
use super::super::dto::{collect_cursor_async, ApiResponse};
type AgentExt = Extension<Arc<ComplianceAgent>>;
/// GET /api/v1/pentest/stats — Aggregated pentest statistics
#[tracing::instrument(skip_all)]
pub async fn pentest_stats(
Extension(agent): AgentExt,
) -> Result<Json<ApiResponse<PentestStats>>, StatusCode> {
let db = &agent.db;
let running_sessions = db
.pentest_sessions()
.count_documents(doc! { "status": "running" })
.await
.unwrap_or(0) as u32;
// Count DAST findings from pentest sessions
let total_vulnerabilities = db
.dast_findings()
.count_documents(doc! { "session_id": { "$exists": true, "$ne": null } })
.await
.unwrap_or(0) as u32;
// Aggregate tool invocations from all sessions
let sessions: Vec<PentestSession> = match db.pentest_sessions().find(doc! {}).await {
Ok(cursor) => collect_cursor_async(cursor).await,
Err(_) => Vec::new(),
};
let total_tool_invocations: u32 = sessions.iter().map(|s| s.tool_invocations).sum();
let total_successes: u32 = sessions.iter().map(|s| s.tool_successes).sum();
let tool_success_rate = if total_tool_invocations == 0 {
100.0
} else {
(total_successes as f64 / total_tool_invocations as f64) * 100.0
};
// Severity distribution from pentest-related DAST findings
let critical = db
.dast_findings()
.count_documents(
doc! { "session_id": { "$exists": true, "$ne": null }, "severity": "critical" },
)
.await
.unwrap_or(0) as u32;
let high = db
.dast_findings()
.count_documents(
doc! { "session_id": { "$exists": true, "$ne": null }, "severity": "high" },
)
.await
.unwrap_or(0) as u32;
let medium = db
.dast_findings()
.count_documents(
doc! { "session_id": { "$exists": true, "$ne": null }, "severity": "medium" },
)
.await
.unwrap_or(0) as u32;
let low = db
.dast_findings()
.count_documents(doc! { "session_id": { "$exists": true, "$ne": null }, "severity": "low" })
.await
.unwrap_or(0) as u32;
let info = db
.dast_findings()
.count_documents(
doc! { "session_id": { "$exists": true, "$ne": null }, "severity": "info" },
)
.await
.unwrap_or(0) as u32;
Ok(Json(ApiResponse {
data: PentestStats {
running_sessions,
total_vulnerabilities,
total_tool_invocations,
tool_success_rate,
severity_distribution: SeverityDistribution {
critical,
high,
medium,
low,
info,
},
},
total: None,
page: None,
}))
}

View File

@@ -0,0 +1,116 @@
use std::sync::Arc;
use axum::extract::{Extension, Path};
use axum::http::StatusCode;
use axum::response::sse::{Event, Sse};
use futures_util::stream;
use mongodb::bson::doc;
use compliance_core::models::pentest::*;
use crate::agent::ComplianceAgent;
use super::super::dto::collect_cursor_async;
type AgentExt = Extension<Arc<ComplianceAgent>>;
/// GET /api/v1/pentest/sessions/:id/stream — SSE endpoint for real-time events
///
/// Returns recent messages as SSE events (polling approach).
/// True real-time streaming with broadcast channels will be added in a future iteration.
#[tracing::instrument(skip_all, fields(session_id = %id))]
pub async fn session_stream(
Extension(agent): AgentExt,
Path(id): Path<String>,
) -> Result<
Sse<impl futures_util::Stream<Item = Result<Event, std::convert::Infallible>>>,
StatusCode,
> {
let oid = mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?;
// Verify session exists
let _session = agent
.db
.pentest_sessions()
.find_one(doc! { "_id": oid })
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
.ok_or(StatusCode::NOT_FOUND)?;
// Fetch recent messages for this session
let messages: Vec<PentestMessage> = match agent
.db
.pentest_messages()
.find(doc! { "session_id": &id })
.sort(doc! { "created_at": 1 })
.limit(100)
.await
{
Ok(cursor) => collect_cursor_async(cursor).await,
Err(_) => Vec::new(),
};
// Fetch recent attack chain nodes
let nodes: Vec<AttackChainNode> = match agent
.db
.attack_chain_nodes()
.find(doc! { "session_id": &id })
.sort(doc! { "started_at": 1 })
.limit(100)
.await
{
Ok(cursor) => collect_cursor_async(cursor).await,
Err(_) => Vec::new(),
};
// Build SSE events from stored data
let mut events: Vec<Result<Event, std::convert::Infallible>> = Vec::new();
for msg in &messages {
let event_data = serde_json::json!({
"type": "message",
"role": msg.role,
"content": msg.content,
"created_at": msg.created_at.to_rfc3339(),
});
if let Ok(data) = serde_json::to_string(&event_data) {
events.push(Ok(Event::default().event("message").data(data)));
}
}
for node in &nodes {
let event_data = serde_json::json!({
"type": "tool_execution",
"node_id": node.node_id,
"tool_name": node.tool_name,
"status": node.status,
"findings_produced": node.findings_produced,
});
if let Ok(data) = serde_json::to_string(&event_data) {
events.push(Ok(Event::default().event("tool").data(data)));
}
}
// Add session status event
let session = agent
.db
.pentest_sessions()
.find_one(doc! { "_id": oid })
.await
.ok()
.flatten();
if let Some(s) = session {
let status_data = serde_json::json!({
"type": "status",
"status": s.status,
"findings_count": s.findings_count,
"tool_invocations": s.tool_invocations,
});
if let Ok(data) = serde_json::to_string(&status_data) {
events.push(Ok(Event::default().event("status").data(data)));
}
}
Ok(Sse::new(stream::iter(events)))
}

View File

@@ -0,0 +1,241 @@
use axum::extract::{Extension, Path, Query};
use axum::http::StatusCode;
use axum::Json;
use mongodb::bson::doc;
use super::dto::*;
use compliance_core::models::*;
#[tracing::instrument(skip_all)]
pub async fn list_repositories(
Extension(agent): AgentExt,
Query(params): Query<PaginationParams>,
) -> ApiResult<Vec<TrackedRepository>> {
let db = &agent.db;
let skip = (params.page.saturating_sub(1)) * params.limit as u64;
let total = db
.repositories()
.count_documents(doc! {})
.await
.unwrap_or(0);
let repos = match db
.repositories()
.find(doc! {})
.skip(skip)
.limit(params.limit)
.await
{
Ok(cursor) => collect_cursor_async(cursor).await,
Err(e) => {
tracing::warn!("Failed to fetch repositories: {e}");
Vec::new()
}
};
Ok(Json(ApiResponse {
data: repos,
total: Some(total),
page: Some(params.page),
}))
}
#[tracing::instrument(skip_all)]
pub async fn add_repository(
Extension(agent): AgentExt,
Json(req): Json<AddRepositoryRequest>,
) -> Result<Json<ApiResponse<TrackedRepository>>, (StatusCode, String)> {
// Validate repository access before saving
let creds = crate::pipeline::git::RepoCredentials {
ssh_key_path: Some(agent.config.ssh_key_path.clone()),
auth_token: req.auth_token.clone(),
auth_username: req.auth_username.clone(),
};
if let Err(e) = crate::pipeline::git::GitOps::test_access(&req.git_url, &creds) {
return Err((
StatusCode::BAD_REQUEST,
format!("Cannot access repository: {e}"),
));
}
let mut repo = TrackedRepository::new(req.name, req.git_url);
repo.default_branch = req.default_branch;
repo.auth_token = req.auth_token;
repo.auth_username = req.auth_username;
repo.tracker_type = req.tracker_type;
repo.tracker_owner = req.tracker_owner;
repo.tracker_repo = req.tracker_repo;
repo.tracker_token = req.tracker_token;
repo.scan_schedule = req.scan_schedule;
agent
.db
.repositories()
.insert_one(&repo)
.await
.map_err(|_| {
(
StatusCode::CONFLICT,
"Repository already exists".to_string(),
)
})?;
Ok(Json(ApiResponse {
data: repo,
total: None,
page: None,
}))
}
#[tracing::instrument(skip_all, fields(repo_id = %id))]
pub async fn update_repository(
Extension(agent): AgentExt,
Path(id): Path<String>,
Json(req): Json<UpdateRepositoryRequest>,
) -> Result<Json<serde_json::Value>, StatusCode> {
let oid = mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?;
let mut set_doc = doc! { "updated_at": mongodb::bson::DateTime::now() };
if let Some(name) = &req.name {
set_doc.insert("name", name);
}
if let Some(branch) = &req.default_branch {
set_doc.insert("default_branch", branch);
}
if let Some(token) = &req.auth_token {
set_doc.insert("auth_token", token);
}
if let Some(username) = &req.auth_username {
set_doc.insert("auth_username", username);
}
if let Some(tracker_type) = &req.tracker_type {
set_doc.insert("tracker_type", tracker_type.to_string());
}
if let Some(owner) = &req.tracker_owner {
set_doc.insert("tracker_owner", owner);
}
if let Some(repo) = &req.tracker_repo {
set_doc.insert("tracker_repo", repo);
}
if let Some(token) = &req.tracker_token {
set_doc.insert("tracker_token", token);
}
if let Some(schedule) = &req.scan_schedule {
set_doc.insert("scan_schedule", schedule);
}
let result = agent
.db
.repositories()
.update_one(doc! { "_id": oid }, doc! { "$set": set_doc })
.await
.map_err(|e| {
tracing::warn!("Failed to update repository: {e}");
StatusCode::INTERNAL_SERVER_ERROR
})?;
if result.matched_count == 0 {
return Err(StatusCode::NOT_FOUND);
}
Ok(Json(serde_json::json!({ "status": "updated" })))
}
#[tracing::instrument(skip_all)]
pub async fn get_ssh_public_key(
Extension(agent): AgentExt,
) -> Result<Json<serde_json::Value>, StatusCode> {
let public_path = format!("{}.pub", agent.config.ssh_key_path);
let public_key = std::fs::read_to_string(&public_path).map_err(|_| StatusCode::NOT_FOUND)?;
Ok(Json(serde_json::json!({ "public_key": public_key.trim() })))
}
#[tracing::instrument(skip_all, fields(repo_id = %id))]
pub async fn trigger_scan(
Extension(agent): AgentExt,
Path(id): Path<String>,
) -> Result<Json<serde_json::Value>, StatusCode> {
let agent_clone = (*agent).clone();
tokio::spawn(async move {
if let Err(e) = agent_clone.run_scan(&id, ScanTrigger::Manual).await {
tracing::error!("Manual scan failed for {id}: {e}");
}
});
Ok(Json(serde_json::json!({ "status": "scan_triggered" })))
}
/// Return the webhook secret for a repository (used by dashboard to display it)
pub async fn get_webhook_config(
Extension(agent): AgentExt,
Path(id): Path<String>,
) -> Result<Json<serde_json::Value>, StatusCode> {
let oid = mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?;
let repo = agent
.db
.repositories()
.find_one(doc! { "_id": oid })
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
.ok_or(StatusCode::NOT_FOUND)?;
let tracker_type = repo
.tracker_type
.as_ref()
.map(|t| t.to_string())
.unwrap_or_else(|| "gitea".to_string());
Ok(Json(serde_json::json!({
"webhook_secret": repo.webhook_secret,
"tracker_type": tracker_type,
})))
}
#[tracing::instrument(skip_all, fields(repo_id = %id))]
pub async fn delete_repository(
Extension(agent): AgentExt,
Path(id): Path<String>,
) -> Result<Json<serde_json::Value>, StatusCode> {
let oid = mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?;
let db = &agent.db;
// Delete the repository
let result = db
.repositories()
.delete_one(doc! { "_id": oid })
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
if result.deleted_count == 0 {
return Err(StatusCode::NOT_FOUND);
}
// Cascade delete all related data
let _ = db.findings().delete_many(doc! { "repo_id": &id }).await;
let _ = db.sbom_entries().delete_many(doc! { "repo_id": &id }).await;
let _ = db.scan_runs().delete_many(doc! { "repo_id": &id }).await;
let _ = db.cve_alerts().delete_many(doc! { "repo_id": &id }).await;
let _ = db
.tracker_issues()
.delete_many(doc! { "repo_id": &id })
.await;
let _ = db.graph_nodes().delete_many(doc! { "repo_id": &id }).await;
let _ = db.graph_edges().delete_many(doc! { "repo_id": &id }).await;
let _ = db.graph_builds().delete_many(doc! { "repo_id": &id }).await;
let _ = db
.impact_analyses()
.delete_many(doc! { "repo_id": &id })
.await;
let _ = db
.code_embeddings()
.delete_many(doc! { "repo_id": &id })
.await;
let _ = db
.embedding_builds()
.delete_many(doc! { "repo_id": &id })
.await;
Ok(Json(serde_json::json!({ "status": "deleted" })))
}

View File

@@ -0,0 +1,379 @@
use axum::extract::{Extension, Query};
use axum::http::{header, StatusCode};
use axum::response::IntoResponse;
use axum::Json;
use mongodb::bson::doc;
use super::dto::*;
use compliance_core::models::SbomEntry;
const COPYLEFT_LICENSES: &[&str] = &[
"GPL-2.0",
"GPL-2.0-only",
"GPL-2.0-or-later",
"GPL-3.0",
"GPL-3.0-only",
"GPL-3.0-or-later",
"AGPL-3.0",
"AGPL-3.0-only",
"AGPL-3.0-or-later",
"LGPL-2.1",
"LGPL-2.1-only",
"LGPL-2.1-or-later",
"LGPL-3.0",
"LGPL-3.0-only",
"LGPL-3.0-or-later",
"MPL-2.0",
];
#[tracing::instrument(skip_all)]
pub async fn sbom_filters(
Extension(agent): AgentExt,
) -> Result<Json<serde_json::Value>, StatusCode> {
let db = &agent.db;
let managers: Vec<String> = db
.sbom_entries()
.distinct("package_manager", doc! {})
.await
.unwrap_or_default()
.into_iter()
.filter_map(|v| v.as_str().map(|s| s.to_string()))
.filter(|s| !s.is_empty() && s != "unknown" && s != "file")
.collect();
let licenses: Vec<String> = db
.sbom_entries()
.distinct("license", doc! {})
.await
.unwrap_or_default()
.into_iter()
.filter_map(|v| v.as_str().map(|s| s.to_string()))
.filter(|s| !s.is_empty())
.collect();
Ok(Json(serde_json::json!({
"package_managers": managers,
"licenses": licenses,
})))
}
#[tracing::instrument(skip_all, fields(repo_id = ?filter.repo_id, package_manager = ?filter.package_manager))]
pub async fn list_sbom(
Extension(agent): AgentExt,
Query(filter): Query<SbomFilter>,
) -> ApiResult<Vec<SbomEntry>> {
let db = &agent.db;
let mut query = doc! {};
if let Some(repo_id) = &filter.repo_id {
query.insert("repo_id", repo_id);
}
if let Some(pm) = &filter.package_manager {
query.insert("package_manager", pm);
}
if let Some(q) = &filter.q {
if !q.is_empty() {
query.insert("name", doc! { "$regex": q, "$options": "i" });
}
}
if let Some(has_vulns) = filter.has_vulns {
if has_vulns {
query.insert("known_vulnerabilities", doc! { "$exists": true, "$ne": [] });
} else {
query.insert("known_vulnerabilities", doc! { "$size": 0 });
}
}
if let Some(license) = &filter.license {
query.insert("license", license);
}
let skip = (filter.page.saturating_sub(1)) * filter.limit as u64;
let total = db
.sbom_entries()
.count_documents(query.clone())
.await
.unwrap_or(0);
let entries = match db
.sbom_entries()
.find(query)
.sort(doc! { "name": 1 })
.skip(skip)
.limit(filter.limit)
.await
{
Ok(cursor) => collect_cursor_async(cursor).await,
Err(e) => {
tracing::warn!("Failed to fetch SBOM entries: {e}");
Vec::new()
}
};
Ok(Json(ApiResponse {
data: entries,
total: Some(total),
page: Some(filter.page),
}))
}
#[tracing::instrument(skip_all)]
pub async fn export_sbom(
Extension(agent): AgentExt,
Query(params): Query<SbomExportParams>,
) -> Result<impl IntoResponse, StatusCode> {
let db = &agent.db;
let entries: Vec<SbomEntry> = match db
.sbom_entries()
.find(doc! { "repo_id": &params.repo_id })
.await
{
Ok(cursor) => collect_cursor_async(cursor).await,
Err(e) => {
tracing::warn!("Failed to fetch SBOM entries for export: {e}");
Vec::new()
}
};
let body = if params.format == "spdx" {
// SPDX 2.3 format
let packages: Vec<serde_json::Value> = entries
.iter()
.enumerate()
.map(|(i, e)| {
serde_json::json!({
"SPDXID": format!("SPDXRef-Package-{i}"),
"name": e.name,
"versionInfo": e.version,
"downloadLocation": "NOASSERTION",
"licenseConcluded": e.license.as_deref().unwrap_or("NOASSERTION"),
"externalRefs": e.purl.as_ref().map(|p| vec![serde_json::json!({
"referenceCategory": "PACKAGE-MANAGER",
"referenceType": "purl",
"referenceLocator": p,
})]).unwrap_or_default(),
})
})
.collect();
serde_json::json!({
"spdxVersion": "SPDX-2.3",
"dataLicense": "CC0-1.0",
"SPDXID": "SPDXRef-DOCUMENT",
"name": format!("sbom-{}", params.repo_id),
"documentNamespace": format!("https://compliance-scanner/sbom/{}", params.repo_id),
"packages": packages,
})
} else {
// CycloneDX 1.5 format
let components: Vec<serde_json::Value> = entries
.iter()
.map(|e| {
let mut comp = serde_json::json!({
"type": "library",
"name": e.name,
"version": e.version,
"group": e.package_manager,
});
if let Some(purl) = &e.purl {
comp["purl"] = serde_json::Value::String(purl.clone());
}
if let Some(license) = &e.license {
comp["licenses"] = serde_json::json!([{ "license": { "id": license } }]);
}
if !e.known_vulnerabilities.is_empty() {
comp["vulnerabilities"] = serde_json::json!(
e.known_vulnerabilities.iter().map(|v| serde_json::json!({
"id": v.id,
"source": { "name": v.source },
"ratings": v.severity.as_ref().map(|s| vec![serde_json::json!({"severity": s})]).unwrap_or_default(),
})).collect::<Vec<_>>()
);
}
comp
})
.collect();
serde_json::json!({
"bomFormat": "CycloneDX",
"specVersion": "1.5",
"version": 1,
"metadata": {
"component": {
"type": "application",
"name": format!("repo-{}", params.repo_id),
}
},
"components": components,
})
};
let json_str =
serde_json::to_string_pretty(&body).map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
let filename = if params.format == "spdx" {
format!("sbom-{}-spdx.json", params.repo_id)
} else {
format!("sbom-{}-cyclonedx.json", params.repo_id)
};
let disposition = format!("attachment; filename=\"{filename}\"");
Ok((
[
(
header::CONTENT_TYPE,
header::HeaderValue::from_static("application/json"),
),
(
header::CONTENT_DISPOSITION,
header::HeaderValue::from_str(&disposition)
.unwrap_or_else(|_| header::HeaderValue::from_static("attachment")),
),
],
json_str,
))
}
#[tracing::instrument(skip_all)]
pub async fn license_summary(
Extension(agent): AgentExt,
Query(params): Query<SbomFilter>,
) -> ApiResult<Vec<LicenseSummary>> {
let db = &agent.db;
let mut query = doc! {};
if let Some(repo_id) = &params.repo_id {
query.insert("repo_id", repo_id);
}
let entries: Vec<SbomEntry> = match db.sbom_entries().find(query).await {
Ok(cursor) => collect_cursor_async(cursor).await,
Err(e) => {
tracing::warn!("Failed to fetch SBOM entries for license summary: {e}");
Vec::new()
}
};
let mut license_map: std::collections::HashMap<String, Vec<String>> =
std::collections::HashMap::new();
for entry in &entries {
let lic = entry.license.as_deref().unwrap_or("Unknown").to_string();
license_map.entry(lic).or_default().push(entry.name.clone());
}
let mut summaries: Vec<LicenseSummary> = license_map
.into_iter()
.map(|(license, packages)| {
let is_copyleft = COPYLEFT_LICENSES
.iter()
.any(|c| license.to_uppercase().contains(&c.to_uppercase()));
LicenseSummary {
license,
count: packages.len() as u64,
is_copyleft,
packages,
}
})
.collect();
summaries.sort_by(|a, b| b.count.cmp(&a.count));
Ok(Json(ApiResponse {
data: summaries,
total: None,
page: None,
}))
}
#[tracing::instrument(skip_all)]
pub async fn sbom_diff(
Extension(agent): AgentExt,
Query(params): Query<SbomDiffParams>,
) -> ApiResult<SbomDiffResult> {
let db = &agent.db;
let entries_a: Vec<SbomEntry> = match db
.sbom_entries()
.find(doc! { "repo_id": &params.repo_a })
.await
{
Ok(cursor) => collect_cursor_async(cursor).await,
Err(e) => {
tracing::warn!("Failed to fetch SBOM entries for repo_a: {e}");
Vec::new()
}
};
let entries_b: Vec<SbomEntry> = match db
.sbom_entries()
.find(doc! { "repo_id": &params.repo_b })
.await
{
Ok(cursor) => collect_cursor_async(cursor).await,
Err(e) => {
tracing::warn!("Failed to fetch SBOM entries for repo_b: {e}");
Vec::new()
}
};
// Build maps by (name, package_manager) -> version
let map_a: std::collections::HashMap<(String, String), String> = entries_a
.iter()
.map(|e| {
(
(e.name.clone(), e.package_manager.clone()),
e.version.clone(),
)
})
.collect();
let map_b: std::collections::HashMap<(String, String), String> = entries_b
.iter()
.map(|e| {
(
(e.name.clone(), e.package_manager.clone()),
e.version.clone(),
)
})
.collect();
let mut only_in_a = Vec::new();
let mut version_changed = Vec::new();
let mut common_count: u64 = 0;
for (key, ver_a) in &map_a {
match map_b.get(key) {
None => only_in_a.push(SbomDiffEntry {
name: key.0.clone(),
version: ver_a.clone(),
package_manager: key.1.clone(),
}),
Some(ver_b) if ver_a != ver_b => {
version_changed.push(SbomVersionDiff {
name: key.0.clone(),
package_manager: key.1.clone(),
version_a: ver_a.clone(),
version_b: ver_b.clone(),
});
}
Some(_) => common_count += 1,
}
}
let only_in_b: Vec<SbomDiffEntry> = map_b
.iter()
.filter(|(key, _)| !map_a.contains_key(key))
.map(|(key, ver)| SbomDiffEntry {
name: key.0.clone(),
version: ver.clone(),
package_manager: key.1.clone(),
})
.collect();
Ok(Json(ApiResponse {
data: SbomDiffResult {
only_in_a,
only_in_b,
version_changed,
common_count,
},
total: None,
page: None,
}))
}

View File

@@ -0,0 +1,37 @@
use axum::extract::{Extension, Query};
use axum::Json;
use mongodb::bson::doc;
use super::dto::*;
use compliance_core::models::ScanRun;
#[tracing::instrument(skip_all)]
pub async fn list_scan_runs(
Extension(agent): AgentExt,
Query(params): Query<PaginationParams>,
) -> ApiResult<Vec<ScanRun>> {
let db = &agent.db;
let skip = (params.page.saturating_sub(1)) * params.limit as u64;
let total = db.scan_runs().count_documents(doc! {}).await.unwrap_or(0);
let scans = match db
.scan_runs()
.find(doc! {})
.sort(doc! { "started_at": -1 })
.skip(skip)
.limit(params.limit)
.await
{
Ok(cursor) => collect_cursor_async(cursor).await,
Err(e) => {
tracing::warn!("Failed to fetch scan runs: {e}");
Vec::new()
}
};
Ok(Json(ApiResponse {
data: scans,
total: Some(total),
page: Some(params.page),
}))
}

View File

@@ -136,7 +136,10 @@ pub fn build_router() -> Router {
"/api/v1/pentest/sessions/{id}/export",
post(handlers::pentest::export_session_report),
)
.route("/api/v1/pentest/stats", get(handlers::pentest::pentest_stats))
.route(
"/api/v1/pentest/stats",
get(handlers::pentest::pentest_stats),
)
// Webhook endpoints (proxied through dashboard)
.route(
"/webhook/github/{repo_id}",

View File

@@ -1,147 +1,17 @@
use secrecy::{ExposeSecret, SecretString};
use serde::{Deserialize, Serialize};
use super::types::*;
use crate::error::AgentError;
#[derive(Clone)]
pub struct LlmClient {
base_url: String,
api_key: SecretString,
model: String,
embed_model: String,
http: reqwest::Client,
pub(crate) base_url: String,
pub(crate) api_key: SecretString,
pub(crate) model: String,
pub(crate) embed_model: String,
pub(crate) http: reqwest::Client,
}
// ── Request types ──────────────────────────────────────────────
#[derive(Serialize, Clone, Debug)]
pub struct ChatMessage {
pub role: String,
pub content: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_calls: Option<Vec<ToolCallRequest>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_call_id: Option<String>,
}
#[derive(Serialize)]
struct ChatCompletionRequest {
model: String,
messages: Vec<ChatMessage>,
#[serde(skip_serializing_if = "Option::is_none")]
temperature: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
max_tokens: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
tools: Option<Vec<ToolDefinitionPayload>>,
}
#[derive(Serialize)]
struct ToolDefinitionPayload {
r#type: String,
function: ToolFunctionPayload,
}
#[derive(Serialize)]
struct ToolFunctionPayload {
name: String,
description: String,
parameters: serde_json::Value,
}
// ── Response types ─────────────────────────────────────────────
#[derive(Deserialize)]
struct ChatCompletionResponse {
choices: Vec<ChatChoice>,
}
#[derive(Deserialize)]
struct ChatChoice {
message: ChatResponseMessage,
}
#[derive(Deserialize)]
struct ChatResponseMessage {
#[serde(default)]
content: Option<String>,
#[serde(default)]
tool_calls: Option<Vec<ToolCallResponse>>,
}
#[derive(Deserialize)]
struct ToolCallResponse {
id: String,
function: ToolCallFunction,
}
#[derive(Deserialize)]
struct ToolCallFunction {
name: String,
arguments: String,
}
// ── Public types for tool calling ──────────────────────────────
/// Definition of a tool that the LLM can invoke
#[derive(Debug, Clone, Serialize)]
pub struct ToolDefinition {
pub name: String,
pub description: String,
pub parameters: serde_json::Value,
}
/// A tool call request from the LLM
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LlmToolCall {
pub id: String,
pub name: String,
pub arguments: serde_json::Value,
}
/// A tool call in the request message format (for sending back tool_calls in assistant messages)
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolCallRequest {
pub id: String,
pub r#type: String,
pub function: ToolCallRequestFunction,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolCallRequestFunction {
pub name: String,
pub arguments: String,
}
/// Response from the LLM — either content or tool calls
#[derive(Debug, Clone)]
pub enum LlmResponse {
Content(String),
/// Tool calls with optional reasoning text from the LLM
ToolCalls { calls: Vec<LlmToolCall>, reasoning: String },
}
// ── Embedding types ────────────────────────────────────────────
#[derive(Serialize)]
struct EmbeddingRequest {
model: String,
input: Vec<String>,
}
#[derive(Deserialize)]
struct EmbeddingResponse {
data: Vec<EmbeddingData>,
}
#[derive(Deserialize)]
struct EmbeddingData {
embedding: Vec<f64>,
index: usize,
}
// ── Implementation ─────────────────────────────────────────────
impl LlmClient {
pub fn new(
base_url: String,
@@ -158,18 +28,14 @@ impl LlmClient {
}
}
pub fn embed_model(&self) -> &str {
&self.embed_model
}
fn chat_url(&self) -> String {
pub(crate) fn chat_url(&self) -> String {
format!(
"{}/v1/chat/completions",
self.base_url.trim_end_matches('/')
)
}
fn auth_header(&self) -> Option<String> {
pub(crate) fn auth_header(&self) -> Option<String> {
let key = self.api_key.expose_secret();
if key.is_empty() {
None
@@ -241,12 +107,12 @@ impl LlmClient {
tools: None,
};
self.send_chat_request(&request_body).await.map(|resp| {
match resp {
self.send_chat_request(&request_body)
.await
.map(|resp| match resp {
LlmResponse::Content(c) => c,
LlmResponse::ToolCalls { .. } => String::new(),
}
})
})
}
/// Chat with tool definitions — returns either content or tool calls.
@@ -292,7 +158,7 @@ impl LlmClient {
) -> Result<LlmResponse, AgentError> {
let mut req = self
.http
.post(&self.chat_url())
.post(self.chat_url())
.header("content-type", "application/json")
.json(request_body);
@@ -345,54 +211,7 @@ impl LlmClient {
}
// Otherwise return content
let content = choice
.message
.content
.clone()
.unwrap_or_default();
let content = choice.message.content.clone().unwrap_or_default();
Ok(LlmResponse::Content(content))
}
/// Generate embeddings for a batch of texts
pub async fn embed(&self, texts: Vec<String>) -> Result<Vec<Vec<f64>>, AgentError> {
let url = format!("{}/v1/embeddings", self.base_url.trim_end_matches('/'));
let request_body = EmbeddingRequest {
model: self.embed_model.clone(),
input: texts,
};
let mut req = self
.http
.post(&url)
.header("content-type", "application/json")
.json(&request_body);
if let Some(auth) = self.auth_header() {
req = req.header("Authorization", auth);
}
let resp = req
.send()
.await
.map_err(|e| AgentError::Other(format!("Embedding request failed: {e}")))?;
if !resp.status().is_success() {
let status = resp.status();
let body = resp.text().await.unwrap_or_default();
return Err(AgentError::Other(format!(
"Embedding API returned {status}: {body}"
)));
}
let body: EmbeddingResponse = resp
.json()
.await
.map_err(|e| AgentError::Other(format!("Failed to parse embedding response: {e}")))?;
let mut data = body.data;
data.sort_by_key(|d| d.index);
Ok(data.into_iter().map(|d| d.embedding).collect())
}
}

View File

@@ -0,0 +1,74 @@
use serde::{Deserialize, Serialize};
use super::client::LlmClient;
use crate::error::AgentError;
// ── Embedding types ────────────────────────────────────────────
#[derive(Serialize)]
struct EmbeddingRequest {
model: String,
input: Vec<String>,
}
#[derive(Deserialize)]
struct EmbeddingResponse {
data: Vec<EmbeddingData>,
}
#[derive(Deserialize)]
struct EmbeddingData {
embedding: Vec<f64>,
index: usize,
}
// ── Embedding implementation ───────────────────────────────────
impl LlmClient {
pub fn embed_model(&self) -> &str {
&self.embed_model
}
/// Generate embeddings for a batch of texts
pub async fn embed(&self, texts: Vec<String>) -> Result<Vec<Vec<f64>>, AgentError> {
let url = format!("{}/v1/embeddings", self.base_url.trim_end_matches('/'));
let request_body = EmbeddingRequest {
model: self.embed_model.clone(),
input: texts,
};
let mut req = self
.http
.post(&url)
.header("content-type", "application/json")
.json(&request_body);
if let Some(auth) = self.auth_header() {
req = req.header("Authorization", auth);
}
let resp = req
.send()
.await
.map_err(|e| AgentError::Other(format!("Embedding request failed: {e}")))?;
if !resp.status().is_success() {
let status = resp.status();
let body = resp.text().await.unwrap_or_default();
return Err(AgentError::Other(format!(
"Embedding API returned {status}: {body}"
)));
}
let body: EmbeddingResponse = resp
.json()
.await
.map_err(|e| AgentError::Other(format!("Failed to parse embedding response: {e}")))?;
let mut data = body.data;
data.sort_by_key(|d| d.index);
Ok(data.into_iter().map(|d| d.embedding).collect())
}
}

View File

@@ -1,11 +1,16 @@
pub mod client;
#[allow(dead_code)]
pub mod descriptions;
pub mod embedding;
#[allow(dead_code)]
pub mod fixes;
#[allow(dead_code)]
pub mod pr_review;
pub mod review_prompts;
pub mod triage;
pub mod types;
pub use client::LlmClient;
pub use types::{
ChatMessage, LlmResponse, ToolCallRequest, ToolCallRequestFunction, ToolDefinition,
};

View File

@@ -278,3 +278,220 @@ struct TriageResult {
fn default_action() -> String {
"confirm".to_string()
}
#[cfg(test)]
mod tests {
use super::*;
use compliance_core::models::Severity;
// ── classify_file_path ───────────────────────────────────────
#[test]
fn classify_none_path() {
assert_eq!(classify_file_path(None), "unknown");
}
#[test]
fn classify_production_path() {
assert_eq!(classify_file_path(Some("src/main.rs")), "production");
assert_eq!(classify_file_path(Some("lib/core/engine.py")), "production");
}
#[test]
fn classify_test_paths() {
assert_eq!(classify_file_path(Some("src/test/helper.rs")), "test");
assert_eq!(classify_file_path(Some("src/tests/unit.rs")), "test");
assert_eq!(classify_file_path(Some("foo_test.go")), "test");
assert_eq!(classify_file_path(Some("bar.test.js")), "test");
assert_eq!(classify_file_path(Some("baz.spec.ts")), "test");
assert_eq!(
classify_file_path(Some("data/fixtures/sample.json")),
"test"
);
assert_eq!(classify_file_path(Some("src/testdata/input.txt")), "test");
}
#[test]
fn classify_example_paths() {
assert_eq!(
classify_file_path(Some("docs/examples/basic.rs")),
"example"
);
// /example matches because contains("/example")
assert_eq!(classify_file_path(Some("src/example/main.py")), "example");
assert_eq!(classify_file_path(Some("src/demo/run.sh")), "example");
assert_eq!(classify_file_path(Some("src/sample/lib.rs")), "example");
}
#[test]
fn classify_generated_paths() {
assert_eq!(
classify_file_path(Some("src/generated/api.rs")),
"generated"
);
assert_eq!(
classify_file_path(Some("proto/gen/service.go")),
"generated"
);
assert_eq!(classify_file_path(Some("api.generated.ts")), "generated");
assert_eq!(classify_file_path(Some("service.pb.go")), "generated");
assert_eq!(classify_file_path(Some("model_generated.rs")), "generated");
}
#[test]
fn classify_vendored_paths() {
// Implementation checks for /vendor/, /node_modules/, /third_party/ (with slashes)
assert_eq!(
classify_file_path(Some("src/vendor/lib/foo.go")),
"vendored"
);
assert_eq!(
classify_file_path(Some("src/node_modules/pkg/index.js")),
"vendored"
);
assert_eq!(
classify_file_path(Some("src/third_party/lib.c")),
"vendored"
);
}
#[test]
fn classify_is_case_insensitive() {
assert_eq!(classify_file_path(Some("src/TEST/Helper.rs")), "test");
assert_eq!(classify_file_path(Some("src/VENDOR/lib.go")), "vendored");
assert_eq!(
classify_file_path(Some("src/GENERATED/foo.ts")),
"generated"
);
}
// ── adjust_confidence ────────────────────────────────────────
#[test]
fn adjust_confidence_production() {
assert_eq!(adjust_confidence(8.0, "production"), 8.0);
}
#[test]
fn adjust_confidence_test() {
assert_eq!(adjust_confidence(10.0, "test"), 5.0);
}
#[test]
fn adjust_confidence_example() {
assert_eq!(adjust_confidence(10.0, "example"), 6.0);
}
#[test]
fn adjust_confidence_generated() {
assert_eq!(adjust_confidence(10.0, "generated"), 3.0);
}
#[test]
fn adjust_confidence_vendored() {
assert_eq!(adjust_confidence(10.0, "vendored"), 4.0);
}
#[test]
fn adjust_confidence_unknown_classification() {
assert_eq!(adjust_confidence(7.0, "unknown"), 7.0);
assert_eq!(adjust_confidence(7.0, "something_else"), 7.0);
}
#[test]
fn adjust_confidence_zero() {
assert_eq!(adjust_confidence(0.0, "test"), 0.0);
assert_eq!(adjust_confidence(0.0, "production"), 0.0);
}
// ── downgrade_severity ───────────────────────────────────────
#[test]
fn downgrade_severity_all_levels() {
assert_eq!(downgrade_severity(&Severity::Critical), Severity::High);
assert_eq!(downgrade_severity(&Severity::High), Severity::Medium);
assert_eq!(downgrade_severity(&Severity::Medium), Severity::Low);
assert_eq!(downgrade_severity(&Severity::Low), Severity::Info);
assert_eq!(downgrade_severity(&Severity::Info), Severity::Info);
}
#[test]
fn downgrade_severity_info_is_floor() {
// Downgrading Info twice should still be Info
let s = downgrade_severity(&Severity::Info);
assert_eq!(downgrade_severity(&s), Severity::Info);
}
// ── upgrade_severity ─────────────────────────────────────────
#[test]
fn upgrade_severity_all_levels() {
assert_eq!(upgrade_severity(&Severity::Info), Severity::Low);
assert_eq!(upgrade_severity(&Severity::Low), Severity::Medium);
assert_eq!(upgrade_severity(&Severity::Medium), Severity::High);
assert_eq!(upgrade_severity(&Severity::High), Severity::Critical);
assert_eq!(upgrade_severity(&Severity::Critical), Severity::Critical);
}
#[test]
fn upgrade_severity_critical_is_ceiling() {
let s = upgrade_severity(&Severity::Critical);
assert_eq!(upgrade_severity(&s), Severity::Critical);
}
// ── upgrade/downgrade roundtrip ──────────────────────────────
#[test]
fn upgrade_then_downgrade_is_identity_for_middle_values() {
for sev in [Severity::Low, Severity::Medium, Severity::High] {
assert_eq!(downgrade_severity(&upgrade_severity(&sev)), sev);
}
}
// ── TriageResult deserialization ─────────────────────────────
#[test]
fn triage_result_full() {
let json = r#"{"action":"dismiss","confidence":8.5,"rationale":"false positive","remediation":"remove code"}"#;
let r: TriageResult = serde_json::from_str(json).unwrap();
assert_eq!(r.action, "dismiss");
assert_eq!(r.confidence, 8.5);
assert_eq!(r.rationale, "false positive");
assert_eq!(r.remediation.as_deref(), Some("remove code"));
}
#[test]
fn triage_result_defaults() {
let json = r#"{}"#;
let r: TriageResult = serde_json::from_str(json).unwrap();
assert_eq!(r.action, "confirm");
assert_eq!(r.confidence, 0.0);
assert_eq!(r.rationale, "");
assert!(r.remediation.is_none());
}
#[test]
fn triage_result_partial() {
let json = r#"{"action":"downgrade","confidence":6.0}"#;
let r: TriageResult = serde_json::from_str(json).unwrap();
assert_eq!(r.action, "downgrade");
assert_eq!(r.confidence, 6.0);
assert_eq!(r.rationale, "");
assert!(r.remediation.is_none());
}
#[test]
fn triage_result_with_markdown_fences() {
// Simulate LLM wrapping response in markdown code fences
let raw = "```json\n{\"action\":\"upgrade\",\"confidence\":9,\"rationale\":\"critical\",\"remediation\":null}\n```";
let cleaned = raw
.trim()
.trim_start_matches("```json")
.trim_start_matches("```")
.trim_end_matches("```")
.trim();
let r: TriageResult = serde_json::from_str(cleaned).unwrap();
assert_eq!(r.action, "upgrade");
assert_eq!(r.confidence, 9.0);
}
}

View File

@@ -0,0 +1,369 @@
use serde::{Deserialize, Serialize};
// ── Request types ──────────────────────────────────────────────
#[derive(Serialize, Clone, Debug)]
pub struct ChatMessage {
pub role: String,
pub content: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_calls: Option<Vec<ToolCallRequest>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_call_id: Option<String>,
}
#[derive(Serialize)]
pub(crate) struct ChatCompletionRequest {
pub(crate) model: String,
pub(crate) messages: Vec<ChatMessage>,
#[serde(skip_serializing_if = "Option::is_none")]
pub(crate) temperature: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub(crate) max_tokens: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub(crate) tools: Option<Vec<ToolDefinitionPayload>>,
}
#[derive(Serialize)]
pub(crate) struct ToolDefinitionPayload {
pub(crate) r#type: String,
pub(crate) function: ToolFunctionPayload,
}
#[derive(Serialize)]
pub(crate) struct ToolFunctionPayload {
pub(crate) name: String,
pub(crate) description: String,
pub(crate) parameters: serde_json::Value,
}
// ── Response types ─────────────────────────────────────────────
#[derive(Deserialize)]
pub(crate) struct ChatCompletionResponse {
pub(crate) choices: Vec<ChatChoice>,
}
#[derive(Deserialize)]
pub(crate) struct ChatChoice {
pub(crate) message: ChatResponseMessage,
}
#[derive(Deserialize)]
pub(crate) struct ChatResponseMessage {
#[serde(default)]
pub(crate) content: Option<String>,
#[serde(default)]
pub(crate) tool_calls: Option<Vec<ToolCallResponse>>,
}
#[derive(Deserialize)]
pub(crate) struct ToolCallResponse {
pub(crate) id: String,
pub(crate) function: ToolCallFunction,
}
#[derive(Deserialize)]
pub(crate) struct ToolCallFunction {
pub(crate) name: String,
pub(crate) arguments: String,
}
// ── Public types for tool calling ──────────────────────────────
/// Definition of a tool that the LLM can invoke
#[derive(Debug, Clone, Serialize)]
pub struct ToolDefinition {
pub name: String,
pub description: String,
pub parameters: serde_json::Value,
}
/// A tool call request from the LLM
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LlmToolCall {
pub id: String,
pub name: String,
pub arguments: serde_json::Value,
}
/// A tool call in the request message format (for sending back tool_calls in assistant messages)
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolCallRequest {
pub id: String,
pub r#type: String,
pub function: ToolCallRequestFunction,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolCallRequestFunction {
pub name: String,
pub arguments: String,
}
/// Response from the LLM — either content or tool calls
#[derive(Debug, Clone)]
pub enum LlmResponse {
Content(String),
/// Tool calls with optional reasoning text from the LLM
ToolCalls {
calls: Vec<LlmToolCall>,
reasoning: String,
},
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
// ── ChatMessage ──────────────────────────────────────────────
#[test]
fn chat_message_serializes_minimal() {
let msg = ChatMessage {
role: "user".to_string(),
content: Some("hello".to_string()),
tool_calls: None,
tool_call_id: None,
};
let v = serde_json::to_value(&msg).unwrap();
assert_eq!(v["role"], "user");
assert_eq!(v["content"], "hello");
// None fields with skip_serializing_if should be absent
assert!(v.get("tool_calls").is_none());
assert!(v.get("tool_call_id").is_none());
}
#[test]
fn chat_message_serializes_with_tool_calls() {
let msg = ChatMessage {
role: "assistant".to_string(),
content: None,
tool_calls: Some(vec![ToolCallRequest {
id: "call_1".to_string(),
r#type: "function".to_string(),
function: ToolCallRequestFunction {
name: "get_weather".to_string(),
arguments: r#"{"city":"NYC"}"#.to_string(),
},
}]),
tool_call_id: None,
};
let v = serde_json::to_value(&msg).unwrap();
assert!(v["tool_calls"].is_array());
assert_eq!(v["tool_calls"][0]["function"]["name"], "get_weather");
}
#[test]
fn chat_message_content_null_when_none() {
let msg = ChatMessage {
role: "assistant".to_string(),
content: None,
tool_calls: None,
tool_call_id: None,
};
let v = serde_json::to_value(&msg).unwrap();
assert!(v["content"].is_null());
}
// ── ToolDefinition ───────────────────────────────────────────
#[test]
fn tool_definition_serializes() {
let td = ToolDefinition {
name: "search".to_string(),
description: "Search the web".to_string(),
parameters: json!({"type": "object", "properties": {"q": {"type": "string"}}}),
};
let v = serde_json::to_value(&td).unwrap();
assert_eq!(v["name"], "search");
assert_eq!(v["parameters"]["type"], "object");
}
#[test]
fn tool_definition_empty_parameters() {
let td = ToolDefinition {
name: "noop".to_string(),
description: "".to_string(),
parameters: json!({}),
};
let v = serde_json::to_value(&td).unwrap();
assert_eq!(v["parameters"], json!({}));
}
// ── LlmToolCall ──────────────────────────────────────────────
#[test]
fn llm_tool_call_roundtrip() {
let call = LlmToolCall {
id: "tc_42".to_string(),
name: "run_scan".to_string(),
arguments: json!({"path": "/tmp", "verbose": true}),
};
let serialized = serde_json::to_string(&call).unwrap();
let deserialized: LlmToolCall = serde_json::from_str(&serialized).unwrap();
assert_eq!(deserialized.id, "tc_42");
assert_eq!(deserialized.name, "run_scan");
assert_eq!(deserialized.arguments["path"], "/tmp");
assert_eq!(deserialized.arguments["verbose"], true);
}
#[test]
fn llm_tool_call_empty_arguments() {
let call = LlmToolCall {
id: "tc_0".to_string(),
name: "noop".to_string(),
arguments: json!({}),
};
let rt: LlmToolCall = serde_json::from_str(&serde_json::to_string(&call).unwrap()).unwrap();
assert!(rt.arguments.as_object().unwrap().is_empty());
}
// ── ToolCallRequest / ToolCallRequestFunction ────────────────
#[test]
fn tool_call_request_roundtrip() {
let req = ToolCallRequest {
id: "call_abc".to_string(),
r#type: "function".to_string(),
function: ToolCallRequestFunction {
name: "my_func".to_string(),
arguments: r#"{"x":1}"#.to_string(),
},
};
let json_str = serde_json::to_string(&req).unwrap();
let back: ToolCallRequest = serde_json::from_str(&json_str).unwrap();
assert_eq!(back.id, "call_abc");
assert_eq!(back.r#type, "function");
assert_eq!(back.function.name, "my_func");
assert_eq!(back.function.arguments, r#"{"x":1}"#);
}
#[test]
fn tool_call_request_type_field_serializes_as_type() {
let req = ToolCallRequest {
id: "id".to_string(),
r#type: "function".to_string(),
function: ToolCallRequestFunction {
name: "f".to_string(),
arguments: "{}".to_string(),
},
};
let v = serde_json::to_value(&req).unwrap();
// The field should be "type" in JSON, not "r#type"
assert!(v.get("type").is_some());
assert!(v.get("r#type").is_none());
}
// ── ChatCompletionRequest ────────────────────────────────────
#[test]
fn chat_completion_request_skips_none_fields() {
let req = ChatCompletionRequest {
model: "gpt-4".to_string(),
messages: vec![],
temperature: None,
max_tokens: None,
tools: None,
};
let v = serde_json::to_value(&req).unwrap();
assert_eq!(v["model"], "gpt-4");
assert!(v.get("temperature").is_none());
assert!(v.get("max_tokens").is_none());
assert!(v.get("tools").is_none());
}
#[test]
fn chat_completion_request_includes_set_fields() {
let req = ChatCompletionRequest {
model: "gpt-4".to_string(),
messages: vec![],
temperature: Some(0.7),
max_tokens: Some(1024),
tools: Some(vec![]),
};
let v = serde_json::to_value(&req).unwrap();
assert_eq!(v["temperature"], 0.7);
assert_eq!(v["max_tokens"], 1024);
assert!(v["tools"].is_array());
}
// ── ChatCompletionResponse deserialization ───────────────────
#[test]
fn chat_completion_response_deserializes_content() {
let json_str = r#"{"choices":[{"message":{"content":"Hello!"}}]}"#;
let resp: ChatCompletionResponse = serde_json::from_str(json_str).unwrap();
assert_eq!(resp.choices.len(), 1);
assert_eq!(resp.choices[0].message.content.as_deref(), Some("Hello!"));
assert!(resp.choices[0].message.tool_calls.is_none());
}
#[test]
fn chat_completion_response_deserializes_tool_calls() {
let json_str = r#"{
"choices": [{
"message": {
"tool_calls": [{
"id": "call_1",
"function": {"name": "search", "arguments": "{\"q\":\"rust\"}"}
}]
}
}]
}"#;
let resp: ChatCompletionResponse = serde_json::from_str(json_str).unwrap();
let tc = resp.choices[0].message.tool_calls.as_ref().unwrap();
assert_eq!(tc.len(), 1);
assert_eq!(tc[0].id, "call_1");
assert_eq!(tc[0].function.name, "search");
}
#[test]
fn chat_completion_response_defaults_missing_fields() {
// content and tool_calls are both missing — should default to None
let json_str = r#"{"choices":[{"message":{}}]}"#;
let resp: ChatCompletionResponse = serde_json::from_str(json_str).unwrap();
assert!(resp.choices[0].message.content.is_none());
assert!(resp.choices[0].message.tool_calls.is_none());
}
// ── LlmResponse ─────────────────────────────────────────────
#[test]
fn llm_response_content_variant() {
let resp = LlmResponse::Content("answer".to_string());
match resp {
LlmResponse::Content(s) => assert_eq!(s, "answer"),
_ => panic!("expected Content variant"),
}
}
#[test]
fn llm_response_tool_calls_variant() {
let resp = LlmResponse::ToolCalls {
calls: vec![LlmToolCall {
id: "1".to_string(),
name: "f".to_string(),
arguments: json!({}),
}],
reasoning: "because".to_string(),
};
match resp {
LlmResponse::ToolCalls { calls, reasoning } => {
assert_eq!(calls.len(), 1);
assert_eq!(reasoning, "because");
}
_ => panic!("expected ToolCalls variant"),
}
}
#[test]
fn llm_response_empty_content() {
let resp = LlmResponse::Content(String::new());
match resp {
LlmResponse::Content(s) => assert!(s.is_empty()),
_ => panic!("expected Content variant"),
}
}
}

View File

@@ -0,0 +1,150 @@
use futures_util::StreamExt;
use mongodb::bson::doc;
use compliance_core::models::dast::DastTarget;
use compliance_core::models::finding::Finding;
use compliance_core::models::pentest::CodeContextHint;
use compliance_core::models::sbom::SbomEntry;
use super::orchestrator::PentestOrchestrator;
impl PentestOrchestrator {
/// Fetch SAST findings, SBOM entries (with CVEs), and code graph entry points
/// for the repo linked to this DAST target.
pub(crate) async fn gather_repo_context(
&self,
target: &DastTarget,
) -> (Vec<Finding>, Vec<SbomEntry>, Vec<CodeContextHint>) {
let Some(repo_id) = &target.repo_id else {
return (Vec::new(), Vec::new(), Vec::new());
};
let sast_findings = self.fetch_sast_findings(repo_id).await;
let sbom_entries = self.fetch_vulnerable_sbom(repo_id).await;
let code_context = self.fetch_code_context(repo_id, &sast_findings).await;
tracing::info!(
repo_id,
sast_findings = sast_findings.len(),
vulnerable_deps = sbom_entries.len(),
code_hints = code_context.len(),
"Gathered code-awareness context for pentest"
);
(sast_findings, sbom_entries, code_context)
}
/// Fetch open/triaged SAST findings for the repo (not false positives or resolved)
async fn fetch_sast_findings(&self, repo_id: &str) -> Vec<Finding> {
let cursor = self
.db
.findings()
.find(doc! {
"repo_id": repo_id,
"status": { "$in": ["open", "triaged"] },
})
.sort(doc! { "severity": -1 })
.limit(100)
.await;
match cursor {
Ok(mut c) => {
let mut results = Vec::new();
while let Some(Ok(f)) = c.next().await {
results.push(f);
}
results
}
Err(e) => {
tracing::warn!("Failed to fetch SAST findings for pentest: {e}");
Vec::new()
}
}
}
/// Fetch SBOM entries that have known vulnerabilities
async fn fetch_vulnerable_sbom(&self, repo_id: &str) -> Vec<SbomEntry> {
let cursor = self
.db
.sbom_entries()
.find(doc! {
"repo_id": repo_id,
"known_vulnerabilities": { "$exists": true, "$ne": [] },
})
.limit(50)
.await;
match cursor {
Ok(mut c) => {
let mut results = Vec::new();
while let Some(Ok(e)) = c.next().await {
results.push(e);
}
results
}
Err(e) => {
tracing::warn!("Failed to fetch vulnerable SBOM entries: {e}");
Vec::new()
}
}
}
/// Build CodeContextHint objects from the code knowledge graph.
/// Maps entry points to their source files and links SAST findings.
async fn fetch_code_context(
&self,
repo_id: &str,
sast_findings: &[Finding],
) -> Vec<CodeContextHint> {
// Get entry point nodes from the code graph
let cursor = self
.db
.graph_nodes()
.find(doc! {
"repo_id": repo_id,
"is_entry_point": true,
})
.limit(50)
.await;
let nodes = match cursor {
Ok(mut c) => {
let mut results = Vec::new();
while let Some(Ok(n)) = c.next().await {
results.push(n);
}
results
}
Err(_) => return Vec::new(),
};
// Build hints by matching graph nodes to SAST findings by file path
nodes
.into_iter()
.map(|node| {
// Find SAST findings in the same file
let linked_vulns: Vec<String> = sast_findings
.iter()
.filter(|f| f.file_path.as_deref() == Some(&node.file_path))
.map(|f| {
format!(
"[{}] {}: {} (line {})",
f.severity,
f.scanner,
f.title,
f.line_number.unwrap_or(0)
)
})
.collect();
CodeContextHint {
endpoint_pattern: node.qualified_name.clone(),
handler_function: node.name.clone(),
file_path: node.file_path.clone(),
code_snippet: String::new(), // Could fetch from embeddings
known_vulnerabilities: linked_vulns,
}
})
.collect()
}
}

View File

@@ -1,4 +1,6 @@
mod context;
pub mod orchestrator;
mod prompt_builder;
pub mod report;
pub use orchestrator::PentestOrchestrator;

View File

@@ -1,31 +1,27 @@
use std::sync::Arc;
use std::time::Duration;
use futures_util::StreamExt;
use mongodb::bson::doc;
use tokio::sync::broadcast;
use compliance_core::models::dast::DastTarget;
use compliance_core::models::finding::{Finding, FindingStatus, Severity};
use compliance_core::models::pentest::*;
use compliance_core::models::sbom::SbomEntry;
use compliance_core::traits::pentest_tool::PentestToolContext;
use compliance_dast::ToolRegistry;
use crate::database::Database;
use crate::llm::client::{
ChatMessage, LlmResponse, ToolCallRequest, ToolCallRequestFunction, ToolDefinition,
use crate::llm::{
ChatMessage, LlmClient, LlmResponse, ToolCallRequest, ToolCallRequestFunction, ToolDefinition,
};
use crate::llm::LlmClient;
/// Maximum duration for a single pentest session before timeout
const SESSION_TIMEOUT: Duration = Duration::from_secs(30 * 60); // 30 minutes
pub struct PentestOrchestrator {
tool_registry: ToolRegistry,
llm: Arc<LlmClient>,
db: Database,
event_tx: broadcast::Sender<PentestEvent>,
pub(crate) tool_registry: ToolRegistry,
pub(crate) llm: Arc<LlmClient>,
pub(crate) db: Database,
pub(crate) event_tx: broadcast::Sender<PentestEvent>,
}
impl PentestOrchestrator {
@@ -39,10 +35,12 @@ impl PentestOrchestrator {
}
}
#[allow(dead_code)]
pub fn subscribe(&self) -> broadcast::Receiver<PentestEvent> {
self.event_tx.subscribe()
}
#[allow(dead_code)]
pub fn event_sender(&self) -> broadcast::Sender<PentestEvent> {
self.event_tx.clone()
}
@@ -111,18 +109,20 @@ impl PentestOrchestrator {
target: &DastTarget,
initial_message: &str,
) -> Result<(), crate::error::AgentError> {
let session_id = session
.id
.map(|oid| oid.to_hex())
.unwrap_or_default();
let session_id = session.id.map(|oid| oid.to_hex()).unwrap_or_default();
// Gather code-awareness context from linked repo
let (sast_findings, sbom_entries, code_context) =
self.gather_repo_context(target).await;
let (sast_findings, sbom_entries, code_context) = self.gather_repo_context(target).await;
// Build system prompt with code context
let system_prompt = self
.build_system_prompt(session, target, &sast_findings, &sbom_entries, &code_context)
.build_system_prompt(
session,
target,
&sast_findings,
&sbom_entries,
&code_context,
)
.await;
// Build tool definitions for LLM
@@ -182,8 +182,7 @@ impl PentestOrchestrator {
match response {
LlmResponse::Content(content) => {
let msg =
PentestMessage::assistant(session_id.clone(), content.clone());
let msg = PentestMessage::assistant(session_id.clone(), content.clone());
let _ = self.db.pentest_messages().insert_one(&msg).await;
let _ = self.event_tx.send(PentestEvent::Message {
content: content.clone(),
@@ -213,7 +212,10 @@ impl PentestOrchestrator {
}
break;
}
LlmResponse::ToolCalls { calls: tool_calls, reasoning } => {
LlmResponse::ToolCalls {
calls: tool_calls,
reasoning,
} => {
let tc_requests: Vec<ToolCallRequest> = tool_calls
.iter()
.map(|tc| ToolCallRequest {
@@ -221,15 +223,18 @@ impl PentestOrchestrator {
r#type: "function".to_string(),
function: ToolCallRequestFunction {
name: tc.name.clone(),
arguments: serde_json::to_string(&tc.arguments)
.unwrap_or_default(),
arguments: serde_json::to_string(&tc.arguments).unwrap_or_default(),
},
})
.collect();
messages.push(ChatMessage {
role: "assistant".to_string(),
content: if reasoning.is_empty() { None } else { Some(reasoning.clone()) },
content: if reasoning.is_empty() {
None
} else {
Some(reasoning.clone())
},
tool_calls: Some(tc_requests),
tool_call_id: None,
});
@@ -274,24 +279,30 @@ impl PentestOrchestrator {
let insert_result =
self.db.dast_findings().insert_one(&finding).await;
if let Ok(res) = &insert_result {
finding_ids.push(res.inserted_id.as_object_id().map(|oid| oid.to_hex()).unwrap_or_default());
}
let _ =
self.event_tx.send(PentestEvent::Finding {
finding_id: finding
.id
finding_ids.push(
res.inserted_id
.as_object_id()
.map(|oid| oid.to_hex())
.unwrap_or_default(),
title: finding.title.clone(),
severity: finding.severity.to_string(),
});
);
}
let _ = self.event_tx.send(PentestEvent::Finding {
finding_id: finding
.id
.map(|oid| oid.to_hex())
.unwrap_or_default(),
title: finding.title.clone(),
severity: finding.severity.to_string(),
});
}
// Compute risk score based on findings severity
let risk_score: Option<u8> = if findings_count > 0 {
Some(std::cmp::min(
100,
(findings_count as u8).saturating_mul(15).saturating_add(20),
(findings_count as u8)
.saturating_mul(15)
.saturating_add(20),
))
} else {
None
@@ -415,347 +426,4 @@ impl PentestOrchestrator {
Ok(())
}
// ── Code-Awareness: Gather context from linked repo ─────────
/// Fetch SAST findings, SBOM entries (with CVEs), and code graph entry points
/// for the repo linked to this DAST target.
async fn gather_repo_context(
&self,
target: &DastTarget,
) -> (Vec<Finding>, Vec<SbomEntry>, Vec<CodeContextHint>) {
let Some(repo_id) = &target.repo_id else {
return (Vec::new(), Vec::new(), Vec::new());
};
let sast_findings = self.fetch_sast_findings(repo_id).await;
let sbom_entries = self.fetch_vulnerable_sbom(repo_id).await;
let code_context = self.fetch_code_context(repo_id, &sast_findings).await;
tracing::info!(
repo_id,
sast_findings = sast_findings.len(),
vulnerable_deps = sbom_entries.len(),
code_hints = code_context.len(),
"Gathered code-awareness context for pentest"
);
(sast_findings, sbom_entries, code_context)
}
/// Fetch open/triaged SAST findings for the repo (not false positives or resolved)
async fn fetch_sast_findings(&self, repo_id: &str) -> Vec<Finding> {
let cursor = self
.db
.findings()
.find(doc! {
"repo_id": repo_id,
"status": { "$in": ["open", "triaged"] },
})
.sort(doc! { "severity": -1 })
.limit(100)
.await;
match cursor {
Ok(mut c) => {
let mut results = Vec::new();
while let Some(Ok(f)) = c.next().await {
results.push(f);
}
results
}
Err(e) => {
tracing::warn!("Failed to fetch SAST findings for pentest: {e}");
Vec::new()
}
}
}
/// Fetch SBOM entries that have known vulnerabilities
async fn fetch_vulnerable_sbom(&self, repo_id: &str) -> Vec<SbomEntry> {
let cursor = self
.db
.sbom_entries()
.find(doc! {
"repo_id": repo_id,
"known_vulnerabilities": { "$exists": true, "$ne": [] },
})
.limit(50)
.await;
match cursor {
Ok(mut c) => {
let mut results = Vec::new();
while let Some(Ok(e)) = c.next().await {
results.push(e);
}
results
}
Err(e) => {
tracing::warn!("Failed to fetch vulnerable SBOM entries: {e}");
Vec::new()
}
}
}
/// Build CodeContextHint objects from the code knowledge graph.
/// Maps entry points to their source files and links SAST findings.
async fn fetch_code_context(
&self,
repo_id: &str,
sast_findings: &[Finding],
) -> Vec<CodeContextHint> {
// Get entry point nodes from the code graph
let cursor = self
.db
.graph_nodes()
.find(doc! {
"repo_id": repo_id,
"is_entry_point": true,
})
.limit(50)
.await;
let nodes = match cursor {
Ok(mut c) => {
let mut results = Vec::new();
while let Some(Ok(n)) = c.next().await {
results.push(n);
}
results
}
Err(_) => return Vec::new(),
};
// Build hints by matching graph nodes to SAST findings by file path
nodes
.into_iter()
.map(|node| {
// Find SAST findings in the same file
let linked_vulns: Vec<String> = sast_findings
.iter()
.filter(|f| {
f.file_path.as_deref() == Some(&node.file_path)
})
.map(|f| {
format!(
"[{}] {}: {} (line {})",
f.severity,
f.scanner,
f.title,
f.line_number.unwrap_or(0)
)
})
.collect();
CodeContextHint {
endpoint_pattern: node.qualified_name.clone(),
handler_function: node.name.clone(),
file_path: node.file_path.clone(),
code_snippet: String::new(), // Could fetch from embeddings
known_vulnerabilities: linked_vulns,
}
})
.collect()
}
// ── System Prompt Builder ───────────────────────────────────
async fn build_system_prompt(
&self,
session: &PentestSession,
target: &DastTarget,
sast_findings: &[Finding],
sbom_entries: &[SbomEntry],
code_context: &[CodeContextHint],
) -> String {
let tool_names = self.tool_registry.list_names().join(", ");
let strategy_guidance = match session.strategy {
PentestStrategy::Quick => {
"Focus on the most common and impactful vulnerabilities. Run a quick recon, then target the highest-risk areas."
}
PentestStrategy::Comprehensive => {
"Perform a thorough assessment covering all vulnerability types. Start with recon, then systematically test each attack surface."
}
PentestStrategy::Targeted => {
"Focus specifically on areas highlighted by SAST findings and known CVEs. Prioritize exploiting known weaknesses."
}
PentestStrategy::Aggressive => {
"Use all available tools aggressively. Test with maximum payloads and attempt full exploitation."
}
PentestStrategy::Stealth => {
"Minimize noise. Use fewer requests, avoid aggressive payloads. Focus on passive analysis and targeted probes."
}
};
// Build SAST findings section
let sast_section = if sast_findings.is_empty() {
String::from("No SAST findings available for this target.")
} else {
let critical = sast_findings
.iter()
.filter(|f| f.severity == Severity::Critical)
.count();
let high = sast_findings
.iter()
.filter(|f| f.severity == Severity::High)
.count();
let mut section = format!(
"{} open findings ({} critical, {} high):\n",
sast_findings.len(),
critical,
high
);
// List the most important findings (critical/high first, up to 20)
for f in sast_findings.iter().take(20) {
let file_info = f
.file_path
.as_ref()
.map(|p| {
format!(
" in {}:{}",
p,
f.line_number.unwrap_or(0)
)
})
.unwrap_or_default();
let status_note = match f.status {
FindingStatus::Triaged => " [TRIAGED]",
_ => "",
};
section.push_str(&format!(
"- [{sev}] {title}{file}{status}\n",
sev = f.severity,
title = f.title,
file = file_info,
status = status_note,
));
if let Some(cwe) = &f.cwe {
section.push_str(&format!(" CWE: {cwe}\n"));
}
}
if sast_findings.len() > 20 {
section.push_str(&format!(
"... and {} more findings\n",
sast_findings.len() - 20
));
}
section
};
// Build SBOM/CVE section
let sbom_section = if sbom_entries.is_empty() {
String::from("No vulnerable dependencies identified.")
} else {
let mut section = format!(
"{} dependencies with known vulnerabilities:\n",
sbom_entries.len()
);
for entry in sbom_entries.iter().take(15) {
let cve_ids: Vec<&str> = entry
.known_vulnerabilities
.iter()
.map(|v| v.id.as_str())
.collect();
section.push_str(&format!(
"- {} {} ({}): {}\n",
entry.name,
entry.version,
entry.package_manager,
cve_ids.join(", ")
));
}
if sbom_entries.len() > 15 {
section.push_str(&format!(
"... and {} more vulnerable dependencies\n",
sbom_entries.len() - 15
));
}
section
};
// Build code context section
let code_section = if code_context.is_empty() {
String::from("No code knowledge graph available for this target.")
} else {
let with_vulns = code_context
.iter()
.filter(|c| !c.known_vulnerabilities.is_empty())
.count();
let mut section = format!(
"{} entry points identified ({} with linked SAST findings):\n",
code_context.len(),
with_vulns
);
for hint in code_context.iter().take(20) {
section.push_str(&format!(
"- {} ({})\n",
hint.endpoint_pattern, hint.file_path
));
for vuln in &hint.known_vulnerabilities {
section.push_str(&format!(" SAST: {vuln}\n"));
}
}
section
};
format!(
r#"You are an expert penetration tester conducting an authorized security assessment.
## Target
- **Name**: {target_name}
- **URL**: {base_url}
- **Type**: {target_type}
- **Rate Limit**: {rate_limit} req/s
- **Destructive Tests Allowed**: {allow_destructive}
- **Linked Repository**: {repo_linked}
## Strategy
{strategy_guidance}
## SAST Findings (Static Analysis)
{sast_section}
## Vulnerable Dependencies (SBOM)
{sbom_section}
## Code Entry Points (Knowledge Graph)
{code_section}
## Available Tools
{tool_names}
## Instructions
1. Start by running reconnaissance (recon tool) to fingerprint the target and discover technologies.
2. Run the OpenAPI parser to discover API endpoints from specs.
3. Check infrastructure: DNS, DMARC, TLS, security headers, cookies, CSP, CORS.
4. Based on SAST findings, prioritize testing endpoints where vulnerabilities were found in code.
5. For each vulnerability type found in SAST, use the corresponding DAST tool to verify exploitability.
6. If vulnerable dependencies are listed, try to trigger known CVE conditions against the running application.
7. Test rate limiting on critical endpoints (login, API).
8. Check for console.log leakage in frontend JavaScript.
9. Analyze tool results and chain findings — if one vulnerability enables others, explore the chain.
10. When testing is complete, provide a structured summary with severity and remediation.
11. Always explain your reasoning before invoking each tool.
12. When done, say "Testing complete" followed by a final summary.
## Important
- This is an authorized penetration test. All testing is permitted within the target scope.
- Respect the rate limit of {rate_limit} requests per second.
- Only use destructive tests if explicitly allowed ({allow_destructive}).
- Use SAST findings to guide your testing — they tell you WHERE in the code vulnerabilities exist.
- Use SBOM data to understand what technologies and versions the target runs.
"#,
target_name = target.name,
base_url = target.base_url,
target_type = target.target_type,
rate_limit = target.rate_limit,
allow_destructive = target.allow_destructive,
repo_linked = target.repo_id.as_deref().unwrap_or("None"),
)
}
}

View File

@@ -0,0 +1,504 @@
use compliance_core::models::dast::DastTarget;
use compliance_core::models::finding::{Finding, FindingStatus, Severity};
use compliance_core::models::pentest::*;
use compliance_core::models::sbom::SbomEntry;
use super::orchestrator::PentestOrchestrator;
/// Return strategy guidance text for the given strategy.
fn strategy_guidance(strategy: &PentestStrategy) -> &'static str {
match strategy {
PentestStrategy::Quick => {
"Focus on the most common and impactful vulnerabilities. Run a quick recon, then target the highest-risk areas."
}
PentestStrategy::Comprehensive => {
"Perform a thorough assessment covering all vulnerability types. Start with recon, then systematically test each attack surface."
}
PentestStrategy::Targeted => {
"Focus specifically on areas highlighted by SAST findings and known CVEs. Prioritize exploiting known weaknesses."
}
PentestStrategy::Aggressive => {
"Use all available tools aggressively. Test with maximum payloads and attempt full exploitation."
}
PentestStrategy::Stealth => {
"Minimize noise. Use fewer requests, avoid aggressive payloads. Focus on passive analysis and targeted probes."
}
}
}
/// Build the SAST findings section for the system prompt.
fn build_sast_section(sast_findings: &[Finding]) -> String {
if sast_findings.is_empty() {
return String::from("No SAST findings available for this target.");
}
let critical = sast_findings
.iter()
.filter(|f| f.severity == Severity::Critical)
.count();
let high = sast_findings
.iter()
.filter(|f| f.severity == Severity::High)
.count();
let mut section = format!(
"{} open findings ({} critical, {} high):\n",
sast_findings.len(),
critical,
high
);
// List the most important findings (critical/high first, up to 20)
for f in sast_findings.iter().take(20) {
let file_info = f
.file_path
.as_ref()
.map(|p| format!(" in {}:{}", p, f.line_number.unwrap_or(0)))
.unwrap_or_default();
let status_note = match f.status {
FindingStatus::Triaged => " [TRIAGED]",
_ => "",
};
section.push_str(&format!(
"- [{sev}] {title}{file}{status}\n",
sev = f.severity,
title = f.title,
file = file_info,
status = status_note,
));
if let Some(cwe) = &f.cwe {
section.push_str(&format!(" CWE: {cwe}\n"));
}
}
if sast_findings.len() > 20 {
section.push_str(&format!(
"... and {} more findings\n",
sast_findings.len() - 20
));
}
section
}
/// Build the SBOM/CVE section for the system prompt.
fn build_sbom_section(sbom_entries: &[SbomEntry]) -> String {
if sbom_entries.is_empty() {
return String::from("No vulnerable dependencies identified.");
}
let mut section = format!(
"{} dependencies with known vulnerabilities:\n",
sbom_entries.len()
);
for entry in sbom_entries.iter().take(15) {
let cve_ids: Vec<&str> = entry
.known_vulnerabilities
.iter()
.map(|v| v.id.as_str())
.collect();
section.push_str(&format!(
"- {} {} ({}): {}\n",
entry.name,
entry.version,
entry.package_manager,
cve_ids.join(", ")
));
}
if sbom_entries.len() > 15 {
section.push_str(&format!(
"... and {} more vulnerable dependencies\n",
sbom_entries.len() - 15
));
}
section
}
/// Build the code context section for the system prompt.
fn build_code_section(code_context: &[CodeContextHint]) -> String {
if code_context.is_empty() {
return String::from("No code knowledge graph available for this target.");
}
let with_vulns = code_context
.iter()
.filter(|c| !c.known_vulnerabilities.is_empty())
.count();
let mut section = format!(
"{} entry points identified ({} with linked SAST findings):\n",
code_context.len(),
with_vulns
);
for hint in code_context.iter().take(20) {
section.push_str(&format!(
"- {} ({})\n",
hint.endpoint_pattern, hint.file_path
));
for vuln in &hint.known_vulnerabilities {
section.push_str(&format!(" SAST: {vuln}\n"));
}
}
section
}
impl PentestOrchestrator {
pub(crate) async fn build_system_prompt(
&self,
session: &PentestSession,
target: &DastTarget,
sast_findings: &[Finding],
sbom_entries: &[SbomEntry],
code_context: &[CodeContextHint],
) -> String {
let tool_names = self.tool_registry.list_names().join(", ");
let guidance = strategy_guidance(&session.strategy);
let sast_section = build_sast_section(sast_findings);
let sbom_section = build_sbom_section(sbom_entries);
let code_section = build_code_section(code_context);
format!(
r#"You are an expert penetration tester conducting an authorized security assessment.
## Target
- **Name**: {target_name}
- **URL**: {base_url}
- **Type**: {target_type}
- **Rate Limit**: {rate_limit} req/s
- **Destructive Tests Allowed**: {allow_destructive}
- **Linked Repository**: {repo_linked}
## Strategy
{strategy_guidance}
## SAST Findings (Static Analysis)
{sast_section}
## Vulnerable Dependencies (SBOM)
{sbom_section}
## Code Entry Points (Knowledge Graph)
{code_section}
## Available Tools
{tool_names}
## Instructions
1. Start by running reconnaissance (recon tool) to fingerprint the target and discover technologies.
2. Run the OpenAPI parser to discover API endpoints from specs.
3. Check infrastructure: DNS, DMARC, TLS, security headers, cookies, CSP, CORS.
4. Based on SAST findings, prioritize testing endpoints where vulnerabilities were found in code.
5. For each vulnerability type found in SAST, use the corresponding DAST tool to verify exploitability.
6. If vulnerable dependencies are listed, try to trigger known CVE conditions against the running application.
7. Test rate limiting on critical endpoints (login, API).
8. Check for console.log leakage in frontend JavaScript.
9. Analyze tool results and chain findings — if one vulnerability enables others, explore the chain.
10. When testing is complete, provide a structured summary with severity and remediation.
11. Always explain your reasoning before invoking each tool.
12. When done, say "Testing complete" followed by a final summary.
## Important
- This is an authorized penetration test. All testing is permitted within the target scope.
- Respect the rate limit of {rate_limit} requests per second.
- Only use destructive tests if explicitly allowed ({allow_destructive}).
- Use SAST findings to guide your testing — they tell you WHERE in the code vulnerabilities exist.
- Use SBOM data to understand what technologies and versions the target runs.
"#,
target_name = target.name,
base_url = target.base_url,
target_type = target.target_type,
rate_limit = target.rate_limit,
allow_destructive = target.allow_destructive,
repo_linked = target.repo_id.as_deref().unwrap_or("None"),
strategy_guidance = guidance,
)
}
}
#[cfg(test)]
mod tests {
use super::*;
use compliance_core::models::finding::Severity;
use compliance_core::models::sbom::VulnRef;
use compliance_core::models::scan::ScanType;
fn make_finding(
severity: Severity,
title: &str,
file_path: Option<&str>,
line: Option<u32>,
status: FindingStatus,
cwe: Option<&str>,
) -> Finding {
let mut f = Finding::new(
"repo-1".into(),
format!("fp-{title}"),
"semgrep".into(),
ScanType::Sast,
title.into(),
"desc".into(),
severity,
);
f.file_path = file_path.map(|s| s.to_string());
f.line_number = line;
f.status = status;
f.cwe = cwe.map(|s| s.to_string());
f
}
fn make_sbom_entry(name: &str, version: &str, cves: &[&str]) -> SbomEntry {
let mut entry = SbomEntry::new("repo-1".into(), name.into(), version.into(), "npm".into());
entry.known_vulnerabilities = cves
.iter()
.map(|id| VulnRef {
id: id.to_string(),
source: "nvd".into(),
severity: None,
url: None,
})
.collect();
entry
}
fn make_code_hint(endpoint: &str, file: &str, vulns: Vec<String>) -> CodeContextHint {
CodeContextHint {
endpoint_pattern: endpoint.into(),
handler_function: "handler".into(),
file_path: file.into(),
code_snippet: String::new(),
known_vulnerabilities: vulns,
}
}
// ── strategy_guidance ────────────────────────────────────────────
#[test]
fn strategy_guidance_quick() {
let g = strategy_guidance(&PentestStrategy::Quick);
assert!(g.contains("most common"));
assert!(g.contains("quick recon"));
}
#[test]
fn strategy_guidance_comprehensive() {
let g = strategy_guidance(&PentestStrategy::Comprehensive);
assert!(g.contains("thorough assessment"));
}
#[test]
fn strategy_guidance_targeted() {
let g = strategy_guidance(&PentestStrategy::Targeted);
assert!(g.contains("SAST findings"));
assert!(g.contains("known CVEs"));
}
#[test]
fn strategy_guidance_aggressive() {
let g = strategy_guidance(&PentestStrategy::Aggressive);
assert!(g.contains("aggressively"));
assert!(g.contains("full exploitation"));
}
#[test]
fn strategy_guidance_stealth() {
let g = strategy_guidance(&PentestStrategy::Stealth);
assert!(g.contains("Minimize noise"));
assert!(g.contains("passive analysis"));
}
// ── build_sast_section ───────────────────────────────────────────
#[test]
fn sast_section_empty() {
let section = build_sast_section(&[]);
assert_eq!(section, "No SAST findings available for this target.");
}
#[test]
fn sast_section_single_critical() {
let findings = vec![make_finding(
Severity::Critical,
"SQL Injection",
Some("src/db.rs"),
Some(42),
FindingStatus::Open,
Some("CWE-89"),
)];
let section = build_sast_section(&findings);
assert!(section.contains("1 open findings (1 critical, 0 high)"));
assert!(section.contains("[critical] SQL Injection in src/db.rs:42"));
assert!(section.contains("CWE: CWE-89"));
}
#[test]
fn sast_section_triaged_finding_shows_marker() {
let findings = vec![make_finding(
Severity::High,
"XSS",
None,
None,
FindingStatus::Triaged,
None,
)];
let section = build_sast_section(&findings);
assert!(section.contains("[TRIAGED]"));
}
#[test]
fn sast_section_no_file_path_omits_location() {
let findings = vec![make_finding(
Severity::Medium,
"Open Redirect",
None,
None,
FindingStatus::Open,
None,
)];
let section = build_sast_section(&findings);
assert!(section.contains("- [medium] Open Redirect\n"));
assert!(!section.contains(" in "));
}
#[test]
fn sast_section_counts_critical_and_high() {
let findings = vec![
make_finding(
Severity::Critical,
"F1",
None,
None,
FindingStatus::Open,
None,
),
make_finding(
Severity::Critical,
"F2",
None,
None,
FindingStatus::Open,
None,
),
make_finding(Severity::High, "F3", None, None, FindingStatus::Open, None),
make_finding(
Severity::Medium,
"F4",
None,
None,
FindingStatus::Open,
None,
),
];
let section = build_sast_section(&findings);
assert!(section.contains("4 open findings (2 critical, 1 high)"));
}
#[test]
fn sast_section_truncates_at_20() {
let findings: Vec<Finding> = (0..25)
.map(|i| {
make_finding(
Severity::Low,
&format!("Finding {i}"),
None,
None,
FindingStatus::Open,
None,
)
})
.collect();
let section = build_sast_section(&findings);
assert!(section.contains("... and 5 more findings"));
// Should contain Finding 19 (the 20th) but not Finding 20 (the 21st)
assert!(section.contains("Finding 19"));
assert!(!section.contains("Finding 20"));
}
// ── build_sbom_section ───────────────────────────────────────────
#[test]
fn sbom_section_empty() {
let section = build_sbom_section(&[]);
assert_eq!(section, "No vulnerable dependencies identified.");
}
#[test]
fn sbom_section_single_entry() {
let entries = vec![make_sbom_entry("lodash", "4.17.20", &["CVE-2021-23337"])];
let section = build_sbom_section(&entries);
assert!(section.contains("1 dependencies with known vulnerabilities"));
assert!(section.contains("- lodash 4.17.20 (npm): CVE-2021-23337"));
}
#[test]
fn sbom_section_multiple_cves() {
let entries = vec![make_sbom_entry(
"openssl",
"1.1.1",
&["CVE-2022-0001", "CVE-2022-0002"],
)];
let section = build_sbom_section(&entries);
assert!(section.contains("CVE-2022-0001, CVE-2022-0002"));
}
#[test]
fn sbom_section_truncates_at_15() {
let entries: Vec<SbomEntry> = (0..18)
.map(|i| make_sbom_entry(&format!("pkg-{i}"), "1.0.0", &["CVE-2024-0001"]))
.collect();
let section = build_sbom_section(&entries);
assert!(section.contains("... and 3 more vulnerable dependencies"));
assert!(section.contains("pkg-14"));
assert!(!section.contains("pkg-15"));
}
// ── build_code_section ───────────────────────────────────────────
#[test]
fn code_section_empty() {
let section = build_code_section(&[]);
assert_eq!(
section,
"No code knowledge graph available for this target."
);
}
#[test]
fn code_section_single_entry_no_vulns() {
let hints = vec![make_code_hint("GET /api/users", "src/routes.rs", vec![])];
let section = build_code_section(&hints);
assert!(section.contains("1 entry points identified (0 with linked SAST findings)"));
assert!(section.contains("- GET /api/users (src/routes.rs)"));
}
#[test]
fn code_section_with_linked_vulns() {
let hints = vec![make_code_hint(
"POST /login",
"src/auth.rs",
vec!["[critical] semgrep: SQL Injection (line 15)".into()],
)];
let section = build_code_section(&hints);
assert!(section.contains("1 entry points identified (1 with linked SAST findings)"));
assert!(section.contains("SAST: [critical] semgrep: SQL Injection (line 15)"));
}
#[test]
fn code_section_counts_entries_with_vulns() {
let hints = vec![
make_code_hint("GET /a", "a.rs", vec!["vuln1".into()]),
make_code_hint("GET /b", "b.rs", vec![]),
make_code_hint("GET /c", "c.rs", vec!["vuln2".into(), "vuln3".into()]),
];
let section = build_code_section(&hints);
assert!(section.contains("3 entry points identified (2 with linked SAST findings)"));
}
#[test]
fn code_section_truncates_at_20() {
let hints: Vec<CodeContextHint> = (0..25)
.map(|i| make_code_hint(&format!("GET /ep{i}"), &format!("f{i}.rs"), vec![]))
.collect();
let section = build_code_section(&hints);
assert!(section.contains("GET /ep19"));
assert!(!section.contains("GET /ep20"));
}
}

View File

@@ -0,0 +1,43 @@
use std::io::{Cursor, Write};
use zip::write::SimpleFileOptions;
use zip::AesMode;
use super::ReportContext;
pub(super) fn build_zip(
ctx: &ReportContext,
password: &str,
html: &str,
pdf: &[u8],
) -> Result<Vec<u8>, zip::result::ZipError> {
let buf = Cursor::new(Vec::new());
let mut zip = zip::ZipWriter::new(buf);
let options = SimpleFileOptions::default()
.compression_method(zip::CompressionMethod::Deflated)
.with_aes_encryption(AesMode::Aes256, password);
// report.pdf (primary)
zip.start_file("report.pdf", options)?;
zip.write_all(pdf)?;
// report.html (fallback)
zip.start_file("report.html", options)?;
zip.write_all(html.as_bytes())?;
// findings.json
let findings_json =
serde_json::to_string_pretty(&ctx.findings).unwrap_or_else(|_| "[]".to_string());
zip.start_file("findings.json", options)?;
zip.write_all(findings_json.as_bytes())?;
// attack-chain.json
let chain_json =
serde_json::to_string_pretty(&ctx.attack_chain).unwrap_or_else(|_| "[]".to_string());
zip.start_file("attack-chain.json", options)?;
zip.write_all(chain_json.as_bytes())?;
let cursor = zip.finish()?;
Ok(cursor.into_inner())
}

View File

@@ -1,193 +1,50 @@
use std::io::{Cursor, Write};
use compliance_core::models::dast::DastFinding;
use compliance_core::models::pentest::{AttackChainNode, PentestSession};
use sha2::{Digest, Sha256};
use zip::write::SimpleFileOptions;
use zip::AesMode;
use compliance_core::models::pentest::AttackChainNode;
/// Report archive with metadata
pub struct ReportArchive {
/// The password-protected ZIP bytes
pub archive: Vec<u8>,
/// SHA-256 hex digest of the archive
pub sha256: String,
}
use super::ReportContext;
/// Report context gathered from the database
pub struct ReportContext {
pub session: PentestSession,
pub target_name: String,
pub target_url: String,
pub findings: Vec<DastFinding>,
pub attack_chain: Vec<AttackChainNode>,
pub requester_name: String,
pub requester_email: String,
}
/// Generate a password-protected ZIP archive containing the pentest report.
///
/// The archive contains:
/// - `report.pdf` — Professional pentest report (PDF)
/// - `report.html` — HTML source (fallback)
/// - `findings.json` — Raw findings data
/// - `attack-chain.json` — Attack chain timeline
///
/// Files are encrypted with AES-256 inside the ZIP (standard WinZip AES format,
/// supported by 7-Zip, WinRAR, macOS Archive Utility, etc.).
pub async fn generate_encrypted_report(
ctx: &ReportContext,
password: &str,
) -> Result<ReportArchive, String> {
let html = build_html_report(ctx);
// Convert HTML to PDF via headless Chrome
let pdf_bytes = html_to_pdf(&html).await?;
let zip_bytes = build_zip(ctx, password, &html, &pdf_bytes)
.map_err(|e| format!("Failed to create archive: {e}"))?;
let mut hasher = Sha256::new();
hasher.update(&zip_bytes);
let sha256 = hex::encode(hasher.finalize());
Ok(ReportArchive { archive: zip_bytes, sha256 })
}
/// Convert HTML string to PDF bytes using headless Chrome/Chromium.
async fn html_to_pdf(html: &str) -> Result<Vec<u8>, String> {
let tmp_dir = std::env::temp_dir();
let run_id = uuid::Uuid::new_v4().to_string();
let html_path = tmp_dir.join(format!("pentest-report-{run_id}.html"));
let pdf_path = tmp_dir.join(format!("pentest-report-{run_id}.pdf"));
// Write HTML to temp file
std::fs::write(&html_path, html)
.map_err(|e| format!("Failed to write temp HTML: {e}"))?;
// Find Chrome/Chromium binary
let chrome_bin = find_chrome_binary()
.ok_or_else(|| "Chrome/Chromium not found. Install google-chrome or chromium to generate PDF reports.".to_string())?;
tracing::info!(chrome = %chrome_bin, "Generating PDF report via headless Chrome");
let html_url = format!("file://{}", html_path.display());
let output = tokio::process::Command::new(&chrome_bin)
.args([
"--headless",
"--disable-gpu",
"--no-sandbox",
"--disable-software-rasterizer",
"--run-all-compositor-stages-before-draw",
"--disable-dev-shm-usage",
&format!("--print-to-pdf={}", pdf_path.display()),
"--no-pdf-header-footer",
&html_url,
])
.output()
.await
.map_err(|e| format!("Failed to run Chrome: {e}"))?;
if !output.status.success() {
let stderr = String::from_utf8_lossy(&output.stderr);
// Clean up temp files
let _ = std::fs::remove_file(&html_path);
let _ = std::fs::remove_file(&pdf_path);
return Err(format!("Chrome PDF generation failed: {stderr}"));
}
let pdf_bytes = std::fs::read(&pdf_path)
.map_err(|e| format!("Failed to read generated PDF: {e}"))?;
// Clean up temp files
let _ = std::fs::remove_file(&html_path);
let _ = std::fs::remove_file(&pdf_path);
if pdf_bytes.is_empty() {
return Err("Chrome produced an empty PDF".to_string());
}
tracing::info!(size_kb = pdf_bytes.len() / 1024, "PDF report generated");
Ok(pdf_bytes)
}
/// Search for Chrome/Chromium binary on the system.
fn find_chrome_binary() -> Option<String> {
let candidates = [
"google-chrome-stable",
"google-chrome",
"chromium-browser",
"chromium",
];
for name in &candidates {
if let Ok(output) = std::process::Command::new("which").arg(name).output() {
if output.status.success() {
let path = String::from_utf8_lossy(&output.stdout).trim().to_string();
if !path.is_empty() {
return Some(path);
}
}
}
}
None
}
fn build_zip(
ctx: &ReportContext,
password: &str,
html: &str,
pdf: &[u8],
) -> Result<Vec<u8>, zip::result::ZipError> {
let buf = Cursor::new(Vec::new());
let mut zip = zip::ZipWriter::new(buf);
let options = SimpleFileOptions::default()
.compression_method(zip::CompressionMethod::Deflated)
.with_aes_encryption(AesMode::Aes256, password);
// report.pdf (primary)
zip.start_file("report.pdf", options.clone())?;
zip.write_all(pdf)?;
// report.html (fallback)
zip.start_file("report.html", options.clone())?;
zip.write_all(html.as_bytes())?;
// findings.json
let findings_json =
serde_json::to_string_pretty(&ctx.findings).unwrap_or_else(|_| "[]".to_string());
zip.start_file("findings.json", options.clone())?;
zip.write_all(findings_json.as_bytes())?;
// attack-chain.json
let chain_json =
serde_json::to_string_pretty(&ctx.attack_chain).unwrap_or_else(|_| "[]".to_string());
zip.start_file("attack-chain.json", options)?;
zip.write_all(chain_json.as_bytes())?;
let cursor = zip.finish()?;
Ok(cursor.into_inner())
}
fn build_html_report(ctx: &ReportContext) -> String {
#[allow(clippy::format_in_format_args)]
pub(super) fn build_html_report(ctx: &ReportContext) -> String {
let session = &ctx.session;
let session_id = session
.id
.map(|oid| oid.to_hex())
.unwrap_or_else(|| "-".to_string());
let date_str = session.started_at.format("%B %d, %Y at %H:%M UTC").to_string();
let date_str = session
.started_at
.format("%B %d, %Y at %H:%M UTC")
.to_string();
let date_short = session.started_at.format("%B %d, %Y").to_string();
let completed_str = session
.completed_at
.map(|d| d.format("%B %d, %Y at %H:%M UTC").to_string())
.unwrap_or_else(|| "In Progress".to_string());
let critical = ctx.findings.iter().filter(|f| f.severity.to_string() == "critical").count();
let high = ctx.findings.iter().filter(|f| f.severity.to_string() == "high").count();
let medium = ctx.findings.iter().filter(|f| f.severity.to_string() == "medium").count();
let low = ctx.findings.iter().filter(|f| f.severity.to_string() == "low").count();
let info = ctx.findings.iter().filter(|f| f.severity.to_string() == "info").count();
let critical = ctx
.findings
.iter()
.filter(|f| f.severity.to_string() == "critical")
.count();
let high = ctx
.findings
.iter()
.filter(|f| f.severity.to_string() == "high")
.count();
let medium = ctx
.findings
.iter()
.filter(|f| f.severity.to_string() == "medium")
.count();
let low = ctx
.findings
.iter()
.filter(|f| f.severity.to_string() == "low")
.count();
let info = ctx
.findings
.iter()
.filter(|f| f.severity.to_string() == "info")
.count();
let exploitable = ctx.findings.iter().filter(|f| f.exploitable).count();
let total = ctx.findings.len();
@@ -212,10 +69,8 @@ fn build_html_report(ctx: &ReportContext) -> String {
};
// Risk score 0-100
let risk_score: usize = std::cmp::min(
100,
critical * 25 + high * 15 + medium * 8 + low * 3 + info * 1,
);
let risk_score: usize =
std::cmp::min(100, critical * 25 + high * 15 + medium * 8 + low * 3 + info);
// Collect unique tool names used
let tool_names: Vec<String> = {
@@ -247,7 +102,8 @@ fn build_html_report(ctx: &ReportContext) -> String {
if high > 0 {
bar.push_str(&format!(
r#"<div class="sev-bar-seg sev-bar-high" style="width:{}%"><span>{}</span></div>"#,
std::cmp::max(high_pct, 4), high
std::cmp::max(high_pct, 4),
high
));
}
if medium > 0 {
@@ -259,22 +115,38 @@ fn build_html_report(ctx: &ReportContext) -> String {
if low > 0 {
bar.push_str(&format!(
r#"<div class="sev-bar-seg sev-bar-low" style="width:{}%"><span>{}</span></div>"#,
std::cmp::max(low_pct, 4), low
std::cmp::max(low_pct, 4),
low
));
}
if info > 0 {
bar.push_str(&format!(
r#"<div class="sev-bar-seg sev-bar-info" style="width:{}%"><span>{}</span></div>"#,
std::cmp::max(info_pct, 4), info
std::cmp::max(info_pct, 4),
info
));
}
bar.push_str("</div>");
bar.push_str(r#"<div class="sev-bar-legend">"#);
if critical > 0 { bar.push_str(r#"<span><i class="sev-dot" style="background:#991b1b"></i> Critical</span>"#); }
if high > 0 { bar.push_str(r#"<span><i class="sev-dot" style="background:#c2410c"></i> High</span>"#); }
if medium > 0 { bar.push_str(r#"<span><i class="sev-dot" style="background:#a16207"></i> Medium</span>"#); }
if low > 0 { bar.push_str(r#"<span><i class="sev-dot" style="background:#1d4ed8"></i> Low</span>"#); }
if info > 0 { bar.push_str(r#"<span><i class="sev-dot" style="background:#4b5563"></i> Info</span>"#); }
if critical > 0 {
bar.push_str(
r#"<span><i class="sev-dot" style="background:#991b1b"></i> Critical</span>"#,
);
}
if high > 0 {
bar.push_str(r#"<span><i class="sev-dot" style="background:#c2410c"></i> High</span>"#);
}
if medium > 0 {
bar.push_str(
r#"<span><i class="sev-dot" style="background:#a16207"></i> Medium</span>"#,
);
}
if low > 0 {
bar.push_str(r#"<span><i class="sev-dot" style="background:#1d4ed8"></i> Low</span>"#);
}
if info > 0 {
bar.push_str(r#"<span><i class="sev-dot" style="background:#4b5563"></i> Info</span>"#);
}
bar.push_str("</div>");
bar
} else {
@@ -322,7 +194,12 @@ fn build_html_report(ctx: &ReportContext) -> String {
let param_row = f
.parameter
.as_deref()
.map(|p| format!("<tr><td>Parameter</td><td><code>{}</code></td></tr>", html_escape(p)))
.map(|p| {
format!(
"<tr><td>Parameter</td><td><code>{}</code></td></tr>",
html_escape(p)
)
})
.unwrap_or_default();
let remediation = f
.remediation
@@ -332,7 +209,9 @@ fn build_html_report(ctx: &ReportContext) -> String {
let evidence_html = if f.evidence.is_empty() {
String::new()
} else {
let mut eh = String::from(r#"<div class="evidence-block"><div class="evidence-title">Evidence</div><table class="evidence-table"><thead><tr><th>Request</th><th>Status</th><th>Details</th></tr></thead><tbody>"#);
let mut eh = String::from(
r#"<div class="evidence-block"><div class="evidence-title">Evidence</div><table class="evidence-table"><thead><tr><th>Request</th><th>Status</th><th>Details</th></tr></thead><tbody>"#,
);
for ev in &f.evidence {
let payload_info = ev
.payload
@@ -346,7 +225,7 @@ fn build_html_report(ctx: &ReportContext) -> String {
ev.response_status,
ev.response_snippet
.as_deref()
.map(|s| html_escape(s))
.map(html_escape)
.unwrap_or_default(),
payload_info,
));
@@ -402,7 +281,8 @@ fn build_html_report(ctx: &ReportContext) -> String {
let mut chain_html = String::new();
if !ctx.attack_chain.is_empty() {
// Compute phases via BFS from root nodes
let mut phase_map: std::collections::HashMap<String, usize> = std::collections::HashMap::new();
let mut phase_map: std::collections::HashMap<String, usize> =
std::collections::HashMap::new();
let mut queue: std::collections::VecDeque<String> = std::collections::VecDeque::new();
for node in &ctx.attack_chain {
@@ -438,7 +318,13 @@ fn build_html_report(ctx: &ReportContext) -> String {
// Group nodes by phase
let max_phase = phase_map.values().copied().max().unwrap_or(0);
let phase_labels = ["Reconnaissance", "Enumeration", "Exploitation", "Validation", "Post-Exploitation"];
let phase_labels = [
"Reconnaissance",
"Enumeration",
"Exploitation",
"Validation",
"Post-Exploitation",
];
for phase_idx in 0..=max_phase {
let phase_nodes: Vec<&AttackChainNode> = ctx
@@ -485,15 +371,28 @@ fn build_html_report(ctx: &ReportContext) -> String {
format!(
r#"<span class="step-findings">{} finding{}</span>"#,
node.findings_produced.len(),
if node.findings_produced.len() == 1 { "" } else { "s" },
if node.findings_produced.len() == 1 {
""
} else {
"s"
},
)
} else {
String::new()
};
let risk_badge = node.risk_score.map(|r| {
let risk_class = if r >= 70 { "risk-high" } else if r >= 40 { "risk-med" } else { "risk-low" };
format!(r#"<span class="step-risk {risk_class}">Risk: {r}</span>"#)
}).unwrap_or_default();
let risk_badge = node
.risk_score
.map(|r| {
let risk_class = if r >= 70 {
"risk-high"
} else if r >= 40 {
"risk-med"
} else {
"risk-low"
};
format!(r#"<span class="step-risk {risk_class}">Risk: {r}</span>"#)
})
.unwrap_or_default();
let reasoning_html = if node.llm_reasoning.is_empty() {
String::new()
@@ -547,10 +446,20 @@ fn build_html_report(ctx: &ReportContext) -> String {
let toc_findings_sub = if !ctx.findings.is_empty() {
let mut sub = String::new();
let mut fnum = 0usize;
for (si, &sev_key) in severity_order.iter().enumerate() {
let count = ctx.findings.iter().filter(|f| f.severity.to_string() == sev_key).count();
if count == 0 { continue; }
for f in ctx.findings.iter().filter(|f| f.severity.to_string() == sev_key) {
for &sev_key in severity_order.iter() {
let count = ctx
.findings
.iter()
.filter(|f| f.severity.to_string() == sev_key)
.count();
if count == 0 {
continue;
}
for f in ctx
.findings
.iter()
.filter(|f| f.severity.to_string() == sev_key)
{
fnum += 1;
sub.push_str(&format!(
r#"<div class="toc-sub">F-{:03} — {}</div>"#,
@@ -1577,19 +1486,49 @@ table.tools-table td:first-child {{
fn tool_category(tool_name: &str) -> &'static str {
let name = tool_name.to_lowercase();
if name.contains("nmap") || name.contains("port") { return "Network Reconnaissance"; }
if name.contains("nikto") || name.contains("header") { return "Web Server Analysis"; }
if name.contains("zap") || name.contains("spider") || name.contains("crawl") { return "Web Application Scanning"; }
if name.contains("sqlmap") || name.contains("sqli") || name.contains("sql") { return "SQL Injection Testing"; }
if name.contains("xss") || name.contains("cross-site") { return "Cross-Site Scripting Testing"; }
if name.contains("dir") || name.contains("brute") || name.contains("fuzz") || name.contains("gobuster") { return "Directory Enumeration"; }
if name.contains("ssl") || name.contains("tls") || name.contains("cert") { return "SSL/TLS Analysis"; }
if name.contains("api") || name.contains("endpoint") { return "API Security Testing"; }
if name.contains("auth") || name.contains("login") || name.contains("credential") { return "Authentication Testing"; }
if name.contains("cors") { return "CORS Testing"; }
if name.contains("csrf") { return "CSRF Testing"; }
if name.contains("nuclei") || name.contains("template") { return "Vulnerability Scanning"; }
if name.contains("whatweb") || name.contains("tech") || name.contains("wappalyzer") { return "Technology Fingerprinting"; }
if name.contains("nmap") || name.contains("port") {
return "Network Reconnaissance";
}
if name.contains("nikto") || name.contains("header") {
return "Web Server Analysis";
}
if name.contains("zap") || name.contains("spider") || name.contains("crawl") {
return "Web Application Scanning";
}
if name.contains("sqlmap") || name.contains("sqli") || name.contains("sql") {
return "SQL Injection Testing";
}
if name.contains("xss") || name.contains("cross-site") {
return "Cross-Site Scripting Testing";
}
if name.contains("dir")
|| name.contains("brute")
|| name.contains("fuzz")
|| name.contains("gobuster")
{
return "Directory Enumeration";
}
if name.contains("ssl") || name.contains("tls") || name.contains("cert") {
return "SSL/TLS Analysis";
}
if name.contains("api") || name.contains("endpoint") {
return "API Security Testing";
}
if name.contains("auth") || name.contains("login") || name.contains("credential") {
return "Authentication Testing";
}
if name.contains("cors") {
return "CORS Testing";
}
if name.contains("csrf") {
return "CSRF Testing";
}
if name.contains("nuclei") || name.contains("template") {
return "Vulnerability Scanning";
}
if name.contains("whatweb") || name.contains("tech") || name.contains("wappalyzer") {
return "Technology Fingerprinting";
}
"Security Testing"
}
@@ -1599,3 +1538,314 @@ fn html_escape(s: &str) -> String {
.replace('>', "&gt;")
.replace('"', "&quot;")
}
#[cfg(test)]
mod tests {
use super::*;
use compliance_core::models::dast::{DastFinding, DastVulnType};
use compliance_core::models::finding::Severity;
use compliance_core::models::pentest::{
AttackChainNode, AttackNodeStatus, PentestSession, PentestStrategy,
};
// ── html_escape ──────────────────────────────────────────────────
#[test]
fn html_escape_handles_ampersand() {
assert_eq!(html_escape("a & b"), "a &amp; b");
}
#[test]
fn html_escape_handles_angle_brackets() {
assert_eq!(html_escape("<script>"), "&lt;script&gt;");
}
#[test]
fn html_escape_handles_quotes() {
assert_eq!(html_escape(r#"key="val""#), "key=&quot;val&quot;");
}
#[test]
fn html_escape_handles_all_special_chars() {
assert_eq!(
html_escape(r#"<a href="x">&y</a>"#),
"&lt;a href=&quot;x&quot;&gt;&amp;y&lt;/a&gt;"
);
}
#[test]
fn html_escape_no_change_for_plain_text() {
assert_eq!(html_escape("hello world"), "hello world");
}
#[test]
fn html_escape_empty_string() {
assert_eq!(html_escape(""), "");
}
// ── tool_category ────────────────────────────────────────────────
#[test]
fn tool_category_nmap() {
assert_eq!(tool_category("nmap_scan"), "Network Reconnaissance");
}
#[test]
fn tool_category_port_scanner() {
assert_eq!(tool_category("port_scanner"), "Network Reconnaissance");
}
#[test]
fn tool_category_nikto() {
assert_eq!(tool_category("nikto"), "Web Server Analysis");
}
#[test]
fn tool_category_header_check() {
assert_eq!(
tool_category("security_header_check"),
"Web Server Analysis"
);
}
#[test]
fn tool_category_zap_spider() {
assert_eq!(tool_category("zap_spider"), "Web Application Scanning");
}
#[test]
fn tool_category_sqlmap() {
assert_eq!(tool_category("sqlmap"), "SQL Injection Testing");
}
#[test]
fn tool_category_xss_scanner() {
assert_eq!(tool_category("xss_scanner"), "Cross-Site Scripting Testing");
}
#[test]
fn tool_category_dir_bruteforce() {
assert_eq!(tool_category("dir_bruteforce"), "Directory Enumeration");
}
#[test]
fn tool_category_gobuster() {
assert_eq!(tool_category("gobuster"), "Directory Enumeration");
}
#[test]
fn tool_category_ssl_check() {
assert_eq!(tool_category("ssl_check"), "SSL/TLS Analysis");
}
#[test]
fn tool_category_tls_scan() {
assert_eq!(tool_category("tls_scan"), "SSL/TLS Analysis");
}
#[test]
fn tool_category_api_test() {
assert_eq!(tool_category("api_endpoint_test"), "API Security Testing");
}
#[test]
fn tool_category_auth_bypass() {
assert_eq!(tool_category("auth_bypass_check"), "Authentication Testing");
}
#[test]
fn tool_category_cors() {
assert_eq!(tool_category("cors_check"), "CORS Testing");
}
#[test]
fn tool_category_csrf() {
assert_eq!(tool_category("csrf_scanner"), "CSRF Testing");
}
#[test]
fn tool_category_nuclei() {
assert_eq!(tool_category("nuclei"), "Vulnerability Scanning");
}
#[test]
fn tool_category_whatweb() {
assert_eq!(tool_category("whatweb"), "Technology Fingerprinting");
}
#[test]
fn tool_category_unknown_defaults_to_security_testing() {
assert_eq!(tool_category("custom_tool"), "Security Testing");
}
#[test]
fn tool_category_is_case_insensitive() {
assert_eq!(tool_category("NMAP_Scanner"), "Network Reconnaissance");
assert_eq!(tool_category("SQLMap"), "SQL Injection Testing");
}
// ── build_html_report ────────────────────────────────────────────
fn make_session(strategy: PentestStrategy) -> PentestSession {
let mut s = PentestSession::new("target-1".into(), strategy);
s.tool_invocations = 5;
s.tool_successes = 4;
s.findings_count = 2;
s.exploitable_count = 1;
s
}
fn make_finding(severity: Severity, title: &str, exploitable: bool) -> DastFinding {
let mut f = DastFinding::new(
"run-1".into(),
"target-1".into(),
DastVulnType::Xss,
title.into(),
"description".into(),
severity,
"https://example.com/test".into(),
"GET".into(),
);
f.exploitable = exploitable;
f
}
fn make_attack_node(tool_name: &str) -> AttackChainNode {
let mut node = AttackChainNode::new(
"session-1".into(),
"node-1".into(),
tool_name.into(),
serde_json::json!({}),
"Testing this tool".into(),
);
node.status = AttackNodeStatus::Completed;
node
}
fn make_report_context(
findings: Vec<DastFinding>,
chain: Vec<AttackChainNode>,
) -> ReportContext {
ReportContext {
session: make_session(PentestStrategy::Comprehensive),
target_name: "Test App".into(),
target_url: "https://example.com".into(),
findings,
attack_chain: chain,
requester_name: "Alice".into(),
requester_email: "alice@example.com".into(),
}
}
#[test]
fn report_contains_target_info() {
let ctx = make_report_context(vec![], vec![]);
let html = build_html_report(&ctx);
assert!(html.contains("Test App"));
assert!(html.contains("https://example.com"));
}
#[test]
fn report_contains_requester_info() {
let ctx = make_report_context(vec![], vec![]);
let html = build_html_report(&ctx);
assert!(html.contains("Alice"));
assert!(html.contains("alice@example.com"));
}
#[test]
fn report_shows_informational_risk_when_no_findings() {
let ctx = make_report_context(vec![], vec![]);
let html = build_html_report(&ctx);
assert!(html.contains("INFORMATIONAL"));
}
#[test]
fn report_shows_critical_risk_with_critical_finding() {
let findings = vec![make_finding(Severity::Critical, "Critical XSS", true)];
let ctx = make_report_context(findings, vec![]);
let html = build_html_report(&ctx);
assert!(html.contains("CRITICAL"));
}
#[test]
fn report_shows_high_risk_without_critical() {
let findings = vec![make_finding(Severity::High, "High SQLi", false)];
let ctx = make_report_context(findings, vec![]);
let html = build_html_report(&ctx);
// Should show HIGH, not CRITICAL
assert!(html.contains("HIGH"));
}
#[test]
fn report_shows_medium_risk_level() {
let findings = vec![make_finding(Severity::Medium, "Medium Issue", false)];
let ctx = make_report_context(findings, vec![]);
let html = build_html_report(&ctx);
assert!(html.contains("MEDIUM"));
}
#[test]
fn report_includes_finding_title() {
let findings = vec![make_finding(
Severity::High,
"Reflected XSS in /search",
true,
)];
let ctx = make_report_context(findings, vec![]);
let html = build_html_report(&ctx);
assert!(html.contains("Reflected XSS in /search"));
}
#[test]
fn report_shows_exploitable_badge() {
let findings = vec![make_finding(Severity::Critical, "SQLi", true)];
let ctx = make_report_context(findings, vec![]);
let html = build_html_report(&ctx);
// The report should mark exploitable findings
assert!(html.contains("EXPLOITABLE"));
}
#[test]
fn report_includes_attack_chain_tool_names() {
let chain = vec![make_attack_node("nmap_scan"), make_attack_node("sqlmap")];
let ctx = make_report_context(vec![], chain);
let html = build_html_report(&ctx);
assert!(html.contains("nmap_scan"));
assert!(html.contains("sqlmap"));
}
#[test]
fn report_is_valid_html_structure() {
let ctx = make_report_context(vec![], vec![]);
let html = build_html_report(&ctx);
assert!(html.contains("<!DOCTYPE html>") || html.contains("<html"));
assert!(html.contains("</html>"));
}
#[test]
fn report_strategy_appears() {
let ctx = make_report_context(vec![], vec![]);
let html = build_html_report(&ctx);
// PentestStrategy::Comprehensive => "comprehensive"
assert!(html.contains("comprehensive") || html.contains("Comprehensive"));
}
#[test]
fn report_finding_count_is_correct() {
let findings = vec![
make_finding(Severity::Critical, "F1", true),
make_finding(Severity::High, "F2", false),
make_finding(Severity::Low, "F3", false),
];
let ctx = make_report_context(findings, vec![]);
let html = build_html_report(&ctx);
// The total count "3" should appear somewhere
assert!(
html.contains(">3<")
|| html.contains(">3 ")
|| html.contains("3 findings")
|| html.contains("3 Total")
);
}
}

View File

@@ -0,0 +1,58 @@
mod archive;
mod html;
mod pdf;
use compliance_core::models::dast::DastFinding;
use compliance_core::models::pentest::{AttackChainNode, PentestSession};
use sha2::{Digest, Sha256};
/// Report archive with metadata
pub struct ReportArchive {
/// The password-protected ZIP bytes
pub archive: Vec<u8>,
/// SHA-256 hex digest of the archive
pub sha256: String,
}
/// Report context gathered from the database
pub struct ReportContext {
pub session: PentestSession,
pub target_name: String,
pub target_url: String,
pub findings: Vec<DastFinding>,
pub attack_chain: Vec<AttackChainNode>,
pub requester_name: String,
pub requester_email: String,
}
/// Generate a password-protected ZIP archive containing the pentest report.
///
/// The archive contains:
/// - `report.pdf` — Professional pentest report (PDF)
/// - `report.html` — HTML source (fallback)
/// - `findings.json` — Raw findings data
/// - `attack-chain.json` — Attack chain timeline
///
/// Files are encrypted with AES-256 inside the ZIP (standard WinZip AES format,
/// supported by 7-Zip, WinRAR, macOS Archive Utility, etc.).
pub async fn generate_encrypted_report(
ctx: &ReportContext,
password: &str,
) -> Result<ReportArchive, String> {
let html = html::build_html_report(ctx);
// Convert HTML to PDF via headless Chrome
let pdf_bytes = pdf::html_to_pdf(&html).await?;
let zip_bytes = archive::build_zip(ctx, password, &html, &pdf_bytes)
.map_err(|e| format!("Failed to create archive: {e}"))?;
let mut hasher = Sha256::new();
hasher.update(&zip_bytes);
let sha256 = hex::encode(hasher.finalize());
Ok(ReportArchive {
archive: zip_bytes,
sha256,
})
}

View File

@@ -0,0 +1,79 @@
/// Convert HTML string to PDF bytes using headless Chrome/Chromium.
pub(super) async fn html_to_pdf(html: &str) -> Result<Vec<u8>, String> {
let tmp_dir = std::env::temp_dir();
let run_id = uuid::Uuid::new_v4().to_string();
let html_path = tmp_dir.join(format!("pentest-report-{run_id}.html"));
let pdf_path = tmp_dir.join(format!("pentest-report-{run_id}.pdf"));
// Write HTML to temp file
std::fs::write(&html_path, html).map_err(|e| format!("Failed to write temp HTML: {e}"))?;
// Find Chrome/Chromium binary
let chrome_bin = find_chrome_binary().ok_or_else(|| {
"Chrome/Chromium not found. Install google-chrome or chromium to generate PDF reports."
.to_string()
})?;
tracing::info!(chrome = %chrome_bin, "Generating PDF report via headless Chrome");
let html_url = format!("file://{}", html_path.display());
let output = tokio::process::Command::new(&chrome_bin)
.args([
"--headless",
"--disable-gpu",
"--no-sandbox",
"--disable-software-rasterizer",
"--run-all-compositor-stages-before-draw",
"--disable-dev-shm-usage",
&format!("--print-to-pdf={}", pdf_path.display()),
"--no-pdf-header-footer",
&html_url,
])
.output()
.await
.map_err(|e| format!("Failed to run Chrome: {e}"))?;
if !output.status.success() {
let stderr = String::from_utf8_lossy(&output.stderr);
// Clean up temp files
let _ = std::fs::remove_file(&html_path);
let _ = std::fs::remove_file(&pdf_path);
return Err(format!("Chrome PDF generation failed: {stderr}"));
}
let pdf_bytes =
std::fs::read(&pdf_path).map_err(|e| format!("Failed to read generated PDF: {e}"))?;
// Clean up temp files
let _ = std::fs::remove_file(&html_path);
let _ = std::fs::remove_file(&pdf_path);
if pdf_bytes.is_empty() {
return Err("Chrome produced an empty PDF".to_string());
}
tracing::info!(size_kb = pdf_bytes.len() / 1024, "PDF report generated");
Ok(pdf_bytes)
}
/// Search for Chrome/Chromium binary on the system.
fn find_chrome_binary() -> Option<String> {
let candidates = [
"google-chrome-stable",
"google-chrome",
"chromium-browser",
"chromium",
];
for name in &candidates {
if let Ok(output) = std::process::Command::new("which").arg(name).output() {
if output.status.success() {
let path = String::from_utf8_lossy(&output.stdout).trim().to_string();
if !path.is_empty() {
return Some(path);
}
}
}
}
None
}

View File

@@ -8,3 +8,51 @@ pub fn compute_fingerprint(parts: &[&str]) -> String {
}
hex::encode(hasher.finalize())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn fingerprint_is_deterministic() {
let a = compute_fingerprint(&["repo1", "rule-x", "src/main.rs", "42"]);
let b = compute_fingerprint(&["repo1", "rule-x", "src/main.rs", "42"]);
assert_eq!(a, b);
}
#[test]
fn fingerprint_changes_with_different_input() {
let a = compute_fingerprint(&["repo1", "rule-x", "src/main.rs", "42"]);
let b = compute_fingerprint(&["repo1", "rule-x", "src/main.rs", "43"]);
assert_ne!(a, b);
}
#[test]
fn fingerprint_is_valid_hex_sha256() {
let fp = compute_fingerprint(&["hello"]);
assert_eq!(fp.len(), 64, "SHA-256 hex should be 64 chars");
assert!(fp.chars().all(|c| c.is_ascii_hexdigit()));
}
#[test]
fn fingerprint_empty_parts() {
let fp = compute_fingerprint(&[]);
// Should still produce a valid hash (of empty input)
assert_eq!(fp.len(), 64);
}
#[test]
fn fingerprint_order_matters() {
let a = compute_fingerprint(&["a", "b"]);
let b = compute_fingerprint(&["b", "a"]);
assert_ne!(a, b);
}
#[test]
fn fingerprint_separator_prevents_collision() {
// "ab" + "c" vs "a" + "bc" should differ because of the "|" separator
let a = compute_fingerprint(&["ab", "c"]);
let b = compute_fingerprint(&["a", "bc"]);
assert_ne!(a, b);
}
}

View File

@@ -129,3 +129,110 @@ struct GitleaksResult {
#[serde(rename = "Match")]
r#match: String,
}
#[cfg(test)]
mod tests {
use super::*;
// --- is_allowlisted tests ---
#[test]
fn allowlisted_env_example_files() {
assert!(is_allowlisted(".env.example"));
assert!(is_allowlisted("config/.env.sample"));
assert!(is_allowlisted("deploy/.ENV.TEMPLATE"));
}
#[test]
fn allowlisted_test_directories() {
assert!(is_allowlisted("src/test/config.json"));
assert!(is_allowlisted("src/tests/fixtures.rs"));
assert!(is_allowlisted("data/fixtures/secret.txt"));
assert!(is_allowlisted("pkg/testdata/key.pem"));
}
#[test]
fn allowlisted_mock_files() {
assert!(is_allowlisted("src/mock_service.py"));
assert!(is_allowlisted("lib/MockAuth.java"));
}
#[test]
fn allowlisted_test_suffixes() {
assert!(is_allowlisted("auth_test.go"));
assert!(is_allowlisted("auth.test.ts"));
assert!(is_allowlisted("auth.test.js"));
assert!(is_allowlisted("auth.spec.ts"));
assert!(is_allowlisted("auth.spec.js"));
}
#[test]
fn not_allowlisted_regular_files() {
assert!(!is_allowlisted("src/main.rs"));
assert!(!is_allowlisted("config/.env"));
assert!(!is_allowlisted("lib/auth.ts"));
assert!(!is_allowlisted("deploy/secrets.yaml"));
}
#[test]
fn not_allowlisted_partial_matches() {
// "test" as substring in a non-directory context should not match
assert!(!is_allowlisted("src/attestation.rs"));
assert!(!is_allowlisted("src/contest/data.json"));
}
// --- GitleaksResult deserialization tests ---
#[test]
fn deserialize_gitleaks_result() {
let json = r#"{
"Description": "AWS Access Key",
"RuleID": "aws-access-key",
"File": "src/config.rs",
"StartLine": 10,
"Match": "AKIAIOSFODNN7EXAMPLE"
}"#;
let result: GitleaksResult = serde_json::from_str(json).unwrap();
assert_eq!(result.description, "AWS Access Key");
assert_eq!(result.rule_id, "aws-access-key");
assert_eq!(result.file, "src/config.rs");
assert_eq!(result.start_line, 10);
assert_eq!(result.r#match, "AKIAIOSFODNN7EXAMPLE");
}
#[test]
fn deserialize_gitleaks_result_array() {
let json = r#"[
{
"Description": "Generic Secret",
"RuleID": "generic-secret",
"File": "app.py",
"StartLine": 5,
"Match": "password=hunter2"
}
]"#;
let results: Vec<GitleaksResult> = serde_json::from_str(json).unwrap();
assert_eq!(results.len(), 1);
assert_eq!(results[0].rule_id, "generic-secret");
}
#[test]
fn severity_mapping_private_key() {
// Verify the severity logic from the scan method
let rule_id = "some-private-key-rule";
assert!(rule_id.contains("private-key"));
}
#[test]
fn severity_mapping_token_password_secret() {
for keyword in &["token", "password", "secret"] {
let rule_id = format!("some-{}-rule", keyword);
assert!(
rule_id.contains("token")
|| rule_id.contains("password")
|| rule_id.contains("secret"),
"Expected '{rule_id}' to match token/password/secret"
);
}
}
}

View File

@@ -0,0 +1,106 @@
use compliance_core::models::Finding;
use super::orchestrator::{GraphContext, PipelineOrchestrator};
use crate::error::AgentError;
impl PipelineOrchestrator {
/// Build the code knowledge graph for a repo and compute impact analyses
pub(super) async fn build_code_graph(
&self,
repo_path: &std::path::Path,
repo_id: &str,
findings: &[Finding],
) -> Result<GraphContext, AgentError> {
let graph_build_id = uuid::Uuid::new_v4().to_string();
let engine = compliance_graph::GraphEngine::new(50_000);
let (mut code_graph, build_run) =
engine
.build_graph(repo_path, repo_id, &graph_build_id)
.map_err(|e| AgentError::Other(format!("Graph build error: {e}")))?;
// Apply community detection
compliance_graph::graph::community::apply_communities(&mut code_graph);
// Store graph in MongoDB
let store = compliance_graph::graph::persistence::GraphStore::new(self.db.inner());
store
.delete_repo_graph(repo_id)
.await
.map_err(|e| AgentError::Other(format!("Graph cleanup error: {e}")))?;
store
.store_graph(&build_run, &code_graph.nodes, &code_graph.edges)
.await
.map_err(|e| AgentError::Other(format!("Graph store error: {e}")))?;
// Compute impact analysis for each finding
let analyzer = compliance_graph::GraphEngine::impact_analyzer(&code_graph);
let mut impacts = Vec::new();
for finding in findings {
if let Some(file_path) = &finding.file_path {
let impact = analyzer.analyze(
repo_id,
&finding.fingerprint,
&graph_build_id,
file_path,
finding.line_number,
);
store
.store_impact(&impact)
.await
.map_err(|e| AgentError::Other(format!("Impact store error: {e}")))?;
impacts.push(impact);
}
}
Ok(GraphContext {
node_count: build_run.node_count,
edge_count: build_run.edge_count,
community_count: build_run.community_count,
impacts,
})
}
/// Trigger DAST scan if a target is configured for this repo
pub(super) async fn maybe_trigger_dast(&self, repo_id: &str, scan_run_id: &str) {
use futures_util::TryStreamExt;
let filter = mongodb::bson::doc! { "repo_id": repo_id };
let targets: Vec<compliance_core::models::DastTarget> =
match self.db.dast_targets().find(filter).await {
Ok(cursor) => cursor.try_collect().await.unwrap_or_default(),
Err(_) => return,
};
if targets.is_empty() {
tracing::info!("[{repo_id}] No DAST targets configured, skipping");
return;
}
for target in targets {
let db = self.db.clone();
let scan_run_id = scan_run_id.to_string();
tokio::spawn(async move {
let orchestrator = compliance_dast::DastOrchestrator::new(100);
match orchestrator.run_scan(&target, Vec::new()).await {
Ok((mut scan_run, findings)) => {
scan_run.sast_scan_run_id = Some(scan_run_id);
if let Err(e) = db.dast_scan_runs().insert_one(&scan_run).await {
tracing::error!("Failed to store DAST scan run: {e}");
}
for finding in &findings {
if let Err(e) = db.dast_findings().insert_one(finding).await {
tracing::error!("Failed to store DAST finding: {e}");
}
}
tracing::info!("DAST scan complete: {} findings", findings.len());
}
Err(e) => {
tracing::error!("DAST scan failed: {e}");
}
}
});
}
}
}

View File

@@ -0,0 +1,259 @@
use mongodb::bson::doc;
use compliance_core::models::*;
use super::orchestrator::{extract_base_url, PipelineOrchestrator};
use super::tracker_dispatch::TrackerDispatch;
use crate::error::AgentError;
use crate::trackers;
impl PipelineOrchestrator {
/// Build an issue tracker client from a repository's tracker configuration.
/// Returns `None` if the repo has no tracker configured.
pub(super) fn build_tracker(&self, repo: &TrackedRepository) -> Option<TrackerDispatch> {
let tracker_type = repo.tracker_type.as_ref()?;
// Per-repo token takes precedence, fall back to global config
match tracker_type {
TrackerType::GitHub => {
let token = repo.tracker_token.clone().or_else(|| {
self.config.github_token.as_ref().map(|t| {
use secrecy::ExposeSecret;
t.expose_secret().to_string()
})
})?;
let secret = secrecy::SecretString::from(token);
match trackers::github::GitHubTracker::new(&secret) {
Ok(t) => Some(TrackerDispatch::GitHub(t)),
Err(e) => {
tracing::warn!("Failed to build GitHub tracker: {e}");
None
}
}
}
TrackerType::GitLab => {
let base_url = self
.config
.gitlab_url
.clone()
.unwrap_or_else(|| "https://gitlab.com".to_string());
let token = repo.tracker_token.clone().or_else(|| {
self.config.gitlab_token.as_ref().map(|t| {
use secrecy::ExposeSecret;
t.expose_secret().to_string()
})
})?;
let secret = secrecy::SecretString::from(token);
Some(TrackerDispatch::GitLab(
trackers::gitlab::GitLabTracker::new(base_url, secret),
))
}
TrackerType::Gitea => {
let token = repo.tracker_token.clone()?;
let base_url = extract_base_url(&repo.git_url)?;
let secret = secrecy::SecretString::from(token);
Some(TrackerDispatch::Gitea(trackers::gitea::GiteaTracker::new(
base_url, secret,
)))
}
TrackerType::Jira => {
let base_url = self.config.jira_url.clone()?;
let email = self.config.jira_email.clone()?;
let project_key = self.config.jira_project_key.clone()?;
let token = repo.tracker_token.clone().or_else(|| {
self.config.jira_api_token.as_ref().map(|t| {
use secrecy::ExposeSecret;
t.expose_secret().to_string()
})
})?;
let secret = secrecy::SecretString::from(token);
Some(TrackerDispatch::Jira(trackers::jira::JiraTracker::new(
base_url,
email,
secret,
project_key,
)))
}
}
}
/// Create tracker issues for new findings (severity >= Medium).
/// Checks for duplicates via fingerprint search before creating.
#[tracing::instrument(skip_all, fields(repo_id = %repo_id))]
pub(super) async fn create_tracker_issues(
&self,
repo: &TrackedRepository,
repo_id: &str,
new_findings: &[Finding],
) -> Result<(), AgentError> {
let tracker = match self.build_tracker(repo) {
Some(t) => t,
None => {
tracing::info!("[{repo_id}] No issue tracker configured, skipping");
return Ok(());
}
};
let owner = match repo.tracker_owner.as_deref() {
Some(o) => o,
None => {
tracing::warn!("[{repo_id}] tracker_owner not set, skipping issue creation");
return Ok(());
}
};
let tracker_repo_name = match repo.tracker_repo.as_deref() {
Some(r) => r,
None => {
tracing::warn!("[{repo_id}] tracker_repo not set, skipping issue creation");
return Ok(());
}
};
// Only create issues for medium+ severity findings
let actionable: Vec<&Finding> = new_findings
.iter()
.filter(|f| {
matches!(
f.severity,
Severity::Medium | Severity::High | Severity::Critical
)
})
.collect();
if actionable.is_empty() {
tracing::info!("[{repo_id}] No medium+ findings, skipping issue creation");
return Ok(());
}
tracing::info!(
"[{repo_id}] Creating issues for {} findings via {}",
actionable.len(),
tracker.name()
);
let mut created = 0u32;
for finding in actionable {
let title = format!(
"[{}] {}: {}",
finding.severity, finding.scanner, finding.title
);
// Check if an issue already exists by fingerprint first, then by title
let mut found_existing = false;
for search_term in [&finding.fingerprint, &title] {
match tracker
.find_existing_issue(owner, tracker_repo_name, search_term)
.await
{
Ok(Some(existing)) => {
tracing::debug!(
"[{repo_id}] Issue already exists for '{}': {}",
search_term,
existing.external_url
);
found_existing = true;
break;
}
Ok(None) => {}
Err(e) => {
tracing::warn!("[{repo_id}] Failed to search for existing issue: {e}");
}
}
}
if found_existing {
continue;
}
let body = format_issue_body(finding);
let labels = vec![
format!("severity:{}", finding.severity),
format!("scanner:{}", finding.scanner),
"compliance-scanner".to_string(),
];
match tracker
.create_issue(owner, tracker_repo_name, &title, &body, &labels)
.await
{
Ok(mut issue) => {
issue.finding_id = finding
.id
.as_ref()
.map(|id| id.to_hex())
.unwrap_or_default();
// Update the finding with the issue URL
if let Some(finding_id) = &finding.id {
let _ = self
.db
.findings()
.update_one(
doc! { "_id": finding_id },
doc! { "$set": { "tracker_issue_url": &issue.external_url } },
)
.await;
}
// Store the tracker issue record
if let Err(e) = self.db.tracker_issues().insert_one(&issue).await {
tracing::warn!("[{repo_id}] Failed to store tracker issue: {e}");
}
created += 1;
}
Err(e) => {
tracing::warn!(
"[{repo_id}] Failed to create issue for {}: {e}",
finding.fingerprint
);
}
}
}
tracing::info!("[{repo_id}] Created {created} tracker issues");
Ok(())
}
}
/// Format a finding into a markdown issue body for the tracker.
pub(super) fn format_issue_body(finding: &Finding) -> String {
let mut body = String::new();
body.push_str(&format!("## {} Finding\n\n", finding.severity));
body.push_str(&format!("**Scanner:** {}\n", finding.scanner));
body.push_str(&format!("**Severity:** {}\n", finding.severity));
if let Some(rule) = &finding.rule_id {
body.push_str(&format!("**Rule:** {}\n", rule));
}
if let Some(cwe) = &finding.cwe {
body.push_str(&format!("**CWE:** {}\n", cwe));
}
body.push_str(&format!("\n### Description\n\n{}\n", finding.description));
if let Some(file_path) = &finding.file_path {
body.push_str(&format!("\n### Location\n\n**File:** `{}`", file_path));
if let Some(line) = finding.line_number {
body.push_str(&format!(" (line {})", line));
}
body.push('\n');
}
if let Some(snippet) = &finding.code_snippet {
body.push_str(&format!("\n### Code\n\n```\n{}\n```\n", snippet));
}
if let Some(remediation) = &finding.remediation {
body.push_str(&format!("\n### Remediation\n\n{}\n", remediation));
}
if let Some(fix) = &finding.suggested_fix {
body.push_str(&format!("\n### Suggested Fix\n\n```\n{}\n```\n", fix));
}
body.push_str(&format!(
"\n---\n*Fingerprint:* `{}`\n*Generated by compliance-scanner*",
finding.fingerprint
));
body
}

View File

@@ -1,366 +0,0 @@
use std::path::Path;
use std::time::Duration;
use compliance_core::models::{Finding, ScanType, Severity};
use compliance_core::traits::{ScanOutput, Scanner};
use compliance_core::CoreError;
use tokio::process::Command;
use crate::pipeline::dedup;
/// Timeout for each individual lint command
const LINT_TIMEOUT: Duration = Duration::from_secs(120);
pub struct LintScanner;
impl Scanner for LintScanner {
fn name(&self) -> &str {
"lint"
}
fn scan_type(&self) -> ScanType {
ScanType::Lint
}
#[tracing::instrument(skip_all)]
async fn scan(&self, repo_path: &Path, repo_id: &str) -> Result<ScanOutput, CoreError> {
let mut all_findings = Vec::new();
// Detect which languages are present and run appropriate linters
if has_rust_project(repo_path) {
match run_clippy(repo_path, repo_id).await {
Ok(findings) => all_findings.extend(findings),
Err(e) => tracing::warn!("Clippy failed: {e}"),
}
}
if has_js_project(repo_path) {
match run_eslint(repo_path, repo_id).await {
Ok(findings) => all_findings.extend(findings),
Err(e) => tracing::warn!("ESLint failed: {e}"),
}
}
if has_python_project(repo_path) {
match run_ruff(repo_path, repo_id).await {
Ok(findings) => all_findings.extend(findings),
Err(e) => tracing::warn!("Ruff failed: {e}"),
}
}
Ok(ScanOutput {
findings: all_findings,
sbom_entries: Vec::new(),
})
}
}
fn has_rust_project(repo_path: &Path) -> bool {
repo_path.join("Cargo.toml").exists()
}
fn has_js_project(repo_path: &Path) -> bool {
// Only run if eslint is actually installed in the project
repo_path.join("package.json").exists() && repo_path.join("node_modules/.bin/eslint").exists()
}
fn has_python_project(repo_path: &Path) -> bool {
repo_path.join("pyproject.toml").exists()
|| repo_path.join("setup.py").exists()
|| repo_path.join("requirements.txt").exists()
}
/// Run a command with a timeout, returning its output or an error
async fn run_with_timeout(
child: tokio::process::Child,
scanner_name: &str,
) -> Result<std::process::Output, CoreError> {
let result = tokio::time::timeout(LINT_TIMEOUT, child.wait_with_output()).await;
match result {
Ok(Ok(output)) => Ok(output),
Ok(Err(e)) => Err(CoreError::Scanner {
scanner: scanner_name.to_string(),
source: Box::new(e),
}),
Err(_) => {
// Process is dropped here which sends SIGKILL on Unix
Err(CoreError::Scanner {
scanner: scanner_name.to_string(),
source: Box::new(std::io::Error::new(
std::io::ErrorKind::TimedOut,
format!("{scanner_name} timed out after {}s", LINT_TIMEOUT.as_secs()),
)),
})
}
}
}
// ── Clippy ──────────────────────────────────────────────
async fn run_clippy(repo_path: &Path, repo_id: &str) -> Result<Vec<Finding>, CoreError> {
let child = Command::new("cargo")
.args([
"clippy",
"--message-format=json",
"--quiet",
"--",
"-W",
"clippy::all",
])
.current_dir(repo_path)
.env("RUSTC_WRAPPER", "")
.stdout(std::process::Stdio::piped())
.stderr(std::process::Stdio::piped())
.spawn()
.map_err(|e| CoreError::Scanner {
scanner: "clippy".to_string(),
source: Box::new(e),
})?;
let output = run_with_timeout(child, "clippy").await?;
let stdout = String::from_utf8_lossy(&output.stdout);
let mut findings = Vec::new();
for line in stdout.lines() {
let msg: serde_json::Value = match serde_json::from_str(line) {
Ok(v) => v,
Err(_) => continue,
};
if msg.get("reason").and_then(|v| v.as_str()) != Some("compiler-message") {
continue;
}
let message = match msg.get("message") {
Some(m) => m,
None => continue,
};
let level = message.get("level").and_then(|v| v.as_str()).unwrap_or("");
if level != "warning" && level != "error" {
continue;
}
let text = message
.get("message")
.and_then(|v| v.as_str())
.unwrap_or("")
.to_string();
let code = message
.get("code")
.and_then(|v| v.get("code"))
.and_then(|v| v.as_str())
.unwrap_or("")
.to_string();
if text.starts_with("aborting due to") || code.is_empty() {
continue;
}
let (file_path, line_number) = extract_primary_span(message);
let severity = if level == "error" {
Severity::High
} else {
Severity::Low
};
let fingerprint = dedup::compute_fingerprint(&[
repo_id,
"clippy",
&code,
&file_path,
&line_number.to_string(),
]);
let mut finding = Finding::new(
repo_id.to_string(),
fingerprint,
"clippy".to_string(),
ScanType::Lint,
format!("[clippy] {text}"),
text,
severity,
);
finding.rule_id = Some(code);
if !file_path.is_empty() {
finding.file_path = Some(file_path);
}
if line_number > 0 {
finding.line_number = Some(line_number);
}
findings.push(finding);
}
Ok(findings)
}
fn extract_primary_span(message: &serde_json::Value) -> (String, u32) {
let spans = match message.get("spans").and_then(|v| v.as_array()) {
Some(s) => s,
None => return (String::new(), 0),
};
for span in spans {
if span.get("is_primary").and_then(|v| v.as_bool()) == Some(true) {
let file = span
.get("file_name")
.and_then(|v| v.as_str())
.unwrap_or("")
.to_string();
let line = span.get("line_start").and_then(|v| v.as_u64()).unwrap_or(0) as u32;
return (file, line);
}
}
(String::new(), 0)
}
// ── ESLint ──────────────────────────────────────────────
async fn run_eslint(repo_path: &Path, repo_id: &str) -> Result<Vec<Finding>, CoreError> {
// Use the project-local eslint binary directly, not npx (which can hang downloading)
let eslint_bin = repo_path.join("node_modules/.bin/eslint");
let child = Command::new(eslint_bin)
.args([".", "--format", "json", "--no-error-on-unmatched-pattern"])
.current_dir(repo_path)
.stdout(std::process::Stdio::piped())
.stderr(std::process::Stdio::piped())
.spawn()
.map_err(|e| CoreError::Scanner {
scanner: "eslint".to_string(),
source: Box::new(e),
})?;
let output = run_with_timeout(child, "eslint").await?;
if output.stdout.is_empty() {
return Ok(Vec::new());
}
let results: Vec<EslintFileResult> = serde_json::from_slice(&output.stdout).unwrap_or_default();
let mut findings = Vec::new();
for file_result in results {
for msg in file_result.messages {
let severity = match msg.severity {
2 => Severity::Medium,
_ => Severity::Low,
};
let rule_id = msg.rule_id.unwrap_or_default();
let fingerprint = dedup::compute_fingerprint(&[
repo_id,
"eslint",
&rule_id,
&file_result.file_path,
&msg.line.to_string(),
]);
let mut finding = Finding::new(
repo_id.to_string(),
fingerprint,
"eslint".to_string(),
ScanType::Lint,
format!("[eslint] {}", msg.message),
msg.message,
severity,
);
finding.rule_id = Some(rule_id);
finding.file_path = Some(file_result.file_path.clone());
finding.line_number = Some(msg.line);
findings.push(finding);
}
}
Ok(findings)
}
#[derive(serde::Deserialize)]
struct EslintFileResult {
#[serde(rename = "filePath")]
file_path: String,
messages: Vec<EslintMessage>,
}
#[derive(serde::Deserialize)]
struct EslintMessage {
#[serde(rename = "ruleId")]
rule_id: Option<String>,
severity: u8,
message: String,
line: u32,
}
// ── Ruff ────────────────────────────────────────────────
async fn run_ruff(repo_path: &Path, repo_id: &str) -> Result<Vec<Finding>, CoreError> {
let child = Command::new("ruff")
.args(["check", ".", "--output-format", "json", "--exit-zero"])
.current_dir(repo_path)
.stdout(std::process::Stdio::piped())
.stderr(std::process::Stdio::piped())
.spawn()
.map_err(|e| CoreError::Scanner {
scanner: "ruff".to_string(),
source: Box::new(e),
})?;
let output = run_with_timeout(child, "ruff").await?;
if output.stdout.is_empty() {
return Ok(Vec::new());
}
let results: Vec<RuffResult> = serde_json::from_slice(&output.stdout).unwrap_or_default();
let findings = results
.into_iter()
.map(|r| {
let severity = if r.code.starts_with('E') || r.code.starts_with('F') {
Severity::Medium
} else {
Severity::Low
};
let fingerprint = dedup::compute_fingerprint(&[
repo_id,
"ruff",
&r.code,
&r.filename,
&r.location.row.to_string(),
]);
let mut finding = Finding::new(
repo_id.to_string(),
fingerprint,
"ruff".to_string(),
ScanType::Lint,
format!("[ruff] {}: {}", r.code, r.message),
r.message,
severity,
);
finding.rule_id = Some(r.code);
finding.file_path = Some(r.filename);
finding.line_number = Some(r.location.row);
finding
})
.collect();
Ok(findings)
}
#[derive(serde::Deserialize)]
struct RuffResult {
code: String,
message: String,
filename: String,
location: RuffLocation,
}
#[derive(serde::Deserialize)]
struct RuffLocation {
row: u32,
}

View File

@@ -0,0 +1,251 @@
use std::path::Path;
use compliance_core::models::{Finding, ScanType, Severity};
use compliance_core::CoreError;
use tokio::process::Command;
use crate::pipeline::dedup;
use super::run_with_timeout;
pub(super) async fn run_clippy(repo_path: &Path, repo_id: &str) -> Result<Vec<Finding>, CoreError> {
let child = Command::new("cargo")
.args([
"clippy",
"--message-format=json",
"--quiet",
"--",
"-W",
"clippy::all",
])
.current_dir(repo_path)
.env("RUSTC_WRAPPER", "")
.stdout(std::process::Stdio::piped())
.stderr(std::process::Stdio::piped())
.spawn()
.map_err(|e| CoreError::Scanner {
scanner: "clippy".to_string(),
source: Box::new(e),
})?;
let output = run_with_timeout(child, "clippy").await?;
let stdout = String::from_utf8_lossy(&output.stdout);
let mut findings = Vec::new();
for line in stdout.lines() {
let msg: serde_json::Value = match serde_json::from_str(line) {
Ok(v) => v,
Err(_) => continue,
};
if msg.get("reason").and_then(|v| v.as_str()) != Some("compiler-message") {
continue;
}
let message = match msg.get("message") {
Some(m) => m,
None => continue,
};
let level = message.get("level").and_then(|v| v.as_str()).unwrap_or("");
if level != "warning" && level != "error" {
continue;
}
let text = message
.get("message")
.and_then(|v| v.as_str())
.unwrap_or("")
.to_string();
let code = message
.get("code")
.and_then(|v| v.get("code"))
.and_then(|v| v.as_str())
.unwrap_or("")
.to_string();
if text.starts_with("aborting due to") || code.is_empty() {
continue;
}
let (file_path, line_number) = extract_primary_span(message);
let severity = if level == "error" {
Severity::High
} else {
Severity::Low
};
let fingerprint = dedup::compute_fingerprint(&[
repo_id,
"clippy",
&code,
&file_path,
&line_number.to_string(),
]);
let mut finding = Finding::new(
repo_id.to_string(),
fingerprint,
"clippy".to_string(),
ScanType::Lint,
format!("[clippy] {text}"),
text,
severity,
);
finding.rule_id = Some(code);
if !file_path.is_empty() {
finding.file_path = Some(file_path);
}
if line_number > 0 {
finding.line_number = Some(line_number);
}
findings.push(finding);
}
Ok(findings)
}
fn extract_primary_span(message: &serde_json::Value) -> (String, u32) {
let spans = match message.get("spans").and_then(|v| v.as_array()) {
Some(s) => s,
None => return (String::new(), 0),
};
for span in spans {
if span.get("is_primary").and_then(|v| v.as_bool()) == Some(true) {
let file = span
.get("file_name")
.and_then(|v| v.as_str())
.unwrap_or("")
.to_string();
let line = span.get("line_start").and_then(|v| v.as_u64()).unwrap_or(0) as u32;
return (file, line);
}
}
(String::new(), 0)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn extract_primary_span_with_primary() {
let msg = serde_json::json!({
"spans": [
{
"file_name": "src/lib.rs",
"line_start": 42,
"is_primary": true
}
]
});
let (file, line) = extract_primary_span(&msg);
assert_eq!(file, "src/lib.rs");
assert_eq!(line, 42);
}
#[test]
fn extract_primary_span_no_primary() {
let msg = serde_json::json!({
"spans": [
{
"file_name": "src/lib.rs",
"line_start": 42,
"is_primary": false
}
]
});
let (file, line) = extract_primary_span(&msg);
assert_eq!(file, "");
assert_eq!(line, 0);
}
#[test]
fn extract_primary_span_multiple_spans() {
let msg = serde_json::json!({
"spans": [
{
"file_name": "src/other.rs",
"line_start": 10,
"is_primary": false
},
{
"file_name": "src/main.rs",
"line_start": 99,
"is_primary": true
}
]
});
let (file, line) = extract_primary_span(&msg);
assert_eq!(file, "src/main.rs");
assert_eq!(line, 99);
}
#[test]
fn extract_primary_span_no_spans() {
let msg = serde_json::json!({});
let (file, line) = extract_primary_span(&msg);
assert_eq!(file, "");
assert_eq!(line, 0);
}
#[test]
fn extract_primary_span_empty_spans() {
let msg = serde_json::json!({ "spans": [] });
let (file, line) = extract_primary_span(&msg);
assert_eq!(file, "");
assert_eq!(line, 0);
}
#[test]
fn parse_clippy_compiler_message_line() {
let line = r#"{"reason":"compiler-message","message":{"level":"warning","message":"unused variable","code":{"code":"unused_variables"},"spans":[{"file_name":"src/main.rs","line_start":5,"is_primary":true}]}}"#;
let msg: serde_json::Value = serde_json::from_str(line).unwrap();
assert_eq!(
msg.get("reason").and_then(|v| v.as_str()),
Some("compiler-message")
);
let message = msg.get("message").unwrap();
assert_eq!(
message.get("level").and_then(|v| v.as_str()),
Some("warning")
);
assert_eq!(
message.get("message").and_then(|v| v.as_str()),
Some("unused variable")
);
assert_eq!(
message
.get("code")
.and_then(|v| v.get("code"))
.and_then(|v| v.as_str()),
Some("unused_variables")
);
let (file, line_num) = extract_primary_span(message);
assert_eq!(file, "src/main.rs");
assert_eq!(line_num, 5);
}
#[test]
fn skip_non_compiler_message() {
let line = r#"{"reason":"build-script-executed","package_id":"foo 0.1.0"}"#;
let msg: serde_json::Value = serde_json::from_str(line).unwrap();
assert_ne!(
msg.get("reason").and_then(|v| v.as_str()),
Some("compiler-message")
);
}
#[test]
fn skip_aborting_message() {
let text = "aborting due to 3 previous errors";
assert!(text.starts_with("aborting due to"));
}
}

View File

@@ -0,0 +1,183 @@
use std::path::Path;
use compliance_core::models::{Finding, ScanType, Severity};
use compliance_core::CoreError;
use tokio::process::Command;
use crate::pipeline::dedup;
use super::run_with_timeout;
pub(super) async fn run_eslint(repo_path: &Path, repo_id: &str) -> Result<Vec<Finding>, CoreError> {
// Use the project-local eslint binary directly, not npx (which can hang downloading)
let eslint_bin = repo_path.join("node_modules/.bin/eslint");
let child = Command::new(eslint_bin)
.args([".", "--format", "json", "--no-error-on-unmatched-pattern"])
.current_dir(repo_path)
.stdout(std::process::Stdio::piped())
.stderr(std::process::Stdio::piped())
.spawn()
.map_err(|e| CoreError::Scanner {
scanner: "eslint".to_string(),
source: Box::new(e),
})?;
let output = run_with_timeout(child, "eslint").await?;
if output.stdout.is_empty() {
return Ok(Vec::new());
}
let results: Vec<EslintFileResult> = serde_json::from_slice(&output.stdout).unwrap_or_default();
let mut findings = Vec::new();
for file_result in results {
for msg in file_result.messages {
let severity = match msg.severity {
2 => Severity::Medium,
_ => Severity::Low,
};
let rule_id = msg.rule_id.unwrap_or_default();
let fingerprint = dedup::compute_fingerprint(&[
repo_id,
"eslint",
&rule_id,
&file_result.file_path,
&msg.line.to_string(),
]);
let mut finding = Finding::new(
repo_id.to_string(),
fingerprint,
"eslint".to_string(),
ScanType::Lint,
format!("[eslint] {}", msg.message),
msg.message,
severity,
);
finding.rule_id = Some(rule_id);
finding.file_path = Some(file_result.file_path.clone());
finding.line_number = Some(msg.line);
findings.push(finding);
}
}
Ok(findings)
}
#[derive(serde::Deserialize)]
struct EslintFileResult {
#[serde(rename = "filePath")]
file_path: String,
messages: Vec<EslintMessage>,
}
#[derive(serde::Deserialize)]
struct EslintMessage {
#[serde(rename = "ruleId")]
rule_id: Option<String>,
severity: u8,
message: String,
line: u32,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn deserialize_eslint_output() {
let json = r#"[
{
"filePath": "/home/user/project/src/app.js",
"messages": [
{
"ruleId": "no-unused-vars",
"severity": 2,
"message": "'x' is defined but never used.",
"line": 10
},
{
"ruleId": "semi",
"severity": 1,
"message": "Missing semicolon.",
"line": 15
}
]
}
]"#;
let results: Vec<EslintFileResult> = serde_json::from_str(json).unwrap();
assert_eq!(results.len(), 1);
assert_eq!(results[0].file_path, "/home/user/project/src/app.js");
assert_eq!(results[0].messages.len(), 2);
assert_eq!(
results[0].messages[0].rule_id,
Some("no-unused-vars".to_string())
);
assert_eq!(results[0].messages[0].severity, 2);
assert_eq!(results[0].messages[0].line, 10);
assert_eq!(results[0].messages[1].severity, 1);
}
#[test]
fn deserialize_eslint_null_rule_id() {
let json = r#"[
{
"filePath": "src/index.js",
"messages": [
{
"ruleId": null,
"severity": 2,
"message": "Parsing error: Unexpected token",
"line": 1
}
]
}
]"#;
let results: Vec<EslintFileResult> = serde_json::from_str(json).unwrap();
assert_eq!(results[0].messages[0].rule_id, None);
}
#[test]
fn deserialize_eslint_empty_messages() {
let json = r#"[{"filePath": "src/clean.js", "messages": []}]"#;
let results: Vec<EslintFileResult> = serde_json::from_str(json).unwrap();
assert_eq!(results[0].messages.len(), 0);
}
#[test]
fn deserialize_eslint_empty_array() {
let json = "[]";
let results: Vec<EslintFileResult> = serde_json::from_str(json).unwrap();
assert!(results.is_empty());
}
#[test]
fn eslint_severity_mapping() {
// severity 2 = error -> Medium, anything else -> Low
assert_eq!(
match 2u8 {
2 => "Medium",
_ => "Low",
},
"Medium"
);
assert_eq!(
match 1u8 {
2 => "Medium",
_ => "Low",
},
"Low"
);
assert_eq!(
match 0u8 {
2 => "Medium",
_ => "Low",
},
"Low"
);
}
}

View File

@@ -0,0 +1,97 @@
mod clippy;
mod eslint;
mod ruff;
use std::path::Path;
use std::time::Duration;
use compliance_core::models::ScanType;
use compliance_core::traits::{ScanOutput, Scanner};
use compliance_core::CoreError;
/// Timeout for each individual lint command
pub(crate) const LINT_TIMEOUT: Duration = Duration::from_secs(120);
pub struct LintScanner;
impl Scanner for LintScanner {
fn name(&self) -> &str {
"lint"
}
fn scan_type(&self) -> ScanType {
ScanType::Lint
}
#[tracing::instrument(skip_all)]
async fn scan(&self, repo_path: &Path, repo_id: &str) -> Result<ScanOutput, CoreError> {
let mut all_findings = Vec::new();
// Detect which languages are present and run appropriate linters
if has_rust_project(repo_path) {
match clippy::run_clippy(repo_path, repo_id).await {
Ok(findings) => all_findings.extend(findings),
Err(e) => tracing::warn!("Clippy failed: {e}"),
}
}
if has_js_project(repo_path) {
match eslint::run_eslint(repo_path, repo_id).await {
Ok(findings) => all_findings.extend(findings),
Err(e) => tracing::warn!("ESLint failed: {e}"),
}
}
if has_python_project(repo_path) {
match ruff::run_ruff(repo_path, repo_id).await {
Ok(findings) => all_findings.extend(findings),
Err(e) => tracing::warn!("Ruff failed: {e}"),
}
}
Ok(ScanOutput {
findings: all_findings,
sbom_entries: Vec::new(),
})
}
}
fn has_rust_project(repo_path: &Path) -> bool {
repo_path.join("Cargo.toml").exists()
}
fn has_js_project(repo_path: &Path) -> bool {
// Only run if eslint is actually installed in the project
repo_path.join("package.json").exists() && repo_path.join("node_modules/.bin/eslint").exists()
}
fn has_python_project(repo_path: &Path) -> bool {
repo_path.join("pyproject.toml").exists()
|| repo_path.join("setup.py").exists()
|| repo_path.join("requirements.txt").exists()
}
/// Run a command with a timeout, returning its output or an error
pub(crate) async fn run_with_timeout(
child: tokio::process::Child,
scanner_name: &str,
) -> Result<std::process::Output, CoreError> {
let result = tokio::time::timeout(LINT_TIMEOUT, child.wait_with_output()).await;
match result {
Ok(Ok(output)) => Ok(output),
Ok(Err(e)) => Err(CoreError::Scanner {
scanner: scanner_name.to_string(),
source: Box::new(e),
}),
Err(_) => {
// Process is dropped here which sends SIGKILL on Unix
Err(CoreError::Scanner {
scanner: scanner_name.to_string(),
source: Box::new(std::io::Error::new(
std::io::ErrorKind::TimedOut,
format!("{scanner_name} timed out after {}s", LINT_TIMEOUT.as_secs()),
)),
})
}
}
}

View File

@@ -0,0 +1,150 @@
use std::path::Path;
use compliance_core::models::{Finding, ScanType, Severity};
use compliance_core::CoreError;
use tokio::process::Command;
use crate::pipeline::dedup;
use super::run_with_timeout;
pub(super) async fn run_ruff(repo_path: &Path, repo_id: &str) -> Result<Vec<Finding>, CoreError> {
let child = Command::new("ruff")
.args(["check", ".", "--output-format", "json", "--exit-zero"])
.current_dir(repo_path)
.stdout(std::process::Stdio::piped())
.stderr(std::process::Stdio::piped())
.spawn()
.map_err(|e| CoreError::Scanner {
scanner: "ruff".to_string(),
source: Box::new(e),
})?;
let output = run_with_timeout(child, "ruff").await?;
if output.stdout.is_empty() {
return Ok(Vec::new());
}
let results: Vec<RuffResult> = serde_json::from_slice(&output.stdout).unwrap_or_default();
let findings = results
.into_iter()
.map(|r| {
let severity = if r.code.starts_with('E') || r.code.starts_with('F') {
Severity::Medium
} else {
Severity::Low
};
let fingerprint = dedup::compute_fingerprint(&[
repo_id,
"ruff",
&r.code,
&r.filename,
&r.location.row.to_string(),
]);
let mut finding = Finding::new(
repo_id.to_string(),
fingerprint,
"ruff".to_string(),
ScanType::Lint,
format!("[ruff] {}: {}", r.code, r.message),
r.message,
severity,
);
finding.rule_id = Some(r.code);
finding.file_path = Some(r.filename);
finding.line_number = Some(r.location.row);
finding
})
.collect();
Ok(findings)
}
#[derive(serde::Deserialize)]
struct RuffResult {
code: String,
message: String,
filename: String,
location: RuffLocation,
}
#[derive(serde::Deserialize)]
struct RuffLocation {
row: u32,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn deserialize_ruff_output() {
let json = r#"[
{
"code": "E501",
"message": "Line too long (120 > 79 characters)",
"filename": "src/main.py",
"location": {"row": 42}
},
{
"code": "F401",
"message": "`os` imported but unused",
"filename": "src/utils.py",
"location": {"row": 1}
}
]"#;
let results: Vec<RuffResult> = serde_json::from_str(json).unwrap();
assert_eq!(results.len(), 2);
assert_eq!(results[0].code, "E501");
assert_eq!(results[0].filename, "src/main.py");
assert_eq!(results[0].location.row, 42);
assert_eq!(results[1].code, "F401");
assert_eq!(results[1].location.row, 1);
}
#[test]
fn deserialize_ruff_empty() {
let json = "[]";
let results: Vec<RuffResult> = serde_json::from_str(json).unwrap();
assert!(results.is_empty());
}
#[test]
fn ruff_severity_e_and_f_are_medium() {
for code in &["E501", "E302", "F401", "F811"] {
let is_medium = code.starts_with('E') || code.starts_with('F');
assert!(is_medium, "Expected {code} to be Medium severity");
}
}
#[test]
fn ruff_severity_others_are_low() {
for code in &["W291", "I001", "D100", "C901", "N801"] {
let is_medium = code.starts_with('E') || code.starts_with('F');
assert!(!is_medium, "Expected {code} to be Low severity");
}
}
#[test]
fn deserialize_ruff_with_extra_fields() {
// Ruff output may contain additional fields we don't use
let json = r#"[{
"code": "W291",
"message": "Trailing whitespace",
"filename": "app.py",
"location": {"row": 3, "column": 10},
"end_location": {"row": 3, "column": 11},
"fix": null,
"noqa_row": 3
}]"#;
let results: Vec<RuffResult> = serde_json::from_str(json).unwrap();
assert_eq!(results.len(), 1);
assert_eq!(results[0].code, "W291");
}
}

View File

@@ -3,8 +3,12 @@ pub mod cve;
pub mod dedup;
pub mod git;
pub mod gitleaks;
mod graph_build;
mod issue_creation;
pub mod lint;
pub mod orchestrator;
pub mod patterns;
mod pr_review;
pub mod sbom;
pub mod semgrep;
mod tracker_dispatch;

View File

@@ -4,7 +4,6 @@ use mongodb::bson::doc;
use tracing::Instrument;
use compliance_core::models::*;
use compliance_core::traits::issue_tracker::IssueTracker;
use compliance_core::traits::Scanner;
use compliance_core::AgentConfig;
@@ -19,84 +18,6 @@ use crate::pipeline::lint::LintScanner;
use crate::pipeline::patterns::{GdprPatternScanner, OAuthPatternScanner};
use crate::pipeline::sbom::SbomScanner;
use crate::pipeline::semgrep::SemgrepScanner;
use crate::trackers;
/// Enum dispatch for issue trackers (async traits aren't dyn-compatible).
enum TrackerDispatch {
GitHub(trackers::github::GitHubTracker),
GitLab(trackers::gitlab::GitLabTracker),
Gitea(trackers::gitea::GiteaTracker),
Jira(trackers::jira::JiraTracker),
}
impl TrackerDispatch {
fn name(&self) -> &str {
match self {
Self::GitHub(t) => t.name(),
Self::GitLab(t) => t.name(),
Self::Gitea(t) => t.name(),
Self::Jira(t) => t.name(),
}
}
async fn create_issue(
&self,
owner: &str,
repo: &str,
title: &str,
body: &str,
labels: &[String],
) -> Result<TrackerIssue, compliance_core::error::CoreError> {
match self {
Self::GitHub(t) => t.create_issue(owner, repo, title, body, labels).await,
Self::GitLab(t) => t.create_issue(owner, repo, title, body, labels).await,
Self::Gitea(t) => t.create_issue(owner, repo, title, body, labels).await,
Self::Jira(t) => t.create_issue(owner, repo, title, body, labels).await,
}
}
async fn find_existing_issue(
&self,
owner: &str,
repo: &str,
fingerprint: &str,
) -> Result<Option<TrackerIssue>, compliance_core::error::CoreError> {
match self {
Self::GitHub(t) => t.find_existing_issue(owner, repo, fingerprint).await,
Self::GitLab(t) => t.find_existing_issue(owner, repo, fingerprint).await,
Self::Gitea(t) => t.find_existing_issue(owner, repo, fingerprint).await,
Self::Jira(t) => t.find_existing_issue(owner, repo, fingerprint).await,
}
}
async fn create_pr_review(
&self,
owner: &str,
repo: &str,
pr_number: u64,
body: &str,
comments: Vec<compliance_core::traits::issue_tracker::ReviewComment>,
) -> Result<(), compliance_core::error::CoreError> {
match self {
Self::GitHub(t) => {
t.create_pr_review(owner, repo, pr_number, body, comments)
.await
}
Self::GitLab(t) => {
t.create_pr_review(owner, repo, pr_number, body, comments)
.await
}
Self::Gitea(t) => {
t.create_pr_review(owner, repo, pr_number, body, comments)
.await
}
Self::Jira(t) => {
t.create_pr_review(owner, repo, pr_number, body, comments)
.await
}
}
}
}
/// Context from graph analysis passed to LLM triage for enhanced filtering
#[derive(Debug)]
@@ -109,10 +30,10 @@ pub struct GraphContext {
}
pub struct PipelineOrchestrator {
config: AgentConfig,
db: Database,
llm: Arc<LlmClient>,
http: reqwest::Client,
pub(super) config: AgentConfig,
pub(super) db: Database,
pub(super) llm: Arc<LlmClient>,
pub(super) http: reqwest::Client,
}
impl PipelineOrchestrator {
@@ -460,446 +381,7 @@ impl PipelineOrchestrator {
Ok(new_count)
}
/// Build the code knowledge graph for a repo and compute impact analyses
async fn build_code_graph(
&self,
repo_path: &std::path::Path,
repo_id: &str,
findings: &[Finding],
) -> Result<GraphContext, AgentError> {
let graph_build_id = uuid::Uuid::new_v4().to_string();
let engine = compliance_graph::GraphEngine::new(50_000);
let (mut code_graph, build_run) =
engine
.build_graph(repo_path, repo_id, &graph_build_id)
.map_err(|e| AgentError::Other(format!("Graph build error: {e}")))?;
// Apply community detection
compliance_graph::graph::community::apply_communities(&mut code_graph);
// Store graph in MongoDB
let store = compliance_graph::graph::persistence::GraphStore::new(self.db.inner());
store
.delete_repo_graph(repo_id)
.await
.map_err(|e| AgentError::Other(format!("Graph cleanup error: {e}")))?;
store
.store_graph(&build_run, &code_graph.nodes, &code_graph.edges)
.await
.map_err(|e| AgentError::Other(format!("Graph store error: {e}")))?;
// Compute impact analysis for each finding
let analyzer = compliance_graph::GraphEngine::impact_analyzer(&code_graph);
let mut impacts = Vec::new();
for finding in findings {
if let Some(file_path) = &finding.file_path {
let impact = analyzer.analyze(
repo_id,
&finding.fingerprint,
&graph_build_id,
file_path,
finding.line_number,
);
store
.store_impact(&impact)
.await
.map_err(|e| AgentError::Other(format!("Impact store error: {e}")))?;
impacts.push(impact);
}
}
Ok(GraphContext {
node_count: build_run.node_count,
edge_count: build_run.edge_count,
community_count: build_run.community_count,
impacts,
})
}
/// Trigger DAST scan if a target is configured for this repo
async fn maybe_trigger_dast(&self, repo_id: &str, scan_run_id: &str) {
use futures_util::TryStreamExt;
let filter = mongodb::bson::doc! { "repo_id": repo_id };
let targets: Vec<compliance_core::models::DastTarget> =
match self.db.dast_targets().find(filter).await {
Ok(cursor) => cursor.try_collect().await.unwrap_or_default(),
Err(_) => return,
};
if targets.is_empty() {
tracing::info!("[{repo_id}] No DAST targets configured, skipping");
return;
}
for target in targets {
let db = self.db.clone();
let scan_run_id = scan_run_id.to_string();
tokio::spawn(async move {
let orchestrator = compliance_dast::DastOrchestrator::new(100);
match orchestrator.run_scan(&target, Vec::new()).await {
Ok((mut scan_run, findings)) => {
scan_run.sast_scan_run_id = Some(scan_run_id);
if let Err(e) = db.dast_scan_runs().insert_one(&scan_run).await {
tracing::error!("Failed to store DAST scan run: {e}");
}
for finding in &findings {
if let Err(e) = db.dast_findings().insert_one(finding).await {
tracing::error!("Failed to store DAST finding: {e}");
}
}
tracing::info!("DAST scan complete: {} findings", findings.len());
}
Err(e) => {
tracing::error!("DAST scan failed: {e}");
}
}
});
}
}
/// Build an issue tracker client from a repository's tracker configuration.
/// Returns `None` if the repo has no tracker configured.
fn build_tracker(&self, repo: &TrackedRepository) -> Option<TrackerDispatch> {
let tracker_type = repo.tracker_type.as_ref()?;
// Per-repo token takes precedence, fall back to global config
match tracker_type {
TrackerType::GitHub => {
let token = repo.tracker_token.clone().or_else(|| {
self.config.github_token.as_ref().map(|t| {
use secrecy::ExposeSecret;
t.expose_secret().to_string()
})
})?;
let secret = secrecy::SecretString::from(token);
match trackers::github::GitHubTracker::new(&secret) {
Ok(t) => Some(TrackerDispatch::GitHub(t)),
Err(e) => {
tracing::warn!("Failed to build GitHub tracker: {e}");
None
}
}
}
TrackerType::GitLab => {
let base_url = self
.config
.gitlab_url
.clone()
.unwrap_or_else(|| "https://gitlab.com".to_string());
let token = repo.tracker_token.clone().or_else(|| {
self.config.gitlab_token.as_ref().map(|t| {
use secrecy::ExposeSecret;
t.expose_secret().to_string()
})
})?;
let secret = secrecy::SecretString::from(token);
Some(TrackerDispatch::GitLab(
trackers::gitlab::GitLabTracker::new(base_url, secret),
))
}
TrackerType::Gitea => {
let token = repo.tracker_token.clone()?;
let base_url = extract_base_url(&repo.git_url)?;
let secret = secrecy::SecretString::from(token);
Some(TrackerDispatch::Gitea(trackers::gitea::GiteaTracker::new(
base_url, secret,
)))
}
TrackerType::Jira => {
let base_url = self.config.jira_url.clone()?;
let email = self.config.jira_email.clone()?;
let project_key = self.config.jira_project_key.clone()?;
let token = repo.tracker_token.clone().or_else(|| {
self.config.jira_api_token.as_ref().map(|t| {
use secrecy::ExposeSecret;
t.expose_secret().to_string()
})
})?;
let secret = secrecy::SecretString::from(token);
Some(TrackerDispatch::Jira(trackers::jira::JiraTracker::new(
base_url,
email,
secret,
project_key,
)))
}
}
}
/// Create tracker issues for new findings (severity >= Medium).
/// Checks for duplicates via fingerprint search before creating.
#[tracing::instrument(skip_all, fields(repo_id = %repo_id))]
async fn create_tracker_issues(
&self,
repo: &TrackedRepository,
repo_id: &str,
new_findings: &[Finding],
) -> Result<(), AgentError> {
let tracker = match self.build_tracker(repo) {
Some(t) => t,
None => {
tracing::info!("[{repo_id}] No issue tracker configured, skipping");
return Ok(());
}
};
let owner = match repo.tracker_owner.as_deref() {
Some(o) => o,
None => {
tracing::warn!("[{repo_id}] tracker_owner not set, skipping issue creation");
return Ok(());
}
};
let tracker_repo_name = match repo.tracker_repo.as_deref() {
Some(r) => r,
None => {
tracing::warn!("[{repo_id}] tracker_repo not set, skipping issue creation");
return Ok(());
}
};
// Only create issues for medium+ severity findings
let actionable: Vec<&Finding> = new_findings
.iter()
.filter(|f| {
matches!(
f.severity,
Severity::Medium | Severity::High | Severity::Critical
)
})
.collect();
if actionable.is_empty() {
tracing::info!("[{repo_id}] No medium+ findings, skipping issue creation");
return Ok(());
}
tracing::info!(
"[{repo_id}] Creating issues for {} findings via {}",
actionable.len(),
tracker.name()
);
let mut created = 0u32;
for finding in actionable {
let title = format!(
"[{}] {}: {}",
finding.severity, finding.scanner, finding.title
);
// Check if an issue already exists by fingerprint first, then by title
let mut found_existing = false;
for search_term in [&finding.fingerprint, &title] {
match tracker
.find_existing_issue(owner, tracker_repo_name, search_term)
.await
{
Ok(Some(existing)) => {
tracing::debug!(
"[{repo_id}] Issue already exists for '{}': {}",
search_term,
existing.external_url
);
found_existing = true;
break;
}
Ok(None) => {}
Err(e) => {
tracing::warn!("[{repo_id}] Failed to search for existing issue: {e}");
}
}
}
if found_existing {
continue;
}
let body = format_issue_body(finding);
let labels = vec![
format!("severity:{}", finding.severity),
format!("scanner:{}", finding.scanner),
"compliance-scanner".to_string(),
];
match tracker
.create_issue(owner, tracker_repo_name, &title, &body, &labels)
.await
{
Ok(mut issue) => {
issue.finding_id = finding
.id
.as_ref()
.map(|id| id.to_hex())
.unwrap_or_default();
// Update the finding with the issue URL
if let Some(finding_id) = &finding.id {
let _ = self
.db
.findings()
.update_one(
doc! { "_id": finding_id },
doc! { "$set": { "tracker_issue_url": &issue.external_url } },
)
.await;
}
// Store the tracker issue record
if let Err(e) = self.db.tracker_issues().insert_one(&issue).await {
tracing::warn!("[{repo_id}] Failed to store tracker issue: {e}");
}
created += 1;
}
Err(e) => {
tracing::warn!(
"[{repo_id}] Failed to create issue for {}: {e}",
finding.fingerprint
);
}
}
}
tracing::info!("[{repo_id}] Created {created} tracker issues");
Ok(())
}
/// Run an incremental scan on a PR diff and post review comments.
#[tracing::instrument(skip_all, fields(repo_id = %repo_id, pr_number))]
pub async fn run_pr_review(
&self,
repo: &TrackedRepository,
repo_id: &str,
pr_number: u64,
base_sha: &str,
head_sha: &str,
) -> Result<(), AgentError> {
let tracker = match self.build_tracker(repo) {
Some(t) => t,
None => {
tracing::warn!("[{repo_id}] No tracker configured, cannot post PR review");
return Ok(());
}
};
let owner = repo.tracker_owner.as_deref().unwrap_or("");
let tracker_repo_name = repo.tracker_repo.as_deref().unwrap_or("");
if owner.is_empty() || tracker_repo_name.is_empty() {
tracing::warn!("[{repo_id}] tracker_owner or tracker_repo not set");
return Ok(());
}
// Clone/fetch the repo
let creds = GitOps::make_repo_credentials(&self.config, repo);
let git_ops = GitOps::new(&self.config.git_clone_base_path, creds);
let repo_path = git_ops.clone_or_fetch(&repo.git_url, &repo.name)?;
// Get diff between base and head
let diff_files = GitOps::get_diff_content(&repo_path, base_sha, head_sha)?;
if diff_files.is_empty() {
tracing::info!("[{repo_id}] PR #{pr_number}: no diff files, skipping review");
return Ok(());
}
// Run semgrep on the full repo but we'll filter findings to changed files
let changed_paths: std::collections::HashSet<String> =
diff_files.iter().map(|f| f.path.clone()).collect();
let mut pr_findings: Vec<Finding> = Vec::new();
// SAST scan (semgrep)
match SemgrepScanner.scan(&repo_path, repo_id).await {
Ok(output) => {
for f in output.findings {
if let Some(fp) = &f.file_path {
if changed_paths.contains(fp.as_str()) {
pr_findings.push(f);
}
}
}
}
Err(e) => tracing::warn!("[{repo_id}] PR semgrep failed: {e}"),
}
// LLM code review on the diff
let reviewer = CodeReviewScanner::new(self.llm.clone());
let review_output = reviewer
.review_diff(&repo_path, repo_id, base_sha, head_sha)
.await;
pr_findings.extend(review_output.findings);
if pr_findings.is_empty() {
// Post a clean review
if let Err(e) = tracker
.create_pr_review(
owner,
tracker_repo_name,
pr_number,
"Compliance scan: no issues found in this PR.",
Vec::new(),
)
.await
{
tracing::warn!("[{repo_id}] Failed to post clean PR review: {e}");
}
return Ok(());
}
// Build review comments from findings
let mut review_comments = Vec::new();
for finding in &pr_findings {
if let (Some(path), Some(line)) = (&finding.file_path, finding.line_number) {
let comment_body = format!(
"**[{}] {}**\n\n{}\n\n*Scanner: {} | {}*",
finding.severity,
finding.title,
finding.description,
finding.scanner,
finding
.cwe
.as_deref()
.map(|c| format!("CWE: {c}"))
.unwrap_or_default(),
);
review_comments.push(compliance_core::traits::issue_tracker::ReviewComment {
path: path.clone(),
line,
body: comment_body,
});
}
}
let summary = format!(
"Compliance scan found **{}** issue(s) in this PR:\n\n{}",
pr_findings.len(),
pr_findings
.iter()
.map(|f| format!("- **[{}]** {}: {}", f.severity, f.scanner, f.title))
.collect::<Vec<_>>()
.join("\n"),
);
if let Err(e) = tracker
.create_pr_review(
owner,
tracker_repo_name,
pr_number,
&summary,
review_comments,
)
.await
{
tracing::warn!("[{repo_id}] Failed to post PR review: {e}");
} else {
tracing::info!(
"[{repo_id}] Posted PR review on #{pr_number} with {} findings",
pr_findings.len()
);
}
Ok(())
}
async fn update_phase(&self, scan_run_id: &str, phase: &str) {
pub(super) async fn update_phase(&self, scan_run_id: &str, phase: &str) {
if let Ok(oid) = mongodb::bson::oid::ObjectId::parse_str(scan_run_id) {
let _ = self
.db
@@ -917,9 +399,9 @@ impl PipelineOrchestrator {
}
/// Extract the scheme + host from a git URL.
/// e.g. "https://gitea.example.com/owner/repo.git" "https://gitea.example.com"
/// e.g. "ssh://git@gitea.example.com:22/owner/repo.git" "https://gitea.example.com"
fn extract_base_url(git_url: &str) -> Option<String> {
/// e.g. "https://gitea.example.com/owner/repo.git" -> "https://gitea.example.com"
/// e.g. "ssh://git@gitea.example.com:22/owner/repo.git" -> "https://gitea.example.com"
pub(super) fn extract_base_url(git_url: &str) -> Option<String> {
if let Some(rest) = git_url.strip_prefix("https://") {
let host = rest.split('/').next()?;
Some(format!("https://{host}"))
@@ -927,7 +409,7 @@ fn extract_base_url(git_url: &str) -> Option<String> {
let host = rest.split('/').next()?;
Some(format!("http://{host}"))
} else if let Some(rest) = git_url.strip_prefix("ssh://") {
// ssh://git@host:port/path extract host
// ssh://git@host:port/path -> extract host
let after_at = rest.find('@').map(|i| &rest[i + 1..]).unwrap_or(rest);
let host = after_at.split(&[':', '/'][..]).next()?;
Some(format!("https://{host}"))
@@ -940,48 +422,3 @@ fn extract_base_url(git_url: &str) -> Option<String> {
None
}
}
/// Format a finding into a markdown issue body for the tracker.
fn format_issue_body(finding: &Finding) -> String {
let mut body = String::new();
body.push_str(&format!("## {} Finding\n\n", finding.severity));
body.push_str(&format!("**Scanner:** {}\n", finding.scanner));
body.push_str(&format!("**Severity:** {}\n", finding.severity));
if let Some(rule) = &finding.rule_id {
body.push_str(&format!("**Rule:** {}\n", rule));
}
if let Some(cwe) = &finding.cwe {
body.push_str(&format!("**CWE:** {}\n", cwe));
}
body.push_str(&format!("\n### Description\n\n{}\n", finding.description));
if let Some(file_path) = &finding.file_path {
body.push_str(&format!("\n### Location\n\n**File:** `{}`", file_path));
if let Some(line) = finding.line_number {
body.push_str(&format!(" (line {})", line));
}
body.push('\n');
}
if let Some(snippet) = &finding.code_snippet {
body.push_str(&format!("\n### Code\n\n```\n{}\n```\n", snippet));
}
if let Some(remediation) = &finding.remediation {
body.push_str(&format!("\n### Remediation\n\n{}\n", remediation));
}
if let Some(fix) = &finding.suggested_fix {
body.push_str(&format!("\n### Suggested Fix\n\n```\n{}\n```\n", fix));
}
body.push_str(&format!(
"\n---\n*Fingerprint:* `{}`\n*Generated by compliance-scanner*",
finding.fingerprint
));
body
}

View File

@@ -256,3 +256,159 @@ fn walkdir(path: &Path) -> Result<Vec<walkdir::DirEntry>, CoreError> {
Ok(entries)
}
#[cfg(test)]
mod tests {
use super::*;
// --- compile_regex tests ---
#[test]
fn compile_regex_valid_pattern() {
let re = compile_regex(r"\bfoo\b");
assert!(re.is_match("hello foo bar"));
assert!(!re.is_match("foobar"));
}
#[test]
fn compile_regex_invalid_pattern_returns_fallback() {
// An invalid regex should return the fallback "^$" that only matches empty strings
let re = compile_regex(r"[invalid");
assert!(re.is_match(""));
assert!(!re.is_match("anything"));
}
// --- GDPR pattern tests ---
#[test]
fn gdpr_pii_logging_matches() {
let scanner = GdprPatternScanner::new();
let pattern = &scanner.patterns[0]; // gdpr-pii-logging
// Regex: (log|print|console\.|logger\.|tracing::)\s*[\.(].*\b(pii_keyword)\b
assert!(pattern.pattern.is_match("console.log(email)"));
assert!(pattern.pattern.is_match("console.log(user.ssn)"));
assert!(pattern.pattern.is_match("print(phone_number)"));
assert!(pattern.pattern.is_match("tracing::(ip_addr)"));
assert!(pattern.pattern.is_match("log.debug(credit_card)"));
}
#[test]
fn gdpr_pii_logging_no_false_positive() {
let scanner = GdprPatternScanner::new();
let pattern = &scanner.patterns[0];
// Regular logging without PII fields should not match
assert!(!pattern
.pattern
.is_match("logger.info(\"request completed\")"));
assert!(!pattern.pattern.is_match("let email = user.email;"));
}
#[test]
fn gdpr_no_consent_matches() {
let scanner = GdprPatternScanner::new();
let pattern = &scanner.patterns[1]; // gdpr-no-consent
assert!(pattern.pattern.is_match("collect personal data"));
assert!(pattern.pattern.is_match("store user_data in db"));
assert!(pattern.pattern.is_match("save pii to disk"));
}
#[test]
fn gdpr_user_model_matches() {
let scanner = GdprPatternScanner::new();
let pattern = &scanner.patterns[2]; // gdpr-no-delete-endpoint
assert!(pattern.pattern.is_match("struct User {"));
assert!(pattern.pattern.is_match("class User(Model):"));
}
#[test]
fn gdpr_hardcoded_retention_matches() {
let scanner = GdprPatternScanner::new();
let pattern = &scanner.patterns[3]; // gdpr-hardcoded-retention
assert!(pattern.pattern.is_match("retention = 30"));
assert!(pattern.pattern.is_match("ttl: 3600"));
assert!(pattern.pattern.is_match("expire = 86400"));
}
// --- OAuth pattern tests ---
#[test]
fn oauth_implicit_grant_matches() {
let scanner = OAuthPatternScanner::new();
let pattern = &scanner.patterns[0]; // oauth-implicit-grant
assert!(pattern.pattern.is_match("response_type = \"token\""));
assert!(pattern.pattern.is_match("grant_type: implicit"));
assert!(pattern.pattern.is_match("response_type='token'"));
}
#[test]
fn oauth_implicit_grant_no_false_positive() {
let scanner = OAuthPatternScanner::new();
let pattern = &scanner.patterns[0];
assert!(!pattern.pattern.is_match("response_type = \"code\""));
assert!(!pattern.pattern.is_match("grant_type: authorization_code"));
}
#[test]
fn oauth_authorization_code_matches() {
let scanner = OAuthPatternScanner::new();
let pattern = &scanner.patterns[1]; // oauth-missing-pkce
assert!(pattern.pattern.is_match("uses authorization_code flow"));
assert!(pattern.pattern.is_match("authorization code grant"));
}
#[test]
fn oauth_token_localstorage_matches() {
let scanner = OAuthPatternScanner::new();
let pattern = &scanner.patterns[2]; // oauth-token-localstorage
assert!(pattern
.pattern
.is_match("localStorage.setItem('access_token', tok)"));
assert!(pattern
.pattern
.is_match("localStorage.getItem(\"refresh_token\")"));
}
#[test]
fn oauth_token_localstorage_no_false_positive() {
let scanner = OAuthPatternScanner::new();
let pattern = &scanner.patterns[2];
assert!(!pattern
.pattern
.is_match("localStorage.setItem('theme', 'dark')"));
assert!(!pattern
.pattern
.is_match("sessionStorage.setItem('token', t)"));
}
#[test]
fn oauth_token_url_matches() {
let scanner = OAuthPatternScanner::new();
let pattern = &scanner.patterns[3]; // oauth-token-url
assert!(pattern.pattern.is_match("access_token = build_url(query)"));
assert!(pattern.pattern.is_match("bearer = url.param"));
}
// --- Pattern rule file extension filtering ---
#[test]
fn gdpr_patterns_cover_common_languages() {
let scanner = GdprPatternScanner::new();
for pattern in &scanner.patterns {
assert!(
pattern.file_extensions.contains(&"rs".to_string()),
"Pattern {} should cover .rs files",
pattern.id
);
}
}
#[test]
fn oauth_localstorage_only_js_ts() {
let scanner = OAuthPatternScanner::new();
let pattern = &scanner.patterns[2]; // oauth-token-localstorage
assert!(pattern.file_extensions.contains(&"js".to_string()));
assert!(pattern.file_extensions.contains(&"ts".to_string()));
assert!(!pattern.file_extensions.contains(&"rs".to_string()));
assert!(!pattern.file_extensions.contains(&"py".to_string()));
}
}

View File

@@ -0,0 +1,146 @@
use compliance_core::models::*;
use super::orchestrator::PipelineOrchestrator;
use crate::error::AgentError;
use crate::pipeline::code_review::CodeReviewScanner;
use crate::pipeline::git::GitOps;
use crate::pipeline::semgrep::SemgrepScanner;
use compliance_core::traits::Scanner;
impl PipelineOrchestrator {
/// Run an incremental scan on a PR diff and post review comments.
#[tracing::instrument(skip_all, fields(repo_id = %repo_id, pr_number))]
pub async fn run_pr_review(
&self,
repo: &TrackedRepository,
repo_id: &str,
pr_number: u64,
base_sha: &str,
head_sha: &str,
) -> Result<(), AgentError> {
let tracker = match self.build_tracker(repo) {
Some(t) => t,
None => {
tracing::warn!("[{repo_id}] No tracker configured, cannot post PR review");
return Ok(());
}
};
let owner = repo.tracker_owner.as_deref().unwrap_or("");
let tracker_repo_name = repo.tracker_repo.as_deref().unwrap_or("");
if owner.is_empty() || tracker_repo_name.is_empty() {
tracing::warn!("[{repo_id}] tracker_owner or tracker_repo not set");
return Ok(());
}
// Clone/fetch the repo
let creds = GitOps::make_repo_credentials(&self.config, repo);
let git_ops = GitOps::new(&self.config.git_clone_base_path, creds);
let repo_path = git_ops.clone_or_fetch(&repo.git_url, &repo.name)?;
// Get diff between base and head
let diff_files = GitOps::get_diff_content(&repo_path, base_sha, head_sha)?;
if diff_files.is_empty() {
tracing::info!("[{repo_id}] PR #{pr_number}: no diff files, skipping review");
return Ok(());
}
// Run semgrep on the full repo but we'll filter findings to changed files
let changed_paths: std::collections::HashSet<String> =
diff_files.iter().map(|f| f.path.clone()).collect();
let mut pr_findings: Vec<Finding> = Vec::new();
// SAST scan (semgrep)
match SemgrepScanner.scan(&repo_path, repo_id).await {
Ok(output) => {
for f in output.findings {
if let Some(fp) = &f.file_path {
if changed_paths.contains(fp.as_str()) {
pr_findings.push(f);
}
}
}
}
Err(e) => tracing::warn!("[{repo_id}] PR semgrep failed: {e}"),
}
// LLM code review on the diff
let reviewer = CodeReviewScanner::new(self.llm.clone());
let review_output = reviewer
.review_diff(&repo_path, repo_id, base_sha, head_sha)
.await;
pr_findings.extend(review_output.findings);
if pr_findings.is_empty() {
// Post a clean review
if let Err(e) = tracker
.create_pr_review(
owner,
tracker_repo_name,
pr_number,
"Compliance scan: no issues found in this PR.",
Vec::new(),
)
.await
{
tracing::warn!("[{repo_id}] Failed to post clean PR review: {e}");
}
return Ok(());
}
// Build review comments from findings
let mut review_comments = Vec::new();
for finding in &pr_findings {
if let (Some(path), Some(line)) = (&finding.file_path, finding.line_number) {
let comment_body = format!(
"**[{}] {}**\n\n{}\n\n*Scanner: {} | {}*",
finding.severity,
finding.title,
finding.description,
finding.scanner,
finding
.cwe
.as_deref()
.map(|c| format!("CWE: {c}"))
.unwrap_or_default(),
);
review_comments.push(compliance_core::traits::issue_tracker::ReviewComment {
path: path.clone(),
line,
body: comment_body,
});
}
}
let summary = format!(
"Compliance scan found **{}** issue(s) in this PR:\n\n{}",
pr_findings.len(),
pr_findings
.iter()
.map(|f| format!("- **[{}]** {}: {}", f.severity, f.scanner, f.title))
.collect::<Vec<_>>()
.join("\n"),
);
if let Err(e) = tracker
.create_pr_review(
owner,
tracker_repo_name,
pr_number,
&summary,
review_comments,
)
.await
{
tracing::warn!("[{repo_id}] Failed to post PR review: {e}");
} else {
tracing::info!(
"[{repo_id}] Posted PR review on #{pr_number} with {} findings",
pr_findings.len()
);
}
Ok(())
}
}

View File

@@ -0,0 +1,72 @@
use std::path::Path;
use compliance_core::CoreError;
pub(super) struct AuditVuln {
pub package: String,
pub id: String,
pub url: String,
}
#[tracing::instrument(skip_all)]
pub(super) async fn run_cargo_audit(
repo_path: &Path,
_repo_id: &str,
) -> Result<Vec<AuditVuln>, CoreError> {
let cargo_lock = repo_path.join("Cargo.lock");
if !cargo_lock.exists() {
return Ok(Vec::new());
}
let output = tokio::process::Command::new("cargo")
.args(["audit", "--json"])
.current_dir(repo_path)
.env("RUSTC_WRAPPER", "")
.output()
.await
.map_err(|e| CoreError::Scanner {
scanner: "cargo-audit".to_string(),
source: Box::new(e),
})?;
let result: CargoAuditOutput =
serde_json::from_slice(&output.stdout).unwrap_or_else(|_| CargoAuditOutput {
vulnerabilities: CargoAuditVulns { list: Vec::new() },
});
let vulns = result
.vulnerabilities
.list
.into_iter()
.map(|v| AuditVuln {
package: v.advisory.package,
id: v.advisory.id,
url: v.advisory.url,
})
.collect();
Ok(vulns)
}
// Cargo audit types
#[derive(serde::Deserialize)]
struct CargoAuditOutput {
vulnerabilities: CargoAuditVulns,
}
#[derive(serde::Deserialize)]
struct CargoAuditVulns {
list: Vec<CargoAuditEntry>,
}
#[derive(serde::Deserialize)]
struct CargoAuditEntry {
advisory: CargoAuditAdvisory,
}
#[derive(serde::Deserialize)]
struct CargoAuditAdvisory {
id: String,
package: String,
url: String,
}

View File

@@ -1,3 +1,6 @@
mod cargo_audit;
mod syft;
use std::path::Path;
use compliance_core::models::{SbomEntry, ScanType, VulnRef};
@@ -23,7 +26,7 @@ impl Scanner for SbomScanner {
generate_lockfiles(repo_path).await;
// Run syft for SBOM generation
match run_syft(repo_path, repo_id).await {
match syft::run_syft(repo_path, repo_id).await {
Ok(syft_entries) => entries.extend(syft_entries),
Err(e) => tracing::warn!("syft failed: {e}"),
}
@@ -32,7 +35,7 @@ impl Scanner for SbomScanner {
enrich_cargo_licenses(repo_path, &mut entries).await;
// Run cargo-audit for Rust-specific vulns
match run_cargo_audit(repo_path, repo_id).await {
match cargo_audit::run_cargo_audit(repo_path, repo_id).await {
Ok(vulns) => merge_audit_vulns(&mut entries, vulns),
Err(e) => tracing::warn!("cargo-audit skipped: {e}"),
}
@@ -186,95 +189,7 @@ async fn enrich_cargo_licenses(repo_path: &Path, entries: &mut [SbomEntry]) {
}
}
#[tracing::instrument(skip_all, fields(repo_id = %repo_id))]
async fn run_syft(repo_path: &Path, repo_id: &str) -> Result<Vec<SbomEntry>, CoreError> {
let output = tokio::process::Command::new("syft")
.arg(repo_path)
.args(["-o", "cyclonedx-json"])
// Enable remote license lookups for all ecosystems
.env("SYFT_GOLANG_SEARCH_REMOTE_LICENSES", "true")
.env("SYFT_JAVASCRIPT_SEARCH_REMOTE_LICENSES", "true")
.env("SYFT_PYTHON_SEARCH_REMOTE_LICENSES", "true")
.env("SYFT_JAVA_USE_NETWORK", "true")
.output()
.await
.map_err(|e| CoreError::Scanner {
scanner: "syft".to_string(),
source: Box::new(e),
})?;
if !output.status.success() {
let stderr = String::from_utf8_lossy(&output.stderr);
return Err(CoreError::Scanner {
scanner: "syft".to_string(),
source: format!("syft exited with {}: {stderr}", output.status).into(),
});
}
let cdx: CycloneDxBom = serde_json::from_slice(&output.stdout)?;
let entries = cdx
.components
.unwrap_or_default()
.into_iter()
.map(|c| {
let package_manager = c
.purl
.as_deref()
.and_then(extract_ecosystem_from_purl)
.unwrap_or_else(|| "unknown".to_string());
let mut entry = SbomEntry::new(
repo_id.to_string(),
c.name,
c.version.unwrap_or_else(|| "unknown".to_string()),
package_manager,
);
entry.purl = c.purl;
entry.license = c.licenses.and_then(|ls| extract_license(&ls));
entry
})
.collect();
Ok(entries)
}
#[tracing::instrument(skip_all)]
async fn run_cargo_audit(repo_path: &Path, _repo_id: &str) -> Result<Vec<AuditVuln>, CoreError> {
let cargo_lock = repo_path.join("Cargo.lock");
if !cargo_lock.exists() {
return Ok(Vec::new());
}
let output = tokio::process::Command::new("cargo")
.args(["audit", "--json"])
.current_dir(repo_path)
.env("RUSTC_WRAPPER", "")
.output()
.await
.map_err(|e| CoreError::Scanner {
scanner: "cargo-audit".to_string(),
source: Box::new(e),
})?;
let result: CargoAuditOutput =
serde_json::from_slice(&output.stdout).unwrap_or_else(|_| CargoAuditOutput {
vulnerabilities: CargoAuditVulns { list: Vec::new() },
});
let vulns = result
.vulnerabilities
.list
.into_iter()
.map(|v| AuditVuln {
package: v.advisory.package,
id: v.advisory.id,
url: v.advisory.url,
})
.collect();
Ok(vulns)
}
fn merge_audit_vulns(entries: &mut [SbomEntry], vulns: Vec<AuditVuln>) {
fn merge_audit_vulns(entries: &mut [SbomEntry], vulns: Vec<cargo_audit::AuditVuln>) {
for vuln in vulns {
if let Some(entry) = entries.iter_mut().find(|e| e.name == vuln.package) {
entry.known_vulnerabilities.push(VulnRef {
@@ -287,65 +202,6 @@ fn merge_audit_vulns(entries: &mut [SbomEntry], vulns: Vec<AuditVuln>) {
}
}
// CycloneDX JSON types
#[derive(serde::Deserialize)]
struct CycloneDxBom {
components: Option<Vec<CdxComponent>>,
}
#[derive(serde::Deserialize)]
struct CdxComponent {
name: String,
version: Option<String>,
#[serde(rename = "type")]
#[allow(dead_code)]
component_type: Option<String>,
purl: Option<String>,
licenses: Option<Vec<CdxLicenseWrapper>>,
}
#[derive(serde::Deserialize)]
struct CdxLicenseWrapper {
license: Option<CdxLicense>,
/// SPDX license expression (e.g. "MIT OR Apache-2.0")
expression: Option<String>,
}
#[derive(serde::Deserialize)]
struct CdxLicense {
id: Option<String>,
name: Option<String>,
}
// Cargo audit types
#[derive(serde::Deserialize)]
struct CargoAuditOutput {
vulnerabilities: CargoAuditVulns,
}
#[derive(serde::Deserialize)]
struct CargoAuditVulns {
list: Vec<CargoAuditEntry>,
}
#[derive(serde::Deserialize)]
struct CargoAuditEntry {
advisory: CargoAuditAdvisory,
}
#[derive(serde::Deserialize)]
struct CargoAuditAdvisory {
id: String,
package: String,
url: String,
}
struct AuditVuln {
package: String,
id: String,
url: String,
}
// Cargo metadata types
#[derive(serde::Deserialize)]
struct CargoMetadata {
@@ -358,49 +214,3 @@ struct CargoPackage {
version: String,
license: Option<String>,
}
/// Extract the best license string from CycloneDX license entries.
/// Handles three formats: expression ("MIT OR Apache-2.0"), license.id ("MIT"), license.name ("MIT License").
fn extract_license(entries: &[CdxLicenseWrapper]) -> Option<String> {
// First pass: look for SPDX expressions (most precise for dual-licensed packages)
for entry in entries {
if let Some(ref expr) = entry.expression {
if !expr.is_empty() {
return Some(expr.clone());
}
}
}
// Second pass: collect license.id or license.name from all entries
let parts: Vec<String> = entries
.iter()
.filter_map(|e| {
e.license.as_ref().and_then(|lic| {
lic.id
.clone()
.or_else(|| lic.name.clone())
.filter(|s| !s.is_empty())
})
})
.collect();
if parts.is_empty() {
return None;
}
Some(parts.join(" OR "))
}
/// Extract the ecosystem/package-manager from a PURL string.
/// e.g. "pkg:npm/lodash@4.17.21" → "npm", "pkg:cargo/serde@1.0" → "cargo"
fn extract_ecosystem_from_purl(purl: &str) -> Option<String> {
let rest = purl.strip_prefix("pkg:")?;
let ecosystem = rest.split('/').next()?;
if ecosystem.is_empty() {
return None;
}
// Normalise common PURL types to user-friendly names
let normalised = match ecosystem {
"golang" => "go",
"pypi" => "pip",
_ => ecosystem,
};
Some(normalised.to_string())
}

View File

@@ -0,0 +1,355 @@
use std::path::Path;
use compliance_core::models::SbomEntry;
use compliance_core::CoreError;
#[tracing::instrument(skip_all, fields(repo_id = %repo_id))]
pub(super) async fn run_syft(repo_path: &Path, repo_id: &str) -> Result<Vec<SbomEntry>, CoreError> {
let output = tokio::process::Command::new("syft")
.arg(repo_path)
.args(["-o", "cyclonedx-json"])
// Enable remote license lookups for all ecosystems
.env("SYFT_GOLANG_SEARCH_REMOTE_LICENSES", "true")
.env("SYFT_JAVASCRIPT_SEARCH_REMOTE_LICENSES", "true")
.env("SYFT_PYTHON_SEARCH_REMOTE_LICENSES", "true")
.env("SYFT_JAVA_USE_NETWORK", "true")
.output()
.await
.map_err(|e| CoreError::Scanner {
scanner: "syft".to_string(),
source: Box::new(e),
})?;
if !output.status.success() {
let stderr = String::from_utf8_lossy(&output.stderr);
return Err(CoreError::Scanner {
scanner: "syft".to_string(),
source: format!("syft exited with {}: {stderr}", output.status).into(),
});
}
let cdx: CycloneDxBom = serde_json::from_slice(&output.stdout)?;
let entries = cdx
.components
.unwrap_or_default()
.into_iter()
.map(|c| {
let package_manager = c
.purl
.as_deref()
.and_then(extract_ecosystem_from_purl)
.unwrap_or_else(|| "unknown".to_string());
let mut entry = SbomEntry::new(
repo_id.to_string(),
c.name,
c.version.unwrap_or_else(|| "unknown".to_string()),
package_manager,
);
entry.purl = c.purl;
entry.license = c.licenses.and_then(|ls| extract_license(&ls));
entry
})
.collect();
Ok(entries)
}
// CycloneDX JSON types
#[derive(serde::Deserialize)]
struct CycloneDxBom {
components: Option<Vec<CdxComponent>>,
}
#[derive(serde::Deserialize)]
struct CdxComponent {
name: String,
version: Option<String>,
#[serde(rename = "type")]
#[allow(dead_code)]
component_type: Option<String>,
purl: Option<String>,
licenses: Option<Vec<CdxLicenseWrapper>>,
}
#[derive(serde::Deserialize)]
struct CdxLicenseWrapper {
license: Option<CdxLicense>,
/// SPDX license expression (e.g. "MIT OR Apache-2.0")
expression: Option<String>,
}
#[derive(serde::Deserialize)]
struct CdxLicense {
id: Option<String>,
name: Option<String>,
}
/// Extract the best license string from CycloneDX license entries.
/// Handles three formats: expression ("MIT OR Apache-2.0"), license.id ("MIT"), license.name ("MIT License").
fn extract_license(entries: &[CdxLicenseWrapper]) -> Option<String> {
// First pass: look for SPDX expressions (most precise for dual-licensed packages)
for entry in entries {
if let Some(ref expr) = entry.expression {
if !expr.is_empty() {
return Some(expr.clone());
}
}
}
// Second pass: collect license.id or license.name from all entries
let parts: Vec<String> = entries
.iter()
.filter_map(|e| {
e.license.as_ref().and_then(|lic| {
lic.id
.clone()
.or_else(|| lic.name.clone())
.filter(|s| !s.is_empty())
})
})
.collect();
if parts.is_empty() {
return None;
}
Some(parts.join(" OR "))
}
/// Extract the ecosystem/package-manager from a PURL string.
/// e.g. "pkg:npm/lodash@4.17.21" -> "npm", "pkg:cargo/serde@1.0" -> "cargo"
fn extract_ecosystem_from_purl(purl: &str) -> Option<String> {
let rest = purl.strip_prefix("pkg:")?;
let ecosystem = rest.split('/').next()?;
if ecosystem.is_empty() {
return None;
}
// Normalise common PURL types to user-friendly names
let normalised = match ecosystem {
"golang" => "go",
"pypi" => "pip",
_ => ecosystem,
};
Some(normalised.to_string())
}
#[cfg(test)]
mod tests {
use super::*;
// --- extract_ecosystem_from_purl tests ---
#[test]
fn purl_npm() {
assert_eq!(
extract_ecosystem_from_purl("pkg:npm/lodash@4.17.21"),
Some("npm".to_string())
);
}
#[test]
fn purl_cargo() {
assert_eq!(
extract_ecosystem_from_purl("pkg:cargo/serde@1.0.197"),
Some("cargo".to_string())
);
}
#[test]
fn purl_golang_normalised() {
assert_eq!(
extract_ecosystem_from_purl("pkg:golang/github.com/gin-gonic/gin@1.9.1"),
Some("go".to_string())
);
}
#[test]
fn purl_pypi_normalised() {
assert_eq!(
extract_ecosystem_from_purl("pkg:pypi/requests@2.31.0"),
Some("pip".to_string())
);
}
#[test]
fn purl_maven() {
assert_eq!(
extract_ecosystem_from_purl("pkg:maven/org.apache.commons/commons-lang3@3.14.0"),
Some("maven".to_string())
);
}
#[test]
fn purl_missing_prefix() {
assert_eq!(extract_ecosystem_from_purl("npm/lodash@4.17.21"), None);
}
#[test]
fn purl_empty_ecosystem() {
assert_eq!(extract_ecosystem_from_purl("pkg:/lodash@4.17.21"), None);
}
#[test]
fn purl_empty_string() {
assert_eq!(extract_ecosystem_from_purl(""), None);
}
#[test]
fn purl_just_prefix() {
assert_eq!(extract_ecosystem_from_purl("pkg:"), None);
}
// --- extract_license tests ---
#[test]
fn license_from_expression() {
let entries = vec![CdxLicenseWrapper {
license: None,
expression: Some("MIT OR Apache-2.0".to_string()),
}];
assert_eq!(
extract_license(&entries),
Some("MIT OR Apache-2.0".to_string())
);
}
#[test]
fn license_from_id() {
let entries = vec![CdxLicenseWrapper {
license: Some(CdxLicense {
id: Some("MIT".to_string()),
name: None,
}),
expression: None,
}];
assert_eq!(extract_license(&entries), Some("MIT".to_string()));
}
#[test]
fn license_from_name_fallback() {
let entries = vec![CdxLicenseWrapper {
license: Some(CdxLicense {
id: None,
name: Some("MIT License".to_string()),
}),
expression: None,
}];
assert_eq!(extract_license(&entries), Some("MIT License".to_string()));
}
#[test]
fn license_expression_preferred_over_id() {
let entries = vec![
CdxLicenseWrapper {
license: Some(CdxLicense {
id: Some("MIT".to_string()),
name: None,
}),
expression: None,
},
CdxLicenseWrapper {
license: None,
expression: Some("MIT AND Apache-2.0".to_string()),
},
];
// Expression should be preferred (first pass finds it)
assert_eq!(
extract_license(&entries),
Some("MIT AND Apache-2.0".to_string())
);
}
#[test]
fn license_multiple_ids_joined() {
let entries = vec![
CdxLicenseWrapper {
license: Some(CdxLicense {
id: Some("MIT".to_string()),
name: None,
}),
expression: None,
},
CdxLicenseWrapper {
license: Some(CdxLicense {
id: Some("Apache-2.0".to_string()),
name: None,
}),
expression: None,
},
];
assert_eq!(
extract_license(&entries),
Some("MIT OR Apache-2.0".to_string())
);
}
#[test]
fn license_empty_entries() {
let entries: Vec<CdxLicenseWrapper> = vec![];
assert_eq!(extract_license(&entries), None);
}
#[test]
fn license_all_empty_strings() {
let entries = vec![CdxLicenseWrapper {
license: Some(CdxLicense {
id: Some(String::new()),
name: Some(String::new()),
}),
expression: Some(String::new()),
}];
assert_eq!(extract_license(&entries), None);
}
#[test]
fn license_none_fields() {
let entries = vec![CdxLicenseWrapper {
license: None,
expression: None,
}];
assert_eq!(extract_license(&entries), None);
}
// --- CycloneDX deserialization tests ---
#[test]
fn deserialize_cyclonedx_bom() {
let json = r#"{
"components": [
{
"name": "serde",
"version": "1.0.197",
"type": "library",
"purl": "pkg:cargo/serde@1.0.197",
"licenses": [
{"expression": "MIT OR Apache-2.0"}
]
}
]
}"#;
let bom: CycloneDxBom = serde_json::from_str(json).unwrap();
let components = bom.components.unwrap();
assert_eq!(components.len(), 1);
assert_eq!(components[0].name, "serde");
assert_eq!(components[0].version, Some("1.0.197".to_string()));
assert_eq!(
components[0].purl,
Some("pkg:cargo/serde@1.0.197".to_string())
);
}
#[test]
fn deserialize_cyclonedx_no_components() {
let json = r#"{}"#;
let bom: CycloneDxBom = serde_json::from_str(json).unwrap();
assert!(bom.components.is_none());
}
#[test]
fn deserialize_cyclonedx_minimal_component() {
let json = r#"{"components": [{"name": "foo"}]}"#;
let bom: CycloneDxBom = serde_json::from_str(json).unwrap();
let c = &bom.components.unwrap()[0];
assert_eq!(c.name, "foo");
assert!(c.version.is_none());
assert!(c.purl.is_none());
assert!(c.licenses.is_none());
}
}

View File

@@ -108,3 +108,124 @@ struct SemgrepExtra {
#[serde(default)]
metadata: Option<serde_json::Value>,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn deserialize_semgrep_output() {
let json = r#"{
"results": [
{
"check_id": "python.lang.security.audit.exec-detected",
"path": "src/main.py",
"start": {"line": 15},
"extra": {
"message": "Detected use of exec()",
"severity": "ERROR",
"lines": "exec(user_input)",
"metadata": {"cwe": "CWE-78"}
}
}
]
}"#;
let output: SemgrepOutput = serde_json::from_str(json).unwrap();
assert_eq!(output.results.len(), 1);
let r = &output.results[0];
assert_eq!(r.check_id, "python.lang.security.audit.exec-detected");
assert_eq!(r.path, "src/main.py");
assert_eq!(r.start.line, 15);
assert_eq!(r.extra.message, "Detected use of exec()");
assert_eq!(r.extra.severity, "ERROR");
assert_eq!(r.extra.lines, "exec(user_input)");
assert_eq!(
r.extra
.metadata
.as_ref()
.unwrap()
.get("cwe")
.unwrap()
.as_str(),
Some("CWE-78")
);
}
#[test]
fn deserialize_semgrep_empty_results() {
let json = r#"{"results": []}"#;
let output: SemgrepOutput = serde_json::from_str(json).unwrap();
assert!(output.results.is_empty());
}
#[test]
fn deserialize_semgrep_no_metadata() {
let json = r#"{
"results": [
{
"check_id": "rule-1",
"path": "app.py",
"start": {"line": 1},
"extra": {
"message": "found something",
"severity": "WARNING",
"lines": "import os"
}
}
]
}"#;
let output: SemgrepOutput = serde_json::from_str(json).unwrap();
assert!(output.results[0].extra.metadata.is_none());
}
#[test]
fn semgrep_severity_mapping() {
let cases = vec![
("ERROR", "High"),
("WARNING", "Medium"),
("INFO", "Low"),
("UNKNOWN", "Info"),
];
for (input, expected) in cases {
let result = match input {
"ERROR" => "High",
"WARNING" => "Medium",
"INFO" => "Low",
_ => "Info",
};
assert_eq!(result, expected, "Severity for '{input}'");
}
}
#[test]
fn deserialize_semgrep_multiple_results() {
let json = r#"{
"results": [
{
"check_id": "rule-a",
"path": "a.py",
"start": {"line": 1},
"extra": {
"message": "msg a",
"severity": "ERROR",
"lines": "line a"
}
},
{
"check_id": "rule-b",
"path": "b.py",
"start": {"line": 99},
"extra": {
"message": "msg b",
"severity": "INFO",
"lines": "line b"
}
}
]
}"#;
let output: SemgrepOutput = serde_json::from_str(json).unwrap();
assert_eq!(output.results.len(), 2);
assert_eq!(output.results[1].start.line, 99);
}
}

View File

@@ -0,0 +1,81 @@
use compliance_core::models::TrackerIssue;
use compliance_core::traits::issue_tracker::IssueTracker;
use crate::trackers;
/// Enum dispatch for issue trackers (async traits aren't dyn-compatible).
pub(crate) enum TrackerDispatch {
GitHub(trackers::github::GitHubTracker),
GitLab(trackers::gitlab::GitLabTracker),
Gitea(trackers::gitea::GiteaTracker),
Jira(trackers::jira::JiraTracker),
}
impl TrackerDispatch {
pub(crate) fn name(&self) -> &str {
match self {
Self::GitHub(t) => t.name(),
Self::GitLab(t) => t.name(),
Self::Gitea(t) => t.name(),
Self::Jira(t) => t.name(),
}
}
pub(crate) async fn create_issue(
&self,
owner: &str,
repo: &str,
title: &str,
body: &str,
labels: &[String],
) -> Result<TrackerIssue, compliance_core::error::CoreError> {
match self {
Self::GitHub(t) => t.create_issue(owner, repo, title, body, labels).await,
Self::GitLab(t) => t.create_issue(owner, repo, title, body, labels).await,
Self::Gitea(t) => t.create_issue(owner, repo, title, body, labels).await,
Self::Jira(t) => t.create_issue(owner, repo, title, body, labels).await,
}
}
pub(crate) async fn find_existing_issue(
&self,
owner: &str,
repo: &str,
fingerprint: &str,
) -> Result<Option<TrackerIssue>, compliance_core::error::CoreError> {
match self {
Self::GitHub(t) => t.find_existing_issue(owner, repo, fingerprint).await,
Self::GitLab(t) => t.find_existing_issue(owner, repo, fingerprint).await,
Self::Gitea(t) => t.find_existing_issue(owner, repo, fingerprint).await,
Self::Jira(t) => t.find_existing_issue(owner, repo, fingerprint).await,
}
}
pub(crate) async fn create_pr_review(
&self,
owner: &str,
repo: &str,
pr_number: u64,
body: &str,
comments: Vec<compliance_core::traits::issue_tracker::ReviewComment>,
) -> Result<(), compliance_core::error::CoreError> {
match self {
Self::GitHub(t) => {
t.create_pr_review(owner, repo, pr_number, body, comments)
.await
}
Self::GitLab(t) => {
t.create_pr_review(owner, repo, pr_number, body, comments)
.await
}
Self::Gitea(t) => {
t.create_pr_review(owner, repo, pr_number, body, comments)
.await
}
Self::Jira(t) => {
t.create_pr_review(owner, repo, pr_number, body, comments)
.await
}
}
}
}

View File

@@ -0,0 +1,3 @@
// Shared test helpers for compliance-agent integration tests.
//
// Add database mocks, fixtures, and test utilities here.

View File

@@ -0,0 +1,4 @@
// Integration tests for the compliance-agent crate.
//
// Add tests that exercise the full pipeline, API handlers,
// and cross-module interactions here.

View File

@@ -250,7 +250,11 @@ pub enum PentestEvent {
findings_count: u32,
},
/// A new finding was discovered
Finding { finding_id: String, title: String, severity: String },
Finding {
finding_id: String,
title: String,
severity: String,
},
/// Assistant message (streaming text)
Message { content: String },
/// Session completed

View File

@@ -0,0 +1,475 @@
use compliance_core::models::*;
// ─── Severity ───
#[test]
fn severity_display_all_variants() {
assert_eq!(Severity::Info.to_string(), "info");
assert_eq!(Severity::Low.to_string(), "low");
assert_eq!(Severity::Medium.to_string(), "medium");
assert_eq!(Severity::High.to_string(), "high");
assert_eq!(Severity::Critical.to_string(), "critical");
}
#[test]
fn severity_ordering() {
assert!(Severity::Info < Severity::Low);
assert!(Severity::Low < Severity::Medium);
assert!(Severity::Medium < Severity::High);
assert!(Severity::High < Severity::Critical);
}
#[test]
fn severity_serde_roundtrip() {
for sev in [
Severity::Info,
Severity::Low,
Severity::Medium,
Severity::High,
Severity::Critical,
] {
let json = serde_json::to_string(&sev).unwrap();
let back: Severity = serde_json::from_str(&json).unwrap();
assert_eq!(sev, back);
}
}
#[test]
fn severity_deserialize_lowercase() {
let s: Severity = serde_json::from_str(r#""critical""#).unwrap();
assert_eq!(s, Severity::Critical);
}
// ─── FindingStatus ───
#[test]
fn finding_status_display_all_variants() {
assert_eq!(FindingStatus::Open.to_string(), "open");
assert_eq!(FindingStatus::Triaged.to_string(), "triaged");
assert_eq!(FindingStatus::FalsePositive.to_string(), "false_positive");
assert_eq!(FindingStatus::Resolved.to_string(), "resolved");
assert_eq!(FindingStatus::Ignored.to_string(), "ignored");
}
#[test]
fn finding_status_serde_roundtrip() {
for status in [
FindingStatus::Open,
FindingStatus::Triaged,
FindingStatus::FalsePositive,
FindingStatus::Resolved,
FindingStatus::Ignored,
] {
let json = serde_json::to_string(&status).unwrap();
let back: FindingStatus = serde_json::from_str(&json).unwrap();
assert_eq!(status, back);
}
}
// ─── Finding ───
#[test]
fn finding_new_defaults() {
let f = Finding::new(
"repo1".into(),
"fp123".into(),
"semgrep".into(),
ScanType::Sast,
"Test title".into(),
"Test desc".into(),
Severity::High,
);
assert_eq!(f.repo_id, "repo1");
assert_eq!(f.fingerprint, "fp123");
assert_eq!(f.scanner, "semgrep");
assert_eq!(f.scan_type, ScanType::Sast);
assert_eq!(f.severity, Severity::High);
assert_eq!(f.status, FindingStatus::Open);
assert!(f.id.is_none());
assert!(f.rule_id.is_none());
assert!(f.confidence.is_none());
assert!(f.file_path.is_none());
assert!(f.remediation.is_none());
assert!(f.suggested_fix.is_none());
assert!(f.triage_action.is_none());
assert!(f.developer_feedback.is_none());
}
// ─── ScanType ───
#[test]
fn scan_type_display_all_variants() {
let cases = vec![
(ScanType::Sast, "sast"),
(ScanType::Sbom, "sbom"),
(ScanType::Cve, "cve"),
(ScanType::Gdpr, "gdpr"),
(ScanType::OAuth, "oauth"),
(ScanType::Graph, "graph"),
(ScanType::Dast, "dast"),
(ScanType::SecretDetection, "secret_detection"),
(ScanType::Lint, "lint"),
(ScanType::CodeReview, "code_review"),
];
for (variant, expected) in cases {
assert_eq!(variant.to_string(), expected);
}
}
#[test]
fn scan_type_serde_roundtrip() {
for st in [
ScanType::Sast,
ScanType::SecretDetection,
ScanType::CodeReview,
] {
let json = serde_json::to_string(&st).unwrap();
let back: ScanType = serde_json::from_str(&json).unwrap();
assert_eq!(st, back);
}
}
// ─── ScanRun ───
#[test]
fn scan_run_new_defaults() {
let sr = ScanRun::new("repo1".into(), ScanTrigger::Manual);
assert_eq!(sr.repo_id, "repo1");
assert_eq!(sr.trigger, ScanTrigger::Manual);
assert_eq!(sr.status, ScanRunStatus::Running);
assert_eq!(sr.current_phase, ScanPhase::ChangeDetection);
assert!(sr.phases_completed.is_empty());
assert_eq!(sr.new_findings_count, 0);
assert!(sr.error_message.is_none());
assert!(sr.completed_at.is_none());
}
// ─── PentestStatus ───
#[test]
fn pentest_status_display() {
assert_eq!(pentest::PentestStatus::Running.to_string(), "running");
assert_eq!(pentest::PentestStatus::Paused.to_string(), "paused");
assert_eq!(pentest::PentestStatus::Completed.to_string(), "completed");
assert_eq!(pentest::PentestStatus::Failed.to_string(), "failed");
}
// ─── PentestStrategy ───
#[test]
fn pentest_strategy_display() {
assert_eq!(pentest::PentestStrategy::Quick.to_string(), "quick");
assert_eq!(
pentest::PentestStrategy::Comprehensive.to_string(),
"comprehensive"
);
assert_eq!(pentest::PentestStrategy::Targeted.to_string(), "targeted");
assert_eq!(
pentest::PentestStrategy::Aggressive.to_string(),
"aggressive"
);
assert_eq!(pentest::PentestStrategy::Stealth.to_string(), "stealth");
}
// ─── PentestSession ───
#[test]
fn pentest_session_new_defaults() {
let s = pentest::PentestSession::new("target1".into(), pentest::PentestStrategy::Quick);
assert_eq!(s.target_id, "target1");
assert_eq!(s.status, pentest::PentestStatus::Running);
assert_eq!(s.strategy, pentest::PentestStrategy::Quick);
assert_eq!(s.tool_invocations, 0);
assert_eq!(s.tool_successes, 0);
assert_eq!(s.findings_count, 0);
assert!(s.completed_at.is_none());
assert!(s.repo_id.is_none());
}
#[test]
fn pentest_session_success_rate_zero_invocations() {
let s = pentest::PentestSession::new("t".into(), pentest::PentestStrategy::Comprehensive);
assert_eq!(s.success_rate(), 100.0);
}
#[test]
fn pentest_session_success_rate_calculation() {
let mut s = pentest::PentestSession::new("t".into(), pentest::PentestStrategy::Comprehensive);
s.tool_invocations = 10;
s.tool_successes = 7;
assert!((s.success_rate() - 70.0).abs() < f64::EPSILON);
}
#[test]
fn pentest_session_success_rate_all_success() {
let mut s = pentest::PentestSession::new("t".into(), pentest::PentestStrategy::Quick);
s.tool_invocations = 5;
s.tool_successes = 5;
assert_eq!(s.success_rate(), 100.0);
}
#[test]
fn pentest_session_success_rate_none_success() {
let mut s = pentest::PentestSession::new("t".into(), pentest::PentestStrategy::Quick);
s.tool_invocations = 3;
s.tool_successes = 0;
assert_eq!(s.success_rate(), 0.0);
}
// ─── PentestMessage factories ───
#[test]
fn pentest_message_user() {
let m = pentest::PentestMessage::user("sess1".into(), "hello".into());
assert_eq!(m.role, "user");
assert_eq!(m.session_id, "sess1");
assert_eq!(m.content, "hello");
assert!(m.attack_node_id.is_none());
assert!(m.tool_calls.is_none());
}
#[test]
fn pentest_message_assistant() {
let m = pentest::PentestMessage::assistant("sess1".into(), "response".into());
assert_eq!(m.role, "assistant");
}
#[test]
fn pentest_message_tool_result() {
let m = pentest::PentestMessage::tool_result("sess1".into(), "output".into(), "node1".into());
assert_eq!(m.role, "tool_result");
assert_eq!(m.attack_node_id, Some("node1".to_string()));
}
// ─── AttackChainNode ───
#[test]
fn attack_chain_node_new_defaults() {
let n = pentest::AttackChainNode::new(
"sess1".into(),
"node1".into(),
"recon".into(),
serde_json::json!({"target": "example.com"}),
"Starting recon".into(),
);
assert_eq!(n.session_id, "sess1");
assert_eq!(n.node_id, "node1");
assert_eq!(n.tool_name, "recon");
assert_eq!(n.status, pentest::AttackNodeStatus::Pending);
assert!(n.parent_node_ids.is_empty());
assert!(n.findings_produced.is_empty());
assert!(n.risk_score.is_none());
assert!(n.started_at.is_none());
}
// ─── DastTarget ───
#[test]
fn dast_target_new_defaults() {
let t = dast::DastTarget::new(
"My App".into(),
"https://example.com".into(),
dast::DastTargetType::WebApp,
);
assert_eq!(t.name, "My App");
assert_eq!(t.base_url, "https://example.com");
assert_eq!(t.target_type, dast::DastTargetType::WebApp);
assert_eq!(t.max_crawl_depth, 5);
assert_eq!(t.rate_limit, 10);
assert!(!t.allow_destructive);
assert!(t.excluded_paths.is_empty());
assert!(t.auth_config.is_none());
assert!(t.repo_id.is_none());
}
#[test]
fn dast_target_type_display() {
assert_eq!(dast::DastTargetType::WebApp.to_string(), "webapp");
assert_eq!(dast::DastTargetType::RestApi.to_string(), "rest_api");
assert_eq!(dast::DastTargetType::GraphQl.to_string(), "graphql");
}
// ─── DastScanRun ───
#[test]
fn dast_scan_run_new_defaults() {
let sr = dast::DastScanRun::new("target1".into());
assert_eq!(sr.status, dast::DastScanStatus::Running);
assert_eq!(sr.current_phase, dast::DastScanPhase::Reconnaissance);
assert!(sr.phases_completed.is_empty());
assert_eq!(sr.endpoints_discovered, 0);
assert_eq!(sr.findings_count, 0);
assert!(!sr.exploitable_count > 0);
assert!(sr.completed_at.is_none());
}
#[test]
fn dast_scan_phase_display() {
assert_eq!(
dast::DastScanPhase::Reconnaissance.to_string(),
"reconnaissance"
);
assert_eq!(dast::DastScanPhase::Crawling.to_string(), "crawling");
assert_eq!(dast::DastScanPhase::Completed.to_string(), "completed");
}
// ─── DastVulnType ───
#[test]
fn dast_vuln_type_display_all_variants() {
let cases = vec![
(dast::DastVulnType::SqlInjection, "sql_injection"),
(dast::DastVulnType::Xss, "xss"),
(dast::DastVulnType::AuthBypass, "auth_bypass"),
(dast::DastVulnType::Ssrf, "ssrf"),
(dast::DastVulnType::Idor, "idor"),
(dast::DastVulnType::Other, "other"),
];
for (variant, expected) in cases {
assert_eq!(variant.to_string(), expected);
}
}
// ─── DastFinding ───
#[test]
fn dast_finding_new_defaults() {
let f = dast::DastFinding::new(
"run1".into(),
"target1".into(),
dast::DastVulnType::Xss,
"XSS in search".into(),
"Reflected XSS".into(),
Severity::High,
"https://example.com/search".into(),
"GET".into(),
);
assert_eq!(f.vuln_type, dast::DastVulnType::Xss);
assert_eq!(f.severity, Severity::High);
assert!(!f.exploitable);
assert!(f.evidence.is_empty());
assert!(f.session_id.is_none());
assert!(f.linked_sast_finding_id.is_none());
}
// ─── SbomEntry ───
#[test]
fn sbom_entry_new_defaults() {
let e = SbomEntry::new(
"repo1".into(),
"lodash".into(),
"4.17.21".into(),
"npm".into(),
);
assert_eq!(e.name, "lodash");
assert_eq!(e.version, "4.17.21");
assert_eq!(e.package_manager, "npm");
assert!(e.license.is_none());
assert!(e.purl.is_none());
assert!(e.known_vulnerabilities.is_empty());
}
// ─── TrackedRepository ───
#[test]
fn tracked_repository_new_defaults() {
let r = TrackedRepository::new("My Repo".into(), "https://github.com/org/repo.git".into());
assert_eq!(r.name, "My Repo");
assert_eq!(r.git_url, "https://github.com/org/repo.git");
assert_eq!(r.default_branch, "main");
assert!(!r.webhook_enabled);
assert!(r.webhook_secret.is_some());
// Webhook secret should be 32 hex chars (UUID without dashes)
assert_eq!(r.webhook_secret.as_ref().unwrap().len(), 32);
assert!(r.tracker_type.is_none());
assert_eq!(r.findings_count, 0);
}
// ─── ScanTrigger ───
#[test]
fn scan_trigger_serde_roundtrip() {
for trigger in [
ScanTrigger::Scheduled,
ScanTrigger::Webhook,
ScanTrigger::Manual,
] {
let json = serde_json::to_string(&trigger).unwrap();
let back: ScanTrigger = serde_json::from_str(&json).unwrap();
assert_eq!(trigger, back);
}
}
// ─── PentestEvent serde (tagged enum) ───
#[test]
fn pentest_event_serde_thinking() {
let event = pentest::PentestEvent::Thinking {
reasoning: "analyzing target".into(),
};
let json = serde_json::to_string(&event).unwrap();
assert!(json.contains(r#""type":"thinking""#));
assert!(json.contains("analyzing target"));
}
#[test]
fn pentest_event_serde_finding() {
let event = pentest::PentestEvent::Finding {
finding_id: "f1".into(),
title: "XSS".into(),
severity: "high".into(),
};
let json = serde_json::to_string(&event).unwrap();
let back: pentest::PentestEvent = serde_json::from_str(&json).unwrap();
match back {
pentest::PentestEvent::Finding {
finding_id,
title,
severity,
} => {
assert_eq!(finding_id, "f1");
assert_eq!(title, "XSS");
assert_eq!(severity, "high");
}
_ => panic!("wrong variant"),
}
}
// ─── Serde helpers (BSON datetime) ───
#[test]
fn bson_datetime_roundtrip_via_finding() {
let f = Finding::new(
"repo1".into(),
"fp".into(),
"test".into(),
ScanType::Sast,
"t".into(),
"d".into(),
Severity::Low,
);
// Serialize to BSON and back
let bson_doc = bson::to_document(&f).unwrap();
let back: Finding = bson::from_document(bson_doc).unwrap();
// Timestamps should survive (within 1 second tolerance due to ms precision)
assert!((back.created_at - f.created_at).num_milliseconds().abs() < 1000);
}
#[test]
fn opt_bson_datetime_roundtrip_with_none() {
let s = pentest::PentestSession::new("t".into(), pentest::PentestStrategy::Quick);
assert!(s.completed_at.is_none());
let bson_doc = bson::to_document(&s).unwrap();
let back: pentest::PentestSession = bson::from_document(bson_doc).unwrap();
assert!(back.completed_at.is_none());
}
#[test]
fn opt_bson_datetime_roundtrip_with_some() {
let mut s = pentest::PentestSession::new("t".into(), pentest::PentestStrategy::Quick);
s.completed_at = Some(chrono::Utc::now());
let bson_doc = bson::to_document(&s).unwrap();
let back: pentest::PentestSession = bson::from_document(bson_doc).unwrap();
assert!(back.completed_at.is_some());
}

View File

@@ -0,0 +1,283 @@
use std::collections::{HashMap, VecDeque};
/// Get category CSS class from tool name
pub(crate) fn tool_category(name: &str) -> &'static str {
let lower = name.to_lowercase();
if lower.contains("recon") {
return "recon";
}
if lower.contains("openapi") || lower.contains("api") || lower.contains("swagger") {
return "api";
}
if lower.contains("header") {
return "headers";
}
if lower.contains("csp") {
return "csp";
}
if lower.contains("cookie") {
return "cookies";
}
if lower.contains("log") || lower.contains("console") {
return "logs";
}
if lower.contains("rate") || lower.contains("limit") {
return "ratelimit";
}
if lower.contains("cors") {
return "cors";
}
if lower.contains("tls") || lower.contains("ssl") {
return "tls";
}
if lower.contains("redirect") {
return "redirect";
}
if lower.contains("dns")
|| lower.contains("dmarc")
|| lower.contains("email")
|| lower.contains("spf")
{
return "email";
}
if lower.contains("auth")
|| lower.contains("jwt")
|| lower.contains("token")
|| lower.contains("session")
{
return "auth";
}
if lower.contains("xss") {
return "xss";
}
if lower.contains("sql") || lower.contains("sqli") {
return "sqli";
}
if lower.contains("ssrf") {
return "ssrf";
}
if lower.contains("idor") {
return "idor";
}
if lower.contains("fuzz") {
return "fuzzer";
}
if lower.contains("cve") || lower.contains("exploit") {
return "cve";
}
"default"
}
/// Get emoji icon from tool category
pub(crate) fn tool_emoji(cat: &str) -> &'static str {
match cat {
"recon" => "\u{1F50D}",
"api" => "\u{1F517}",
"headers" => "\u{1F6E1}",
"csp" => "\u{1F6A7}",
"cookies" => "\u{1F36A}",
"logs" => "\u{1F4DD}",
"ratelimit" => "\u{23F1}",
"cors" => "\u{1F30D}",
"tls" => "\u{1F510}",
"redirect" => "\u{21AA}",
"email" => "\u{1F4E7}",
"auth" => "\u{1F512}",
"xss" => "\u{26A1}",
"sqli" => "\u{1F489}",
"ssrf" => "\u{1F310}",
"idor" => "\u{1F511}",
"fuzzer" => "\u{1F9EA}",
"cve" => "\u{1F4A3}",
_ => "\u{1F527}",
}
}
/// Compute display label for category
pub(crate) fn cat_label(cat: &str) -> &'static str {
match cat {
"recon" => "Recon",
"api" => "API",
"headers" => "Headers",
"csp" => "CSP",
"cookies" => "Cookies",
"logs" => "Logs",
"ratelimit" => "Rate Limit",
"cors" => "CORS",
"tls" => "TLS",
"redirect" => "Redirect",
"email" => "Email/DNS",
"auth" => "Auth",
"xss" => "XSS",
"sqli" => "SQLi",
"ssrf" => "SSRF",
"idor" => "IDOR",
"fuzzer" => "Fuzzer",
"cve" => "CVE",
_ => "Other",
}
}
/// Phase name heuristic based on depth
pub(crate) fn phase_name(depth: usize) -> &'static str {
match depth {
0 => "Reconnaissance",
1 => "Analysis",
2 => "Boundary Testing",
3 => "Injection & Exploitation",
4 => "Authentication Testing",
5 => "Validation",
6 => "Deep Scan",
_ => "Final",
}
}
/// Short label for phase rail
pub(crate) fn phase_short_name(depth: usize) -> &'static str {
match depth {
0 => "Recon",
1 => "Analysis",
2 => "Boundary",
3 => "Exploit",
4 => "Auth",
5 => "Validate",
6 => "Deep",
_ => "Final",
}
}
/// Compute BFS phases from attack chain nodes
pub(crate) fn compute_phases(steps: &[serde_json::Value]) -> Vec<Vec<usize>> {
let node_ids: Vec<String> = steps
.iter()
.map(|s| {
s.get("node_id")
.and_then(|v| v.as_str())
.unwrap_or("")
.to_string()
})
.collect();
let id_to_idx: HashMap<String, usize> = node_ids
.iter()
.enumerate()
.map(|(i, id)| (id.clone(), i))
.collect();
// Compute depth via BFS
let mut depths = vec![usize::MAX; steps.len()];
let mut queue = VecDeque::new();
// Root nodes: those with no parents or parents not in the set
for (i, step) in steps.iter().enumerate() {
let parents = step
.get("parent_node_ids")
.and_then(|v| v.as_array())
.map(|arr| {
arr.iter()
.filter_map(|p| p.as_str())
.filter(|p| id_to_idx.contains_key(*p))
.count()
})
.unwrap_or(0);
if parents == 0 {
depths[i] = 0;
queue.push_back(i);
}
}
// BFS to compute min depth
while let Some(idx) = queue.pop_front() {
let current_depth = depths[idx];
let node_id = &node_ids[idx];
// Find children: nodes that list this node as a parent
for (j, step) in steps.iter().enumerate() {
if depths[j] <= current_depth + 1 {
continue;
}
let is_child = step
.get("parent_node_ids")
.and_then(|v| v.as_array())
.map(|arr| arr.iter().any(|p| p.as_str() == Some(node_id.as_str())))
.unwrap_or(false);
if is_child {
depths[j] = current_depth + 1;
queue.push_back(j);
}
}
}
// Handle unreachable nodes
for d in depths.iter_mut() {
if *d == usize::MAX {
*d = 0;
}
}
// Group by depth
let max_depth = depths.iter().copied().max().unwrap_or(0);
let mut phases: Vec<Vec<usize>> = Vec::new();
for d in 0..=max_depth {
let indices: Vec<usize> = depths
.iter()
.enumerate()
.filter(|(_, &dep)| dep == d)
.map(|(i, _)| i)
.collect();
if !indices.is_empty() {
phases.push(indices);
}
}
phases
}
/// Format BSON datetime to readable string
pub(crate) fn format_bson_time(val: &serde_json::Value) -> String {
// Handle BSON {"$date":{"$numberLong":"..."}}
if let Some(date_obj) = val.get("$date") {
if let Some(ms_str) = date_obj.get("$numberLong").and_then(|v| v.as_str()) {
if let Ok(ms) = ms_str.parse::<i64>() {
let secs = ms / 1000;
let h = (secs / 3600) % 24;
let m = (secs / 60) % 60;
let s = secs % 60;
return format!("{h:02}:{m:02}:{s:02}");
}
}
// Handle {"$date": "2025-..."}
if let Some(s) = date_obj.as_str() {
return s.to_string();
}
}
// Handle plain string
if let Some(s) = val.as_str() {
return s.to_string();
}
String::new()
}
/// Compute duration string from started_at and completed_at
pub(crate) fn compute_duration(step: &serde_json::Value) -> String {
let extract_ms = |val: &serde_json::Value| -> Option<i64> {
val.get("$date")?
.get("$numberLong")?
.as_str()?
.parse::<i64>()
.ok()
};
let started = step.get("started_at").and_then(extract_ms);
let completed = step.get("completed_at").and_then(extract_ms);
match (started, completed) {
(Some(s), Some(c)) => {
let diff_ms = c - s;
if diff_ms < 1000 {
format!("{}ms", diff_ms)
} else {
format!("{:.1}s", diff_ms as f64 / 1000.0)
}
}
_ => String::new(),
}
}

View File

@@ -0,0 +1,4 @@
pub mod helpers;
mod view;
pub use view::AttackChainView;

View File

@@ -0,0 +1,363 @@
use dioxus::prelude::*;
use super::helpers::*;
/// (phase_index, steps, findings_count, has_failed, has_running, all_done)
type PhaseData<'a> = (usize, Vec<&'a serde_json::Value>, usize, bool, bool, bool);
#[component]
pub fn AttackChainView(
steps: Vec<serde_json::Value>,
is_running: bool,
session_findings: usize,
session_tool_invocations: usize,
session_success_rate: f64,
) -> Element {
let phases = compute_phases(&steps);
// Compute KPIs — prefer session-level stats, fall back to node-level
let total_tools = steps.len();
let node_findings: usize = steps
.iter()
.map(|s| {
s.get("findings_produced")
.and_then(|v| v.as_array())
.map(|a| a.len())
.unwrap_or(0)
})
.sum();
// Use session-level findings count if nodes don't have findings linked
let total_findings = if node_findings > 0 {
node_findings
} else {
session_findings
};
let completed_count = steps
.iter()
.filter(|s| s.get("status").and_then(|v| v.as_str()) == Some("completed"))
.count();
let failed_count = steps
.iter()
.filter(|s| s.get("status").and_then(|v| v.as_str()) == Some("failed"))
.count();
let finished = completed_count + failed_count;
let success_pct = if finished == 0 {
100
} else {
(completed_count * 100) / finished
};
let max_risk: u8 = steps
.iter()
.filter_map(|s| s.get("risk_score").and_then(|v| v.as_u64()))
.map(|v| v as u8)
.max()
.unwrap_or(0);
let progress_pct = if total_tools == 0 {
0
} else {
((completed_count + failed_count) * 100) / total_tools
};
// Build phase data for rail and accordion
let phase_data: Vec<PhaseData<'_>> = phases
.iter()
.enumerate()
.map(|(pi, indices)| {
let phase_steps: Vec<&serde_json::Value> = indices.iter().map(|&i| &steps[i]).collect();
let phase_findings: usize = phase_steps
.iter()
.map(|s| {
s.get("findings_produced")
.and_then(|v| v.as_array())
.map(|a| a.len())
.unwrap_or(0)
})
.sum();
let has_failed = phase_steps
.iter()
.any(|s| s.get("status").and_then(|v| v.as_str()) == Some("failed"));
let has_running = phase_steps
.iter()
.any(|s| s.get("status").and_then(|v| v.as_str()) == Some("running"));
let all_done = phase_steps.iter().all(|s| {
let st = s.get("status").and_then(|v| v.as_str()).unwrap_or("");
st == "completed" || st == "failed" || st == "skipped"
});
(
pi,
phase_steps,
phase_findings,
has_failed,
has_running,
all_done,
)
})
.collect();
let mut active_rail = use_signal(|| 0usize);
rsx! {
// KPI bar
div { class: "ac-kpi-bar",
div { class: "ac-kpi-card",
div { class: "ac-kpi-value", style: "color: var(--text-primary);", "{total_tools}" }
div { class: "ac-kpi-label", "Tools Run" }
}
div { class: "ac-kpi-card",
div { class: "ac-kpi-value", style: "color: var(--danger, #dc2626);", "{total_findings}" }
div { class: "ac-kpi-label", "Findings" }
}
div { class: "ac-kpi-card",
div { class: "ac-kpi-value", style: "color: var(--success, #16a34a);", "{success_pct}%" }
div { class: "ac-kpi-label", "Success Rate" }
}
div { class: "ac-kpi-card",
div { class: "ac-kpi-value", style: "color: var(--warning, #d97706);", "{max_risk}" }
div { class: "ac-kpi-label", "Max Risk" }
}
}
// Phase rail
div { class: "ac-phase-rail",
for (pi, (_phase_idx, phase_steps, phase_findings, has_failed, has_running, all_done)) in phase_data.iter().enumerate() {
{
if pi > 0 {
let prev_done = phase_data.get(pi - 1).map(|p| p.5).unwrap_or(false);
let bar_class = if prev_done && *all_done {
"done"
} else if prev_done {
"running"
} else {
""
};
rsx! {
div { class: "ac-rail-bar",
div { class: "ac-rail-bar-inner {bar_class}" }
}
}
} else {
rsx! {}
}
}
{
let dot_class = if *has_running {
"running"
} else if *has_failed && *all_done {
"mixed"
} else if *all_done {
"done"
} else {
"pending"
};
let is_active = *active_rail.read() == pi;
let active_cls = if is_active { " active" } else { "" };
let findings_cls = if *phase_findings > 0 { "has" } else { "none" };
let findings_text = if *phase_findings > 0 {
format!("{phase_findings}")
} else {
"\u{2014}".to_string()
};
let short = phase_short_name(pi);
rsx! {
div {
class: "ac-rail-node{active_cls}",
onclick: move |_| {
active_rail.set(pi);
let js = format!(
"document.getElementById('ac-phase-{pi}')?.scrollIntoView({{behavior:'smooth',block:'nearest'}});document.getElementById('ac-phase-{pi}')?.classList.add('open');"
);
document::eval(&js);
},
div { class: "ac-rail-dot {dot_class}" }
div { class: "ac-rail-label", "{short}" }
div { class: "ac-rail-findings {findings_cls}", "{findings_text}" }
div { class: "ac-rail-heatmap",
for step in phase_steps.iter() {
{
let st = step.get("status").and_then(|v| v.as_str()).unwrap_or("pending");
let hm_cls = match st {
"completed" => "ok",
"failed" => "fail",
"running" => "run",
_ => "wait",
};
rsx! { div { class: "ac-hm-cell {hm_cls}" } }
}
}
}
}
}
}
}
}
// Progress bar
div { class: "ac-progress-track",
div { class: "ac-progress-fill", style: "width: {progress_pct}%;" }
}
// Expand all
div { class: "ac-controls",
button {
class: "ac-btn-toggle",
onclick: move |_| {
document::eval(
"document.querySelectorAll('.ac-phase').forEach(p => p.classList.toggle('open', !document.querySelector('.ac-phase.open') || !document.querySelectorAll('.ac-phase:not(.open)').length === 0));(function(){var ps=document.querySelectorAll('.ac-phase');var allOpen=Array.from(ps).every(p=>p.classList.contains('open'));ps.forEach(p=>{if(allOpen)p.classList.remove('open');else p.classList.add('open');});})();"
);
},
"Expand all"
}
}
// Phase accordion
div { class: "ac-phases",
for (pi, (_, phase_steps, phase_findings, _has_failed, has_running, _all_done)) in phase_data.iter().enumerate() {
{
let open_cls = if pi == 0 { " open" } else { "" };
let phase_label = phase_name(pi);
let tool_count = phase_steps.len();
let meta_text = if *has_running {
"in progress".to_string()
} else {
format!("{phase_findings} findings")
};
let meta_cls = if *has_running { "running-ct" } else { "findings-ct" };
let phase_num_label = format!("PHASE {}", pi + 1);
let phase_el_id = format!("ac-phase-{pi}");
let phase_el_id2 = phase_el_id.clone();
rsx! {
div {
class: "ac-phase{open_cls}",
id: "{phase_el_id}",
div {
class: "ac-phase-header",
onclick: move |_| {
let js = format!("document.getElementById('{phase_el_id2}').classList.toggle('open');");
document::eval(&js);
},
span { class: "ac-phase-num", "{phase_num_label}" }
span { class: "ac-phase-title", "{phase_label}" }
div { class: "ac-phase-dots",
for step in phase_steps.iter() {
{
let st = step.get("status").and_then(|v| v.as_str()).unwrap_or("pending");
rsx! { div { class: "ac-phase-dot {st}" } }
}
}
}
div { class: "ac-phase-meta",
span { "{tool_count} tools" }
span { class: "{meta_cls}", "{meta_text}" }
}
span { class: "ac-phase-chevron", "\u{25B8}" }
}
div { class: "ac-phase-body",
div { class: "ac-phase-body-inner",
for step in phase_steps.iter() {
{
let tool_name_val = step.get("tool_name").and_then(|v| v.as_str()).unwrap_or("Unknown").to_string();
let status = step.get("status").and_then(|v| v.as_str()).unwrap_or("pending").to_string();
let cat = tool_category(&tool_name_val);
let emoji = tool_emoji(cat);
let label = cat_label(cat);
let findings_n = step.get("findings_produced").and_then(|v| v.as_array()).map(|a| a.len()).unwrap_or(0);
let risk = step.get("risk_score").and_then(|v| v.as_u64()).map(|v| v as u8);
let reasoning = step.get("llm_reasoning").and_then(|v| v.as_str()).unwrap_or("").to_string();
let duration = compute_duration(step);
let started = step.get("started_at").map(format_bson_time).unwrap_or_default();
let is_pending = status == "pending";
let pending_cls = if is_pending { " is-pending" } else { "" };
let duration_cls = if status == "running" { "ac-tool-duration running-text" } else { "ac-tool-duration" };
let duration_text = if status == "running" {
"running\u{2026}".to_string()
} else if duration.is_empty() {
"\u{2014}".to_string()
} else {
duration
};
let pill_cls = if findings_n > 0 { "ac-findings-pill has" } else { "ac-findings-pill zero" };
let pill_text = if findings_n > 0 { format!("{findings_n}") } else { "\u{2014}".to_string() };
let (risk_cls, risk_text) = match risk {
Some(r) if r >= 75 => ("ac-risk-val high", format!("{r}")),
Some(r) if r >= 40 => ("ac-risk-val medium", format!("{r}")),
Some(r) => ("ac-risk-val low", format!("{r}")),
None => ("ac-risk-val none", "\u{2014}".to_string()),
};
let node_id = step.get("node_id").and_then(|v| v.as_str()).unwrap_or("").to_string();
let detail_id = format!("ac-detail-{node_id}");
let row_id = format!("ac-row-{node_id}");
let detail_id_clone = detail_id.clone();
rsx! {
div {
class: "ac-tool-row{pending_cls}",
id: "{row_id}",
onclick: move |_| {
if is_pending { return; }
let js = format!(
"(function(){{var r=document.getElementById('{row_id}');var d=document.getElementById('{detail_id}');if(r.classList.contains('expanded')){{r.classList.remove('expanded');d.classList.remove('open');}}else{{r.classList.add('expanded');d.classList.add('open');}}}})()"
);
document::eval(&js);
},
div { class: "ac-status-bar {status}" }
div { class: "ac-tool-icon", "{emoji}" }
div { class: "ac-tool-info",
div { class: "ac-tool-name", "{tool_name_val}" }
span { class: "ac-cat-chip {cat}", "{label}" }
}
div { class: "{duration_cls}", "{duration_text}" }
div { span { class: "{pill_cls}", "{pill_text}" } }
div { class: "{risk_cls}", "{risk_text}" }
}
div {
class: "ac-tool-detail",
id: "{detail_id_clone}",
if !reasoning.is_empty() || !started.is_empty() {
div { class: "ac-tool-detail-inner",
if !reasoning.is_empty() {
div { class: "ac-reasoning-block", "{reasoning}" }
}
if !started.is_empty() {
div { class: "ac-detail-grid",
span { class: "ac-detail-label", "Started" }
span { class: "ac-detail-value", "{started}" }
if !duration_text.is_empty() && status != "running" && duration_text != "\u{2014}" {
span { class: "ac-detail-label", "Duration" }
span { class: "ac-detail-value", "{duration_text}" }
}
span { class: "ac-detail-label", "Status" }
if status == "completed" {
span { class: "ac-detail-value", style: "color: var(--success, #16a34a);", "Completed" }
} else if status == "failed" {
span { class: "ac-detail-value", style: "color: var(--danger, #dc2626);", "Failed" }
} else if status == "running" {
span { class: "ac-detail-value", style: "color: var(--warning, #d97706);", "Running" }
} else {
span { class: "ac-detail-value", "{status}" }
}
}
}
}
}
}
}
}
}
}
}
}
}
}
}
}
}
}

View File

@@ -1,4 +1,5 @@
pub mod app_shell;
pub mod attack_chain;
pub mod code_inspector;
pub mod code_snippet;
pub mod file_tree;

View File

@@ -101,11 +101,18 @@ pub async fn fetch_pentest_session(id: String) -> Result<PentestSessionResponse,
if let Ok(tbody) = tresp.json::<serde_json::Value>().await {
if let Some(targets) = tbody.get("data").and_then(|v| v.as_array()) {
for t in targets {
let t_id = t.get("_id").and_then(|v| v.get("$oid")).and_then(|v| v.as_str()).unwrap_or("");
let t_id = t
.get("_id")
.and_then(|v| v.get("$oid"))
.and_then(|v| v.as_str())
.unwrap_or("");
if t_id == tid {
if let Some(name) = t.get("name").and_then(|v| v.as_str()) {
body.data.as_object_mut().map(|obj| {
obj.insert("target_name".to_string(), serde_json::Value::String(name.to_string()))
obj.insert(
"target_name".to_string(),
serde_json::Value::String(name.to_string()),
)
});
}
break;
@@ -155,9 +162,7 @@ pub async fn fetch_pentest_stats() -> Result<PentestStatsResponse, ServerFnError
}
#[server]
pub async fn fetch_attack_chain(
session_id: String,
) -> Result<AttackChainResponse, ServerFnError> {
pub async fn fetch_attack_chain(session_id: String) -> Result<AttackChainResponse, ServerFnError> {
let state: super::server_state::ServerState =
dioxus_fullstack::FullstackContext::extract().await?;
let url = format!(

View File

@@ -116,7 +116,10 @@ pub fn PentestDashboardPage() -> Element {
_ => serde_json::Value::Null,
}
};
let severity_critical = sev_dist.get("critical").and_then(|v| v.as_u64()).unwrap_or(0);
let severity_critical = sev_dist
.get("critical")
.and_then(|v| v.as_u64())
.unwrap_or(0);
let severity_high = sev_dist.get("high").and_then(|v| v.as_u64()).unwrap_or(0);
let severity_medium = sev_dist.get("medium").and_then(|v| v.as_u64()).unwrap_or(0);
let severity_low = sev_dist.get("low").and_then(|v| v.as_u64()).unwrap_or(0);

View File

@@ -1,10 +1,9 @@
use std::collections::{HashMap, VecDeque};
use dioxus::prelude::*;
use dioxus_free_icons::icons::bs_icons::*;
use dioxus_free_icons::Icon;
use crate::app::Route;
use crate::components::attack_chain::AttackChainView;
use crate::components::severity_badge::SeverityBadge;
use crate::infrastructure::pentest::{
export_pentest_report, fetch_attack_chain, fetch_pentest_findings, fetch_pentest_session,
@@ -115,9 +114,7 @@ pub fn PentestSessionPage(session_id: String) -> Element {
let list = &data.data;
let c = list
.iter()
.filter(|f| {
f.get("severity").and_then(|v| v.as_str()) == Some("critical")
})
.filter(|f| f.get("severity").and_then(|v| v.as_str()) == Some("critical"))
.count();
let h = list
.iter()
@@ -125,9 +122,7 @@ pub fn PentestSessionPage(session_id: String) -> Element {
.count();
let m = list
.iter()
.filter(|f| {
f.get("severity").and_then(|v| v.as_str()) == Some("medium")
})
.filter(|f| f.get("severity").and_then(|v| v.as_str()) == Some("medium"))
.count();
let l = list
.iter()
@@ -140,7 +135,9 @@ pub fn PentestSessionPage(session_id: String) -> Element {
let e = list
.iter()
.filter(|f| {
f.get("exploitable").and_then(|v| v.as_bool()).unwrap_or(false)
f.get("exploitable")
.and_then(|v| v.as_bool())
.unwrap_or(false)
})
.count();
(c, h, m, l, i, e)
@@ -171,14 +168,7 @@ pub fn PentestSessionPage(session_id: String) -> Element {
let sid = sid_for_export.clone();
spawn(async move {
// TODO: get real user info from auth context
match export_pentest_report(
sid.clone(),
pw,
String::new(),
String::new(),
)
.await
{
match export_pentest_report(sid.clone(), pw, String::new(), String::new()).await {
Ok(resp) => {
export_sha256.set(Some(resp.sha256.clone()));
// Trigger download via JS
@@ -556,586 +546,3 @@ pub fn PentestSessionPage(session_id: String) -> Element {
}
}
}
// ═══════════════════════════════════════
// Attack Chain Visualization Component
// ═══════════════════════════════════════
/// Get category CSS class from tool name
fn tool_category(name: &str) -> &'static str {
let lower = name.to_lowercase();
if lower.contains("recon") { return "recon"; }
if lower.contains("openapi") || lower.contains("api") || lower.contains("swagger") { return "api"; }
if lower.contains("header") { return "headers"; }
if lower.contains("csp") { return "csp"; }
if lower.contains("cookie") { return "cookies"; }
if lower.contains("log") || lower.contains("console") { return "logs"; }
if lower.contains("rate") || lower.contains("limit") { return "ratelimit"; }
if lower.contains("cors") { return "cors"; }
if lower.contains("tls") || lower.contains("ssl") { return "tls"; }
if lower.contains("redirect") { return "redirect"; }
if lower.contains("dns") || lower.contains("dmarc") || lower.contains("email") || lower.contains("spf") { return "email"; }
if lower.contains("auth") || lower.contains("jwt") || lower.contains("token") || lower.contains("session") { return "auth"; }
if lower.contains("xss") { return "xss"; }
if lower.contains("sql") || lower.contains("sqli") { return "sqli"; }
if lower.contains("ssrf") { return "ssrf"; }
if lower.contains("idor") { return "idor"; }
if lower.contains("fuzz") { return "fuzzer"; }
if lower.contains("cve") || lower.contains("exploit") { return "cve"; }
"default"
}
/// Get emoji icon from tool category
fn tool_emoji(cat: &str) -> &'static str {
match cat {
"recon" => "\u{1F50D}",
"api" => "\u{1F517}",
"headers" => "\u{1F6E1}",
"csp" => "\u{1F6A7}",
"cookies" => "\u{1F36A}",
"logs" => "\u{1F4DD}",
"ratelimit" => "\u{23F1}",
"cors" => "\u{1F30D}",
"tls" => "\u{1F510}",
"redirect" => "\u{21AA}",
"email" => "\u{1F4E7}",
"auth" => "\u{1F512}",
"xss" => "\u{26A1}",
"sqli" => "\u{1F489}",
"ssrf" => "\u{1F310}",
"idor" => "\u{1F511}",
"fuzzer" => "\u{1F9EA}",
"cve" => "\u{1F4A3}",
_ => "\u{1F527}",
}
}
/// Compute display label for category
fn cat_label(cat: &str) -> &'static str {
match cat {
"recon" => "Recon",
"api" => "API",
"headers" => "Headers",
"csp" => "CSP",
"cookies" => "Cookies",
"logs" => "Logs",
"ratelimit" => "Rate Limit",
"cors" => "CORS",
"tls" => "TLS",
"redirect" => "Redirect",
"email" => "Email/DNS",
"auth" => "Auth",
"xss" => "XSS",
"sqli" => "SQLi",
"ssrf" => "SSRF",
"idor" => "IDOR",
"fuzzer" => "Fuzzer",
"cve" => "CVE",
_ => "Other",
}
}
/// Phase name heuristic based on depth
fn phase_name(depth: usize) -> &'static str {
match depth {
0 => "Reconnaissance",
1 => "Analysis",
2 => "Boundary Testing",
3 => "Injection & Exploitation",
4 => "Authentication Testing",
5 => "Validation",
6 => "Deep Scan",
_ => "Final",
}
}
/// Short label for phase rail
fn phase_short_name(depth: usize) -> &'static str {
match depth {
0 => "Recon",
1 => "Analysis",
2 => "Boundary",
3 => "Exploit",
4 => "Auth",
5 => "Validate",
6 => "Deep",
_ => "Final",
}
}
/// Compute BFS phases from attack chain nodes
fn compute_phases(steps: &[serde_json::Value]) -> Vec<Vec<usize>> {
let node_ids: Vec<String> = steps
.iter()
.map(|s| s.get("node_id").and_then(|v| v.as_str()).unwrap_or("").to_string())
.collect();
let id_to_idx: HashMap<String, usize> = node_ids
.iter()
.enumerate()
.map(|(i, id)| (id.clone(), i))
.collect();
// Compute depth via BFS
let mut depths = vec![usize::MAX; steps.len()];
let mut queue = VecDeque::new();
// Root nodes: those with no parents or parents not in the set
for (i, step) in steps.iter().enumerate() {
let parents = step
.get("parent_node_ids")
.and_then(|v| v.as_array())
.map(|arr| {
arr.iter()
.filter_map(|p| p.as_str())
.filter(|p| id_to_idx.contains_key(*p))
.count()
})
.unwrap_or(0);
if parents == 0 {
depths[i] = 0;
queue.push_back(i);
}
}
// BFS to compute min depth
while let Some(idx) = queue.pop_front() {
let current_depth = depths[idx];
let node_id = &node_ids[idx];
// Find children: nodes that list this node as a parent
for (j, step) in steps.iter().enumerate() {
if depths[j] <= current_depth + 1 {
continue;
}
let is_child = step
.get("parent_node_ids")
.and_then(|v| v.as_array())
.map(|arr| arr.iter().any(|p| p.as_str() == Some(node_id.as_str())))
.unwrap_or(false);
if is_child {
depths[j] = current_depth + 1;
queue.push_back(j);
}
}
}
// Handle unreachable nodes
for d in depths.iter_mut() {
if *d == usize::MAX {
*d = 0;
}
}
// Group by depth
let max_depth = depths.iter().copied().max().unwrap_or(0);
let mut phases: Vec<Vec<usize>> = Vec::new();
for d in 0..=max_depth {
let indices: Vec<usize> = depths
.iter()
.enumerate()
.filter(|(_, &dep)| dep == d)
.map(|(i, _)| i)
.collect();
if !indices.is_empty() {
phases.push(indices);
}
}
phases
}
/// Format BSON datetime to readable string
fn format_bson_time(val: &serde_json::Value) -> String {
// Handle BSON {"$date":{"$numberLong":"..."}}
if let Some(date_obj) = val.get("$date") {
if let Some(ms_str) = date_obj.get("$numberLong").and_then(|v| v.as_str()) {
if let Ok(ms) = ms_str.parse::<i64>() {
let secs = ms / 1000;
let h = (secs / 3600) % 24;
let m = (secs / 60) % 60;
let s = secs % 60;
return format!("{h:02}:{m:02}:{s:02}");
}
}
// Handle {"$date": "2025-..."}
if let Some(s) = date_obj.as_str() {
return s.to_string();
}
}
// Handle plain string
if let Some(s) = val.as_str() {
return s.to_string();
}
String::new()
}
/// Compute duration string from started_at and completed_at
fn compute_duration(step: &serde_json::Value) -> String {
let extract_ms = |val: &serde_json::Value| -> Option<i64> {
val.get("$date")?
.get("$numberLong")?
.as_str()?
.parse::<i64>()
.ok()
};
let started = step.get("started_at").and_then(extract_ms);
let completed = step.get("completed_at").and_then(extract_ms);
match (started, completed) {
(Some(s), Some(c)) => {
let diff_ms = c - s;
if diff_ms < 1000 {
format!("{}ms", diff_ms)
} else {
format!("{:.1}s", diff_ms as f64 / 1000.0)
}
}
_ => String::new(),
}
}
#[component]
fn AttackChainView(
steps: Vec<serde_json::Value>,
is_running: bool,
session_findings: usize,
session_tool_invocations: usize,
session_success_rate: f64,
) -> Element {
let phases = compute_phases(&steps);
// Compute KPIs — prefer session-level stats, fall back to node-level
let total_tools = steps.len();
let node_findings: usize = steps
.iter()
.map(|s| {
s.get("findings_produced")
.and_then(|v| v.as_array())
.map(|a| a.len())
.unwrap_or(0)
})
.sum();
// Use session-level findings count if nodes don't have findings linked
let total_findings = if node_findings > 0 { node_findings } else { session_findings };
let completed_count = steps
.iter()
.filter(|s| s.get("status").and_then(|v| v.as_str()) == Some("completed"))
.count();
let failed_count = steps
.iter()
.filter(|s| s.get("status").and_then(|v| v.as_str()) == Some("failed"))
.count();
let finished = completed_count + failed_count;
let success_pct = if finished == 0 {
100
} else {
(completed_count * 100) / finished
};
let max_risk: u8 = steps
.iter()
.filter_map(|s| s.get("risk_score").and_then(|v| v.as_u64()))
.map(|v| v as u8)
.max()
.unwrap_or(0);
let progress_pct = if total_tools == 0 {
0
} else {
((completed_count + failed_count) * 100) / total_tools
};
// Build phase data for rail and accordion
let phase_data: Vec<(usize, Vec<&serde_json::Value>, usize, bool, bool, bool)> = phases
.iter()
.enumerate()
.map(|(pi, indices)| {
let phase_steps: Vec<&serde_json::Value> = indices.iter().map(|&i| &steps[i]).collect();
let phase_findings: usize = phase_steps
.iter()
.map(|s| {
s.get("findings_produced")
.and_then(|v| v.as_array())
.map(|a| a.len())
.unwrap_or(0)
})
.sum();
let has_failed = phase_steps
.iter()
.any(|s| s.get("status").and_then(|v| v.as_str()) == Some("failed"));
let has_running = phase_steps
.iter()
.any(|s| s.get("status").and_then(|v| v.as_str()) == Some("running"));
let all_done = phase_steps.iter().all(|s| {
let st = s.get("status").and_then(|v| v.as_str()).unwrap_or("");
st == "completed" || st == "failed" || st == "skipped"
});
(pi, phase_steps, phase_findings, has_failed, has_running, all_done)
})
.collect();
let mut active_rail = use_signal(|| 0usize);
rsx! {
// KPI bar
div { class: "ac-kpi-bar",
div { class: "ac-kpi-card",
div { class: "ac-kpi-value", style: "color: var(--text-primary);", "{total_tools}" }
div { class: "ac-kpi-label", "Tools Run" }
}
div { class: "ac-kpi-card",
div { class: "ac-kpi-value", style: "color: var(--danger, #dc2626);", "{total_findings}" }
div { class: "ac-kpi-label", "Findings" }
}
div { class: "ac-kpi-card",
div { class: "ac-kpi-value", style: "color: var(--success, #16a34a);", "{success_pct}%" }
div { class: "ac-kpi-label", "Success Rate" }
}
div { class: "ac-kpi-card",
div { class: "ac-kpi-value", style: "color: var(--warning, #d97706);", "{max_risk}" }
div { class: "ac-kpi-label", "Max Risk" }
}
}
// Phase rail
div { class: "ac-phase-rail",
for (pi, (_phase_idx, phase_steps, phase_findings, has_failed, has_running, all_done)) in phase_data.iter().enumerate() {
{
if pi > 0 {
let prev_done = phase_data.get(pi - 1).map(|p| p.5).unwrap_or(false);
let bar_class = if prev_done && *all_done {
"done"
} else if prev_done {
"running"
} else {
""
};
rsx! {
div { class: "ac-rail-bar",
div { class: "ac-rail-bar-inner {bar_class}" }
}
}
} else {
rsx! {}
}
}
{
let dot_class = if *has_running {
"running"
} else if *has_failed && *all_done {
"mixed"
} else if *all_done {
"done"
} else {
"pending"
};
let is_active = *active_rail.read() == pi;
let active_cls = if is_active { " active" } else { "" };
let findings_cls = if *phase_findings > 0 { "has" } else { "none" };
let findings_text = if *phase_findings > 0 {
format!("{phase_findings}")
} else {
"\u{2014}".to_string()
};
let short = phase_short_name(pi);
rsx! {
div {
class: "ac-rail-node{active_cls}",
onclick: move |_| {
active_rail.set(pi);
let js = format!(
"document.getElementById('ac-phase-{pi}')?.scrollIntoView({{behavior:'smooth',block:'nearest'}});document.getElementById('ac-phase-{pi}')?.classList.add('open');"
);
document::eval(&js);
},
div { class: "ac-rail-dot {dot_class}" }
div { class: "ac-rail-label", "{short}" }
div { class: "ac-rail-findings {findings_cls}", "{findings_text}" }
div { class: "ac-rail-heatmap",
for step in phase_steps.iter() {
{
let st = step.get("status").and_then(|v| v.as_str()).unwrap_or("pending");
let hm_cls = match st {
"completed" => "ok",
"failed" => "fail",
"running" => "run",
_ => "wait",
};
rsx! { div { class: "ac-hm-cell {hm_cls}" } }
}
}
}
}
}
}
}
}
// Progress bar
div { class: "ac-progress-track",
div { class: "ac-progress-fill", style: "width: {progress_pct}%;" }
}
// Expand all
div { class: "ac-controls",
button {
class: "ac-btn-toggle",
onclick: move |_| {
document::eval(
"document.querySelectorAll('.ac-phase').forEach(p => p.classList.toggle('open', !document.querySelector('.ac-phase.open') || !document.querySelectorAll('.ac-phase:not(.open)').length === 0));(function(){var ps=document.querySelectorAll('.ac-phase');var allOpen=Array.from(ps).every(p=>p.classList.contains('open'));ps.forEach(p=>{if(allOpen)p.classList.remove('open');else p.classList.add('open');});})();"
);
},
"Expand all"
}
}
// Phase accordion
div { class: "ac-phases",
for (pi, (_, phase_steps, phase_findings, has_failed, has_running, all_done)) in phase_data.iter().enumerate() {
{
let open_cls = if pi == 0 { " open" } else { "" };
let phase_label = phase_name(pi);
let tool_count = phase_steps.len();
let meta_text = if *has_running {
"in progress".to_string()
} else {
format!("{phase_findings} findings")
};
let meta_cls = if *has_running { "running-ct" } else { "findings-ct" };
let phase_num_label = format!("PHASE {}", pi + 1);
let phase_el_id = format!("ac-phase-{pi}");
let phase_el_id2 = phase_el_id.clone();
rsx! {
div {
class: "ac-phase{open_cls}",
id: "{phase_el_id}",
div {
class: "ac-phase-header",
onclick: move |_| {
let js = format!("document.getElementById('{phase_el_id2}').classList.toggle('open');");
document::eval(&js);
},
span { class: "ac-phase-num", "{phase_num_label}" }
span { class: "ac-phase-title", "{phase_label}" }
div { class: "ac-phase-dots",
for step in phase_steps.iter() {
{
let st = step.get("status").and_then(|v| v.as_str()).unwrap_or("pending");
rsx! { div { class: "ac-phase-dot {st}" } }
}
}
}
div { class: "ac-phase-meta",
span { "{tool_count} tools" }
span { class: "{meta_cls}", "{meta_text}" }
}
span { class: "ac-phase-chevron", "\u{25B8}" }
}
div { class: "ac-phase-body",
div { class: "ac-phase-body-inner",
for step in phase_steps.iter() {
{
let tool_name_val = step.get("tool_name").and_then(|v| v.as_str()).unwrap_or("Unknown").to_string();
let status = step.get("status").and_then(|v| v.as_str()).unwrap_or("pending").to_string();
let cat = tool_category(&tool_name_val);
let emoji = tool_emoji(cat);
let label = cat_label(cat);
let findings_n = step.get("findings_produced").and_then(|v| v.as_array()).map(|a| a.len()).unwrap_or(0);
let risk = step.get("risk_score").and_then(|v| v.as_u64()).map(|v| v as u8);
let reasoning = step.get("llm_reasoning").and_then(|v| v.as_str()).unwrap_or("").to_string();
let duration = compute_duration(step);
let started = step.get("started_at").map(format_bson_time).unwrap_or_default();
let is_pending = status == "pending";
let pending_cls = if is_pending { " is-pending" } else { "" };
let duration_cls = if status == "running" { "ac-tool-duration running-text" } else { "ac-tool-duration" };
let duration_text = if status == "running" {
"running\u{2026}".to_string()
} else if duration.is_empty() {
"\u{2014}".to_string()
} else {
duration
};
let pill_cls = if findings_n > 0 { "ac-findings-pill has" } else { "ac-findings-pill zero" };
let pill_text = if findings_n > 0 { format!("{findings_n}") } else { "\u{2014}".to_string() };
let (risk_cls, risk_text) = match risk {
Some(r) if r >= 75 => ("ac-risk-val high", format!("{r}")),
Some(r) if r >= 40 => ("ac-risk-val medium", format!("{r}")),
Some(r) => ("ac-risk-val low", format!("{r}")),
None => ("ac-risk-val none", "\u{2014}".to_string()),
};
let node_id = step.get("node_id").and_then(|v| v.as_str()).unwrap_or("").to_string();
let detail_id = format!("ac-detail-{node_id}");
let row_id = format!("ac-row-{node_id}");
let detail_id_clone = detail_id.clone();
rsx! {
div {
class: "ac-tool-row{pending_cls}",
id: "{row_id}",
onclick: move |_| {
if is_pending { return; }
let js = format!(
"(function(){{var r=document.getElementById('{row_id}');var d=document.getElementById('{detail_id}');if(r.classList.contains('expanded')){{r.classList.remove('expanded');d.classList.remove('open');}}else{{r.classList.add('expanded');d.classList.add('open');}}}})()"
);
document::eval(&js);
},
div { class: "ac-status-bar {status}" }
div { class: "ac-tool-icon", "{emoji}" }
div { class: "ac-tool-info",
div { class: "ac-tool-name", "{tool_name_val}" }
span { class: "ac-cat-chip {cat}", "{label}" }
}
div { class: "{duration_cls}", "{duration_text}" }
div { span { class: "{pill_cls}", "{pill_text}" } }
div { class: "{risk_cls}", "{risk_text}" }
}
div {
class: "ac-tool-detail",
id: "{detail_id_clone}",
if !reasoning.is_empty() || !started.is_empty() {
div { class: "ac-tool-detail-inner",
if !reasoning.is_empty() {
div { class: "ac-reasoning-block", "{reasoning}" }
}
if !started.is_empty() {
div { class: "ac-detail-grid",
span { class: "ac-detail-label", "Started" }
span { class: "ac-detail-value", "{started}" }
if !duration_text.is_empty() && status != "running" && duration_text != "\u{2014}" {
span { class: "ac-detail-label", "Duration" }
span { class: "ac-detail-value", "{duration_text}" }
}
span { class: "ac-detail-label", "Status" }
if status == "completed" {
span { class: "ac-detail-value", style: "color: var(--success, #16a34a);", "Completed" }
} else if status == "failed" {
span { class: "ac-detail-value", style: "color: var(--danger, #dc2626);", "Failed" }
} else if status == "running" {
span { class: "ac-detail-value", style: "color: var(--warning, #d97706);", "Running" }
} else {
span { class: "ac-detail-value", "{status}" }
}
}
}
}
}
}
}
}
}
}
}
}
}
}
}
}
}
}

View File

@@ -1,5 +1,7 @@
use compliance_core::error::CoreError;
use compliance_core::traits::dast_agent::{DastAgent, DastContext, DiscoveredEndpoint, EndpointParameter};
use compliance_core::traits::dast_agent::{
DastAgent, DastContext, DiscoveredEndpoint, EndpointParameter,
};
use compliance_core::traits::pentest_tool::{PentestTool, PentestToolContext, PentestToolResult};
use serde_json::json;
@@ -7,30 +9,52 @@ use crate::agents::api_fuzzer::ApiFuzzerAgent;
/// PentestTool wrapper around the existing ApiFuzzerAgent.
pub struct ApiFuzzerTool {
http: reqwest::Client,
_http: reqwest::Client,
agent: ApiFuzzerAgent,
}
impl ApiFuzzerTool {
pub fn new(http: reqwest::Client) -> Self {
let agent = ApiFuzzerAgent::new(http.clone());
Self { http, agent }
Self { _http: http, agent }
}
fn parse_endpoints(input: &serde_json::Value) -> Vec<DiscoveredEndpoint> {
let mut endpoints = Vec::new();
if let Some(arr) = input.get("endpoints").and_then(|v| v.as_array()) {
for ep in arr {
let url = ep.get("url").and_then(|v| v.as_str()).unwrap_or_default().to_string();
let method = ep.get("method").and_then(|v| v.as_str()).unwrap_or("GET").to_string();
let url = ep
.get("url")
.and_then(|v| v.as_str())
.unwrap_or_default()
.to_string();
let method = ep
.get("method")
.and_then(|v| v.as_str())
.unwrap_or("GET")
.to_string();
let mut parameters = Vec::new();
if let Some(params) = ep.get("parameters").and_then(|v| v.as_array()) {
for p in params {
parameters.push(EndpointParameter {
name: p.get("name").and_then(|v| v.as_str()).unwrap_or_default().to_string(),
location: p.get("location").and_then(|v| v.as_str()).unwrap_or("query").to_string(),
param_type: p.get("param_type").and_then(|v| v.as_str()).map(String::from),
example_value: p.get("example_value").and_then(|v| v.as_str()).map(String::from),
name: p
.get("name")
.and_then(|v| v.as_str())
.unwrap_or_default()
.to_string(),
location: p
.get("location")
.and_then(|v| v.as_str())
.unwrap_or("query")
.to_string(),
param_type: p
.get("param_type")
.and_then(|v| v.as_str())
.map(String::from),
example_value: p
.get("example_value")
.and_then(|v| v.as_str())
.map(String::from),
});
}
}
@@ -38,8 +62,14 @@ impl ApiFuzzerTool {
url,
method,
parameters,
content_type: ep.get("content_type").and_then(|v| v.as_str()).map(String::from),
requires_auth: ep.get("requires_auth").and_then(|v| v.as_bool()).unwrap_or(false),
content_type: ep
.get("content_type")
.and_then(|v| v.as_str())
.map(String::from),
requires_auth: ep
.get("requires_auth")
.and_then(|v| v.as_bool())
.unwrap_or(false),
});
}
}
@@ -98,49 +128,51 @@ impl PentestTool for ApiFuzzerTool {
&'a self,
input: serde_json::Value,
context: &'a PentestToolContext,
) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<PentestToolResult, CoreError>> + Send + 'a>> {
) -> std::pin::Pin<
Box<dyn std::future::Future<Output = Result<PentestToolResult, CoreError>> + Send + 'a>,
> {
Box::pin(async move {
let mut endpoints = Self::parse_endpoints(&input);
let mut endpoints = Self::parse_endpoints(&input);
// If a base_url is provided but no endpoints, create a default endpoint
if endpoints.is_empty() {
if let Some(base) = input.get("base_url").and_then(|v| v.as_str()) {
endpoints.push(DiscoveredEndpoint {
url: base.to_string(),
method: "GET".to_string(),
parameters: Vec::new(),
content_type: None,
requires_auth: false,
// If a base_url is provided but no endpoints, create a default endpoint
if endpoints.is_empty() {
if let Some(base) = input.get("base_url").and_then(|v| v.as_str()) {
endpoints.push(DiscoveredEndpoint {
url: base.to_string(),
method: "GET".to_string(),
parameters: Vec::new(),
content_type: None,
requires_auth: false,
});
}
}
if endpoints.is_empty() {
return Ok(PentestToolResult {
summary: "No endpoints or base_url provided to fuzz.".to_string(),
findings: Vec::new(),
data: json!({}),
});
}
}
if endpoints.is_empty() {
return Ok(PentestToolResult {
summary: "No endpoints or base_url provided to fuzz.".to_string(),
findings: Vec::new(),
data: json!({}),
});
}
let dast_context = DastContext {
endpoints,
technologies: Vec::new(),
sast_hints: Vec::new(),
};
let dast_context = DastContext {
endpoints,
technologies: Vec::new(),
sast_hints: Vec::new(),
};
let findings = self.agent.run(&context.target, &dast_context).await?;
let count = findings.len();
let findings = self.agent.run(&context.target, &dast_context).await?;
let count = findings.len();
Ok(PentestToolResult {
summary: if count > 0 {
format!("Found {count} API misconfigurations or information disclosures.")
} else {
"No API misconfigurations detected.".to_string()
},
findings,
data: json!({ "endpoints_tested": dast_context.endpoints.len() }),
})
Ok(PentestToolResult {
summary: if count > 0 {
format!("Found {count} API misconfigurations or information disclosures.")
} else {
"No API misconfigurations detected.".to_string()
},
findings,
data: json!({ "endpoints_tested": dast_context.endpoints.len() }),
})
})
}
}

View File

@@ -1,5 +1,7 @@
use compliance_core::error::CoreError;
use compliance_core::traits::dast_agent::{DastAgent, DastContext, DiscoveredEndpoint, EndpointParameter};
use compliance_core::traits::dast_agent::{
DastAgent, DastContext, DiscoveredEndpoint, EndpointParameter,
};
use compliance_core::traits::pentest_tool::{PentestTool, PentestToolContext, PentestToolResult};
use serde_json::json;
@@ -7,30 +9,52 @@ use crate::agents::auth_bypass::AuthBypassAgent;
/// PentestTool wrapper around the existing AuthBypassAgent.
pub struct AuthBypassTool {
http: reqwest::Client,
_http: reqwest::Client,
agent: AuthBypassAgent,
}
impl AuthBypassTool {
pub fn new(http: reqwest::Client) -> Self {
let agent = AuthBypassAgent::new(http.clone());
Self { http, agent }
Self { _http: http, agent }
}
fn parse_endpoints(input: &serde_json::Value) -> Vec<DiscoveredEndpoint> {
let mut endpoints = Vec::new();
if let Some(arr) = input.get("endpoints").and_then(|v| v.as_array()) {
for ep in arr {
let url = ep.get("url").and_then(|v| v.as_str()).unwrap_or_default().to_string();
let method = ep.get("method").and_then(|v| v.as_str()).unwrap_or("GET").to_string();
let url = ep
.get("url")
.and_then(|v| v.as_str())
.unwrap_or_default()
.to_string();
let method = ep
.get("method")
.and_then(|v| v.as_str())
.unwrap_or("GET")
.to_string();
let mut parameters = Vec::new();
if let Some(params) = ep.get("parameters").and_then(|v| v.as_array()) {
for p in params {
parameters.push(EndpointParameter {
name: p.get("name").and_then(|v| v.as_str()).unwrap_or_default().to_string(),
location: p.get("location").and_then(|v| v.as_str()).unwrap_or("query").to_string(),
param_type: p.get("param_type").and_then(|v| v.as_str()).map(String::from),
example_value: p.get("example_value").and_then(|v| v.as_str()).map(String::from),
name: p
.get("name")
.and_then(|v| v.as_str())
.unwrap_or_default()
.to_string(),
location: p
.get("location")
.and_then(|v| v.as_str())
.unwrap_or("query")
.to_string(),
param_type: p
.get("param_type")
.and_then(|v| v.as_str())
.map(String::from),
example_value: p
.get("example_value")
.and_then(|v| v.as_str())
.map(String::from),
});
}
}
@@ -38,8 +62,14 @@ impl AuthBypassTool {
url,
method,
parameters,
content_type: ep.get("content_type").and_then(|v| v.as_str()).map(String::from),
requires_auth: ep.get("requires_auth").and_then(|v| v.as_bool()).unwrap_or(false),
content_type: ep
.get("content_type")
.and_then(|v| v.as_str())
.map(String::from),
requires_auth: ep
.get("requires_auth")
.and_then(|v| v.as_bool())
.unwrap_or(false),
});
}
}
@@ -96,35 +126,37 @@ impl PentestTool for AuthBypassTool {
&'a self,
input: serde_json::Value,
context: &'a PentestToolContext,
) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<PentestToolResult, CoreError>> + Send + 'a>> {
) -> std::pin::Pin<
Box<dyn std::future::Future<Output = Result<PentestToolResult, CoreError>> + Send + 'a>,
> {
Box::pin(async move {
let endpoints = Self::parse_endpoints(&input);
if endpoints.is_empty() {
return Ok(PentestToolResult {
summary: "No endpoints provided to test.".to_string(),
findings: Vec::new(),
data: json!({}),
});
}
let endpoints = Self::parse_endpoints(&input);
if endpoints.is_empty() {
return Ok(PentestToolResult {
summary: "No endpoints provided to test.".to_string(),
findings: Vec::new(),
data: json!({}),
});
}
let dast_context = DastContext {
endpoints,
technologies: Vec::new(),
sast_hints: Vec::new(),
};
let dast_context = DastContext {
endpoints,
technologies: Vec::new(),
sast_hints: Vec::new(),
};
let findings = self.agent.run(&context.target, &dast_context).await?;
let count = findings.len();
let findings = self.agent.run(&context.target, &dast_context).await?;
let count = findings.len();
Ok(PentestToolResult {
summary: if count > 0 {
format!("Found {count} authentication bypass vulnerabilities.")
} else {
"No authentication bypass vulnerabilities detected.".to_string()
},
findings,
data: json!({ "endpoints_tested": dast_context.endpoints.len() }),
})
Ok(PentestToolResult {
summary: if count > 0 {
format!("Found {count} authentication bypass vulnerabilities.")
} else {
"No authentication bypass vulnerabilities detected.".to_string()
},
findings,
data: json!({ "endpoints_tested": dast_context.endpoints.len() }),
})
})
}
}

View File

@@ -54,7 +54,7 @@ impl ConsoleLogDetectorTool {
}
let quote = html.as_bytes().get(abs_start).copied();
let (open, close) = match quote {
let (_open, close) = match quote {
Some(b'"') => ('"', '"'),
Some(b'\'') => ('\'', '\''),
_ => {
@@ -122,6 +122,96 @@ impl ConsoleLogDetectorTool {
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn extract_js_urls_from_html() {
let html = r#"
<html>
<head>
<script src="/static/app.js"></script>
<script src="https://cdn.example.com/lib.js"></script>
<script src='//cdn2.example.com/vendor.js'></script>
</head>
</html>
"#;
let urls = ConsoleLogDetectorTool::extract_js_urls(html, "https://example.com");
assert_eq!(urls.len(), 3);
assert!(urls.contains(&"https://example.com/static/app.js".to_string()));
assert!(urls.contains(&"https://cdn.example.com/lib.js".to_string()));
assert!(urls.contains(&"https://cdn2.example.com/vendor.js".to_string()));
}
#[test]
fn extract_js_urls_no_scripts() {
let html = "<html><body><p>Hello</p></body></html>";
let urls = ConsoleLogDetectorTool::extract_js_urls(html, "https://example.com");
assert!(urls.is_empty());
}
#[test]
fn extract_js_urls_filters_non_js() {
let html = r#"<link src="/style.css"><script src="/app.js"></script>"#;
let urls = ConsoleLogDetectorTool::extract_js_urls(html, "https://example.com");
// Only .js files should be extracted
assert_eq!(urls.len(), 1);
assert!(urls[0].ends_with("/app.js"));
}
#[test]
fn scan_js_content_finds_console_log() {
let js = r#"
function init() {
console.log("debug info");
doStuff();
}
"#;
let matches = ConsoleLogDetectorTool::scan_js_content(js, "https://example.com/app.js");
assert_eq!(matches.len(), 1);
assert_eq!(matches[0].pattern, "console.log");
assert_eq!(matches[0].line_number, Some(3));
}
#[test]
fn scan_js_content_finds_multiple_patterns() {
let js =
"console.log('a');\nconsole.debug('b');\nconsole.error('c');\ndebugger;\nalert('x');";
let matches = ConsoleLogDetectorTool::scan_js_content(js, "test.js");
assert_eq!(matches.len(), 5);
}
#[test]
fn scan_js_content_skips_comments() {
let js = "// console.log('commented out');\n* console.log('also comment');\n/* console.log('block comment') */";
let matches = ConsoleLogDetectorTool::scan_js_content(js, "test.js");
assert!(matches.is_empty());
}
#[test]
fn scan_js_content_one_match_per_line() {
let js = "console.log('a'); console.debug('b');";
let matches = ConsoleLogDetectorTool::scan_js_content(js, "test.js");
// Only one match per line
assert_eq!(matches.len(), 1);
}
#[test]
fn scan_js_content_empty_input() {
let matches = ConsoleLogDetectorTool::scan_js_content("", "test.js");
assert!(matches.is_empty());
}
#[test]
fn patterns_list_is_not_empty() {
let patterns = ConsoleLogDetectorTool::patterns();
assert!(patterns.len() >= 8);
assert!(patterns.contains(&"console.log("));
assert!(patterns.contains(&"debugger;"));
}
}
impl PentestTool for ConsoleLogDetectorTool {
fn name(&self) -> &str {
"console_log_detector"
@@ -154,173 +244,180 @@ impl PentestTool for ConsoleLogDetectorTool {
&'a self,
input: serde_json::Value,
context: &'a PentestToolContext,
) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<PentestToolResult, CoreError>> + Send + 'a>> {
) -> std::pin::Pin<
Box<dyn std::future::Future<Output = Result<PentestToolResult, CoreError>> + Send + 'a>,
> {
Box::pin(async move {
let url = input
.get("url")
.and_then(|v| v.as_str())
.ok_or_else(|| CoreError::Dast("Missing required 'url' parameter".to_string()))?;
let url = input
.get("url")
.and_then(|v| v.as_str())
.ok_or_else(|| CoreError::Dast("Missing required 'url' parameter".to_string()))?;
let additional_js: Vec<String> = input
.get("additional_js_urls")
.and_then(|v| v.as_array())
.map(|arr| {
arr.iter()
.filter_map(|v| v.as_str().map(String::from))
.collect()
})
.unwrap_or_default();
let target_id = context
.target
.id
.map(|oid| oid.to_hex())
.unwrap_or_else(|| "unknown".to_string());
// Fetch the main page
let response = self
.http
.get(url)
.send()
.await
.map_err(|e| CoreError::Dast(format!("Failed to fetch {url}: {e}")))?;
let html = response.text().await.unwrap_or_default();
// Scan inline scripts in the HTML
let mut all_matches = Vec::new();
let inline_matches = Self::scan_js_content(&html, url);
all_matches.extend(inline_matches);
// Extract JS file URLs from the HTML
let mut js_urls = Self::extract_js_urls(&html, url);
js_urls.extend(additional_js);
js_urls.dedup();
// Fetch and scan each JS file
for js_url in &js_urls {
match self.http.get(js_url).send().await {
Ok(resp) => {
if resp.status().is_success() {
let js_content = resp.text().await.unwrap_or_default();
// Only scan non-minified-looking files or files where we can still
// find patterns (minifiers typically strip console calls, but not always)
let file_matches = Self::scan_js_content(&js_content, js_url);
all_matches.extend(file_matches);
}
}
Err(_) => continue,
}
}
let mut findings = Vec::new();
let match_data: Vec<serde_json::Value> = all_matches
.iter()
.map(|m| {
json!({
"pattern": m.pattern,
"file": m.file_url,
"line": m.line_number,
"snippet": m.line_snippet,
let additional_js: Vec<String> = input
.get("additional_js_urls")
.and_then(|v| v.as_array())
.map(|arr| {
arr.iter()
.filter_map(|v| v.as_str().map(String::from))
.collect()
})
})
.collect();
.unwrap_or_default();
if !all_matches.is_empty() {
// Group by file for the finding
let mut by_file: std::collections::HashMap<&str, Vec<&ConsoleMatch>> =
std::collections::HashMap::new();
for m in &all_matches {
by_file.entry(&m.file_url).or_default().push(m);
let target_id = context
.target
.id
.map(|oid| oid.to_hex())
.unwrap_or_else(|| "unknown".to_string());
// Fetch the main page
let response = self
.http
.get(url)
.send()
.await
.map_err(|e| CoreError::Dast(format!("Failed to fetch {url}: {e}")))?;
let html = response.text().await.unwrap_or_default();
// Scan inline scripts in the HTML
let mut all_matches = Vec::new();
let inline_matches = Self::scan_js_content(&html, url);
all_matches.extend(inline_matches);
// Extract JS file URLs from the HTML
let mut js_urls = Self::extract_js_urls(&html, url);
js_urls.extend(additional_js);
js_urls.dedup();
// Fetch and scan each JS file
for js_url in &js_urls {
match self.http.get(js_url).send().await {
Ok(resp) => {
if resp.status().is_success() {
let js_content = resp.text().await.unwrap_or_default();
// Only scan non-minified-looking files or files where we can still
// find patterns (minifiers typically strip console calls, but not always)
let file_matches = Self::scan_js_content(&js_content, js_url);
all_matches.extend(file_matches);
}
}
Err(_) => continue,
}
}
for (file_url, matches) in &by_file {
let pattern_summary: Vec<String> = matches
.iter()
.take(5)
.map(|m| {
format!(
" Line {}: {} - {}",
m.line_number.unwrap_or(0),
m.pattern,
if m.line_snippet.len() > 80 {
format!("{}...", &m.line_snippet[..80])
} else {
m.line_snippet.clone()
}
)
let mut findings = Vec::new();
let match_data: Vec<serde_json::Value> = all_matches
.iter()
.map(|m| {
json!({
"pattern": m.pattern,
"file": m.file_url,
"line": m.line_number,
"snippet": m.line_snippet,
})
.collect();
})
.collect();
let evidence = DastEvidence {
request_method: "GET".to_string(),
request_url: file_url.to_string(),
request_headers: None,
request_body: None,
response_status: 200,
response_headers: None,
response_snippet: Some(pattern_summary.join("\n")),
screenshot_path: None,
payload: None,
response_time_ms: None,
};
if !all_matches.is_empty() {
// Group by file for the finding
let mut by_file: std::collections::HashMap<&str, Vec<&ConsoleMatch>> =
std::collections::HashMap::new();
for m in &all_matches {
by_file.entry(&m.file_url).or_default().push(m);
}
let total = matches.len();
let extra = if total > 5 {
format!(" (and {} more)", total - 5)
} else {
String::new()
};
for (file_url, matches) in &by_file {
let pattern_summary: Vec<String> = matches
.iter()
.take(5)
.map(|m| {
format!(
" Line {}: {} - {}",
m.line_number.unwrap_or(0),
m.pattern,
if m.line_snippet.len() > 80 {
format!("{}...", &m.line_snippet[..80])
} else {
m.line_snippet.clone()
}
)
})
.collect();
let mut finding = DastFinding::new(
String::new(),
target_id.clone(),
DastVulnType::ConsoleLogLeakage,
format!("Console/debug statements in {}", file_url),
format!(
"Found {total} console/debug statements in {file_url}{extra}. \
let evidence = DastEvidence {
request_method: "GET".to_string(),
request_url: file_url.to_string(),
request_headers: None,
request_body: None,
response_status: 200,
response_headers: None,
response_snippet: Some(pattern_summary.join("\n")),
screenshot_path: None,
payload: None,
response_time_ms: None,
};
let total = matches.len();
let extra = if total > 5 {
format!(" (and {} more)", total - 5)
} else {
String::new()
};
let mut finding = DastFinding::new(
String::new(),
target_id.clone(),
DastVulnType::ConsoleLogLeakage,
format!("Console/debug statements in {}", file_url),
format!(
"Found {total} console/debug statements in {file_url}{extra}. \
These can leak sensitive information such as API responses, user data, \
or internal state to anyone with browser developer tools open."
),
Severity::Low,
file_url.to_string(),
"GET".to_string(),
);
finding.cwe = Some("CWE-532".to_string());
finding.evidence = vec![evidence];
finding.remediation = Some(
"Remove console.log/debug/error statements from production code. \
),
Severity::Low,
file_url.to_string(),
"GET".to_string(),
);
finding.cwe = Some("CWE-532".to_string());
finding.evidence = vec![evidence];
finding.remediation = Some(
"Remove console.log/debug/error statements from production code. \
Use a build step (e.g., babel plugin, terser) to strip console calls \
during the production build."
.to_string(),
);
findings.push(finding);
.to_string(),
);
findings.push(finding);
}
}
}
let total_matches = all_matches.len();
let count = findings.len();
info!(url, js_files = js_urls.len(), total_matches, "Console log detection complete");
let total_matches = all_matches.len();
let count = findings.len();
info!(
url,
js_files = js_urls.len(),
total_matches,
"Console log detection complete"
);
Ok(PentestToolResult {
summary: if total_matches > 0 {
format!(
"Found {total_matches} console/debug statements across {} files.",
count
)
} else {
format!(
"No console/debug statements found in HTML or {} JS files.",
js_urls.len()
)
},
findings,
data: json!({
"total_matches": total_matches,
"js_files_scanned": js_urls.len(),
"matches": match_data,
}),
})
Ok(PentestToolResult {
summary: if total_matches > 0 {
format!(
"Found {total_matches} console/debug statements across {} files.",
count
)
} else {
format!(
"No console/debug statements found in HTML or {} JS files.",
js_urls.len()
)
},
findings,
data: json!({
"total_matches": total_matches,
"js_files_scanned": js_urls.len(),
"matches": match_data,
}),
})
})
}
}

View File

@@ -14,6 +14,7 @@ pub struct CookieAnalyzerTool {
#[derive(Debug)]
struct ParsedCookie {
name: String,
#[allow(dead_code)]
value: String,
secure: bool,
http_only: bool,
@@ -92,6 +93,81 @@ impl CookieAnalyzerTool {
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parse_simple_cookie() {
let cookie = CookieAnalyzerTool::parse_set_cookie("session_id=abc123");
assert_eq!(cookie.name, "session_id");
assert_eq!(cookie.value, "abc123");
assert!(!cookie.secure);
assert!(!cookie.http_only);
assert!(cookie.same_site.is_none());
assert!(cookie.domain.is_none());
assert!(cookie.path.is_none());
}
#[test]
fn parse_cookie_with_all_attributes() {
let raw = "token=xyz; Secure; HttpOnly; SameSite=Strict; Domain=.example.com; Path=/api";
let cookie = CookieAnalyzerTool::parse_set_cookie(raw);
assert_eq!(cookie.name, "token");
assert_eq!(cookie.value, "xyz");
assert!(cookie.secure);
assert!(cookie.http_only);
assert_eq!(cookie.same_site.as_deref(), Some("strict"));
assert_eq!(cookie.domain.as_deref(), Some(".example.com"));
assert_eq!(cookie.path.as_deref(), Some("/api"));
assert_eq!(cookie.raw, raw);
}
#[test]
fn parse_cookie_samesite_none() {
let cookie = CookieAnalyzerTool::parse_set_cookie("id=1; SameSite=None; Secure");
assert_eq!(cookie.same_site.as_deref(), Some("none"));
assert!(cookie.secure);
}
#[test]
fn parse_cookie_with_equals_in_value() {
let cookie = CookieAnalyzerTool::parse_set_cookie("data=a=b=c; HttpOnly");
assert_eq!(cookie.name, "data");
assert_eq!(cookie.value, "a=b=c");
assert!(cookie.http_only);
}
#[test]
fn is_sensitive_cookie_known_names() {
assert!(CookieAnalyzerTool::is_sensitive_cookie("session_id"));
assert!(CookieAnalyzerTool::is_sensitive_cookie("PHPSESSID"));
assert!(CookieAnalyzerTool::is_sensitive_cookie("JSESSIONID"));
assert!(CookieAnalyzerTool::is_sensitive_cookie("connect.sid"));
assert!(CookieAnalyzerTool::is_sensitive_cookie("asp.net_sessionid"));
assert!(CookieAnalyzerTool::is_sensitive_cookie("auth_token"));
assert!(CookieAnalyzerTool::is_sensitive_cookie("jwt_access"));
assert!(CookieAnalyzerTool::is_sensitive_cookie("csrf_token"));
assert!(CookieAnalyzerTool::is_sensitive_cookie("my_sess_cookie"));
assert!(CookieAnalyzerTool::is_sensitive_cookie("SID"));
}
#[test]
fn is_sensitive_cookie_non_sensitive() {
assert!(!CookieAnalyzerTool::is_sensitive_cookie("theme"));
assert!(!CookieAnalyzerTool::is_sensitive_cookie("language"));
assert!(!CookieAnalyzerTool::is_sensitive_cookie("_ga"));
assert!(!CookieAnalyzerTool::is_sensitive_cookie("tracking"));
}
#[test]
fn parse_empty_cookie_header() {
let cookie = CookieAnalyzerTool::parse_set_cookie("");
assert_eq!(cookie.name, "");
assert_eq!(cookie.value, "");
}
}
impl PentestTool for CookieAnalyzerTool {
fn name(&self) -> &str {
"cookie_analyzer"
@@ -123,96 +199,96 @@ impl PentestTool for CookieAnalyzerTool {
&'a self,
input: serde_json::Value,
context: &'a PentestToolContext,
) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<PentestToolResult, CoreError>> + Send + 'a>> {
) -> std::pin::Pin<
Box<dyn std::future::Future<Output = Result<PentestToolResult, CoreError>> + Send + 'a>,
> {
Box::pin(async move {
let url = input
.get("url")
.and_then(|v| v.as_str())
.ok_or_else(|| CoreError::Dast("Missing required 'url' parameter".to_string()))?;
let url = input
.get("url")
.and_then(|v| v.as_str())
.ok_or_else(|| CoreError::Dast("Missing required 'url' parameter".to_string()))?;
let login_url = input.get("login_url").and_then(|v| v.as_str());
let login_url = input.get("login_url").and_then(|v| v.as_str());
let target_id = context
.target
.id
.map(|oid| oid.to_hex())
.unwrap_or_else(|| "unknown".to_string());
let target_id = context
.target
.id
.map(|oid| oid.to_hex())
.unwrap_or_else(|| "unknown".to_string());
let mut findings = Vec::new();
let mut cookie_data = Vec::new();
let mut findings = Vec::new();
let mut cookie_data = Vec::new();
// Collect Set-Cookie headers from the main URL and optional login URL
let urls_to_check: Vec<&str> = std::iter::once(url)
.chain(login_url.into_iter())
.collect();
// Collect Set-Cookie headers from the main URL and optional login URL
let urls_to_check: Vec<&str> = std::iter::once(url).chain(login_url).collect();
for check_url in &urls_to_check {
// Use a client that does NOT follow redirects so we catch cookies on redirect responses
let no_redirect_client = reqwest::Client::builder()
.danger_accept_invalid_certs(true)
.redirect(reqwest::redirect::Policy::none())
.timeout(std::time::Duration::from_secs(15))
.build()
.map_err(|e| CoreError::Dast(format!("Client build error: {e}")))?;
for check_url in &urls_to_check {
// Use a client that does NOT follow redirects so we catch cookies on redirect responses
let no_redirect_client = reqwest::Client::builder()
.danger_accept_invalid_certs(true)
.redirect(reqwest::redirect::Policy::none())
.timeout(std::time::Duration::from_secs(15))
.build()
.map_err(|e| CoreError::Dast(format!("Client build error: {e}")))?;
let response = match no_redirect_client.get(*check_url).send().await {
Ok(r) => r,
Err(e) => {
// Try with the main client that follows redirects
match self.http.get(*check_url).send().await {
Ok(r) => r,
Err(_) => continue,
let response = match no_redirect_client.get(*check_url).send().await {
Ok(r) => r,
Err(_e) => {
// Try with the main client that follows redirects
match self.http.get(*check_url).send().await {
Ok(r) => r,
Err(_) => continue,
}
}
}
};
};
let status = response.status().as_u16();
let set_cookie_headers: Vec<String> = response
.headers()
.get_all("set-cookie")
.iter()
.filter_map(|v| v.to_str().ok().map(String::from))
.collect();
let status = response.status().as_u16();
let set_cookie_headers: Vec<String> = response
.headers()
.get_all("set-cookie")
.iter()
.filter_map(|v| v.to_str().ok().map(String::from))
.collect();
for raw_cookie in &set_cookie_headers {
let cookie = Self::parse_set_cookie(raw_cookie);
let is_sensitive = Self::is_sensitive_cookie(&cookie.name);
let is_https = check_url.starts_with("https://");
for raw_cookie in &set_cookie_headers {
let cookie = Self::parse_set_cookie(raw_cookie);
let is_sensitive = Self::is_sensitive_cookie(&cookie.name);
let is_https = check_url.starts_with("https://");
let cookie_info = json!({
"name": cookie.name,
"secure": cookie.secure,
"http_only": cookie.http_only,
"same_site": cookie.same_site,
"domain": cookie.domain,
"path": cookie.path,
"is_sensitive": is_sensitive,
"url": check_url,
});
cookie_data.push(cookie_info);
let cookie_info = json!({
"name": cookie.name,
"secure": cookie.secure,
"http_only": cookie.http_only,
"same_site": cookie.same_site,
"domain": cookie.domain,
"path": cookie.path,
"is_sensitive": is_sensitive,
"url": check_url,
});
cookie_data.push(cookie_info);
// Check: missing Secure flag
if !cookie.secure && (is_https || is_sensitive) {
let severity = if is_sensitive {
Severity::High
} else {
Severity::Medium
};
// Check: missing Secure flag
if !cookie.secure && (is_https || is_sensitive) {
let severity = if is_sensitive {
Severity::High
} else {
Severity::Medium
};
let evidence = DastEvidence {
request_method: "GET".to_string(),
request_url: check_url.to_string(),
request_headers: None,
request_body: None,
response_status: status,
response_headers: None,
response_snippet: Some(cookie.raw.clone()),
screenshot_path: None,
payload: None,
response_time_ms: None,
};
let evidence = DastEvidence {
request_method: "GET".to_string(),
request_url: check_url.to_string(),
request_headers: None,
request_body: None,
response_status: status,
response_headers: None,
response_snippet: Some(cookie.raw.clone()),
screenshot_path: None,
payload: None,
response_time_ms: None,
};
let mut finding = DastFinding::new(
let mut finding = DastFinding::new(
String::new(),
target_id.clone(),
DastVulnType::CookieSecurity,
@@ -226,32 +302,32 @@ impl PentestTool for CookieAnalyzerTool {
check_url.to_string(),
"GET".to_string(),
);
finding.cwe = Some("CWE-614".to_string());
finding.evidence = vec![evidence];
finding.remediation = Some(
"Add the 'Secure' attribute to the Set-Cookie header to ensure the \
finding.cwe = Some("CWE-614".to_string());
finding.evidence = vec![evidence];
finding.remediation = Some(
"Add the 'Secure' attribute to the Set-Cookie header to ensure the \
cookie is only sent over HTTPS connections."
.to_string(),
);
findings.push(finding);
}
.to_string(),
);
findings.push(finding);
}
// Check: missing HttpOnly flag on sensitive cookies
if !cookie.http_only && is_sensitive {
let evidence = DastEvidence {
request_method: "GET".to_string(),
request_url: check_url.to_string(),
request_headers: None,
request_body: None,
response_status: status,
response_headers: None,
response_snippet: Some(cookie.raw.clone()),
screenshot_path: None,
payload: None,
response_time_ms: None,
};
// Check: missing HttpOnly flag on sensitive cookies
if !cookie.http_only && is_sensitive {
let evidence = DastEvidence {
request_method: "GET".to_string(),
request_url: check_url.to_string(),
request_headers: None,
request_body: None,
response_status: status,
response_headers: None,
response_snippet: Some(cookie.raw.clone()),
screenshot_path: None,
payload: None,
response_time_ms: None,
};
let mut finding = DastFinding::new(
let mut finding = DastFinding::new(
String::new(),
target_id.clone(),
DastVulnType::CookieSecurity,
@@ -265,137 +341,137 @@ impl PentestTool for CookieAnalyzerTool {
check_url.to_string(),
"GET".to_string(),
);
finding.cwe = Some("CWE-1004".to_string());
finding.evidence = vec![evidence];
finding.remediation = Some(
"Add the 'HttpOnly' attribute to the Set-Cookie header to prevent \
finding.cwe = Some("CWE-1004".to_string());
finding.evidence = vec![evidence];
finding.remediation = Some(
"Add the 'HttpOnly' attribute to the Set-Cookie header to prevent \
JavaScript access to the cookie."
.to_string(),
);
findings.push(finding);
}
.to_string(),
);
findings.push(finding);
}
// Check: missing or weak SameSite
if is_sensitive {
let weak_same_site = match &cookie.same_site {
None => true,
Some(ss) => ss == "none",
};
if weak_same_site {
let evidence = DastEvidence {
request_method: "GET".to_string(),
request_url: check_url.to_string(),
request_headers: None,
request_body: None,
response_status: status,
response_headers: None,
response_snippet: Some(cookie.raw.clone()),
screenshot_path: None,
payload: None,
response_time_ms: None,
// Check: missing or weak SameSite
if is_sensitive {
let weak_same_site = match &cookie.same_site {
None => true,
Some(ss) => ss == "none",
};
let desc = if cookie.same_site.is_none() {
format!(
if weak_same_site {
let evidence = DastEvidence {
request_method: "GET".to_string(),
request_url: check_url.to_string(),
request_headers: None,
request_body: None,
response_status: status,
response_headers: None,
response_snippet: Some(cookie.raw.clone()),
screenshot_path: None,
payload: None,
response_time_ms: None,
};
let desc = if cookie.same_site.is_none() {
format!(
"The session/auth cookie '{}' does not have a SameSite attribute. \
This may allow cross-site request forgery (CSRF) attacks.",
cookie.name
)
} else {
format!(
} else {
format!(
"The session/auth cookie '{}' has SameSite=None, which allows it \
to be sent in cross-site requests, enabling CSRF attacks.",
cookie.name
)
};
};
let mut finding = DastFinding::new(
String::new(),
target_id.clone(),
DastVulnType::CookieSecurity,
format!("Cookie '{}' missing or weak SameSite", cookie.name),
desc,
Severity::Medium,
check_url.to_string(),
"GET".to_string(),
);
finding.cwe = Some("CWE-1275".to_string());
finding.evidence = vec![evidence];
finding.remediation = Some(
"Set 'SameSite=Strict' or 'SameSite=Lax' on session/auth cookies \
let mut finding = DastFinding::new(
String::new(),
target_id.clone(),
DastVulnType::CookieSecurity,
format!("Cookie '{}' missing or weak SameSite", cookie.name),
desc,
Severity::Medium,
check_url.to_string(),
"GET".to_string(),
);
finding.cwe = Some("CWE-1275".to_string());
finding.evidence = vec![evidence];
finding.remediation = Some(
"Set 'SameSite=Strict' or 'SameSite=Lax' on session/auth cookies \
to prevent cross-site request inclusion."
.to_string(),
);
findings.push(finding);
.to_string(),
);
findings.push(finding);
}
}
}
// Check: overly broad domain
if let Some(ref domain) = cookie.domain {
// A domain starting with a dot applies to all subdomains
let dot_domain = domain.starts_with('.');
// Count domain parts - if only 2 parts (e.g., .example.com), it's broad
let parts: Vec<&str> = domain.trim_start_matches('.').split('.').collect();
if dot_domain && parts.len() <= 2 && is_sensitive {
let evidence = DastEvidence {
request_method: "GET".to_string(),
request_url: check_url.to_string(),
request_headers: None,
request_body: None,
response_status: status,
response_headers: None,
response_snippet: Some(cookie.raw.clone()),
screenshot_path: None,
payload: None,
response_time_ms: None,
};
// Check: overly broad domain
if let Some(ref domain) = cookie.domain {
// A domain starting with a dot applies to all subdomains
let dot_domain = domain.starts_with('.');
// Count domain parts - if only 2 parts (e.g., .example.com), it's broad
let parts: Vec<&str> = domain.trim_start_matches('.').split('.').collect();
if dot_domain && parts.len() <= 2 && is_sensitive {
let evidence = DastEvidence {
request_method: "GET".to_string(),
request_url: check_url.to_string(),
request_headers: None,
request_body: None,
response_status: status,
response_headers: None,
response_snippet: Some(cookie.raw.clone()),
screenshot_path: None,
payload: None,
response_time_ms: None,
};
let mut finding = DastFinding::new(
String::new(),
target_id.clone(),
DastVulnType::CookieSecurity,
format!("Cookie '{}' has overly broad domain", cookie.name),
format!(
"The cookie '{}' is scoped to domain '{}' which includes all \
let mut finding = DastFinding::new(
String::new(),
target_id.clone(),
DastVulnType::CookieSecurity,
format!("Cookie '{}' has overly broad domain", cookie.name),
format!(
"The cookie '{}' is scoped to domain '{}' which includes all \
subdomains. If any subdomain is compromised, the attacker can \
access this cookie.",
cookie.name, domain
),
Severity::Low,
check_url.to_string(),
"GET".to_string(),
);
finding.cwe = Some("CWE-1004".to_string());
finding.evidence = vec![evidence];
finding.remediation = Some(
cookie.name, domain
),
Severity::Low,
check_url.to_string(),
"GET".to_string(),
);
finding.cwe = Some("CWE-1004".to_string());
finding.evidence = vec![evidence];
finding.remediation = Some(
"Restrict the cookie domain to the specific subdomain that needs it \
rather than the entire parent domain."
.to_string(),
);
findings.push(finding);
findings.push(finding);
}
}
}
}
}
let count = findings.len();
info!(url, findings = count, "Cookie analysis complete");
let count = findings.len();
info!(url, findings = count, "Cookie analysis complete");
Ok(PentestToolResult {
summary: if count > 0 {
format!("Found {count} cookie security issues.")
} else if cookie_data.is_empty() {
"No cookies were set by the target.".to_string()
} else {
"All cookies have proper security attributes.".to_string()
},
findings,
data: json!({
"cookies": cookie_data,
"total_cookies": cookie_data.len(),
}),
})
Ok(PentestToolResult {
summary: if count > 0 {
format!("Found {count} cookie security issues.")
} else if cookie_data.is_empty() {
"No cookies were set by the target.".to_string()
} else {
"All cookies have proper security attributes.".to_string()
},
findings,
data: json!({
"cookies": cookie_data,
"total_cookies": cookie_data.len(),
}),
})
})
}
}

View File

@@ -22,19 +22,60 @@ impl CorsCheckerTool {
vec![
("null_origin", "null".to_string()),
("evil_domain", "https://evil.com".to_string()),
(
"subdomain_spoof",
format!("https://{target_host}.evil.com"),
),
(
"prefix_spoof",
format!("https://evil-{target_host}"),
),
("subdomain_spoof", format!("https://{target_host}.evil.com")),
("prefix_spoof", format!("https://evil-{target_host}")),
("http_downgrade", format!("http://{target_host}")),
]
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_origins_contains_expected_variants() {
let origins = CorsCheckerTool::test_origins("example.com");
assert_eq!(origins.len(), 5);
let names: Vec<&str> = origins.iter().map(|(name, _)| *name).collect();
assert!(names.contains(&"null_origin"));
assert!(names.contains(&"evil_domain"));
assert!(names.contains(&"subdomain_spoof"));
assert!(names.contains(&"prefix_spoof"));
assert!(names.contains(&"http_downgrade"));
}
#[test]
fn test_origins_uses_target_host() {
let origins = CorsCheckerTool::test_origins("myapp.io");
let subdomain = origins
.iter()
.find(|(n, _)| *n == "subdomain_spoof")
.unwrap();
assert_eq!(subdomain.1, "https://myapp.io.evil.com");
let prefix = origins.iter().find(|(n, _)| *n == "prefix_spoof").unwrap();
assert_eq!(prefix.1, "https://evil-myapp.io");
let http_downgrade = origins
.iter()
.find(|(n, _)| *n == "http_downgrade")
.unwrap();
assert_eq!(http_downgrade.1, "http://myapp.io");
}
#[test]
fn test_origins_null_and_evil_are_static() {
let origins = CorsCheckerTool::test_origins("anything.com");
let null_origin = origins.iter().find(|(n, _)| *n == "null_origin").unwrap();
assert_eq!(null_origin.1, "null");
let evil = origins.iter().find(|(n, _)| *n == "evil_domain").unwrap();
assert_eq!(evil.1, "https://evil.com");
}
}
impl PentestTool for CorsCheckerTool {
fn name(&self) -> &str {
"cors_checker"
@@ -68,82 +109,82 @@ impl PentestTool for CorsCheckerTool {
&'a self,
input: serde_json::Value,
context: &'a PentestToolContext,
) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<PentestToolResult, CoreError>> + Send + 'a>> {
) -> std::pin::Pin<
Box<dyn std::future::Future<Output = Result<PentestToolResult, CoreError>> + Send + 'a>,
> {
Box::pin(async move {
let url = input
.get("url")
.and_then(|v| v.as_str())
.ok_or_else(|| CoreError::Dast("Missing required 'url' parameter".to_string()))?;
let url = input
.get("url")
.and_then(|v| v.as_str())
.ok_or_else(|| CoreError::Dast("Missing required 'url' parameter".to_string()))?;
let additional_origins: Vec<String> = input
.get("additional_origins")
.and_then(|v| v.as_array())
.map(|arr| {
arr.iter()
.filter_map(|v| v.as_str().map(String::from))
.collect()
})
.unwrap_or_default();
let additional_origins: Vec<String> = input
.get("additional_origins")
.and_then(|v| v.as_array())
.map(|arr| {
arr.iter()
.filter_map(|v| v.as_str().map(String::from))
.collect()
})
.unwrap_or_default();
let target_id = context
.target
.id
.map(|oid| oid.to_hex())
.unwrap_or_else(|| "unknown".to_string());
let target_id = context
.target
.id
.map(|oid| oid.to_hex())
.unwrap_or_else(|| "unknown".to_string());
let target_host = url::Url::parse(url)
.ok()
.and_then(|u| u.host_str().map(String::from))
.unwrap_or_else(|| url.to_string());
let target_host = url::Url::parse(url)
.ok()
.and_then(|u| u.host_str().map(String::from))
.unwrap_or_else(|| url.to_string());
let mut findings = Vec::new();
let mut cors_data: Vec<serde_json::Value> = Vec::new();
let mut findings = Vec::new();
let mut cors_data: Vec<serde_json::Value> = Vec::new();
// First, send a request without Origin to get baseline
let baseline = self
.http
.get(url)
.send()
.await
.map_err(|e| CoreError::Dast(format!("Failed to fetch {url}: {e}")))?;
// First, send a request without Origin to get baseline
let baseline = self
.http
.get(url)
.send()
.await
.map_err(|e| CoreError::Dast(format!("Failed to fetch {url}: {e}")))?;
let baseline_acao = baseline
.headers()
.get("access-control-allow-origin")
.and_then(|v| v.to_str().ok())
.map(String::from);
let baseline_acao = baseline
.headers()
.get("access-control-allow-origin")
.and_then(|v| v.to_str().ok())
.map(String::from);
cors_data.push(json!({
"origin": null,
"acao": baseline_acao,
}));
cors_data.push(json!({
"origin": null,
"acao": baseline_acao,
}));
// Check for wildcard + credentials (dangerous combo)
if let Some(ref acao) = baseline_acao {
if acao == "*" {
let acac = baseline
.headers()
.get("access-control-allow-credentials")
.and_then(|v| v.to_str().ok())
.unwrap_or("");
// Check for wildcard + credentials (dangerous combo)
if let Some(ref acao) = baseline_acao {
if acao == "*" {
let acac = baseline
.headers()
.get("access-control-allow-credentials")
.and_then(|v| v.to_str().ok())
.unwrap_or("");
if acac.to_lowercase() == "true" {
let evidence = DastEvidence {
request_method: "GET".to_string(),
request_url: url.to_string(),
request_headers: None,
request_body: None,
response_status: baseline.status().as_u16(),
response_headers: None,
response_snippet: Some(format!(
"Access-Control-Allow-Origin: *\nAccess-Control-Allow-Credentials: true"
)),
screenshot_path: None,
payload: None,
response_time_ms: None,
};
if acac.to_lowercase() == "true" {
let evidence = DastEvidence {
request_method: "GET".to_string(),
request_url: url.to_string(),
request_headers: None,
request_body: None,
response_status: baseline.status().as_u16(),
response_headers: None,
response_snippet: Some("Access-Control-Allow-Origin: *\nAccess-Control-Allow-Credentials: true".to_string()),
screenshot_path: None,
payload: None,
response_time_ms: None,
};
let mut finding = DastFinding::new(
let mut finding = DastFinding::new(
String::new(),
target_id.clone(),
DastVulnType::CorsMisconfiguration,
@@ -157,254 +198,251 @@ impl PentestTool for CorsCheckerTool {
url.to_string(),
"GET".to_string(),
);
finding.cwe = Some("CWE-942".to_string());
finding.evidence = vec![evidence];
finding.remediation = Some(
"Never combine Access-Control-Allow-Origin: * with \
finding.cwe = Some("CWE-942".to_string());
finding.evidence = vec![evidence];
finding.remediation = Some(
"Never combine Access-Control-Allow-Origin: * with \
Access-Control-Allow-Credentials: true. Specify explicit allowed origins."
.to_string(),
);
findings.push(finding);
.to_string(),
);
findings.push(finding);
}
}
}
}
// Test with various Origin headers
let mut test_origins = Self::test_origins(&target_host);
for origin in &additional_origins {
test_origins.push(("custom", origin.clone()));
}
// Test with various Origin headers
let mut test_origins = Self::test_origins(&target_host);
for origin in &additional_origins {
test_origins.push(("custom", origin.clone()));
}
for (test_name, origin) in &test_origins {
let resp = match self
for (test_name, origin) in &test_origins {
let resp = match self
.http
.get(url)
.header("Origin", origin.as_str())
.send()
.await
{
Ok(r) => r,
Err(_) => continue,
};
let status = resp.status().as_u16();
let acao = resp
.headers()
.get("access-control-allow-origin")
.and_then(|v| v.to_str().ok())
.map(String::from);
let acac = resp
.headers()
.get("access-control-allow-credentials")
.and_then(|v| v.to_str().ok())
.map(String::from);
let acam = resp
.headers()
.get("access-control-allow-methods")
.and_then(|v| v.to_str().ok())
.map(String::from);
cors_data.push(json!({
"test": test_name,
"origin": origin,
"acao": acao,
"acac": acac,
"acam": acam,
"status": status,
}));
// Check if the origin was reflected back
if let Some(ref acao_val) = acao {
let origin_reflected = acao_val == origin;
let credentials_allowed = acac
.as_ref()
.map(|v| v.to_lowercase() == "true")
.unwrap_or(false);
if origin_reflected && *test_name != "http_downgrade" {
let severity = if credentials_allowed {
Severity::Critical
} else {
Severity::High
};
let resp_headers: HashMap<String, String> = resp
.headers()
.iter()
.map(|(k, v)| (k.to_string(), v.to_str().unwrap_or("").to_string()))
.collect();
let evidence = DastEvidence {
request_method: "GET".to_string(),
request_url: url.to_string(),
request_headers: Some(
[("Origin".to_string(), origin.clone())]
.into_iter()
.collect(),
),
request_body: None,
response_status: status,
response_headers: Some(resp_headers),
response_snippet: Some(format!(
"Origin: {origin}\nAccess-Control-Allow-Origin: {acao_val}\n\
Access-Control-Allow-Credentials: {}",
acac.as_deref().unwrap_or("not set")
)),
screenshot_path: None,
payload: Some(origin.clone()),
response_time_ms: None,
};
let title = match *test_name {
"null_origin" => "CORS accepts null origin".to_string(),
"evil_domain" => "CORS reflects arbitrary origin".to_string(),
"subdomain_spoof" => {
"CORS vulnerable to subdomain spoofing".to_string()
}
"prefix_spoof" => "CORS vulnerable to prefix spoofing".to_string(),
_ => format!("CORS reflects untrusted origin ({test_name})"),
};
let cred_note = if credentials_allowed {
" Combined with Access-Control-Allow-Credentials: true, this allows \
the attacker to steal authenticated data."
} else {
""
};
let mut finding = DastFinding::new(
String::new(),
target_id.clone(),
DastVulnType::CorsMisconfiguration,
title,
format!(
"The endpoint {url} reflects the Origin header '{origin}' back in \
Access-Control-Allow-Origin, allowing cross-origin requests from \
untrusted domains.{cred_note}"
),
severity,
url.to_string(),
"GET".to_string(),
);
finding.cwe = Some("CWE-942".to_string());
finding.exploitable = credentials_allowed;
finding.evidence = vec![evidence];
finding.remediation = Some(
"Validate the Origin header against a whitelist of trusted origins. \
Do not reflect the Origin header value directly. Use specific allowed \
origins instead of wildcards."
.to_string(),
);
findings.push(finding);
warn!(
url,
test_name,
origin,
credentials = credentials_allowed,
"CORS misconfiguration detected"
);
}
// Special case: HTTP downgrade
if *test_name == "http_downgrade" && origin_reflected && credentials_allowed {
let evidence = DastEvidence {
request_method: "GET".to_string(),
request_url: url.to_string(),
request_headers: Some(
[("Origin".to_string(), origin.clone())]
.into_iter()
.collect(),
),
request_body: None,
response_status: status,
response_headers: None,
response_snippet: Some(format!(
"HTTP origin accepted: {origin} -> ACAO: {acao_val}"
)),
screenshot_path: None,
payload: Some(origin.clone()),
response_time_ms: None,
};
let mut finding = DastFinding::new(
String::new(),
target_id.clone(),
DastVulnType::CorsMisconfiguration,
"CORS allows HTTP origin with credentials".to_string(),
format!(
"The HTTPS endpoint {url} accepts the HTTP origin {origin} with \
credentials. An attacker performing a man-in-the-middle attack on \
the HTTP version could steal authenticated data."
),
Severity::High,
url.to_string(),
"GET".to_string(),
);
finding.cwe = Some("CWE-942".to_string());
finding.evidence = vec![evidence];
finding.remediation = Some(
"Do not accept HTTP origins for HTTPS endpoints. Ensure CORS \
origin validation enforces the https:// scheme."
.to_string(),
);
findings.push(finding);
}
}
}
// Also send a preflight OPTIONS request
if let Ok(resp) = self
.http
.get(url)
.header("Origin", origin.as_str())
.request(reqwest::Method::OPTIONS, url)
.header("Origin", "https://evil.com")
.header("Access-Control-Request-Method", "POST")
.header(
"Access-Control-Request-Headers",
"Authorization, Content-Type",
)
.send()
.await
{
Ok(r) => r,
Err(_) => continue,
};
let acam = resp
.headers()
.get("access-control-allow-methods")
.and_then(|v| v.to_str().ok())
.map(String::from);
let status = resp.status().as_u16();
let acao = resp
.headers()
.get("access-control-allow-origin")
.and_then(|v| v.to_str().ok())
.map(String::from);
let acah = resp
.headers()
.get("access-control-allow-headers")
.and_then(|v| v.to_str().ok())
.map(String::from);
let acac = resp
.headers()
.get("access-control-allow-credentials")
.and_then(|v| v.to_str().ok())
.map(String::from);
let acam = resp
.headers()
.get("access-control-allow-methods")
.and_then(|v| v.to_str().ok())
.map(String::from);
cors_data.push(json!({
"test": test_name,
"origin": origin,
"acao": acao,
"acac": acac,
"acam": acam,
"status": status,
}));
// Check if the origin was reflected back
if let Some(ref acao_val) = acao {
let origin_reflected = acao_val == origin;
let credentials_allowed = acac
.as_ref()
.map(|v| v.to_lowercase() == "true")
.unwrap_or(false);
if origin_reflected && *test_name != "http_downgrade" {
let severity = if credentials_allowed {
Severity::Critical
} else {
Severity::High
};
let resp_headers: HashMap<String, String> = resp
.headers()
.iter()
.map(|(k, v)| (k.to_string(), v.to_str().unwrap_or("").to_string()))
.collect();
let evidence = DastEvidence {
request_method: "GET".to_string(),
request_url: url.to_string(),
request_headers: Some(
[("Origin".to_string(), origin.clone())]
.into_iter()
.collect(),
),
request_body: None,
response_status: status,
response_headers: Some(resp_headers),
response_snippet: Some(format!(
"Origin: {origin}\nAccess-Control-Allow-Origin: {acao_val}\n\
Access-Control-Allow-Credentials: {}",
acac.as_deref().unwrap_or("not set")
)),
screenshot_path: None,
payload: Some(origin.clone()),
response_time_ms: None,
};
let title = match *test_name {
"null_origin" => {
"CORS accepts null origin".to_string()
}
"evil_domain" => {
"CORS reflects arbitrary origin".to_string()
}
"subdomain_spoof" => {
"CORS vulnerable to subdomain spoofing".to_string()
}
"prefix_spoof" => {
"CORS vulnerable to prefix spoofing".to_string()
}
_ => format!("CORS reflects untrusted origin ({test_name})"),
};
let cred_note = if credentials_allowed {
" Combined with Access-Control-Allow-Credentials: true, this allows \
the attacker to steal authenticated data."
} else {
""
};
let mut finding = DastFinding::new(
String::new(),
target_id.clone(),
DastVulnType::CorsMisconfiguration,
title,
format!(
"The endpoint {url} reflects the Origin header '{origin}' back in \
Access-Control-Allow-Origin, allowing cross-origin requests from \
untrusted domains.{cred_note}"
),
severity,
url.to_string(),
"GET".to_string(),
);
finding.cwe = Some("CWE-942".to_string());
finding.exploitable = credentials_allowed;
finding.evidence = vec![evidence];
finding.remediation = Some(
"Validate the Origin header against a whitelist of trusted origins. \
Do not reflect the Origin header value directly. Use specific allowed \
origins instead of wildcards."
.to_string(),
);
findings.push(finding);
warn!(
url,
test_name,
origin,
credentials = credentials_allowed,
"CORS misconfiguration detected"
);
}
// Special case: HTTP downgrade
if *test_name == "http_downgrade" && origin_reflected && credentials_allowed {
let evidence = DastEvidence {
request_method: "GET".to_string(),
request_url: url.to_string(),
request_headers: Some(
[("Origin".to_string(), origin.clone())]
.into_iter()
.collect(),
),
request_body: None,
response_status: status,
response_headers: None,
response_snippet: Some(format!(
"HTTP origin accepted: {origin} -> ACAO: {acao_val}"
)),
screenshot_path: None,
payload: Some(origin.clone()),
response_time_ms: None,
};
let mut finding = DastFinding::new(
String::new(),
target_id.clone(),
DastVulnType::CorsMisconfiguration,
"CORS allows HTTP origin with credentials".to_string(),
format!(
"The HTTPS endpoint {url} accepts the HTTP origin {origin} with \
credentials. An attacker performing a man-in-the-middle attack on \
the HTTP version could steal authenticated data."
),
Severity::High,
url.to_string(),
"GET".to_string(),
);
finding.cwe = Some("CWE-942".to_string());
finding.evidence = vec![evidence];
finding.remediation = Some(
"Do not accept HTTP origins for HTTPS endpoints. Ensure CORS \
origin validation enforces the https:// scheme."
.to_string(),
);
findings.push(finding);
}
cors_data.push(json!({
"test": "preflight",
"status": resp.status().as_u16(),
"allow_methods": acam,
"allow_headers": acah,
}));
}
}
// Also send a preflight OPTIONS request
if let Ok(resp) = self
.http
.request(reqwest::Method::OPTIONS, url)
.header("Origin", "https://evil.com")
.header("Access-Control-Request-Method", "POST")
.header("Access-Control-Request-Headers", "Authorization, Content-Type")
.send()
.await
{
let acam = resp
.headers()
.get("access-control-allow-methods")
.and_then(|v| v.to_str().ok())
.map(String::from);
let count = findings.len();
info!(url, findings = count, "CORS check complete");
let acah = resp
.headers()
.get("access-control-allow-headers")
.and_then(|v| v.to_str().ok())
.map(String::from);
cors_data.push(json!({
"test": "preflight",
"status": resp.status().as_u16(),
"allow_methods": acam,
"allow_headers": acah,
}));
}
let count = findings.len();
info!(url, findings = count, "CORS check complete");
Ok(PentestToolResult {
summary: if count > 0 {
format!("Found {count} CORS misconfiguration issues for {url}.")
} else {
format!("CORS configuration appears secure for {url}.")
},
findings,
data: json!({
"tests": cors_data,
}),
})
Ok(PentestToolResult {
summary: if count > 0 {
format!("Found {count} CORS misconfiguration issues for {url}.")
} else {
format!("CORS configuration appears secure for {url}.")
},
findings,
data: json!({
"tests": cors_data,
}),
})
})
}
}

View File

@@ -47,7 +47,7 @@ impl CspAnalyzerTool {
url: &str,
target_id: &str,
status: u16,
csp_raw: &str,
_csp_raw: &str,
) -> Vec<DastFinding> {
let mut findings = Vec::new();
@@ -216,12 +216,18 @@ impl CspAnalyzerTool {
("object-src", "Controls plugins like Flash"),
("base-uri", "Controls the base URL for relative URLs"),
("form-action", "Controls where forms can submit to"),
("frame-ancestors", "Controls who can embed this page in iframes"),
(
"frame-ancestors",
"Controls who can embed this page in iframes",
),
];
for (dir_name, desc) in &important_directives {
if !directive_names.contains(dir_name)
&& !(has_default_src && *dir_name != "frame-ancestors" && *dir_name != "base-uri" && *dir_name != "form-action")
&& (!has_default_src
|| *dir_name == "frame-ancestors"
|| *dir_name == "base-uri"
|| *dir_name == "form-action")
{
let evidence = make_evidence(format!("CSP missing directive: {dir_name}"));
let mut finding = DastFinding::new(
@@ -258,6 +264,125 @@ impl CspAnalyzerTool {
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parse_csp_basic() {
let directives = CspAnalyzerTool::parse_csp(
"default-src 'self'; script-src 'self' https://cdn.example.com",
);
assert_eq!(directives.len(), 2);
assert_eq!(directives[0].name, "default-src");
assert_eq!(directives[0].values, vec!["'self'"]);
assert_eq!(directives[1].name, "script-src");
assert_eq!(
directives[1].values,
vec!["'self'", "https://cdn.example.com"]
);
}
#[test]
fn parse_csp_empty() {
let directives = CspAnalyzerTool::parse_csp("");
assert!(directives.is_empty());
}
#[test]
fn parse_csp_trailing_semicolons() {
let directives = CspAnalyzerTool::parse_csp("default-src 'none';;;");
assert_eq!(directives.len(), 1);
assert_eq!(directives[0].name, "default-src");
assert_eq!(directives[0].values, vec!["'none'"]);
}
#[test]
fn parse_csp_directive_without_value() {
let directives = CspAnalyzerTool::parse_csp("upgrade-insecure-requests");
assert_eq!(directives.len(), 1);
assert_eq!(directives[0].name, "upgrade-insecure-requests");
assert!(directives[0].values.is_empty());
}
#[test]
fn parse_csp_names_are_lowercased() {
let directives = CspAnalyzerTool::parse_csp("Script-Src 'self'");
assert_eq!(directives[0].name, "script-src");
}
#[test]
fn analyze_detects_unsafe_inline() {
let directives = CspAnalyzerTool::parse_csp("script-src 'self' 'unsafe-inline'");
let findings =
CspAnalyzerTool::analyze_directives(&directives, "https://example.com", "t1", 200, "");
assert!(findings.iter().any(|f| f.title.contains("unsafe-inline")));
}
#[test]
fn analyze_detects_unsafe_eval() {
let directives = CspAnalyzerTool::parse_csp("script-src 'self' 'unsafe-eval'");
let findings =
CspAnalyzerTool::analyze_directives(&directives, "https://example.com", "t1", 200, "");
assert!(findings.iter().any(|f| f.title.contains("unsafe-eval")));
}
#[test]
fn analyze_detects_wildcard() {
let directives = CspAnalyzerTool::parse_csp("img-src *");
let findings =
CspAnalyzerTool::analyze_directives(&directives, "https://example.com", "t1", 200, "");
assert!(findings.iter().any(|f| f.title.contains("wildcard")));
}
#[test]
fn analyze_detects_data_uri_in_script_src() {
let directives = CspAnalyzerTool::parse_csp("script-src 'self' data:");
let findings =
CspAnalyzerTool::analyze_directives(&directives, "https://example.com", "t1", 200, "");
assert!(findings.iter().any(|f| f.title.contains("data:")));
}
#[test]
fn analyze_detects_http_sources() {
let directives = CspAnalyzerTool::parse_csp("script-src http:");
let findings =
CspAnalyzerTool::analyze_directives(&directives, "https://example.com", "t1", 200, "");
assert!(findings.iter().any(|f| f.title.contains("HTTP sources")));
}
#[test]
fn analyze_detects_missing_directives_without_default_src() {
let directives = CspAnalyzerTool::parse_csp("img-src 'self'");
let findings =
CspAnalyzerTool::analyze_directives(&directives, "https://example.com", "t1", 200, "");
let missing_names: Vec<&str> = findings
.iter()
.filter(|f| f.title.contains("missing"))
.map(|f| f.title.as_str())
.collect();
// Should flag script-src, object-src, base-uri, form-action, frame-ancestors
assert!(missing_names.len() >= 4);
}
#[test]
fn analyze_good_csp_no_unsafe_findings() {
let directives = CspAnalyzerTool::parse_csp(
"default-src 'none'; script-src 'self'; style-src 'self'; \
img-src 'self'; object-src 'none'; base-uri 'self'; \
form-action 'self'; frame-ancestors 'none'",
);
let findings =
CspAnalyzerTool::analyze_directives(&directives, "https://example.com", "t1", 200, "");
// A well-configured CSP should not produce unsafe-inline/eval/wildcard findings
assert!(findings.iter().all(|f| {
!f.title.contains("unsafe-inline")
&& !f.title.contains("unsafe-eval")
&& !f.title.contains("wildcard")
}));
}
}
impl PentestTool for CspAnalyzerTool {
fn name(&self) -> &str {
"csp_analyzer"
@@ -285,163 +410,167 @@ impl PentestTool for CspAnalyzerTool {
&'a self,
input: serde_json::Value,
context: &'a PentestToolContext,
) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<PentestToolResult, CoreError>> + Send + 'a>> {
) -> std::pin::Pin<
Box<dyn std::future::Future<Output = Result<PentestToolResult, CoreError>> + Send + 'a>,
> {
Box::pin(async move {
let url = input
.get("url")
.and_then(|v| v.as_str())
.ok_or_else(|| CoreError::Dast("Missing required 'url' parameter".to_string()))?;
let url = input
.get("url")
.and_then(|v| v.as_str())
.ok_or_else(|| CoreError::Dast("Missing required 'url' parameter".to_string()))?;
let target_id = context
.target
.id
.map(|oid| oid.to_hex())
.unwrap_or_else(|| "unknown".to_string());
let target_id = context
.target
.id
.map(|oid| oid.to_hex())
.unwrap_or_else(|| "unknown".to_string());
let response = self
.http
.get(url)
.send()
.await
.map_err(|e| CoreError::Dast(format!("Failed to fetch {url}: {e}")))?;
let response = self
.http
.get(url)
.send()
.await
.map_err(|e| CoreError::Dast(format!("Failed to fetch {url}: {e}")))?;
let status = response.status().as_u16();
let status = response.status().as_u16();
// Check for CSP header
let csp_header = response
.headers()
.get("content-security-policy")
.and_then(|v| v.to_str().ok())
.map(String::from);
// Check for CSP header
let csp_header = response
.headers()
.get("content-security-policy")
.and_then(|v| v.to_str().ok())
.map(String::from);
// Also check for report-only variant
let csp_report_only = response
.headers()
.get("content-security-policy-report-only")
.and_then(|v| v.to_str().ok())
.map(String::from);
// Also check for report-only variant
let csp_report_only = response
.headers()
.get("content-security-policy-report-only")
.and_then(|v| v.to_str().ok())
.map(String::from);
let mut findings = Vec::new();
let mut csp_data = json!({});
let mut findings = Vec::new();
let mut csp_data = json!({});
match &csp_header {
Some(csp) => {
let directives = Self::parse_csp(csp);
let directive_map: serde_json::Value = directives
.iter()
.map(|d| (d.name.clone(), json!(d.values)))
.collect::<serde_json::Map<String, serde_json::Value>>()
.into();
match &csp_header {
Some(csp) => {
let directives = Self::parse_csp(csp);
let directive_map: serde_json::Value = directives
.iter()
.map(|d| (d.name.clone(), json!(d.values)))
.collect::<serde_json::Map<String, serde_json::Value>>()
.into();
csp_data["csp_header"] = json!(csp);
csp_data["directives"] = directive_map;
csp_data["csp_header"] = json!(csp);
csp_data["directives"] = directive_map;
findings.extend(Self::analyze_directives(
&directives,
url,
&target_id,
status,
csp,
));
}
None => {
csp_data["csp_header"] = json!(null);
findings.extend(Self::analyze_directives(
&directives,
url,
&target_id,
status,
csp,
));
}
None => {
csp_data["csp_header"] = json!(null);
let evidence = DastEvidence {
request_method: "GET".to_string(),
request_url: url.to_string(),
request_headers: None,
request_body: None,
response_status: status,
response_headers: None,
response_snippet: Some("Content-Security-Policy header is missing".to_string()),
screenshot_path: None,
payload: None,
response_time_ms: None,
};
let evidence = DastEvidence {
request_method: "GET".to_string(),
request_url: url.to_string(),
request_headers: None,
request_body: None,
response_status: status,
response_headers: None,
response_snippet: Some(
"Content-Security-Policy header is missing".to_string(),
),
screenshot_path: None,
payload: None,
response_time_ms: None,
};
let mut finding = DastFinding::new(
String::new(),
target_id.clone(),
DastVulnType::CspIssue,
"Missing Content-Security-Policy header".to_string(),
format!(
"No Content-Security-Policy header is present on {url}. \
let mut finding = DastFinding::new(
String::new(),
target_id.clone(),
DastVulnType::CspIssue,
"Missing Content-Security-Policy header".to_string(),
format!(
"No Content-Security-Policy header is present on {url}. \
Without CSP, the browser has no instructions on which sources are \
trusted, making XSS exploitation much easier."
),
Severity::Medium,
url.to_string(),
"GET".to_string(),
);
finding.cwe = Some("CWE-16".to_string());
finding.evidence = vec![evidence];
finding.remediation = Some(
),
Severity::Medium,
url.to_string(),
"GET".to_string(),
);
finding.cwe = Some("CWE-16".to_string());
finding.evidence = vec![evidence];
finding.remediation = Some(
"Add a Content-Security-Policy header. Start with a restrictive policy like \
\"default-src 'self'; script-src 'self'; style-src 'self'; img-src 'self' data:; \
object-src 'none'; frame-ancestors 'none'; base-uri 'self'\"."
.to_string(),
);
findings.push(finding);
findings.push(finding);
}
}
}
if let Some(ref report_only) = csp_report_only {
csp_data["csp_report_only"] = json!(report_only);
// If ONLY report-only exists (no enforcing CSP), warn
if csp_header.is_none() {
let evidence = DastEvidence {
request_method: "GET".to_string(),
request_url: url.to_string(),
request_headers: None,
request_body: None,
response_status: status,
response_headers: None,
response_snippet: Some(format!(
"Content-Security-Policy-Report-Only: {}",
report_only
)),
screenshot_path: None,
payload: None,
response_time_ms: None,
};
if let Some(ref report_only) = csp_report_only {
csp_data["csp_report_only"] = json!(report_only);
// If ONLY report-only exists (no enforcing CSP), warn
if csp_header.is_none() {
let evidence = DastEvidence {
request_method: "GET".to_string(),
request_url: url.to_string(),
request_headers: None,
request_body: None,
response_status: status,
response_headers: None,
response_snippet: Some(format!(
"Content-Security-Policy-Report-Only: {}",
report_only
)),
screenshot_path: None,
payload: None,
response_time_ms: None,
};
let mut finding = DastFinding::new(
String::new(),
target_id.clone(),
DastVulnType::CspIssue,
"CSP is report-only, not enforcing".to_string(),
"A Content-Security-Policy-Report-Only header is present but no enforcing \
let mut finding = DastFinding::new(
String::new(),
target_id.clone(),
DastVulnType::CspIssue,
"CSP is report-only, not enforcing".to_string(),
"A Content-Security-Policy-Report-Only header is present but no enforcing \
Content-Security-Policy header exists. Report-only mode only logs violations \
but does not block them."
.to_string(),
Severity::Low,
url.to_string(),
"GET".to_string(),
);
finding.cwe = Some("CWE-16".to_string());
finding.evidence = vec![evidence];
finding.remediation = Some(
.to_string(),
Severity::Low,
url.to_string(),
"GET".to_string(),
);
finding.cwe = Some("CWE-16".to_string());
finding.evidence = vec![evidence];
finding.remediation = Some(
"Once you have verified the CSP policy works correctly in report-only mode, \
deploy it as an enforcing Content-Security-Policy header."
.to_string(),
);
findings.push(finding);
findings.push(finding);
}
}
}
let count = findings.len();
info!(url, findings = count, "CSP analysis complete");
let count = findings.len();
info!(url, findings = count, "CSP analysis complete");
Ok(PentestToolResult {
summary: if count > 0 {
format!("Found {count} CSP issues for {url}.")
} else {
format!("Content-Security-Policy looks good for {url}.")
},
findings,
data: csp_data,
})
Ok(PentestToolResult {
summary: if count > 0 {
format!("Found {count} CSP issues for {url}.")
} else {
format!("Content-Security-Policy looks good for {url}.")
},
findings,
data: csp_data,
})
})
}
}

View File

@@ -8,6 +8,12 @@ use tracing::{info, warn};
/// Tool that checks email security configuration (DMARC and SPF records).
pub struct DmarcCheckerTool;
impl Default for DmarcCheckerTool {
fn default() -> Self {
Self::new()
}
}
impl DmarcCheckerTool {
pub fn new() -> Self {
Self
@@ -78,6 +84,105 @@ impl DmarcCheckerTool {
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parse_dmarc_policy_reject() {
let record = "v=DMARC1; p=reject; rua=mailto:dmarc@example.com";
assert_eq!(
DmarcCheckerTool::parse_dmarc_policy(record),
Some("reject".to_string())
);
}
#[test]
fn parse_dmarc_policy_none() {
let record = "v=DMARC1; p=none";
assert_eq!(
DmarcCheckerTool::parse_dmarc_policy(record),
Some("none".to_string())
);
}
#[test]
fn parse_dmarc_policy_quarantine() {
let record = "v=DMARC1; p=quarantine; sp=none";
assert_eq!(
DmarcCheckerTool::parse_dmarc_policy(record),
Some("quarantine".to_string())
);
}
#[test]
fn parse_dmarc_policy_missing() {
let record = "v=DMARC1; rua=mailto:test@example.com";
assert_eq!(DmarcCheckerTool::parse_dmarc_policy(record), None);
}
#[test]
fn parse_dmarc_subdomain_policy() {
let record = "v=DMARC1; p=reject; sp=quarantine";
assert_eq!(
DmarcCheckerTool::parse_dmarc_subdomain_policy(record),
Some("quarantine".to_string())
);
}
#[test]
fn parse_dmarc_subdomain_policy_missing() {
let record = "v=DMARC1; p=reject";
assert_eq!(DmarcCheckerTool::parse_dmarc_subdomain_policy(record), None);
}
#[test]
fn parse_dmarc_rua_present() {
let record = "v=DMARC1; p=reject; rua=mailto:dmarc@example.com";
assert_eq!(
DmarcCheckerTool::parse_dmarc_rua(record),
Some("mailto:dmarc@example.com".to_string())
);
}
#[test]
fn parse_dmarc_rua_missing() {
let record = "v=DMARC1; p=none";
assert_eq!(DmarcCheckerTool::parse_dmarc_rua(record), None);
}
#[test]
fn is_spf_record_valid() {
assert!(DmarcCheckerTool::is_spf_record(
"v=spf1 include:_spf.google.com -all"
));
assert!(DmarcCheckerTool::is_spf_record("v=spf1 -all"));
}
#[test]
fn is_spf_record_invalid() {
assert!(!DmarcCheckerTool::is_spf_record("v=DMARC1; p=reject"));
assert!(!DmarcCheckerTool::is_spf_record("some random txt record"));
}
#[test]
fn spf_soft_fail_detection() {
assert!(DmarcCheckerTool::spf_uses_soft_fail(
"v=spf1 include:_spf.google.com ~all"
));
assert!(!DmarcCheckerTool::spf_uses_soft_fail(
"v=spf1 include:_spf.google.com -all"
));
}
#[test]
fn spf_allows_all_detection() {
assert!(DmarcCheckerTool::spf_allows_all("v=spf1 +all"));
assert!(!DmarcCheckerTool::spf_allows_all("v=spf1 -all"));
assert!(!DmarcCheckerTool::spf_allows_all("v=spf1 ~all"));
}
}
impl PentestTool for DmarcCheckerTool {
fn name(&self) -> &str {
"dmarc_checker"
@@ -105,43 +210,89 @@ impl PentestTool for DmarcCheckerTool {
&'a self,
input: serde_json::Value,
context: &'a PentestToolContext,
) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<PentestToolResult, CoreError>> + Send + 'a>> {
) -> std::pin::Pin<
Box<dyn std::future::Future<Output = Result<PentestToolResult, CoreError>> + Send + 'a>,
> {
Box::pin(async move {
let domain = input
.get("domain")
.and_then(|v| v.as_str())
.ok_or_else(|| CoreError::Dast("Missing required 'domain' parameter".to_string()))?;
let domain = input
.get("domain")
.and_then(|v| v.as_str())
.ok_or_else(|| {
CoreError::Dast("Missing required 'domain' parameter".to_string())
})?;
let target_id = context
.target
.id
.map(|oid| oid.to_hex())
.unwrap_or_else(|| "unknown".to_string());
let target_id = context
.target
.id
.map(|oid| oid.to_hex())
.unwrap_or_else(|| "unknown".to_string());
let mut findings = Vec::new();
let mut email_data = json!({});
let mut findings = Vec::new();
let mut email_data = json!({});
// ---- DMARC check ----
let dmarc_domain = format!("_dmarc.{domain}");
let dmarc_records = Self::query_txt(&dmarc_domain).await.unwrap_or_default();
// ---- DMARC check ----
let dmarc_domain = format!("_dmarc.{domain}");
let dmarc_records = Self::query_txt(&dmarc_domain).await.unwrap_or_default();
let dmarc_record = dmarc_records.iter().find(|r| r.starts_with("v=DMARC1"));
let dmarc_record = dmarc_records.iter().find(|r| r.starts_with("v=DMARC1"));
match dmarc_record {
Some(record) => {
email_data["dmarc_record"] = json!(record);
match dmarc_record {
Some(record) => {
email_data["dmarc_record"] = json!(record);
let policy = Self::parse_dmarc_policy(record);
let sp = Self::parse_dmarc_subdomain_policy(record);
let rua = Self::parse_dmarc_rua(record);
let policy = Self::parse_dmarc_policy(record);
let sp = Self::parse_dmarc_subdomain_policy(record);
let rua = Self::parse_dmarc_rua(record);
email_data["dmarc_policy"] = json!(policy);
email_data["dmarc_subdomain_policy"] = json!(sp);
email_data["dmarc_rua"] = json!(rua);
email_data["dmarc_policy"] = json!(policy);
email_data["dmarc_subdomain_policy"] = json!(sp);
email_data["dmarc_rua"] = json!(rua);
// Warn on weak policy
if let Some(ref p) = policy {
if p == "none" {
// Warn on weak policy
if let Some(ref p) = policy {
if p == "none" {
let evidence = DastEvidence {
request_method: "DNS".to_string(),
request_url: dmarc_domain.clone(),
request_headers: None,
request_body: None,
response_status: 0,
response_headers: None,
response_snippet: Some(record.clone()),
screenshot_path: None,
payload: None,
response_time_ms: None,
};
let mut finding = DastFinding::new(
String::new(),
target_id.clone(),
DastVulnType::EmailSecurity,
format!("Weak DMARC policy for {domain}"),
format!(
"The DMARC policy for {domain} is set to 'none', which only monitors \
but does not enforce email authentication. Attackers can spoof emails \
from this domain."
),
Severity::Medium,
domain.to_string(),
"DNS".to_string(),
);
finding.cwe = Some("CWE-290".to_string());
finding.evidence = vec![evidence];
finding.remediation = Some(
"Upgrade the DMARC policy from 'p=none' to 'p=quarantine' or \
'p=reject' after verifying legitimate email flows are properly \
authenticated."
.to_string(),
);
findings.push(finding);
warn!(domain, "DMARC policy is 'none'");
}
}
// Warn if no reporting URI
if rua.is_none() {
let evidence = DastEvidence {
request_method: "DNS".to_string(),
request_url: dmarc_domain.clone(),
@@ -156,48 +307,6 @@ impl PentestTool for DmarcCheckerTool {
};
let mut finding = DastFinding::new(
String::new(),
target_id.clone(),
DastVulnType::EmailSecurity,
format!("Weak DMARC policy for {domain}"),
format!(
"The DMARC policy for {domain} is set to 'none', which only monitors \
but does not enforce email authentication. Attackers can spoof emails \
from this domain."
),
Severity::Medium,
domain.to_string(),
"DNS".to_string(),
);
finding.cwe = Some("CWE-290".to_string());
finding.evidence = vec![evidence];
finding.remediation = Some(
"Upgrade the DMARC policy from 'p=none' to 'p=quarantine' or \
'p=reject' after verifying legitimate email flows are properly \
authenticated."
.to_string(),
);
findings.push(finding);
warn!(domain, "DMARC policy is 'none'");
}
}
// Warn if no reporting URI
if rua.is_none() {
let evidence = DastEvidence {
request_method: "DNS".to_string(),
request_url: dmarc_domain.clone(),
request_headers: None,
request_body: None,
response_status: 0,
response_headers: None,
response_snippet: Some(record.clone()),
screenshot_path: None,
payload: None,
response_time_ms: None,
};
let mut finding = DastFinding::new(
String::new(),
target_id.clone(),
DastVulnType::EmailSecurity,
@@ -211,80 +320,80 @@ impl PentestTool for DmarcCheckerTool {
domain.to_string(),
"DNS".to_string(),
);
finding.cwe = Some("CWE-778".to_string());
finding.evidence = vec![evidence];
finding.remediation = Some(
"Add a 'rua=' tag to your DMARC record to receive aggregate reports. \
finding.cwe = Some("CWE-778".to_string());
finding.evidence = vec![evidence];
finding.remediation = Some(
"Add a 'rua=' tag to your DMARC record to receive aggregate reports. \
Example: 'rua=mailto:dmarc-reports@example.com'."
.to_string(),
);
findings.push(finding);
.to_string(),
);
findings.push(finding);
}
}
}
None => {
email_data["dmarc_record"] = json!(null);
None => {
email_data["dmarc_record"] = json!(null);
let evidence = DastEvidence {
request_method: "DNS".to_string(),
request_url: dmarc_domain.clone(),
request_headers: None,
request_body: None,
response_status: 0,
response_headers: None,
response_snippet: Some("No DMARC record found".to_string()),
screenshot_path: None,
payload: None,
response_time_ms: None,
};
let mut finding = DastFinding::new(
String::new(),
target_id.clone(),
DastVulnType::EmailSecurity,
format!("Missing DMARC record for {domain}"),
format!(
"No DMARC record was found for {domain}. Without DMARC, there is no \
policy to prevent email spoofing and phishing attacks using this domain."
),
Severity::High,
domain.to_string(),
"DNS".to_string(),
);
finding.cwe = Some("CWE-290".to_string());
finding.evidence = vec![evidence];
finding.remediation = Some(
"Create a DMARC TXT record at _dmarc.<domain>. Start with 'v=DMARC1; p=none; \
rua=mailto:dmarc@example.com' and gradually move to 'p=reject'."
.to_string(),
);
findings.push(finding);
warn!(domain, "No DMARC record found");
}
}
// ---- SPF check ----
let txt_records = Self::query_txt(domain).await.unwrap_or_default();
let spf_record = txt_records.iter().find(|r| Self::is_spf_record(r));
match spf_record {
Some(record) => {
email_data["spf_record"] = json!(record);
if Self::spf_allows_all(record) {
let evidence = DastEvidence {
request_method: "DNS".to_string(),
request_url: domain.to_string(),
request_url: dmarc_domain.clone(),
request_headers: None,
request_body: None,
response_status: 0,
response_headers: None,
response_snippet: Some(record.clone()),
response_snippet: Some("No DMARC record found".to_string()),
screenshot_path: None,
payload: None,
response_time_ms: None,
};
let mut finding = DastFinding::new(
String::new(),
target_id.clone(),
DastVulnType::EmailSecurity,
format!("Missing DMARC record for {domain}"),
format!(
"No DMARC record was found for {domain}. Without DMARC, there is no \
policy to prevent email spoofing and phishing attacks using this domain."
),
Severity::High,
domain.to_string(),
"DNS".to_string(),
);
finding.cwe = Some("CWE-290".to_string());
finding.evidence = vec![evidence];
finding.remediation = Some(
"Create a DMARC TXT record at _dmarc.<domain>. Start with 'v=DMARC1; p=none; \
rua=mailto:dmarc@example.com' and gradually move to 'p=reject'."
.to_string(),
);
findings.push(finding);
warn!(domain, "No DMARC record found");
}
}
// ---- SPF check ----
let txt_records = Self::query_txt(domain).await.unwrap_or_default();
let spf_record = txt_records.iter().find(|r| Self::is_spf_record(r));
match spf_record {
Some(record) => {
email_data["spf_record"] = json!(record);
if Self::spf_allows_all(record) {
let evidence = DastEvidence {
request_method: "DNS".to_string(),
request_url: domain.to_string(),
request_headers: None,
request_body: None,
response_status: 0,
response_headers: None,
response_snippet: Some(record.clone()),
screenshot_path: None,
payload: None,
response_time_ms: None,
};
let mut finding = DastFinding::new(
String::new(),
target_id.clone(),
DastVulnType::EmailSecurity,
@@ -297,15 +406,55 @@ impl PentestTool for DmarcCheckerTool {
domain.to_string(),
"DNS".to_string(),
);
finding.cwe = Some("CWE-290".to_string());
finding.evidence = vec![evidence];
finding.remediation = Some(
finding.cwe = Some("CWE-290".to_string());
finding.evidence = vec![evidence];
finding.remediation = Some(
"Change '+all' to '-all' (hard fail) or '~all' (soft fail) in your SPF record. \
Only list authorized mail servers."
.to_string(),
);
findings.push(finding);
} else if Self::spf_uses_soft_fail(record) {
findings.push(finding);
} else if Self::spf_uses_soft_fail(record) {
let evidence = DastEvidence {
request_method: "DNS".to_string(),
request_url: domain.to_string(),
request_headers: None,
request_body: None,
response_status: 0,
response_headers: None,
response_snippet: Some(record.clone()),
screenshot_path: None,
payload: None,
response_time_ms: None,
};
let mut finding = DastFinding::new(
String::new(),
target_id.clone(),
DastVulnType::EmailSecurity,
format!("SPF soft fail for {domain}"),
format!(
"The SPF record for {domain} uses '~all' (soft fail) instead of \
'-all' (hard fail). Soft fail marks unauthorized emails as suspicious \
but does not reject them."
),
Severity::Low,
domain.to_string(),
"DNS".to_string(),
);
finding.cwe = Some("CWE-290".to_string());
finding.evidence = vec![evidence];
finding.remediation = Some(
"Consider changing '~all' to '-all' in your SPF record once you have \
confirmed all legitimate mail sources are listed."
.to_string(),
);
findings.push(finding);
}
}
None => {
email_data["spf_record"] = json!(null);
let evidence = DastEvidence {
request_method: "DNS".to_string(),
request_url: domain.to_string(),
@@ -313,7 +462,7 @@ impl PentestTool for DmarcCheckerTool {
request_body: None,
response_status: 0,
response_headers: None,
response_snippet: Some(record.clone()),
response_snippet: Some("No SPF record found".to_string()),
screenshot_path: None,
payload: None,
response_time_ms: None,
@@ -323,79 +472,39 @@ impl PentestTool for DmarcCheckerTool {
String::new(),
target_id.clone(),
DastVulnType::EmailSecurity,
format!("SPF soft fail for {domain}"),
format!("Missing SPF record for {domain}"),
format!(
"The SPF record for {domain} uses '~all' (soft fail) instead of \
'-all' (hard fail). Soft fail marks unauthorized emails as suspicious \
but does not reject them."
),
Severity::Low,
"No SPF record was found for {domain}. Without SPF, any server can claim \
to send email on behalf of this domain."
),
Severity::High,
domain.to_string(),
"DNS".to_string(),
);
finding.cwe = Some("CWE-290".to_string());
finding.evidence = vec![evidence];
finding.remediation = Some(
"Consider changing '~all' to '-all' in your SPF record once you have \
confirmed all legitimate mail sources are listed."
"Create an SPF TXT record for your domain. Example: \
'v=spf1 include:_spf.google.com -all'."
.to_string(),
);
findings.push(finding);
warn!(domain, "No SPF record found");
}
}
None => {
email_data["spf_record"] = json!(null);
let evidence = DastEvidence {
request_method: "DNS".to_string(),
request_url: domain.to_string(),
request_headers: None,
request_body: None,
response_status: 0,
response_headers: None,
response_snippet: Some("No SPF record found".to_string()),
screenshot_path: None,
payload: None,
response_time_ms: None,
};
let count = findings.len();
info!(domain, findings = count, "DMARC/SPF check complete");
let mut finding = DastFinding::new(
String::new(),
target_id.clone(),
DastVulnType::EmailSecurity,
format!("Missing SPF record for {domain}"),
format!(
"No SPF record was found for {domain}. Without SPF, any server can claim \
to send email on behalf of this domain."
),
Severity::High,
domain.to_string(),
"DNS".to_string(),
);
finding.cwe = Some("CWE-290".to_string());
finding.evidence = vec![evidence];
finding.remediation = Some(
"Create an SPF TXT record for your domain. Example: \
'v=spf1 include:_spf.google.com -all'."
.to_string(),
);
findings.push(finding);
warn!(domain, "No SPF record found");
}
}
let count = findings.len();
info!(domain, findings = count, "DMARC/SPF check complete");
Ok(PentestToolResult {
summary: if count > 0 {
format!("Found {count} email security issues for {domain}.")
} else {
format!("Email security configuration looks good for {domain}.")
},
findings,
data: email_data,
})
Ok(PentestToolResult {
summary: if count > 0 {
format!("Found {count} email security issues for {domain}.")
} else {
format!("Email security configuration looks good for {domain}.")
},
findings,
data: email_data,
})
})
}
}

View File

@@ -16,6 +16,12 @@ use tracing::{info, warn};
/// `tokio::process::Command` wrapper around `dig` where available.
pub struct DnsCheckerTool;
impl Default for DnsCheckerTool {
fn default() -> Self {
Self::new()
}
}
impl DnsCheckerTool {
pub fn new() -> Self {
Self
@@ -54,7 +60,9 @@ impl DnsCheckerTool {
}
}
Err(e) => {
return Err(CoreError::Dast(format!("DNS resolution failed for {domain}: {e}")));
return Err(CoreError::Dast(format!(
"DNS resolution failed for {domain}: {e}"
)));
}
}
@@ -94,107 +102,111 @@ impl PentestTool for DnsCheckerTool {
&'a self,
input: serde_json::Value,
context: &'a PentestToolContext,
) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<PentestToolResult, CoreError>> + Send + 'a>> {
) -> std::pin::Pin<
Box<dyn std::future::Future<Output = Result<PentestToolResult, CoreError>> + Send + 'a>,
> {
Box::pin(async move {
let domain = input
.get("domain")
.and_then(|v| v.as_str())
.ok_or_else(|| CoreError::Dast("Missing required 'domain' parameter".to_string()))?;
let domain = input
.get("domain")
.and_then(|v| v.as_str())
.ok_or_else(|| {
CoreError::Dast("Missing required 'domain' parameter".to_string())
})?;
let subdomains: Vec<String> = input
.get("subdomains")
.and_then(|v| v.as_array())
.map(|arr| {
arr.iter()
.filter_map(|v| v.as_str().map(String::from))
.collect()
})
.unwrap_or_default();
let subdomains: Vec<String> = input
.get("subdomains")
.and_then(|v| v.as_array())
.map(|arr| {
arr.iter()
.filter_map(|v| v.as_str().map(String::from))
.collect()
})
.unwrap_or_default();
let mut findings = Vec::new();
let mut dns_data: HashMap<String, serde_json::Value> = HashMap::new();
let mut findings = Vec::new();
let mut dns_data: HashMap<String, serde_json::Value> = HashMap::new();
let target_id = context
.target
.id
.map(|oid| oid.to_hex())
.unwrap_or_else(|| "unknown".to_string());
let target_id = context
.target
.id
.map(|oid| oid.to_hex())
.unwrap_or_else(|| "unknown".to_string());
// --- A / AAAA records ---
match Self::resolve_addresses(domain).await {
Ok((ipv4, ipv6)) => {
dns_data.insert("a_records".to_string(), json!(ipv4));
dns_data.insert("aaaa_records".to_string(), json!(ipv6));
// --- A / AAAA records ---
match Self::resolve_addresses(domain).await {
Ok((ipv4, ipv6)) => {
dns_data.insert("a_records".to_string(), json!(ipv4));
dns_data.insert("aaaa_records".to_string(), json!(ipv6));
}
Err(e) => {
dns_data.insert("a_records_error".to_string(), json!(e.to_string()));
}
}
Err(e) => {
dns_data.insert("a_records_error".to_string(), json!(e.to_string()));
}
}
// --- MX records ---
match Self::dig_query(domain, "MX").await {
Ok(mx) => {
dns_data.insert("mx_records".to_string(), json!(mx));
// --- MX records ---
match Self::dig_query(domain, "MX").await {
Ok(mx) => {
dns_data.insert("mx_records".to_string(), json!(mx));
}
Err(e) => {
dns_data.insert("mx_records_error".to_string(), json!(e.to_string()));
}
}
Err(e) => {
dns_data.insert("mx_records_error".to_string(), json!(e.to_string()));
}
}
// --- NS records ---
let ns_records = match Self::dig_query(domain, "NS").await {
Ok(ns) => {
dns_data.insert("ns_records".to_string(), json!(ns));
ns
}
Err(e) => {
dns_data.insert("ns_records_error".to_string(), json!(e.to_string()));
Vec::new()
}
};
// --- NS records ---
let ns_records = match Self::dig_query(domain, "NS").await {
Ok(ns) => {
dns_data.insert("ns_records".to_string(), json!(ns));
ns
}
Err(e) => {
dns_data.insert("ns_records_error".to_string(), json!(e.to_string()));
Vec::new()
}
};
// --- TXT records ---
match Self::dig_query(domain, "TXT").await {
Ok(txt) => {
dns_data.insert("txt_records".to_string(), json!(txt));
// --- TXT records ---
match Self::dig_query(domain, "TXT").await {
Ok(txt) => {
dns_data.insert("txt_records".to_string(), json!(txt));
}
Err(e) => {
dns_data.insert("txt_records_error".to_string(), json!(e.to_string()));
}
}
Err(e) => {
dns_data.insert("txt_records_error".to_string(), json!(e.to_string()));
// --- CNAME records (for subdomains) ---
let mut cname_data: HashMap<String, Vec<String>> = HashMap::new();
let mut domains_to_check = vec![domain.to_string()];
for sub in &subdomains {
domains_to_check.push(format!("{sub}.{domain}"));
}
}
// --- CNAME records (for subdomains) ---
let mut cname_data: HashMap<String, Vec<String>> = HashMap::new();
let mut domains_to_check = vec![domain.to_string()];
for sub in &subdomains {
domains_to_check.push(format!("{sub}.{domain}"));
}
for fqdn in &domains_to_check {
match Self::dig_query(fqdn, "CNAME").await {
Ok(cnames) if !cnames.is_empty() => {
// Check for dangling CNAME
for cname in &cnames {
let cname_clean = cname.trim_end_matches('.');
let check_addr = format!("{cname_clean}:443");
let is_dangling = lookup_host(&check_addr).await.is_err();
if is_dangling {
let evidence = DastEvidence {
request_method: "DNS".to_string(),
request_url: fqdn.clone(),
request_headers: None,
request_body: None,
response_status: 0,
response_headers: None,
response_snippet: Some(format!(
"CNAME {fqdn} -> {cname} (target does not resolve)"
)),
screenshot_path: None,
payload: None,
response_time_ms: None,
};
for fqdn in &domains_to_check {
match Self::dig_query(fqdn, "CNAME").await {
Ok(cnames) if !cnames.is_empty() => {
// Check for dangling CNAME
for cname in &cnames {
let cname_clean = cname.trim_end_matches('.');
let check_addr = format!("{cname_clean}:443");
let is_dangling = lookup_host(&check_addr).await.is_err();
if is_dangling {
let evidence = DastEvidence {
request_method: "DNS".to_string(),
request_url: fqdn.clone(),
request_headers: None,
request_body: None,
response_status: 0,
response_headers: None,
response_snippet: Some(format!(
"CNAME {fqdn} -> {cname} (target does not resolve)"
)),
screenshot_path: None,
payload: None,
response_time_ms: None,
};
let mut finding = DastFinding::new(
let mut finding = DastFinding::new(
String::new(),
target_id.clone(),
DastVulnType::DnsMisconfiguration,
@@ -207,44 +219,47 @@ impl PentestTool for DnsCheckerTool {
fqdn.clone(),
"DNS".to_string(),
);
finding.cwe = Some("CWE-923".to_string());
finding.evidence = vec![evidence];
finding.remediation = Some(
finding.cwe = Some("CWE-923".to_string());
finding.evidence = vec![evidence];
finding.remediation = Some(
"Remove dangling CNAME records or ensure the target hostname is \
properly configured and resolvable."
.to_string(),
);
findings.push(finding);
warn!(fqdn, cname, "Dangling CNAME detected - potential subdomain takeover");
findings.push(finding);
warn!(
fqdn,
cname, "Dangling CNAME detected - potential subdomain takeover"
);
}
}
cname_data.insert(fqdn.clone(), cnames);
}
cname_data.insert(fqdn.clone(), cnames);
_ => {}
}
_ => {}
}
}
if !cname_data.is_empty() {
dns_data.insert("cname_records".to_string(), json!(cname_data));
}
if !cname_data.is_empty() {
dns_data.insert("cname_records".to_string(), json!(cname_data));
}
// --- CAA records ---
match Self::dig_query(domain, "CAA").await {
Ok(caa) => {
if caa.is_empty() {
let evidence = DastEvidence {
request_method: "DNS".to_string(),
request_url: domain.to_string(),
request_headers: None,
request_body: None,
response_status: 0,
response_headers: None,
response_snippet: Some("No CAA records found".to_string()),
screenshot_path: None,
payload: None,
response_time_ms: None,
};
// --- CAA records ---
match Self::dig_query(domain, "CAA").await {
Ok(caa) => {
if caa.is_empty() {
let evidence = DastEvidence {
request_method: "DNS".to_string(),
request_url: domain.to_string(),
request_headers: None,
request_body: None,
response_status: 0,
response_headers: None,
response_snippet: Some("No CAA records found".to_string()),
screenshot_path: None,
payload: None,
response_time_ms: None,
};
let mut finding = DastFinding::new(
let mut finding = DastFinding::new(
String::new(),
target_id.clone(),
DastVulnType::DnsMisconfiguration,
@@ -257,35 +272,83 @@ impl PentestTool for DnsCheckerTool {
domain.to_string(),
"DNS".to_string(),
);
finding.cwe = Some("CWE-295".to_string());
finding.evidence = vec![evidence];
finding.remediation = Some(
finding.cwe = Some("CWE-295".to_string());
finding.evidence = vec![evidence];
finding.remediation = Some(
"Add CAA DNS records to restrict which certificate authorities can issue \
certificates for your domain. Example: '0 issue \"letsencrypt.org\"'."
.to_string(),
);
findings.push(finding);
findings.push(finding);
}
dns_data.insert("caa_records".to_string(), json!(caa));
}
Err(e) => {
dns_data.insert("caa_records_error".to_string(), json!(e.to_string()));
}
dns_data.insert("caa_records".to_string(), json!(caa));
}
Err(e) => {
dns_data.insert("caa_records_error".to_string(), json!(e.to_string()));
// --- DNSSEC check ---
let dnssec_output = tokio::process::Command::new("dig")
.args(["+dnssec", "+short", "DNSKEY", domain])
.output()
.await;
match dnssec_output {
Ok(output) => {
let stdout = String::from_utf8_lossy(&output.stdout);
let has_dnssec = !stdout.trim().is_empty();
dns_data.insert("dnssec_enabled".to_string(), json!(has_dnssec));
if !has_dnssec {
let evidence = DastEvidence {
request_method: "DNS".to_string(),
request_url: domain.to_string(),
request_headers: None,
request_body: None,
response_status: 0,
response_headers: None,
response_snippet: Some(
"No DNSKEY records found - DNSSEC not enabled".to_string(),
),
screenshot_path: None,
payload: None,
response_time_ms: None,
};
let mut finding = DastFinding::new(
String::new(),
target_id.clone(),
DastVulnType::DnsMisconfiguration,
format!("DNSSEC not enabled for {domain}"),
format!(
"DNSSEC is not enabled for {domain}. Without DNSSEC, DNS responses \
can be spoofed, allowing man-in-the-middle attacks."
),
Severity::Medium,
domain.to_string(),
"DNS".to_string(),
);
finding.cwe = Some("CWE-350".to_string());
finding.evidence = vec![evidence];
finding.remediation = Some(
"Enable DNSSEC for your domain by configuring DNSKEY and DS records \
with your DNS provider and domain registrar."
.to_string(),
);
findings.push(finding);
}
}
Err(_) => {
dns_data.insert("dnssec_check_error".to_string(), json!("dig not available"));
}
}
}
// --- DNSSEC check ---
let dnssec_output = tokio::process::Command::new("dig")
.args(["+dnssec", "+short", "DNSKEY", domain])
.output()
.await;
match dnssec_output {
Ok(output) => {
let stdout = String::from_utf8_lossy(&output.stdout);
let has_dnssec = !stdout.trim().is_empty();
dns_data.insert("dnssec_enabled".to_string(), json!(has_dnssec));
if !has_dnssec {
// --- Check NS records for dangling ---
for ns in &ns_records {
let ns_clean = ns.trim_end_matches('.');
let check_addr = format!("{ns_clean}:53");
if lookup_host(&check_addr).await.is_err() {
let evidence = DastEvidence {
request_method: "DNS".to_string(),
request_url: domain.to_string(),
@@ -293,61 +356,13 @@ impl PentestTool for DnsCheckerTool {
request_body: None,
response_status: 0,
response_headers: None,
response_snippet: Some("No DNSKEY records found - DNSSEC not enabled".to_string()),
response_snippet: Some(format!("NS record {ns} does not resolve")),
screenshot_path: None,
payload: None,
response_time_ms: None,
};
let mut finding = DastFinding::new(
String::new(),
target_id.clone(),
DastVulnType::DnsMisconfiguration,
format!("DNSSEC not enabled for {domain}"),
format!(
"DNSSEC is not enabled for {domain}. Without DNSSEC, DNS responses \
can be spoofed, allowing man-in-the-middle attacks."
),
Severity::Medium,
domain.to_string(),
"DNS".to_string(),
);
finding.cwe = Some("CWE-350".to_string());
finding.evidence = vec![evidence];
finding.remediation = Some(
"Enable DNSSEC for your domain by configuring DNSKEY and DS records \
with your DNS provider and domain registrar."
.to_string(),
);
findings.push(finding);
}
}
Err(_) => {
dns_data.insert("dnssec_check_error".to_string(), json!("dig not available"));
}
}
// --- Check NS records for dangling ---
for ns in &ns_records {
let ns_clean = ns.trim_end_matches('.');
let check_addr = format!("{ns_clean}:53");
if lookup_host(&check_addr).await.is_err() {
let evidence = DastEvidence {
request_method: "DNS".to_string(),
request_url: domain.to_string(),
request_headers: None,
request_body: None,
response_status: 0,
response_headers: None,
response_snippet: Some(format!(
"NS record {ns} does not resolve"
)),
screenshot_path: None,
payload: None,
response_time_ms: None,
};
let mut finding = DastFinding::new(
String::new(),
target_id.clone(),
DastVulnType::DnsMisconfiguration,
@@ -360,30 +375,33 @@ impl PentestTool for DnsCheckerTool {
domain.to_string(),
"DNS".to_string(),
);
finding.cwe = Some("CWE-923".to_string());
finding.evidence = vec![evidence];
finding.remediation = Some(
"Remove dangling NS records or ensure the nameserver hostname is properly \
finding.cwe = Some("CWE-923".to_string());
finding.evidence = vec![evidence];
finding.remediation = Some(
"Remove dangling NS records or ensure the nameserver hostname is properly \
configured. Dangling NS records can lead to full domain takeover."
.to_string(),
);
findings.push(finding);
warn!(domain, ns, "Dangling NS record detected - potential domain takeover");
.to_string(),
);
findings.push(finding);
warn!(
domain,
ns, "Dangling NS record detected - potential domain takeover"
);
}
}
}
let count = findings.len();
info!(domain, findings = count, "DNS check complete");
let count = findings.len();
info!(domain, findings = count, "DNS check complete");
Ok(PentestToolResult {
summary: if count > 0 {
format!("Found {count} DNS configuration issues for {domain}.")
} else {
format!("No DNS configuration issues found for {domain}.")
},
findings,
data: json!(dns_data),
})
Ok(PentestToolResult {
summary: if count > 0 {
format!("Found {count} DNS configuration issues for {domain}.")
} else {
format!("No DNS configuration issues found for {domain}.")
},
findings,
data: json!(dns_data),
})
})
}
}

View File

@@ -33,8 +33,15 @@ pub struct ToolRegistry {
tools: HashMap<String, Box<dyn PentestTool>>,
}
impl Default for ToolRegistry {
fn default() -> Self {
Self::new()
}
}
impl ToolRegistry {
/// Create a new registry with all built-in tools pre-registered.
#[allow(clippy::expect_used)]
pub fn new() -> Self {
let http = reqwest::Client::builder()
.danger_accept_invalid_certs(true)
@@ -67,13 +74,10 @@ impl ToolRegistry {
);
// New infrastructure / analysis tools
register(&mut tools, Box::<dns_checker::DnsCheckerTool>::default());
register(
&mut tools,
Box::new(dns_checker::DnsCheckerTool::new()),
);
register(
&mut tools,
Box::new(dmarc_checker::DmarcCheckerTool::new()),
Box::<dmarc_checker::DmarcCheckerTool>::default(),
);
register(
&mut tools,
@@ -109,10 +113,7 @@ impl ToolRegistry {
&mut tools,
Box::new(openapi_parser::OpenApiParserTool::new(http.clone())),
);
register(
&mut tools,
Box::new(recon::ReconTool::new(http)),
);
register(&mut tools, Box::new(recon::ReconTool::new(http)));
Self { tools }
}

View File

@@ -92,7 +92,10 @@ impl OpenApiParserTool {
// If content type suggests YAML, we can't easily parse without a YAML dep,
// so just report the URL as found
if content_type.contains("yaml") || body.starts_with("openapi:") || body.starts_with("swagger:") {
if content_type.contains("yaml")
|| body.starts_with("openapi:")
|| body.starts_with("swagger:")
{
// Return a minimal JSON indicating YAML was found
return Some((
url.to_string(),
@@ -107,7 +110,7 @@ impl OpenApiParserTool {
}
/// Parse an OpenAPI 3.x or Swagger 2.x spec into structured endpoints.
fn parse_spec(spec: &serde_json::Value, base_url: &str) -> Vec<ParsedEndpoint> {
fn parse_spec(spec: &serde_json::Value, _base_url: &str) -> Vec<ParsedEndpoint> {
let mut endpoints = Vec::new();
// Determine base path
@@ -258,6 +261,166 @@ impl OpenApiParserTool {
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn common_spec_paths_not_empty() {
let paths = OpenApiParserTool::common_spec_paths();
assert!(paths.len() >= 5);
assert!(paths.contains(&"/openapi.json"));
assert!(paths.contains(&"/swagger.json"));
}
#[test]
fn parse_spec_openapi3_basic() {
let spec = json!({
"openapi": "3.0.0",
"info": { "title": "Test API", "version": "1.0" },
"paths": {
"/users": {
"get": {
"operationId": "listUsers",
"summary": "List all users",
"parameters": [
{
"name": "limit",
"in": "query",
"required": false,
"schema": { "type": "integer" }
}
],
"responses": {
"200": { "description": "OK" },
"401": { "description": "Unauthorized" }
},
"tags": ["users"]
},
"post": {
"operationId": "createUser",
"requestBody": {
"content": {
"application/json": {}
}
},
"responses": { "201": {} },
"security": [{ "bearerAuth": [] }]
}
}
}
});
let endpoints = OpenApiParserTool::parse_spec(&spec, "https://api.example.com");
assert_eq!(endpoints.len(), 2);
let get_ep = endpoints.iter().find(|e| e.method == "GET").unwrap();
assert_eq!(get_ep.path, "/users");
assert_eq!(get_ep.operation_id.as_deref(), Some("listUsers"));
assert_eq!(get_ep.summary.as_deref(), Some("List all users"));
assert_eq!(get_ep.parameters.len(), 1);
assert_eq!(get_ep.parameters[0].name, "limit");
assert_eq!(get_ep.parameters[0].location, "query");
assert!(!get_ep.parameters[0].required);
assert_eq!(get_ep.parameters[0].param_type.as_deref(), Some("integer"));
assert_eq!(get_ep.response_codes.len(), 2);
assert_eq!(get_ep.tags, vec!["users"]);
let post_ep = endpoints.iter().find(|e| e.method == "POST").unwrap();
assert_eq!(
post_ep.request_body_content_type.as_deref(),
Some("application/json")
);
assert_eq!(post_ep.security, vec!["bearerAuth"]);
}
#[test]
fn parse_spec_swagger2_with_base_path() {
let spec = json!({
"swagger": "2.0",
"basePath": "/api/v1",
"paths": {
"/items": {
"get": {
"parameters": [
{ "name": "id", "in": "path", "required": true, "type": "string" }
],
"responses": { "200": {} }
}
}
}
});
let endpoints = OpenApiParserTool::parse_spec(&spec, "https://api.example.com");
assert_eq!(endpoints.len(), 1);
assert_eq!(endpoints[0].path, "/api/v1/items");
assert_eq!(
endpoints[0].parameters[0].param_type.as_deref(),
Some("string")
);
}
#[test]
fn parse_spec_empty_paths() {
let spec = json!({ "openapi": "3.0.0", "paths": {} });
let endpoints = OpenApiParserTool::parse_spec(&spec, "https://example.com");
assert!(endpoints.is_empty());
}
#[test]
fn parse_spec_no_paths_key() {
let spec = json!({ "openapi": "3.0.0" });
let endpoints = OpenApiParserTool::parse_spec(&spec, "https://example.com");
assert!(endpoints.is_empty());
}
#[test]
fn parse_spec_servers_base_url() {
let spec = json!({
"openapi": "3.0.0",
"servers": [{ "url": "/api/v2" }],
"paths": {
"/health": {
"get": { "responses": { "200": {} } }
}
}
});
let endpoints = OpenApiParserTool::parse_spec(&spec, "https://example.com");
assert_eq!(endpoints[0].path, "/api/v2/health");
}
#[test]
fn parse_spec_path_level_parameters_merged() {
let spec = json!({
"openapi": "3.0.0",
"paths": {
"/items/{id}": {
"parameters": [
{ "name": "id", "in": "path", "required": true, "schema": { "type": "string" } }
],
"get": {
"parameters": [
{ "name": "fields", "in": "query", "schema": { "type": "string" } }
],
"responses": { "200": {} }
}
}
}
});
let endpoints = OpenApiParserTool::parse_spec(&spec, "https://example.com");
assert_eq!(endpoints[0].parameters.len(), 2);
assert!(endpoints[0]
.parameters
.iter()
.any(|p| p.name == "id" && p.location == "path"));
assert!(endpoints[0]
.parameters
.iter()
.any(|p| p.name == "fields" && p.location == "query"));
}
}
impl PentestTool for OpenApiParserTool {
fn name(&self) -> &str {
"openapi_parser"
@@ -289,134 +452,138 @@ impl PentestTool for OpenApiParserTool {
fn execute<'a>(
&'a self,
input: serde_json::Value,
context: &'a PentestToolContext,
) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<PentestToolResult, CoreError>> + Send + 'a>> {
_context: &'a PentestToolContext,
) -> std::pin::Pin<
Box<dyn std::future::Future<Output = Result<PentestToolResult, CoreError>> + Send + 'a>,
> {
Box::pin(async move {
let base_url = input
.get("base_url")
.and_then(|v| v.as_str())
.ok_or_else(|| CoreError::Dast("Missing required 'base_url' parameter".to_string()))?;
let base_url = input
.get("base_url")
.and_then(|v| v.as_str())
.ok_or_else(|| {
CoreError::Dast("Missing required 'base_url' parameter".to_string())
})?;
let explicit_spec_url = input.get("spec_url").and_then(|v| v.as_str());
let explicit_spec_url = input.get("spec_url").and_then(|v| v.as_str());
let base_url_trimmed = base_url.trim_end_matches('/');
let base_url_trimmed = base_url.trim_end_matches('/');
// If an explicit spec URL is provided, try it first
let mut spec_result: Option<(String, serde_json::Value)> = None;
if let Some(spec_url) = explicit_spec_url {
spec_result = Self::try_fetch_spec(&self.http, spec_url).await;
}
// If an explicit spec URL is provided, try it first
let mut spec_result: Option<(String, serde_json::Value)> = None;
if let Some(spec_url) = explicit_spec_url {
spec_result = Self::try_fetch_spec(&self.http, spec_url).await;
}
// If no explicit URL or it failed, try common paths
if spec_result.is_none() {
for path in Self::common_spec_paths() {
let url = format!("{base_url_trimmed}{path}");
if let Some(result) = Self::try_fetch_spec(&self.http, &url).await {
spec_result = Some(result);
break;
// If no explicit URL or it failed, try common paths
if spec_result.is_none() {
for path in Self::common_spec_paths() {
let url = format!("{base_url_trimmed}{path}");
if let Some(result) = Self::try_fetch_spec(&self.http, &url).await {
spec_result = Some(result);
break;
}
}
}
}
match spec_result {
Some((spec_url, spec)) => {
let spec_version = spec
.get("openapi")
.or_else(|| spec.get("swagger"))
.and_then(|v| v.as_str())
.unwrap_or("unknown");
match spec_result {
Some((spec_url, spec)) => {
let spec_version = spec
.get("openapi")
.or_else(|| spec.get("swagger"))
.and_then(|v| v.as_str())
.unwrap_or("unknown");
let api_title = spec
.get("info")
.and_then(|i| i.get("title"))
.and_then(|t| t.as_str())
.unwrap_or("Unknown API");
let api_title = spec
.get("info")
.and_then(|i| i.get("title"))
.and_then(|t| t.as_str())
.unwrap_or("Unknown API");
let api_version = spec
.get("info")
.and_then(|i| i.get("version"))
.and_then(|v| v.as_str())
.unwrap_or("unknown");
let api_version = spec
.get("info")
.and_then(|i| i.get("version"))
.and_then(|v| v.as_str())
.unwrap_or("unknown");
let endpoints = Self::parse_spec(&spec, base_url_trimmed);
let endpoints = Self::parse_spec(&spec, base_url_trimmed);
let endpoint_data: Vec<serde_json::Value> = endpoints
.iter()
.map(|ep| {
let params: Vec<serde_json::Value> = ep
.parameters
.iter()
.map(|p| {
json!({
"name": p.name,
"in": p.location,
"required": p.required,
"type": p.param_type,
"description": p.description,
let endpoint_data: Vec<serde_json::Value> = endpoints
.iter()
.map(|ep| {
let params: Vec<serde_json::Value> = ep
.parameters
.iter()
.map(|p| {
json!({
"name": p.name,
"in": p.location,
"required": p.required,
"type": p.param_type,
"description": p.description,
})
})
.collect();
json!({
"path": ep.path,
"method": ep.method,
"operation_id": ep.operation_id,
"summary": ep.summary,
"parameters": params,
"request_body_content_type": ep.request_body_content_type,
"response_codes": ep.response_codes,
"security": ep.security,
"tags": ep.tags,
})
.collect();
json!({
"path": ep.path,
"method": ep.method,
"operation_id": ep.operation_id,
"summary": ep.summary,
"parameters": params,
"request_body_content_type": ep.request_body_content_type,
"response_codes": ep.response_codes,
"security": ep.security,
"tags": ep.tags,
})
})
.collect();
.collect();
let endpoint_count = endpoints.len();
info!(
spec_url = %spec_url,
spec_version,
api_title,
endpoints = endpoint_count,
"OpenAPI spec parsed"
);
let endpoint_count = endpoints.len();
info!(
spec_url = %spec_url,
spec_version,
api_title,
endpoints = endpoint_count,
"OpenAPI spec parsed"
);
Ok(PentestToolResult {
summary: format!(
"Found OpenAPI spec ({spec_version}) at {spec_url}. \
Ok(PentestToolResult {
summary: format!(
"Found OpenAPI spec ({spec_version}) at {spec_url}. \
API: {api_title} v{api_version}. \
Parsed {endpoint_count} endpoints."
),
findings: Vec::new(), // This tool produces data, not findings
data: json!({
"spec_url": spec_url,
"spec_version": spec_version,
"api_title": api_title,
"api_version": api_version,
"endpoint_count": endpoint_count,
"endpoints": endpoint_data,
"security_schemes": spec.get("components")
.and_then(|c| c.get("securitySchemes"))
.or_else(|| spec.get("securityDefinitions")),
}),
})
}
None => {
info!(base_url, "No OpenAPI spec found");
),
findings: Vec::new(), // This tool produces data, not findings
data: json!({
"spec_url": spec_url,
"spec_version": spec_version,
"api_title": api_title,
"api_version": api_version,
"endpoint_count": endpoint_count,
"endpoints": endpoint_data,
"security_schemes": spec.get("components")
.and_then(|c| c.get("securitySchemes"))
.or_else(|| spec.get("securityDefinitions")),
}),
})
}
None => {
info!(base_url, "No OpenAPI spec found");
Ok(PentestToolResult {
summary: format!(
"No OpenAPI/Swagger specification found for {base_url}. \
Ok(PentestToolResult {
summary: format!(
"No OpenAPI/Swagger specification found for {base_url}. \
Tried {} common paths.",
Self::common_spec_paths().len()
),
findings: Vec::new(),
data: json!({
"spec_found": false,
"paths_tried": Self::common_spec_paths(),
}),
})
Self::common_spec_paths().len()
),
findings: Vec::new(),
data: json!({
"spec_found": false,
"paths_tried": Self::common_spec_paths(),
}),
})
}
}
}
})
}
}

View File

@@ -62,224 +62,229 @@ impl PentestTool for RateLimitTesterTool {
&'a self,
input: serde_json::Value,
context: &'a PentestToolContext,
) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<PentestToolResult, CoreError>> + Send + 'a>> {
) -> std::pin::Pin<
Box<dyn std::future::Future<Output = Result<PentestToolResult, CoreError>> + Send + 'a>,
> {
Box::pin(async move {
let url = input
.get("url")
.and_then(|v| v.as_str())
.ok_or_else(|| CoreError::Dast("Missing required 'url' parameter".to_string()))?;
let url = input
.get("url")
.and_then(|v| v.as_str())
.ok_or_else(|| CoreError::Dast("Missing required 'url' parameter".to_string()))?;
let method = input
.get("method")
.and_then(|v| v.as_str())
.unwrap_or("GET");
let method = input
.get("method")
.and_then(|v| v.as_str())
.unwrap_or("GET");
let request_count = input
.get("request_count")
.and_then(|v| v.as_u64())
.unwrap_or(50)
.min(200) as usize;
let request_count = input
.get("request_count")
.and_then(|v| v.as_u64())
.unwrap_or(50)
.min(200) as usize;
let body = input.get("body").and_then(|v| v.as_str());
let body = input.get("body").and_then(|v| v.as_str());
let target_id = context
.target
.id
.map(|oid| oid.to_hex())
.unwrap_or_else(|| "unknown".to_string());
let target_id = context
.target
.id
.map(|oid| oid.to_hex())
.unwrap_or_else(|| "unknown".to_string());
// Respect the context rate limit if set
let max_requests = if context.rate_limit > 0 {
request_count.min(context.rate_limit as usize * 5)
} else {
request_count
};
let mut status_codes: Vec<u16> = Vec::with_capacity(max_requests);
let mut response_times: Vec<u128> = Vec::with_capacity(max_requests);
let mut got_429 = false;
let mut rate_limit_at_request: Option<usize> = None;
for i in 0..max_requests {
let start = Instant::now();
let request = match method {
"POST" => {
let mut req = self.http.post(url);
if let Some(b) = body {
req = req.body(b.to_string());
}
req
}
"PUT" => {
let mut req = self.http.put(url);
if let Some(b) = body {
req = req.body(b.to_string());
}
req
}
"PATCH" => {
let mut req = self.http.patch(url);
if let Some(b) = body {
req = req.body(b.to_string());
}
req
}
"DELETE" => self.http.delete(url),
_ => self.http.get(url),
// Respect the context rate limit if set
let max_requests = if context.rate_limit > 0 {
request_count.min(context.rate_limit as usize * 5)
} else {
request_count
};
match request.send().await {
Ok(resp) => {
let elapsed = start.elapsed().as_millis();
let status = resp.status().as_u16();
status_codes.push(status);
response_times.push(elapsed);
let mut status_codes: Vec<u16> = Vec::with_capacity(max_requests);
let mut response_times: Vec<u128> = Vec::with_capacity(max_requests);
let mut got_429 = false;
let mut rate_limit_at_request: Option<usize> = None;
if status == 429 && !got_429 {
got_429 = true;
rate_limit_at_request = Some(i + 1);
info!(url, request_num = i + 1, "Rate limit triggered (429)");
for i in 0..max_requests {
let start = Instant::now();
let request = match method {
"POST" => {
let mut req = self.http.post(url);
if let Some(b) = body {
req = req.body(b.to_string());
}
req
}
"PUT" => {
let mut req = self.http.put(url);
if let Some(b) = body {
req = req.body(b.to_string());
}
req
}
"PATCH" => {
let mut req = self.http.patch(url);
if let Some(b) = body {
req = req.body(b.to_string());
}
req
}
"DELETE" => self.http.delete(url),
_ => self.http.get(url),
};
// Check for rate limit headers even on 200
if !got_429 {
let headers = resp.headers();
let has_rate_headers = headers.contains_key("x-ratelimit-limit")
|| headers.contains_key("x-ratelimit-remaining")
|| headers.contains_key("ratelimit-limit")
|| headers.contains_key("ratelimit-remaining")
|| headers.contains_key("retry-after");
match request.send().await {
Ok(resp) => {
let elapsed = start.elapsed().as_millis();
let status = resp.status().as_u16();
status_codes.push(status);
response_times.push(elapsed);
if has_rate_headers && rate_limit_at_request.is_none() {
// Server has rate limit headers but hasn't blocked yet
if status == 429 && !got_429 {
got_429 = true;
rate_limit_at_request = Some(i + 1);
info!(url, request_num = i + 1, "Rate limit triggered (429)");
}
// Check for rate limit headers even on 200
if !got_429 {
let headers = resp.headers();
let has_rate_headers = headers.contains_key("x-ratelimit-limit")
|| headers.contains_key("x-ratelimit-remaining")
|| headers.contains_key("ratelimit-limit")
|| headers.contains_key("ratelimit-remaining")
|| headers.contains_key("retry-after");
if has_rate_headers && rate_limit_at_request.is_none() {
// Server has rate limit headers but hasn't blocked yet
}
}
}
}
Err(e) => {
let elapsed = start.elapsed().as_millis();
status_codes.push(0);
response_times.push(elapsed);
Err(_e) => {
let elapsed = start.elapsed().as_millis();
status_codes.push(0);
response_times.push(elapsed);
}
}
}
}
let mut findings = Vec::new();
let total_sent = status_codes.len();
let count_429 = status_codes.iter().filter(|&&s| s == 429).count();
let count_success = status_codes.iter().filter(|&&s| (200..300).contains(&s)).count();
let mut findings = Vec::new();
let total_sent = status_codes.len();
let count_429 = status_codes.iter().filter(|&&s| s == 429).count();
let count_success = status_codes
.iter()
.filter(|&&s| (200..300).contains(&s))
.count();
// Calculate response time statistics
let avg_time = if !response_times.is_empty() {
response_times.iter().sum::<u128>() / response_times.len() as u128
} else {
0
};
let first_half_avg = if response_times.len() >= 4 {
let half = response_times.len() / 2;
response_times[..half].iter().sum::<u128>() / half as u128
} else {
avg_time
};
let second_half_avg = if response_times.len() >= 4 {
let half = response_times.len() / 2;
response_times[half..].iter().sum::<u128>() / (response_times.len() - half) as u128
} else {
avg_time
};
// Significant time degradation suggests possible (weak) rate limiting
let time_degradation = if first_half_avg > 0 {
(second_half_avg as f64 / first_half_avg as f64) - 1.0
} else {
0.0
};
let rate_data = json!({
"total_requests_sent": total_sent,
"status_429_count": count_429,
"success_count": count_success,
"rate_limit_at_request": rate_limit_at_request,
"avg_response_time_ms": avg_time,
"first_half_avg_ms": first_half_avg,
"second_half_avg_ms": second_half_avg,
"time_degradation_pct": (time_degradation * 100.0).round(),
});
if !got_429 && count_success == total_sent {
// No rate limiting detected at all
let evidence = DastEvidence {
request_method: method.to_string(),
request_url: url.to_string(),
request_headers: None,
request_body: body.map(String::from),
response_status: 200,
response_headers: None,
response_snippet: Some(format!(
"Sent {total_sent} rapid requests. All returned success (2xx). \
No 429 responses received. Avg response time: {avg_time}ms."
)),
screenshot_path: None,
payload: None,
response_time_ms: Some(avg_time as u64),
// Calculate response time statistics
let avg_time = if !response_times.is_empty() {
response_times.iter().sum::<u128>() / response_times.len() as u128
} else {
0
};
let mut finding = DastFinding::new(
String::new(),
target_id.clone(),
DastVulnType::RateLimitAbsent,
format!("No rate limiting on {} {}", method, url),
format!(
"The endpoint {} {} does not enforce rate limiting. \
let first_half_avg = if response_times.len() >= 4 {
let half = response_times.len() / 2;
response_times[..half].iter().sum::<u128>() / half as u128
} else {
avg_time
};
let second_half_avg = if response_times.len() >= 4 {
let half = response_times.len() / 2;
response_times[half..].iter().sum::<u128>() / (response_times.len() - half) as u128
} else {
avg_time
};
// Significant time degradation suggests possible (weak) rate limiting
let time_degradation = if first_half_avg > 0 {
(second_half_avg as f64 / first_half_avg as f64) - 1.0
} else {
0.0
};
let rate_data = json!({
"total_requests_sent": total_sent,
"status_429_count": count_429,
"success_count": count_success,
"rate_limit_at_request": rate_limit_at_request,
"avg_response_time_ms": avg_time,
"first_half_avg_ms": first_half_avg,
"second_half_avg_ms": second_half_avg,
"time_degradation_pct": (time_degradation * 100.0).round(),
});
if !got_429 && count_success == total_sent {
// No rate limiting detected at all
let evidence = DastEvidence {
request_method: method.to_string(),
request_url: url.to_string(),
request_headers: None,
request_body: body.map(String::from),
response_status: 200,
response_headers: None,
response_snippet: Some(format!(
"Sent {total_sent} rapid requests. All returned success (2xx). \
No 429 responses received. Avg response time: {avg_time}ms."
)),
screenshot_path: None,
payload: None,
response_time_ms: Some(avg_time as u64),
};
let mut finding = DastFinding::new(
String::new(),
target_id.clone(),
DastVulnType::RateLimitAbsent,
format!("No rate limiting on {} {}", method, url),
format!(
"The endpoint {} {} does not enforce rate limiting. \
{total_sent} rapid requests were all accepted with no 429 responses \
or noticeable degradation. This makes the endpoint vulnerable to \
brute force attacks and abuse.",
method, url
),
Severity::Medium,
url.to_string(),
method.to_string(),
);
finding.cwe = Some("CWE-770".to_string());
finding.evidence = vec![evidence];
finding.remediation = Some(
"Implement rate limiting on this endpoint. Use token bucket or sliding window \
method, url
),
Severity::Medium,
url.to_string(),
method.to_string(),
);
finding.cwe = Some("CWE-770".to_string());
finding.evidence = vec![evidence];
finding.remediation = Some(
"Implement rate limiting on this endpoint. Use token bucket or sliding window \
algorithms. Return 429 Too Many Requests with a Retry-After header when the \
limit is exceeded."
.to_string(),
);
findings.push(finding);
warn!(url, method, total_sent, "No rate limiting detected");
} else if got_429 {
info!(
url,
method,
rate_limit_at = ?rate_limit_at_request,
"Rate limiting is enforced"
);
}
.to_string(),
);
findings.push(finding);
warn!(url, method, total_sent, "No rate limiting detected");
} else if got_429 {
info!(
url,
method,
rate_limit_at = ?rate_limit_at_request,
"Rate limiting is enforced"
);
}
let count = findings.len();
let count = findings.len();
Ok(PentestToolResult {
summary: if got_429 {
format!(
"Rate limiting is enforced on {method} {url}. \
Ok(PentestToolResult {
summary: if got_429 {
format!(
"Rate limiting is enforced on {method} {url}. \
429 response received after {} requests.",
rate_limit_at_request.unwrap_or(0)
)
} else if count > 0 {
format!(
"No rate limiting detected on {method} {url} after {total_sent} requests."
)
} else {
format!("Rate limit testing complete for {method} {url}.")
},
findings,
data: rate_data,
})
rate_limit_at_request.unwrap_or(0)
)
} else if count > 0 {
format!(
"No rate limiting detected on {method} {url} after {total_sent} requests."
)
} else {
format!("Rate limit testing complete for {method} {url}.")
},
findings,
data: rate_data,
})
})
}
}

View File

@@ -54,72 +54,75 @@ impl PentestTool for ReconTool {
fn execute<'a>(
&'a self,
input: serde_json::Value,
context: &'a PentestToolContext,
) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<PentestToolResult, CoreError>> + Send + 'a>> {
_context: &'a PentestToolContext,
) -> std::pin::Pin<
Box<dyn std::future::Future<Output = Result<PentestToolResult, CoreError>> + Send + 'a>,
> {
Box::pin(async move {
let url = input
.get("url")
.and_then(|v| v.as_str())
.ok_or_else(|| CoreError::Dast("Missing required 'url' parameter".to_string()))?;
let url = input
.get("url")
.and_then(|v| v.as_str())
.ok_or_else(|| CoreError::Dast("Missing required 'url' parameter".to_string()))?;
let additional_paths: Vec<String> = input
.get("additional_paths")
.and_then(|v| v.as_array())
.map(|arr| {
arr.iter()
.filter_map(|v| v.as_str().map(String::from))
.collect()
})
.unwrap_or_default();
let additional_paths: Vec<String> = input
.get("additional_paths")
.and_then(|v| v.as_array())
.map(|arr| {
arr.iter()
.filter_map(|v| v.as_str().map(String::from))
.collect()
})
.unwrap_or_default();
let result = self.agent.scan(url).await?;
let result = self.agent.scan(url).await?;
// Scan additional paths for more technology signals
let mut extra_technologies: Vec<String> = Vec::new();
let mut extra_headers: HashMap<String, String> = HashMap::new();
// Scan additional paths for more technology signals
let mut extra_technologies: Vec<String> = Vec::new();
let mut extra_headers: HashMap<String, String> = HashMap::new();
let base_url = url.trim_end_matches('/');
for path in &additional_paths {
let probe_url = format!("{base_url}/{}", path.trim_start_matches('/'));
if let Ok(resp) = self.http.get(&probe_url).send().await {
for (key, value) in resp.headers() {
let k = key.to_string().to_lowercase();
let v = value.to_str().unwrap_or("").to_string();
let base_url = url.trim_end_matches('/');
for path in &additional_paths {
let probe_url = format!("{base_url}/{}", path.trim_start_matches('/'));
if let Ok(resp) = self.http.get(&probe_url).send().await {
for (key, value) in resp.headers() {
let k = key.to_string().to_lowercase();
let v = value.to_str().unwrap_or("").to_string();
// Look for technology indicators
if k == "x-powered-by" || k == "server" || k == "x-generator" {
if !result.technologies.contains(&v) && !extra_technologies.contains(&v) {
// Look for technology indicators
if (k == "x-powered-by" || k == "server" || k == "x-generator")
&& !result.technologies.contains(&v)
&& !extra_technologies.contains(&v)
{
extra_technologies.push(v.clone());
}
extra_headers.insert(format!("{probe_url} -> {k}"), v);
}
extra_headers.insert(format!("{probe_url} -> {k}"), v);
}
}
}
let mut all_technologies = result.technologies.clone();
all_technologies.extend(extra_technologies);
all_technologies.dedup();
let mut all_technologies = result.technologies.clone();
all_technologies.extend(extra_technologies);
all_technologies.dedup();
let tech_count = all_technologies.len();
info!(url, technologies = tech_count, "Recon complete");
let tech_count = all_technologies.len();
info!(url, technologies = tech_count, "Recon complete");
Ok(PentestToolResult {
summary: format!(
"Recon complete for {url}. Detected {} technologies. Server: {}.",
tech_count,
result.server.as_deref().unwrap_or("unknown")
),
findings: Vec::new(), // Recon produces data, not findings
data: json!({
"base_url": url,
"server": result.server,
"technologies": all_technologies,
"interesting_headers": result.interesting_headers,
"extra_headers": extra_headers,
"open_ports": result.open_ports,
}),
})
Ok(PentestToolResult {
summary: format!(
"Recon complete for {url}. Detected {} technologies. Server: {}.",
tech_count,
result.server.as_deref().unwrap_or("unknown")
),
findings: Vec::new(), // Recon produces data, not findings
data: json!({
"base_url": url,
"server": result.server,
"technologies": all_technologies,
"interesting_headers": result.interesting_headers,
"extra_headers": extra_headers,
"open_ports": result.open_ports,
}),
})
})
}
}

View File

@@ -111,57 +111,107 @@ impl PentestTool for SecurityHeadersTool {
&'a self,
input: serde_json::Value,
context: &'a PentestToolContext,
) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<PentestToolResult, CoreError>> + Send + 'a>> {
) -> std::pin::Pin<
Box<dyn std::future::Future<Output = Result<PentestToolResult, CoreError>> + Send + 'a>,
> {
Box::pin(async move {
let url = input
.get("url")
.and_then(|v| v.as_str())
.ok_or_else(|| CoreError::Dast("Missing required 'url' parameter".to_string()))?;
let url = input
.get("url")
.and_then(|v| v.as_str())
.ok_or_else(|| CoreError::Dast("Missing required 'url' parameter".to_string()))?;
let target_id = context
.target
.id
.map(|oid| oid.to_hex())
.unwrap_or_else(|| "unknown".to_string());
let target_id = context
.target
.id
.map(|oid| oid.to_hex())
.unwrap_or_else(|| "unknown".to_string());
let response = self
.http
.get(url)
.send()
.await
.map_err(|e| CoreError::Dast(format!("Failed to fetch {url}: {e}")))?;
let response = self
.http
.get(url)
.send()
.await
.map_err(|e| CoreError::Dast(format!("Failed to fetch {url}: {e}")))?;
let status = response.status().as_u16();
let response_headers: HashMap<String, String> = response
.headers()
.iter()
.map(|(k, v)| (k.to_string().to_lowercase(), v.to_str().unwrap_or("").to_string()))
.collect();
let status = response.status().as_u16();
let response_headers: HashMap<String, String> = response
.headers()
.iter()
.map(|(k, v)| {
(
k.to_string().to_lowercase(),
v.to_str().unwrap_or("").to_string(),
)
})
.collect();
let mut findings = Vec::new();
let mut header_results: HashMap<String, serde_json::Value> = HashMap::new();
let mut findings = Vec::new();
let mut header_results: HashMap<String, serde_json::Value> = HashMap::new();
for expected in Self::expected_headers() {
let header_value = response_headers.get(expected.name);
for expected in Self::expected_headers() {
let header_value = response_headers.get(expected.name);
match header_value {
Some(value) => {
let mut is_valid = true;
if let Some(ref valid) = expected.valid_values {
let lower = value.to_lowercase();
is_valid = valid.iter().any(|v| lower.contains(v));
match header_value {
Some(value) => {
let mut is_valid = true;
if let Some(ref valid) = expected.valid_values {
let lower = value.to_lowercase();
is_valid = valid.iter().any(|v| lower.contains(v));
}
header_results.insert(
expected.name.to_string(),
json!({
"present": true,
"value": value,
"valid": is_valid,
}),
);
if !is_valid {
let evidence = DastEvidence {
request_method: "GET".to_string(),
request_url: url.to_string(),
request_headers: None,
request_body: None,
response_status: status,
response_headers: Some(response_headers.clone()),
response_snippet: Some(format!("{}: {}", expected.name, value)),
screenshot_path: None,
payload: None,
response_time_ms: None,
};
let mut finding = DastFinding::new(
String::new(),
target_id.clone(),
DastVulnType::SecurityHeaderMissing,
format!("Invalid {} header value", expected.name),
format!(
"The {} header is present but has an invalid or weak value: '{}'. \
{}",
expected.name, value, expected.description
),
expected.severity.clone(),
url.to_string(),
"GET".to_string(),
);
finding.cwe = Some(expected.cwe.to_string());
finding.evidence = vec![evidence];
finding.remediation = Some(expected.remediation.to_string());
findings.push(finding);
}
}
None => {
header_results.insert(
expected.name.to_string(),
json!({
"present": false,
"value": null,
"valid": false,
}),
);
header_results.insert(
expected.name.to_string(),
json!({
"present": true,
"value": value,
"valid": is_valid,
}),
);
if !is_valid {
let evidence = DastEvidence {
request_method: "GET".to_string(),
request_url: url.to_string(),
@@ -169,7 +219,7 @@ impl PentestTool for SecurityHeadersTool {
request_body: None,
response_status: status,
response_headers: Some(response_headers.clone()),
response_snippet: Some(format!("{}: {}", expected.name, value)),
response_snippet: Some(format!("{} header is missing", expected.name)),
screenshot_path: None,
payload: None,
response_time_ms: None,
@@ -179,11 +229,10 @@ impl PentestTool for SecurityHeadersTool {
String::new(),
target_id.clone(),
DastVulnType::SecurityHeaderMissing,
format!("Invalid {} header value", expected.name),
format!("Missing {} header", expected.name),
format!(
"The {} header is present but has an invalid or weak value: '{}'. \
{}",
expected.name, value, expected.description
"The {} header is not present in the response. {}",
expected.name, expected.description
),
expected.severity.clone(),
url.to_string(),
@@ -195,14 +244,20 @@ impl PentestTool for SecurityHeadersTool {
findings.push(finding);
}
}
None => {
}
// Also check for information disclosure headers
let disclosure_headers = [
"server",
"x-powered-by",
"x-aspnet-version",
"x-aspnetmvc-version",
];
for h in &disclosure_headers {
if let Some(value) = response_headers.get(*h) {
header_results.insert(
expected.name.to_string(),
json!({
"present": false,
"value": null,
"valid": false,
}),
format!("{h}_disclosure"),
json!({ "present": true, "value": value }),
);
let evidence = DastEvidence {
@@ -212,56 +267,13 @@ impl PentestTool for SecurityHeadersTool {
request_body: None,
response_status: status,
response_headers: Some(response_headers.clone()),
response_snippet: Some(format!("{} header is missing", expected.name)),
response_snippet: Some(format!("{h}: {value}")),
screenshot_path: None,
payload: None,
response_time_ms: None,
};
let mut finding = DastFinding::new(
String::new(),
target_id.clone(),
DastVulnType::SecurityHeaderMissing,
format!("Missing {} header", expected.name),
format!(
"The {} header is not present in the response. {}",
expected.name, expected.description
),
expected.severity.clone(),
url.to_string(),
"GET".to_string(),
);
finding.cwe = Some(expected.cwe.to_string());
finding.evidence = vec![evidence];
finding.remediation = Some(expected.remediation.to_string());
findings.push(finding);
}
}
}
// Also check for information disclosure headers
let disclosure_headers = ["server", "x-powered-by", "x-aspnet-version", "x-aspnetmvc-version"];
for h in &disclosure_headers {
if let Some(value) = response_headers.get(*h) {
header_results.insert(
format!("{h}_disclosure"),
json!({ "present": true, "value": value }),
);
let evidence = DastEvidence {
request_method: "GET".to_string(),
request_url: url.to_string(),
request_headers: None,
request_body: None,
response_status: status,
response_headers: Some(response_headers.clone()),
response_snippet: Some(format!("{h}: {value}")),
screenshot_path: None,
payload: None,
response_time_ms: None,
};
let mut finding = DastFinding::new(
String::new(),
target_id.clone(),
DastVulnType::SecurityHeaderMissing,
@@ -274,27 +286,27 @@ impl PentestTool for SecurityHeadersTool {
url.to_string(),
"GET".to_string(),
);
finding.cwe = Some("CWE-200".to_string());
finding.evidence = vec![evidence];
finding.remediation = Some(format!(
"Remove or suppress the {h} header in your server configuration."
));
findings.push(finding);
finding.cwe = Some("CWE-200".to_string());
finding.evidence = vec![evidence];
finding.remediation = Some(format!(
"Remove or suppress the {h} header in your server configuration."
));
findings.push(finding);
}
}
}
let count = findings.len();
info!(url, findings = count, "Security headers check complete");
let count = findings.len();
info!(url, findings = count, "Security headers check complete");
Ok(PentestToolResult {
summary: if count > 0 {
format!("Found {count} security header issues for {url}.")
} else {
format!("All checked security headers are present and valid for {url}.")
},
findings,
data: json!(header_results),
})
Ok(PentestToolResult {
summary: if count > 0 {
format!("Found {count} security header issues for {url}.")
} else {
format!("All checked security headers are present and valid for {url}.")
},
findings,
data: json!(header_results),
})
})
}
}

View File

@@ -1,5 +1,7 @@
use compliance_core::error::CoreError;
use compliance_core::traits::dast_agent::{DastAgent, DastContext, DiscoveredEndpoint, EndpointParameter};
use compliance_core::traits::dast_agent::{
DastAgent, DastContext, DiscoveredEndpoint, EndpointParameter,
};
use compliance_core::traits::pentest_tool::{PentestTool, PentestToolContext, PentestToolResult};
use serde_json::json;
@@ -7,29 +9,51 @@ use crate::agents::injection::SqlInjectionAgent;
/// PentestTool wrapper around the existing SqlInjectionAgent.
pub struct SqlInjectionTool {
http: reqwest::Client,
_http: reqwest::Client,
agent: SqlInjectionAgent,
}
impl SqlInjectionTool {
pub fn new(http: reqwest::Client) -> Self {
let agent = SqlInjectionAgent::new(http.clone());
Self { http, agent }
Self { _http: http, agent }
}
fn parse_endpoints(input: &serde_json::Value) -> Vec<DiscoveredEndpoint> {
let mut endpoints = Vec::new();
if let Some(arr) = input.get("endpoints").and_then(|v| v.as_array()) {
for ep in arr {
let url = ep.get("url").and_then(|v| v.as_str()).unwrap_or_default().to_string();
let method = ep.get("method").and_then(|v| v.as_str()).unwrap_or("GET").to_string();
let url = ep
.get("url")
.and_then(|v| v.as_str())
.unwrap_or_default()
.to_string();
let method = ep
.get("method")
.and_then(|v| v.as_str())
.unwrap_or("GET")
.to_string();
let mut parameters = Vec::new();
if let Some(params) = ep.get("parameters").and_then(|v| v.as_array()) {
for p in params {
let name = p.get("name").and_then(|v| v.as_str()).unwrap_or_default().to_string();
let location = p.get("location").and_then(|v| v.as_str()).unwrap_or("query").to_string();
let param_type = p.get("param_type").and_then(|v| v.as_str()).map(String::from);
let example_value = p.get("example_value").and_then(|v| v.as_str()).map(String::from);
let name = p
.get("name")
.and_then(|v| v.as_str())
.unwrap_or_default()
.to_string();
let location = p
.get("location")
.and_then(|v| v.as_str())
.unwrap_or("query")
.to_string();
let param_type = p
.get("param_type")
.and_then(|v| v.as_str())
.map(String::from);
let example_value = p
.get("example_value")
.and_then(|v| v.as_str())
.map(String::from);
parameters.push(EndpointParameter {
name,
location,
@@ -42,8 +66,14 @@ impl SqlInjectionTool {
url,
method,
parameters,
content_type: ep.get("content_type").and_then(|v| v.as_str()).map(String::from),
requires_auth: ep.get("requires_auth").and_then(|v| v.as_bool()).unwrap_or(false),
content_type: ep
.get("content_type")
.and_then(|v| v.as_str())
.map(String::from),
requires_auth: ep
.get("requires_auth")
.and_then(|v| v.as_bool())
.unwrap_or(false),
});
}
}
@@ -51,6 +81,62 @@ impl SqlInjectionTool {
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn parse_endpoints_basic() {
let input = json!({
"endpoints": [{
"url": "https://example.com/api/users",
"method": "POST",
"parameters": [
{ "name": "id", "location": "body", "param_type": "integer" }
]
}]
});
let endpoints = SqlInjectionTool::parse_endpoints(&input);
assert_eq!(endpoints.len(), 1);
assert_eq!(endpoints[0].url, "https://example.com/api/users");
assert_eq!(endpoints[0].method, "POST");
assert_eq!(endpoints[0].parameters[0].name, "id");
assert_eq!(endpoints[0].parameters[0].location, "body");
assert_eq!(
endpoints[0].parameters[0].param_type.as_deref(),
Some("integer")
);
}
#[test]
fn parse_endpoints_empty_input() {
assert!(SqlInjectionTool::parse_endpoints(&json!({})).is_empty());
assert!(SqlInjectionTool::parse_endpoints(&json!({ "endpoints": [] })).is_empty());
}
#[test]
fn parse_endpoints_multiple() {
let input = json!({
"endpoints": [
{ "url": "https://a.com/1", "method": "GET", "parameters": [] },
{ "url": "https://b.com/2", "method": "DELETE", "parameters": [] }
]
});
let endpoints = SqlInjectionTool::parse_endpoints(&input);
assert_eq!(endpoints.len(), 2);
assert_eq!(endpoints[0].url, "https://a.com/1");
assert_eq!(endpoints[1].method, "DELETE");
}
#[test]
fn parse_endpoints_default_method() {
let input = json!({ "endpoints": [{ "url": "https://x.com", "parameters": [] }] });
let endpoints = SqlInjectionTool::parse_endpoints(&input);
assert_eq!(endpoints[0].method, "GET");
}
}
impl PentestTool for SqlInjectionTool {
fn name(&self) -> &str {
"sql_injection_scanner"
@@ -104,35 +190,37 @@ impl PentestTool for SqlInjectionTool {
&'a self,
input: serde_json::Value,
context: &'a PentestToolContext,
) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<PentestToolResult, CoreError>> + Send + 'a>> {
) -> std::pin::Pin<
Box<dyn std::future::Future<Output = Result<PentestToolResult, CoreError>> + Send + 'a>,
> {
Box::pin(async move {
let endpoints = Self::parse_endpoints(&input);
if endpoints.is_empty() {
return Ok(PentestToolResult {
summary: "No endpoints provided to test.".to_string(),
findings: Vec::new(),
data: json!({}),
});
}
let endpoints = Self::parse_endpoints(&input);
if endpoints.is_empty() {
return Ok(PentestToolResult {
summary: "No endpoints provided to test.".to_string(),
findings: Vec::new(),
data: json!({}),
});
}
let dast_context = DastContext {
endpoints,
technologies: Vec::new(),
sast_hints: Vec::new(),
};
let dast_context = DastContext {
endpoints,
technologies: Vec::new(),
sast_hints: Vec::new(),
};
let findings = self.agent.run(&context.target, &dast_context).await?;
let count = findings.len();
let findings = self.agent.run(&context.target, &dast_context).await?;
let count = findings.len();
Ok(PentestToolResult {
summary: if count > 0 {
format!("Found {count} SQL injection vulnerabilities.")
} else {
"No SQL injection vulnerabilities detected.".to_string()
},
findings,
data: json!({ "endpoints_tested": dast_context.endpoints.len() }),
})
Ok(PentestToolResult {
summary: if count > 0 {
format!("Found {count} SQL injection vulnerabilities.")
} else {
"No SQL injection vulnerabilities detected.".to_string()
},
findings,
data: json!({ "endpoints_tested": dast_context.endpoints.len() }),
})
})
}
}

View File

@@ -1,5 +1,7 @@
use compliance_core::error::CoreError;
use compliance_core::traits::dast_agent::{DastAgent, DastContext, DiscoveredEndpoint, EndpointParameter};
use compliance_core::traits::dast_agent::{
DastAgent, DastContext, DiscoveredEndpoint, EndpointParameter,
};
use compliance_core::traits::pentest_tool::{PentestTool, PentestToolContext, PentestToolResult};
use serde_json::json;
@@ -7,30 +9,52 @@ use crate::agents::ssrf::SsrfAgent;
/// PentestTool wrapper around the existing SsrfAgent.
pub struct SsrfTool {
http: reqwest::Client,
_http: reqwest::Client,
agent: SsrfAgent,
}
impl SsrfTool {
pub fn new(http: reqwest::Client) -> Self {
let agent = SsrfAgent::new(http.clone());
Self { http, agent }
Self { _http: http, agent }
}
fn parse_endpoints(input: &serde_json::Value) -> Vec<DiscoveredEndpoint> {
let mut endpoints = Vec::new();
if let Some(arr) = input.get("endpoints").and_then(|v| v.as_array()) {
for ep in arr {
let url = ep.get("url").and_then(|v| v.as_str()).unwrap_or_default().to_string();
let method = ep.get("method").and_then(|v| v.as_str()).unwrap_or("GET").to_string();
let url = ep
.get("url")
.and_then(|v| v.as_str())
.unwrap_or_default()
.to_string();
let method = ep
.get("method")
.and_then(|v| v.as_str())
.unwrap_or("GET")
.to_string();
let mut parameters = Vec::new();
if let Some(params) = ep.get("parameters").and_then(|v| v.as_array()) {
for p in params {
parameters.push(EndpointParameter {
name: p.get("name").and_then(|v| v.as_str()).unwrap_or_default().to_string(),
location: p.get("location").and_then(|v| v.as_str()).unwrap_or("query").to_string(),
param_type: p.get("param_type").and_then(|v| v.as_str()).map(String::from),
example_value: p.get("example_value").and_then(|v| v.as_str()).map(String::from),
name: p
.get("name")
.and_then(|v| v.as_str())
.unwrap_or_default()
.to_string(),
location: p
.get("location")
.and_then(|v| v.as_str())
.unwrap_or("query")
.to_string(),
param_type: p
.get("param_type")
.and_then(|v| v.as_str())
.map(String::from),
example_value: p
.get("example_value")
.and_then(|v| v.as_str())
.map(String::from),
});
}
}
@@ -38,8 +62,14 @@ impl SsrfTool {
url,
method,
parameters,
content_type: ep.get("content_type").and_then(|v| v.as_str()).map(String::from),
requires_auth: ep.get("requires_auth").and_then(|v| v.as_bool()).unwrap_or(false),
content_type: ep
.get("content_type")
.and_then(|v| v.as_str())
.map(String::from),
requires_auth: ep
.get("requires_auth")
.and_then(|v| v.as_bool())
.unwrap_or(false),
});
}
}
@@ -100,35 +130,37 @@ impl PentestTool for SsrfTool {
&'a self,
input: serde_json::Value,
context: &'a PentestToolContext,
) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<PentestToolResult, CoreError>> + Send + 'a>> {
) -> std::pin::Pin<
Box<dyn std::future::Future<Output = Result<PentestToolResult, CoreError>> + Send + 'a>,
> {
Box::pin(async move {
let endpoints = Self::parse_endpoints(&input);
if endpoints.is_empty() {
return Ok(PentestToolResult {
summary: "No endpoints provided to test.".to_string(),
findings: Vec::new(),
data: json!({}),
});
}
let endpoints = Self::parse_endpoints(&input);
if endpoints.is_empty() {
return Ok(PentestToolResult {
summary: "No endpoints provided to test.".to_string(),
findings: Vec::new(),
data: json!({}),
});
}
let dast_context = DastContext {
endpoints,
technologies: Vec::new(),
sast_hints: Vec::new(),
};
let dast_context = DastContext {
endpoints,
technologies: Vec::new(),
sast_hints: Vec::new(),
};
let findings = self.agent.run(&context.target, &dast_context).await?;
let count = findings.len();
let findings = self.agent.run(&context.target, &dast_context).await?;
let count = findings.len();
Ok(PentestToolResult {
summary: if count > 0 {
format!("Found {count} SSRF vulnerabilities.")
} else {
"No SSRF vulnerabilities detected.".to_string()
},
findings,
data: json!({ "endpoints_tested": dast_context.endpoints.len() }),
})
Ok(PentestToolResult {
summary: if count > 0 {
format!("Found {count} SSRF vulnerabilities.")
} else {
"No SSRF vulnerabilities detected.".to_string()
},
findings,
data: json!({ "endpoints_tested": dast_context.endpoints.len() }),
})
})
}
}

View File

@@ -39,10 +39,7 @@ impl TlsAnalyzerTool {
/// TLS client hello. We test SSLv3 / old protocol support by attempting
/// connection with the system's native-tls which typically negotiates the
/// best available, then inspect what was negotiated.
async fn check_tls(
host: &str,
port: u16,
) -> Result<TlsInfo, CoreError> {
async fn check_tls(host: &str, port: u16) -> Result<TlsInfo, CoreError> {
let addr = format!("{host}:{port}");
let tcp = TcpStream::connect(&addr)
@@ -62,11 +59,13 @@ impl TlsAnalyzerTool {
.await
.map_err(|e| CoreError::Dast(format!("TLS handshake with {addr} failed: {e}")))?;
let peer_cert = tls_stream.get_ref().peer_certificate()
let peer_cert = tls_stream
.get_ref()
.peer_certificate()
.map_err(|e| CoreError::Dast(format!("Failed to get peer certificate: {e}")))?;
let mut tls_info = TlsInfo {
protocol_version: String::new(),
_protocol_version: String::new(),
cert_subject: String::new(),
cert_issuer: String::new(),
cert_not_before: String::new(),
@@ -78,7 +77,8 @@ impl TlsAnalyzerTool {
};
if let Some(cert) = peer_cert {
let der = cert.to_der()
let der = cert
.to_der()
.map_err(|e| CoreError::Dast(format!("Certificate DER encoding failed: {e}")))?;
// native_tls doesn't give rich access, so we parse what we can
@@ -93,7 +93,7 @@ impl TlsAnalyzerTool {
/// Best-effort parse of DER-encoded X.509 certificate for dates and subject.
/// This is a simplified parser; in production you would use a proper x509 crate.
fn parse_cert_der(der: &[u8], mut info: TlsInfo) -> TlsInfo {
fn parse_cert_der(_der: &[u8], mut info: TlsInfo) -> TlsInfo {
// We rely on the native_tls debug output stored in cert_subject
// and just mark fields as "see certificate details"
if info.cert_subject.contains("self signed") || info.cert_subject.contains("Self-Signed") {
@@ -104,7 +104,7 @@ impl TlsAnalyzerTool {
}
struct TlsInfo {
protocol_version: String,
_protocol_version: String,
cert_subject: String,
cert_issuer: String,
cert_not_before: String,
@@ -152,111 +152,194 @@ impl PentestTool for TlsAnalyzerTool {
&'a self,
input: serde_json::Value,
context: &'a PentestToolContext,
) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<PentestToolResult, CoreError>> + Send + 'a>> {
) -> std::pin::Pin<
Box<dyn std::future::Future<Output = Result<PentestToolResult, CoreError>> + Send + 'a>,
> {
Box::pin(async move {
let url = input
.get("url")
.and_then(|v| v.as_str())
.ok_or_else(|| CoreError::Dast("Missing required 'url' parameter".to_string()))?;
let url = input
.get("url")
.and_then(|v| v.as_str())
.ok_or_else(|| CoreError::Dast("Missing required 'url' parameter".to_string()))?;
let host = Self::extract_host(url)
.unwrap_or_else(|| url.to_string());
let host = Self::extract_host(url).unwrap_or_else(|| url.to_string());
let port = input
.get("port")
.and_then(|v| v.as_u64())
.map(|p| p as u16)
.unwrap_or_else(|| Self::extract_port(url));
let port = input
.get("port")
.and_then(|v| v.as_u64())
.map(|p| p as u16)
.unwrap_or_else(|| Self::extract_port(url));
let target_id = context
.target
.id
.map(|oid| oid.to_hex())
.unwrap_or_else(|| "unknown".to_string());
let target_id = context
.target
.id
.map(|oid| oid.to_hex())
.unwrap_or_else(|| "unknown".to_string());
let mut findings = Vec::new();
let mut tls_data = json!({});
let mut findings = Vec::new();
let mut tls_data = json!({});
// First check: does the server even support HTTPS?
let https_url = if url.starts_with("https://") {
url.to_string()
} else if url.starts_with("http://") {
url.replace("http://", "https://")
} else {
format!("https://{url}")
};
// First check: does the server even support HTTPS?
let https_url = if url.starts_with("https://") {
url.to_string()
} else if url.starts_with("http://") {
url.replace("http://", "https://")
} else {
format!("https://{url}")
};
// Check if HTTP redirects to HTTPS
let http_url = if url.starts_with("http://") {
url.to_string()
} else if url.starts_with("https://") {
url.replace("https://", "http://")
} else {
format!("http://{url}")
};
// Check if HTTP redirects to HTTPS
let http_url = if url.starts_with("http://") {
url.to_string()
} else if url.starts_with("https://") {
url.replace("https://", "http://")
} else {
format!("http://{url}")
};
match self.http.get(&http_url).send().await {
Ok(resp) => {
let final_url = resp.url().to_string();
let redirects_to_https = final_url.starts_with("https://");
tls_data["http_redirects_to_https"] = json!(redirects_to_https);
match self.http.get(&http_url).send().await {
Ok(resp) => {
let final_url = resp.url().to_string();
let redirects_to_https = final_url.starts_with("https://");
tls_data["http_redirects_to_https"] = json!(redirects_to_https);
if !redirects_to_https {
let evidence = DastEvidence {
request_method: "GET".to_string(),
request_url: http_url.clone(),
request_headers: None,
request_body: None,
response_status: resp.status().as_u16(),
response_headers: None,
response_snippet: Some(format!("Final URL: {final_url}")),
screenshot_path: None,
payload: None,
response_time_ms: None,
};
if !redirects_to_https {
let evidence = DastEvidence {
request_method: "GET".to_string(),
request_url: http_url.clone(),
request_headers: None,
request_body: None,
response_status: resp.status().as_u16(),
response_headers: None,
response_snippet: Some(format!("Final URL: {final_url}")),
screenshot_path: None,
payload: None,
response_time_ms: None,
};
let mut finding = DastFinding::new(
String::new(),
target_id.clone(),
DastVulnType::TlsMisconfiguration,
format!("HTTP does not redirect to HTTPS for {host}"),
format!(
"HTTP requests to {host} are not redirected to HTTPS. \
let mut finding = DastFinding::new(
String::new(),
target_id.clone(),
DastVulnType::TlsMisconfiguration,
format!("HTTP does not redirect to HTTPS for {host}"),
format!(
"HTTP requests to {host} are not redirected to HTTPS. \
Users accessing the site via HTTP will have their traffic \
transmitted in cleartext."
),
Severity::Medium,
http_url.clone(),
"GET".to_string(),
);
finding.cwe = Some("CWE-319".to_string());
finding.evidence = vec![evidence];
finding.remediation = Some(
"Configure the web server to redirect all HTTP requests to HTTPS \
),
Severity::Medium,
http_url.clone(),
"GET".to_string(),
);
finding.cwe = Some("CWE-319".to_string());
finding.evidence = vec![evidence];
finding.remediation = Some(
"Configure the web server to redirect all HTTP requests to HTTPS \
using a 301 redirect."
.to_string(),
);
findings.push(finding);
.to_string(),
);
findings.push(finding);
}
}
Err(_) => {
tls_data["http_check_error"] = json!("Could not connect via HTTP");
}
}
Err(_) => {
tls_data["http_check_error"] = json!("Could not connect via HTTP");
}
}
// Perform TLS analysis
match Self::check_tls(&host, port).await {
Ok(tls_info) => {
tls_data["host"] = json!(host);
tls_data["port"] = json!(port);
tls_data["cert_subject"] = json!(tls_info.cert_subject);
tls_data["cert_issuer"] = json!(tls_info.cert_issuer);
tls_data["cert_not_before"] = json!(tls_info.cert_not_before);
tls_data["cert_not_after"] = json!(tls_info.cert_not_after);
tls_data["alpn_protocol"] = json!(tls_info.alpn_protocol);
tls_data["san_names"] = json!(tls_info.san_names);
// Perform TLS analysis
match Self::check_tls(&host, port).await {
Ok(tls_info) => {
tls_data["host"] = json!(host);
tls_data["port"] = json!(port);
tls_data["cert_subject"] = json!(tls_info.cert_subject);
tls_data["cert_issuer"] = json!(tls_info.cert_issuer);
tls_data["cert_not_before"] = json!(tls_info.cert_not_before);
tls_data["cert_not_after"] = json!(tls_info.cert_not_after);
tls_data["alpn_protocol"] = json!(tls_info.alpn_protocol);
tls_data["san_names"] = json!(tls_info.san_names);
if tls_info.cert_expired {
if tls_info.cert_expired {
let evidence = DastEvidence {
request_method: "TLS".to_string(),
request_url: format!("{host}:{port}"),
request_headers: None,
request_body: None,
response_status: 0,
response_headers: None,
response_snippet: Some(format!(
"Certificate expired. Not After: {}",
tls_info.cert_not_after
)),
screenshot_path: None,
payload: None,
response_time_ms: None,
};
let mut finding = DastFinding::new(
String::new(),
target_id.clone(),
DastVulnType::TlsMisconfiguration,
format!("Expired TLS certificate for {host}"),
format!(
"The TLS certificate for {host} has expired. \
Browsers will show security warnings to users."
),
Severity::High,
format!("https://{host}:{port}"),
"TLS".to_string(),
);
finding.cwe = Some("CWE-295".to_string());
finding.evidence = vec![evidence];
finding.remediation = Some(
"Renew the TLS certificate. Consider using automated certificate \
management with Let's Encrypt or a similar CA."
.to_string(),
);
findings.push(finding);
warn!(host, "Expired TLS certificate");
}
if tls_info.cert_self_signed {
let evidence = DastEvidence {
request_method: "TLS".to_string(),
request_url: format!("{host}:{port}"),
request_headers: None,
request_body: None,
response_status: 0,
response_headers: None,
response_snippet: Some("Self-signed certificate detected".to_string()),
screenshot_path: None,
payload: None,
response_time_ms: None,
};
let mut finding = DastFinding::new(
String::new(),
target_id.clone(),
DastVulnType::TlsMisconfiguration,
format!("Self-signed TLS certificate for {host}"),
format!(
"The TLS certificate for {host} is self-signed and not issued by a \
trusted certificate authority. Browsers will show security warnings."
),
Severity::Medium,
format!("https://{host}:{port}"),
"TLS".to_string(),
);
finding.cwe = Some("CWE-295".to_string());
finding.evidence = vec![evidence];
finding.remediation = Some(
"Replace the self-signed certificate with one issued by a trusted \
certificate authority."
.to_string(),
);
findings.push(finding);
warn!(host, "Self-signed certificate");
}
}
Err(e) => {
tls_data["tls_error"] = json!(e.to_string());
// TLS handshake failure itself is a finding
let evidence = DastEvidence {
request_method: "TLS".to_string(),
request_url: format!("{host}:{port}"),
@@ -264,10 +347,7 @@ impl PentestTool for TlsAnalyzerTool {
request_body: None,
response_status: 0,
response_headers: None,
response_snippet: Some(format!(
"Certificate expired. Not After: {}",
tls_info.cert_not_after
)),
response_snippet: Some(format!("TLS error: {e}")),
screenshot_path: None,
payload: None,
response_time_ms: None,
@@ -277,10 +357,9 @@ impl PentestTool for TlsAnalyzerTool {
String::new(),
target_id.clone(),
DastVulnType::TlsMisconfiguration,
format!("Expired TLS certificate for {host}"),
format!("TLS handshake failure for {host}"),
format!(
"The TLS certificate for {host} has expired. \
Browsers will show security warnings to users."
"Could not establish a TLS connection to {host}:{port}. Error: {e}"
),
Severity::High,
format!("https://{host}:{port}"),
@@ -289,115 +368,37 @@ impl PentestTool for TlsAnalyzerTool {
finding.cwe = Some("CWE-295".to_string());
finding.evidence = vec![evidence];
finding.remediation = Some(
"Renew the TLS certificate. Consider using automated certificate \
management with Let's Encrypt or a similar CA."
.to_string(),
);
findings.push(finding);
warn!(host, "Expired TLS certificate");
}
if tls_info.cert_self_signed {
let evidence = DastEvidence {
request_method: "TLS".to_string(),
request_url: format!("{host}:{port}"),
request_headers: None,
request_body: None,
response_status: 0,
response_headers: None,
response_snippet: Some("Self-signed certificate detected".to_string()),
screenshot_path: None,
payload: None,
response_time_ms: None,
};
let mut finding = DastFinding::new(
String::new(),
target_id.clone(),
DastVulnType::TlsMisconfiguration,
format!("Self-signed TLS certificate for {host}"),
format!(
"The TLS certificate for {host} is self-signed and not issued by a \
trusted certificate authority. Browsers will show security warnings."
),
Severity::Medium,
format!("https://{host}:{port}"),
"TLS".to_string(),
);
finding.cwe = Some("CWE-295".to_string());
finding.evidence = vec![evidence];
finding.remediation = Some(
"Replace the self-signed certificate with one issued by a trusted \
certificate authority."
.to_string(),
);
findings.push(finding);
warn!(host, "Self-signed certificate");
}
}
Err(e) => {
tls_data["tls_error"] = json!(e.to_string());
// TLS handshake failure itself is a finding
let evidence = DastEvidence {
request_method: "TLS".to_string(),
request_url: format!("{host}:{port}"),
request_headers: None,
request_body: None,
response_status: 0,
response_headers: None,
response_snippet: Some(format!("TLS error: {e}")),
screenshot_path: None,
payload: None,
response_time_ms: None,
};
let mut finding = DastFinding::new(
String::new(),
target_id.clone(),
DastVulnType::TlsMisconfiguration,
format!("TLS handshake failure for {host}"),
format!(
"Could not establish a TLS connection to {host}:{port}. Error: {e}"
),
Severity::High,
format!("https://{host}:{port}"),
"TLS".to_string(),
);
finding.cwe = Some("CWE-295".to_string());
finding.evidence = vec![evidence];
finding.remediation = Some(
"Ensure TLS is properly configured on the server. Check that the \
"Ensure TLS is properly configured on the server. Check that the \
certificate is valid and the server supports modern TLS versions."
.to_string(),
);
findings.push(finding);
.to_string(),
);
findings.push(finding);
}
}
}
// Check strict transport security via an HTTPS request
match self.http.get(&https_url).send().await {
Ok(resp) => {
let hsts = resp.headers().get("strict-transport-security");
tls_data["hsts_header"] = json!(hsts.map(|v| v.to_str().unwrap_or("")));
// Check strict transport security via an HTTPS request
match self.http.get(&https_url).send().await {
Ok(resp) => {
let hsts = resp.headers().get("strict-transport-security");
tls_data["hsts_header"] = json!(hsts.map(|v| v.to_str().unwrap_or("")));
if hsts.is_none() {
let evidence = DastEvidence {
request_method: "GET".to_string(),
request_url: https_url.clone(),
request_headers: None,
request_body: None,
response_status: resp.status().as_u16(),
response_headers: None,
response_snippet: Some(
"Strict-Transport-Security header not present".to_string(),
),
screenshot_path: None,
payload: None,
response_time_ms: None,
};
if hsts.is_none() {
let evidence = DastEvidence {
request_method: "GET".to_string(),
request_url: https_url.clone(),
request_headers: None,
request_body: None,
response_status: resp.status().as_u16(),
response_headers: None,
response_snippet: Some(
"Strict-Transport-Security header not present".to_string(),
),
screenshot_path: None,
payload: None,
response_time_ms: None,
};
let mut finding = DastFinding::new(
let mut finding = DastFinding::new(
String::new(),
target_id.clone(),
DastVulnType::TlsMisconfiguration,
@@ -410,33 +411,33 @@ impl PentestTool for TlsAnalyzerTool {
https_url.clone(),
"GET".to_string(),
);
finding.cwe = Some("CWE-319".to_string());
finding.evidence = vec![evidence];
finding.remediation = Some(
finding.cwe = Some("CWE-319".to_string());
finding.evidence = vec![evidence];
finding.remediation = Some(
"Add the Strict-Transport-Security header with an appropriate max-age. \
Example: 'Strict-Transport-Security: max-age=31536000; includeSubDomains'."
.to_string(),
);
findings.push(finding);
findings.push(finding);
}
}
Err(_) => {
tls_data["https_check_error"] = json!("Could not connect via HTTPS");
}
}
Err(_) => {
tls_data["https_check_error"] = json!("Could not connect via HTTPS");
}
}
let count = findings.len();
info!(host = %host, findings = count, "TLS analysis complete");
let count = findings.len();
info!(host = %host, findings = count, "TLS analysis complete");
Ok(PentestToolResult {
summary: if count > 0 {
format!("Found {count} TLS configuration issues for {host}.")
} else {
format!("TLS configuration looks good for {host}.")
},
findings,
data: tls_data,
})
Ok(PentestToolResult {
summary: if count > 0 {
format!("Found {count} TLS configuration issues for {host}.")
} else {
format!("TLS configuration looks good for {host}.")
},
findings,
data: tls_data,
})
})
}
}

View File

@@ -1,5 +1,7 @@
use compliance_core::error::CoreError;
use compliance_core::traits::dast_agent::{DastAgent, DastContext, DiscoveredEndpoint, EndpointParameter};
use compliance_core::traits::dast_agent::{
DastAgent, DastContext, DiscoveredEndpoint, EndpointParameter,
};
use compliance_core::traits::pentest_tool::{PentestTool, PentestToolContext, PentestToolResult};
use serde_json::json;
@@ -7,30 +9,52 @@ use crate::agents::xss::XssAgent;
/// PentestTool wrapper around the existing XssAgent.
pub struct XssTool {
http: reqwest::Client,
_http: reqwest::Client,
agent: XssAgent,
}
impl XssTool {
pub fn new(http: reqwest::Client) -> Self {
let agent = XssAgent::new(http.clone());
Self { http, agent }
Self { _http: http, agent }
}
fn parse_endpoints(input: &serde_json::Value) -> Vec<DiscoveredEndpoint> {
let mut endpoints = Vec::new();
if let Some(arr) = input.get("endpoints").and_then(|v| v.as_array()) {
for ep in arr {
let url = ep.get("url").and_then(|v| v.as_str()).unwrap_or_default().to_string();
let method = ep.get("method").and_then(|v| v.as_str()).unwrap_or("GET").to_string();
let url = ep
.get("url")
.and_then(|v| v.as_str())
.unwrap_or_default()
.to_string();
let method = ep
.get("method")
.and_then(|v| v.as_str())
.unwrap_or("GET")
.to_string();
let mut parameters = Vec::new();
if let Some(params) = ep.get("parameters").and_then(|v| v.as_array()) {
for p in params {
parameters.push(EndpointParameter {
name: p.get("name").and_then(|v| v.as_str()).unwrap_or_default().to_string(),
location: p.get("location").and_then(|v| v.as_str()).unwrap_or("query").to_string(),
param_type: p.get("param_type").and_then(|v| v.as_str()).map(String::from),
example_value: p.get("example_value").and_then(|v| v.as_str()).map(String::from),
name: p
.get("name")
.and_then(|v| v.as_str())
.unwrap_or_default()
.to_string(),
location: p
.get("location")
.and_then(|v| v.as_str())
.unwrap_or("query")
.to_string(),
param_type: p
.get("param_type")
.and_then(|v| v.as_str())
.map(String::from),
example_value: p
.get("example_value")
.and_then(|v| v.as_str())
.map(String::from),
});
}
}
@@ -38,8 +62,14 @@ impl XssTool {
url,
method,
parameters,
content_type: ep.get("content_type").and_then(|v| v.as_str()).map(String::from),
requires_auth: ep.get("requires_auth").and_then(|v| v.as_bool()).unwrap_or(false),
content_type: ep
.get("content_type")
.and_then(|v| v.as_str())
.map(String::from),
requires_auth: ep
.get("requires_auth")
.and_then(|v| v.as_bool())
.unwrap_or(false),
});
}
}
@@ -47,6 +77,91 @@ impl XssTool {
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn parse_endpoints_basic() {
let input = json!({
"endpoints": [
{
"url": "https://example.com/search",
"method": "GET",
"parameters": [
{ "name": "q", "location": "query" }
]
}
]
});
let endpoints = XssTool::parse_endpoints(&input);
assert_eq!(endpoints.len(), 1);
assert_eq!(endpoints[0].url, "https://example.com/search");
assert_eq!(endpoints[0].method, "GET");
assert_eq!(endpoints[0].parameters.len(), 1);
assert_eq!(endpoints[0].parameters[0].name, "q");
assert_eq!(endpoints[0].parameters[0].location, "query");
}
#[test]
fn parse_endpoints_empty() {
let input = json!({ "endpoints": [] });
assert!(XssTool::parse_endpoints(&input).is_empty());
}
#[test]
fn parse_endpoints_missing_key() {
let input = json!({});
assert!(XssTool::parse_endpoints(&input).is_empty());
}
#[test]
fn parse_endpoints_defaults() {
let input = json!({
"endpoints": [
{ "url": "https://example.com/api", "parameters": [] }
]
});
let endpoints = XssTool::parse_endpoints(&input);
assert_eq!(endpoints[0].method, "GET"); // default
assert!(!endpoints[0].requires_auth); // default false
}
#[test]
fn parse_endpoints_full_params() {
let input = json!({
"endpoints": [{
"url": "https://example.com",
"method": "POST",
"content_type": "application/json",
"requires_auth": true,
"parameters": [{
"name": "body",
"location": "body",
"param_type": "string",
"example_value": "test"
}]
}]
});
let endpoints = XssTool::parse_endpoints(&input);
assert_eq!(endpoints[0].method, "POST");
assert_eq!(
endpoints[0].content_type.as_deref(),
Some("application/json")
);
assert!(endpoints[0].requires_auth);
assert_eq!(
endpoints[0].parameters[0].param_type.as_deref(),
Some("string")
);
assert_eq!(
endpoints[0].parameters[0].example_value.as_deref(),
Some("test")
);
}
}
impl PentestTool for XssTool {
fn name(&self) -> &str {
"xss_scanner"
@@ -100,35 +215,37 @@ impl PentestTool for XssTool {
&'a self,
input: serde_json::Value,
context: &'a PentestToolContext,
) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<PentestToolResult, CoreError>> + Send + 'a>> {
) -> std::pin::Pin<
Box<dyn std::future::Future<Output = Result<PentestToolResult, CoreError>> + Send + 'a>,
> {
Box::pin(async move {
let endpoints = Self::parse_endpoints(&input);
if endpoints.is_empty() {
return Ok(PentestToolResult {
summary: "No endpoints provided to test.".to_string(),
findings: Vec::new(),
data: json!({}),
});
}
let endpoints = Self::parse_endpoints(&input);
if endpoints.is_empty() {
return Ok(PentestToolResult {
summary: "No endpoints provided to test.".to_string(),
findings: Vec::new(),
data: json!({}),
});
}
let dast_context = DastContext {
endpoints,
technologies: Vec::new(),
sast_hints: Vec::new(),
};
let dast_context = DastContext {
endpoints,
technologies: Vec::new(),
sast_hints: Vec::new(),
};
let findings = self.agent.run(&context.target, &dast_context).await?;
let count = findings.len();
let findings = self.agent.run(&context.target, &dast_context).await?;
let count = findings.len();
Ok(PentestToolResult {
summary: if count > 0 {
format!("Found {count} XSS vulnerabilities.")
} else {
"No XSS vulnerabilities detected.".to_string()
},
findings,
data: json!({ "endpoints_tested": dast_context.endpoints.len() }),
})
Ok(PentestToolResult {
summary: if count > 0 {
format!("Found {count} XSS vulnerabilities.")
} else {
"No XSS vulnerabilities detected.".to_string()
},
findings,
data: json!({ "endpoints_tested": dast_context.endpoints.len() }),
})
})
}
}

View File

@@ -0,0 +1,4 @@
// Integration tests for DAST agents.
//
// Test individual security testing agents (XSS, SQLi, SSRF, etc.)
// against controlled test targets.

View File

@@ -94,3 +94,64 @@ fn build_context_header(file_path: &str, qualified_name: &str, kind: &str) -> St
format!("// {file_path} | {kind}")
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_build_context_header_with_parent() {
let result =
build_context_header("src/main.rs", "src/main.rs::MyStruct::my_method", "method");
assert_eq!(result, "// src/main.rs | method in src/main.rs::MyStruct");
}
#[test]
fn test_build_context_header_top_level() {
let result = build_context_header("src/lib.rs", "main", "function");
assert_eq!(result, "// src/lib.rs | function");
}
#[test]
fn test_build_context_header_single_parent() {
let result = build_context_header("src/lib.rs", "src/lib.rs::do_stuff", "function");
assert_eq!(result, "// src/lib.rs | function in src/lib.rs");
}
#[test]
fn test_build_context_header_deep_nesting() {
let result = build_context_header(
"src/mod.rs",
"src/mod.rs::Outer::Inner::deep_fn",
"function",
);
assert_eq!(
result,
"// src/mod.rs | function in src/mod.rs::Outer::Inner"
);
}
#[test]
fn test_build_context_header_empty_strings() {
let result = build_context_header("", "", "function");
assert_eq!(result, "// | function");
}
#[test]
fn test_code_chunk_struct_fields() {
let chunk = CodeChunk {
qualified_name: "main".to_string(),
kind: "function".to_string(),
file_path: "src/main.rs".to_string(),
start_line: 1,
end_line: 10,
language: "rust".to_string(),
content: "fn main() {}".to_string(),
context_header: "// src/main.rs | function".to_string(),
token_estimate: 3,
};
assert_eq!(chunk.start_line, 1);
assert_eq!(chunk.end_line, 10);
assert_eq!(chunk.language, "rust");
}
}

View File

@@ -253,3 +253,215 @@ fn detect_communities_with_assignment(code_graph: &mut CodeGraph) -> u32 {
next_id
}
#[cfg(test)]
mod tests {
use super::*;
use compliance_core::models::graph::{CodeEdgeKind, CodeNode, CodeNodeKind};
use petgraph::graph::DiGraph;
fn make_node(qualified_name: &str, graph_index: u32) -> CodeNode {
CodeNode {
id: None,
repo_id: "test".to_string(),
graph_build_id: "build1".to_string(),
qualified_name: qualified_name.to_string(),
name: qualified_name.to_string(),
kind: CodeNodeKind::Function,
file_path: "test.rs".to_string(),
start_line: 1,
end_line: 10,
language: "rust".to_string(),
community_id: None,
is_entry_point: false,
graph_index: Some(graph_index),
}
}
fn make_empty_code_graph() -> CodeGraph {
CodeGraph {
graph: DiGraph::new(),
node_map: HashMap::new(),
nodes: Vec::new(),
edges: Vec::new(),
}
}
#[test]
fn test_detect_communities_empty_graph() {
let cg = make_empty_code_graph();
assert_eq!(detect_communities(&cg), 0);
}
#[test]
fn test_detect_communities_single_node_no_edges() {
let mut graph = DiGraph::new();
let idx = graph.add_node("a".to_string());
let mut node_map = HashMap::new();
node_map.insert("a".to_string(), idx);
let cg = CodeGraph {
graph,
node_map,
nodes: vec![make_node("a", 0)],
edges: Vec::new(),
};
// Single node with no edges => 1 community (itself)
assert_eq!(detect_communities(&cg), 1);
}
#[test]
fn test_detect_communities_isolated_nodes() {
let mut graph = DiGraph::new();
let a = graph.add_node("a".to_string());
let b = graph.add_node("b".to_string());
let c = graph.add_node("c".to_string());
let mut node_map = HashMap::new();
node_map.insert("a".to_string(), a);
node_map.insert("b".to_string(), b);
node_map.insert("c".to_string(), c);
let cg = CodeGraph {
graph,
node_map,
nodes: vec![make_node("a", 0), make_node("b", 1), make_node("c", 2)],
edges: Vec::new(),
};
// 3 isolated nodes => 3 communities
assert_eq!(detect_communities(&cg), 3);
}
#[test]
fn test_detect_communities_fully_connected() {
let mut graph = DiGraph::new();
let a = graph.add_node("a".to_string());
let b = graph.add_node("b".to_string());
let c = graph.add_node("c".to_string());
graph.add_edge(a, b, CodeEdgeKind::Calls);
graph.add_edge(b, c, CodeEdgeKind::Calls);
graph.add_edge(c, a, CodeEdgeKind::Calls);
let mut node_map = HashMap::new();
node_map.insert("a".to_string(), a);
node_map.insert("b".to_string(), b);
node_map.insert("c".to_string(), c);
let cg = CodeGraph {
graph,
node_map,
nodes: vec![make_node("a", 0), make_node("b", 1), make_node("c", 2)],
edges: Vec::new(),
};
let num = detect_communities(&cg);
// Fully connected triangle should converge to 1 community
assert!(num >= 1);
assert!(num <= 3);
}
#[test]
fn test_detect_communities_two_clusters() {
// Two separate triangles connected by a single weak edge
let mut graph = DiGraph::new();
let a = graph.add_node("a".to_string());
let b = graph.add_node("b".to_string());
let c = graph.add_node("c".to_string());
let d = graph.add_node("d".to_string());
let e = graph.add_node("e".to_string());
let f = graph.add_node("f".to_string());
// Cluster 1: a-b-c fully connected
graph.add_edge(a, b, CodeEdgeKind::Calls);
graph.add_edge(b, a, CodeEdgeKind::Calls);
graph.add_edge(b, c, CodeEdgeKind::Calls);
graph.add_edge(c, b, CodeEdgeKind::Calls);
graph.add_edge(a, c, CodeEdgeKind::Calls);
graph.add_edge(c, a, CodeEdgeKind::Calls);
// Cluster 2: d-e-f fully connected
graph.add_edge(d, e, CodeEdgeKind::Calls);
graph.add_edge(e, d, CodeEdgeKind::Calls);
graph.add_edge(e, f, CodeEdgeKind::Calls);
graph.add_edge(f, e, CodeEdgeKind::Calls);
graph.add_edge(d, f, CodeEdgeKind::Calls);
graph.add_edge(f, d, CodeEdgeKind::Calls);
let mut node_map = HashMap::new();
node_map.insert("a".to_string(), a);
node_map.insert("b".to_string(), b);
node_map.insert("c".to_string(), c);
node_map.insert("d".to_string(), d);
node_map.insert("e".to_string(), e);
node_map.insert("f".to_string(), f);
let cg = CodeGraph {
graph,
node_map,
nodes: vec![
make_node("a", 0),
make_node("b", 1),
make_node("c", 2),
make_node("d", 3),
make_node("e", 4),
make_node("f", 5),
],
edges: Vec::new(),
};
let num = detect_communities(&cg);
// Two disconnected clusters should yield 2 communities
assert_eq!(num, 2);
}
#[test]
fn test_apply_communities_assigns_ids() {
let mut graph = DiGraph::new();
let a = graph.add_node("a".to_string());
let b = graph.add_node("b".to_string());
graph.add_edge(a, b, CodeEdgeKind::Calls);
let mut node_map = HashMap::new();
node_map.insert("a".to_string(), a);
node_map.insert("b".to_string(), b);
let mut cg = CodeGraph {
graph,
node_map,
nodes: vec![make_node("a", 0), make_node("b", 1)],
edges: Vec::new(),
};
let count = apply_communities(&mut cg);
assert!(count >= 1);
// All nodes should have a community_id assigned
for node in &cg.nodes {
assert!(node.community_id.is_some());
}
}
#[test]
fn test_apply_communities_empty() {
let mut cg = make_empty_code_graph();
assert_eq!(apply_communities(&mut cg), 0);
}
#[test]
fn test_apply_communities_isolated_nodes_get_own_community() {
let mut graph = DiGraph::new();
let a = graph.add_node("a".to_string());
let b = graph.add_node("b".to_string());
let mut node_map = HashMap::new();
node_map.insert("a".to_string(), a);
node_map.insert("b".to_string(), b);
let mut cg = CodeGraph {
graph,
node_map,
nodes: vec![make_node("a", 0), make_node("b", 1)],
edges: Vec::new(),
};
let count = apply_communities(&mut cg);
assert_eq!(count, 2);
// Each isolated node should be in a different community
let c0 = cg.nodes[0].community_id.unwrap();
let c1 = cg.nodes[1].community_id.unwrap();
assert_ne!(c0, c1);
}
}

View File

@@ -172,3 +172,185 @@ impl GraphEngine {
ImpactAnalyzer::new(code_graph)
}
}
#[cfg(test)]
mod tests {
use super::*;
use compliance_core::models::graph::{CodeEdgeKind, CodeNode, CodeNodeKind};
fn make_node(qualified_name: &str) -> CodeNode {
CodeNode {
id: None,
repo_id: "test".to_string(),
graph_build_id: "build1".to_string(),
qualified_name: qualified_name.to_string(),
name: qualified_name
.split("::")
.last()
.unwrap_or(qualified_name)
.to_string(),
kind: CodeNodeKind::Function,
file_path: "src/main.rs".to_string(),
start_line: 1,
end_line: 10,
language: "rust".to_string(),
community_id: None,
is_entry_point: false,
graph_index: None,
}
}
fn build_test_node_map(names: &[&str]) -> HashMap<String, NodeIndex> {
let mut graph: DiGraph<String, String> = DiGraph::new();
let mut map = HashMap::new();
for name in names {
let idx = graph.add_node(name.to_string());
map.insert(name.to_string(), idx);
}
map
}
#[test]
fn test_resolve_edge_target_direct_match() {
let engine = GraphEngine::new(1000);
let node_map = build_test_node_map(&["src/main.rs::foo", "src/main.rs::bar"]);
let result = engine.resolve_edge_target("src/main.rs::foo", &node_map);
assert!(result.is_some());
assert_eq!(result.unwrap(), node_map["src/main.rs::foo"]);
}
#[test]
fn test_resolve_edge_target_short_name_match() {
let engine = GraphEngine::new(1000);
let node_map = build_test_node_map(&["src/main.rs::foo", "src/main.rs::bar"]);
let result = engine.resolve_edge_target("foo", &node_map);
assert!(result.is_some());
assert_eq!(result.unwrap(), node_map["src/main.rs::foo"]);
}
#[test]
fn test_resolve_edge_target_method_match() {
let engine = GraphEngine::new(1000);
let node_map = build_test_node_map(&["src/main.rs::MyStruct::do_thing"]);
let result = engine.resolve_edge_target("do_thing", &node_map);
assert!(result.is_some());
}
#[test]
fn test_resolve_edge_target_self_method() {
let engine = GraphEngine::new(1000);
let node_map = build_test_node_map(&["src/main.rs::MyStruct::process"]);
let result = engine.resolve_edge_target("self.process", &node_map);
assert!(result.is_some());
}
#[test]
fn test_resolve_edge_target_no_match() {
let engine = GraphEngine::new(1000);
let node_map = build_test_node_map(&["src/main.rs::foo"]);
let result = engine.resolve_edge_target("nonexistent", &node_map);
assert!(result.is_none());
}
#[test]
fn test_resolve_edge_target_empty_map() {
let engine = GraphEngine::new(1000);
let node_map = HashMap::new();
let result = engine.resolve_edge_target("anything", &node_map);
assert!(result.is_none());
}
#[test]
fn test_resolve_edge_target_dot_notation() {
let engine = GraphEngine::new(1000);
let node_map = build_test_node_map(&["src/app.js.handler"]);
let result = engine.resolve_edge_target("handler", &node_map);
assert!(result.is_some());
}
#[test]
fn test_build_petgraph_empty() {
let engine = GraphEngine::new(1000);
let output = ParseOutput::default();
let code_graph = engine.build_petgraph(output).unwrap();
assert_eq!(code_graph.nodes.len(), 0);
assert_eq!(code_graph.edges.len(), 0);
assert_eq!(code_graph.graph.node_count(), 0);
}
#[test]
fn test_build_petgraph_nodes_get_graph_index() {
let engine = GraphEngine::new(1000);
let mut output = ParseOutput::default();
output.nodes.push(make_node("src/main.rs::foo"));
output.nodes.push(make_node("src/main.rs::bar"));
let code_graph = engine.build_petgraph(output).unwrap();
assert_eq!(code_graph.nodes.len(), 2);
assert_eq!(code_graph.graph.node_count(), 2);
// All nodes should have a graph_index assigned
for node in &code_graph.nodes {
assert!(node.graph_index.is_some());
}
}
#[test]
fn test_build_petgraph_resolves_edges() {
let engine = GraphEngine::new(1000);
let mut output = ParseOutput::default();
output.nodes.push(make_node("src/main.rs::foo"));
output.nodes.push(make_node("src/main.rs::bar"));
output.edges.push(CodeEdge {
id: None,
repo_id: "test".to_string(),
graph_build_id: "build1".to_string(),
source: "src/main.rs::foo".to_string(),
target: "bar".to_string(), // short name, should resolve
kind: CodeEdgeKind::Calls,
file_path: "src/main.rs".to_string(),
line_number: Some(5),
});
let code_graph = engine.build_petgraph(output).unwrap();
assert_eq!(code_graph.edges.len(), 1);
assert_eq!(code_graph.graph.edge_count(), 1);
// The resolved edge target should be the full qualified name
assert_eq!(code_graph.edges[0].target, "src/main.rs::bar");
}
#[test]
fn test_build_petgraph_skips_unresolved_edges() {
let engine = GraphEngine::new(1000);
let mut output = ParseOutput::default();
output.nodes.push(make_node("src/main.rs::foo"));
output.edges.push(CodeEdge {
id: None,
repo_id: "test".to_string(),
graph_build_id: "build1".to_string(),
source: "src/main.rs::foo".to_string(),
target: "external_crate::something".to_string(),
kind: CodeEdgeKind::Calls,
file_path: "src/main.rs".to_string(),
line_number: Some(5),
});
let code_graph = engine.build_petgraph(output).unwrap();
assert_eq!(code_graph.edges.len(), 0);
assert_eq!(code_graph.graph.edge_count(), 0);
}
#[test]
fn test_code_graph_node_map_consistency() {
let engine = GraphEngine::new(1000);
let mut output = ParseOutput::default();
output.nodes.push(make_node("a::b"));
output.nodes.push(make_node("a::c"));
output.nodes.push(make_node("a::d"));
let code_graph = engine.build_petgraph(output).unwrap();
assert_eq!(code_graph.node_map.len(), 3);
assert!(code_graph.node_map.contains_key("a::b"));
assert!(code_graph.node_map.contains_key("a::c"));
assert!(code_graph.node_map.contains_key("a::d"));
}
}

View File

@@ -222,3 +222,378 @@ impl<'a> ImpactAnalyzer<'a> {
.find(|n| n.graph_index == Some(target_gi))
}
}
#[cfg(test)]
mod tests {
use super::*;
use compliance_core::models::graph::{CodeEdgeKind, CodeNode, CodeNodeKind};
use petgraph::graph::DiGraph;
use std::collections::HashMap;
fn make_node(
qualified_name: &str,
file_path: &str,
start: u32,
end: u32,
graph_index: u32,
is_entry: bool,
kind: CodeNodeKind,
) -> CodeNode {
CodeNode {
id: None,
repo_id: "test".to_string(),
graph_build_id: "build1".to_string(),
qualified_name: qualified_name.to_string(),
name: qualified_name
.split("::")
.last()
.unwrap_or(qualified_name)
.to_string(),
kind,
file_path: file_path.to_string(),
start_line: start,
end_line: end,
language: "rust".to_string(),
community_id: None,
is_entry_point: is_entry,
graph_index: Some(graph_index),
}
}
fn make_fn_node(
qualified_name: &str,
file_path: &str,
start: u32,
end: u32,
gi: u32,
) -> CodeNode {
make_node(
qualified_name,
file_path,
start,
end,
gi,
false,
CodeNodeKind::Function,
)
}
/// Build a simple linear graph: A -> B -> C
fn build_linear_graph() -> CodeGraph {
let mut graph = DiGraph::new();
let a = graph.add_node("a".to_string());
let b = graph.add_node("b".to_string());
let c = graph.add_node("c".to_string());
graph.add_edge(a, b, CodeEdgeKind::Calls);
graph.add_edge(b, c, CodeEdgeKind::Calls);
let mut node_map = HashMap::new();
node_map.insert("a".to_string(), a);
node_map.insert("b".to_string(), b);
node_map.insert("c".to_string(), c);
CodeGraph {
graph,
node_map,
nodes: vec![
make_fn_node("a", "src/main.rs", 1, 5, 0),
make_fn_node("b", "src/main.rs", 7, 12, 1),
make_fn_node("c", "src/main.rs", 14, 20, 2),
],
edges: Vec::new(),
}
}
#[test]
fn test_bfs_reachable_outgoing_linear() {
let cg = build_linear_graph();
let analyzer = ImpactAnalyzer::new(&cg);
let start = cg.node_map["a"];
let reachable = analyzer.bfs_reachable(start, Direction::Outgoing);
// From a, we can reach b and c
assert_eq!(reachable.len(), 2);
assert!(reachable.contains(&cg.node_map["b"]));
assert!(reachable.contains(&cg.node_map["c"]));
}
#[test]
fn test_bfs_reachable_incoming_linear() {
let cg = build_linear_graph();
let analyzer = ImpactAnalyzer::new(&cg);
let start = cg.node_map["c"];
let reachable = analyzer.bfs_reachable(start, Direction::Incoming);
// c is reached by a and b
assert_eq!(reachable.len(), 2);
assert!(reachable.contains(&cg.node_map["a"]));
assert!(reachable.contains(&cg.node_map["b"]));
}
#[test]
fn test_bfs_reachable_no_neighbors() {
let mut graph = DiGraph::new();
let a = graph.add_node("a".to_string());
let cg = CodeGraph {
graph,
node_map: [("a".to_string(), a)].into_iter().collect(),
nodes: vec![make_fn_node("a", "src/main.rs", 1, 5, 0)],
edges: Vec::new(),
};
let analyzer = ImpactAnalyzer::new(&cg);
let reachable = analyzer.bfs_reachable(a, Direction::Outgoing);
assert!(reachable.is_empty());
}
#[test]
fn test_bfs_reachable_cycle() {
let mut graph = DiGraph::new();
let a = graph.add_node("a".to_string());
let b = graph.add_node("b".to_string());
graph.add_edge(a, b, CodeEdgeKind::Calls);
graph.add_edge(b, a, CodeEdgeKind::Calls);
let cg = CodeGraph {
graph,
node_map: [("a".to_string(), a), ("b".to_string(), b)]
.into_iter()
.collect(),
nodes: vec![
make_fn_node("a", "f.rs", 1, 5, 0),
make_fn_node("b", "f.rs", 6, 10, 1),
],
edges: Vec::new(),
};
let analyzer = ImpactAnalyzer::new(&cg);
let reachable = analyzer.bfs_reachable(a, Direction::Outgoing);
// Should handle cycle without infinite loop
assert_eq!(reachable.len(), 1);
assert!(reachable.contains(&b));
}
#[test]
fn test_find_path_exists() {
let cg = build_linear_graph();
let analyzer = ImpactAnalyzer::new(&cg);
let path = analyzer.find_path(cg.node_map["a"], cg.node_map["c"], 10);
assert!(path.is_some());
let names = path.unwrap();
assert_eq!(names, vec!["a", "b", "c"]);
}
#[test]
fn test_find_path_direct() {
let cg = build_linear_graph();
let analyzer = ImpactAnalyzer::new(&cg);
let path = analyzer.find_path(cg.node_map["a"], cg.node_map["b"], 10);
assert!(path.is_some());
let names = path.unwrap();
assert_eq!(names, vec!["a", "b"]);
}
#[test]
fn test_find_path_same_node() {
let cg = build_linear_graph();
let analyzer = ImpactAnalyzer::new(&cg);
let path = analyzer.find_path(cg.node_map["a"], cg.node_map["a"], 10);
assert!(path.is_some());
let names = path.unwrap();
assert_eq!(names, vec!["a"]);
}
#[test]
fn test_find_path_no_connection() {
let mut graph = DiGraph::new();
let a = graph.add_node("a".to_string());
let b = graph.add_node("b".to_string());
// No edge between a and b
let cg = CodeGraph {
graph,
node_map: [("a".to_string(), a), ("b".to_string(), b)]
.into_iter()
.collect(),
nodes: vec![
make_fn_node("a", "f.rs", 1, 5, 0),
make_fn_node("b", "f.rs", 6, 10, 1),
],
edges: Vec::new(),
};
let analyzer = ImpactAnalyzer::new(&cg);
let path = analyzer.find_path(a, b, 10);
assert!(path.is_none());
}
#[test]
fn test_find_path_depth_limited() {
// Build a long chain: a -> b -> c -> d -> e
let mut graph = DiGraph::new();
let a = graph.add_node("a".to_string());
let b = graph.add_node("b".to_string());
let c = graph.add_node("c".to_string());
let d = graph.add_node("d".to_string());
let e = graph.add_node("e".to_string());
graph.add_edge(a, b, CodeEdgeKind::Calls);
graph.add_edge(b, c, CodeEdgeKind::Calls);
graph.add_edge(c, d, CodeEdgeKind::Calls);
graph.add_edge(d, e, CodeEdgeKind::Calls);
let mut node_map = HashMap::new();
node_map.insert("a".to_string(), a);
node_map.insert("b".to_string(), b);
node_map.insert("c".to_string(), c);
node_map.insert("d".to_string(), d);
node_map.insert("e".to_string(), e);
let cg = CodeGraph {
graph,
node_map,
nodes: vec![
make_fn_node("a", "f.rs", 1, 2, 0),
make_fn_node("b", "f.rs", 3, 4, 1),
make_fn_node("c", "f.rs", 5, 6, 2),
make_fn_node("d", "f.rs", 7, 8, 3),
make_fn_node("e", "f.rs", 9, 10, 4),
],
edges: Vec::new(),
};
let analyzer = ImpactAnalyzer::new(&cg);
// Depth 3 won't reach e from a (path length 5)
let path = analyzer.find_path(a, e, 3);
assert!(path.is_none());
// Depth 5 should reach
let path = analyzer.find_path(a, e, 5);
assert!(path.is_some());
}
#[test]
fn test_find_node_at_location_exact_line() {
let cg = build_linear_graph();
let analyzer = ImpactAnalyzer::new(&cg);
// Node "b" is at lines 7-12
let result = analyzer.find_node_at_location("src/main.rs", Some(9));
assert!(result.is_some());
assert_eq!(result.unwrap(), cg.node_map["b"]);
}
#[test]
fn test_find_node_at_location_narrowest_match() {
// Outer function 1-20, inner nested 5-10
let mut graph = DiGraph::new();
let outer = graph.add_node("outer".to_string());
let inner = graph.add_node("inner".to_string());
let cg = CodeGraph {
graph,
node_map: [("outer".to_string(), outer), ("inner".to_string(), inner)]
.into_iter()
.collect(),
nodes: vec![
make_fn_node("outer", "src/main.rs", 1, 20, 0),
make_fn_node("inner", "src/main.rs", 5, 10, 1),
],
edges: Vec::new(),
};
let analyzer = ImpactAnalyzer::new(&cg);
// Line 7 is inside both, but inner is narrower
let result = analyzer.find_node_at_location("src/main.rs", Some(7));
assert!(result.is_some());
assert_eq!(result.unwrap(), inner);
}
#[test]
fn test_find_node_at_location_no_line_returns_file_node() {
let mut graph = DiGraph::new();
let file_node = graph.add_node("src/main.rs".to_string());
let fn_node = graph.add_node("src/main.rs::foo".to_string());
let cg = CodeGraph {
graph,
node_map: [
("src/main.rs".to_string(), file_node),
("src/main.rs::foo".to_string(), fn_node),
]
.into_iter()
.collect(),
nodes: vec![
make_node(
"src/main.rs",
"src/main.rs",
1,
100,
0,
false,
CodeNodeKind::File,
),
make_fn_node("src/main.rs::foo", "src/main.rs", 5, 10, 1),
],
edges: Vec::new(),
};
let analyzer = ImpactAnalyzer::new(&cg);
let result = analyzer.find_node_at_location("src/main.rs", None);
assert!(result.is_some());
assert_eq!(result.unwrap(), file_node);
}
#[test]
fn test_find_node_at_location_wrong_file() {
let cg = build_linear_graph();
let analyzer = ImpactAnalyzer::new(&cg);
let result = analyzer.find_node_at_location("nonexistent.rs", Some(5));
assert!(result.is_none());
}
#[test]
fn test_find_node_at_location_line_out_of_range() {
let cg = build_linear_graph();
let analyzer = ImpactAnalyzer::new(&cg);
let result = analyzer.find_node_at_location("src/main.rs", Some(999));
assert!(result.is_none());
}
#[test]
fn test_analyze_basic() {
// A (entry) -> B -> C
let mut graph = DiGraph::new();
let a = graph.add_node("a".to_string());
let b = graph.add_node("b".to_string());
let c = graph.add_node("c".to_string());
graph.add_edge(a, b, CodeEdgeKind::Calls);
graph.add_edge(b, c, CodeEdgeKind::Calls);
let mut node_map = HashMap::new();
node_map.insert("a".to_string(), a);
node_map.insert("b".to_string(), b);
node_map.insert("c".to_string(), c);
let cg = CodeGraph {
graph,
node_map,
nodes: vec![
make_node("a", "src/main.rs", 1, 5, 0, true, CodeNodeKind::Function),
make_fn_node("b", "src/main.rs", 7, 12, 1),
make_fn_node("c", "src/main.rs", 14, 20, 2),
],
edges: Vec::new(),
};
let analyzer = ImpactAnalyzer::new(&cg);
let result = analyzer.analyze("repo1", "finding1", "build1", "src/main.rs", Some(9));
// B's blast radius: C is reachable forward
assert_eq!(result.blast_radius, 1);
// B has A as direct caller
assert_eq!(result.direct_callers, vec!["a"]);
// B calls C
assert_eq!(result.direct_callees, vec!["c"]);
// A is an entry point that reaches B
assert_eq!(result.affected_entry_points, vec!["a"]);
}
#[test]
fn test_analyze_no_matching_node() {
let cg = build_linear_graph();
let analyzer = ImpactAnalyzer::new(&cg);
let result = analyzer.analyze("repo1", "f1", "b1", "nonexistent.rs", Some(1));
assert_eq!(result.blast_radius, 0);
assert!(result.affected_entry_points.is_empty());
assert!(result.direct_callers.is_empty());
assert!(result.direct_callees.is_empty());
}
}

View File

@@ -184,3 +184,115 @@ impl Default for ParserRegistry {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::path::PathBuf;
#[test]
fn test_supports_rust_extension() {
let registry = ParserRegistry::new();
assert!(registry.supports_extension("rs"));
}
#[test]
fn test_supports_python_extension() {
let registry = ParserRegistry::new();
assert!(registry.supports_extension("py"));
}
#[test]
fn test_supports_javascript_extension() {
let registry = ParserRegistry::new();
assert!(registry.supports_extension("js"));
}
#[test]
fn test_supports_typescript_extension() {
let registry = ParserRegistry::new();
assert!(registry.supports_extension("ts"));
}
#[test]
fn test_does_not_support_unknown_extension() {
let registry = ParserRegistry::new();
assert!(!registry.supports_extension("go"));
assert!(!registry.supports_extension("java"));
assert!(!registry.supports_extension("cpp"));
assert!(!registry.supports_extension(""));
}
#[test]
fn test_supported_extensions_includes_all() {
let registry = ParserRegistry::new();
let exts = registry.supported_extensions();
assert!(exts.contains(&"rs"));
assert!(exts.contains(&"py"));
assert!(exts.contains(&"js"));
assert!(exts.contains(&"ts"));
}
#[test]
fn test_supported_extensions_count() {
let registry = ParserRegistry::new();
let exts = registry.supported_extensions();
// At least 4 extensions (rs, py, js, ts); could be more if tsx, jsx etc.
assert!(exts.len() >= 4);
}
#[test]
fn test_parse_file_returns_none_for_unsupported() {
let registry = ParserRegistry::new();
let path = PathBuf::from("test.go");
let result = registry.parse_file(&path, "package main", "repo1", "build1");
assert!(result.is_ok());
assert!(result.unwrap().is_none());
}
#[test]
fn test_parse_file_rust_source() {
let registry = ParserRegistry::new();
let path = PathBuf::from("src/main.rs");
let source = "fn main() {\n println!(\"hello\");\n}\n";
let result = registry.parse_file(&path, source, "repo1", "build1");
assert!(result.is_ok());
let output = result.unwrap();
assert!(output.is_some());
let output = output.unwrap();
// Should have at least the file node and the main function node
assert!(output.nodes.len() >= 2);
}
#[test]
fn test_parse_file_python_source() {
let registry = ParserRegistry::new();
let path = PathBuf::from("app.py");
let source = "def hello():\n print('hi')\n";
let result = registry.parse_file(&path, source, "repo1", "build1");
assert!(result.is_ok());
let output = result.unwrap();
assert!(output.is_some());
let output = output.unwrap();
assert!(!output.nodes.is_empty());
}
#[test]
fn test_parse_file_empty_source() {
let registry = ParserRegistry::new();
let path = PathBuf::from("empty.rs");
let result = registry.parse_file(&path, "", "repo1", "build1");
assert!(result.is_ok());
let output = result.unwrap();
assert!(output.is_some());
// At minimum the file node
let output = output.unwrap();
assert!(!output.nodes.is_empty());
}
#[test]
fn test_default_trait() {
let registry = ParserRegistry::default();
assert!(registry.supports_extension("rs"));
}
}

View File

@@ -363,6 +363,214 @@ impl RustParser {
}
}
#[cfg(test)]
mod tests {
use super::*;
use compliance_core::traits::graph_builder::LanguageParser;
use std::path::PathBuf;
fn parse_rust(source: &str) -> ParseOutput {
let parser = RustParser::new();
parser
.parse_file(&PathBuf::from("test.rs"), source, "repo1", "build1")
.unwrap()
}
#[test]
fn test_extract_use_path_simple() {
let parser = RustParser::new();
assert_eq!(
parser.extract_use_path("use std::collections::HashMap;"),
Some("std::collections::HashMap".to_string())
);
}
#[test]
fn test_extract_use_path_nested() {
let parser = RustParser::new();
assert_eq!(
parser.extract_use_path("use crate::models::graph::CodeNode;"),
Some("crate::models::graph::CodeNode".to_string())
);
}
#[test]
fn test_extract_use_path_no_prefix() {
let parser = RustParser::new();
assert_eq!(parser.extract_use_path("let x = 5;"), None);
}
#[test]
fn test_extract_use_path_empty() {
let parser = RustParser::new();
assert_eq!(parser.extract_use_path(""), None);
}
#[test]
fn test_parse_function() {
let output = parse_rust("fn hello() {\n let x = 1;\n}\n");
let fn_nodes: Vec<_> = output
.nodes
.iter()
.filter(|n| n.kind == CodeNodeKind::Function)
.collect();
assert_eq!(fn_nodes.len(), 1);
assert_eq!(fn_nodes[0].name, "hello");
assert!(fn_nodes[0].qualified_name.contains("hello"));
}
#[test]
fn test_parse_struct() {
let output = parse_rust("struct Foo {\n x: i32,\n}\n");
let struct_nodes: Vec<_> = output
.nodes
.iter()
.filter(|n| n.kind == CodeNodeKind::Struct)
.collect();
assert_eq!(struct_nodes.len(), 1);
assert_eq!(struct_nodes[0].name, "Foo");
}
#[test]
fn test_parse_enum() {
let output = parse_rust("enum Color {\n Red,\n Blue,\n}\n");
let enum_nodes: Vec<_> = output
.nodes
.iter()
.filter(|n| n.kind == CodeNodeKind::Enum)
.collect();
assert_eq!(enum_nodes.len(), 1);
assert_eq!(enum_nodes[0].name, "Color");
}
#[test]
fn test_parse_trait() {
let output = parse_rust("trait Drawable {\n fn draw(&self);\n}\n");
let trait_nodes: Vec<_> = output
.nodes
.iter()
.filter(|n| n.kind == CodeNodeKind::Trait)
.collect();
assert_eq!(trait_nodes.len(), 1);
assert_eq!(trait_nodes[0].name, "Drawable");
}
#[test]
fn test_parse_file_node_always_created() {
let output = parse_rust("");
let file_nodes: Vec<_> = output
.nodes
.iter()
.filter(|n| n.kind == CodeNodeKind::File)
.collect();
assert_eq!(file_nodes.len(), 1);
assert_eq!(file_nodes[0].language, "rust");
}
#[test]
fn test_parse_multiple_functions() {
let source = "fn foo() {}\nfn bar() {}\nfn baz() {}\n";
let output = parse_rust(source);
let fn_nodes: Vec<_> = output
.nodes
.iter()
.filter(|n| n.kind == CodeNodeKind::Function)
.collect();
assert_eq!(fn_nodes.len(), 3);
}
#[test]
fn test_parse_main_is_entry_point() {
let output = parse_rust("fn main() {\n println!(\"hi\");\n}\n");
let main_node = output.nodes.iter().find(|n| n.name == "main").unwrap();
assert!(main_node.is_entry_point);
}
#[test]
fn test_parse_pub_fn_is_entry_point() {
let output = parse_rust("pub fn handler() {}\n");
let node = output.nodes.iter().find(|n| n.name == "handler").unwrap();
assert!(node.is_entry_point);
}
#[test]
fn test_parse_private_fn_is_not_entry_point() {
let output = parse_rust("fn helper() {}\n");
let node = output.nodes.iter().find(|n| n.name == "helper").unwrap();
assert!(!node.is_entry_point);
}
#[test]
fn test_parse_function_calls_create_edges() {
let source = "fn caller() {\n callee();\n}\nfn callee() {}\n";
let output = parse_rust(source);
let call_edges: Vec<_> = output
.edges
.iter()
.filter(|e| e.kind == CodeEdgeKind::Calls)
.collect();
assert!(!call_edges.is_empty());
assert!(call_edges.iter().any(|e| e.target.contains("callee")));
}
#[test]
fn test_parse_use_declaration_creates_import_edge() {
let source = "use std::collections::HashMap;\nfn foo() {}\n";
let output = parse_rust(source);
let import_edges: Vec<_> = output
.edges
.iter()
.filter(|e| e.kind == CodeEdgeKind::Imports)
.collect();
assert!(!import_edges.is_empty());
}
#[test]
fn test_parse_impl_methods() {
let source = "struct Foo {}\nimpl Foo {\n fn do_thing(&self) {}\n}\n";
let output = parse_rust(source);
let fn_nodes: Vec<_> = output
.nodes
.iter()
.filter(|n| n.kind == CodeNodeKind::Function)
.collect();
assert_eq!(fn_nodes.len(), 1);
assert_eq!(fn_nodes[0].name, "do_thing");
// Method should be qualified under the impl type
assert!(fn_nodes[0].qualified_name.contains("Foo"));
}
#[test]
fn test_parse_mod_item() {
let source = "mod inner {\n fn nested() {}\n}\n";
let output = parse_rust(source);
let mod_nodes: Vec<_> = output
.nodes
.iter()
.filter(|n| n.kind == CodeNodeKind::Module)
.collect();
assert_eq!(mod_nodes.len(), 1);
assert_eq!(mod_nodes[0].name, "inner");
}
#[test]
fn test_parse_line_numbers() {
let source = "fn first() {}\n\n\nfn second() {}\n";
let output = parse_rust(source);
let first = output.nodes.iter().find(|n| n.name == "first").unwrap();
let second = output.nodes.iter().find(|n| n.name == "second").unwrap();
assert_eq!(first.start_line, 1);
assert!(second.start_line > first.start_line);
}
#[test]
fn test_language_and_extensions() {
let parser = RustParser::new();
assert_eq!(parser.language(), "rust");
assert_eq!(parser.extensions(), &["rs"]);
}
}
impl LanguageParser for RustParser {
fn language(&self) -> &str {
"rust"

View File

@@ -128,3 +128,186 @@ impl SymbolIndex {
Ok(results)
}
}
#[cfg(test)]
mod tests {
use super::*;
use compliance_core::models::graph::CodeNodeKind;
fn make_node(
qualified_name: &str,
name: &str,
kind: CodeNodeKind,
file_path: &str,
language: &str,
) -> CodeNode {
CodeNode {
id: None,
repo_id: "test".to_string(),
graph_build_id: "build1".to_string(),
qualified_name: qualified_name.to_string(),
name: name.to_string(),
kind,
file_path: file_path.to_string(),
start_line: 1,
end_line: 10,
language: language.to_string(),
community_id: None,
is_entry_point: false,
graph_index: None,
}
}
#[test]
fn test_new_creates_index() {
let index = SymbolIndex::new();
assert!(index.is_ok());
}
#[test]
fn test_index_empty_nodes() {
let index = SymbolIndex::new().unwrap();
let result = index.index_nodes(&[]);
assert!(result.is_ok());
}
#[test]
fn test_index_and_search_single_node() {
let index = SymbolIndex::new().unwrap();
let nodes = vec![make_node(
"src/main.rs::main",
"main",
CodeNodeKind::Function,
"src/main.rs",
"rust",
)];
index.index_nodes(&nodes).unwrap();
let results = index.search("main", 10).unwrap();
assert!(!results.is_empty());
assert_eq!(results[0].name, "main");
assert_eq!(results[0].qualified_name, "src/main.rs::main");
}
#[test]
fn test_search_no_results() {
let index = SymbolIndex::new().unwrap();
let nodes = vec![make_node(
"src/main.rs::foo",
"foo",
CodeNodeKind::Function,
"src/main.rs",
"rust",
)];
index.index_nodes(&nodes).unwrap();
let results = index.search("zzzznonexistent", 10).unwrap();
assert!(results.is_empty());
}
#[test]
fn test_search_multiple_nodes() {
let index = SymbolIndex::new().unwrap();
let nodes = vec![
make_node(
"a.rs::handle_request",
"handle_request",
CodeNodeKind::Function,
"a.rs",
"rust",
),
make_node(
"b.rs::handle_response",
"handle_response",
CodeNodeKind::Function,
"b.rs",
"rust",
),
make_node(
"c.rs::process_data",
"process_data",
CodeNodeKind::Function,
"c.rs",
"rust",
),
];
index.index_nodes(&nodes).unwrap();
let results = index.search("handle", 10).unwrap();
assert!(results.len() >= 2);
}
#[test]
fn test_search_limit() {
let index = SymbolIndex::new().unwrap();
let mut nodes = Vec::new();
for i in 0..20 {
nodes.push(make_node(
&format!("mod::func_{i}"),
&format!("func_{i}"),
CodeNodeKind::Function,
"mod.rs",
"rust",
));
}
index.index_nodes(&nodes).unwrap();
let results = index.search("func", 5).unwrap();
assert!(results.len() <= 5);
}
#[test]
fn test_search_result_has_score() {
let index = SymbolIndex::new().unwrap();
let nodes = vec![make_node(
"src/lib.rs::compute",
"compute",
CodeNodeKind::Function,
"src/lib.rs",
"rust",
)];
index.index_nodes(&nodes).unwrap();
let results = index.search("compute", 10).unwrap();
assert!(!results.is_empty());
assert!(results[0].score > 0.0);
}
#[test]
fn test_search_result_fields() {
let index = SymbolIndex::new().unwrap();
let nodes = vec![make_node(
"src/app.py::MyClass",
"MyClass",
CodeNodeKind::Class,
"src/app.py",
"python",
)];
index.index_nodes(&nodes).unwrap();
let results = index.search("MyClass", 10).unwrap();
assert_eq!(results.len(), 1);
assert_eq!(results[0].name, "MyClass");
assert_eq!(results[0].kind, "class");
assert_eq!(results[0].file_path, "src/app.py");
assert_eq!(results[0].language, "python");
}
#[test]
fn test_search_empty_query() {
let index = SymbolIndex::new().unwrap();
let nodes = vec![make_node(
"src/lib.rs::foo",
"foo",
CodeNodeKind::Function,
"src/lib.rs",
"rust",
)];
index.index_nodes(&nodes).unwrap();
// Empty query may parse error or return empty - both acceptable
let result = index.search("", 10);
// Just verify it doesn't panic
let _ = result;
}
}

View File

@@ -0,0 +1,4 @@
// Tests for language parsers (Rust, TypeScript, JavaScript, Python).
//
// Test AST parsing, symbol extraction, and dependency graph construction
// using fixture source files.

View File

@@ -12,6 +12,66 @@ fn cap_limit(limit: Option<i64>) -> i64 {
limit.unwrap_or(DEFAULT_LIMIT).clamp(1, MAX_LIMIT)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn cap_limit_default() {
assert_eq!(cap_limit(None), DEFAULT_LIMIT);
}
#[test]
fn cap_limit_clamps_high() {
assert_eq!(cap_limit(Some(300)), MAX_LIMIT);
}
#[test]
fn cap_limit_clamps_low() {
assert_eq!(cap_limit(Some(0)), 1);
}
#[test]
fn list_dast_findings_params_deserialize() {
let json = serde_json::json!({
"target_id": "t1",
"scan_run_id": "sr1",
"severity": "critical",
"exploitable": true,
"vuln_type": "sql_injection",
"limit": 10
});
let params: ListDastFindingsParams = serde_json::from_value(json).unwrap();
assert_eq!(params.target_id.as_deref(), Some("t1"));
assert_eq!(params.scan_run_id.as_deref(), Some("sr1"));
assert_eq!(params.severity.as_deref(), Some("critical"));
assert_eq!(params.exploitable, Some(true));
assert_eq!(params.vuln_type.as_deref(), Some("sql_injection"));
assert_eq!(params.limit, Some(10));
}
#[test]
fn list_dast_findings_params_all_optional() {
let params: ListDastFindingsParams = serde_json::from_value(serde_json::json!({})).unwrap();
assert!(params.target_id.is_none());
assert!(params.scan_run_id.is_none());
assert!(params.severity.is_none());
assert!(params.exploitable.is_none());
assert!(params.vuln_type.is_none());
assert!(params.limit.is_none());
}
#[test]
fn dast_scan_summary_params_deserialize() {
let params: DastScanSummaryParams =
serde_json::from_value(serde_json::json!({ "target_id": "abc" })).unwrap();
assert_eq!(params.target_id.as_deref(), Some("abc"));
let params2: DastScanSummaryParams = serde_json::from_value(serde_json::json!({})).unwrap();
assert!(params2.target_id.is_none());
}
}
#[derive(Debug, Deserialize, JsonSchema)]
pub struct ListDastFindingsParams {
/// Filter by DAST target ID

View File

@@ -12,6 +12,89 @@ fn cap_limit(limit: Option<i64>) -> i64 {
limit.unwrap_or(DEFAULT_LIMIT).clamp(1, MAX_LIMIT)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn cap_limit_default() {
assert_eq!(cap_limit(None), DEFAULT_LIMIT);
}
#[test]
fn cap_limit_normal_value() {
assert_eq!(cap_limit(Some(100)), 100);
}
#[test]
fn cap_limit_exceeds_max() {
assert_eq!(cap_limit(Some(500)), MAX_LIMIT);
assert_eq!(cap_limit(Some(201)), MAX_LIMIT);
}
#[test]
fn cap_limit_zero_clamped_to_one() {
assert_eq!(cap_limit(Some(0)), 1);
}
#[test]
fn cap_limit_negative_clamped_to_one() {
assert_eq!(cap_limit(Some(-10)), 1);
}
#[test]
fn cap_limit_boundary_values() {
assert_eq!(cap_limit(Some(1)), 1);
assert_eq!(cap_limit(Some(MAX_LIMIT)), MAX_LIMIT);
}
#[test]
fn list_findings_params_deserialize() {
let json = serde_json::json!({
"repo_id": "abc",
"severity": "high",
"status": "open",
"scan_type": "sast",
"limit": 25
});
let params: ListFindingsParams = serde_json::from_value(json).unwrap();
assert_eq!(params.repo_id.as_deref(), Some("abc"));
assert_eq!(params.severity.as_deref(), Some("high"));
assert_eq!(params.status.as_deref(), Some("open"));
assert_eq!(params.scan_type.as_deref(), Some("sast"));
assert_eq!(params.limit, Some(25));
}
#[test]
fn list_findings_params_all_optional() {
let json = serde_json::json!({});
let params: ListFindingsParams = serde_json::from_value(json).unwrap();
assert!(params.repo_id.is_none());
assert!(params.severity.is_none());
assert!(params.status.is_none());
assert!(params.scan_type.is_none());
assert!(params.limit.is_none());
}
#[test]
fn get_finding_params_deserialize() {
let json = serde_json::json!({ "id": "507f1f77bcf86cd799439011" });
let params: GetFindingParams = serde_json::from_value(json).unwrap();
assert_eq!(params.id, "507f1f77bcf86cd799439011");
}
#[test]
fn findings_summary_params_deserialize() {
let json = serde_json::json!({ "repo_id": "r1" });
let params: FindingsSummaryParams = serde_json::from_value(json).unwrap();
assert_eq!(params.repo_id.as_deref(), Some("r1"));
let json2 = serde_json::json!({});
let params2: FindingsSummaryParams = serde_json::from_value(json2).unwrap();
assert!(params2.repo_id.is_none());
}
}
#[derive(Debug, Deserialize, JsonSchema)]
pub struct ListFindingsParams {
/// Filter by repository ID

View File

@@ -12,6 +12,90 @@ fn cap_limit(limit: Option<i64>) -> i64 {
limit.unwrap_or(DEFAULT_LIMIT).clamp(1, MAX_LIMIT)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn cap_limit_default() {
assert_eq!(cap_limit(None), DEFAULT_LIMIT);
}
#[test]
fn cap_limit_clamps_high() {
assert_eq!(cap_limit(Some(1000)), MAX_LIMIT);
}
#[test]
fn cap_limit_clamps_low() {
assert_eq!(cap_limit(Some(-100)), 1);
assert_eq!(cap_limit(Some(0)), 1);
}
#[test]
fn cap_limit_normal() {
assert_eq!(cap_limit(Some(42)), 42);
}
#[test]
fn list_pentest_sessions_params_deserialize() {
let json = serde_json::json!({
"target_id": "tgt",
"status": "running",
"strategy": "aggressive",
"limit": 5
});
let params: ListPentestSessionsParams = serde_json::from_value(json).unwrap();
assert_eq!(params.target_id.as_deref(), Some("tgt"));
assert_eq!(params.status.as_deref(), Some("running"));
assert_eq!(params.strategy.as_deref(), Some("aggressive"));
assert_eq!(params.limit, Some(5));
}
#[test]
fn list_pentest_sessions_params_all_optional() {
let params: ListPentestSessionsParams =
serde_json::from_value(serde_json::json!({})).unwrap();
assert!(params.target_id.is_none());
assert!(params.status.is_none());
assert!(params.strategy.is_none());
assert!(params.limit.is_none());
}
#[test]
fn get_pentest_session_params_deserialize() {
let params: GetPentestSessionParams =
serde_json::from_value(serde_json::json!({ "id": "abc123" })).unwrap();
assert_eq!(params.id, "abc123");
}
#[test]
fn get_attack_chain_params_deserialize() {
let params: GetAttackChainParams =
serde_json::from_value(serde_json::json!({ "session_id": "s1", "limit": 20 })).unwrap();
assert_eq!(params.session_id, "s1");
assert_eq!(params.limit, Some(20));
}
#[test]
fn get_pentest_messages_params_deserialize() {
let params: GetPentestMessagesParams =
serde_json::from_value(serde_json::json!({ "session_id": "s2" })).unwrap();
assert_eq!(params.session_id, "s2");
assert!(params.limit.is_none());
}
#[test]
fn pentest_stats_params_deserialize() {
let params: PentestStatsParams =
serde_json::from_value(serde_json::json!({ "target_id": "t1" })).unwrap();
assert_eq!(params.target_id.as_deref(), Some("t1"));
let params2: PentestStatsParams = serde_json::from_value(serde_json::json!({})).unwrap();
assert!(params2.target_id.is_none());
}
}
// ── List Pentest Sessions ──────────────────────────────────────
#[derive(Debug, Deserialize, JsonSchema)]

View File

@@ -12,6 +12,66 @@ fn cap_limit(limit: Option<i64>) -> i64 {
limit.unwrap_or(DEFAULT_LIMIT).clamp(1, MAX_LIMIT)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn cap_limit_default() {
assert_eq!(cap_limit(None), DEFAULT_LIMIT);
}
#[test]
fn cap_limit_clamps_high() {
assert_eq!(cap_limit(Some(999)), MAX_LIMIT);
}
#[test]
fn cap_limit_clamps_low() {
assert_eq!(cap_limit(Some(0)), 1);
assert_eq!(cap_limit(Some(-5)), 1);
}
#[test]
fn cap_limit_normal() {
assert_eq!(cap_limit(Some(75)), 75);
}
#[test]
fn list_sbom_params_deserialize() {
let json = serde_json::json!({
"repo_id": "repo1",
"has_vulns": true,
"package_manager": "npm",
"license": "MIT",
"limit": 30
});
let params: ListSbomPackagesParams = serde_json::from_value(json).unwrap();
assert_eq!(params.repo_id.as_deref(), Some("repo1"));
assert_eq!(params.has_vulns, Some(true));
assert_eq!(params.package_manager.as_deref(), Some("npm"));
assert_eq!(params.license.as_deref(), Some("MIT"));
assert_eq!(params.limit, Some(30));
}
#[test]
fn list_sbom_params_all_optional() {
let params: ListSbomPackagesParams = serde_json::from_value(serde_json::json!({})).unwrap();
assert!(params.repo_id.is_none());
assert!(params.has_vulns.is_none());
assert!(params.package_manager.is_none());
assert!(params.license.is_none());
assert!(params.limit.is_none());
}
#[test]
fn sbom_vuln_report_params_deserialize() {
let json = serde_json::json!({ "repo_id": "my-repo" });
let params: SbomVulnReportParams = serde_json::from_value(json).unwrap();
assert_eq!(params.repo_id, "my-repo");
}
}
#[derive(Debug, Deserialize, JsonSchema)]
pub struct ListSbomPackagesParams {
/// Filter by repository ID

View File

@@ -0,0 +1,4 @@
// Tests for MCP tool implementations.
//
// Test tool request/response formats, parameter validation,
// and database query construction.

16
fuzz/Cargo.toml Normal file
View File

@@ -0,0 +1,16 @@
[package]
name = "compliance-fuzz"
version = "0.0.0"
publish = false
edition = "2021"
[dependencies]
libfuzzer-sys = "0.4"
compliance-core = { path = "../compliance-core" }
# Fuzz targets are defined below. Add new targets as [[bin]] entries.
[[bin]]
name = "fuzz_finding_dedup"
path = "fuzz_targets/fuzz_finding_dedup.rs"
doc = false

View File

@@ -0,0 +1,12 @@
#![no_main]
use libfuzzer_sys::fuzz_target;
// Example fuzz target stub for finding deduplication logic.
// Replace with actual dedup function calls once ready.
fuzz_target!(|data: &[u8]| {
if let Ok(s) = std::str::from_utf8(data) {
// TODO: Call dedup/fingerprint functions with fuzzed input
let _ = s.len();
}
});