Files
breakpilot-compliance/backend-compliance/compliance/api/crosswalk_routes.py
Benjamin Admin d22c47c9eb
Some checks failed
CI/CD / go-lint (push) Has been skipped
CI/CD / python-lint (push) Has been skipped
CI/CD / nodejs-lint (push) Has been skipped
CI/CD / test-go-ai-compliance (push) Failing after 35s
CI/CD / test-python-backend-compliance (push) Successful in 34s
CI/CD / test-python-document-crawler (push) Successful in 22s
CI/CD / test-python-dsms-gateway (push) Successful in 19s
CI/CD / validate-canonical-controls (push) Successful in 11s
CI/CD / Deploy (push) Has been skipped
feat(pipeline): Anthropic Batch API, source/regulation filter, cost optimization
- Add Anthropic API support to decomposition Pass 0a/0b (prompt caching, content batching)
- Add Anthropic Batch API (50% cost reduction, async 24h processing)
- Add source_filter (ILIKE on source_citation) for regulation-based filtering
- Add category_filter to Pass 0a for selective decomposition
- Add regulation_filter to control_generator for RAG scan phase filtering
  (prefix match on regulation_code — enables CE + Code Review focus)
- New API endpoints: batch-submit-0a, batch-submit-0b, batch-status, batch-process
- 83 new tests (all passing)

Cost reduction: $2,525 → ~$600-700 with all optimizations combined.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-17 13:22:01 +01:00

739 lines
25 KiB
Python

