use std::sync::Arc; use compliance_core::models::{Finding, FindingStatus}; use crate::llm::LlmClient; use crate::pipeline::orchestrator::GraphContext; /// Maximum number of findings to include in a single LLM triage call. const TRIAGE_CHUNK_SIZE: usize = 30; const TRIAGE_SYSTEM_PROMPT: &str = r#"You are a security finding triage expert. Analyze each of the following security findings with its code context and determine the appropriate action. Actions: - "confirm": The finding is a true positive at the reported severity. Keep as-is. - "downgrade": The finding is real but over-reported. Lower severity recommended. - "upgrade": The finding is under-reported. Higher severity recommended. - "dismiss": The finding is a false positive. Should be removed. Consider: - Is the code in a test, example, or generated file? (lower confidence for test code) - Does the surrounding code context confirm or refute the finding? - Is the finding actionable by a developer? - Would a real attacker be able to exploit this? Respond with a JSON array, one entry per finding in the same order they were presented: [{"id": "", "action": "confirm|downgrade|upgrade|dismiss", "confidence": 0-10, "rationale": "brief explanation", "remediation": "optional fix suggestion"}, ...]"#; pub async fn triage_findings( llm: &Arc, findings: &mut Vec, graph_context: Option<&GraphContext>, ) -> usize { let mut passed = 0; // Process findings in chunks to avoid overflowing the LLM context window. for chunk_start in (0..findings.len()).step_by(TRIAGE_CHUNK_SIZE) { let chunk_end = (chunk_start + TRIAGE_CHUNK_SIZE).min(findings.len()); let chunk = &mut findings[chunk_start..chunk_end]; // Build a combined prompt for the entire chunk. let mut user_prompt = String::new(); let mut file_classifications: Vec = Vec::new(); for (i, finding) in chunk.iter().enumerate() { let file_classification = classify_file_path(finding.file_path.as_deref()); user_prompt.push_str(&format!( "\n--- Finding {} (id: {}) ---\nScanner: {}\nRule: {}\nSeverity: {}\nTitle: {}\nDescription: {}\nFile: {}\nLine: {}\nCode: {}\nFile classification: {}", i + 1, finding.fingerprint, finding.scanner, finding.rule_id.as_deref().unwrap_or("N/A"), finding.severity, finding.title, finding.description, finding.file_path.as_deref().unwrap_or("N/A"), finding.line_number.map(|n| n.to_string()).unwrap_or_else(|| "N/A".to_string()), finding.code_snippet.as_deref().unwrap_or("N/A"), file_classification, )); // Enrich with surrounding code context if possible if let Some(context) = read_surrounding_context(finding) { user_prompt.push_str(&format!( "\n\n--- Surrounding Code (50 lines) ---\n{context}" )); } // Enrich with graph context if available if let Some(ctx) = graph_context { if let Some(impact) = ctx .impacts .iter() .find(|im| im.finding_id == finding.fingerprint) { user_prompt.push_str(&format!( "\n\n--- Code Graph Context ---\n\ Blast radius: {} nodes affected\n\ Entry points affected: {}\n\ Direct callers: {}\n\ Communities affected: {}\n\ Call chains: {}", impact.blast_radius, if impact.affected_entry_points.is_empty() { "none".to_string() } else { impact.affected_entry_points.join(", ") }, if impact.direct_callers.is_empty() { "none".to_string() } else { impact.direct_callers.join(", ") }, impact.affected_communities.len(), impact.call_chains.len(), )); } } user_prompt.push('\n'); file_classifications.push(file_classification); } // Send the batch to the LLM. match llm .chat(TRIAGE_SYSTEM_PROMPT, &user_prompt, Some(0.1)) .await { Ok(response) => { let cleaned = response.trim(); let cleaned = if cleaned.starts_with("```") { cleaned .trim_start_matches("```json") .trim_start_matches("```") .trim_end_matches("```") .trim() } else { cleaned }; match serde_json::from_str::>(cleaned) { Ok(results) => { for (idx, finding) in chunk.iter_mut().enumerate() { // Match result by position; fall back to keeping the finding. let Some(result) = results.get(idx) else { finding.status = FindingStatus::Triaged; passed += 1; continue; }; let file_classification = file_classifications .get(idx) .map(|s| s.as_str()) .unwrap_or("unknown"); let adjusted_confidence = adjust_confidence(result.confidence, file_classification); finding.confidence = Some(adjusted_confidence); finding.triage_action = Some(result.action.clone()); finding.triage_rationale = Some(result.rationale.clone()); if let Some(ref remediation) = result.remediation { finding.remediation = Some(remediation.clone()); } match result.action.as_str() { "dismiss" => { finding.status = FindingStatus::FalsePositive; } "downgrade" => { finding.severity = downgrade_severity(&finding.severity); finding.status = FindingStatus::Triaged; passed += 1; } "upgrade" => { finding.severity = upgrade_severity(&finding.severity); finding.status = FindingStatus::Triaged; passed += 1; } _ => { // "confirm" or unknown — keep as-is if adjusted_confidence >= 3.0 { finding.status = FindingStatus::Triaged; passed += 1; } else { finding.status = FindingStatus::FalsePositive; } } } } } Err(_) => { // Batch parse failure — keep all findings in the chunk. tracing::warn!( "Failed to parse batch triage response for chunk starting at {chunk_start}: {cleaned}" ); for finding in chunk.iter_mut() { finding.status = FindingStatus::Triaged; passed += 1; } } } } Err(e) => { // On LLM error, keep all findings in the chunk. tracing::warn!("LLM batch triage failed for chunk starting at {chunk_start}: {e}"); for finding in chunk.iter_mut() { finding.status = FindingStatus::Triaged; passed += 1; } } } } // Remove false positives findings.retain(|f| f.status != FindingStatus::FalsePositive); passed } /// Read ~50 lines of surrounding code from the file at the finding's location fn read_surrounding_context(finding: &Finding) -> Option { let file_path = finding.file_path.as_deref()?; let line = finding.line_number? as usize; // Try to read the file — this works because the repo is cloned locally let content = std::fs::read_to_string(file_path).ok()?; let lines: Vec<&str> = content.lines().collect(); let start = line.saturating_sub(25); let end = (line + 25).min(lines.len()); Some( lines[start..end] .iter() .enumerate() .map(|(i, l)| format!("{:>4} | {}", start + i + 1, l)) .collect::>() .join("\n"), ) } /// Classify a file path to inform triage confidence adjustment fn classify_file_path(path: Option<&str>) -> String { let path = match path { Some(p) => p.to_lowercase(), None => return "unknown".to_string(), }; if path.contains("/test/") || path.contains("/tests/") || path.contains("_test.") || path.contains(".test.") || path.contains(".spec.") || path.contains("/fixtures/") || path.contains("/testdata/") { return "test".to_string(); } if path.contains("/example") || path.contains("/examples/") || path.contains("/demo/") || path.contains("/sample") { return "example".to_string(); } if path.contains("/generated/") || path.contains("/gen/") || path.contains(".generated.") || path.contains(".pb.go") || path.contains("_generated.rs") { return "generated".to_string(); } if path.contains("/vendor/") || path.contains("/node_modules/") || path.contains("/third_party/") { return "vendored".to_string(); } "production".to_string() } /// Adjust confidence based on file classification fn adjust_confidence(raw_confidence: f64, classification: &str) -> f64 { let multiplier = match classification { "test" => 0.5, "example" => 0.6, "generated" => 0.3, "vendored" => 0.4, _ => 1.0, }; raw_confidence * multiplier } fn downgrade_severity( severity: &compliance_core::models::Severity, ) -> compliance_core::models::Severity { use compliance_core::models::Severity; match severity { Severity::Critical => Severity::High, Severity::High => Severity::Medium, Severity::Medium => Severity::Low, Severity::Low => Severity::Info, Severity::Info => Severity::Info, } } fn upgrade_severity( severity: &compliance_core::models::Severity, ) -> compliance_core::models::Severity { use compliance_core::models::Severity; match severity { Severity::Info => Severity::Low, Severity::Low => Severity::Medium, Severity::Medium => Severity::High, Severity::High => Severity::Critical, Severity::Critical => Severity::Critical, } } #[derive(serde::Deserialize)] struct TriageResult { /// Finding fingerprint echoed back by the LLM (optional). #[serde(default)] #[allow(dead_code)] id: String, #[serde(default = "default_action")] action: String, #[serde(default)] confidence: f64, #[serde(default)] rationale: String, remediation: Option, } fn default_action() -> String { "confirm".to_string() } #[cfg(test)] mod tests { use super::*; use compliance_core::models::Severity; // ── classify_file_path ─────────────────────────────────────── #[test] fn classify_none_path() { assert_eq!(classify_file_path(None), "unknown"); } #[test] fn classify_production_path() { assert_eq!(classify_file_path(Some("src/main.rs")), "production"); assert_eq!(classify_file_path(Some("lib/core/engine.py")), "production"); } #[test] fn classify_test_paths() { assert_eq!(classify_file_path(Some("src/test/helper.rs")), "test"); assert_eq!(classify_file_path(Some("src/tests/unit.rs")), "test"); assert_eq!(classify_file_path(Some("foo_test.go")), "test"); assert_eq!(classify_file_path(Some("bar.test.js")), "test"); assert_eq!(classify_file_path(Some("baz.spec.ts")), "test"); assert_eq!( classify_file_path(Some("data/fixtures/sample.json")), "test" ); assert_eq!(classify_file_path(Some("src/testdata/input.txt")), "test"); } #[test] fn classify_example_paths() { assert_eq!( classify_file_path(Some("docs/examples/basic.rs")), "example" ); // /example matches because contains("/example") assert_eq!(classify_file_path(Some("src/example/main.py")), "example"); assert_eq!(classify_file_path(Some("src/demo/run.sh")), "example"); assert_eq!(classify_file_path(Some("src/sample/lib.rs")), "example"); } #[test] fn classify_generated_paths() { assert_eq!( classify_file_path(Some("src/generated/api.rs")), "generated" ); assert_eq!( classify_file_path(Some("proto/gen/service.go")), "generated" ); assert_eq!(classify_file_path(Some("api.generated.ts")), "generated"); assert_eq!(classify_file_path(Some("service.pb.go")), "generated"); assert_eq!(classify_file_path(Some("model_generated.rs")), "generated"); } #[test] fn classify_vendored_paths() { // Implementation checks for /vendor/, /node_modules/, /third_party/ (with slashes) assert_eq!( classify_file_path(Some("src/vendor/lib/foo.go")), "vendored" ); assert_eq!( classify_file_path(Some("src/node_modules/pkg/index.js")), "vendored" ); assert_eq!( classify_file_path(Some("src/third_party/lib.c")), "vendored" ); } #[test] fn classify_is_case_insensitive() { assert_eq!(classify_file_path(Some("src/TEST/Helper.rs")), "test"); assert_eq!(classify_file_path(Some("src/VENDOR/lib.go")), "vendored"); assert_eq!( classify_file_path(Some("src/GENERATED/foo.ts")), "generated" ); } // ── adjust_confidence ──────────────────────────────────────── #[test] fn adjust_confidence_production() { assert_eq!(adjust_confidence(8.0, "production"), 8.0); } #[test] fn adjust_confidence_test() { assert_eq!(adjust_confidence(10.0, "test"), 5.0); } #[test] fn adjust_confidence_example() { assert_eq!(adjust_confidence(10.0, "example"), 6.0); } #[test] fn adjust_confidence_generated() { assert_eq!(adjust_confidence(10.0, "generated"), 3.0); } #[test] fn adjust_confidence_vendored() { assert_eq!(adjust_confidence(10.0, "vendored"), 4.0); } #[test] fn adjust_confidence_unknown_classification() { assert_eq!(adjust_confidence(7.0, "unknown"), 7.0); assert_eq!(adjust_confidence(7.0, "something_else"), 7.0); } #[test] fn adjust_confidence_zero() { assert_eq!(adjust_confidence(0.0, "test"), 0.0); assert_eq!(adjust_confidence(0.0, "production"), 0.0); } // ── downgrade_severity ─────────────────────────────────────── #[test] fn downgrade_severity_all_levels() { assert_eq!(downgrade_severity(&Severity::Critical), Severity::High); assert_eq!(downgrade_severity(&Severity::High), Severity::Medium); assert_eq!(downgrade_severity(&Severity::Medium), Severity::Low); assert_eq!(downgrade_severity(&Severity::Low), Severity::Info); assert_eq!(downgrade_severity(&Severity::Info), Severity::Info); } #[test] fn downgrade_severity_info_is_floor() { // Downgrading Info twice should still be Info let s = downgrade_severity(&Severity::Info); assert_eq!(downgrade_severity(&s), Severity::Info); } // ── upgrade_severity ───────────────────────────────────────── #[test] fn upgrade_severity_all_levels() { assert_eq!(upgrade_severity(&Severity::Info), Severity::Low); assert_eq!(upgrade_severity(&Severity::Low), Severity::Medium); assert_eq!(upgrade_severity(&Severity::Medium), Severity::High); assert_eq!(upgrade_severity(&Severity::High), Severity::Critical); assert_eq!(upgrade_severity(&Severity::Critical), Severity::Critical); } #[test] fn upgrade_severity_critical_is_ceiling() { let s = upgrade_severity(&Severity::Critical); assert_eq!(upgrade_severity(&s), Severity::Critical); } // ── upgrade/downgrade roundtrip ────────────────────────────── #[test] fn upgrade_then_downgrade_is_identity_for_middle_values() { for sev in [Severity::Low, Severity::Medium, Severity::High] { assert_eq!(downgrade_severity(&upgrade_severity(&sev)), sev); } } // ── TriageResult deserialization ───────────────────────────── #[test] fn triage_result_full() { let json = r#"{"action":"dismiss","confidence":8.5,"rationale":"false positive","remediation":"remove code"}"#; let r: TriageResult = serde_json::from_str(json).unwrap(); assert_eq!(r.action, "dismiss"); assert_eq!(r.confidence, 8.5); assert_eq!(r.rationale, "false positive"); assert_eq!(r.remediation.as_deref(), Some("remove code")); } #[test] fn triage_result_defaults() { let json = r#"{}"#; let r: TriageResult = serde_json::from_str(json).unwrap(); assert_eq!(r.action, "confirm"); assert_eq!(r.confidence, 0.0); assert_eq!(r.rationale, ""); assert!(r.remediation.is_none()); } #[test] fn triage_result_partial() { let json = r#"{"action":"downgrade","confidence":6.0}"#; let r: TriageResult = serde_json::from_str(json).unwrap(); assert_eq!(r.action, "downgrade"); assert_eq!(r.confidence, 6.0); assert_eq!(r.rationale, ""); assert!(r.remediation.is_none()); } #[test] fn triage_result_with_markdown_fences() { // Simulate LLM wrapping response in markdown code fences let raw = "```json\n{\"action\":\"upgrade\",\"confidence\":9,\"rationale\":\"critical\",\"remediation\":null}\n```"; let cleaned = raw .trim() .trim_start_matches("```json") .trim_start_matches("```") .trim_end_matches("```") .trim(); let r: TriageResult = serde_json::from_str(cleaned).unwrap(); assert_eq!(r.action, "upgrade"); assert_eq!(r.confidence, 9.0); } }