refactor: modularize codebase and add 404 unit tests
Some checks failed
CI / Format (push) Failing after 4s
CI / Format (pull_request) Failing after 4s
CI / Clippy (pull_request) Failing after 1m41s
CI / Security Audit (pull_request) Has been skipped
CI / Tests (pull_request) Has been skipped
CI / Clippy (push) Failing after 1m46s
CI / Security Audit (push) Has been skipped
CI / Tests (push) Has been skipped
CI / Detect Changes (push) Has been skipped
CI / Detect Changes (pull_request) Has been skipped
CI / Deploy Agent (push) Has been skipped
CI / Deploy Dashboard (push) Has been skipped
CI / Deploy Docs (push) Has been skipped
CI / Deploy MCP (push) Has been skipped
CI / Deploy Agent (pull_request) Has been skipped
CI / Deploy Dashboard (pull_request) Has been skipped
CI / Deploy Docs (pull_request) Has been skipped
CI / Deploy MCP (pull_request) Has been skipped

Split large files into focused modules across all crates while
maintaining API compatibility via re-exports. Add comprehensive
unit tests covering core models, pipeline parsers, LLM triage,
DAST security tools, graph algorithms, and MCP parameter validation.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Sharang Parnerkar
2026-03-12 16:59:05 +01:00
parent acc5b86aa4
commit 4e95fd7016
89 changed files with 11855 additions and 6032 deletions

View File

@@ -0,0 +1,481 @@
use compliance_core::models::TrackerType;
use serde::{Deserialize, Serialize};
use compliance_core::models::ScanRun;
#[derive(Deserialize)]
pub struct PaginationParams {
#[serde(default = "default_page")]
pub page: u64,
#[serde(default = "default_limit")]
pub limit: i64,
}
pub(crate) fn default_page() -> u64 {
1
}
pub(crate) fn default_limit() -> i64 {
50
}
#[derive(Deserialize)]
pub struct FindingsFilter {
#[serde(default)]
pub repo_id: Option<String>,
#[serde(default)]
pub severity: Option<String>,
#[serde(default)]
pub scan_type: Option<String>,
#[serde(default)]
pub status: Option<String>,
#[serde(default)]
pub q: Option<String>,
#[serde(default)]
pub sort_by: Option<String>,
#[serde(default)]
pub sort_order: Option<String>,
#[serde(default = "default_page")]
pub page: u64,
#[serde(default = "default_limit")]
pub limit: i64,
}
#[derive(Serialize)]
pub struct ApiResponse<T: Serialize> {
pub data: T,
#[serde(skip_serializing_if = "Option::is_none")]
pub total: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub page: Option<u64>,
}
#[derive(Serialize)]
pub struct OverviewStats {
pub total_repositories: u64,
pub total_findings: u64,
pub critical_findings: u64,
pub high_findings: u64,
pub medium_findings: u64,
pub low_findings: u64,
pub total_sbom_entries: u64,
pub total_cve_alerts: u64,
pub total_issues: u64,
pub recent_scans: Vec<ScanRun>,
}
#[derive(Deserialize)]
pub struct AddRepositoryRequest {
pub name: String,
pub git_url: String,
#[serde(default = "default_branch")]
pub default_branch: String,
pub auth_token: Option<String>,
pub auth_username: Option<String>,
pub tracker_type: Option<TrackerType>,
pub tracker_owner: Option<String>,
pub tracker_repo: Option<String>,
pub tracker_token: Option<String>,
pub scan_schedule: Option<String>,
}
#[derive(Deserialize)]
pub struct UpdateRepositoryRequest {
pub name: Option<String>,
pub default_branch: Option<String>,
pub auth_token: Option<String>,
pub auth_username: Option<String>,
pub tracker_type: Option<TrackerType>,
pub tracker_owner: Option<String>,
pub tracker_repo: Option<String>,
pub tracker_token: Option<String>,
pub scan_schedule: Option<String>,
}
fn default_branch() -> String {
"main".to_string()
}
#[derive(Deserialize)]
pub struct UpdateStatusRequest {
pub status: String,
}
#[derive(Deserialize)]
pub struct BulkUpdateStatusRequest {
pub ids: Vec<String>,
pub status: String,
}
#[derive(Deserialize)]
pub struct UpdateFeedbackRequest {
pub feedback: String,
}
#[derive(Deserialize)]
pub struct SbomFilter {
#[serde(default)]
pub repo_id: Option<String>,
#[serde(default)]
pub package_manager: Option<String>,
#[serde(default)]
pub q: Option<String>,
#[serde(default)]
pub has_vulns: Option<bool>,
#[serde(default)]
pub license: Option<String>,
#[serde(default = "default_page")]
pub page: u64,
#[serde(default = "default_limit")]
pub limit: i64,
}
#[derive(Deserialize)]
pub struct SbomExportParams {
pub repo_id: String,
#[serde(default = "default_export_format")]
pub format: String,
}
fn default_export_format() -> String {
"cyclonedx".to_string()
}
#[derive(Deserialize)]
pub struct SbomDiffParams {
pub repo_a: String,
pub repo_b: String,
}
#[derive(Serialize)]
pub struct LicenseSummary {
pub license: String,
pub count: u64,
pub is_copyleft: bool,
pub packages: Vec<String>,
}
#[derive(Serialize)]
pub struct SbomDiffResult {
pub only_in_a: Vec<SbomDiffEntry>,
pub only_in_b: Vec<SbomDiffEntry>,
pub version_changed: Vec<SbomVersionDiff>,
pub common_count: u64,
}
#[derive(Serialize)]
pub struct SbomDiffEntry {
pub name: String,
pub version: String,
pub package_manager: String,
}
#[derive(Serialize)]
pub struct SbomVersionDiff {
pub name: String,
pub package_manager: String,
pub version_a: String,
pub version_b: String,
}
pub(crate) type AgentExt = axum::extract::Extension<std::sync::Arc<crate::agent::ComplianceAgent>>;
pub(crate) type ApiResult<T> = Result<axum::Json<ApiResponse<T>>, axum::http::StatusCode>;
pub(crate) async fn collect_cursor_async<T: serde::de::DeserializeOwned + Unpin + Send>(
mut cursor: mongodb::Cursor<T>,
) -> Vec<T> {
use futures_util::StreamExt;
let mut items = Vec::new();
while let Some(result) = cursor.next().await {
match result {
Ok(item) => items.push(item),
Err(e) => tracing::warn!("Failed to deserialize document: {e}"),
}
}
items
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
// ── PaginationParams ─────────────────────────────────────────
#[test]
fn pagination_params_defaults() {
let p: PaginationParams = serde_json::from_str("{}").unwrap();
assert_eq!(p.page, 1);
assert_eq!(p.limit, 50);
}
#[test]
fn pagination_params_custom_values() {
let p: PaginationParams = serde_json::from_str(r#"{"page":3,"limit":10}"#).unwrap();
assert_eq!(p.page, 3);
assert_eq!(p.limit, 10);
}
#[test]
fn pagination_params_partial_override() {
let p: PaginationParams = serde_json::from_str(r#"{"page":5}"#).unwrap();
assert_eq!(p.page, 5);
assert_eq!(p.limit, 50);
}
#[test]
fn pagination_params_zero_page() {
let p: PaginationParams = serde_json::from_str(r#"{"page":0}"#).unwrap();
assert_eq!(p.page, 0);
}
// ── FindingsFilter ───────────────────────────────────────────
#[test]
fn findings_filter_all_defaults() {
let f: FindingsFilter = serde_json::from_str("{}").unwrap();
assert!(f.repo_id.is_none());
assert!(f.severity.is_none());
assert!(f.scan_type.is_none());
assert!(f.status.is_none());
assert!(f.q.is_none());
assert!(f.sort_by.is_none());
assert!(f.sort_order.is_none());
assert_eq!(f.page, 1);
assert_eq!(f.limit, 50);
}
#[test]
fn findings_filter_with_all_fields() {
let f: FindingsFilter = serde_json::from_str(
r#"{
"repo_id": "abc",
"severity": "high",
"scan_type": "sast",
"status": "open",
"q": "sql injection",
"sort_by": "severity",
"sort_order": "desc",
"page": 2,
"limit": 25
}"#,
)
.unwrap();
assert_eq!(f.repo_id.as_deref(), Some("abc"));
assert_eq!(f.severity.as_deref(), Some("high"));
assert_eq!(f.scan_type.as_deref(), Some("sast"));
assert_eq!(f.status.as_deref(), Some("open"));
assert_eq!(f.q.as_deref(), Some("sql injection"));
assert_eq!(f.sort_by.as_deref(), Some("severity"));
assert_eq!(f.sort_order.as_deref(), Some("desc"));
assert_eq!(f.page, 2);
assert_eq!(f.limit, 25);
}
#[test]
fn findings_filter_empty_string_fields() {
let f: FindingsFilter = serde_json::from_str(r#"{"repo_id":"","severity":""}"#).unwrap();
assert_eq!(f.repo_id.as_deref(), Some(""));
assert_eq!(f.severity.as_deref(), Some(""));
}
// ── ApiResponse ──────────────────────────────────────────────
#[test]
fn api_response_serializes_with_all_fields() {
let resp = ApiResponse {
data: vec!["a", "b"],
total: Some(100),
page: Some(1),
};
let v = serde_json::to_value(&resp).unwrap();
assert_eq!(v["data"], json!(["a", "b"]));
assert_eq!(v["total"], 100);
assert_eq!(v["page"], 1);
}
#[test]
fn api_response_skips_none_fields() {
let resp = ApiResponse {
data: "hello",
total: None,
page: None,
};
let v = serde_json::to_value(&resp).unwrap();
assert_eq!(v["data"], "hello");
assert!(v.get("total").is_none());
assert!(v.get("page").is_none());
}
#[test]
fn api_response_with_nested_struct() {
#[derive(Serialize)]
struct Item {
id: u32,
}
let resp = ApiResponse {
data: Item { id: 42 },
total: Some(1),
page: None,
};
let v = serde_json::to_value(&resp).unwrap();
assert_eq!(v["data"]["id"], 42);
assert_eq!(v["total"], 1);
assert!(v.get("page").is_none());
}
#[test]
fn api_response_empty_vec() {
let resp: ApiResponse<Vec<String>> = ApiResponse {
data: vec![],
total: Some(0),
page: Some(1),
};
let v = serde_json::to_value(&resp).unwrap();
assert!(v["data"].as_array().unwrap().is_empty());
}
// ── SbomFilter ───────────────────────────────────────────────
#[test]
fn sbom_filter_defaults() {
let f: SbomFilter = serde_json::from_str("{}").unwrap();
assert!(f.repo_id.is_none());
assert!(f.package_manager.is_none());
assert!(f.q.is_none());
assert!(f.has_vulns.is_none());
assert!(f.license.is_none());
assert_eq!(f.page, 1);
assert_eq!(f.limit, 50);
}
#[test]
fn sbom_filter_has_vulns_bool() {
let f: SbomFilter = serde_json::from_str(r#"{"has_vulns": true}"#).unwrap();
assert_eq!(f.has_vulns, Some(true));
}
// ── SbomExportParams ─────────────────────────────────────────
#[test]
fn sbom_export_params_default_format() {
let p: SbomExportParams = serde_json::from_str(r#"{"repo_id":"r1"}"#).unwrap();
assert_eq!(p.repo_id, "r1");
assert_eq!(p.format, "cyclonedx");
}
#[test]
fn sbom_export_params_custom_format() {
let p: SbomExportParams =
serde_json::from_str(r#"{"repo_id":"r1","format":"spdx"}"#).unwrap();
assert_eq!(p.format, "spdx");
}
// ── AddRepositoryRequest ─────────────────────────────────────
#[test]
fn add_repository_request_defaults() {
let r: AddRepositoryRequest = serde_json::from_str(
r#"{
"name": "my-repo",
"git_url": "https://github.com/x/y.git"
}"#,
)
.unwrap();
assert_eq!(r.name, "my-repo");
assert_eq!(r.git_url, "https://github.com/x/y.git");
assert_eq!(r.default_branch, "main");
assert!(r.auth_token.is_none());
assert!(r.tracker_type.is_none());
assert!(r.scan_schedule.is_none());
}
#[test]
fn add_repository_request_custom_branch() {
let r: AddRepositoryRequest = serde_json::from_str(
r#"{
"name": "repo",
"git_url": "url",
"default_branch": "develop"
}"#,
)
.unwrap();
assert_eq!(r.default_branch, "develop");
}
// ── UpdateStatusRequest / BulkUpdateStatusRequest ────────────
#[test]
fn update_status_request() {
let r: UpdateStatusRequest = serde_json::from_str(r#"{"status":"resolved"}"#).unwrap();
assert_eq!(r.status, "resolved");
}
#[test]
fn bulk_update_status_request() {
let r: BulkUpdateStatusRequest =
serde_json::from_str(r#"{"ids":["a","b"],"status":"dismissed"}"#).unwrap();
assert_eq!(r.ids, vec!["a", "b"]);
assert_eq!(r.status, "dismissed");
}
#[test]
fn bulk_update_status_empty_ids() {
let r: BulkUpdateStatusRequest =
serde_json::from_str(r#"{"ids":[],"status":"x"}"#).unwrap();
assert!(r.ids.is_empty());
}
// ── SbomDiffResult serialization ─────────────────────────────
#[test]
fn sbom_diff_result_serializes() {
let r = SbomDiffResult {
only_in_a: vec![SbomDiffEntry {
name: "pkg-a".to_string(),
version: "1.0".to_string(),
package_manager: "npm".to_string(),
}],
only_in_b: vec![],
version_changed: vec![SbomVersionDiff {
name: "shared".to_string(),
package_manager: "cargo".to_string(),
version_a: "0.1".to_string(),
version_b: "0.2".to_string(),
}],
common_count: 10,
};
let v = serde_json::to_value(&r).unwrap();
assert_eq!(v["only_in_a"].as_array().unwrap().len(), 1);
assert_eq!(v["only_in_b"].as_array().unwrap().len(), 0);
assert_eq!(v["version_changed"][0]["version_a"], "0.1");
assert_eq!(v["common_count"], 10);
}
// ── LicenseSummary ───────────────────────────────────────────
#[test]
fn license_summary_serializes() {
let ls = LicenseSummary {
license: "MIT".to_string(),
count: 42,
is_copyleft: false,
packages: vec!["serde".to_string()],
};
let v = serde_json::to_value(&ls).unwrap();
assert_eq!(v["license"], "MIT");
assert_eq!(v["is_copyleft"], false);
assert_eq!(v["count"], 42);
}
// ── Default helper functions ─────────────────────────────────
#[test]
fn default_page_returns_1() {
assert_eq!(default_page(), 1);
}
#[test]
fn default_limit_returns_50() {
assert_eq!(default_limit(), 50);
}
}

View File

@@ -0,0 +1,172 @@
use axum::extract::{Extension, Path, Query};
use axum::http::StatusCode;
use axum::Json;
use mongodb::bson::doc;
use super::dto::*;
use compliance_core::models::Finding;
#[tracing::instrument(skip_all, fields(repo_id = ?filter.repo_id, severity = ?filter.severity, scan_type = ?filter.scan_type))]
pub async fn list_findings(
Extension(agent): AgentExt,
Query(filter): Query<FindingsFilter>,
) -> ApiResult<Vec<Finding>> {
let db = &agent.db;
let mut query = doc! {};
if let Some(repo_id) = &filter.repo_id {
query.insert("repo_id", repo_id);
}
if let Some(severity) = &filter.severity {
query.insert("severity", severity);
}
if let Some(scan_type) = &filter.scan_type {
query.insert("scan_type", scan_type);
}
if let Some(status) = &filter.status {
query.insert("status", status);
}
// Text search across title, description, file_path, rule_id
if let Some(q) = &filter.q {
if !q.is_empty() {
let regex = doc! { "$regex": q, "$options": "i" };
query.insert(
"$or",
mongodb::bson::bson!([
{ "title": regex.clone() },
{ "description": regex.clone() },
{ "file_path": regex.clone() },
{ "rule_id": regex },
]),
);
}
}
// Dynamic sort
let sort_field = filter.sort_by.as_deref().unwrap_or("created_at");
let sort_dir: i32 = match filter.sort_order.as_deref() {
Some("asc") => 1,
_ => -1,
};
let sort_doc = doc! { sort_field: sort_dir };
let skip = (filter.page.saturating_sub(1)) * filter.limit as u64;
let total = db
.findings()
.count_documents(query.clone())
.await
.unwrap_or(0);
let findings = match db
.findings()
.find(query)
.sort(sort_doc)
.skip(skip)
.limit(filter.limit)
.await
{
Ok(cursor) => collect_cursor_async(cursor).await,
Err(e) => {
tracing::warn!("Failed to fetch findings: {e}");
Vec::new()
}
};
Ok(Json(ApiResponse {
data: findings,
total: Some(total),
page: Some(filter.page),
}))
}
#[tracing::instrument(skip_all, fields(finding_id = %id))]
pub async fn get_finding(
Extension(agent): AgentExt,
Path(id): Path<String>,
) -> Result<Json<ApiResponse<Finding>>, StatusCode> {
let oid = mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?;
let finding = agent
.db
.findings()
.find_one(doc! { "_id": oid })
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
.ok_or(StatusCode::NOT_FOUND)?;
Ok(Json(ApiResponse {
data: finding,
total: None,
page: None,
}))
}
#[tracing::instrument(skip_all, fields(finding_id = %id))]
pub async fn update_finding_status(
Extension(agent): AgentExt,
Path(id): Path<String>,
Json(req): Json<UpdateStatusRequest>,
) -> Result<Json<serde_json::Value>, StatusCode> {
let oid = mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?;
agent
.db
.findings()
.update_one(
doc! { "_id": oid },
doc! { "$set": { "status": &req.status, "updated_at": mongodb::bson::DateTime::now() } },
)
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
Ok(Json(serde_json::json!({ "status": "updated" })))
}
#[tracing::instrument(skip_all)]
pub async fn bulk_update_finding_status(
Extension(agent): AgentExt,
Json(req): Json<BulkUpdateStatusRequest>,
) -> Result<Json<serde_json::Value>, StatusCode> {
let oids: Vec<mongodb::bson::oid::ObjectId> = req
.ids
.iter()
.filter_map(|id| mongodb::bson::oid::ObjectId::parse_str(id).ok())
.collect();
if oids.is_empty() {
return Err(StatusCode::BAD_REQUEST);
}
let result = agent
.db
.findings()
.update_many(
doc! { "_id": { "$in": oids } },
doc! { "$set": { "status": &req.status, "updated_at": mongodb::bson::DateTime::now() } },
)
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
Ok(Json(
serde_json::json!({ "status": "updated", "modified_count": result.modified_count }),
))
}
#[tracing::instrument(skip_all)]
pub async fn update_finding_feedback(
Extension(agent): AgentExt,
Path(id): Path<String>,
Json(req): Json<UpdateFeedbackRequest>,
) -> Result<Json<serde_json::Value>, StatusCode> {
let oid = mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?;
agent
.db
.findings()
.update_one(
doc! { "_id": oid },
doc! { "$set": { "developer_feedback": &req.feedback, "updated_at": mongodb::bson::DateTime::now() } },
)
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
Ok(Json(serde_json::json!({ "status": "updated" })))
}

View File

@@ -0,0 +1,84 @@
use axum::Json;
use mongodb::bson::doc;
use super::dto::*;
use compliance_core::models::ScanRun;
#[tracing::instrument(skip_all)]
pub async fn health() -> Json<serde_json::Value> {
Json(serde_json::json!({ "status": "ok" }))
}
#[tracing::instrument(skip_all)]
pub async fn stats_overview(axum::extract::Extension(agent): AgentExt) -> ApiResult<OverviewStats> {
let db = &agent.db;
let total_repositories = db
.repositories()
.count_documents(doc! {})
.await
.unwrap_or(0);
let total_findings = db.findings().count_documents(doc! {}).await.unwrap_or(0);
let critical_findings = db
.findings()
.count_documents(doc! { "severity": "critical" })
.await
.unwrap_or(0);
let high_findings = db
.findings()
.count_documents(doc! { "severity": "high" })
.await
.unwrap_or(0);
let medium_findings = db
.findings()
.count_documents(doc! { "severity": "medium" })
.await
.unwrap_or(0);
let low_findings = db
.findings()
.count_documents(doc! { "severity": "low" })
.await
.unwrap_or(0);
let total_sbom_entries = db
.sbom_entries()
.count_documents(doc! {})
.await
.unwrap_or(0);
let total_cve_alerts = db.cve_alerts().count_documents(doc! {}).await.unwrap_or(0);
let total_issues = db
.tracker_issues()
.count_documents(doc! {})
.await
.unwrap_or(0);
let recent_scans: Vec<ScanRun> = match db
.scan_runs()
.find(doc! {})
.sort(doc! { "started_at": -1 })
.limit(10)
.await
{
Ok(cursor) => collect_cursor_async(cursor).await,
Err(e) => {
tracing::warn!("Failed to fetch recent scans: {e}");
Vec::new()
}
};
Ok(Json(ApiResponse {
data: OverviewStats {
total_repositories,
total_findings,
critical_findings,
high_findings,
medium_findings,
low_findings,
total_sbom_entries,
total_cve_alerts,
total_issues,
recent_scans,
},
total: None,
page: None,
}))
}

View File

@@ -0,0 +1,41 @@
use axum::extract::{Extension, Query};
use axum::Json;
use mongodb::bson::doc;
use super::dto::*;
use compliance_core::models::TrackerIssue;
#[tracing::instrument(skip_all)]
pub async fn list_issues(
Extension(agent): AgentExt,
Query(params): Query<PaginationParams>,
) -> ApiResult<Vec<TrackerIssue>> {
let db = &agent.db;
let skip = (params.page.saturating_sub(1)) * params.limit as u64;
let total = db
.tracker_issues()
.count_documents(doc! {})
.await
.unwrap_or(0);
let issues = match db
.tracker_issues()
.find(doc! {})
.sort(doc! { "created_at": -1 })
.skip(skip)
.limit(params.limit)
.await
{
Ok(cursor) => collect_cursor_async(cursor).await,
Err(e) => {
tracing::warn!("Failed to fetch tracker issues: {e}");
Vec::new()
}
};
Ok(Json(ApiResponse {
data: issues,
total: Some(total),
page: Some(params.page),
}))
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,131 @@
use std::sync::Arc;
use axum::extract::{Extension, Path};
use axum::http::StatusCode;
use axum::response::IntoResponse;
use axum::Json;
use mongodb::bson::doc;
use serde::Deserialize;
use compliance_core::models::dast::DastFinding;
use compliance_core::models::pentest::*;
use crate::agent::ComplianceAgent;
use super::super::dto::collect_cursor_async;
type AgentExt = Extension<Arc<ComplianceAgent>>;
#[derive(Deserialize)]
pub struct ExportBody {
pub password: String,
/// Requester display name (from auth)
#[serde(default)]
pub requester_name: String,
/// Requester email (from auth)
#[serde(default)]
pub requester_email: String,
}
/// POST /api/v1/pentest/sessions/:id/export — Export an encrypted pentest report archive
#[tracing::instrument(skip_all, fields(session_id = %id))]
pub async fn export_session_report(
Extension(agent): AgentExt,
Path(id): Path<String>,
Json(body): Json<ExportBody>,
) -> Result<axum::response::Response, (StatusCode, String)> {
let oid = mongodb::bson::oid::ObjectId::parse_str(&id)
.map_err(|_| (StatusCode::BAD_REQUEST, "Invalid session ID".to_string()))?;
if body.password.len() < 8 {
return Err((
StatusCode::BAD_REQUEST,
"Password must be at least 8 characters".to_string(),
));
}
// Fetch session
let session = agent
.db
.pentest_sessions()
.find_one(doc! { "_id": oid })
.await
.map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("Database error: {e}"),
)
})?
.ok_or_else(|| (StatusCode::NOT_FOUND, "Session not found".to_string()))?;
// Resolve target name
let target = if let Ok(tid) = mongodb::bson::oid::ObjectId::parse_str(&session.target_id) {
agent
.db
.dast_targets()
.find_one(doc! { "_id": tid })
.await
.ok()
.flatten()
} else {
None
};
let target_name = target
.as_ref()
.map(|t| t.name.clone())
.unwrap_or_else(|| "Unknown Target".to_string());
let target_url = target
.as_ref()
.map(|t| t.base_url.clone())
.unwrap_or_default();
// Fetch attack chain nodes
let nodes: Vec<AttackChainNode> = match agent
.db
.attack_chain_nodes()
.find(doc! { "session_id": &id })
.sort(doc! { "started_at": 1 })
.await
{
Ok(cursor) => collect_cursor_async(cursor).await,
Err(_) => Vec::new(),
};
// Fetch DAST findings for this session
let findings: Vec<DastFinding> = match agent
.db
.dast_findings()
.find(doc! { "session_id": &id })
.sort(doc! { "severity": -1, "created_at": -1 })
.await
{
Ok(cursor) => collect_cursor_async(cursor).await,
Err(_) => Vec::new(),
};
let ctx = crate::pentest::report::ReportContext {
session,
target_name,
target_url,
findings,
attack_chain: nodes,
requester_name: if body.requester_name.is_empty() {
"Unknown".to_string()
} else {
body.requester_name
},
requester_email: body.requester_email,
};
let report = crate::pentest::generate_encrypted_report(&ctx, &body.password)
.await
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e))?;
let response = serde_json::json!({
"archive_base64": base64::Engine::encode(&base64::engine::general_purpose::STANDARD, &report.archive),
"sha256": report.sha256,
"filename": format!("pentest-report-{id}.zip"),
});
Ok(Json(response).into_response())
}

View File

@@ -0,0 +1,9 @@
mod export;
mod session;
mod stats;
mod stream;
pub use export::*;
pub use session::*;
pub use stats::*;
pub use stream::*;

View File

@@ -2,20 +2,16 @@ use std::sync::Arc;
use axum::extract::{Extension, Path, Query};
use axum::http::StatusCode;
use axum::response::sse::{Event, Sse};
use axum::response::IntoResponse;
use axum::Json;
use futures_util::stream;
use mongodb::bson::doc;
use serde::Deserialize;
use compliance_core::models::dast::DastFinding;
use compliance_core::models::pentest::*;
use crate::agent::ComplianceAgent;
use crate::pentest::PentestOrchestrator;
use super::{collect_cursor_async, ApiResponse, PaginationParams};
use super::super::dto::{collect_cursor_async, ApiResponse, PaginationParams};
type AgentExt = Extension<Arc<ComplianceAgent>>;
@@ -160,8 +156,7 @@ pub async fn get_session(
Extension(agent): AgentExt,
Path(id): Path<String>,
) -> Result<Json<ApiResponse<PentestSession>>, StatusCode> {
let oid =
mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?;
let oid = mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?;
let session = agent
.db
@@ -210,13 +205,12 @@ pub async fn send_message(
}
// Look up the target
let target_oid =
mongodb::bson::oid::ObjectId::parse_str(&session.target_id).map_err(|_| {
(
StatusCode::INTERNAL_SERVER_ERROR,
"Invalid target_id in session".to_string(),
)
})?;
let target_oid = mongodb::bson::oid::ObjectId::parse_str(&session.target_id).map_err(|_| {
(
StatusCode::INTERNAL_SERVER_ERROR,
"Invalid target_id in session".to_string(),
)
})?;
let target = agent
.db
@@ -261,106 +255,6 @@ pub async fn send_message(
}))
}
/// GET /api/v1/pentest/sessions/:id/stream — SSE endpoint for real-time events
///
/// Returns recent messages as SSE events (polling approach).
/// True real-time streaming with broadcast channels will be added in a future iteration.
#[tracing::instrument(skip_all, fields(session_id = %id))]
pub async fn session_stream(
Extension(agent): AgentExt,
Path(id): Path<String>,
) -> Result<Sse<impl futures_util::Stream<Item = Result<Event, std::convert::Infallible>>>, StatusCode>
{
let oid =
mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?;
// Verify session exists
let _session = agent
.db
.pentest_sessions()
.find_one(doc! { "_id": oid })
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
.ok_or(StatusCode::NOT_FOUND)?;
// Fetch recent messages for this session
let messages: Vec<PentestMessage> = match agent
.db
.pentest_messages()
.find(doc! { "session_id": &id })
.sort(doc! { "created_at": 1 })
.limit(100)
.await
{
Ok(cursor) => collect_cursor_async(cursor).await,
Err(_) => Vec::new(),
};
// Fetch recent attack chain nodes
let nodes: Vec<AttackChainNode> = match agent
.db
.attack_chain_nodes()
.find(doc! { "session_id": &id })
.sort(doc! { "started_at": 1 })
.limit(100)
.await
{
Ok(cursor) => collect_cursor_async(cursor).await,
Err(_) => Vec::new(),
};
// Build SSE events from stored data
let mut events: Vec<Result<Event, std::convert::Infallible>> = Vec::new();
for msg in &messages {
let event_data = serde_json::json!({
"type": "message",
"role": msg.role,
"content": msg.content,
"created_at": msg.created_at.to_rfc3339(),
});
if let Ok(data) = serde_json::to_string(&event_data) {
events.push(Ok(Event::default().event("message").data(data)));
}
}
for node in &nodes {
let event_data = serde_json::json!({
"type": "tool_execution",
"node_id": node.node_id,
"tool_name": node.tool_name,
"status": node.status,
"findings_produced": node.findings_produced,
});
if let Ok(data) = serde_json::to_string(&event_data) {
events.push(Ok(Event::default().event("tool").data(data)));
}
}
// Add session status event
let session = agent
.db
.pentest_sessions()
.find_one(doc! { "_id": oid })
.await
.ok()
.flatten();
if let Some(s) = session {
let status_data = serde_json::json!({
"type": "status",
"status": s.status,
"findings_count": s.findings_count,
"tool_invocations": s.tool_invocations,
});
if let Ok(data) = serde_json::to_string(&status_data) {
events.push(Ok(Event::default().event("status").data(data)));
}
}
Ok(Sse::new(stream::iter(events)))
}
/// POST /api/v1/pentest/sessions/:id/stop — Stop a running pentest session
#[tracing::instrument(skip_all, fields(session_id = %id))]
pub async fn stop_session(
@@ -375,7 +269,12 @@ pub async fn stop_session(
.pentest_sessions()
.find_one(doc! { "_id": oid })
.await
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, format!("Database error: {e}")))?
.map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("Database error: {e}"),
)
})?
.ok_or_else(|| (StatusCode::NOT_FOUND, "Session not found".to_string()))?;
if session.status != PentestStatus::Running {
@@ -397,15 +296,30 @@ pub async fn stop_session(
}},
)
.await
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, format!("Database error: {e}")))?;
.map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("Database error: {e}"),
)
})?;
let updated = agent
.db
.pentest_sessions()
.find_one(doc! { "_id": oid })
.await
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, format!("Database error: {e}")))?
.ok_or_else(|| (StatusCode::NOT_FOUND, "Session not found after update".to_string()))?;
.map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("Database error: {e}"),
)
})?
.ok_or_else(|| {
(
StatusCode::NOT_FOUND,
"Session not found after update".to_string(),
)
})?;
Ok(Json(ApiResponse {
data: updated,
@@ -420,9 +334,7 @@ pub async fn get_attack_chain(
Extension(agent): AgentExt,
Path(id): Path<String>,
) -> Result<Json<ApiResponse<Vec<AttackChainNode>>>, StatusCode> {
// Verify the session ID is valid
let _oid =
mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?;
let _oid = mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?;
let nodes = match agent
.db
@@ -453,8 +365,7 @@ pub async fn get_messages(
Path(id): Path<String>,
Query(params): Query<PaginationParams>,
) -> Result<Json<ApiResponse<Vec<PentestMessage>>>, StatusCode> {
let _oid =
mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?;
let _oid = mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?;
let skip = (params.page.saturating_sub(1)) * params.limit as u64;
let total = agent
@@ -487,95 +398,14 @@ pub async fn get_messages(
}))
}
/// GET /api/v1/pentest/stats — Aggregated pentest statistics
#[tracing::instrument(skip_all)]
pub async fn pentest_stats(
Extension(agent): AgentExt,
) -> Result<Json<ApiResponse<PentestStats>>, StatusCode> {
let db = &agent.db;
let running_sessions = db
.pentest_sessions()
.count_documents(doc! { "status": "running" })
.await
.unwrap_or(0) as u32;
// Count DAST findings from pentest sessions
let total_vulnerabilities = db
.dast_findings()
.count_documents(doc! { "session_id": { "$exists": true, "$ne": null } })
.await
.unwrap_or(0) as u32;
// Aggregate tool invocations from all sessions
let sessions: Vec<PentestSession> = match db.pentest_sessions().find(doc! {}).await {
Ok(cursor) => collect_cursor_async(cursor).await,
Err(_) => Vec::new(),
};
let total_tool_invocations: u32 = sessions.iter().map(|s| s.tool_invocations).sum();
let total_successes: u32 = sessions.iter().map(|s| s.tool_successes).sum();
let tool_success_rate = if total_tool_invocations == 0 {
100.0
} else {
(total_successes as f64 / total_tool_invocations as f64) * 100.0
};
// Severity distribution from pentest-related DAST findings
let critical = db
.dast_findings()
.count_documents(doc! { "session_id": { "$exists": true, "$ne": null }, "severity": "critical" })
.await
.unwrap_or(0) as u32;
let high = db
.dast_findings()
.count_documents(doc! { "session_id": { "$exists": true, "$ne": null }, "severity": "high" })
.await
.unwrap_or(0) as u32;
let medium = db
.dast_findings()
.count_documents(doc! { "session_id": { "$exists": true, "$ne": null }, "severity": "medium" })
.await
.unwrap_or(0) as u32;
let low = db
.dast_findings()
.count_documents(doc! { "session_id": { "$exists": true, "$ne": null }, "severity": "low" })
.await
.unwrap_or(0) as u32;
let info = db
.dast_findings()
.count_documents(doc! { "session_id": { "$exists": true, "$ne": null }, "severity": "info" })
.await
.unwrap_or(0) as u32;
Ok(Json(ApiResponse {
data: PentestStats {
running_sessions,
total_vulnerabilities,
total_tool_invocations,
tool_success_rate,
severity_distribution: SeverityDistribution {
critical,
high,
medium,
low,
info,
},
},
total: None,
page: None,
}))
}
/// GET /api/v1/pentest/sessions/:id/findings — Get DAST findings for a pentest session
#[tracing::instrument(skip_all, fields(session_id = %id))]
pub async fn get_session_findings(
Extension(agent): AgentExt,
Path(id): Path<String>,
Query(params): Query<PaginationParams>,
) -> Result<Json<ApiResponse<Vec<DastFinding>>>, StatusCode> {
let _oid =
mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?;
) -> Result<Json<ApiResponse<Vec<compliance_core::models::dast::DastFinding>>>, StatusCode> {
let _oid = mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?;
let skip = (params.page.saturating_sub(1)) * params.limit as u64;
let total = agent
@@ -607,112 +437,3 @@ pub async fn get_session_findings(
page: Some(params.page),
}))
}
#[derive(Deserialize)]
pub struct ExportBody {
pub password: String,
/// Requester display name (from auth)
#[serde(default)]
pub requester_name: String,
/// Requester email (from auth)
#[serde(default)]
pub requester_email: String,
}
/// POST /api/v1/pentest/sessions/:id/export — Export an encrypted pentest report archive
#[tracing::instrument(skip_all, fields(session_id = %id))]
pub async fn export_session_report(
Extension(agent): AgentExt,
Path(id): Path<String>,
Json(body): Json<ExportBody>,
) -> Result<axum::response::Response, (StatusCode, String)> {
let oid = mongodb::bson::oid::ObjectId::parse_str(&id)
.map_err(|_| (StatusCode::BAD_REQUEST, "Invalid session ID".to_string()))?;
if body.password.len() < 8 {
return Err((
StatusCode::BAD_REQUEST,
"Password must be at least 8 characters".to_string(),
));
}
// Fetch session
let session = agent
.db
.pentest_sessions()
.find_one(doc! { "_id": oid })
.await
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, format!("Database error: {e}")))?
.ok_or_else(|| (StatusCode::NOT_FOUND, "Session not found".to_string()))?;
// Resolve target name
let target = if let Ok(tid) = mongodb::bson::oid::ObjectId::parse_str(&session.target_id) {
agent
.db
.dast_targets()
.find_one(doc! { "_id": tid })
.await
.ok()
.flatten()
} else {
None
};
let target_name = target
.as_ref()
.map(|t| t.name.clone())
.unwrap_or_else(|| "Unknown Target".to_string());
let target_url = target
.as_ref()
.map(|t| t.base_url.clone())
.unwrap_or_default();
// Fetch attack chain nodes
let nodes: Vec<AttackChainNode> = match agent
.db
.attack_chain_nodes()
.find(doc! { "session_id": &id })
.sort(doc! { "started_at": 1 })
.await
{
Ok(cursor) => collect_cursor_async(cursor).await,
Err(_) => Vec::new(),
};
// Fetch DAST findings for this session
let findings: Vec<DastFinding> = match agent
.db
.dast_findings()
.find(doc! { "session_id": &id })
.sort(doc! { "severity": -1, "created_at": -1 })
.await
{
Ok(cursor) => collect_cursor_async(cursor).await,
Err(_) => Vec::new(),
};
let ctx = crate::pentest::report::ReportContext {
session,
target_name,
target_url,
findings,
attack_chain: nodes,
requester_name: if body.requester_name.is_empty() {
"Unknown".to_string()
} else {
body.requester_name
},
requester_email: body.requester_email,
};
let report = crate::pentest::generate_encrypted_report(&ctx, &body.password)
.await
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e))?;
let response = serde_json::json!({
"archive_base64": base64::Engine::encode(&base64::engine::general_purpose::STANDARD, &report.archive),
"sha256": report.sha256,
"filename": format!("pentest-report-{id}.zip"),
});
Ok(Json(response).into_response())
}