"""
FastAPI routes for the Multi-Layer Control Architecture.
Pattern Library, Obligation Extraction, Crosswalk Matrix, and Migration endpoints.
Endpoints:
GET /v1/canonical/patterns — All patterns (with filters)
GET /v1/canonical/patterns/{pattern_id} — Single pattern
GET /v1/canonical/patterns/{pattern_id}/controls — Controls for a pattern
POST /v1/canonical/obligations/extract — Extract obligations from text
GET /v1/canonical/crosswalk — Query crosswalk matrix
GET /v1/canonical/crosswalk/stats — Coverage statistics
POST /v1/canonical/migrate/decompose — Pass 0a: Obligation extraction
POST /v1/canonical/migrate/compose-atomic — Pass 0b: Atomic control composition
POST /v1/canonical/migrate/link-obligations — Pass 1: Obligation linkage
POST /v1/canonical/migrate/classify-patterns — Pass 2: Pattern classification
POST /v1/canonical/migrate/triage — Pass 3: Quality triage
POST /v1/canonical/migrate/backfill-crosswalk — Pass 4: Crosswalk backfill
POST /v1/canonical/migrate/deduplicate — Pass 5: Deduplication
GET /v1/canonical/migrate/status — Migration progress
GET /v1/canonical/migrate/decomposition-status — Decomposition progress
"""
import json
import logging
from typing import Optional, List
from fastapi import APIRouter, HTTPException, Query
from pydantic import BaseModel
from sqlalchemy import text
from database import SessionLocal
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/v1/canonical", tags=["crosswalk"])
# =============================================================================
# REQUEST / RESPONSE MODELS
# =============================================================================
class PatternResponse(BaseModel):
id: str
name: str
name_de: str
domain: str
category: str
description: str
objective_template: str
severity_default: str
implementation_effort_default: str = "m"
tags: list = []
composable_with: list = []
open_anchor_refs: list = []
controls_count: int = 0
class PatternListResponse(BaseModel):
patterns: List[PatternResponse]
total: int
class PatternDetailResponse(PatternResponse):
rationale_template: str = ""
requirements_template: list = []
test_procedure_template: list = []
evidence_template: list = []
obligation_match_keywords: list = []
class ObligationExtractRequest(BaseModel):
text: str
regulation_code: Optional[str] = None
article: Optional[str] = None
paragraph: Optional[str] = None
class ObligationExtractResponse(BaseModel):
obligation_id: Optional[str] = None
obligation_title: Optional[str] = None
obligation_text: Optional[str] = None
method: str = "none"
confidence: float = 0.0
regulation_id: Optional[str] = None
pattern_id: Optional[str] = None
pattern_confidence: float = 0.0
class CrosswalkRow(BaseModel):
regulation_code: str = ""
article: Optional[str] = None
obligation_id: Optional[str] = None
pattern_id: Optional[str] = None
master_control_id: Optional[str] = None
confidence: float = 0.0
source: str = "auto"
class CrosswalkQueryResponse(BaseModel):
rows: List[CrosswalkRow]
total: int
class CrosswalkStatsResponse(BaseModel):
total_rows: int = 0
regulations_covered: int = 0
obligations_linked: int = 0
patterns_used: int = 0
controls_linked: int = 0
coverage_by_regulation: dict = {}
class MigrationRequest(BaseModel):
limit: int = 0 # 0 = no limit
batch_size: int = 0 # 0 = auto (5 for Anthropic, 1 for Ollama)
use_anthropic: bool = False # Use Anthropic API instead of Ollama
category_filter: Optional[str] = None # Comma-separated categories
source_filter: Optional[str] = None # Comma-separated source regulations (ILIKE match)
class BatchSubmitRequest(BaseModel):
limit: int = 0
batch_size: int = 5
category_filter: Optional[str] = None
source_filter: Optional[str] = None
class BatchProcessRequest(BaseModel):
batch_id: str
pass_type: str = "0a" # "0a" or "0b"
class MigrationResponse(BaseModel):
status: str = "completed"
stats: dict = {}
class MigrationStatusResponse(BaseModel):
total_controls: int = 0
has_obligation: int = 0
has_pattern: int = 0
fully_linked: int = 0
deprecated: int = 0
coverage_obligation_pct: float = 0.0
coverage_pattern_pct: float = 0.0
coverage_full_pct: float = 0.0
class DecompositionStatusResponse(BaseModel):
rich_controls: int = 0
decomposed_controls: int = 0
total_candidates: int = 0
validated: int = 0
rejected: int = 0
composed: int = 0
atomic_controls: int = 0
decomposition_pct: float = 0.0
composition_pct: float = 0.0
# =============================================================================
# PATTERN LIBRARY ENDPOINTS
# =============================================================================
@router.get("/patterns", response_model=PatternListResponse)
async def list_patterns(
domain: Optional[str] = Query(None, description="Filter by domain (e.g. AUTH, CRYP)"),
category: Optional[str] = Query(None, description="Filter by category"),
tag: Optional[str] = Query(None, description="Filter by tag"),
):
"""List all control patterns with optional filters."""
from compliance.services.pattern_matcher import PatternMatcher
matcher = PatternMatcher()
matcher._load_patterns()
matcher._build_keyword_index()
patterns = matcher._patterns
if domain:
patterns = [p for p in patterns if p.domain == domain.upper()]
if category:
patterns = [p for p in patterns if p.category == category.lower()]
if tag:
patterns = [p for p in patterns if tag.lower() in [t.lower() for t in p.tags]]
# Count controls per pattern from DB
control_counts = _get_pattern_control_counts()
response_patterns = []
for p in patterns:
response_patterns.append(PatternResponse(
id=p.id,
name=p.name,
name_de=p.name_de,
domain=p.domain,
category=p.category,
description=p.description,
objective_template=p.objective_template,
severity_default=p.severity_default,
implementation_effort_default=p.implementation_effort_default,
tags=p.tags,
composable_with=p.composable_with,
open_anchor_refs=p.open_anchor_refs,
controls_count=control_counts.get(p.id, 0),
))
return PatternListResponse(patterns=response_patterns, total=len(response_patterns))
@router.get("/patterns/{pattern_id}", response_model=PatternDetailResponse)
async def get_pattern(pattern_id: str):
"""Get a single control pattern by ID."""
from compliance.services.pattern_matcher import PatternMatcher
matcher = PatternMatcher()
matcher._load_patterns()
pattern = matcher.get_pattern(pattern_id)
if not pattern:
raise HTTPException(status_code=404, detail=f"Pattern {pattern_id} not found")
control_counts = _get_pattern_control_counts()
return PatternDetailResponse(
id=pattern.id,
name=pattern.name,
name_de=pattern.name_de,
domain=pattern.domain,
category=pattern.category,
description=pattern.description,
objective_template=pattern.objective_template,
rationale_template=pattern.rationale_template,
requirements_template=pattern.requirements_template,
test_procedure_template=pattern.test_procedure_template,
evidence_template=pattern.evidence_template,
severity_default=pattern.severity_default,
implementation_effort_default=pattern.implementation_effort_default,
tags=pattern.tags,
composable_with=pattern.composable_with,
open_anchor_refs=pattern.open_anchor_refs,
obligation_match_keywords=pattern.obligation_match_keywords,
controls_count=control_counts.get(pattern.id, 0),
)
@router.get("/patterns/{pattern_id}/controls")
async def get_pattern_controls(
pattern_id: str,
limit: int = Query(50, ge=1, le=500),
offset: int = Query(0, ge=0),
):
"""Get controls generated from a specific pattern."""
db = SessionLocal()
try:
result = db.execute(
text("""
SELECT id, control_id, title, objective, severity,
release_state, category, obligation_ids
FROM canonical_controls
WHERE pattern_id = :pattern_id
AND release_state NOT IN ('deprecated')
ORDER BY control_id
LIMIT :limit OFFSET :offset
"""),
{"pattern_id": pattern_id.upper(), "limit": limit, "offset": offset},
)
rows = result.fetchall()
count_result = db.execute(
text("""
SELECT count(*) FROM canonical_controls
WHERE pattern_id = :pattern_id
AND release_state NOT IN ('deprecated')
"""),
{"pattern_id": pattern_id.upper()},
)
total = count_result.fetchone()[0]
controls = []
for row in rows:
obl_ids = row[7]
if isinstance(obl_ids, str):
try:
obl_ids = json.loads(obl_ids)
except (json.JSONDecodeError, TypeError):
obl_ids = []
controls.append({
"id": str(row[0]),
"control_id": row[1],
"title": row[2],
"objective": row[3],
"severity": row[4],
"release_state": row[5],
"category": row[6],
"obligation_ids": obl_ids or [],
})
return {"controls": controls, "total": total}
finally:
db.close()
# =============================================================================
# OBLIGATION EXTRACTION ENDPOINT
# =============================================================================
@router.post("/obligations/extract", response_model=ObligationExtractResponse)
async def extract_obligation(req: ObligationExtractRequest):
"""Extract obligation from text using 3-tier strategy, then match to pattern."""
from compliance.services.obligation_extractor import ObligationExtractor
from compliance.services.pattern_matcher import PatternMatcher
extractor = ObligationExtractor()
await extractor.initialize()
obligation = await extractor.extract(
chunk_text=req.text,
regulation_code=req.regulation_code or "",
article=req.article,
paragraph=req.paragraph,
)
# Also match to pattern
matcher = PatternMatcher()
matcher._load_patterns()
matcher._build_keyword_index()
pattern_text = obligation.obligation_text or obligation.obligation_title or req.text[:500]
pattern_result = matcher._tier1_keyword(pattern_text, obligation.regulation_id)
return ObligationExtractResponse(
obligation_id=obligation.obligation_id,
obligation_title=obligation.obligation_title,
obligation_text=obligation.obligation_text,
method=obligation.method,
confidence=obligation.confidence,
regulation_id=obligation.regulation_id,
pattern_id=pattern_result.pattern_id if pattern_result else None,
pattern_confidence=pattern_result.confidence if pattern_result else 0,
)
# =============================================================================
# CROSSWALK MATRIX ENDPOINTS
# =============================================================================
@router.get("/crosswalk", response_model=CrosswalkQueryResponse)
async def query_crosswalk(
regulation_code: Optional[str] = Query(None),
article: Optional[str] = Query(None),
obligation_id: Optional[str] = Query(None),
pattern_id: Optional[str] = Query(None),
limit: int = Query(100, ge=1, le=1000),
offset: int = Query(0, ge=0),
):
"""Query the crosswalk matrix with filters."""
db = SessionLocal()
try:
conditions = ["1=1"]
params = {"limit": limit, "offset": offset}
if regulation_code:
conditions.append("regulation_code = :reg")
params["reg"] = regulation_code
if article:
conditions.append("article = :art")
params["art"] = article
if obligation_id:
conditions.append("obligation_id = :obl")
params["obl"] = obligation_id
if pattern_id:
conditions.append("pattern_id = :pat")
params["pat"] = pattern_id
where = " AND ".join(conditions)
result = db.execute(
text(f"""
SELECT regulation_code, article, obligation_id,
pattern_id, master_control_id, confidence, source
FROM crosswalk_matrix
WHERE {where}
ORDER BY regulation_code, article
LIMIT :limit OFFSET :offset
"""),
params,
)
rows = result.fetchall()
count_result = db.execute(
text(f"SELECT count(*) FROM crosswalk_matrix WHERE {where}"),
params,
)
total = count_result.fetchone()[0]
crosswalk_rows = [
CrosswalkRow(
regulation_code=r[0] or "",
article=r[1],
obligation_id=r[2],
pattern_id=r[3],
master_control_id=r[4],
confidence=float(r[5] or 0),
source=r[6] or "auto",
)
for r in rows
]
return CrosswalkQueryResponse(rows=crosswalk_rows, total=total)
finally:
db.close()
@router.get("/crosswalk/stats", response_model=CrosswalkStatsResponse)
async def crosswalk_stats():
"""Get crosswalk coverage statistics."""
db = SessionLocal()
try:
row = db.execute(text("""
SELECT
count(*) AS total,
count(DISTINCT regulation_code) FILTER (WHERE regulation_code != '') AS regs,
count(DISTINCT obligation_id) FILTER (WHERE obligation_id IS NOT NULL) AS obls,
count(DISTINCT pattern_id) FILTER (WHERE pattern_id IS NOT NULL) AS pats,
count(DISTINCT master_control_id) FILTER (WHERE master_control_id IS NOT NULL) AS ctrls
FROM crosswalk_matrix
""")).fetchone()
# Coverage by regulation
reg_rows = db.execute(text("""
SELECT regulation_code, count(*) AS cnt
FROM crosswalk_matrix
WHERE regulation_code != ''
GROUP BY regulation_code
ORDER BY cnt DESC
""")).fetchall()
coverage = {r[0]: r[1] for r in reg_rows}
return CrosswalkStatsResponse(
total_rows=row[0],
regulations_covered=row[1],
obligations_linked=row[2],
patterns_used=row[3],
controls_linked=row[4],
coverage_by_regulation=coverage,
)
finally:
db.close()
# =============================================================================
# MIGRATION ENDPOINTS
# =============================================================================
@router.post("/migrate/decompose", response_model=MigrationResponse)
async def migrate_decompose(req: MigrationRequest):
"""Pass 0a: Extract obligation candidates from rich controls.
With use_anthropic=true, uses Anthropic API with prompt caching
and content batching (multiple controls per API call).
"""
from compliance.services.decomposition_pass import DecompositionPass
db = SessionLocal()
try:
decomp = DecompositionPass(db=db)
stats = await decomp.run_pass0a(
limit=req.limit,
batch_size=req.batch_size,
use_anthropic=req.use_anthropic,
category_filter=req.category_filter,
source_filter=req.source_filter,
)
return MigrationResponse(status="completed", stats=stats)
except Exception as e:
logger.error("Decomposition pass 0a failed: %s", e)
raise HTTPException(status_code=500, detail=str(e))
finally:
db.close()
@router.post("/migrate/compose-atomic", response_model=MigrationResponse)
async def migrate_compose_atomic(req: MigrationRequest):
"""Pass 0b: Compose atomic controls from obligation candidates.
With use_anthropic=true, uses Anthropic API with prompt caching
and content batching (multiple obligations per API call).
"""
from compliance.services.decomposition_pass import DecompositionPass
db = SessionLocal()
try:
decomp = DecompositionPass(db=db)
stats = await decomp.run_pass0b(
limit=req.limit,
batch_size=req.batch_size,
use_anthropic=req.use_anthropic,
)
return MigrationResponse(status="completed", stats=stats)
except Exception as e:
logger.error("Decomposition pass 0b failed: %s", e)
raise HTTPException(status_code=500, detail=str(e))
finally:
db.close()
@router.post("/migrate/batch-submit-0a", response_model=MigrationResponse)
async def batch_submit_pass0a(req: BatchSubmitRequest):
"""Submit Pass 0a as Anthropic Batch API job (50% cost reduction).
Returns a batch_id for polling. Results are processed asynchronously
within 24 hours by Anthropic.
"""
from compliance.services.decomposition_pass import DecompositionPass
db = SessionLocal()
try:
decomp = DecompositionPass(db=db)
result = await decomp.submit_batch_pass0a(
limit=req.limit,
batch_size=req.batch_size,
category_filter=req.category_filter,
source_filter=req.source_filter,
)
return MigrationResponse(status=result.pop("status", "submitted"), stats=result)
except Exception as e:
logger.error("Batch submit 0a failed: %s", e)
raise HTTPException(status_code=500, detail=str(e))
finally:
db.close()
@router.post("/migrate/batch-submit-0b", response_model=MigrationResponse)
async def batch_submit_pass0b(req: BatchSubmitRequest):
"""Submit Pass 0b as Anthropic Batch API job (50% cost reduction)."""
from compliance.services.decomposition_pass import DecompositionPass
db = SessionLocal()
try:
decomp = DecompositionPass(db=db)
result = await decomp.submit_batch_pass0b(
limit=req.limit,
batch_size=req.batch_size,
)
return MigrationResponse(status=result.pop("status", "submitted"), stats=result)
except Exception as e:
logger.error("Batch submit 0b failed: %s", e)
raise HTTPException(status_code=500, detail=str(e))
finally:
db.close()
@router.get("/migrate/batch-status/{batch_id}")
async def batch_check_status(batch_id: str):
"""Check processing status of an Anthropic batch job."""
from compliance.services.decomposition_pass import check_batch_status
try:
status = await check_batch_status(batch_id)
return status
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.post("/migrate/batch-process", response_model=MigrationResponse)
async def batch_process_results(req: BatchProcessRequest):
"""Fetch and process results from a completed Anthropic batch.
Call this after batch-status shows processing_status='ended'.
"""
from compliance.services.decomposition_pass import DecompositionPass
db = SessionLocal()
try:
decomp = DecompositionPass(db=db)
stats = await decomp.process_batch_results(
batch_id=req.batch_id,
pass_type=req.pass_type,
)
return MigrationResponse(status=stats.pop("status", "completed"), stats=stats)
except Exception as e:
logger.error("Batch process failed: %s", e)
raise HTTPException(status_code=500, detail=str(e))
finally:
db.close()
@router.post("/migrate/link-obligations", response_model=MigrationResponse)
async def migrate_link_obligations(req: MigrationRequest):
"""Pass 1: Link controls to obligations via source_citation article."""
from compliance.services.pipeline_adapter import MigrationPasses
db = SessionLocal()
try:
migration = MigrationPasses(db=db)
await migration.initialize()
stats = await migration.run_pass1_obligation_linkage(limit=req.limit)
return MigrationResponse(status="completed", stats=stats)
except Exception as e:
logger.error("Migration pass 1 failed: %s", e)
raise HTTPException(status_code=500, detail=str(e))
finally:
db.close()
@router.post("/migrate/classify-patterns", response_model=MigrationResponse)
async def migrate_classify_patterns(req: MigrationRequest):
"""Pass 2: Classify controls into patterns via keyword matching."""
from compliance.services.pipeline_adapter import MigrationPasses
db = SessionLocal()
try:
migration = MigrationPasses(db=db)
await migration.initialize()
stats = await migration.run_pass2_pattern_classification(limit=req.limit)
return MigrationResponse(status="completed", stats=stats)
except Exception as e:
logger.error("Migration pass 2 failed: %s", e)
raise HTTPException(status_code=500, detail=str(e))
finally:
db.close()
@router.post("/migrate/triage", response_model=MigrationResponse)
async def migrate_triage():
"""Pass 3: Quality triage — categorize by linkage completeness."""
from compliance.services.pipeline_adapter import MigrationPasses
db = SessionLocal()
try:
migration = MigrationPasses(db=db)
stats = migration.run_pass3_quality_triage()
return MigrationResponse(status="completed", stats=stats)
except Exception as e:
logger.error("Migration pass 3 failed: %s", e)
raise HTTPException(status_code=500, detail=str(e))
finally:
db.close()
@router.post("/migrate/backfill-crosswalk", response_model=MigrationResponse)
async def migrate_backfill_crosswalk():
"""Pass 4: Create crosswalk rows for linked controls."""
from compliance.services.pipeline_adapter import MigrationPasses
db = SessionLocal()
try:
migration = MigrationPasses(db=db)
stats = migration.run_pass4_crosswalk_backfill()
return MigrationResponse(status="completed", stats=stats)
except Exception as e:
logger.error("Migration pass 4 failed: %s", e)
raise HTTPException(status_code=500, detail=str(e))
finally:
db.close()
@router.post("/migrate/deduplicate", response_model=MigrationResponse)
async def migrate_deduplicate():
"""Pass 5: Mark duplicate controls (same obligation + pattern)."""
from compliance.services.pipeline_adapter import MigrationPasses
db = SessionLocal()
try:
migration = MigrationPasses(db=db)
stats = migration.run_pass5_deduplication()
return MigrationResponse(status="completed", stats=stats)
except Exception as e:
logger.error("Migration pass 5 failed: %s", e)
raise HTTPException(status_code=500, detail=str(e))
finally:
db.close()
@router.get("/migrate/status", response_model=MigrationStatusResponse)
async def migration_status():
"""Get overall migration progress."""
from compliance.services.pipeline_adapter import MigrationPasses
db = SessionLocal()
try:
migration = MigrationPasses(db=db)
status = migration.migration_status()
return MigrationStatusResponse(**status)
except Exception as e:
logger.error("Migration status failed: %s", e)
raise HTTPException(status_code=500, detail=str(e))
finally:
db.close()
@router.get("/migrate/decomposition-status", response_model=DecompositionStatusResponse)
async def decomposition_status():
"""Get decomposition progress (Pass 0a/0b)."""
from compliance.services.decomposition_pass import DecompositionPass
db = SessionLocal()
try:
decomp = DecompositionPass(db=db)
status = decomp.decomposition_status()
return DecompositionStatusResponse(**status)
except Exception as e:
logger.error("Decomposition status failed: %s", e)
raise HTTPException(status_code=500, detail=str(e))
finally:
db.close()
# =============================================================================
# HELPERS
# =============================================================================
def _get_pattern_control_counts() -> dict[str, int]:
"""Get count of controls per pattern_id from DB."""
db = SessionLocal()
try:
result = db.execute(text("""
SELECT pattern_id, count(*) AS cnt
FROM canonical_controls
WHERE pattern_id IS NOT NULL AND pattern_id != ''
AND release_state NOT IN ('deprecated')
GROUP BY pattern_id
"""))
return {row[0]: row[1] for row in result.fetchall()}
except Exception:
return {}
finally:
db.close()