View File

@@ -0,0 +1,102 @@
use std::sync::Arc;
use axum::extract::Extension;
use axum::http::StatusCode;
use axum::Json;
use mongodb::bson::doc;
use compliance_core::models::pentest::*;
use crate::agent::ComplianceAgent;
use super::super::dto::{collect_cursor_async, ApiResponse};
type AgentExt = Extension<Arc<ComplianceAgent>>;
/// GET /api/v1/pentest/stats — Aggregated pentest statistics
#[tracing::instrument(skip_all)]
pub async fn pentest_stats(
Extension(agent): AgentExt,
) -> Result<Json<ApiResponse<PentestStats>>, StatusCode> {
let db = &agent.db;
let running_sessions = db
.pentest_sessions()
.count_documents(doc! { "status": "running" })
.await
.unwrap_or(0) as u32;
// Count DAST findings from pentest sessions
let total_vulnerabilities = db
.dast_findings()
.count_documents(doc! { "session_id": { "$exists": true, "$ne": null } })
.await
.unwrap_or(0) as u32;
// Aggregate tool invocations from all sessions
let sessions: Vec<PentestSession> = match db.pentest_sessions().find(doc! {}).await {
Ok(cursor) => collect_cursor_async(cursor).await,
Err(_) => Vec::new(),
};
let total_tool_invocations: u32 = sessions.iter().map(|s| s.tool_invocations).sum();
let total_successes: u32 = sessions.iter().map(|s| s.tool_successes).sum();
let tool_success_rate = if total_tool_invocations == 0 {
100.0
} else {
(total_successes as f64 / total_tool_invocations as f64) * 100.0
};
// Severity distribution from pentest-related DAST findings
let critical = db
.dast_findings()
.count_documents(
doc! { "session_id": { "$exists": true, "$ne": null }, "severity": "critical" },
)
.await
.unwrap_or(0) as u32;
let high = db
.dast_findings()
.count_documents(
doc! { "session_id": { "$exists": true, "$ne": null }, "severity": "high" },
)
.await
.unwrap_or(0) as u32;
let medium = db
.dast_findings()
.count_documents(
doc! { "session_id": { "$exists": true, "$ne": null }, "severity": "medium" },
)
.await
.unwrap_or(0) as u32;
let low = db
.dast_findings()
.count_documents(doc! { "session_id": { "$exists": true, "$ne": null }, "severity": "low" })
.await
.unwrap_or(0) as u32;
let info = db
.dast_findings()
.count_documents(
doc! { "session_id": { "$exists": true, "$ne": null }, "severity": "info" },
)
.await
.unwrap_or(0) as u32;
Ok(Json(ApiResponse {
data: PentestStats {
running_sessions,
total_vulnerabilities,
total_tool_invocations,
tool_success_rate,
severity_distribution: SeverityDistribution {
critical,
high,
medium,
low,
info,
},
},
total: None,
page: None,
}))
}

View File

@@ -0,0 +1,116 @@
use std::sync::Arc;
use axum::extract::{Extension, Path};
use axum::http::StatusCode;
use axum::response::sse::{Event, Sse};
use futures_util::stream;
use mongodb::bson::doc;
use compliance_core::models::pentest::*;
use crate::agent::ComplianceAgent;
use super::super::dto::collect_cursor_async;
type AgentExt = Extension<Arc<ComplianceAgent>>;
/// GET /api/v1/pentest/sessions/:id/stream — SSE endpoint for real-time events
///
/// Returns recent messages as SSE events (polling approach).
/// True real-time streaming with broadcast channels will be added in a future iteration.
#[tracing::instrument(skip_all, fields(session_id = %id))]
pub async fn session_stream(
Extension(agent): AgentExt,
Path(id): Path<String>,
) -> Result<
Sse<impl futures_util::Stream<Item = Result<Event, std::convert::Infallible>>>,
StatusCode,
> {
let oid = mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?;
// Verify session exists
let _session = agent
.db
.pentest_sessions()
.find_one(doc! { "_id": oid })
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
.ok_or(StatusCode::NOT_FOUND)?;
// Fetch recent messages for this session
let messages: Vec<PentestMessage> = match agent
.db
.pentest_messages()
.find(doc! { "session_id": &id })
.sort(doc! { "created_at": 1 })
.limit(100)
.await
{
Ok(cursor) => collect_cursor_async(cursor).await,
Err(_) => Vec::new(),
};
// Fetch recent attack chain nodes
let nodes: Vec<AttackChainNode> = match agent
.db
.attack_chain_nodes()
.find(doc! { "session_id": &id })
.sort(doc! { "started_at": 1 })
.limit(100)
.await
{
Ok(cursor) => collect_cursor_async(cursor).await,
Err(_) => Vec::new(),
};
// Build SSE events from stored data
let mut events: Vec<Result<Event, std::convert::Infallible>> = Vec::new();
for msg in &messages {
let event_data = serde_json::json!({
"type": "message",
"role": msg.role,
"content": msg.content,
"created_at": msg.created_at.to_rfc3339(),
});
if let Ok(data) = serde_json::to_string(&event_data) {
events.push(Ok(Event::default().event("message").data(data)));
}
}
for node in &nodes {
let event_data = serde_json::json!({
"type": "tool_execution",
"node_id": node.node_id,
"tool_name": node.tool_name,
"status": node.status,
"findings_produced": node.findings_produced,
});
if let Ok(data) = serde_json::to_string(&event_data) {
events.push(Ok(Event::default().event("tool").data(data)));
}
}
// Add session status event
let session = agent
.db
.pentest_sessions()
.find_one(doc! { "_id": oid })
.await
.ok()
.flatten();
if let Some(s) = session {
let status_data = serde_json::json!({
"type": "status",
"status": s.status,
"findings_count": s.findings_count,
"tool_invocations": s.tool_invocations,
});
if let Ok(data) = serde_json::to_string(&status_data) {
events.push(Ok(Event::default().event("status").data(data)));
}
}
Ok(Sse::new(stream::iter(events)))
}

View File

@@ -0,0 +1,241 @@
use axum::extract::{Extension, Path, Query};
use axum::http::StatusCode;
use axum::Json;
use mongodb::bson::doc;
use super::dto::*;
use compliance_core::models::*;
#[tracing::instrument(skip_all)]
pub async fn list_repositories(
Extension(agent): AgentExt,
Query(params): Query<PaginationParams>,
) -> ApiResult<Vec<TrackedRepository>> {
let db = &agent.db;
let skip = (params.page.saturating_sub(1)) * params.limit as u64;
let total = db
.repositories()
.count_documents(doc! {})
.await
.unwrap_or(0);
let repos = match db
.repositories()
.find(doc! {})
.skip(skip)
.limit(params.limit)
.await
{
Ok(cursor) => collect_cursor_async(cursor).await,
Err(e) => {
tracing::warn!("Failed to fetch repositories: {e}");
Vec::new()
}
};
Ok(Json(ApiResponse {
data: repos,
total: Some(total),
page: Some(params.page),
}))
}
#[tracing::instrument(skip_all)]
pub async fn add_repository(
Extension(agent): AgentExt,
Json(req): Json<AddRepositoryRequest>,
) -> Result<Json<ApiResponse<TrackedRepository>>, (StatusCode, String)> {
// Validate repository access before saving
let creds = crate::pipeline::git::RepoCredentials {
ssh_key_path: Some(agent.config.ssh_key_path.clone()),
auth_token: req.auth_token.clone(),
auth_username: req.auth_username.clone(),
};
if let Err(e) = crate::pipeline::git::GitOps::test_access(&req.git_url, &creds) {
return Err((
StatusCode::BAD_REQUEST,
format!("Cannot access repository: {e}"),
));
}
let mut repo = TrackedRepository::new(req.name, req.git_url);
repo.default_branch = req.default_branch;
repo.auth_token = req.auth_token;
repo.auth_username = req.auth_username;
repo.tracker_type = req.tracker_type;
repo.tracker_owner = req.tracker_owner;
repo.tracker_repo = req.tracker_repo;
repo.tracker_token = req.tracker_token;
repo.scan_schedule = req.scan_schedule;
agent
.db
.repositories()
.insert_one(&repo)
.await
.map_err(|_| {
(
StatusCode::CONFLICT,
"Repository already exists".to_string(),
)
})?;
Ok(Json(ApiResponse {
data: repo,
total: None,
page: None,
}))
}
#[tracing::instrument(skip_all, fields(repo_id = %id))]
pub async fn update_repository(
Extension(agent): AgentExt,
Path(id): Path<String>,
Json(req): Json<UpdateRepositoryRequest>,
) -> Result<Json<serde_json::Value>, StatusCode> {
let oid = mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?;
let mut set_doc = doc! { "updated_at": mongodb::bson::DateTime::now() };
if let Some(name) = &req.name {
set_doc.insert("name", name);
}
if let Some(branch) = &req.default_branch {
set_doc.insert("default_branch", branch);
}
if let Some(token) = &req.auth_token {
set_doc.insert("auth_token", token);
}
if let Some(username) = &req.auth_username {
set_doc.insert("auth_username", username);
}
if let Some(tracker_type) = &req.tracker_type {
set_doc.insert("tracker_type", tracker_type.to_string());
}
if let Some(owner) = &req.tracker_owner {
set_doc.insert("tracker_owner", owner);
}
if let Some(repo) = &req.tracker_repo {
set_doc.insert("tracker_repo", repo);
}
if let Some(token) = &req.tracker_token {
set_doc.insert("tracker_token", token);
}
if let Some(schedule) = &req.scan_schedule {
set_doc.insert("scan_schedule", schedule);
}
let result = agent
.db
.repositories()
.update_one(doc! { "_id": oid }, doc! { "$set": set_doc })
.await
.map_err(|e| {
tracing::warn!("Failed to update repository: {e}");
StatusCode::INTERNAL_SERVER_ERROR
})?;
if result.matched_count == 0 {
return Err(StatusCode::NOT_FOUND);
}
Ok(Json(serde_json::json!({ "status": "updated" })))
}
#[tracing::instrument(skip_all)]
pub async fn get_ssh_public_key(
Extension(agent): AgentExt,
) -> Result<Json<serde_json::Value>, StatusCode> {
let public_path = format!("{}.pub", agent.config.ssh_key_path);
let public_key = std::fs::read_to_string(&public_path).map_err(|_| StatusCode::NOT_FOUND)?;
Ok(Json(serde_json::json!({ "public_key": public_key.trim() })))
}
#[tracing::instrument(skip_all, fields(repo_id = %id))]
pub async fn trigger_scan(
Extension(agent): AgentExt,
Path(id): Path<String>,
) -> Result<Json<serde_json::Value>, StatusCode> {
let agent_clone = (*agent).clone();
tokio::spawn(async move {
if let Err(e) = agent_clone.run_scan(&id, ScanTrigger::Manual).await {
tracing::error!("Manual scan failed for {id}: {e}");
}
});
Ok(Json(serde_json::json!({ "status": "scan_triggered" })))
}
/// Return the webhook secret for a repository (used by dashboard to display it)
pub async fn get_webhook_config(
Extension(agent): AgentExt,
Path(id): Path<String>,
) -> Result<Json<serde_json::Value>, StatusCode> {
let oid = mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?;
let repo = agent
.db
.repositories()
.find_one(doc! { "_id": oid })
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
.ok_or(StatusCode::NOT_FOUND)?;
let tracker_type = repo
.tracker_type
.as_ref()
.map(|t| t.to_string())
.unwrap_or_else(|| "gitea".to_string());
Ok(Json(serde_json::json!({
"webhook_secret": repo.webhook_secret,
"tracker_type": tracker_type,
})))
}
#[tracing::instrument(skip_all, fields(repo_id = %id))]
pub async fn delete_repository(
Extension(agent): AgentExt,
Path(id): Path<String>,
) -> Result<Json<serde_json::Value>, StatusCode> {
let oid = mongodb::bson::oid::ObjectId::parse_str(&id).map_err(|_| StatusCode::BAD_REQUEST)?;
let db = &agent.db;
// Delete the repository
let result = db
.repositories()
.delete_one(doc! { "_id": oid })
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
if result.deleted_count == 0 {
return Err(StatusCode::NOT_FOUND);
}
// Cascade delete all related data
let _ = db.findings().delete_many(doc! { "repo_id": &id }).await;
let _ = db.sbom_entries().delete_many(doc! { "repo_id": &id }).await;
let _ = db.scan_runs().delete_many(doc! { "repo_id": &id }).await;
let _ = db.cve_alerts().delete_many(doc! { "repo_id": &id }).await;
let _ = db
.tracker_issues()
.delete_many(doc! { "repo_id": &id })
.await;
let _ = db.graph_nodes().delete_many(doc! { "repo_id": &id }).await;
let _ = db.graph_edges().delete_many(doc! { "repo_id": &id }).await;
let _ = db.graph_builds().delete_many(doc! { "repo_id": &id }).await;
let _ = db
.impact_analyses()
.delete_many(doc! { "repo_id": &id })
.await;
let _ = db
.code_embeddings()
.delete_many(doc! { "repo_id": &id })
.await;
let _ = db
.embedding_builds()
.delete_many(doc! { "repo_id": &id })
.await;
Ok(Json(serde_json::json!({ "status": "deleted" })))
}

View File

@@ -0,0 +1,379 @@
use axum::extract::{Extension, Query};
use axum::http::{header, StatusCode};
use axum::response::IntoResponse;
use axum::Json;
use mongodb::bson::doc;
use super::dto::*;
use compliance_core::models::SbomEntry;
const COPYLEFT_LICENSES: &[&str] = &[
"GPL-2.0",
"GPL-2.0-only",
"GPL-2.0-or-later",
"GPL-3.0",
"GPL-3.0-only",
"GPL-3.0-or-later",
"AGPL-3.0",
"AGPL-3.0-only",
"AGPL-3.0-or-later",
"LGPL-2.1",
"LGPL-2.1-only",
"LGPL-2.1-or-later",
"LGPL-3.0",
"LGPL-3.0-only",
"LGPL-3.0-or-later",
"MPL-2.0",
];
#[tracing::instrument(skip_all)]
pub async fn sbom_filters(
Extension(agent): AgentExt,
) -> Result<Json<serde_json::Value>, StatusCode> {
let db = &agent.db;
let managers: Vec<String> = db
.sbom_entries()
.distinct("package_manager", doc! {})
.await
.unwrap_or_default()
.into_iter()
.filter_map(|v| v.as_str().map(|s| s.to_string()))
.filter(|s| !s.is_empty() && s != "unknown" && s != "file")
.collect();
let licenses: Vec<String> = db
.sbom_entries()
.distinct("license", doc! {})
.await
.unwrap_or_default()
.into_iter()
.filter_map(|v| v.as_str().map(|s| s.to_string()))
.filter(|s| !s.is_empty())
.collect();
Ok(Json(serde_json::json!({
"package_managers": managers,
"licenses": licenses,
})))
}
#[tracing::instrument(skip_all, fields(repo_id = ?filter.repo_id, package_manager = ?filter.package_manager))]
pub async fn list_sbom(
Extension(agent): AgentExt,
Query(filter): Query<SbomFilter>,
) -> ApiResult<Vec<SbomEntry>> {
let db = &agent.db;
let mut query = doc! {};
if let Some(repo_id) = &filter.repo_id {
query.insert("repo_id", repo_id);
}
if let Some(pm) = &filter.package_manager {
query.insert("package_manager", pm);
}
if let Some(q) = &filter.q {
if !q.is_empty() {
query.insert("name", doc! { "$regex": q, "$options": "i" });
}
}
if let Some(has_vulns) = filter.has_vulns {
if has_vulns {
query.insert("known_vulnerabilities", doc! { "$exists": true, "$ne": [] });
} else {
query.insert("known_vulnerabilities", doc! { "$size": 0 });
}
}
if let Some(license) = &filter.license {
query.insert("license", license);
}
let skip = (filter.page.saturating_sub(1)) * filter.limit as u64;
let total = db
.sbom_entries()
.count_documents(query.clone())
.await
.unwrap_or(0);
let entries = match db
.sbom_entries()
.find(query)
.sort(doc! { "name": 1 })
.skip(skip)
.limit(filter.limit)
.await
{
Ok(cursor) => collect_cursor_async(cursor).await,
Err(e) => {
tracing::warn!("Failed to fetch SBOM entries: {e}");
Vec::new()
}
};
Ok(Json(ApiResponse {
data: entries,
total: Some(total),
page: Some(filter.page),
}))
}
#[tracing::instrument(skip_all)]
pub async fn export_sbom(
Extension(agent): AgentExt,
Query(params): Query<SbomExportParams>,
) -> Result<impl IntoResponse, StatusCode> {
let db = &agent.db;
let entries: Vec<SbomEntry> = match db
.sbom_entries()
.find(doc! { "repo_id": &params.repo_id })
.await
{
Ok(cursor) => collect_cursor_async(cursor).await,
Err(e) => {
tracing::warn!("Failed to fetch SBOM entries for export: {e}");
Vec::new()
}
};
let body = if params.format == "spdx" {
// SPDX 2.3 format
let packages: Vec<serde_json::Value> = entries
.iter()
.enumerate()
.map(|(i, e)| {
serde_json::json!({
"SPDXID": format!("SPDXRef-Package-{i}"),
"name": e.name,
"versionInfo": e.version,
"downloadLocation": "NOASSERTION",
"licenseConcluded": e.license.as_deref().unwrap_or("NOASSERTION"),
"externalRefs": e.purl.as_ref().map(|p| vec![serde_json::json!({
"referenceCategory": "PACKAGE-MANAGER",
"referenceType": "purl",
"referenceLocator": p,
})]).unwrap_or_default(),
})
})
.collect();
serde_json::json!({
"spdxVersion": "SPDX-2.3",
"dataLicense": "CC0-1.0",
"SPDXID": "SPDXRef-DOCUMENT",
"name": format!("sbom-{}", params.repo_id),
"documentNamespace": format!("https://compliance-scanner/sbom/{}", params.repo_id),
"packages": packages,
})
} else {
// CycloneDX 1.5 format
let components: Vec<serde_json::Value> = entries
.iter()
.map(|e| {
let mut comp = serde_json::json!({
"type": "library",
"name": e.name,
"version": e.version,
"group": e.package_manager,
});
if let Some(purl) = &e.purl {
comp["purl"] = serde_json::Value::String(purl.clone());
}
if let Some(license) = &e.license {
comp["licenses"] = serde_json::json!([{ "license": { "id": license } }]);
}
if !e.known_vulnerabilities.is_empty() {
comp["vulnerabilities"] = serde_json::json!(
e.known_vulnerabilities.iter().map(|v| serde_json::json!({
"id": v.id,
"source": { "name": v.source },
"ratings": v.severity.as_ref().map(|s| vec![serde_json::json!({"severity": s})]).unwrap_or_default(),
})).collect::<Vec<_>>()
);
}
comp
})
.collect();
serde_json::json!({
"bomFormat": "CycloneDX",
"specVersion": "1.5",
"version": 1,
"metadata": {
"component": {
"type": "application",
"name": format!("repo-{}", params.repo_id),
}
},
"components": components,
})
};
let json_str =
serde_json::to_string_pretty(&body).map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
let filename = if params.format == "spdx" {
format!("sbom-{}-spdx.json", params.repo_id)
} else {
format!("sbom-{}-cyclonedx.json", params.repo_id)
};
let disposition = format!("attachment; filename=\"{filename}\"");
Ok((
[
(
header::CONTENT_TYPE,
header::HeaderValue::from_static("application/json"),
),
(
header::CONTENT_DISPOSITION,
header::HeaderValue::from_str(&disposition)
.unwrap_or_else(|_| header::HeaderValue::from_static("attachment")),
),
],
json_str,
))
}
#[tracing::instrument(skip_all)]
pub async fn license_summary(
Extension(agent): AgentExt,
Query(params): Query<SbomFilter>,
) -> ApiResult<Vec<LicenseSummary>> {
let db = &agent.db;
let mut query = doc! {};
if let Some(repo_id) = &params.repo_id {
query.insert("repo_id", repo_id);
}
let entries: Vec<SbomEntry> = match db.sbom_entries().find(query).await {
Ok(cursor) => collect_cursor_async(cursor).await,
Err(e) => {
tracing::warn!("Failed to fetch SBOM entries for license summary: {e}");
Vec::new()
}
};
let mut license_map: std::collections::HashMap<String, Vec<String>> =
std::collections::HashMap::new();
for entry in &entries {
let lic = entry.license.as_deref().unwrap_or("Unknown").to_string();
license_map.entry(lic).or_default().push(entry.name.clone());
}
let mut summaries: Vec<LicenseSummary> = license_map
.into_iter()
.map(|(license, packages)| {
let is_copyleft = COPYLEFT_LICENSES
.iter()
.any(|c| license.to_uppercase().contains(&c.to_uppercase()));
LicenseSummary {
license,
count: packages.len() as u64,
is_copyleft,
packages,
}
})
.collect();
summaries.sort_by(|a, b| b.count.cmp(&a.count));
Ok(Json(ApiResponse {
data: summaries,
total: None,
page: None,
}))
}
#[tracing::instrument(skip_all)]
pub async fn sbom_diff(
Extension(agent): AgentExt,
Query(params): Query<SbomDiffParams>,
) -> ApiResult<SbomDiffResult> {
let db = &agent.db;
let entries_a: Vec<SbomEntry> = match db
.sbom_entries()
.find(doc! { "repo_id": &params.repo_a })
.await
{
Ok(cursor) => collect_cursor_async(cursor).await,
Err(e) => {
tracing::warn!("Failed to fetch SBOM entries for repo_a: {e}");
Vec::new()
}
};
let entries_b: Vec<SbomEntry> = match db
.sbom_entries()
.find(doc! { "repo_id": &params.repo_b })
.await
{
Ok(cursor) => collect_cursor_async(cursor).await,
Err(e) => {
tracing::warn!("Failed to fetch SBOM entries for repo_b: {e}");
Vec::new()
}
};
// Build maps by (name, package_manager) -> version
let map_a: std::collections::HashMap<(String, String), String> = entries_a
.iter()
.map(|e| {
(
(e.name.clone(), e.package_manager.clone()),
e.version.clone(),
)
})
.collect();
let map_b: std::collections::HashMap<(String, String), String> = entries_b
.iter()
.map(|e| {
(
(e.name.clone(), e.package_manager.clone()),
e.version.clone(),
)
})
.collect();
let mut only_in_a = Vec::new();
let mut version_changed = Vec::new();
let mut common_count: u64 = 0;
for (key, ver_a) in &map_a {
match map_b.get(key) {
None => only_in_a.push(SbomDiffEntry {
name: key.0.clone(),
version: ver_a.clone(),
package_manager: key.1.clone(),
}),
Some(ver_b) if ver_a != ver_b => {
version_changed.push(SbomVersionDiff {
name: key.0.clone(),
package_manager: key.1.clone(),
version_a: ver_a.clone(),
version_b: ver_b.clone(),
});
}
Some(_) => common_count += 1,
}
}
let only_in_b: Vec<SbomDiffEntry> = map_b
.iter()
.filter(|(key, _)| !map_a.contains_key(key))
.map(|(key, ver)| SbomDiffEntry {
name: key.0.clone(),
version: ver.clone(),
package_manager: key.1.clone(),
})
.collect();
Ok(Json(ApiResponse {
data: SbomDiffResult {
only_in_a,
only_in_b,
version_changed,
common_count,
},
total: None,
page: None,
}))
}

View File

@@ -0,0 +1,37 @@
use axum::extract::{Extension, Query};
use axum::Json;
use mongodb::bson::doc;
use super::dto::*;
use compliance_core::models::ScanRun;
#[tracing::instrument(skip_all)]
pub async fn list_scan_runs(
Extension(agent): AgentExt,
Query(params): Query<PaginationParams>,
) -> ApiResult<Vec<ScanRun>> {
let db = &agent.db;
let skip = (params.page.saturating_sub(1)) * params.limit as u64;
let total = db.scan_runs().count_documents(doc! {}).await.unwrap_or(0);
let scans = match db
.scan_runs()
.find(doc! {})
.sort(doc! { "started_at": -1 })
.skip(skip)
.limit(params.limit)
.await
{
Ok(cursor) => collect_cursor_async(cursor).await,
Err(e) => {
tracing::warn!("Failed to fetch scan runs: {e}");
Vec::new()
}
};
Ok(Json(ApiResponse {
data: scans,
total: Some(total),
page: Some(params.page),
}))
}

View File

@@ -136,7 +136,10 @@ pub fn build_router() -> Router {
"/api/v1/pentest/sessions/{id}/export",
post(handlers::pentest::export_session_report),
)
.route("/api/v1/pentest/stats", get(handlers::pentest::pentest_stats))
.route(
"/api/v1/pentest/stats",
get(handlers::pentest::pentest_stats),
)
// Webhook endpoints (proxied through dashboard)
.route(
"/webhook/github/{repo_id}",

View File

@@ -1,147 +1,17 @@
use secrecy::{ExposeSecret, SecretString};
use serde::{Deserialize, Serialize};
use super::types::*;
use crate::error::AgentError;
#[derive(Clone)]
pub struct LlmClient {
base_url: String,
api_key: SecretString,
model: String,
embed_model: String,
http: reqwest::Client,
pub(crate) base_url: String,
pub(crate) api_key: SecretString,
pub(crate) model: String,
pub(crate) embed_model: String,
pub(crate) http: reqwest::Client,
}
// ── Request types ──────────────────────────────────────────────
#[derive(Serialize, Clone, Debug)]
pub struct ChatMessage {
pub role: String,
pub content: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_calls: Option<Vec<ToolCallRequest>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_call_id: Option<String>,
}
#[derive(Serialize)]
struct ChatCompletionRequest {
model: String,
messages: Vec<ChatMessage>,
#[serde(skip_serializing_if = "Option::is_none")]
temperature: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
max_tokens: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
tools: Option<Vec<ToolDefinitionPayload>>,
}
#[derive(Serialize)]
struct ToolDefinitionPayload {
r#type: String,
function: ToolFunctionPayload,
}
#[derive(Serialize)]
struct ToolFunctionPayload {
name: String,
description: String,
parameters: serde_json::Value,
}
// ── Response types ─────────────────────────────────────────────
#[derive(Deserialize)]
struct ChatCompletionResponse {
choices: Vec<ChatChoice>,
}
#[derive(Deserialize)]
struct ChatChoice {
message: ChatResponseMessage,
}
#[derive(Deserialize)]
struct ChatResponseMessage {
#[serde(default)]
content: Option<String>,
#[serde(default)]
tool_calls: Option<Vec<ToolCallResponse>>,
}
#[derive(Deserialize)]
struct ToolCallResponse {
id: String,
function: ToolCallFunction,
}
#[derive(Deserialize)]
struct ToolCallFunction {
name: String,
arguments: String,
}
// ── Public types for tool calling ──────────────────────────────
/// Definition of a tool that the LLM can invoke
#[derive(Debug, Clone, Serialize)]
pub struct ToolDefinition {
pub name: String,
pub description: String,
pub parameters: serde_json::Value,
}
/// A tool call request from the LLM
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LlmToolCall {
pub id: String,
pub name: String,
pub arguments: serde_json::Value,
}
/// A tool call in the request message format (for sending back tool_calls in assistant messages)
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolCallRequest {
pub id: String,
pub r#type: String,
pub function: ToolCallRequestFunction,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolCallRequestFunction {
pub name: String,
pub arguments: String,
}
/// Response from the LLM — either content or tool calls
#[derive(Debug, Clone)]
pub enum LlmResponse {
Content(String),
/// Tool calls with optional reasoning text from the LLM
ToolCalls { calls: Vec<LlmToolCall>, reasoning: String },
}
// ── Embedding types ────────────────────────────────────────────
#[derive(Serialize)]
struct EmbeddingRequest {
model: String,
input: Vec<String>,
}
#[derive(Deserialize)]
struct EmbeddingResponse {
data: Vec<EmbeddingData>,
}
#[derive(Deserialize)]
struct EmbeddingData {
embedding: Vec<f64>,
index: usize,
}
// ── Implementation ─────────────────────────────────────────────
impl LlmClient {
pub fn new(
base_url: String,
@@ -158,18 +28,14 @@ impl LlmClient {
}
}
pub fn embed_model(&self) -> &str {
&self.embed_model
}
fn chat_url(&self) -> String {
pub(crate) fn chat_url(&self) -> String {
format!(
"{}/v1/chat/completions",
self.base_url.trim_end_matches('/')
)
}
fn auth_header(&self) -> Option<String> {
pub(crate) fn auth_header(&self) -> Option<String> {
let key = self.api_key.expose_secret();
if key.is_empty() {
None
@@ -241,12 +107,12 @@ impl LlmClient {
tools: None,
};
self.send_chat_request(&request_body).await.map(|resp| {
match resp {
self.send_chat_request(&request_body)
.await
.map(|resp| match resp {
LlmResponse::Content(c) => c,
LlmResponse::ToolCalls { .. } => String::new(),
}
})
})
}
/// Chat with tool definitions — returns either content or tool calls.
@@ -345,54 +211,7 @@ impl LlmClient {
}
// Otherwise return content
let content = choice
.message
.content
.clone()
.unwrap_or_default();
let content = choice.message.content.clone().unwrap_or_default();
Ok(LlmResponse::Content(content))
}
/// Generate embeddings for a batch of texts
pub async fn embed(&self, texts: Vec<String>) -> Result<Vec<Vec<f64>>, AgentError> {
let url = format!("{}/v1/embeddings", self.base_url.trim_end_matches('/'));
let request_body = EmbeddingRequest {
model: self.embed_model.clone(),
input: texts,
};
let mut req = self
.http
.post(&url)
.header("content-type", "application/json")
.json(&request_body);
if let Some(auth) = self.auth_header() {
req = req.header("Authorization", auth);
}
let resp = req
.send()
.await
.map_err(|e| AgentError::Other(format!("Embedding request failed: {e}")))?;
if !resp.status().is_success() {
let status = resp.status();
let body = resp.text().await.unwrap_or_default();
return Err(AgentError::Other(format!(
"Embedding API returned {status}: {body}"
)));
}
let body: EmbeddingResponse = resp
.json()
.await
.map_err(|e| AgentError::Other(format!("Failed to parse embedding response: {e}")))?;
let mut data = body.data;
data.sort_by_key(|d| d.index);
Ok(data.into_iter().map(|d| d.embedding).collect())
}
}

View File

@@ -0,0 +1,74 @@
use serde::{Deserialize, Serialize};
use super::client::LlmClient;
use crate::error::AgentError;
// ── Embedding types ────────────────────────────────────────────
#[derive(Serialize)]
struct EmbeddingRequest {
model: String,
input: Vec<String>,
}
#[derive(Deserialize)]
struct EmbeddingResponse {
data: Vec<EmbeddingData>,
}
#[derive(Deserialize)]
struct EmbeddingData {
embedding: Vec<f64>,
index: usize,
}
// ── Embedding implementation ───────────────────────────────────
impl LlmClient {
pub fn embed_model(&self) -> &str {
&self.embed_model
}
/// Generate embeddings for a batch of texts
pub async fn embed(&self, texts: Vec<String>) -> Result<Vec<Vec<f64>>, AgentError> {
let url = format!("{}/v1/embeddings", self.base_url.trim_end_matches('/'));
let request_body = EmbeddingRequest {
model: self.embed_model.clone(),
input: texts,
};
let mut req = self
.http
.post(&url)
.header("content-type", "application/json")
.json(&request_body);
if let Some(auth) = self.auth_header() {
req = req.header("Authorization", auth);
}
let resp = req
.send()
.await
.map_err(|e| AgentError::Other(format!("Embedding request failed: {e}")))?;
if !resp.status().is_success() {
let status = resp.status();
let body = resp.text().await.unwrap_or_default();
return Err(AgentError::Other(format!(
"Embedding API returned {status}: {body}"
)));
}
let body: EmbeddingResponse = resp
.json()
.await
.map_err(|e| AgentError::Other(format!("Failed to parse embedding response: {e}")))?;
let mut data = body.data;
data.sort_by_key(|d| d.index);
Ok(data.into_iter().map(|d| d.embedding).collect())
}
}

View File

@@ -1,11 +1,16 @@
pub mod client;
#[allow(dead_code)]
pub mod descriptions;
pub mod embedding;
#[allow(dead_code)]
pub mod fixes;
#[allow(dead_code)]
pub mod pr_review;
pub mod review_prompts;
pub mod triage;
pub mod types;
pub use client::LlmClient;
pub use types::{
ChatMessage, LlmResponse, ToolCallRequest, ToolCallRequestFunction, ToolDefinition,
};

View File

@@ -278,3 +278,220 @@ struct TriageResult {
fn default_action() -> String {
"confirm".to_string()
}
#[cfg(test)]
mod tests {
use super::*;
use compliance_core::models::Severity;
// ── classify_file_path ───────────────────────────────────────
#[test]
fn classify_none_path() {
assert_eq!(classify_file_path(None), "unknown");
}
#[test]
fn classify_production_path() {
assert_eq!(classify_file_path(Some("src/main.rs")), "production");
assert_eq!(classify_file_path(Some("lib/core/engine.py")), "production");
}
#[test]
fn classify_test_paths() {
assert_eq!(classify_file_path(Some("src/test/helper.rs")), "test");
assert_eq!(classify_file_path(Some("src/tests/unit.rs")), "test");
assert_eq!(classify_file_path(Some("foo_test.go")), "test");
assert_eq!(classify_file_path(Some("bar.test.js")), "test");
assert_eq!(classify_file_path(Some("baz.spec.ts")), "test");
assert_eq!(
classify_file_path(Some("data/fixtures/sample.json")),
"test"
);
assert_eq!(classify_file_path(Some("src/testdata/input.txt")), "test");
}
#[test]
fn classify_example_paths() {
assert_eq!(
classify_file_path(Some("docs/examples/basic.rs")),
"example"
);
// /example matches because contains("/example")
assert_eq!(classify_file_path(Some("src/example/main.py")), "example");
assert_eq!(classify_file_path(Some("src/demo/run.sh")), "example");
assert_eq!(classify_file_path(Some("src/sample/lib.rs")), "example");
}
#[test]
fn classify_generated_paths() {
assert_eq!(
classify_file_path(Some("src/generated/api.rs")),
"generated"
);
assert_eq!(
classify_file_path(Some("proto/gen/service.go")),
"generated"
);
assert_eq!(classify_file_path(Some("api.generated.ts")), "generated");
assert_eq!(classify_file_path(Some("service.pb.go")), "generated");
assert_eq!(classify_file_path(Some("model_generated.rs")), "generated");
}
#[test]
fn classify_vendored_paths() {
// Implementation checks for /vendor/, /node_modules/, /third_party/ (with slashes)
assert_eq!(
classify_file_path(Some("src/vendor/lib/foo.go")),
"vendored"
);
assert_eq!(
classify_file_path(Some("src/node_modules/pkg/index.js")),
"vendored"
);
assert_eq!(
classify_file_path(Some("src/third_party/lib.c")),
"vendored"
);
}
#[test]
fn classify_is_case_insensitive() {
assert_eq!(classify_file_path(Some("src/TEST/Helper.rs")), "test");
assert_eq!(classify_file_path(Some("src/VENDOR/lib.go")), "vendored");
assert_eq!(
classify_file_path(Some("src/GENERATED/foo.ts")),
"generated"
);
}
// ── adjust_confidence ────────────────────────────────────────
#[test]
fn adjust_confidence_production() {
assert_eq!(adjust_confidence(8.0, "production"), 8.0);
}
#[test]
fn adjust_confidence_test() {
assert_eq!(adjust_confidence(10.0, "test"), 5.0);
}
#[test]
fn adjust_confidence_example() {
assert_eq!(adjust_confidence(10.0, "example"), 6.0);
}
#[test]
fn adjust_confidence_generated() {
assert_eq!(adjust_confidence(10.0, "generated"), 3.0);
}
#[test]
fn adjust_confidence_vendored() {
assert_eq!(adjust_confidence(10.0, "vendored"), 4.0);
}
#[test]
fn adjust_confidence_unknown_classification() {
assert_eq!(adjust_confidence(7.0, "unknown"), 7.0);
assert_eq!(adjust_confidence(7.0, "something_else"), 7.0);
}
#[test]
fn adjust_confidence_zero() {
assert_eq!(adjust_confidence(0.0, "test"), 0.0);
assert_eq!(adjust_confidence(0.0, "production"), 0.0);
}
// ── downgrade_severity ───────────────────────────────────────
#[test]
fn downgrade_severity_all_levels() {
assert_eq!(downgrade_severity(&Severity::Critical), Severity::High);
assert_eq!(downgrade_severity(&Severity::High), Severity::Medium);
assert_eq!(downgrade_severity(&Severity::Medium), Severity::Low);
assert_eq!(downgrade_severity(&Severity::Low), Severity::Info);
assert_eq!(downgrade_severity(&Severity::Info), Severity::Info);
}
#[test]
fn downgrade_severity_info_is_floor() {
// Downgrading Info twice should still be Info
let s = downgrade_severity(&Severity::Info);
assert_eq!(downgrade_severity(&s), Severity::Info);
}
// ── upgrade_severity ─────────────────────────────────────────
#[test]
fn upgrade_severity_all_levels() {
assert_eq!(upgrade_severity(&Severity::Info), Severity::Low);
assert_eq!(upgrade_severity(&Severity::Low), Severity::Medium);
assert_eq!(upgrade_severity(&Severity::Medium), Severity::High);
assert_eq!(upgrade_severity(&Severity::High), Severity::Critical);
assert_eq!(upgrade_severity(&Severity::Critical), Severity::Critical);
}
#[test]
fn upgrade_severity_critical_is_ceiling() {
let s = upgrade_severity(&Severity::Critical);
assert_eq!(upgrade_severity(&s), Severity::Critical);
}
// ── upgrade/downgrade roundtrip ──────────────────────────────
#[test]
fn upgrade_then_downgrade_is_identity_for_middle_values() {
for sev in [Severity::Low, Severity::Medium, Severity::High] {
assert_eq!(downgrade_severity(&upgrade_severity(&sev)), sev);
}
}
// ── TriageResult deserialization ─────────────────────────────
#[test]
fn triage_result_full() {
let json = r#"{"action":"dismiss","confidence":8.5,"rationale":"false positive","remediation":"remove code"}"#;
let r: TriageResult = serde_json::from_str(json).unwrap();
assert_eq!(r.action, "dismiss");
assert_eq!(r.confidence, 8.5);
assert_eq!(r.rationale, "false positive");
assert_eq!(r.remediation.as_deref(), Some("remove code"));
}
#[test]
fn triage_result_defaults() {
let json = r#"{}"#;
let r: TriageResult = serde_json::from_str(json).unwrap();
assert_eq!(r.action, "confirm");
assert_eq!(r.confidence, 0.0);
assert_eq!(r.rationale, "");
assert!(r.remediation.is_none());
}
#[test]
fn triage_result_partial() {
let json = r#"{"action":"downgrade","confidence":6.0}"#;
let r: TriageResult = serde_json::from_str(json).unwrap();
assert_eq!(r.action, "downgrade");
assert_eq!(r.confidence, 6.0);
assert_eq!(r.rationale, "");
assert!(r.remediation.is_none());
}
#[test]
fn triage_result_with_markdown_fences() {
// Simulate LLM wrapping response in markdown code fences
let raw = "```json\n{\"action\":\"upgrade\",\"confidence\":9,\"rationale\":\"critical\",\"remediation\":null}\n```";
let cleaned = raw
.trim()
.trim_start_matches("```json")
.trim_start_matches("```")
.trim_end_matches("```")
.trim();
let r: TriageResult = serde_json::from_str(cleaned).unwrap();
assert_eq!(r.action, "upgrade");
assert_eq!(r.confidence, 9.0);
}
}

View File

@@ -0,0 +1,369 @@
use serde::{Deserialize, Serialize};
// ── Request types ──────────────────────────────────────────────
#[derive(Serialize, Clone, Debug)]
pub struct ChatMessage {
pub role: String,
pub content: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_calls: Option<Vec<ToolCallRequest>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_call_id: Option<String>,
}
#[derive(Serialize)]
pub(crate) struct ChatCompletionRequest {
pub(crate) model: String,
pub(crate) messages: Vec<ChatMessage>,
#[serde(skip_serializing_if = "Option::is_none")]
pub(crate) temperature: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub(crate) max_tokens: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub(crate) tools: Option<Vec<ToolDefinitionPayload>>,
}
#[derive(Serialize)]
pub(crate) struct ToolDefinitionPayload {
pub(crate) r#type: String,
pub(crate) function: ToolFunctionPayload,
}
#[derive(Serialize)]
pub(crate) struct ToolFunctionPayload {
pub(crate) name: String,
pub(crate) description: String,
pub(crate) parameters: serde_json::Value,
}
// ── Response types ─────────────────────────────────────────────
#[derive(Deserialize)]
pub(crate) struct ChatCompletionResponse {
pub(crate) choices: Vec<ChatChoice>,
}
#[derive(Deserialize)]
pub(crate) struct ChatChoice {
pub(crate) message: ChatResponseMessage,
}
#[derive(Deserialize)]
pub(crate) struct ChatResponseMessage {
#[serde(default)]
pub(crate) content: Option<String>,
#[serde(default)]
pub(crate) tool_calls: Option<Vec<ToolCallResponse>>,
}
#[derive(Deserialize)]
pub(crate) struct ToolCallResponse {
pub(crate) id: String,
pub(crate) function: ToolCallFunction,
}
#[derive(Deserialize)]
pub(crate) struct ToolCallFunction {
pub(crate) name: String,
pub(crate) arguments: String,
}
// ── Public types for tool calling ──────────────────────────────
/// Definition of a tool that the LLM can invoke
#[derive(Debug, Clone, Serialize)]
pub struct ToolDefinition {
pub name: String,
pub description: String,
pub parameters: serde_json::Value,
}
/// A tool call request from the LLM
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LlmToolCall {
pub id: String,
pub name: String,
pub arguments: serde_json::Value,
}
/// A tool call in the request message format (for sending back tool_calls in assistant messages)
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolCallRequest {
pub id: String,
pub r#type: String,
pub function: ToolCallRequestFunction,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolCallRequestFunction {
pub name: String,
pub arguments: String,
}
/// Response from the LLM — either content or tool calls
#[derive(Debug, Clone)]
pub enum LlmResponse {
Content(String),
/// Tool calls with optional reasoning text from the LLM
ToolCalls {
calls: Vec<LlmToolCall>,
reasoning: String,
},
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
// ── ChatMessage ──────────────────────────────────────────────
#[test]
fn chat_message_serializes_minimal() {
let msg = ChatMessage {
role: "user".to_string(),
content: Some("hello".to_string()),
tool_calls: None,
tool_call_id: None,
};
let v = serde_json::to_value(&msg).unwrap();
assert_eq!(v["role"], "user");
assert_eq!(v["content"], "hello");
// None fields with skip_serializing_if should be absent
assert!(v.get("tool_calls").is_none());
assert!(v.get("tool_call_id").is_none());
}
#[test]
fn chat_message_serializes_with_tool_calls() {
let msg = ChatMessage {
role: "assistant".to_string(),
content: None,
tool_calls: Some(vec![ToolCallRequest {
id: "call_1".to_string(),
r#type: "function".to_string(),
function: ToolCallRequestFunction {
name: "get_weather".to_string(),
arguments: r#"{"city":"NYC"}"#.to_string(),
},
}]),
tool_call_id: None,
};
let v = serde_json::to_value(&msg).unwrap();
assert!(v["tool_calls"].is_array());
assert_eq!(v["tool_calls"][0]["function"]["name"], "get_weather");
}
#[test]
fn chat_message_content_null_when_none() {
let msg = ChatMessage {
role: "assistant".to_string(),
content: None,
tool_calls: None,
tool_call_id: None,
};
let v = serde_json::to_value(&msg).unwrap();
assert!(v["content"].is_null());
}
// ── ToolDefinition ───────────────────────────────────────────
#[test]
fn tool_definition_serializes() {
let td = ToolDefinition {
name: "search".to_string(),
description: "Search the web".to_string(),
parameters: json!({"type": "object", "properties": {"q": {"type": "string"}}}),
};
let v = serde_json::to_value(&td).unwrap();
assert_eq!(v["name"], "search");
assert_eq!(v["parameters"]["type"], "object");
}
#[test]
fn tool_definition_empty_parameters() {
let td = ToolDefinition {
name: "noop".to_string(),
description: "".to_string(),
parameters: json!({}),
};
let v = serde_json::to_value(&td).unwrap();
assert_eq!(v["parameters"], json!({}));
}
// ── LlmToolCall ──────────────────────────────────────────────
#[test]
fn llm_tool_call_roundtrip() {
let call = LlmToolCall {
id: "tc_42".to_string(),
name: "run_scan".to_string(),
arguments: json!({"path": "/tmp", "verbose": true}),
};
let serialized = serde_json::to_string(&call).unwrap();
let deserialized: LlmToolCall = serde_json::from_str(&serialized).unwrap();
assert_eq!(deserialized.id, "tc_42");
assert_eq!(deserialized.name, "run_scan");
assert_eq!(deserialized.arguments["path"], "/tmp");
assert_eq!(deserialized.arguments["verbose"], true);
}
#[test]
fn llm_tool_call_empty_arguments() {
let call = LlmToolCall {
id: "tc_0".to_string(),
name: "noop".to_string(),
arguments: json!({}),
};
let rt: LlmToolCall = serde_json::from_str(&serde_json::to_string(&call).unwrap()).unwrap();
assert!(rt.arguments.as_object().unwrap().is_empty());
}
// ── ToolCallRequest / ToolCallRequestFunction ────────────────
#[test]
fn tool_call_request_roundtrip() {
let req = ToolCallRequest {
id: "call_abc".to_string(),
r#type: "function".to_string(),
function: ToolCallRequestFunction {
name: "my_func".to_string(),
arguments: r#"{"x":1}"#.to_string(),
},
};
let json_str = serde_json::to_string(&req).unwrap();
let back: ToolCallRequest = serde_json::from_str(&json_str).unwrap();
assert_eq!(back.id, "call_abc");
assert_eq!(back.r#type, "function");
assert_eq!(back.function.name, "my_func");
assert_eq!(back.function.arguments, r#"{"x":1}"#);
}
#[test]
fn tool_call_request_type_field_serializes_as_type() {
let req = ToolCallRequest {
id: "id".to_string(),
r#type: "function".to_string(),
function: ToolCallRequestFunction {
name: "f".to_string(),
arguments: "{}".to_string(),
},
};
let v = serde_json::to_value(&req).unwrap();
// The field should be "type" in JSON, not "r#type"
assert!(v.get("type").is_some());
assert!(v.get("r#type").is_none());
}
// ── ChatCompletionRequest ────────────────────────────────────
#[test]
fn chat_completion_request_skips_none_fields() {
let req = ChatCompletionRequest {
model: "gpt-4".to_string(),
messages: vec![],
temperature: None,
max_tokens: None,
tools: None,
};
let v = serde_json::to_value(&req).unwrap();
assert_eq!(v["model"], "gpt-4");
assert!(v.get("temperature").is_none());
assert!(v.get("max_tokens").is_none());
assert!(v.get("tools").is_none());
}
#[test]
fn chat_completion_request_includes_set_fields() {
let req = ChatCompletionRequest {
model: "gpt-4".to_string(),
messages: vec![],
temperature: Some(0.7),
max_tokens: Some(1024),
tools: Some(vec![]),
};
let v = serde_json::to_value(&req).unwrap();
assert_eq!(v["temperature"], 0.7);
assert_eq!(v["max_tokens"], 1024);
assert!(v["tools"].is_array());
}
// ── ChatCompletionResponse deserialization ───────────────────
#[test]
fn chat_completion_response_deserializes_content() {
let json_str = r#"{"choices":[{"message":{"content":"Hello!"}}]}"#;
let resp: ChatCompletionResponse = serde_json::from_str(json_str).unwrap();
assert_eq!(resp.choices.len(), 1);
assert_eq!(resp.choices[0].message.content.as_deref(), Some("Hello!"));
assert!(resp.choices[0].message.tool_calls.is_none());
}
#[test]
fn chat_completion_response_deserializes_tool_calls() {
let json_str = r#"{
"choices": [{
"message": {
"tool_calls": [{
"id": "call_1",
"function": {"name": "search", "arguments": "{\"q\":\"rust\"}"}
}]
}
}]
}"#;
let resp: ChatCompletionResponse = serde_json::from_str(json_str).unwrap();
let tc = resp.choices[0].message.tool_calls.as_ref().unwrap();
assert_eq!(tc.len(), 1);
assert_eq!(tc[0].id, "call_1");
assert_eq!(tc[0].function.name, "search");
}
#[test]
fn chat_completion_response_defaults_missing_fields() {
// content and tool_calls are both missing — should default to None
let json_str = r#"{"choices":[{"message":{}}]}"#;
let resp: ChatCompletionResponse = serde_json::from_str(json_str).unwrap();
assert!(resp.choices[0].message.content.is_none());
assert!(resp.choices[0].message.tool_calls.is_none());
}
// ── LlmResponse ─────────────────────────────────────────────
#[test]
fn llm_response_content_variant() {
let resp = LlmResponse::Content("answer".to_string());
match resp {
LlmResponse::Content(s) => assert_eq!(s, "answer"),
_ => panic!("expected Content variant"),
}
}
#[test]
fn llm_response_tool_calls_variant() {
let resp = LlmResponse::ToolCalls {
calls: vec![LlmToolCall {
id: "1".to_string(),
name: "f".to_string(),
arguments: json!({}),
}],
reasoning: "because".to_string(),
};
match resp {
LlmResponse::ToolCalls { calls, reasoning } => {
assert_eq!(calls.len(), 1);
assert_eq!(reasoning, "because");
}
_ => panic!("expected ToolCalls variant"),
}
}
#[test]
fn llm_response_empty_content() {
let resp = LlmResponse::Content(String::new());
match resp {
LlmResponse::Content(s) => assert!(s.is_empty()),
_ => panic!("expected Content variant"),
}
}
}

View File

@@ -0,0 +1,150 @@
use futures_util::StreamExt;
use mongodb::bson::doc;
use compliance_core::models::dast::DastTarget;
use compliance_core::models::finding::Finding;
use compliance_core::models::pentest::CodeContextHint;
use compliance_core::models::sbom::SbomEntry;
use super::orchestrator::PentestOrchestrator;
impl PentestOrchestrator {
/// Fetch SAST findings, SBOM entries (with CVEs), and code graph entry points
/// for the repo linked to this DAST target.
pub(crate) async fn gather_repo_context(
&self,
target: &DastTarget,
) -> (Vec<Finding>, Vec<SbomEntry>, Vec<CodeContextHint>) {
let Some(repo_id) = &target.repo_id else {
return (Vec::new(), Vec::new(), Vec::new());
};
let sast_findings = self.fetch_sast_findings(repo_id).await;
let sbom_entries = self.fetch_vulnerable_sbom(repo_id).await;
let code_context = self.fetch_code_context(repo_id, &sast_findings).await;
tracing::info!(
repo_id,
sast_findings = sast_findings.len(),
vulnerable_deps = sbom_entries.len(),
code_hints = code_context.len(),
"Gathered code-awareness context for pentest"
);
(sast_findings, sbom_entries, code_context)
}
/// Fetch open/triaged SAST findings for the repo (not false positives or resolved)
async fn fetch_sast_findings(&self, repo_id: &str) -> Vec<Finding> {
let cursor = self
.db
.findings()
.find(doc! {
"repo_id": repo_id,
"status": { "$in": ["open", "triaged"] },
})
.sort(doc! { "severity": -1 })
.limit(100)
.await;
match cursor {
Ok(mut c) => {
let mut results = Vec::new();
while let Some(Ok(f)) = c.next().await {
results.push(f);
}
results
}
Err(e) => {
tracing::warn!("Failed to fetch SAST findings for pentest: {e}");
Vec::new()
}
}
}
/// Fetch SBOM entries that have known vulnerabilities
async fn fetch_vulnerable_sbom(&self, repo_id: &str) -> Vec<SbomEntry> {
let cursor = self
.db
.sbom_entries()
.find(doc! {
"repo_id": repo_id,
"known_vulnerabilities": { "$exists": true, "$ne": [] },
})
.limit(50)
.await;
match cursor {
Ok(mut c) => {
let mut results = Vec::new();
while let Some(Ok(e)) = c.next().await {
results.push(e);
}
results
}
Err(e) => {
tracing::warn!("Failed to fetch vulnerable SBOM entries: {e}");
Vec::new()
}
}
}
/// Build CodeContextHint objects from the code knowledge graph.
/// Maps entry points to their source files and links SAST findings.
async fn fetch_code_context(
&self,
repo_id: &str,
sast_findings: &[Finding],
) -> Vec<CodeContextHint> {
// Get entry point nodes from the code graph
let cursor = self
.db
.graph_nodes()
.find(doc! {
"repo_id": repo_id,
"is_entry_point": true,
})
.limit(50)
.await;
let nodes = match cursor {
Ok(mut c) => {
let mut results = Vec::new();
while let Some(Ok(n)) = c.next().await {
results.push(n);
}
results
}
Err(_) => return Vec::new(),
};
// Build hints by matching graph nodes to SAST findings by file path
nodes
.into_iter()
.map(|node| {
// Find SAST findings in the same file
let linked_vulns: Vec<String> = sast_findings
.iter()
.filter(|f| f.file_path.as_deref() == Some(&node.file_path))
.map(|f| {
format!(
"[{}] {}: {} (line {})",
f.severity,
f.scanner,
f.title,
f.line_number.unwrap_or(0)
)
})
.collect();
CodeContextHint {
endpoint_pattern: node.qualified_name.clone(),
handler_function: node.name.clone(),
file_path: node.file_path.clone(),
code_snippet: String::new(), // Could fetch from embeddings
known_vulnerabilities: linked_vulns,
}
})
.collect()
}
}

View File

@@ -1,4 +1,6 @@
mod context;
pub mod orchestrator;
mod prompt_builder;
pub mod report;
pub use orchestrator::PentestOrchestrator;

View File

@@ -1,31 +1,27 @@
use std::sync::Arc;
use std::time::Duration;
use futures_util::StreamExt;
use mongodb::bson::doc;
use tokio::sync::broadcast;
use compliance_core::models::dast::DastTarget;
use compliance_core::models::finding::{Finding, FindingStatus, Severity};
use compliance_core::models::pentest::*;
use compliance_core::models::sbom::SbomEntry;
use compliance_core::traits::pentest_tool::PentestToolContext;
use compliance_dast::ToolRegistry;
use crate::database::Database;
use crate::llm::client::{
ChatMessage, LlmResponse, ToolCallRequest, ToolCallRequestFunction, ToolDefinition,
use crate::llm::{
ChatMessage, LlmClient, LlmResponse, ToolCallRequest, ToolCallRequestFunction, ToolDefinition,
};
use crate::llm::LlmClient;
/// Maximum duration for a single pentest session before timeout
const SESSION_TIMEOUT: Duration = Duration::from_secs(30 * 60); // 30 minutes
pub struct PentestOrchestrator {
tool_registry: ToolRegistry,
llm: Arc<LlmClient>,
db: Database,
event_tx: broadcast::Sender<PentestEvent>,
pub(crate) tool_registry: ToolRegistry,
pub(crate) llm: Arc<LlmClient>,
pub(crate) db: Database,
pub(crate) event_tx: broadcast::Sender<PentestEvent>,
}
impl PentestOrchestrator {
@@ -111,18 +107,20 @@ impl PentestOrchestrator {
target: &DastTarget,
initial_message: &str,
) -> Result<(), crate::error::AgentError> {
let session_id = session
.id
.map(|oid| oid.to_hex())
.unwrap_or_default();
let session_id = session.id.map(|oid| oid.to_hex()).unwrap_or_default();
// Gather code-awareness context from linked repo
let (sast_findings, sbom_entries, code_context) =
self.gather_repo_context(target).await;
let (sast_findings, sbom_entries, code_context) = self.gather_repo_context(target).await;
// Build system prompt with code context
let system_prompt = self
.build_system_prompt(session, target, &sast_findings, &sbom_entries, &code_context)
.build_system_prompt(
session,
target,
&sast_findings,
&sbom_entries,
&code_context,
)
.await;
// Build tool definitions for LLM
@@ -182,8 +180,7 @@ impl PentestOrchestrator {
match response {
LlmResponse::Content(content) => {
let msg =
PentestMessage::assistant(session_id.clone(), content.clone());
let msg = PentestMessage::assistant(session_id.clone(), content.clone());
let _ = self.db.pentest_messages().insert_one(&msg).await;
let _ = self.event_tx.send(PentestEvent::Message {
content: content.clone(),
@@ -213,7 +210,10 @@ impl PentestOrchestrator {
}
break;
}
LlmResponse::ToolCalls { calls: tool_calls, reasoning } => {
LlmResponse::ToolCalls {
calls: tool_calls,
reasoning,
} => {
let tc_requests: Vec<ToolCallRequest> = tool_calls
.iter()
.map(|tc| ToolCallRequest {
@@ -221,15 +221,18 @@ impl PentestOrchestrator {
r#type: "function".to_string(),
function: ToolCallRequestFunction {
name: tc.name.clone(),
arguments: serde_json::to_string(&tc.arguments)
.unwrap_or_default(),
arguments: serde_json::to_string(&tc.arguments).unwrap_or_default(),
},
})
.collect();
messages.push(ChatMessage {
role: "assistant".to_string(),
content: if reasoning.is_empty() { None } else { Some(reasoning.clone()) },
content: if reasoning.is_empty() {
None
} else {
Some(reasoning.clone())
},
tool_calls: Some(tc_requests),
tool_call_id: None,
});
@@ -274,24 +277,30 @@ impl PentestOrchestrator {
let insert_result =
self.db.dast_findings().insert_one(&finding).await;
if let Ok(res) = &insert_result {
finding_ids.push(res.inserted_id.as_object_id().map(|oid| oid.to_hex()).unwrap_or_default());
}
let _ =
self.event_tx.send(PentestEvent::Finding {
finding_id: finding
.id
finding_ids.push(
res.inserted_id
.as_object_id()
.map(|oid| oid.to_hex())
.unwrap_or_default(),
title: finding.title.clone(),
severity: finding.severity.to_string(),
});
);
}
let _ = self.event_tx.send(PentestEvent::Finding {
finding_id: finding
.id
.map(|oid| oid.to_hex())
.unwrap_or_default(),
title: finding.title.clone(),
severity: finding.severity.to_string(),
});
}
// Compute risk score based on findings severity
let risk_score: Option<u8> = if findings_count > 0 {
Some(std::cmp::min(
100,
(findings_count as u8).saturating_mul(15).saturating_add(20),
(findings_count as u8)
.saturating_mul(15)
.saturating_add(20),
))
} else {
None
@@ -415,347 +424,4 @@ impl PentestOrchestrator {
Ok(())
}
// ── Code-Awareness: Gather context from linked repo ─────────
/// Fetch SAST findings, SBOM entries (with CVEs), and code graph entry points
/// for the repo linked to this DAST target.
async fn gather_repo_context(
&self,
target: &DastTarget,
) -> (Vec<Finding>, Vec<SbomEntry>, Vec<CodeContextHint>) {
let Some(repo_id) = &target.repo_id else {
return (Vec::new(), Vec::new(), Vec::new());
};
let sast_findings = self.fetch_sast_findings(repo_id).await;
let sbom_entries = self.fetch_vulnerable_sbom(repo_id).await;
let code_context = self.fetch_code_context(repo_id, &sast_findings).await;
tracing::info!(
repo_id,
sast_findings = sast_findings.len(),
vulnerable_deps = sbom_entries.len(),
code_hints = code_context.len(),
"Gathered code-awareness context for pentest"
);
(sast_findings, sbom_entries, code_context)
}
/// Fetch open/triaged SAST findings for the repo (not false positives or resolved)
async fn fetch_sast_findings(&self, repo_id: &str) -> Vec<Finding> {
let cursor = self
.db
.findings()
.find(doc! {
"repo_id": repo_id,
"status": { "$in": ["open", "triaged"] },
})
.sort(doc! { "severity": -1 })
.limit(100)
.await;
match cursor {
Ok(mut c) => {
let mut results = Vec::new();
while let Some(Ok(f)) = c.next().await {
results.push(f);
}
results
}
Err(e) => {
tracing::warn!("Failed to fetch SAST findings for pentest: {e}");
Vec::new()
}
}
}
/// Fetch SBOM entries that have known vulnerabilities
async fn fetch_vulnerable_sbom(&self, repo_id: &str) -> Vec<SbomEntry> {
let cursor = self
.db
.sbom_entries()
.find(doc! {
"repo_id": repo_id,
"known_vulnerabilities": { "$exists": true, "$ne": [] },
})
.limit(50)
.await;
match cursor {
Ok(mut c) => {
let mut results = Vec::new();
while let Some(Ok(e)) = c.next().await {
results.push(e);
}
results
}
Err(e) => {
tracing::warn!("Failed to fetch vulnerable SBOM entries: {e}");
Vec::new()
}
}
}
/// Build CodeContextHint objects from the code knowledge graph.
/// Maps entry points to their source files and links SAST findings.
async fn fetch_code_context(
&self,
repo_id: &str,
sast_findings: &[Finding],
) -> Vec<CodeContextHint> {
// Get entry point nodes from the code graph
let cursor = self
.db
.graph_nodes()
.find(doc! {
"repo_id": repo_id,
"is_entry_point": true,
})
.limit(50)
.await;
let nodes = match cursor {
Ok(mut c) => {
let mut results = Vec::new();
while let Some(Ok(n)) = c.next().await {
results.push(n);
}
results
}
Err(_) => return Vec::new(),
};
// Build hints by matching graph nodes to SAST findings by file path
nodes
.into_iter()
.map(|node| {
// Find SAST findings in the same file
let linked_vulns: Vec<String> = sast_findings
.iter()
.filter(|f| {
f.file_path.as_deref() == Some(&node.file_path)
})
.map(|f| {
format!(
"[{}] {}: {} (line {})",
f.severity,
f.scanner,
f.title,
f.line_number.unwrap_or(0)
)
})
.collect();
CodeContextHint {
endpoint_pattern: node.qualified_name.clone(),
handler_function: node.name.clone(),
file_path: node.file_path.clone(),
code_snippet: String::new(), // Could fetch from embeddings
known_vulnerabilities: linked_vulns,
}
})
.collect()
}
// ── System Prompt Builder ───────────────────────────────────
async fn build_system_prompt(
&self,
session: &PentestSession,
target: &DastTarget,
sast_findings: &[Finding],
sbom_entries: &[SbomEntry],
code_context: &[CodeContextHint],
) -> String {
let tool_names = self.tool_registry.list_names().join(", ");
let strategy_guidance = match session.strategy {
PentestStrategy::Quick => {
"Focus on the most common and impactful vulnerabilities. Run a quick recon, then target the highest-risk areas."
}
PentestStrategy::Comprehensive => {
"Perform a thorough assessment covering all vulnerability types. Start with recon, then systematically test each attack surface."
}
PentestStrategy::Targeted => {
"Focus specifically on areas highlighted by SAST findings and known CVEs. Prioritize exploiting known weaknesses."
}
PentestStrategy::Aggressive => {
"Use all available tools aggressively. Test with maximum payloads and attempt full exploitation."
}
PentestStrategy::Stealth => {
"Minimize noise. Use fewer requests, avoid aggressive payloads. Focus on passive analysis and targeted probes."
}
};
// Build SAST findings section
let sast_section = if sast_findings.is_empty() {
String::from("No SAST findings available for this target.")
} else {
let critical = sast_findings
.iter()
.filter(|f| f.severity == Severity::Critical)
.count();
let high = sast_findings
.iter()
.filter(|f| f.severity == Severity::High)
.count();
let mut section = format!(
"{} open findings ({} critical, {} high):\n",
sast_findings.len(),
critical,
high
);
// List the most important findings (critical/high first, up to 20)
for f in sast_findings.iter().take(20) {
let file_info = f
.file_path
.as_ref()
.map(|p| {
format!(
" in {}:{}",
p,
f.line_number.unwrap_or(0)
)
})
.unwrap_or_default();
let status_note = match f.status {
FindingStatus::Triaged => " [TRIAGED]",
_ => "",
};
section.push_str(&format!(
"- [{sev}] {title}{file}{status}\n",
sev = f.severity,
title = f.title,
file = file_info,
status = status_note,
));
if let Some(cwe) = &f.cwe {
section.push_str(&format!(" CWE: {cwe}\n"));
}
}
if sast_findings.len() > 20 {
section.push_str(&format!(
"... and {} more findings\n",
sast_findings.len() - 20
));
}
section
};
// Build SBOM/CVE section
let sbom_section = if sbom_entries.is_empty() {
String::from("No vulnerable dependencies identified.")
} else {
let mut section = format!(
"{} dependencies with known vulnerabilities:\n",
sbom_entries.len()
);
for entry in sbom_entries.iter().take(15) {
let cve_ids: Vec<&str> = entry
.known_vulnerabilities
.iter()
.map(|v| v.id.as_str())
.collect();
section.push_str(&format!(
"- {} {} ({}): {}\n",
entry.name,
entry.version,
entry.package_manager,
cve_ids.join(", ")
));
}
if sbom_entries.len() > 15 {
section.push_str(&format!(
"... and {} more vulnerable dependencies\n",
sbom_entries.len() - 15
));
}
section
};
// Build code context section
let code_section = if code_context.is_empty() {
String::from("No code knowledge graph available for this target.")
} else {
let with_vulns = code_context
.iter()
.filter(|c| !c.known_vulnerabilities.is_empty())
.count();
let mut section = format!(
"{} entry points identified ({} with linked SAST findings):\n",
code_context.len(),
with_vulns
);
for hint in code_context.iter().take(20) {
section.push_str(&format!(
"- {} ({})\n",
hint.endpoint_pattern, hint.file_path
));
for vuln in &hint.known_vulnerabilities {
section.push_str(&format!(" SAST: {vuln}\n"));
}
}
section
};
format!(
r#"You are an expert penetration tester conducting an authorized security assessment.
## Target
- **Name**: {target_name}
- **URL**: {base_url}
- **Type**: {target_type}
- **Rate Limit**: {rate_limit} req/s
- **Destructive Tests Allowed**: {allow_destructive}
- **Linked Repository**: {repo_linked}
## Strategy
{strategy_guidance}
## SAST Findings (Static Analysis)
{sast_section}
## Vulnerable Dependencies (SBOM)
{sbom_section}
## Code Entry Points (Knowledge Graph)
{code_section}
## Available Tools
{tool_names}
## Instructions
1. Start by running reconnaissance (recon tool) to fingerprint the target and discover technologies.
2. Run the OpenAPI parser to discover API endpoints from specs.
3. Check infrastructure: DNS, DMARC, TLS, security headers, cookies, CSP, CORS.
4. Based on SAST findings, prioritize testing endpoints where vulnerabilities were found in code.
5. For each vulnerability type found in SAST, use the corresponding DAST tool to verify exploitability.
6. If vulnerable dependencies are listed, try to trigger known CVE conditions against the running application.
7. Test rate limiting on critical endpoints (login, API).
8. Check for console.log leakage in frontend JavaScript.
9. Analyze tool results and chain findings — if one vulnerability enables others, explore the chain.
10. When testing is complete, provide a structured summary with severity and remediation.
11. Always explain your reasoning before invoking each tool.
12. When done, say "Testing complete" followed by a final summary.
## Important
- This is an authorized penetration test. All testing is permitted within the target scope.
- Respect the rate limit of {rate_limit} requests per second.
- Only use destructive tests if explicitly allowed ({allow_destructive}).
- Use SAST findings to guide your testing — they tell you WHERE in the code vulnerabilities exist.
- Use SBOM data to understand what technologies and versions the target runs.
"#,
target_name = target.name,
base_url = target.base_url,
target_type = target.target_type,
rate_limit = target.rate_limit,
allow_destructive = target.allow_destructive,
repo_linked = target.repo_id.as_deref().unwrap_or("None"),
)
}
}

View File

@@ -0,0 +1,504 @@
use compliance_core::models::dast::DastTarget;
use compliance_core::models::finding::{Finding, FindingStatus, Severity};
use compliance_core::models::pentest::*;
use compliance_core::models::sbom::SbomEntry;
use super::orchestrator::PentestOrchestrator;
/// Return strategy guidance text for the given strategy.
fn strategy_guidance(strategy: &PentestStrategy) -> &'static str {
match strategy {
PentestStrategy::Quick => {
"Focus on the most common and impactful vulnerabilities. Run a quick recon, then target the highest-risk areas."
}
PentestStrategy::Comprehensive => {
"Perform a thorough assessment covering all vulnerability types. Start with recon, then systematically test each attack surface."
}
PentestStrategy::Targeted => {
"Focus specifically on areas highlighted by SAST findings and known CVEs. Prioritize exploiting known weaknesses."
}
PentestStrategy::Aggressive => {
"Use all available tools aggressively. Test with maximum payloads and attempt full exploitation."
}
PentestStrategy::Stealth => {
"Minimize noise. Use fewer requests, avoid aggressive payloads. Focus on passive analysis and targeted probes."
}
}
}
/// Build the SAST findings section for the system prompt.
fn build_sast_section(sast_findings: &[Finding]) -> String {
if sast_findings.is_empty() {
return String::from("No SAST findings available for this target.");
}
let critical = sast_findings
.iter()
.filter(|f| f.severity == Severity::Critical)
.count();
let high = sast_findings
.iter()
.filter(|f| f.severity == Severity::High)
.count();
let mut section = format!(
"{} open findings ({} critical, {} high):\n",
sast_findings.len(),
critical,
high
);
// List the most important findings (critical/high first, up to 20)
for f in sast_findings.iter().take(20) {
let file_info = f
.file_path
.as_ref()
.map(|p| format!(" in {}:{}", p, f.line_number.unwrap_or(0)))
.unwrap_or_default();
let status_note = match f.status {
FindingStatus::Triaged => " [TRIAGED]",
_ => "",
};
section.push_str(&format!(
"- [{sev}] {title}{file}{status}\n",
sev = f.severity,
title = f.title,
file = file_info,
status = status_note,
));
if let Some(cwe) = &f.cwe {
section.push_str(&format!(" CWE: {cwe}\n"));
}
}
if sast_findings.len() > 20 {
section.push_str(&format!(
"... and {} more findings\n",
sast_findings.len() - 20
));
}
section
}
/// Build the SBOM/CVE section for the system prompt.
fn build_sbom_section(sbom_entries: &[SbomEntry]) -> String {
if sbom_entries.is_empty() {
return String::from("No vulnerable dependencies identified.");
}
let mut section = format!(
"{} dependencies with known vulnerabilities:\n",
sbom_entries.len()
);
for entry in sbom_entries.iter().take(15) {
let cve_ids: Vec<&str> = entry
.known_vulnerabilities
.iter()
.map(|v| v.id.as_str())
.collect();
section.push_str(&format!(
"- {} {} ({}): {}\n",
entry.name,
entry.version,
entry.package_manager,
cve_ids.join(", ")
));
}
if sbom_entries.len() > 15 {
section.push_str(&format!(
"... and {} more vulnerable dependencies\n",
sbom_entries.len() - 15
));
}
section
}
/// Build the code context section for the system prompt.
fn build_code_section(code_context: &[CodeContextHint]) -> String {
if code_context.is_empty() {
return String::from("No code knowledge graph available for this target.");
}
let with_vulns = code_context
.iter()
.filter(|c| !c.known_vulnerabilities.is_empty())
.count();
let mut section = format!(
"{} entry points identified ({} with linked SAST findings):\n",
code_context.len(),
with_vulns
);
for hint in code_context.iter().take(20) {
section.push_str(&format!(
"- {} ({})\n",
hint.endpoint_pattern, hint.file_path
));
for vuln in &hint.known_vulnerabilities {
section.push_str(&format!(" SAST: {vuln}\n"));
}
}
section
}
impl PentestOrchestrator {
pub(crate) async fn build_system_prompt(
&self,
session: &PentestSession,
target: &DastTarget,
sast_findings: &[Finding],
sbom_entries: &[SbomEntry],
code_context: &[CodeContextHint],
) -> String {
let tool_names = self.tool_registry.list_names().join(", ");
let guidance = strategy_guidance(&session.strategy);
let sast_section = build_sast_section(sast_findings);
let sbom_section = build_sbom_section(sbom_entries);
let code_section = build_code_section(code_context);
format!(
r#"You are an expert penetration tester conducting an authorized security assessment.
## Target
- **Name**: {target_name}
- **URL**: {base_url}
- **Type**: {target_type}
- **Rate Limit**: {rate_limit} req/s
- **Destructive Tests Allowed**: {allow_destructive}
- **Linked Repository**: {repo_linked}
## Strategy
{strategy_guidance}
## SAST Findings (Static Analysis)
{sast_section}
## Vulnerable Dependencies (SBOM)
{sbom_section}
## Code Entry Points (Knowledge Graph)
{code_section}
## Available Tools
{tool_names}
## Instructions
1. Start by running reconnaissance (recon tool) to fingerprint the target and discover technologies.
2. Run the OpenAPI parser to discover API endpoints from specs.
3. Check infrastructure: DNS, DMARC, TLS, security headers, cookies, CSP, CORS.
4. Based on SAST findings, prioritize testing endpoints where vulnerabilities were found in code.
5. For each vulnerability type found in SAST, use the corresponding DAST tool to verify exploitability.
6. If vulnerable dependencies are listed, try to trigger known CVE conditions against the running application.
7. Test rate limiting on critical endpoints (login, API).
8. Check for console.log leakage in frontend JavaScript.
9. Analyze tool results and chain findings — if one vulnerability enables others, explore the chain.
10. When testing is complete, provide a structured summary with severity and remediation.
11. Always explain your reasoning before invoking each tool.
12. When done, say "Testing complete" followed by a final summary.
## Important
- This is an authorized penetration test. All testing is permitted within the target scope.
- Respect the rate limit of {rate_limit} requests per second.
- Only use destructive tests if explicitly allowed ({allow_destructive}).
- Use SAST findings to guide your testing — they tell you WHERE in the code vulnerabilities exist.
- Use SBOM data to understand what technologies and versions the target runs.
"#,
target_name = target.name,
base_url = target.base_url,
target_type = target.target_type,
rate_limit = target.rate_limit,
allow_destructive = target.allow_destructive,
repo_linked = target.repo_id.as_deref().unwrap_or("None"),
strategy_guidance = guidance,
)
}
}
#[cfg(test)]
mod tests {
use super::*;
use compliance_core::models::finding::Severity;
use compliance_core::models::sbom::VulnRef;
use compliance_core::models::scan::ScanType;
fn make_finding(
severity: Severity,
title: &str,
file_path: Option<&str>,
line: Option<u32>,
status: FindingStatus,
cwe: Option<&str>,
) -> Finding {
let mut f = Finding::new(
"repo-1".into(),
format!("fp-{title}"),
"semgrep".into(),
ScanType::Sast,
title.into(),
"desc".into(),
severity,
);
f.file_path = file_path.map(|s| s.to_string());
f.line_number = line;
f.status = status;
f.cwe = cwe.map(|s| s.to_string());
f
}
fn make_sbom_entry(name: &str, version: &str, cves: &[&str]) -> SbomEntry {
let mut entry = SbomEntry::new("repo-1".into(), name.into(), version.into(), "npm".into());
entry.known_vulnerabilities = cves
.iter()
.map(|id| VulnRef {
id: id.to_string(),
source: "nvd".into(),
severity: None,
url: None,
})
.collect();
entry
}
fn make_code_hint(endpoint: &str, file: &str, vulns: Vec<String>) -> CodeContextHint {
CodeContextHint {
endpoint_pattern: endpoint.into(),
handler_function: "handler".into(),
file_path: file.into(),
code_snippet: String::new(),
known_vulnerabilities: vulns,
}
}
// ── strategy_guidance ────────────────────────────────────────────
#[test]
fn strategy_guidance_quick() {
let g = strategy_guidance(&PentestStrategy::Quick);
assert!(g.contains("most common"));
assert!(g.contains("quick recon"));
}
#[test]
fn strategy_guidance_comprehensive() {
let g = strategy_guidance(&PentestStrategy::Comprehensive);
assert!(g.contains("thorough assessment"));
}
#[test]
fn strategy_guidance_targeted() {
let g = strategy_guidance(&PentestStrategy::Targeted);
assert!(g.contains("SAST findings"));
assert!(g.contains("known CVEs"));
}
#[test]
fn strategy_guidance_aggressive() {
let g = strategy_guidance(&PentestStrategy::Aggressive);
assert!(g.contains("aggressively"));
assert!(g.contains("full exploitation"));
}
#[test]
fn strategy_guidance_stealth() {
let g = strategy_guidance(&PentestStrategy::Stealth);
assert!(g.contains("Minimize noise"));
assert!(g.contains("passive analysis"));
}
// ── build_sast_section ───────────────────────────────────────────
#[test]
fn sast_section_empty() {
let section = build_sast_section(&[]);
assert_eq!(section, "No SAST findings available for this target.");
}
#[test]
fn sast_section_single_critical() {
let findings = vec![make_finding(
Severity::Critical,
"SQL Injection",
Some("src/db.rs"),
Some(42),
FindingStatus::Open,
Some("CWE-89"),
)];
let section = build_sast_section(&findings);
assert!(section.contains("1 open findings (1 critical, 0 high)"));
assert!(section.contains("[critical] SQL Injection in src/db.rs:42"));
assert!(section.contains("CWE: CWE-89"));
}
#[test]
fn sast_section_triaged_finding_shows_marker() {
let findings = vec![make_finding(
Severity::High,
"XSS",
None,
None,
FindingStatus::Triaged,
None,
)];
let section = build_sast_section(&findings);
assert!(section.contains("[TRIAGED]"));
}
#[test]
fn sast_section_no_file_path_omits_location() {
let findings = vec![make_finding(
Severity::Medium,
"Open Redirect",
None,
None,
FindingStatus::Open,
None,
)];
let section = build_sast_section(&findings);
assert!(section.contains("- [medium] Open Redirect\n"));
assert!(!section.contains(" in "));
}
#[test]
fn sast_section_counts_critical_and_high() {
let findings = vec![
make_finding(
Severity::Critical,
"F1",
None,
None,
FindingStatus::Open,
None,
),
make_finding(
Severity::Critical,
"F2",
None,
None,
FindingStatus::Open,
None,
),
make_finding(Severity::High, "F3", None, None, FindingStatus::Open, None),
make_finding(
Severity::Medium,
"F4",
None,
None,
FindingStatus::Open,
None,
),
];
let section = build_sast_section(&findings);
assert!(section.contains("4 open findings (2 critical, 1 high)"));
}
#[test]
fn sast_section_truncates_at_20() {
let findings: Vec<Finding> = (0..25)
.map(|i| {
make_finding(
Severity::Low,
&format!("Finding {i}"),
None,
None,
FindingStatus::Open,
None,
)
})
.collect();
let section = build_sast_section(&findings);
assert!(section.contains("... and 5 more findings"));
// Should contain Finding 19 (the 20th) but not Finding 20 (the 21st)
assert!(section.contains("Finding 19"));
assert!(!section.contains("Finding 20"));
}
// ── build_sbom_section ───────────────────────────────────────────
#[test]
fn sbom_section_empty() {
let section = build_sbom_section(&[]);
assert_eq!(section, "No vulnerable dependencies identified.");
}
#[test]
fn sbom_section_single_entry() {
let entries = vec![make_sbom_entry("lodash", "4.17.20", &["CVE-2021-23337"])];
let section = build_sbom_section(&entries);
assert!(section.contains("1 dependencies with known vulnerabilities"));
assert!(section.contains("- lodash 4.17.20 (npm): CVE-2021-23337"));
}
#[test]
fn sbom_section_multiple_cves() {
let entries = vec![make_sbom_entry(
"openssl",
"1.1.1",
&["CVE-2022-0001", "CVE-2022-0002"],
)];
let section = build_sbom_section(&entries);
assert!(section.contains("CVE-2022-0001, CVE-2022-0002"));
}
#[test]
fn sbom_section_truncates_at_15() {
let entries: Vec<SbomEntry> = (0..18)
.map(|i| make_sbom_entry(&format!("pkg-{i}"), "1.0.0", &["CVE-2024-0001"]))
.collect();
let section = build_sbom_section(&entries);
assert!(section.contains("... and 3 more vulnerable dependencies"));
assert!(section.contains("pkg-14"));
assert!(!section.contains("pkg-15"));
}
// ── build_code_section ───────────────────────────────────────────
#[test]
fn code_section_empty() {
let section = build_code_section(&[]);
assert_eq!(
section,
"No code knowledge graph available for this target."
);
}
#[test]
fn code_section_single_entry_no_vulns() {
let hints = vec![make_code_hint("GET /api/users", "src/routes.rs", vec![])];
let section = build_code_section(&hints);
assert!(section.contains("1 entry points identified (0 with linked SAST findings)"));
assert!(section.contains("- GET /api/users (src/routes.rs)"));
}
#[test]
fn code_section_with_linked_vulns() {
let hints = vec![make_code_hint(
"POST /login",
"src/auth.rs",
vec!["[critical] semgrep: SQL Injection (line 15)".into()],
)];
let section = build_code_section(&hints);
assert!(section.contains("1 entry points identified (1 with linked SAST findings)"));
assert!(section.contains("SAST: [critical] semgrep: SQL Injection (line 15)"));
}
#[test]
fn code_section_counts_entries_with_vulns() {
let hints = vec![
make_code_hint("GET /a", "a.rs", vec!["vuln1".into()]),
make_code_hint("GET /b", "b.rs", vec![]),
make_code_hint("GET /c", "c.rs", vec!["vuln2".into(), "vuln3".into()]),
];
let section = build_code_section(&hints);
assert!(section.contains("3 entry points identified (2 with linked SAST findings)"));
}
#[test]
fn code_section_truncates_at_20() {
let hints: Vec<CodeContextHint> = (0..25)
.map(|i| make_code_hint(&format!("GET /ep{i}"), &format!("f{i}.rs"), vec![]))
.collect();
let section = build_code_section(&hints);
assert!(section.contains("GET /ep19"));
assert!(!section.contains("GET /ep20"));
}
}

View File

@@ -0,0 +1,43 @@
use std::io::{Cursor, Write};
use zip::write::SimpleFileOptions;
use zip::AesMode;
use super::ReportContext;
pub(super) fn build_zip(
ctx: &ReportContext,
password: &str,
html: &str,
pdf: &[u8],
) -> Result<Vec<u8>, zip::result::ZipError> {
let buf = Cursor::new(Vec::new());
let mut zip = zip::ZipWriter::new(buf);
let options = SimpleFileOptions::default()
.compression_method(zip::CompressionMethod::Deflated)
.with_aes_encryption(AesMode::Aes256, password);
// report.pdf (primary)
zip.start_file("report.pdf", options.clone())?;
zip.write_all(pdf)?;
// report.html (fallback)
zip.start_file("report.html", options.clone())?;
zip.write_all(html.as_bytes())?;
// findings.json
let findings_json =
serde_json::to_string_pretty(&ctx.findings).unwrap_or_else(|_| "[]".to_string());
zip.start_file("findings.json", options.clone())?;
zip.write_all(findings_json.as_bytes())?;
// attack-chain.json
let chain_json =
serde_json::to_string_pretty(&ctx.attack_chain).unwrap_or_else(|_| "[]".to_string());
zip.start_file("attack-chain.json", options)?;
zip.write_all(chain_json.as_bytes())?;
let cursor = zip.finish()?;
Ok(cursor.into_inner())
}

View File

@@ -1,193 +1,49 @@
use std::io::{Cursor, Write};
use compliance_core::models::dast::DastFinding;
use compliance_core::models::pentest::{AttackChainNode, PentestSession};
use sha2::{Digest, Sha256};
use zip::write::SimpleFileOptions;
use zip::AesMode;
use compliance_core::models::pentest::AttackChainNode;
/// Report archive with metadata
pub struct ReportArchive {
/// The password-protected ZIP bytes
pub archive: Vec<u8>,
/// SHA-256 hex digest of the archive
pub sha256: String,
}
use super::ReportContext;
/// Report context gathered from the database
pub struct ReportContext {
pub session: PentestSession,
pub target_name: String,
pub target_url: String,
pub findings: Vec<DastFinding>,
pub attack_chain: Vec<AttackChainNode>,
pub requester_name: String,
pub requester_email: String,
}
/// Generate a password-protected ZIP archive containing the pentest report.
///
/// The archive contains:
/// - `report.pdf` — Professional pentest report (PDF)
/// - `report.html` — HTML source (fallback)
/// - `findings.json` — Raw findings data
/// - `attack-chain.json` — Attack chain timeline
///
/// Files are encrypted with AES-256 inside the ZIP (standard WinZip AES format,
/// supported by 7-Zip, WinRAR, macOS Archive Utility, etc.).
pub async fn generate_encrypted_report(
ctx: &ReportContext,
password: &str,
) -> Result<ReportArchive, String> {
let html = build_html_report(ctx);
// Convert HTML to PDF via headless Chrome
let pdf_bytes = html_to_pdf(&html).await?;
let zip_bytes = build_zip(ctx, password, &html, &pdf_bytes)
.map_err(|e| format!("Failed to create archive: {e}"))?;
let mut hasher = Sha256::new();
hasher.update(&zip_bytes);
let sha256 = hex::encode(hasher.finalize());
Ok(ReportArchive { archive: zip_bytes, sha256 })
}
/// Convert HTML string to PDF bytes using headless Chrome/Chromium.
async fn html_to_pdf(html: &str) -> Result<Vec<u8>, String> {
let tmp_dir = std::env::temp_dir();
let run_id = uuid::Uuid::new_v4().to_string();
let html_path = tmp_dir.join(format!("pentest-report-{run_id}.html"));
let pdf_path = tmp_dir.join(format!("pentest-report-{run_id}.pdf"));
// Write HTML to temp file
std::fs::write(&html_path, html)
.map_err(|e| format!("Failed to write temp HTML: {e}"))?;
// Find Chrome/Chromium binary
let chrome_bin = find_chrome_binary()
.ok_or_else(|| "Chrome/Chromium not found. Install google-chrome or chromium to generate PDF reports.".to_string())?;
tracing::info!(chrome = %chrome_bin, "Generating PDF report via headless Chrome");
let html_url = format!("file://{}", html_path.display());
let output = tokio::process::Command::new(&chrome_bin)
.args([
"--headless",
"--disable-gpu",
"--no-sandbox",
"--disable-software-rasterizer",
"--run-all-compositor-stages-before-draw",
"--disable-dev-shm-usage",
&format!("--print-to-pdf={}", pdf_path.display()),
"--no-pdf-header-footer",
&html_url,
])
.output()
.await
.map_err(|e| format!("Failed to run Chrome: {e}"))?;
if !output.status.success() {
let stderr = String::from_utf8_lossy(&output.stderr);
// Clean up temp files
let _ = std::fs::remove_file(&html_path);
let _ = std::fs::remove_file(&pdf_path);
return Err(format!("Chrome PDF generation failed: {stderr}"));
}
let pdf_bytes = std::fs::read(&pdf_path)
.map_err(|e| format!("Failed to read generated PDF: {e}"))?;
// Clean up temp files
let _ = std::fs::remove_file(&html_path);
let _ = std::fs::remove_file(&pdf_path);
if pdf_bytes.is_empty() {
return Err("Chrome produced an empty PDF".to_string());
}
tracing::info!(size_kb = pdf_bytes.len() / 1024, "PDF report generated");
Ok(pdf_bytes)
}
/// Search for Chrome/Chromium binary on the system.
fn find_chrome_binary() -> Option<String> {
let candidates = [
"google-chrome-stable",
"google-chrome",
"chromium-browser",
"chromium",
];
for name in &candidates {
if let Ok(output) = std::process::Command::new("which").arg(name).output() {
if output.status.success() {
let path = String::from_utf8_lossy(&output.stdout).trim().to_string();
if !path.is_empty() {
return Some(path);
}
}
}
}
None
}
fn build_zip(
ctx: &ReportContext,
password: &str,
html: &str,
pdf: &[u8],
) -> Result<Vec<u8>, zip::result::ZipError> {
let buf = Cursor::new(Vec::new());
let mut zip = zip::ZipWriter::new(buf);
let options = SimpleFileOptions::default()
.compression_method(zip::CompressionMethod::Deflated)
.with_aes_encryption(AesMode::Aes256, password);
// report.pdf (primary)
zip.start_file("report.pdf", options.clone())?;
zip.write_all(pdf)?;
// report.html (fallback)
zip.start_file("report.html", options.clone())?;
zip.write_all(html.as_bytes())?;
// findings.json
let findings_json =
serde_json::to_string_pretty(&ctx.findings).unwrap_or_else(|_| "[]".to_string());
zip.start_file("findings.json", options.clone())?;
zip.write_all(findings_json.as_bytes())?;
// attack-chain.json
let chain_json =
serde_json::to_string_pretty(&ctx.attack_chain).unwrap_or_else(|_| "[]".to_string());
zip.start_file("attack-chain.json", options)?;
zip.write_all(chain_json.as_bytes())?;
let cursor = zip.finish()?;
Ok(cursor.into_inner())
}
fn build_html_report(ctx: &ReportContext) -> String {
pub(super) fn build_html_report(ctx: &ReportContext) -> String {
let session = &ctx.session;
let session_id = session
.id
.map(|oid| oid.to_hex())
.unwrap_or_else(|| "-".to_string());
let date_str = session.started_at.format("%B %d, %Y at %H:%M UTC").to_string();
let date_str = session
.started_at
.format("%B %d, %Y at %H:%M UTC")
.to_string();
let date_short = session.started_at.format("%B %d, %Y").to_string();
let completed_str = session
.completed_at
.map(|d| d.format("%B %d, %Y at %H:%M UTC").to_string())
.unwrap_or_else(|| "In Progress".to_string());
let critical = ctx.findings.iter().filter(|f| f.severity.to_string() == "critical").count();
let high = ctx.findings.iter().filter(|f| f.severity.to_string() == "high").count();
let medium = ctx.findings.iter().filter(|f| f.severity.to_string() == "medium").count();
let low = ctx.findings.iter().filter(|f| f.severity.to_string() == "low").count();
let info = ctx.findings.iter().filter(|f| f.severity.to_string() == "info").count();
let critical = ctx
.findings
.iter()
.filter(|f| f.severity.to_string() == "critical")
.count();
let high = ctx
.findings
.iter()
.filter(|f| f.severity.to_string() == "high")
.count();
let medium = ctx
.findings
.iter()
.filter(|f| f.severity.to_string() == "medium")
.count();
let low = ctx
.findings
.iter()
.filter(|f| f.severity.to_string() == "low")
.count();
let info = ctx
.findings
.iter()
.filter(|f| f.severity.to_string() == "info")
.count();
let exploitable = ctx.findings.iter().filter(|f| f.exploitable).count();
let total = ctx.findings.len();
@@ -247,7 +103,8 @@ fn build_html_report(ctx: &ReportContext) -> String {
if high > 0 {
bar.push_str(&format!(
r#"<div class="sev-bar-seg sev-bar-high" style="width:{}%"><span>{}</span></div>"#,
std::cmp::max(high_pct, 4), high
std::cmp::max(high_pct, 4),
high
));
}
if medium > 0 {
@@ -259,22 +116,38 @@ fn build_html_report(ctx: &ReportContext) -> String {
if low > 0 {
bar.push_str(&format!(
r#"<div class="sev-bar-seg sev-bar-low" style="width:{}%"><span>{}</span></div>"#,
std::cmp::max(low_pct, 4), low
std::cmp::max(low_pct, 4),
low
));
}
if info > 0 {
bar.push_str(&format!(
r#"<div class="sev-bar-seg sev-bar-info" style="width:{}%"><span>{}</span></div>"#,
std::cmp::max(info_pct, 4), info
std::cmp::max(info_pct, 4),
info
));
}
bar.push_str("</div>");
bar.push_str(r#"<div class="sev-bar-legend">"#);
if critical > 0 { bar.push_str(r#"<span><i class="sev-dot" style="background:#991b1b"></i> Critical</span>"#); }
if high > 0 { bar.push_str(r#"<span><i class="sev-dot" style="background:#c2410c"></i> High</span>"#); }
if medium > 0 { bar.push_str(r#"<span><i class="sev-dot" style="background:#a16207"></i> Medium</span>"#); }
if low > 0 { bar.push_str(r#"<span><i class="sev-dot" style="background:#1d4ed8"></i> Low</span>"#); }
if info > 0 { bar.push_str(r#"<span><i class="sev-dot" style="background:#4b5563"></i> Info</span>"#); }
if critical > 0 {
bar.push_str(
r#"<span><i class="sev-dot" style="background:#991b1b"></i> Critical</span>"#,
);
}
if high > 0 {
bar.push_str(r#"<span><i class="sev-dot" style="background:#c2410c"></i> High</span>"#);
}
if medium > 0 {
bar.push_str(
r#"<span><i class="sev-dot" style="background:#a16207"></i> Medium</span>"#,
);
}
if low > 0 {
bar.push_str(r#"<span><i class="sev-dot" style="background:#1d4ed8"></i> Low</span>"#);
}
if info > 0 {
bar.push_str(r#"<span><i class="sev-dot" style="background:#4b5563"></i> Info</span>"#);
}
bar.push_str("</div>");
bar
} else {
@@ -322,7 +195,12 @@ fn build_html_report(ctx: &ReportContext) -> String {
let param_row = f
.parameter
.as_deref()
.map(|p| format!("<tr><td>Parameter</td><td><code>{}</code></td></tr>", html_escape(p)))
.map(|p| {
format!(
"<tr><td>Parameter</td><td><code>{}</code></td></tr>",
html_escape(p)
)
})
.unwrap_or_default();
let remediation = f
.remediation
@@ -332,7 +210,9 @@ fn build_html_report(ctx: &ReportContext) -> String {
let evidence_html = if f.evidence.is_empty() {
String::new()
} else {
let mut eh = String::from(r#"<div class="evidence-block"><div class="evidence-title">Evidence</div><table class="evidence-table"><thead><tr><th>Request</th><th>Status</th><th>Details</th></tr></thead><tbody>"#);
let mut eh = String::from(
r#"<div class="evidence-block"><div class="evidence-title">Evidence</div><table class="evidence-table"><thead><tr><th>Request</th><th>Status</th><th>Details</th></tr></thead><tbody>"#,
);
for ev in &f.evidence {
let payload_info = ev
.payload
@@ -402,7 +282,8 @@ fn build_html_report(ctx: &ReportContext) -> String {
let mut chain_html = String::new();
if !ctx.attack_chain.is_empty() {
// Compute phases via BFS from root nodes
let mut phase_map: std::collections::HashMap<String, usize> = std::collections::HashMap::new();
let mut phase_map: std::collections::HashMap<String, usize> =
std::collections::HashMap::new();
let mut queue: std::collections::VecDeque<String> = std::collections::VecDeque::new();
for node in &ctx.attack_chain {
@@ -438,7 +319,13 @@ fn build_html_report(ctx: &ReportContext) -> String {
// Group nodes by phase
let max_phase = phase_map.values().copied().max().unwrap_or(0);
let phase_labels = ["Reconnaissance", "Enumeration", "Exploitation", "Validation", "Post-Exploitation"];
let phase_labels = [
"Reconnaissance",
"Enumeration",
"Exploitation",
"Validation",
"Post-Exploitation",
];
for phase_idx in 0..=max_phase {
let phase_nodes: Vec<&AttackChainNode> = ctx
@@ -485,15 +372,28 @@ fn build_html_report(ctx: &ReportContext) -> String {
format!(
r#"<span class="step-findings">{} finding{}</span>"#,
node.findings_produced.len(),
if node.findings_produced.len() == 1 { "" } else { "s" },
if node.findings_produced.len() == 1 {
""
} else {
"s"
},
)
} else {
String::new()
};
let risk_badge = node.risk_score.map(|r| {
let risk_class = if r >= 70 { "risk-high" } else if r >= 40 { "risk-med" } else { "risk-low" };
format!(r#"<span class="step-risk {risk_class}">Risk: {r}</span>"#)
}).unwrap_or_default();
let risk_badge = node
.risk_score
.map(|r| {
let risk_class = if r >= 70 {
"risk-high"
} else if r >= 40 {
"risk-med"
} else {
"risk-low"
};
format!(r#"<span class="step-risk {risk_class}">Risk: {r}</span>"#)
})
.unwrap_or_default();
let reasoning_html = if node.llm_reasoning.is_empty() {
String::new()
@@ -548,9 +448,19 @@ fn build_html_report(ctx: &ReportContext) -> String {
let mut sub = String::new();
let mut fnum = 0usize;
for (si, &sev_key) in severity_order.iter().enumerate() {
let count = ctx.findings.iter().filter(|f| f.severity.to_string() == sev_key).count();
if count == 0 { continue; }
for f in ctx.findings.iter().filter(|f| f.severity.to_string() == sev_key) {
let count = ctx
.findings
.iter()
.filter(|f| f.severity.to_string() == sev_key)
.count();
if count == 0 {
continue;
}
for f in ctx
.findings
.iter()
.filter(|f| f.severity.to_string() == sev_key)
{
fnum += 1;
sub.push_str(&format!(
r#"<div class="toc-sub">F-{:03} — {}</div>"#,
@@ -1577,19 +1487,49 @@ table.tools-table td:first-child {{
fn tool_category(tool_name: &str) -> &'static str {
let name = tool_name.to_lowercase();
if name.contains("nmap") || name.contains("port") { return "Network Reconnaissance"; }
if name.contains("nikto") || name.contains("header") { return "Web Server Analysis"; }
if name.contains("zap") || name.contains("spider") || name.contains("crawl") { return "Web Application Scanning"; }
if name.contains("sqlmap") || name.contains("sqli") || name.contains("sql") { return "SQL Injection Testing"; }
if name.contains("xss") || name.contains("cross-site") { return "Cross-Site Scripting Testing"; }
if name.contains("dir") || name.contains("brute") || name.contains("fuzz") || name.contains("gobuster") { return "Directory Enumeration"; }
if name.contains("ssl") || name.contains("tls") || name.contains("cert") { return "SSL/TLS Analysis"; }
if name.contains("api") || name.contains("endpoint") { return "API Security Testing"; }
if name.contains("auth") || name.contains("login") || name.contains("credential") { return "Authentication Testing"; }
if name.contains("cors") { return "CORS Testing"; }
if name.contains("csrf") { return "CSRF Testing"; }
if name.contains("nuclei") || name.contains("template") { return "Vulnerability Scanning"; }
if name.contains("whatweb") || name.contains("tech") || name.contains("wappalyzer") { return "Technology Fingerprinting"; }
if name.contains("nmap") || name.contains("port") {
return "Network Reconnaissance";
}
if name.contains("nikto") || name.contains("header") {
return "Web Server Analysis";
}
if name.contains("zap") || name.contains("spider") || name.contains("crawl") {
return "Web Application Scanning";
}
if name.contains("sqlmap") || name.contains("sqli") || name.contains("sql") {
return "SQL Injection Testing";
}
if name.contains("xss") || name.contains("cross-site") {
return "Cross-Site Scripting Testing";
}
if name.contains("dir")
|| name.contains("brute")
|| name.contains("fuzz")
|| name.contains("gobuster")
{
return "Directory Enumeration";
}
if name.contains("ssl") || name.contains("tls") || name.contains("cert") {
return "SSL/TLS Analysis";
}
if name.contains("api") || name.contains("endpoint") {
return "API Security Testing";
}
if name.contains("auth") || name.contains("login") || name.contains("credential") {
return "Authentication Testing";
}
if name.contains("cors") {
return "CORS Testing";
}
if name.contains("csrf") {
return "CSRF Testing";
}
if name.contains("nuclei") || name.contains("template") {
return "Vulnerability Scanning";
}
if name.contains("whatweb") || name.contains("tech") || name.contains("wappalyzer") {
return "Technology Fingerprinting";
}
"Security Testing"
}
@@ -1599,3 +1539,314 @@ fn html_escape(s: &str) -> String {
.replace('>', "&gt;")
.replace('"', "&quot;")
}
#[cfg(test)]
mod tests {
use super::*;
use compliance_core::models::dast::{DastFinding, DastVulnType};
use compliance_core::models::finding::Severity;
use compliance_core::models::pentest::{
AttackChainNode, AttackNodeStatus, PentestSession, PentestStrategy,
};
// ── html_escape ──────────────────────────────────────────────────
#[test]
fn html_escape_handles_ampersand() {
assert_eq!(html_escape("a & b"), "a &amp; b");
}
#[test]
fn html_escape_handles_angle_brackets() {
assert_eq!(html_escape("<script>"), "&lt;script&gt;");
}
#[test]
fn html_escape_handles_quotes() {
assert_eq!(html_escape(r#"key="val""#), "key=&quot;val&quot;");
}
#[test]
fn html_escape_handles_all_special_chars() {
assert_eq!(
html_escape(r#"<a href="x">&y</a>"#),
"&lt;a href=&quot;x&quot;&gt;&amp;y&lt;/a&gt;"
);
}
#[test]
fn html_escape_no_change_for_plain_text() {
assert_eq!(html_escape("hello world"), "hello world");
}
#[test]
fn html_escape_empty_string() {
assert_eq!(html_escape(""), "");
}
// ── tool_category ────────────────────────────────────────────────
#[test]
fn tool_category_nmap() {
assert_eq!(tool_category("nmap_scan"), "Network Reconnaissance");
}
#[test]
fn tool_category_port_scanner() {
assert_eq!(tool_category("port_scanner"), "Network Reconnaissance");
}
#[test]
fn tool_category_nikto() {
assert_eq!(tool_category("nikto"), "Web Server Analysis");
}
#[test]
fn tool_category_header_check() {
assert_eq!(
tool_category("security_header_check"),
"Web Server Analysis"
);
}
#[test]
fn tool_category_zap_spider() {
assert_eq!(tool_category("zap_spider"), "Web Application Scanning");
}
#[test]
fn tool_category_sqlmap() {
assert_eq!(tool_category("sqlmap"), "SQL Injection Testing");
}
#[test]
fn tool_category_xss_scanner() {
assert_eq!(tool_category("xss_scanner"), "Cross-Site Scripting Testing");
}
#[test]
fn tool_category_dir_bruteforce() {
assert_eq!(tool_category("dir_bruteforce"), "Directory Enumeration");
}
#[test]
fn tool_category_gobuster() {
assert_eq!(tool_category("gobuster"), "Directory Enumeration");
}
#[test]
fn tool_category_ssl_check() {
assert_eq!(tool_category("ssl_check"), "SSL/TLS Analysis");
}
#[test]
fn tool_category_tls_scan() {
assert_eq!(tool_category("tls_scan"), "SSL/TLS Analysis");
}
#[test]
fn tool_category_api_test() {
assert_eq!(tool_category("api_endpoint_test"), "API Security Testing");
}
#[test]
fn tool_category_auth_bypass() {
assert_eq!(tool_category("auth_bypass_check"), "Authentication Testing");
}
#[test]
fn tool_category_cors() {
assert_eq!(tool_category("cors_check"), "CORS Testing");
}
#[test]
fn tool_category_csrf() {
assert_eq!(tool_category("csrf_scanner"), "CSRF Testing");
}
#[test]
fn tool_category_nuclei() {
assert_eq!(tool_category("nuclei"), "Vulnerability Scanning");
}
#[test]
fn tool_category_whatweb() {
assert_eq!(tool_category("whatweb"), "Technology Fingerprinting");
}
#[test]
fn tool_category_unknown_defaults_to_security_testing() {
assert_eq!(tool_category("custom_tool"), "Security Testing");
}
#[test]
fn tool_category_is_case_insensitive() {
assert_eq!(tool_category("NMAP_Scanner"), "Network Reconnaissance");
assert_eq!(tool_category("SQLMap"), "SQL Injection Testing");
}
// ── build_html_report ────────────────────────────────────────────
fn make_session(strategy: PentestStrategy) -> PentestSession {
let mut s = PentestSession::new("target-1".into(), strategy);
s.tool_invocations = 5;
s.tool_successes = 4;
s.findings_count = 2;
s.exploitable_count = 1;
s
}
fn make_finding(severity: Severity, title: &str, exploitable: bool) -> DastFinding {
let mut f = DastFinding::new(
"run-1".into(),
"target-1".into(),
DastVulnType::Xss,
title.into(),
"description".into(),
severity,
"https://example.com/test".into(),
"GET".into(),
);
f.exploitable = exploitable;
f
}
fn make_attack_node(tool_name: &str) -> AttackChainNode {
let mut node = AttackChainNode::new(
"session-1".into(),
"node-1".into(),
tool_name.into(),
serde_json::json!({}),
"Testing this tool".into(),
);
node.status = AttackNodeStatus::Completed;
node
}
fn make_report_context(
findings: Vec<DastFinding>,
chain: Vec<AttackChainNode>,
) -> ReportContext {
ReportContext {
session: make_session(PentestStrategy::Comprehensive),
target_name: "Test App".into(),
target_url: "https://example.com".into(),
findings,
attack_chain: chain,
requester_name: "Alice".into(),
requester_email: "alice@example.com".into(),
}
}
#[test]
fn report_contains_target_info() {
let ctx = make_report_context(vec![], vec![]);
let html = build_html_report(&ctx);
assert!(html.contains("Test App"));
assert!(html.contains("https://example.com"));
}
#[test]
fn report_contains_requester_info() {
let ctx = make_report_context(vec![], vec![]);
let html = build_html_report(&ctx);
assert!(html.contains("Alice"));
assert!(html.contains("alice@example.com"));
}
#[test]
fn report_shows_informational_risk_when_no_findings() {
let ctx = make_report_context(vec![], vec![]);
let html = build_html_report(&ctx);
assert!(html.contains("INFORMATIONAL"));
}
#[test]
fn report_shows_critical_risk_with_critical_finding() {
let findings = vec![make_finding(Severity::Critical, "Critical XSS", true)];
let ctx = make_report_context(findings, vec![]);
let html = build_html_report(&ctx);
assert!(html.contains("CRITICAL"));
}
#[test]
fn report_shows_high_risk_without_critical() {
let findings = vec![make_finding(Severity::High, "High SQLi", false)];
let ctx = make_report_context(findings, vec![]);
let html = build_html_report(&ctx);
// Should show HIGH, not CRITICAL
assert!(html.contains("HIGH"));
}
#[test]
fn report_shows_medium_risk_level() {
let findings = vec![make_finding(Severity::Medium, "Medium Issue", false)];
let ctx = make_report_context(findings, vec![]);
let html = build_html_report(&ctx);
assert!(html.contains("MEDIUM"));
}
#[test]
fn report_includes_finding_title() {
let findings = vec![make_finding(
Severity::High,
"Reflected XSS in /search",
true,
)];
let ctx = make_report_context(findings, vec![]);
let html = build_html_report(&ctx);
assert!(html.contains("Reflected XSS in /search"));
}
#[test]
fn report_shows_exploitable_badge() {
let findings = vec![make_finding(Severity::Critical, "SQLi", true)];
let ctx = make_report_context(findings, vec![]);
let html = build_html_report(&ctx);
// The report should mark exploitable findings
assert!(html.contains("EXPLOITABLE"));
}
#[test]
fn report_includes_attack_chain_tool_names() {
let chain = vec![make_attack_node("nmap_scan"), make_attack_node("sqlmap")];
let ctx = make_report_context(vec![], chain);
let html = build_html_report(&ctx);
assert!(html.contains("nmap_scan"));
assert!(html.contains("sqlmap"));
}
#[test]
fn report_is_valid_html_structure() {
let ctx = make_report_context(vec![], vec![]);
let html = build_html_report(&ctx);
assert!(html.contains("<!DOCTYPE html>") || html.contains("<html"));
assert!(html.contains("</html>"));
}
#[test]
fn report_strategy_appears() {
let ctx = make_report_context(vec![], vec![]);
let html = build_html_report(&ctx);
// PentestStrategy::Comprehensive => "comprehensive"
assert!(html.contains("comprehensive") || html.contains("Comprehensive"));
}
#[test]
fn report_finding_count_is_correct() {
let findings = vec![
make_finding(Severity::Critical, "F1", true),
make_finding(Severity::High, "F2", false),
make_finding(Severity::Low, "F3", false),
];
let ctx = make_report_context(findings, vec![]);
let html = build_html_report(&ctx);
// The total count "3" should appear somewhere
assert!(
html.contains(">3<")
|| html.contains(">3 ")
|| html.contains("3 findings")
|| html.contains("3 Total")
);
}
}

View File

@@ -0,0 +1,58 @@
mod archive;
mod html;
mod pdf;
use compliance_core::models::dast::DastFinding;
use compliance_core::models::pentest::{AttackChainNode, PentestSession};
use sha2::{Digest, Sha256};
/// Report archive with metadata
pub struct ReportArchive {
/// The password-protected ZIP bytes
pub archive: Vec<u8>,
/// SHA-256 hex digest of the archive
pub sha256: String,
}
/// Report context gathered from the database
pub struct ReportContext {
pub session: PentestSession,
pub target_name: String,
pub target_url: String,
pub findings: Vec<DastFinding>,
pub attack_chain: Vec<AttackChainNode>,
pub requester_name: String,
pub requester_email: String,
}
/// Generate a password-protected ZIP archive containing the pentest report.
///
/// The archive contains:
/// - `report.pdf` — Professional pentest report (PDF)
/// - `report.html` — HTML source (fallback)
/// - `findings.json` — Raw findings data
/// - `attack-chain.json` — Attack chain timeline
///
/// Files are encrypted with AES-256 inside the ZIP (standard WinZip AES format,
/// supported by 7-Zip, WinRAR, macOS Archive Utility, etc.).
pub async fn generate_encrypted_report(
ctx: &ReportContext,
password: &str,
) -> Result<ReportArchive, String> {
let html = html::build_html_report(ctx);
// Convert HTML to PDF via headless Chrome
let pdf_bytes = pdf::html_to_pdf(&html).await?;
let zip_bytes = archive::build_zip(ctx, password, &html, &pdf_bytes)
.map_err(|e| format!("Failed to create archive: {e}"))?;
let mut hasher = Sha256::new();
hasher.update(&zip_bytes);
let sha256 = hex::encode(hasher.finalize());
Ok(ReportArchive {
archive: zip_bytes,
sha256,
})
}

View File

@@ -0,0 +1,79 @@
/// Convert HTML string to PDF bytes using headless Chrome/Chromium.
pub(super) async fn html_to_pdf(html: &str) -> Result<Vec<u8>, String> {
let tmp_dir = std::env::temp_dir();
let run_id = uuid::Uuid::new_v4().to_string();
let html_path = tmp_dir.join(format!("pentest-report-{run_id}.html"));
let pdf_path = tmp_dir.join(format!("pentest-report-{run_id}.pdf"));
// Write HTML to temp file
std::fs::write(&html_path, html).map_err(|e| format!("Failed to write temp HTML: {e}"))?;
// Find Chrome/Chromium binary
let chrome_bin = find_chrome_binary().ok_or_else(|| {
"Chrome/Chromium not found. Install google-chrome or chromium to generate PDF reports."
.to_string()
})?;
tracing::info!(chrome = %chrome_bin, "Generating PDF report via headless Chrome");
let html_url = format!("file://{}", html_path.display());
let output = tokio::process::Command::new(&chrome_bin)
.args([
"--headless",
"--disable-gpu",
"--no-sandbox",
"--disable-software-rasterizer",
"--run-all-compositor-stages-before-draw",
"--disable-dev-shm-usage",
&format!("--print-to-pdf={}", pdf_path.display()),
"--no-pdf-header-footer",
&html_url,
])
.output()
.await
.map_err(|e| format!("Failed to run Chrome: {e}"))?;
if !output.status.success() {
let stderr = String::from_utf8_lossy(&output.stderr);
// Clean up temp files
let _ = std::fs::remove_file(&html_path);
let _ = std::fs::remove_file(&pdf_path);
return Err(format!("Chrome PDF generation failed: {stderr}"));
}
let pdf_bytes =
std::fs::read(&pdf_path).map_err(|e| format!("Failed to read generated PDF: {e}"))?;
// Clean up temp files
let _ = std::fs::remove_file(&html_path);
let _ = std::fs::remove_file(&pdf_path);
if pdf_bytes.is_empty() {
return Err("Chrome produced an empty PDF".to_string());
}
tracing::info!(size_kb = pdf_bytes.len() / 1024, "PDF report generated");
Ok(pdf_bytes)
}
/// Search for Chrome/Chromium binary on the system.
fn find_chrome_binary() -> Option<String> {
let candidates = [
"google-chrome-stable",
"google-chrome",
"chromium-browser",
"chromium",
];
for name in &candidates {
if let Ok(output) = std::process::Command::new("which").arg(name).output() {
if output.status.success() {
let path = String::from_utf8_lossy(&output.stdout).trim().to_string();
if !path.is_empty() {
return Some(path);
}
}
}
}
None
}

View File

@@ -8,3 +8,51 @@ pub fn compute_fingerprint(parts: &[&str]) -> String {
}
hex::encode(hasher.finalize())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn fingerprint_is_deterministic() {
let a = compute_fingerprint(&["repo1", "rule-x", "src/main.rs", "42"]);
let b = compute_fingerprint(&["repo1", "rule-x", "src/main.rs", "42"]);
assert_eq!(a, b);
}
#[test]
fn fingerprint_changes_with_different_input() {
let a = compute_fingerprint(&["repo1", "rule-x", "src/main.rs", "42"]);
let b = compute_fingerprint(&["repo1", "rule-x", "src/main.rs", "43"]);
assert_ne!(a, b);
}
#[test]
fn fingerprint_is_valid_hex_sha256() {
let fp = compute_fingerprint(&["hello"]);
assert_eq!(fp.len(), 64, "SHA-256 hex should be 64 chars");
assert!(fp.chars().all(|c| c.is_ascii_hexdigit()));
}
#[test]
fn fingerprint_empty_parts() {
let fp = compute_fingerprint(&[]);
// Should still produce a valid hash (of empty input)
assert_eq!(fp.len(), 64);
}
#[test]
fn fingerprint_order_matters() {
let a = compute_fingerprint(&["a", "b"]);
let b = compute_fingerprint(&["b", "a"]);
assert_ne!(a, b);
}
#[test]
fn fingerprint_separator_prevents_collision() {
// "ab" + "c" vs "a" + "bc" should differ because of the "|" separator
let a = compute_fingerprint(&["ab", "c"]);
let b = compute_fingerprint(&["a", "bc"]);
assert_ne!(a, b);
}
}

View File

@@ -129,3 +129,110 @@ struct GitleaksResult {
#[serde(rename = "Match")]
r#match: String,
}
#[cfg(test)]
mod tests {
use super::*;
// --- is_allowlisted tests ---
#[test]
fn allowlisted_env_example_files() {
assert!(is_allowlisted(".env.example"));
assert!(is_allowlisted("config/.env.sample"));
assert!(is_allowlisted("deploy/.ENV.TEMPLATE"));
}
#[test]
fn allowlisted_test_directories() {
assert!(is_allowlisted("src/test/config.json"));
assert!(is_allowlisted("src/tests/fixtures.rs"));
assert!(is_allowlisted("data/fixtures/secret.txt"));
assert!(is_allowlisted("pkg/testdata/key.pem"));
}
#[test]
fn allowlisted_mock_files() {
assert!(is_allowlisted("src/mock_service.py"));
assert!(is_allowlisted("lib/MockAuth.java"));
}
#[test]
fn allowlisted_test_suffixes() {
assert!(is_allowlisted("auth_test.go"));
assert!(is_allowlisted("auth.test.ts"));
assert!(is_allowlisted("auth.test.js"));
assert!(is_allowlisted("auth.spec.ts"));
assert!(is_allowlisted("auth.spec.js"));
}
#[test]
fn not_allowlisted_regular_files() {
assert!(!is_allowlisted("src/main.rs"));
assert!(!is_allowlisted("config/.env"));
assert!(!is_allowlisted("lib/auth.ts"));
assert!(!is_allowlisted("deploy/secrets.yaml"));
}
#[test]
fn not_allowlisted_partial_matches() {
// "test" as substring in a non-directory context should not match
assert!(!is_allowlisted("src/attestation.rs"));
assert!(!is_allowlisted("src/contest/data.json"));
}
// --- GitleaksResult deserialization tests ---
#[test]
fn deserialize_gitleaks_result() {
let json = r#"{
"Description": "AWS Access Key",
"RuleID": "aws-access-key",
"File": "src/config.rs",
"StartLine": 10,
"Match": "AKIAIOSFODNN7EXAMPLE"
}"#;
let result: GitleaksResult = serde_json::from_str(json).unwrap();
assert_eq!(result.description, "AWS Access Key");
assert_eq!(result.rule_id, "aws-access-key");
assert_eq!(result.file, "src/config.rs");
assert_eq!(result.start_line, 10);
assert_eq!(result.r#match, "AKIAIOSFODNN7EXAMPLE");
}
#[test]
fn deserialize_gitleaks_result_array() {
let json = r#"[
{
"Description": "Generic Secret",
"RuleID": "generic-secret",
"File": "app.py",
"StartLine": 5,
"Match": "password=hunter2"
}
]"#;
let results: Vec<GitleaksResult> = serde_json::from_str(json).unwrap();
assert_eq!(results.len(), 1);
assert_eq!(results[0].rule_id, "generic-secret");
}
#[test]
fn severity_mapping_private_key() {
// Verify the severity logic from the scan method
let rule_id = "some-private-key-rule";
assert!(rule_id.contains("private-key"));
}
#[test]
fn severity_mapping_token_password_secret() {
for keyword in &["token", "password", "secret"] {
let rule_id = format!("some-{}-rule", keyword);
assert!(
rule_id.contains("token")
|| rule_id.contains("password")
|| rule_id.contains("secret"),
"Expected '{rule_id}' to match token/password/secret"
);
}
}
}

View File

@@ -0,0 +1,106 @@
use compliance_core::models::Finding;
use super::orchestrator::{GraphContext, PipelineOrchestrator};
use crate::error::AgentError;
impl PipelineOrchestrator {
/// Build the code knowledge graph for a repo and compute impact analyses
pub(super) async fn build_code_graph(
&self,
repo_path: &std::path::Path,
repo_id: &str,
findings: &[Finding],
) -> Result<GraphContext, AgentError> {
let graph_build_id = uuid::Uuid::new_v4().to_string();
let engine = compliance_graph::GraphEngine::new(50_000);
let (mut code_graph, build_run) =
engine
.build_graph(repo_path, repo_id, &graph_build_id)
.map_err(|e| AgentError::Other(format!("Graph build error: {e}")))?;
// Apply community detection
compliance_graph::graph::community::apply_communities(&mut code_graph);
// Store graph in MongoDB
let store = compliance_graph::graph::persistence::GraphStore::new(self.db.inner());
store
.delete_repo_graph(repo_id)
.await
.map_err(|e| AgentError::Other(format!("Graph cleanup error: {e}")))?;
store
.store_graph(&build_run, &code_graph.nodes, &code_graph.edges)
.await
.map_err(|e| AgentError::Other(format!("Graph store error: {e}")))?;
// Compute impact analysis for each finding
let analyzer = compliance_graph::GraphEngine::impact_analyzer(&code_graph);
let mut impacts = Vec::new();
for finding in findings {
if let Some(file_path) = &finding.file_path {
let impact = analyzer.analyze(
repo_id,
&finding.fingerprint,
&graph_build_id,
file_path,
finding.line_number,
);
store
.store_impact(&impact)
.await
.map_err(|e| AgentError::Other(format!("Impact store error: {e}")))?;
impacts.push(impact);
}
}
Ok(GraphContext {
node_count: build_run.node_count,
edge_count: build_run.edge_count,
community_count: build_run.community_count,
impacts,
})
}
/// Trigger DAST scan if a target is configured for this repo
pub(super) async fn maybe_trigger_dast(&self, repo_id: &str, scan_run_id: &str) {
use futures_util::TryStreamExt;
let filter = mongodb::bson::doc! { "repo_id": repo_id };
let targets: Vec<compliance_core::models::DastTarget> =
match self.db.dast_targets().find(filter).await {
Ok(cursor) => cursor.try_collect().await.unwrap_or_default(),
Err(_) => return,
};
if targets.is_empty() {
tracing::info!("[{repo_id}] No DAST targets configured, skipping");
return;
}
for target in targets {
let db = self.db.clone();
let scan_run_id = scan_run_id.to_string();
tokio::spawn(async move {
let orchestrator = compliance_dast::DastOrchestrator::new(100);
match orchestrator.run_scan(&target, Vec::new()).await {
Ok((mut scan_run, findings)) => {
scan_run.sast_scan_run_id = Some(scan_run_id);
if let Err(e) = db.dast_scan_runs().insert_one(&scan_run).await {
tracing::error!("Failed to store DAST scan run: {e}");
}
for finding in &findings {
if let Err(e) = db.dast_findings().insert_one(finding).await {
tracing::error!("Failed to store DAST finding: {e}");
}
}
tracing::info!("DAST scan complete: {} findings", findings.len());
}
Err(e) => {
tracing::error!("DAST scan failed: {e}");
}
}
});
}
}
}

View File

@@ -0,0 +1,259 @@
use mongodb::bson::doc;
use compliance_core::models::*;
use super::orchestrator::{extract_base_url, PipelineOrchestrator};
use super::tracker_dispatch::TrackerDispatch;
use crate::error::AgentError;
use crate::trackers;
impl PipelineOrchestrator {
/// Build an issue tracker client from a repository's tracker configuration.
/// Returns `None` if the repo has no tracker configured.
pub(super) fn build_tracker(&self, repo: &TrackedRepository) -> Option<TrackerDispatch> {
let tracker_type = repo.tracker_type.as_ref()?;
// Per-repo token takes precedence, fall back to global config
match tracker_type {
TrackerType::GitHub => {
let token = repo.tracker_token.clone().or_else(|| {
self.config.github_token.as_ref().map(|t| {
use secrecy::ExposeSecret;
t.expose_secret().to_string()
})
})?;
let secret = secrecy::SecretString::from(token);
match trackers::github::GitHubTracker::new(&secret) {
Ok(t) => Some(TrackerDispatch::GitHub(t)),
Err(e) => {
tracing::warn!("Failed to build GitHub tracker: {e}");
None
}
}
}
TrackerType::GitLab => {
let base_url = self
.config
.gitlab_url
.clone()
.unwrap_or_else(|| "https://gitlab.com".to_string());
let token = repo.tracker_token.clone().or_else(|| {
self.config.gitlab_token.as_ref().map(|t| {
use secrecy::ExposeSecret;
t.expose_secret().to_string()
})
})?;
let secret = secrecy::SecretString::from(token);
Some(TrackerDispatch::GitLab(
trackers::gitlab::GitLabTracker::new(base_url, secret),
))
}
TrackerType::Gitea => {
let token = repo.tracker_token.clone()?;
let base_url = extract_base_url(&repo.git_url)?;
let secret = secrecy::SecretString::from(token);
Some(TrackerDispatch::Gitea(trackers::gitea::GiteaTracker::new(
base_url, secret,
)))
}
TrackerType::Jira => {
let base_url = self.config.jira_url.clone()?;
let email = self.config.jira_email.clone()?;
let project_key = self.config.jira_project_key.clone()?;
let token = repo.tracker_token.clone().or_else(|| {
self.config.jira_api_token.as_ref().map(|t| {
use secrecy::ExposeSecret;
t.expose_secret().to_string()
})
})?;
let secret = secrecy::SecretString::from(token);
Some(TrackerDispatch::Jira(trackers::jira::JiraTracker::new(
base_url,
email,
secret,
project_key,
)))
}
}
}
/// Create tracker issues for new findings (severity >= Medium).
/// Checks for duplicates via fingerprint search before creating.
#[tracing::instrument(skip_all, fields(repo_id = %repo_id))]
pub(super) async fn create_tracker_issues(
&self,
repo: &TrackedRepository,
repo_id: &str,
new_findings: &[Finding],
) -> Result<(), AgentError> {
let tracker = match self.build_tracker(repo) {
Some(t) => t,
None => {
tracing::info!("[{repo_id}] No issue tracker configured, skipping");
return Ok(());
}
};
let owner = match repo.tracker_owner.as_deref() {
Some(o) => o,
None => {
tracing::warn!("[{repo_id}] tracker_owner not set, skipping issue creation");
return Ok(());
}
};
let tracker_repo_name = match repo.tracker_repo.as_deref() {
Some(r) => r,
None => {
tracing::warn!("[{repo_id}] tracker_repo not set, skipping issue creation");
return Ok(());
}
};
// Only create issues for medium+ severity findings
let actionable: Vec<&Finding> = new_findings
.iter()
.filter(|f| {
matches!(
f.severity,
Severity::Medium | Severity::High | Severity::Critical
)
})
.collect();
if actionable.is_empty() {
tracing::info!("[{repo_id}] No medium+ findings, skipping issue creation");
return Ok(());
}
tracing::info!(
"[{repo_id}] Creating issues for {} findings via {}",
actionable.len(),
tracker.name()
);
let mut created = 0u32;
for finding in actionable {
let title = format!(
"[{}] {}: {}",
finding.severity, finding.scanner, finding.title
);
// Check if an issue already exists by fingerprint first, then by title
let mut found_existing = false;
for search_term in [&finding.fingerprint, &title] {
match tracker
.find_existing_issue(owner, tracker_repo_name, search_term)
.await
{
Ok(Some(existing)) => {
tracing::debug!(
"[{repo_id}] Issue already exists for '{}': {}",
search_term,
existing.external_url
);
found_existing = true;
break;
}
Ok(None) => {}
Err(e) => {
tracing::warn!("[{repo_id}] Failed to search for existing issue: {e}");
}
}
}
if found_existing {
continue;
}
let body = format_issue_body(finding);
let labels = vec![
format!("severity:{}", finding.severity),
format!("scanner:{}", finding.scanner),
"compliance-scanner".to_string(),
];
match tracker
.create_issue(owner, tracker_repo_name, &title, &body, &labels)
.await
{
Ok(mut issue) => {
issue.finding_id = finding
.id
.as_ref()
.map(|id| id.to_hex())
.unwrap_or_default();
// Update the finding with the issue URL
if let Some(finding_id) = &finding.id {
let _ = self
.db
.findings()
.update_one(
doc! { "_id": finding_id },
doc! { "$set": { "tracker_issue_url": &issue.external_url } },
)
.await;
}
// Store the tracker issue record
if let Err(e) = self.db.tracker_issues().insert_one(&issue).await {
tracing::warn!("[{repo_id}] Failed to store tracker issue: {e}");
}
created += 1;
}
Err(e) => {
tracing::warn!(
"[{repo_id}] Failed to create issue for {}: {e}",
finding.fingerprint
);
}
}
}
tracing::info!("[{repo_id}] Created {created} tracker issues");
Ok(())
}
}
/// Format a finding into a markdown issue body for the tracker.
pub(super) fn format_issue_body(finding: &Finding) -> String {
let mut body = String::new();
body.push_str(&format!("## {} Finding\n\n", finding.severity));
body.push_str(&format!("**Scanner:** {}\n", finding.scanner));
body.push_str(&format!("**Severity:** {}\n", finding.severity));
if let Some(rule) = &finding.rule_id {
body.push_str(&format!("**Rule:** {}\n", rule));
}
if let Some(cwe) = &finding.cwe {
body.push_str(&format!("**CWE:** {}\n", cwe));
}
body.push_str(&format!("\n### Description\n\n{}\n", finding.description));
if let Some(file_path) = &finding.file_path {
body.push_str(&format!("\n### Location\n\n**File:** `{}`", file_path));
if let Some(line) = finding.line_number {
body.push_str(&format!(" (line {})", line));
}
body.push('\n');
}
if let Some(snippet) = &finding.code_snippet {
body.push_str(&format!("\n### Code\n\n```\n{}\n```\n", snippet));
}
if let Some(remediation) = &finding.remediation {
body.push_str(&format!("\n### Remediation\n\n{}\n", remediation));
}
if let Some(fix) = &finding.suggested_fix {
body.push_str(&format!("\n### Suggested Fix\n\n```\n{}\n```\n", fix));
}
body.push_str(&format!(
"\n---\n*Fingerprint:* `{}`\n*Generated by compliance-scanner*",
finding.fingerprint
));
body
}

View File

@@ -1,366 +0,0 @@
use std::path::Path;
use std::time::Duration;
use compliance_core::models::{Finding, ScanType, Severity};
use compliance_core::traits::{ScanOutput, Scanner};
use compliance_core::CoreError;
use tokio::process::Command;
use crate::pipeline::dedup;
/// Timeout for each individual lint command
const LINT_TIMEOUT: Duration = Duration::from_secs(120);
pub struct LintScanner;
impl Scanner for LintScanner {
fn name(&self) -> &str {
"lint"
}
fn scan_type(&self) -> ScanType {
ScanType::Lint
}
#[tracing::instrument(skip_all)]
async fn scan(&self, repo_path: &Path, repo_id: &str) -> Result<ScanOutput, CoreError> {
let mut all_findings = Vec::new();
// Detect which languages are present and run appropriate linters
if has_rust_project(repo_path) {
match run_clippy(repo_path, repo_id).await {
Ok(findings) => all_findings.extend(findings),
Err(e) => tracing::warn!("Clippy failed: {e}"),
}
}
if has_js_project(repo_path) {
match run_eslint(repo_path, repo_id).await {
Ok(findings) => all_findings.extend(findings),
Err(e) => tracing::warn!("ESLint failed: {e}"),
}
}
if has_python_project(repo_path) {
match run_ruff(repo_path, repo_id).await {
Ok(findings) => all_findings.extend(findings),
Err(e) => tracing::warn!("Ruff failed: {e}"),
}
}
Ok(ScanOutput {
findings: all_findings,
sbom_entries: Vec::new(),
})
}
}
fn has_rust_project(repo_path: &Path) -> bool {
repo_path.join("Cargo.toml").exists()
}
fn has_js_project(repo_path: &Path) -> bool {
// Only run if eslint is actually installed in the project
repo_path.join("package.json").exists() && repo_path.join("node_modules/.bin/eslint").exists()
}
fn has_python_project(repo_path: &Path) -> bool {
repo_path.join("pyproject.toml").exists()
|| repo_path.join("setup.py").exists()
|| repo_path.join("requirements.txt").exists()
}
/// Run a command with a timeout, returning its output or an error
async fn run_with_timeout(
child: tokio::process::Child,
scanner_name: &str,
) -> Result<std::process::Output, CoreError> {
let result = tokio::time::timeout(LINT_TIMEOUT, child.wait_with_output()).await;
match result {
Ok(Ok(output)) => Ok(output),
Ok(Err(e)) => Err(CoreError::Scanner {
scanner: scanner_name.to_string(),
source: Box::new(e),
}),
Err(_) => {
// Process is dropped here which sends SIGKILL on Unix
Err(CoreError::Scanner {
scanner: scanner_name.to_string(),
source: Box::new(std::io::Error::new(
std::io::ErrorKind::TimedOut,
format!("{scanner_name} timed out after {}s", LINT_TIMEOUT.as_secs()),
)),
})
}
}
}
// ── Clippy ──────────────────────────────────────────────
async fn run_clippy(repo_path: &Path, repo_id: &str) -> Result<Vec<Finding>, CoreError> {
let child = Command::new("cargo")
.args([
"clippy",
"--message-format=json",
"--quiet",
"--",
"-W",
"clippy::all",
])
.current_dir(repo_path)
.env("RUSTC_WRAPPER", "")
.stdout(std::process::Stdio::piped())
.stderr(std::process::Stdio::piped())
.spawn()
.map_err(|e| CoreError::Scanner {
scanner: "clippy".to_string(),
source: Box::new(e),
})?;
let output = run_with_timeout(child, "clippy").await?;
let stdout = String::from_utf8_lossy(&output.stdout);
let mut findings = Vec::new();
for line in stdout.lines() {
let msg: serde_json::Value = match serde_json::from_str(line) {
Ok(v) => v,
Err(_) => continue,
};
if msg.get("reason").and_then(|v| v.as_str()) != Some("compiler-message") {
continue;
}
let message = match msg.get("message") {
Some(m) => m,
None => continue,
};
let level = message.get("level").and_then(|v| v.as_str()).unwrap_or("");
if level != "warning" && level != "error" {
continue;
}
let text = message
.get("message")
.and_then(|v| v.as_str())
.unwrap_or("")
.to_string();
let code = message
.get("code")
.and_then(|v| v.get("code"))
.and_then(|v| v.as_str())
.unwrap_or("")
.to_string();
if text.starts_with("aborting due to") || code.is_empty() {
continue;
}
let (file_path, line_number) = extract_primary_span(message);
let severity = if level == "error" {
Severity::High
} else {
Severity::Low
};
let fingerprint = dedup::compute_fingerprint(&[
repo_id,
"clippy",
&code,
&file_path,
&line_number.to_string(),
]);
let mut finding = Finding::new(
repo_id.to_string(),
fingerprint,
"clippy".to_string(),
ScanType::Lint,
format!("[clippy] {text}"),
text,
severity,
);
finding.rule_id = Some(code);
if !file_path.is_empty() {
finding.file_path = Some(file_path);
}
if line_number > 0 {
finding.line_number = Some(line_number);
}
findings.push(finding);
}
Ok(findings)
}
fn extract_primary_span(message: &serde_json::Value) -> (String, u32) {
let spans = match message.get("spans").and_then(|v| v.as_array()) {
Some(s) => s,
None => return (String::new(), 0),
};
for span in spans {
if span.get("is_primary").and_then(|v| v.as_bool()) == Some(true) {
let file = span
.get("file_name")
.and_then(|v| v.as_str())
.unwrap_or("")
.to_string();
let line = span.get("line_start").and_then(|v| v.as_u64()).unwrap_or(0) as u32;
return (file, line);
}
}
(String::new(), 0)
}
// ── ESLint ──────────────────────────────────────────────
async fn run_eslint(repo_path: &Path, repo_id: &str) -> Result<Vec<Finding>, CoreError> {
// Use the project-local eslint binary directly, not npx (which can hang downloading)
let eslint_bin = repo_path.join("node_modules/.bin/eslint");
let child = Command::new(eslint_bin)
.args([".", "--format", "json", "--no-error-on-unmatched-pattern"])
.current_dir(repo_path)
.stdout(std::process::Stdio::piped())
.stderr(std::process::Stdio::piped())
.spawn()
.map_err(|e| CoreError::Scanner {
scanner: "eslint".to_string(),
source: Box::new(e),
})?;
let output = run_with_timeout(child, "eslint").await?;
if output.stdout.is_empty() {
return Ok(Vec::new());
}
let results: Vec<EslintFileResult> = serde_json::from_slice(&output.stdout).unwrap_or_default();
let mut findings = Vec::new();
for file_result in results {
for msg in file_result.messages {
let severity = match msg.severity {
2 => Severity::Medium,
_ => Severity::Low,
};
let rule_id = msg.rule_id.unwrap_or_default();
let fingerprint = dedup::compute_fingerprint(&[
repo_id,
"eslint",
&rule_id,
&file_result.file_path,
&msg.line.to_string(),
]);
let mut finding = Finding::new(
repo_id.to_string(),
fingerprint,
"eslint".to_string(),
ScanType::Lint,
format!("[eslint] {}", msg.message),
msg.message,
severity,
);
finding.rule_id = Some(rule_id);
finding.file_path = Some(file_result.file_path.clone());
finding.line_number = Some(msg.line);
findings.push(finding);
}
}
Ok(findings)
}
#[derive(serde::Deserialize)]
struct EslintFileResult {
#[serde(rename = "filePath")]
file_path: String,
messages: Vec<EslintMessage>,
}
#[derive(serde::Deserialize)]
struct EslintMessage {
#[serde(rename = "ruleId")]
rule_id: Option<String>,
severity: u8,
message: String,
line: u32,
}
// ── Ruff ────────────────────────────────────────────────
async fn run_ruff(repo_path: &Path, repo_id: &str) -> Result<Vec<Finding>, CoreError> {
let child = Command::new("ruff")
.args(["check", ".", "--output-format", "json", "--exit-zero"])
.current_dir(repo_path)
.stdout(std::process::Stdio::piped())
.stderr(std::process::Stdio::piped())
.spawn()
.map_err(|e| CoreError::Scanner {
scanner: "ruff".to_string(),
source: Box::new(e),
})?;
let output = run_with_timeout(child, "ruff").await?;
if output.stdout.is_empty() {
return Ok(Vec::new());
}
let results: Vec<RuffResult> = serde_json::from_slice(&output.stdout).unwrap_or_default();
let findings = results
.into_iter()
.map(|r| {
let severity = if r.code.starts_with('E') || r.code.starts_with('F') {
Severity::Medium
} else {
Severity::Low
};
let fingerprint = dedup::compute_fingerprint(&[
repo_id,
"ruff",
&r.code,
&r.filename,
&r.location.row.to_string(),
]);
let mut finding = Finding::new(
repo_id.to_string(),
fingerprint,
"ruff".to_string(),
ScanType::Lint,
format!("[ruff] {}: {}", r.code, r.message),
r.message,
severity,
);
finding.rule_id = Some(r.code);
finding.file_path = Some(r.filename);
finding.line_number = Some(r.location.row);
finding
})
.collect();
Ok(findings)
}
#[derive(serde::Deserialize)]
struct RuffResult {
code: String,
message: String,
filename: String,
location: RuffLocation,
}
#[derive(serde::Deserialize)]
struct RuffLocation {
row: u32,
}

View File

@@ -0,0 +1,251 @@
use std::path::Path;
use compliance_core::models::{Finding, ScanType, Severity};
use compliance_core::CoreError;
use tokio::process::Command;
use crate::pipeline::dedup;
use super::run_with_timeout;
pub(super) async fn run_clippy(repo_path: &Path, repo_id: &str) -> Result<Vec<Finding>, CoreError> {
let child = Command::new("cargo")
.args([
"clippy",
"--message-format=json",
"--quiet",
"--",
"-W",
"clippy::all",
])
.current_dir(repo_path)
.env("RUSTC_WRAPPER", "")
.stdout(std::process::Stdio::piped())
.stderr(std::process::Stdio::piped())
.spawn()
.map_err(|e| CoreError::Scanner {
scanner: "clippy".to_string(),
source: Box::new(e),
})?;
let output = run_with_timeout(child, "clippy").await?;
let stdout = String::from_utf8_lossy(&output.stdout);
let mut findings = Vec::new();
for line in stdout.lines() {
let msg: serde_json::Value = match serde_json::from_str(line) {
Ok(v) => v,
Err(_) => continue,
};
if msg.get("reason").and_then(|v| v.as_str()) != Some("compiler-message") {
continue;
}
let message = match msg.get("message") {
Some(m) => m,
None => continue,
};
let level = message.get("level").and_then(|v| v.as_str()).unwrap_or("");
if level != "warning" && level != "error" {
continue;
}
let text = message
.get("message")
.and_then(|v| v.as_str())
.unwrap_or("")
.to_string();
let code = message
.get("code")
.and_then(|v| v.get("code"))
.and_then(|v| v.as_str())
.unwrap_or("")
.to_string();
if text.starts_with("aborting due to") || code.is_empty() {
continue;
}
let (file_path, line_number) = extract_primary_span(message);
let severity = if level == "error" {
Severity::High
} else {
Severity::Low
};
let fingerprint = dedup::compute_fingerprint(&[
repo_id,
"clippy",
&code,
&file_path,
&line_number.to_string(),
]);
let mut finding = Finding::new(
repo_id.to_string(),
fingerprint,
"clippy".to_string(),
ScanType::Lint,
format!("[clippy] {text}"),
text,
severity,
);
finding.rule_id = Some(code);
if !file_path.is_empty() {
finding.file_path = Some(file_path);
}
if line_number > 0 {
finding.line_number = Some(line_number);
}
findings.push(finding);
}
Ok(findings)
}
fn extract_primary_span(message: &serde_json::Value) -> (String, u32) {
let spans = match message.get("spans").and_then(|v| v.as_array()) {
Some(s) => s,
None => return (String::new(), 0),
};
for span in spans {
if span.get("is_primary").and_then(|v| v.as_bool()) == Some(true) {
let file = span
.get("file_name")
.and_then(|v| v.as_str())
.unwrap_or("")
.to_string();
let line = span.get("line_start").and_then(|v| v.as_u64()).unwrap_or(0) as u32;
return (file, line);
}
}
(String::new(), 0)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn extract_primary_span_with_primary() {
let msg = serde_json::json!({
"spans": [
{
"file_name": "src/lib.rs",
"line_start": 42,
"is_primary": true
}
]
});
let (file, line) = extract_primary_span(&msg);
assert_eq!(file, "src/lib.rs");
assert_eq!(line, 42);
}
#[test]
fn extract_primary_span_no_primary() {
let msg = serde_json::json!({
"spans": [
{
"file_name": "src/lib.rs",
"line_start": 42,
"is_primary": false
}
]
});
let (file, line) = extract_primary_span(&msg);
assert_eq!(file, "");
assert_eq!(line, 0);
}
#[test]
fn extract_primary_span_multiple_spans() {
let msg = serde_json::json!({
"spans": [
{
"file_name": "src/other.rs",
"line_start": 10,
"is_primary": false
},
{
"file_name": "src/main.rs",
"line_start": 99,
"is_primary": true
}
]
});
let (file, line) = extract_primary_span(&msg);
assert_eq!(file, "src/main.rs");
assert_eq!(line, 99);
}
#[test]
fn extract_primary_span_no_spans() {
let msg = serde_json::json!({});
let (file, line) = extract_primary_span(&msg);
assert_eq!(file, "");
assert_eq!(line, 0);
}
#[test]
fn extract_primary_span_empty_spans() {
let msg = serde_json::json!({ "spans": [] });
let (file, line) = extract_primary_span(&msg);
assert_eq!(file, "");
assert_eq!(line, 0);
}
#[test]
fn parse_clippy_compiler_message_line() {
let line = r#"{"reason":"compiler-message","message":{"level":"warning","message":"unused variable","code":{"code":"unused_variables"},"spans":[{"file_name":"src/main.rs","line_start":5,"is_primary":true}]}}"#;
let msg: serde_json::Value = serde_json::from_str(line).unwrap();
assert_eq!(
msg.get("reason").and_then(|v| v.as_str()),
Some("compiler-message")
);
let message = msg.get("message").unwrap();
assert_eq!(
message.get("level").and_then(|v| v.as_str()),
Some("warning")
);
assert_eq!(
message.get("message").and_then(|v| v.as_str()),
Some("unused variable")
);
assert_eq!(
message
.get("code")
.and_then(|v| v.get("code"))
.and_then(|v| v.as_str()),
Some("unused_variables")
);
let (file, line_num) = extract_primary_span(message);
assert_eq!(file, "src/main.rs");
assert_eq!(line_num, 5);
}
#[test]
fn skip_non_compiler_message() {
let line = r#"{"reason":"build-script-executed","package_id":"foo 0.1.0"}"#;
let msg: serde_json::Value = serde_json::from_str(line).unwrap();
assert_ne!(
msg.get("reason").and_then(|v| v.as_str()),
Some("compiler-message")
);
}
#[test]
fn skip_aborting_message() {
let text = "aborting due to 3 previous errors";
assert!(text.starts_with("aborting due to"));
}
}

View File

@@ -0,0 +1,183 @@
use std::path::Path;
use compliance_core::models::{Finding, ScanType, Severity};
use compliance_core::CoreError;
use tokio::process::Command;
use crate::pipeline::dedup;
use super::run_with_timeout;
pub(super) async fn run_eslint(repo_path: &Path, repo_id: &str) -> Result<Vec<Finding>, CoreError> {
// Use the project-local eslint binary directly, not npx (which can hang downloading)
let eslint_bin = repo_path.join("node_modules/.bin/eslint");
let child = Command::new(eslint_bin)
.args([".", "--format", "json", "--no-error-on-unmatched-pattern"])
.current_dir(repo_path)
.stdout(std::process::Stdio::piped())
.stderr(std::process::Stdio::piped())
.spawn()
.map_err(|e| CoreError::Scanner {
scanner: "eslint".to_string(),
source: Box::new(e),
})?;
let output = run_with_timeout(child, "eslint").await?;
if output.stdout.is_empty() {
return Ok(Vec::new());
}
let results: Vec<EslintFileResult> = serde_json::from_slice(&output.stdout).unwrap_or_default();
let mut findings = Vec::new();
for file_result in results {
for msg in file_result.messages {
let severity = match msg.severity {
2 => Severity::Medium,
_ => Severity::Low,
};
let rule_id = msg.rule_id.unwrap_or_default();
let fingerprint = dedup::compute_fingerprint(&[
repo_id,
"eslint",
&rule_id,
&file_result.file_path,
&msg.line.to_string(),
]);
let mut finding = Finding::new(
repo_id.to_string(),
fingerprint,
"eslint".to_string(),
ScanType::Lint,
format!("[eslint] {}", msg.message),
msg.message,
severity,
);
finding.rule_id = Some(rule_id);
finding.file_path = Some(file_result.file_path.clone());
finding.line_number = Some(msg.line);
findings.push(finding);
}
}
Ok(findings)
}
#[derive(serde::Deserialize)]
struct EslintFileResult {
#[serde(rename = "filePath")]
file_path: String,
messages: Vec<EslintMessage>,
}
#[derive(serde::Deserialize)]
struct EslintMessage {
#[serde(rename = "ruleId")]
rule_id: Option<String>,
severity: u8,
message: String,
line: u32,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn deserialize_eslint_output() {
let json = r#"[
{
"filePath": "/home/user/project/src/app.js",
"messages": [
{
"ruleId": "no-unused-vars",
"severity": 2,
"message": "'x' is defined but never used.",
"line": 10
},
{
"ruleId": "semi",
"severity": 1,
"message": "Missing semicolon.",
"line": 15
}
]
}
]"#;
let results: Vec<EslintFileResult> = serde_json::from_str(json).unwrap();
assert_eq!(results.len(), 1);
assert_eq!(results[0].file_path, "/home/user/project/src/app.js");
assert_eq!(results[0].messages.len(), 2);
assert_eq!(
results[0].messages[0].rule_id,
Some("no-unused-vars".to_string())
);
assert_eq!(results[0].messages[0].severity, 2);
assert_eq!(results[0].messages[0].line, 10);
assert_eq!(results[0].messages[1].severity, 1);
}
#[test]
fn deserialize_eslint_null_rule_id() {
let json = r#"[
{
"filePath": "src/index.js",
"messages": [
{
"ruleId": null,
"severity": 2,
"message": "Parsing error: Unexpected token",
"line": 1
}
]
}
]"#;
let results: Vec<EslintFileResult> = serde_json::from_str(json).unwrap();
assert_eq!(results[0].messages[0].rule_id, None);
}
#[test]
fn deserialize_eslint_empty_messages() {
let json = r#"[{"filePath": "src/clean.js", "messages": []}]"#;
let results: Vec<EslintFileResult> = serde_json::from_str(json).unwrap();
assert_eq!(results[0].messages.len(), 0);
}
#[test]
fn deserialize_eslint_empty_array() {
let json = "[]";
let results: Vec<EslintFileResult> = serde_json::from_str(json).unwrap();
assert!(results.is_empty());
}
#[test]
fn eslint_severity_mapping() {
// severity 2 = error -> Medium, anything else -> Low
assert_eq!(
match 2u8 {
2 => "Medium",
_ => "Low",
},
"Medium"
);
assert_eq!(
match 1u8 {
2 => "Medium",
_ => "Low",
},
"Low"
);
assert_eq!(
match 0u8 {
2 => "Medium",
_ => "Low",
},
"Low"
);
}
}

View File

@@ -0,0 +1,97 @@
mod clippy;
mod eslint;
mod ruff;
use std::path::Path;
use std::time::Duration;
use compliance_core::models::ScanType;
use compliance_core::traits::{ScanOutput, Scanner};
use compliance_core::CoreError;
/// Timeout for each individual lint command
pub(crate) const LINT_TIMEOUT: Duration = Duration::from_secs(120);
pub struct LintScanner;
impl Scanner for LintScanner {
fn name(&self) -> &str {
"lint"
}
fn scan_type(&self) -> ScanType {
ScanType::Lint
}
#[tracing::instrument(skip_all)]
async fn scan(&self, repo_path: &Path, repo_id: &str) -> Result<ScanOutput, CoreError> {
let mut all_findings = Vec::new();
// Detect which languages are present and run appropriate linters
if has_rust_project(repo_path) {
match clippy::run_clippy(repo_path, repo_id).await {
Ok(findings) => all_findings.extend(findings),
Err(e) => tracing::warn!("Clippy failed: {e}"),
}
}
if has_js_project(repo_path) {
match eslint::run_eslint(repo_path, repo_id).await {
Ok(findings) => all_findings.extend(findings),
Err(e) => tracing::warn!("ESLint failed: {e}"),
}
}
if has_python_project(repo_path) {
match ruff::run_ruff(repo_path, repo_id).await {
Ok(findings) => all_findings.extend(findings),
Err(e) => tracing::warn!("Ruff failed: {e}"),
}
}
Ok(ScanOutput {
findings: all_findings,
sbom_entries: Vec::new(),
})
}
}
fn has_rust_project(repo_path: &Path) -> bool {
repo_path.join("Cargo.toml").exists()
}
fn has_js_project(repo_path: &Path) -> bool {
// Only run if eslint is actually installed in the project
repo_path.join("package.json").exists() && repo_path.join("node_modules/.bin/eslint").exists()
}
fn has_python_project(repo_path: &Path) -> bool {
repo_path.join("pyproject.toml").exists()
|| repo_path.join("setup.py").exists()
|| repo_path.join("requirements.txt").exists()
}
/// Run a command with a timeout, returning its output or an error
pub(crate) async fn run_with_timeout(
child: tokio::process::Child,
scanner_name: &str,
) -> Result<std::process::Output, CoreError> {
let result = tokio::time::timeout(LINT_TIMEOUT, child.wait_with_output()).await;
match result {
Ok(Ok(output)) => Ok(output),
Ok(Err(e)) => Err(CoreError::Scanner {
scanner: scanner_name.to_string(),
source: Box::new(e),
}),
Err(_) => {
// Process is dropped here which sends SIGKILL on Unix
Err(CoreError::Scanner {
scanner: scanner_name.to_string(),
source: Box::new(std::io::Error::new(
std::io::ErrorKind::TimedOut,
format!("{scanner_name} timed out after {}s", LINT_TIMEOUT.as_secs()),
)),
})
}
}
}

View File

@@ -0,0 +1,150 @@
use std::path::Path;
use compliance_core::models::{Finding, ScanType, Severity};
use compliance_core::CoreError;
use tokio::process::Command;
use crate::pipeline::dedup;
use super::run_with_timeout;
pub(super) async fn run_ruff(repo_path: &Path, repo_id: &str) -> Result<Vec<Finding>, CoreError> {
let child = Command::new("ruff")
.args(["check", ".", "--output-format", "json", "--exit-zero"])
.current_dir(repo_path)
.stdout(std::process::Stdio::piped())
.stderr(std::process::Stdio::piped())
.spawn()
.map_err(|e| CoreError::Scanner {
scanner: "ruff".to_string(),
source: Box::new(e),
})?;
let output = run_with_timeout(child, "ruff").await?;
if output.stdout.is_empty() {
return Ok(Vec::new());
}
let results: Vec<RuffResult> = serde_json::from_slice(&output.stdout).unwrap_or_default();
let findings = results
.into_iter()
.map(|r| {
let severity = if r.code.starts_with('E') || r.code.starts_with('F') {
Severity::Medium
} else {
Severity::Low
};
let fingerprint = dedup::compute_fingerprint(&[
repo_id,
"ruff",
&r.code,
&r.filename,
&r.location.row.to_string(),
]);
let mut finding = Finding::new(
repo_id.to_string(),
fingerprint,
"ruff".to_string(),
ScanType::Lint,
format!("[ruff] {}: {}", r.code, r.message),
r.message,
severity,
);
finding.rule_id = Some(r.code);
finding.file_path = Some(r.filename);
finding.line_number = Some(r.location.row);
finding
})
.collect();
Ok(findings)
}
#[derive(serde::Deserialize)]
struct RuffResult {
code: String,
message: String,
filename: String,
location: RuffLocation,
}
#[derive(serde::Deserialize)]
struct RuffLocation {
row: u32,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn deserialize_ruff_output() {
let json = r#"[
{
"code": "E501",
"message": "Line too long (120 > 79 characters)",
"filename": "src/main.py",
"location": {"row": 42}
},
{
"code": "F401",
"message": "`os` imported but unused",
"filename": "src/utils.py",
"location": {"row": 1}
}
]"#;
let results: Vec<RuffResult> = serde_json::from_str(json).unwrap();
assert_eq!(results.len(), 2);
assert_eq!(results[0].code, "E501");
assert_eq!(results[0].filename, "src/main.py");
assert_eq!(results[0].location.row, 42);
assert_eq!(results[1].code, "F401");
assert_eq!(results[1].location.row, 1);
}
#[test]
fn deserialize_ruff_empty() {
let json = "[]";
let results: Vec<RuffResult> = serde_json::from_str(json).unwrap();
assert!(results.is_empty());
}
#[test]
fn ruff_severity_e_and_f_are_medium() {
for code in &["E501", "E302", "F401", "F811"] {
let is_medium = code.starts_with('E') || code.starts_with('F');
assert!(is_medium, "Expected {code} to be Medium severity");
}
}
#[test]
fn ruff_severity_others_are_low() {
for code in &["W291", "I001", "D100", "C901", "N801"] {
let is_medium = code.starts_with('E') || code.starts_with('F');
assert!(!is_medium, "Expected {code} to be Low severity");
}
}
#[test]
fn deserialize_ruff_with_extra_fields() {
// Ruff output may contain additional fields we don't use
let json = r#"[{
"code": "W291",
"message": "Trailing whitespace",
"filename": "app.py",
"location": {"row": 3, "column": 10},
"end_location": {"row": 3, "column": 11},
"fix": null,
"noqa_row": 3
}]"#;
let results: Vec<RuffResult> = serde_json::from_str(json).unwrap();
assert_eq!(results.len(), 1);
assert_eq!(results[0].code, "W291");
}
}

View File

@@ -3,8 +3,12 @@ pub mod cve;
pub mod dedup;
pub mod git;
pub mod gitleaks;
mod graph_build;
mod issue_creation;
pub mod lint;
pub mod orchestrator;
pub mod patterns;
mod pr_review;
pub mod sbom;
pub mod semgrep;
mod tracker_dispatch;

View File

@@ -4,7 +4,6 @@ use mongodb::bson::doc;
use tracing::Instrument;
use compliance_core::models::*;
use compliance_core::traits::issue_tracker::IssueTracker;
use compliance_core::traits::Scanner;
use compliance_core::AgentConfig;
@@ -19,84 +18,6 @@ use crate::pipeline::lint::LintScanner;
use crate::pipeline::patterns::{GdprPatternScanner, OAuthPatternScanner};
use crate::pipeline::sbom::SbomScanner;
use crate::pipeline::semgrep::SemgrepScanner;
use crate::trackers;
/// Enum dispatch for issue trackers (async traits aren't dyn-compatible).
enum TrackerDispatch {
GitHub(trackers::github::GitHubTracker),
GitLab(trackers::gitlab::GitLabTracker),
Gitea(trackers::gitea::GiteaTracker),
Jira(trackers::jira::JiraTracker),
}
impl TrackerDispatch {
fn name(&self) -> &str {
match self {
Self::GitHub(t) => t.name(),
Self::GitLab(t) => t.name(),
Self::Gitea(t) => t.name(),
Self::Jira(t) => t.name(),
}
}
async fn create_issue(
&self,
owner: &str,
repo: &str,
title: &str,
body: &str,
labels: &[String],
) -> Result<TrackerIssue, compliance_core::error::CoreError> {
match self {
Self::GitHub(t) => t.create_issue(owner, repo, title, body, labels).await,
Self::GitLab(t) => t.create_issue(owner, repo, title, body, labels).await,
Self::Gitea(t) => t.create_issue(owner, repo, title, body, labels).await,
Self::Jira(t) => t.create_issue(owner, repo, title, body, labels).await,
}
}
async fn find_existing_issue(
&self,
owner: &str,
repo: &str,
fingerprint: &str,
) -> Result<Option<TrackerIssue>, compliance_core::error::CoreError> {
match self {
Self::GitHub(t) => t.find_existing_issue(owner, repo, fingerprint).await,
Self::GitLab(t) => t.find_existing_issue(owner, repo, fingerprint).await,
Self::Gitea(t) => t.find_existing_issue(owner, repo, fingerprint).await,
Self::Jira(t) => t.find_existing_issue(owner, repo, fingerprint).await,
}
}
async fn create_pr_review(
&self,
owner: &str,
repo: &str,
pr_number: u64,
body: &str,
comments: Vec<compliance_core::traits::issue_tracker::ReviewComment>,
) -> Result<(), compliance_core::error::CoreError> {
match self {
Self::GitHub(t) => {
t.create_pr_review(owner, repo, pr_number, body, comments)
.await
}
Self::GitLab(t) => {
t.create_pr_review(owner, repo, pr_number, body, comments)
.await
}
Self::Gitea(t) => {
t.create_pr_review(owner, repo, pr_number, body, comments)
.await
}
Self::Jira(t) => {
t.create_pr_review(owner, repo, pr_number, body, comments)
.await
}
}
}
}
/// Context from graph analysis passed to LLM triage for enhanced filtering
#[derive(Debug)]
@@ -109,10 +30,10 @@ pub struct GraphContext {
}
pub struct PipelineOrchestrator {
config: AgentConfig,
db: Database,
llm: Arc<LlmClient>,
http: reqwest::Client,
pub(super) config: AgentConfig,
pub(super) db: Database,
pub(super) llm: Arc<LlmClient>,
pub(super) http: reqwest::Client,
}
impl PipelineOrchestrator {
@@ -460,446 +381,7 @@ impl PipelineOrchestrator {
Ok(new_count)
}
/// Build the code knowledge graph for a repo and compute impact analyses
async fn build_code_graph(
&self,
repo_path: &std::path::Path,
repo_id: &str,
findings: &[Finding],
) -> Result<GraphContext, AgentError> {
let graph_build_id = uuid::Uuid::new_v4().to_string();
let engine = compliance_graph::GraphEngine::new(50_000);
let (mut code_graph, build_run) =
engine
.build_graph(repo_path, repo_id, &graph_build_id)
.map_err(|e| AgentError::Other(format!("Graph build error: {e}")))?;
// Apply community detection
compliance_graph::graph::community::apply_communities(&mut code_graph);
// Store graph in MongoDB
let store = compliance_graph::graph::persistence::GraphStore::new(self.db.inner());
store
.delete_repo_graph(repo_id)
.await
.map_err(|e| AgentError::Other(format!("Graph cleanup error: {e}")))?;
store
.store_graph(&build_run, &code_graph.nodes, &code_graph.edges)
.await
.map_err(|e| AgentError::Other(format!("Graph store error: {e}")))?;
// Compute impact analysis for each finding
let analyzer = compliance_graph::GraphEngine::impact_analyzer(&code_graph);
let mut impacts = Vec::new();
for finding in findings {
if let Some(file_path) = &finding.file_path {
let impact = analyzer.analyze(
repo_id,
&finding.fingerprint,
&graph_build_id,
file_path,
finding.line_number,
);
store
.store_impact(&impact)
.await
.map_err(|e| AgentError::Other(format!("Impact store error: {e}")))?;
impacts.push(impact);
}
}
Ok(GraphContext {
node_count: build_run.node_count,
edge_count: build_run.edge_count,
community_count: build_run.community_count,
impacts,
})
}
/// Trigger DAST scan if a target is configured for this repo
async fn maybe_trigger_dast(&self, repo_id: &str, scan_run_id: &str) {
use futures_util::TryStreamExt;
let filter = mongodb::bson::doc! { "repo_id": repo_id };
let targets: Vec<compliance_core::models::DastTarget> =
match self.db.dast_targets().find(filter).await {
Ok(cursor) => cursor.try_collect().await.unwrap_or_default(),
Err(_) => return,
};
if targets.is_empty() {
tracing::info!("[{repo_id}] No DAST targets configured, skipping");
return;
}
for target in targets {
let db = self.db.clone();
let scan_run_id = scan_run_id.to_string();
tokio::spawn(async move {
let orchestrator = compliance_dast::DastOrchestrator::new(100);
match orchestrator.run_scan(&target, Vec::new()).await {
Ok((mut scan_run, findings)) => {
scan_run.sast_scan_run_id = Some(scan_run_id);
if let Err(e) = db.dast_scan_runs().insert_one(&scan_run).await {
tracing::error!("Failed to store DAST scan run: {e}");
}
for finding in &findings {
if let Err(e) = db.dast_findings().insert_one(finding).await {
tracing::error!("Failed to store DAST finding: {e}");
}
}
tracing::info!("DAST scan complete: {} findings", findings.len());
}
Err(e) => {
tracing::error!("DAST scan failed: {e}");
}
}
});
}
}
/// Build an issue tracker client from a repository's tracker configuration.
/// Returns `None` if the repo has no tracker configured.
fn build_tracker(&self, repo: &TrackedRepository) -> Option<TrackerDispatch> {
let tracker_type = repo.tracker_type.as_ref()?;
// Per-repo token takes precedence, fall back to global config
match tracker_type {
TrackerType::GitHub => {
let token = repo.tracker_token.clone().or_else(|| {
self.config.github_token.as_ref().map(|t| {
use secrecy::ExposeSecret;
t.expose_secret().to_string()
})
})?;
let secret = secrecy::SecretString::from(token);
match trackers::github::GitHubTracker::new(&secret) {
Ok(t) => Some(TrackerDispatch::GitHub(t)),
Err(e) => {
tracing::warn!("Failed to build GitHub tracker: {e}");
None
}
}
}
TrackerType::GitLab => {
let base_url = self
.config
.gitlab_url
.clone()
.unwrap_or_else(|| "https://gitlab.com".to_string());
let token = repo.tracker_token.clone().or_else(|| {
self.config.gitlab_token.as_ref().map(|t| {
use secrecy::ExposeSecret;
t.expose_secret().to_string()
})
})?;
let secret = secrecy::SecretString::from(token);
Some(TrackerDispatch::GitLab(
trackers::gitlab::GitLabTracker::new(base_url, secret),
))
}
TrackerType::Gitea => {
let token = repo.tracker_token.clone()?;
let base_url = extract_base_url(&repo.git_url)?;
let secret = secrecy::SecretString::from(token);
Some(TrackerDispatch::Gitea(trackers::gitea::GiteaTracker::new(
base_url, secret,
)))
}
TrackerType::Jira => {
let base_url = self.config.jira_url.clone()?;
let email = self.config.jira_email.clone()?;
let project_key = self.config.jira_project_key.clone()?;
let token = repo.tracker_token.clone().or_else(|| {
self.config.jira_api_token.as_ref().map(|t| {
use secrecy::ExposeSecret;
t.expose_secret().to_string()
})
})?;
let secret = secrecy::SecretString::from(token);
Some(TrackerDispatch::Jira(trackers::jira::JiraTracker::new(
base_url,
email,
secret,
project_key,
)))
}
}
}
/// Create tracker issues for new findings (severity >= Medium).
/// Checks for duplicates via fingerprint search before creating.
#[tracing::instrument(skip_all, fields(repo_id = %repo_id))]
async fn create_tracker_issues(
&self,
repo: &TrackedRepository,
repo_id: &str,
new_findings: &[Finding],
) -> Result<(), AgentError> {
let tracker = match self.build_tracker(repo) {
Some(t) => t,
None => {
tracing::info!("[{repo_id}] No issue tracker configured, skipping");
return Ok(());
}
};
let owner = match repo.tracker_owner.as_deref() {
Some(o) => o,
None => {
tracing::warn!("[{repo_id}] tracker_owner not set, skipping issue creation");
return Ok(());
}
};
let tracker_repo_name = match repo.tracker_repo.as_deref() {
Some(r) => r,
None => {
tracing::warn!("[{repo_id}] tracker_repo not set, skipping issue creation");
return Ok(());
}
};
// Only create issues for medium+ severity findings
let actionable: Vec<&Finding> = new_findings
.iter()
.filter(|f| {
matches!(
f.severity,
Severity::Medium | Severity::High | Severity::Critical
)
})
.collect();
if actionable.is_empty() {
tracing::info!("[{repo_id}] No medium+ findings, skipping issue creation");
return Ok(());
}
tracing::info!(
"[{repo_id}] Creating issues for {} findings via {}",
actionable.len(),
tracker.name()
);
let mut created = 0u32;
for finding in actionable {
let title = format!(
"[{}] {}: {}",
finding.severity, finding.scanner, finding.title
);
// Check if an issue already exists by fingerprint first, then by title
let mut found_existing = false;
for search_term in [&finding.fingerprint, &title] {
match tracker
.find_existing_issue(owner, tracker_repo_name, search_term)
.await
{
Ok(Some(existing)) => {
tracing::debug!(
"[{repo_id}] Issue already exists for '{}': {}",
search_term,
existing.external_url
);
found_existing = true;
break;
}
Ok(None) => {}
Err(e) => {
tracing::warn!("[{repo_id}] Failed to search for existing issue: {e}");
}
}
}
if found_existing {
continue;
}
let body = format_issue_body(finding);
let labels = vec![
format!("severity:{}", finding.severity),
format!("scanner:{}", finding.scanner),
"compliance-scanner".to_string(),
];
match tracker
.create_issue(owner, tracker_repo_name, &title, &body, &labels)
.await
{
Ok(mut issue) => {
issue.finding_id = finding
.id
.as_ref()
.map(|id| id.to_hex())
.unwrap_or_default();
// Update the finding with the issue URL
if let Some(finding_id) = &finding.id {
let _ = self
.db
.findings()
.update_one(
doc! { "_id": finding_id },
doc! { "$set": { "tracker_issue_url": &issue.external_url } },
)
.await;
}
// Store the tracker issue record
if let Err(e) = self.db.tracker_issues().insert_one(&issue).await {
tracing::warn!("[{repo_id}] Failed to store tracker issue: {e}");
}
created += 1;
}
Err(e) => {
tracing::warn!(
"[{repo_id}] Failed to create issue for {}: {e}",
finding.fingerprint
);
}
}
}
tracing::info!("[{repo_id}] Created {created} tracker issues");
Ok(())
}
/// Run an incremental scan on a PR diff and post review comments.
#[tracing::instrument(skip_all, fields(repo_id = %repo_id, pr_number))]
pub async fn run_pr_review(
&self,
repo: &TrackedRepository,
repo_id: &str,
pr_number: u64,
base_sha: &str,
head_sha: &str,
) -> Result<(), AgentError> {
let tracker = match self.build_tracker(repo) {
Some(t) => t,
None => {
tracing::warn!("[{repo_id}] No tracker configured, cannot post PR review");
return Ok(());
}
};
let owner = repo.tracker_owner.as_deref().unwrap_or("");
let tracker_repo_name = repo.tracker_repo.as_deref().unwrap_or("");
if owner.is_empty() || tracker_repo_name.is_empty() {
tracing::warn!("[{repo_id}] tracker_owner or tracker_repo not set");
return Ok(());
}
// Clone/fetch the repo
let creds = GitOps::make_repo_credentials(&self.config, repo);
let git_ops = GitOps::new(&self.config.git_clone_base_path, creds);
let repo_path = git_ops.clone_or_fetch(&repo.git_url, &repo.name)?;
// Get diff between base and head
let diff_files = GitOps::get_diff_content(&repo_path, base_sha, head_sha)?;
if diff_files.is_empty() {
tracing::info!("[{repo_id}] PR #{pr_number}: no diff files, skipping review");
return Ok(());
}
// Run semgrep on the full repo but we'll filter findings to changed files
let changed_paths: std::collections::HashSet<String> =
diff_files.iter().map(|f| f.path.clone()).collect();
let mut pr_findings: Vec<Finding> = Vec::new();
// SAST scan (semgrep)
match SemgrepScanner.scan(&repo_path, repo_id).await {
Ok(output) => {
for f in output.findings {
if let Some(fp) = &f.file_path {
if changed_paths.contains(fp.as_str()) {
pr_findings.push(f);
}
}
}
}
Err(e) => tracing::warn!("[{repo_id}] PR semgrep failed: {e}"),
}
// LLM code review on the diff
let reviewer = CodeReviewScanner::new(self.llm.clone());
let review_output = reviewer
.review_diff(&repo_path, repo_id, base_sha, head_sha)
.await;
pr_findings.extend(review_output.findings);
if pr_findings.is_empty() {
// Post a clean review
if let Err(e) = tracker
.create_pr_review(
owner,
tracker_repo_name,
pr_number,
"Compliance scan: no issues found in this PR.",
Vec::new(),
)
.await
{
tracing::warn!("[{repo_id}] Failed to post clean PR review: {e}");
}
return Ok(());
}
// Build review comments from findings
let mut review_comments = Vec::new();
for finding in &pr_findings {
if let (Some(path), Some(line)) = (&finding.file_path, finding.line_number) {
let comment_body = format!(
"**[{}] {}**\n\n{}\n\n*Scanner: {} | {}*",
finding.severity,
finding.title,
finding.description,
finding.scanner,
finding
.cwe
.as_deref()
.map(|c| format!("CWE: {c}"))
.unwrap_or_default(),
);
review_comments.push(compliance_core::traits::issue_tracker::ReviewComment {
path: path.clone(),
line,
body: comment_body,
});
}
}
let summary = format!(
"Compliance scan found **{}** issue(s) in this PR:\n\n{}",
pr_findings.len(),
pr_findings
.iter()
.map(|f| format!("- **[{}]** {}: {}", f.severity, f.scanner, f.title))
.collect::<Vec<_>>()
.join("\n"),
);
if let Err(e) = tracker
.create_pr_review(
owner,
tracker_repo_name,
pr_number,
&summary,
review_comments,
)
.await
{
tracing::warn!("[{repo_id}] Failed to post PR review: {e}");
} else {
tracing::info!(
"[{repo_id}] Posted PR review on #{pr_number} with {} findings",
pr_findings.len()
);
}
Ok(())
}
async fn update_phase(&self, scan_run_id: &str, phase: &str) {
pub(super) async fn update_phase(&self, scan_run_id: &str, phase: &str) {
if let Ok(oid) = mongodb::bson::oid::ObjectId::parse_str(scan_run_id) {
let _ = self
.db
@@ -917,9 +399,9 @@ impl PipelineOrchestrator {
}
/// Extract the scheme + host from a git URL.
/// e.g. "https://gitea.example.com/owner/repo.git" "https://gitea.example.com"
/// e.g. "ssh://git@gitea.example.com:22/owner/repo.git" "https://gitea.example.com"
fn extract_base_url(git_url: &str) -> Option<String> {
/// e.g. "https://gitea.example.com/owner/repo.git" -> "https://gitea.example.com"
/// e.g. "ssh://git@gitea.example.com:22/owner/repo.git" -> "https://gitea.example.com"
pub(super) fn extract_base_url(git_url: &str) -> Option<String> {
if let Some(rest) = git_url.strip_prefix("https://") {
let host = rest.split('/').next()?;
Some(format!("https://{host}"))
@@ -927,7 +409,7 @@ fn extract_base_url(git_url: &str) -> Option<String> {
let host = rest.split('/').next()?;
Some(format!("http://{host}"))
} else if let Some(rest) = git_url.strip_prefix("ssh://") {
// ssh://git@host:port/path extract host
// ssh://git@host:port/path -> extract host
let after_at = rest.find('@').map(|i| &rest[i + 1..]).unwrap_or(rest);
let host = after_at.split(&[':', '/'][..]).next()?;
Some(format!("https://{host}"))
@@ -940,48 +422,3 @@ fn extract_base_url(git_url: &str) -> Option<String> {
None
}
}
/// Format a finding into a markdown issue body for the tracker.
fn format_issue_body(finding: &Finding) -> String {
let mut body = String::new();
body.push_str(&format!("## {} Finding\n\n", finding.severity));
body.push_str(&format!("**Scanner:** {}\n", finding.scanner));
body.push_str(&format!("**Severity:** {}\n", finding.severity));
if let Some(rule) = &finding.rule_id {
body.push_str(&format!("**Rule:** {}\n", rule));
}
if let Some(cwe) = &finding.cwe {
body.push_str(&format!("**CWE:** {}\n", cwe));
}
body.push_str(&format!("\n### Description\n\n{}\n", finding.description));
if let Some(file_path) = &finding.file_path {
body.push_str(&format!("\n### Location\n\n**File:** `{}`", file_path));
if let Some(line) = finding.line_number {
body.push_str(&format!(" (line {})", line));
}
body.push('\n');
}
if let Some(snippet) = &finding.code_snippet {
body.push_str(&format!("\n### Code\n\n```\n{}\n```\n", snippet));
}
if let Some(remediation) = &finding.remediation {
body.push_str(&format!("\n### Remediation\n\n{}\n", remediation));
}
if let Some(fix) = &finding.suggested_fix {
body.push_str(&format!("\n### Suggested Fix\n\n```\n{}\n```\n", fix));
}
body.push_str(&format!(
"\n---\n*Fingerprint:* `{}`\n*Generated by compliance-scanner*",
finding.fingerprint
));
body
}

View File

@@ -256,3 +256,159 @@ fn walkdir(path: &Path) -> Result<Vec<walkdir::DirEntry>, CoreError> {
Ok(entries)
}
#[cfg(test)]
mod tests {
use super::*;
// --- compile_regex tests ---
#[test]
fn compile_regex_valid_pattern() {
let re = compile_regex(r"\bfoo\b");
assert!(re.is_match("hello foo bar"));
assert!(!re.is_match("foobar"));
}
#[test]
fn compile_regex_invalid_pattern_returns_fallback() {
// An invalid regex should return the fallback "^$" that only matches empty strings
let re = compile_regex(r"[invalid");
assert!(re.is_match(""));
assert!(!re.is_match("anything"));
}
// --- GDPR pattern tests ---
#[test]
fn gdpr_pii_logging_matches() {
let scanner = GdprPatternScanner::new();
let pattern = &scanner.patterns[0]; // gdpr-pii-logging
// Regex: (log|print|console\.|logger\.|tracing::)\s*[\.(].*\b(pii_keyword)\b
assert!(pattern.pattern.is_match("console.log(email)"));
assert!(pattern.pattern.is_match("console.log(user.ssn)"));
assert!(pattern.pattern.is_match("print(phone_number)"));
assert!(pattern.pattern.is_match("tracing::(ip_addr)"));
assert!(pattern.pattern.is_match("log.debug(credit_card)"));
}
#[test]
fn gdpr_pii_logging_no_false_positive() {
let scanner = GdprPatternScanner::new();
let pattern = &scanner.patterns[0];
// Regular logging without PII fields should not match
assert!(!pattern
.pattern
.is_match("logger.info(\"request completed\")"));
assert!(!pattern.pattern.is_match("let email = user.email;"));
}
#[test]
fn gdpr_no_consent_matches() {
let scanner = GdprPatternScanner::new();
let pattern = &scanner.patterns[1]; // gdpr-no-consent
assert!(pattern.pattern.is_match("collect personal data"));
assert!(pattern.pattern.is_match("store user_data in db"));
assert!(pattern.pattern.is_match("save pii to disk"));
}
#[test]
fn gdpr_user_model_matches() {
let scanner = GdprPatternScanner::new();
let pattern = &scanner.patterns[2]; // gdpr-no-delete-endpoint
assert!(pattern.pattern.is_match("struct User {"));
assert!(pattern.pattern.is_match("class User(Model):"));
}
#[test]
fn gdpr_hardcoded_retention_matches() {
let scanner = GdprPatternScanner::new();
let pattern = &scanner.patterns[3]; // gdpr-hardcoded-retention
assert!(pattern.pattern.is_match("retention = 30"));
assert!(pattern.pattern.is_match("ttl: 3600"));
assert!(pattern.pattern.is_match("expire = 86400"));
}
// --- OAuth pattern tests ---
#[test]
fn oauth_implicit_grant_matches() {
let scanner = OAuthPatternScanner::new();
let pattern = &scanner.patterns[0]; // oauth-implicit-grant
assert!(pattern.pattern.is_match("response_type = \"token\""));
assert!(pattern.pattern.is_match("grant_type: implicit"));
assert!(pattern.pattern.is_match("response_type='token'"));
}
#[test]
fn oauth_implicit_grant_no_false_positive() {
let scanner = OAuthPatternScanner::new();
let pattern = &scanner.patterns[0];
assert!(!pattern.pattern.is_match("response_type = \"code\""));
assert!(!pattern.pattern.is_match("grant_type: authorization_code"));
}
#[test]
fn oauth_authorization_code_matches() {
let scanner = OAuthPatternScanner::new();
let pattern = &scanner.patterns[1]; // oauth-missing-pkce
assert!(pattern.pattern.is_match("uses authorization_code flow"));
assert!(pattern.pattern.is_match("authorization code grant"));
}
#[test]
fn oauth_token_localstorage_matches() {
let scanner = OAuthPatternScanner::new();
let pattern = &scanner.patterns[2]; // oauth-token-localstorage
assert!(pattern
.pattern
.is_match("localStorage.setItem('access_token', tok)"));
assert!(pattern
.pattern
.is_match("localStorage.getItem(\"refresh_token\")"));
}
#[test]
fn oauth_token_localstorage_no_false_positive() {
let scanner = OAuthPatternScanner::new();
let pattern = &scanner.patterns[2];
assert!(!pattern
.pattern
.is_match("localStorage.setItem('theme', 'dark')"));
assert!(!pattern
.pattern
.is_match("sessionStorage.setItem('token', t)"));
}
#[test]
fn oauth_token_url_matches() {
let scanner = OAuthPatternScanner::new();
let pattern = &scanner.patterns[3]; // oauth-token-url
assert!(pattern.pattern.is_match("access_token = build_url(query)"));
assert!(pattern.pattern.is_match("bearer = url.param"));
}
// --- Pattern rule file extension filtering ---
#[test]
fn gdpr_patterns_cover_common_languages() {
let scanner = GdprPatternScanner::new();
for pattern in &scanner.patterns {
assert!(
pattern.file_extensions.contains(&"rs".to_string()),
"Pattern {} should cover .rs files",
pattern.id
);
}
}
#[test]
fn oauth_localstorage_only_js_ts() {
let scanner = OAuthPatternScanner::new();
let pattern = &scanner.patterns[2]; // oauth-token-localstorage
assert!(pattern.file_extensions.contains(&"js".to_string()));
assert!(pattern.file_extensions.contains(&"ts".to_string()));
assert!(!pattern.file_extensions.contains(&"rs".to_string()));
assert!(!pattern.file_extensions.contains(&"py".to_string()));
}
}

View File

@@ -0,0 +1,146 @@
use compliance_core::models::*;
use super::orchestrator::PipelineOrchestrator;
use crate::error::AgentError;
use crate::pipeline::code_review::CodeReviewScanner;
use crate::pipeline::git::GitOps;
use crate::pipeline::semgrep::SemgrepScanner;
use compliance_core::traits::Scanner;
impl PipelineOrchestrator {
/// Run an incremental scan on a PR diff and post review comments.
#[tracing::instrument(skip_all, fields(repo_id = %repo_id, pr_number))]
pub async fn run_pr_review(
&self,
repo: &TrackedRepository,
repo_id: &str,
pr_number: u64,
base_sha: &str,
head_sha: &str,
) -> Result<(), AgentError> {
let tracker = match self.build_tracker(repo) {
Some(t) => t,
None => {
tracing::warn!("[{repo_id}] No tracker configured, cannot post PR review");
return Ok(());
}
};
let owner = repo.tracker_owner.as_deref().unwrap_or("");
let tracker_repo_name = repo.tracker_repo.as_deref().unwrap_or("");
if owner.is_empty() || tracker_repo_name.is_empty() {
tracing::warn!("[{repo_id}] tracker_owner or tracker_repo not set");
return Ok(());
}
// Clone/fetch the repo
let creds = GitOps::make_repo_credentials(&self.config, repo);
let git_ops = GitOps::new(&self.config.git_clone_base_path, creds);
let repo_path = git_ops.clone_or_fetch(&repo.git_url, &repo.name)?;
// Get diff between base and head
let diff_files = GitOps::get_diff_content(&repo_path, base_sha, head_sha)?;
if diff_files.is_empty() {
tracing::info!("[{repo_id}] PR #{pr_number}: no diff files, skipping review");
return Ok(());
}
// Run semgrep on the full repo but we'll filter findings to changed files
let changed_paths: std::collections::HashSet<String> =
diff_files.iter().map(|f| f.path.clone()).collect();
let mut pr_findings: Vec<Finding> = Vec::new();
// SAST scan (semgrep)
match SemgrepScanner.scan(&repo_path, repo_id).await {
Ok(output) => {
for f in output.findings {
if let Some(fp) = &f.file_path {
if changed_paths.contains(fp.as_str()) {
pr_findings.push(f);
}
}
}
}
Err(e) => tracing::warn!("[{repo_id}] PR semgrep failed: {e}"),
}
// LLM code review on the diff
let reviewer = CodeReviewScanner::new(self.llm.clone());
let review_output = reviewer
.review_diff(&repo_path, repo_id, base_sha, head_sha)
.await;
pr_findings.extend(review_output.findings);
if pr_findings.is_empty() {
// Post a clean review
if let Err(e) = tracker
.create_pr_review(
owner,
tracker_repo_name,
pr_number,
"Compliance scan: no issues found in this PR.",
Vec::new(),
)
.await
{
tracing::warn!("[{repo_id}] Failed to post clean PR review: {e}");
}
return Ok(());
}
// Build review comments from findings
let mut review_comments = Vec::new();
for finding in &pr_findings {
if let (Some(path), Some(line)) = (&finding.file_path, finding.line_number) {
let comment_body = format!(
"**[{}] {}**\n\n{}\n\n*Scanner: {} | {}*",
finding.severity,
finding.title,
finding.description,
finding.scanner,
finding
.cwe
.as_deref()
.map(|c| format!("CWE: {c}"))
.unwrap_or_default(),
);
review_comments.push(compliance_core::traits::issue_tracker::ReviewComment {
path: path.clone(),
line,
body: comment_body,
});
}
}
let summary = format!(
"Compliance scan found **{}** issue(s) in this PR:\n\n{}",
pr_findings.len(),
pr_findings
.iter()
.map(|f| format!("- **[{}]** {}: {}", f.severity, f.scanner, f.title))
.collect::<Vec<_>>()
.join("\n"),
);
if let Err(e) = tracker
.create_pr_review(
owner,
tracker_repo_name,
pr_number,
&summary,
review_comments,
)
.await
{
tracing::warn!("[{repo_id}] Failed to post PR review: {e}");
} else {
tracing::info!(
"[{repo_id}] Posted PR review on #{pr_number} with {} findings",
pr_findings.len()
);
}
Ok(())
}
}

View File

@@ -0,0 +1,72 @@
use std::path::Path;
use compliance_core::CoreError;
pub(super) struct AuditVuln {
pub package: String,
pub id: String,
pub url: String,
}
#[tracing::instrument(skip_all)]
pub(super) async fn run_cargo_audit(
repo_path: &Path,
_repo_id: &str,
) -> Result<Vec<AuditVuln>, CoreError> {
let cargo_lock = repo_path.join("Cargo.lock");
if !cargo_lock.exists() {
return Ok(Vec::new());
}
let output = tokio::process::Command::new("cargo")
.args(["audit", "--json"])
.current_dir(repo_path)
.env("RUSTC_WRAPPER", "")
.output()
.await
.map_err(|e| CoreError::Scanner {
scanner: "cargo-audit".to_string(),
source: Box::new(e),
})?;
let result: CargoAuditOutput =
serde_json::from_slice(&output.stdout).unwrap_or_else(|_| CargoAuditOutput {
vulnerabilities: CargoAuditVulns { list: Vec::new() },
});
let vulns = result
.vulnerabilities
.list
.into_iter()
.map(|v| AuditVuln {
package: v.advisory.package,
id: v.advisory.id,
url: v.advisory.url,
})
.collect();
Ok(vulns)
}
// Cargo audit types
#[derive(serde::Deserialize)]
struct CargoAuditOutput {
vulnerabilities: CargoAuditVulns,
}
#[derive(serde::Deserialize)]
struct CargoAuditVulns {
list: Vec<CargoAuditEntry>,
}
#[derive(serde::Deserialize)]
struct CargoAuditEntry {
advisory: CargoAuditAdvisory,
}
#[derive(serde::Deserialize)]
struct CargoAuditAdvisory {
id: String,
package: String,
url: String,
}

View File

@@ -1,3 +1,6 @@
mod cargo_audit;
mod syft;
use std::path::Path;
use compliance_core::models::{SbomEntry, ScanType, VulnRef};
@@ -23,7 +26,7 @@ impl Scanner for SbomScanner {
generate_lockfiles(repo_path).await;
// Run syft for SBOM generation
match run_syft(repo_path, repo_id).await {
match syft::run_syft(repo_path, repo_id).await {
Ok(syft_entries) => entries.extend(syft_entries),
Err(e) => tracing::warn!("syft failed: {e}"),
}
@@ -32,7 +35,7 @@ impl Scanner for SbomScanner {
enrich_cargo_licenses(repo_path, &mut entries).await;
// Run cargo-audit for Rust-specific vulns
match run_cargo_audit(repo_path, repo_id).await {
match cargo_audit::run_cargo_audit(repo_path, repo_id).await {
Ok(vulns) => merge_audit_vulns(&mut entries, vulns),
Err(e) => tracing::warn!("cargo-audit skipped: {e}"),
}
@@ -186,95 +189,7 @@ async fn enrich_cargo_licenses(repo_path: &Path, entries: &mut [SbomEntry]) {
}
}
#[tracing::instrument(skip_all, fields(repo_id = %repo_id))]
async fn run_syft(repo_path: &Path, repo_id: &str) -> Result<Vec<SbomEntry>, CoreError> {
let output = tokio::process::Command::new("syft")
.arg(repo_path)
.args(["-o", "cyclonedx-json"])
// Enable remote license lookups for all ecosystems
.env("SYFT_GOLANG_SEARCH_REMOTE_LICENSES", "true")
.env("SYFT_JAVASCRIPT_SEARCH_REMOTE_LICENSES", "true")
.env("SYFT_PYTHON_SEARCH_REMOTE_LICENSES", "true")
.env("SYFT_JAVA_USE_NETWORK", "true")
.output()
.await
.map_err(|e| CoreError::Scanner {
scanner: "syft".to_string(),
source: Box::new(e),
})?;
if !output.status.success() {
let stderr = String::from_utf8_lossy(&output.stderr);
return Err(CoreError::Scanner {
scanner: "syft".to_string(),
source: format!("syft exited with {}: {stderr}", output.status).into(),
});
}
let cdx: CycloneDxBom = serde_json::from_slice(&output.stdout)?;
let entries = cdx
.components
.unwrap_or_default()
.into_iter()
.map(|c| {
let package_manager = c
.purl
.as_deref()
.and_then(extract_ecosystem_from_purl)
.unwrap_or_else(|| "unknown".to_string());
let mut entry = SbomEntry::new(
repo_id.to_string(),
c.name,
c.version.unwrap_or_else(|| "unknown".to_string()),
package_manager,
);
entry.purl = c.purl;
entry.license = c.licenses.and_then(|ls| extract_license(&ls));
entry
})
.collect();
Ok(entries)
}
#[tracing::instrument(skip_all)]
async fn run_cargo_audit(repo_path: &Path, _repo_id: &str) -> Result<Vec<AuditVuln>, CoreError> {
let cargo_lock = repo_path.join("Cargo.lock");
if !cargo_lock.exists() {
return Ok(Vec::new());
}
let output = tokio::process::Command::new("cargo")
.args(["audit", "--json"])
.current_dir(repo_path)
.env("RUSTC_WRAPPER", "")
.output()
.await
.map_err(|e| CoreError::Scanner {
scanner: "cargo-audit".to_string(),
source: Box::new(e),
})?;
let result: CargoAuditOutput =
serde_json::from_slice(&output.stdout).unwrap_or_else(|_| CargoAuditOutput {
vulnerabilities: CargoAuditVulns { list: Vec::new() },
});
let vulns = result
.vulnerabilities
.list
.into_iter()
.map(|v| AuditVuln {
package: v.advisory.package,
id: v.advisory.id,
url: v.advisory.url,
})
.collect();
Ok(vulns)
}
fn merge_audit_vulns(entries: &mut [SbomEntry], vulns: Vec<AuditVuln>) {
fn merge_audit_vulns(entries: &mut [SbomEntry], vulns: Vec<cargo_audit::AuditVuln>) {
for vuln in vulns {
if let Some(entry) = entries.iter_mut().find(|e| e.name == vuln.package) {
entry.known_vulnerabilities.push(VulnRef {
@@ -287,65 +202,6 @@ fn merge_audit_vulns(entries: &mut [SbomEntry], vulns: Vec<AuditVuln>) {
}
}
// CycloneDX JSON types
#[derive(serde::Deserialize)]
struct CycloneDxBom {
components: Option<Vec<CdxComponent>>,
}
#[derive(serde::Deserialize)]
struct CdxComponent {
name: String,
version: Option<String>,
#[serde(rename = "type")]
#[allow(dead_code)]
component_type: Option<String>,
purl: Option<String>,
licenses: Option<Vec<CdxLicenseWrapper>>,
}
#[derive(serde::Deserialize)]
struct CdxLicenseWrapper {
license: Option<CdxLicense>,
/// SPDX license expression (e.g. "MIT OR Apache-2.0")
expression: Option<String>,
}
#[derive(serde::Deserialize)]
struct CdxLicense {
id: Option<String>,
name: Option<String>,
}
// Cargo audit types
#[derive(serde::Deserialize)]
struct CargoAuditOutput {
vulnerabilities: CargoAuditVulns,
}
#[derive(serde::Deserialize)]
struct CargoAuditVulns {
list: Vec<CargoAuditEntry>,
}
#[derive(serde::Deserialize)]
struct CargoAuditEntry {
advisory: CargoAuditAdvisory,
}
#[derive(serde::Deserialize)]
struct CargoAuditAdvisory {
id: String,
package: String,
url: String,
}
struct AuditVuln {
package: String,
id: String,
url: String,
}
// Cargo metadata types
#[derive(serde::Deserialize)]
struct CargoMetadata {
@@ -358,49 +214,3 @@ struct CargoPackage {
version: String,
license: Option<String>,
}
/// Extract the best license string from CycloneDX license entries.
/// Handles three formats: expression ("MIT OR Apache-2.0"), license.id ("MIT"), license.name ("MIT License").
fn extract_license(entries: &[CdxLicenseWrapper]) -> Option<String> {
// First pass: look for SPDX expressions (most precise for dual-licensed packages)
for entry in entries {
if let Some(ref expr) = entry.expression {
if !expr.is_empty() {
return Some(expr.clone());
}
}
}
// Second pass: collect license.id or license.name from all entries
let parts: Vec<String> = entries
.iter()
.filter_map(|e| {
e.license.as_ref().and_then(|lic| {
lic.id
.clone()
.or_else(|| lic.name.clone())
.filter(|s| !s.is_empty())
})
})
.collect();
if parts.is_empty() {
return None;
}
Some(parts.join(" OR "))
}
/// Extract the ecosystem/package-manager from a PURL string.
/// e.g. "pkg:npm/lodash@4.17.21" → "npm", "pkg:cargo/serde@1.0" → "cargo"
fn extract_ecosystem_from_purl(purl: &str) -> Option<String> {
let rest = purl.strip_prefix("pkg:")?;
let ecosystem = rest.split('/').next()?;
if ecosystem.is_empty() {
return None;
}
// Normalise common PURL types to user-friendly names
let normalised = match ecosystem {
"golang" => "go",
"pypi" => "pip",
_ => ecosystem,
};
Some(normalised.to_string())
}

View File

@@ -0,0 +1,355 @@
use std::path::Path;
use compliance_core::models::SbomEntry;
use compliance_core::CoreError;
#[tracing::instrument(skip_all, fields(repo_id = %repo_id))]
pub(super) async fn run_syft(repo_path: &Path, repo_id: &str) -> Result<Vec<SbomEntry>, CoreError> {
let output = tokio::process::Command::new("syft")
.arg(repo_path)
.args(["-o", "cyclonedx-json"])
// Enable remote license lookups for all ecosystems
.env("SYFT_GOLANG_SEARCH_REMOTE_LICENSES", "true")
.env("SYFT_JAVASCRIPT_SEARCH_REMOTE_LICENSES", "true")
.env("SYFT_PYTHON_SEARCH_REMOTE_LICENSES", "true")
.env("SYFT_JAVA_USE_NETWORK", "true")
.output()
.await
.map_err(|e| CoreError::Scanner {
scanner: "syft".to_string(),
source: Box::new(e),
})?;
if !output.status.success() {
let stderr = String::from_utf8_lossy(&output.stderr);
return Err(CoreError::Scanner {
scanner: "syft".to_string(),
source: format!("syft exited with {}: {stderr}", output.status).into(),
});
}
let cdx: CycloneDxBom = serde_json::from_slice(&output.stdout)?;
let entries = cdx
.components
.unwrap_or_default()
.into_iter()
.map(|c| {
let package_manager = c
.purl
.as_deref()
.and_then(extract_ecosystem_from_purl)
.unwrap_or_else(|| "unknown".to_string());
let mut entry = SbomEntry::new(
repo_id.to_string(),
c.name,
c.version.unwrap_or_else(|| "unknown".to_string()),
package_manager,
);
entry.purl = c.purl;
entry.license = c.licenses.and_then(|ls| extract_license(&ls));
entry
})
.collect();
Ok(entries)
}
// CycloneDX JSON types
#[derive(serde::Deserialize)]
struct CycloneDxBom {
components: Option<Vec<CdxComponent>>,
}
#[derive(serde::Deserialize)]
struct CdxComponent {
name: String,
version: Option<String>,
#[serde(rename = "type")]
#[allow(dead_code)]
component_type: Option<String>,
purl: Option<String>,
licenses: Option<Vec<CdxLicenseWrapper>>,
}
#[derive(serde::Deserialize)]
struct CdxLicenseWrapper {
license: Option<CdxLicense>,
/// SPDX license expression (e.g. "MIT OR Apache-2.0")
expression: Option<String>,
}
#[derive(serde::Deserialize)]
struct CdxLicense {
id: Option<String>,
name: Option<String>,
}
/// Extract the best license string from CycloneDX license entries.
/// Handles three formats: expression ("MIT OR Apache-2.0"), license.id ("MIT"), license.name ("MIT License").
fn extract_license(entries: &[CdxLicenseWrapper]) -> Option<String> {
// First pass: look for SPDX expressions (most precise for dual-licensed packages)
for entry in entries {
if let Some(ref expr) = entry.expression {
if !expr.is_empty() {
return Some(expr.clone());
}
}
}
// Second pass: collect license.id or license.name from all entries
let parts: Vec<String> = entries
.iter()
.filter_map(|e| {
e.license.as_ref().and_then(|lic| {
lic.id
.clone()
.or_else(|| lic.name.clone())
.filter(|s| !s.is_empty())
})
})
.collect();
if parts.is_empty() {
return None;
}
Some(parts.join(" OR "))
}
/// Extract the ecosystem/package-manager from a PURL string.
/// e.g. "pkg:npm/lodash@4.17.21" -> "npm", "pkg:cargo/serde@1.0" -> "cargo"
fn extract_ecosystem_from_purl(purl: &str) -> Option<String> {
let rest = purl.strip_prefix("pkg:")?;
let ecosystem = rest.split('/').next()?;
if ecosystem.is_empty() {
return None;
}
// Normalise common PURL types to user-friendly names
let normalised = match ecosystem {
"golang" => "go",
"pypi" => "pip",
_ => ecosystem,
};
Some(normalised.to_string())
}
#[cfg(test)]
mod tests {
use super::*;
// --- extract_ecosystem_from_purl tests ---
#[test]
fn purl_npm() {
assert_eq!(
extract_ecosystem_from_purl("pkg:npm/lodash@4.17.21"),
Some("npm".to_string())
);
}
#[test]
fn purl_cargo() {
assert_eq!(
extract_ecosystem_from_purl("pkg:cargo/serde@1.0.197"),
Some("cargo".to_string())
);
}
#[test]
fn purl_golang_normalised() {
assert_eq!(
extract_ecosystem_from_purl("pkg:golang/github.com/gin-gonic/gin@1.9.1"),
Some("go".to_string())
);
}
#[test]
fn purl_pypi_normalised() {
assert_eq!(
extract_ecosystem_from_purl("pkg:pypi/requests@2.31.0"),
Some("pip".to_string())
);
}
#[test]
fn purl_maven() {
assert_eq!(
extract_ecosystem_from_purl("pkg:maven/org.apache.commons/commons-lang3@3.14.0"),
Some("maven".to_string())
);
}
#[test]
fn purl_missing_prefix() {
assert_eq!(extract_ecosystem_from_purl("npm/lodash@4.17.21"), None);
}
#[test]
fn purl_empty_ecosystem() {
assert_eq!(extract_ecosystem_from_purl("pkg:/lodash@4.17.21"), None);
}
#[test]
fn purl_empty_string() {
assert_eq!(extract_ecosystem_from_purl(""), None);
}
#[test]
fn purl_just_prefix() {
assert_eq!(extract_ecosystem_from_purl("pkg:"), None);
}
// --- extract_license tests ---
#[test]
fn license_from_expression() {
let entries = vec![CdxLicenseWrapper {
license: None,
expression: Some("MIT OR Apache-2.0".to_string()),
}];
assert_eq!(
extract_license(&entries),
Some("MIT OR Apache-2.0".to_string())
);
}
#[test]
fn license_from_id() {
let entries = vec![CdxLicenseWrapper {
license: Some(CdxLicense {
id: Some("MIT".to_string()),
name: None,
}),
expression: None,
}];
assert_eq!(extract_license(&entries), Some("MIT".to_string()));
}
#[test]
fn license_from_name_fallback() {
let entries = vec![CdxLicenseWrapper {
license: Some(CdxLicense {
id: None,
name: Some("MIT License".to_string()),
}),
expression: None,
}];
assert_eq!(extract_license(&entries), Some("MIT License".to_string()));
}
#[test]
fn license_expression_preferred_over_id() {
let entries = vec![
CdxLicenseWrapper {
license: Some(CdxLicense {
id: Some("MIT".to_string()),
name: None,
}),
expression: None,
},
CdxLicenseWrapper {
license: None,
expression: Some("MIT AND Apache-2.0".to_string()),
},
];
// Expression should be preferred (first pass finds it)
assert_eq!(
extract_license(&entries),
Some("MIT AND Apache-2.0".to_string())
);
}
#[test]
fn license_multiple_ids_joined() {
let entries = vec![
CdxLicenseWrapper {
license: Some(CdxLicense {
id: Some("MIT".to_string()),
name: None,
}),
expression: None,
},
CdxLicenseWrapper {
license: Some(CdxLicense {
id: Some("Apache-2.0".to_string()),
name: None,
}),
expression: None,
},
];
assert_eq!(
extract_license(&entries),
Some("MIT OR Apache-2.0".to_string())
);
}
#[test]
fn license_empty_entries() {
let entries: Vec<CdxLicenseWrapper> = vec![];
assert_eq!(extract_license(&entries), None);
}
#[test]
fn license_all_empty_strings() {
let entries = vec![CdxLicenseWrapper {
license: Some(CdxLicense {
id: Some(String::new()),
name: Some(String::new()),
}),
expression: Some(String::new()),
}];
assert_eq!(extract_license(&entries), None);
}
#[test]
fn license_none_fields() {
let entries = vec![CdxLicenseWrapper {
license: None,
expression: None,
}];
assert_eq!(extract_license(&entries), None);
}
// --- CycloneDX deserialization tests ---
#[test]
fn deserialize_cyclonedx_bom() {
let json = r#"{
"components": [
{
"name": "serde",
"version": "1.0.197",
"type": "library",
"purl": "pkg:cargo/serde@1.0.197",
"licenses": [
{"expression": "MIT OR Apache-2.0"}
]
}
]
}"#;
let bom: CycloneDxBom = serde_json::from_str(json).unwrap();
let components = bom.components.unwrap();
assert_eq!(components.len(), 1);
assert_eq!(components[0].name, "serde");
assert_eq!(components[0].version, Some("1.0.197".to_string()));
assert_eq!(
components[0].purl,
Some("pkg:cargo/serde@1.0.197".to_string())
);
}
#[test]
fn deserialize_cyclonedx_no_components() {
let json = r#"{}"#;
let bom: CycloneDxBom = serde_json::from_str(json).unwrap();
assert!(bom.components.is_none());
}
#[test]
fn deserialize_cyclonedx_minimal_component() {
let json = r#"{"components": [{"name": "foo"}]}"#;
let bom: CycloneDxBom = serde_json::from_str(json).unwrap();
let c = &bom.components.unwrap()[0];
assert_eq!(c.name, "foo");
assert!(c.version.is_none());
assert!(c.purl.is_none());
assert!(c.licenses.is_none());
}
}

View File

@@ -108,3 +108,124 @@ struct SemgrepExtra {
#[serde(default)]
metadata: Option<serde_json::Value>,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn deserialize_semgrep_output() {
let json = r#"{
"results": [
{
"check_id": "python.lang.security.audit.exec-detected",
"path": "src/main.py",
"start": {"line": 15},
"extra": {
"message": "Detected use of exec()",
"severity": "ERROR",
"lines": "exec(user_input)",
"metadata": {"cwe": "CWE-78"}
}
}
]
}"#;
let output: SemgrepOutput = serde_json::from_str(json).unwrap();
assert_eq!(output.results.len(), 1);
let r = &output.results[0];
assert_eq!(r.check_id, "python.lang.security.audit.exec-detected");
assert_eq!(r.path, "src/main.py");
assert_eq!(r.start.line, 15);
assert_eq!(r.extra.message, "Detected use of exec()");
assert_eq!(r.extra.severity, "ERROR");
assert_eq!(r.extra.lines, "exec(user_input)");
assert_eq!(
r.extra
.metadata
.as_ref()
.unwrap()
.get("cwe")
.unwrap()
.as_str(),
Some("CWE-78")
);
}
#[test]
fn deserialize_semgrep_empty_results() {
let json = r#"{"results": []}"#;
let output: SemgrepOutput = serde_json::from_str(json).unwrap();
assert!(output.results.is_empty());
}
#[test]
fn deserialize_semgrep_no_metadata() {
let json = r#"{
"results": [
{
"check_id": "rule-1",
"path": "app.py",
"start": {"line": 1},
"extra": {
"message": "found something",
"severity": "WARNING",
"lines": "import os"
}
}
]
}"#;
let output: SemgrepOutput = serde_json::from_str(json).unwrap();
assert!(output.results[0].extra.metadata.is_none());
}
#[test]
fn semgrep_severity_mapping() {
let cases = vec![
("ERROR", "High"),
("WARNING", "Medium"),
("INFO", "Low"),
("UNKNOWN", "Info"),
];
for (input, expected) in cases {
let result = match input {
"ERROR" => "High",
"WARNING" => "Medium",
"INFO" => "Low",
_ => "Info",
};
assert_eq!(result, expected, "Severity for '{input}'");
}
}
#[test]
fn deserialize_semgrep_multiple_results() {
let json = r#"{
"results": [
{
"check_id": "rule-a",
"path": "a.py",
"start": {"line": 1},
"extra": {
"message": "msg a",
"severity": "ERROR",
"lines": "line a"
}
},
{
"check_id": "rule-b",
"path": "b.py",
"start": {"line": 99},
"extra": {
"message": "msg b",
"severity": "INFO",
"lines": "line b"
}
}
]
}"#;
let output: SemgrepOutput = serde_json::from_str(json).unwrap();
assert_eq!(output.results.len(), 2);
assert_eq!(output.results[1].start.line, 99);
}
}

View File

@@ -0,0 +1,81 @@
use compliance_core::models::TrackerIssue;
use compliance_core::traits::issue_tracker::IssueTracker;
use crate::trackers;
/// Enum dispatch for issue trackers (async traits aren't dyn-compatible).
pub(crate) enum TrackerDispatch {
GitHub(trackers::github::GitHubTracker),
GitLab(trackers::gitlab::GitLabTracker),
Gitea(trackers::gitea::GiteaTracker),
Jira(trackers::jira::JiraTracker),
}
impl TrackerDispatch {
pub(crate) fn name(&self) -> &str {
match self {
Self::GitHub(t) => t.name(),
Self::GitLab(t) => t.name(),
Self::Gitea(t) => t.name(),
Self::Jira(t) => t.name(),
}
}
pub(crate) async fn create_issue(
&self,
owner: &str,
repo: &str,
title: &str,
body: &str,
labels: &[String],
) -> Result<TrackerIssue, compliance_core::error::CoreError> {
match self {
Self::GitHub(t) => t.create_issue(owner, repo, title, body, labels).await,
Self::GitLab(t) => t.create_issue(owner, repo, title, body, labels).await,
Self::Gitea(t) => t.create_issue(owner, repo, title, body, labels).await,
Self::Jira(t) => t.create_issue(owner, repo, title, body, labels).await,
}
}
pub(crate) async fn find_existing_issue(
&self,
owner: &str,
repo: &str,
fingerprint: &str,
) -> Result<Option<TrackerIssue>, compliance_core::error::CoreError> {
match self {
Self::GitHub(t) => t.find_existing_issue(owner, repo, fingerprint).await,
Self::GitLab(t) => t.find_existing_issue(owner, repo, fingerprint).await,
Self::Gitea(t) => t.find_existing_issue(owner, repo, fingerprint).await,
Self::Jira(t) => t.find_existing_issue(owner, repo, fingerprint).await,
}
}
pub(crate) async fn create_pr_review(
&self,
owner: &str,
repo: &str,
pr_number: u64,
body: &str,
comments: Vec<compliance_core::traits::issue_tracker::ReviewComment>,
) -> Result<(), compliance_core::error::CoreError> {
match self {
Self::GitHub(t) => {
t.create_pr_review(owner, repo, pr_number, body, comments)
.await
}
Self::GitLab(t) => {
t.create_pr_review(owner, repo, pr_number, body, comments)
.await
}
Self::Gitea(t) => {
t.create_pr_review(owner, repo, pr_number, body, comments)
.await
}
Self::Jira(t) => {
t.create_pr_review(owner, repo, pr_number, body, comments)
.await
}
}
}
}

View File

@@ -0,0 +1,3 @@
// Shared test helpers for compliance-agent integration tests.
//
// Add database mocks, fixtures, and test utilities here.

View File

@@ -0,0 +1,4 @@
// Integration tests for the compliance-agent crate.
//
// Add tests that exercise the full pipeline, API handlers,
// and cross-module interactions here.