""" FastAPI routes for the Multi-Layer Control Architecture. Pattern Library, Obligation Extraction, Crosswalk Matrix, and Migration endpoints. Endpoints: GET /v1/canonical/patterns — All patterns (with filters) GET /v1/canonical/patterns/{pattern_id} — Single pattern GET /v1/canonical/patterns/{pattern_id}/controls — Controls for a pattern POST /v1/canonical/obligations/extract — Extract obligations from text GET /v1/canonical/crosswalk — Query crosswalk matrix GET /v1/canonical/crosswalk/stats — Coverage statistics POST /v1/canonical/migrate/decompose — Pass 0a: Obligation extraction POST /v1/canonical/migrate/compose-atomic — Pass 0b: Atomic control composition POST /v1/canonical/migrate/link-obligations — Pass 1: Obligation linkage POST /v1/canonical/migrate/classify-patterns — Pass 2: Pattern classification POST /v1/canonical/migrate/triage — Pass 3: Quality triage POST /v1/canonical/migrate/backfill-crosswalk — Pass 4: Crosswalk backfill POST /v1/canonical/migrate/deduplicate — Pass 5: Deduplication GET /v1/canonical/migrate/status — Migration progress GET /v1/canonical/migrate/decomposition-status — Decomposition progress """ import json import logging from typing import Optional, List from fastapi import APIRouter, HTTPException, Query from pydantic import BaseModel from sqlalchemy import text from database import SessionLocal logger = logging.getLogger(__name__) router = APIRouter(prefix="/v1/canonical", tags=["crosswalk"]) # ============================================================================= # REQUEST / RESPONSE MODELS # ============================================================================= class PatternResponse(BaseModel): id: str name: str name_de: str domain: str category: str description: str objective_template: str severity_default: str implementation_effort_default: str = "m" tags: list = [] composable_with: list = [] open_anchor_refs: list = [] controls_count: int = 0 class PatternListResponse(BaseModel): patterns: List[PatternResponse] total: int class PatternDetailResponse(PatternResponse): rationale_template: str = "" requirements_template: list = [] test_procedure_template: list = [] evidence_template: list = [] obligation_match_keywords: list = [] class ObligationExtractRequest(BaseModel): text: str regulation_code: Optional[str] = None article: Optional[str] = None paragraph: Optional[str] = None class ObligationExtractResponse(BaseModel): obligation_id: Optional[str] = None obligation_title: Optional[str] = None obligation_text: Optional[str] = None method: str = "none" confidence: float = 0.0 regulation_id: Optional[str] = None pattern_id: Optional[str] = None pattern_confidence: float = 0.0 class CrosswalkRow(BaseModel): regulation_code: str = "" article: Optional[str] = None obligation_id: Optional[str] = None pattern_id: Optional[str] = None master_control_id: Optional[str] = None confidence: float = 0.0 source: str = "auto" class CrosswalkQueryResponse(BaseModel): rows: List[CrosswalkRow] total: int class CrosswalkStatsResponse(BaseModel): total_rows: int = 0 regulations_covered: int = 0 obligations_linked: int = 0 patterns_used: int = 0 controls_linked: int = 0 coverage_by_regulation: dict = {} class MigrationRequest(BaseModel): limit: int = 0 # 0 = no limit batch_size: int = 0 # 0 = auto (5 for Anthropic, 1 for Ollama) use_anthropic: bool = False # Use Anthropic API instead of Ollama category_filter: Optional[str] = None # Comma-separated categories source_filter: Optional[str] = None # Comma-separated source regulations (ILIKE match) class BatchSubmitRequest(BaseModel): limit: int = 0 batch_size: int = 5 category_filter: Optional[str] = None source_filter: Optional[str] = None class BatchProcessRequest(BaseModel): batch_id: str pass_type: str = "0a" # "0a" or "0b" class MigrationResponse(BaseModel): status: str = "completed" stats: dict = {} class MigrationStatusResponse(BaseModel): total_controls: int = 0 has_obligation: int = 0 has_pattern: int = 0 fully_linked: int = 0 deprecated: int = 0 coverage_obligation_pct: float = 0.0 coverage_pattern_pct: float = 0.0 coverage_full_pct: float = 0.0 class DecompositionStatusResponse(BaseModel): rich_controls: int = 0 decomposed_controls: int = 0 total_candidates: int = 0 validated: int = 0 rejected: int = 0 composed: int = 0 atomic_controls: int = 0 decomposition_pct: float = 0.0 composition_pct: float = 0.0 # ============================================================================= # PATTERN LIBRARY ENDPOINTS # ============================================================================= @router.get("/patterns", response_model=PatternListResponse) async def list_patterns( domain: Optional[str] = Query(None, description="Filter by domain (e.g. AUTH, CRYP)"), category: Optional[str] = Query(None, description="Filter by category"), tag: Optional[str] = Query(None, description="Filter by tag"), ): """List all control patterns with optional filters.""" from compliance.services.pattern_matcher import PatternMatcher matcher = PatternMatcher() matcher._load_patterns() matcher._build_keyword_index() patterns = matcher._patterns if domain: patterns = [p for p in patterns if p.domain == domain.upper()] if category: patterns = [p for p in patterns if p.category == category.lower()] if tag: patterns = [p for p in patterns if tag.lower() in [t.lower() for t in p.tags]] # Count controls per pattern from DB control_counts = _get_pattern_control_counts() response_patterns = [] for p in patterns: response_patterns.append(PatternResponse( id=p.id, name=p.name, name_de=p.name_de, domain=p.domain, category=p.category, description=p.description, objective_template=p.objective_template, severity_default=p.severity_default, implementation_effort_default=p.implementation_effort_default, tags=p.tags, composable_with=p.composable_with, open_anchor_refs=p.open_anchor_refs, controls_count=control_counts.get(p.id, 0), )) return PatternListResponse(patterns=response_patterns, total=len(response_patterns)) @router.get("/patterns/{pattern_id}", response_model=PatternDetailResponse) async def get_pattern(pattern_id: str): """Get a single control pattern by ID.""" from compliance.services.pattern_matcher import PatternMatcher matcher = PatternMatcher() matcher._load_patterns() pattern = matcher.get_pattern(pattern_id) if not pattern: raise HTTPException(status_code=404, detail=f"Pattern {pattern_id} not found") control_counts = _get_pattern_control_counts() return PatternDetailResponse( id=pattern.id, name=pattern.name, name_de=pattern.name_de, domain=pattern.domain, category=pattern.category, description=pattern.description, objective_template=pattern.objective_template, rationale_template=pattern.rationale_template, requirements_template=pattern.requirements_template, test_procedure_template=pattern.test_procedure_template, evidence_template=pattern.evidence_template, severity_default=pattern.severity_default, implementation_effort_default=pattern.implementation_effort_default, tags=pattern.tags, composable_with=pattern.composable_with, open_anchor_refs=pattern.open_anchor_refs, obligation_match_keywords=pattern.obligation_match_keywords, controls_count=control_counts.get(pattern.id, 0), ) @router.get("/patterns/{pattern_id}/controls") async def get_pattern_controls( pattern_id: str, limit: int = Query(50, ge=1, le=500), offset: int = Query(0, ge=0), ): """Get controls generated from a specific pattern.""" db = SessionLocal() try: result = db.execute( text(""" SELECT id, control_id, title, objective, severity, release_state, category, obligation_ids FROM canonical_controls WHERE pattern_id = :pattern_id AND release_state NOT IN ('deprecated') ORDER BY control_id LIMIT :limit OFFSET :offset """), {"pattern_id": pattern_id.upper(), "limit": limit, "offset": offset}, ) rows = result.fetchall() count_result = db.execute( text(""" SELECT count(*) FROM canonical_controls WHERE pattern_id = :pattern_id AND release_state NOT IN ('deprecated') """), {"pattern_id": pattern_id.upper()}, ) total = count_result.fetchone()[0] controls = [] for row in rows: obl_ids = row[7] if isinstance(obl_ids, str): try: obl_ids = json.loads(obl_ids) except (json.JSONDecodeError, TypeError): obl_ids = [] controls.append({ "id": str(row[0]), "control_id": row[1], "title": row[2], "objective": row[3], "severity": row[4], "release_state": row[5], "category": row[6], "obligation_ids": obl_ids or [], }) return {"controls": controls, "total": total} finally: db.close() # ============================================================================= # OBLIGATION EXTRACTION ENDPOINT # ============================================================================= @router.post("/obligations/extract", response_model=ObligationExtractResponse) async def extract_obligation(req: ObligationExtractRequest): """Extract obligation from text using 3-tier strategy, then match to pattern.""" from compliance.services.obligation_extractor import ObligationExtractor from compliance.services.pattern_matcher import PatternMatcher extractor = ObligationExtractor() await extractor.initialize() obligation = await extractor.extract( chunk_text=req.text, regulation_code=req.regulation_code or "", article=req.article, paragraph=req.paragraph, ) # Also match to pattern matcher = PatternMatcher() matcher._load_patterns() matcher._build_keyword_index() pattern_text = obligation.obligation_text or obligation.obligation_title or req.text[:500] pattern_result = matcher._tier1_keyword(pattern_text, obligation.regulation_id) return ObligationExtractResponse( obligation_id=obligation.obligation_id, obligation_title=obligation.obligation_title, obligation_text=obligation.obligation_text, method=obligation.method, confidence=obligation.confidence, regulation_id=obligation.regulation_id, pattern_id=pattern_result.pattern_id if pattern_result else None, pattern_confidence=pattern_result.confidence if pattern_result else 0, ) # ============================================================================= # CROSSWALK MATRIX ENDPOINTS # ============================================================================= @router.get("/crosswalk", response_model=CrosswalkQueryResponse) async def query_crosswalk( regulation_code: Optional[str] = Query(None), article: Optional[str] = Query(None), obligation_id: Optional[str] = Query(None), pattern_id: Optional[str] = Query(None), limit: int = Query(100, ge=1, le=1000), offset: int = Query(0, ge=0), ): """Query the crosswalk matrix with filters.""" db = SessionLocal() try: conditions = ["1=1"] params = {"limit": limit, "offset": offset} if regulation_code: conditions.append("regulation_code = :reg") params["reg"] = regulation_code if article: conditions.append("article = :art") params["art"] = article if obligation_id: conditions.append("obligation_id = :obl") params["obl"] = obligation_id if pattern_id: conditions.append("pattern_id = :pat") params["pat"] = pattern_id where = " AND ".join(conditions) result = db.execute( text(f""" SELECT regulation_code, article, obligation_id, pattern_id, master_control_id, confidence, source FROM crosswalk_matrix WHERE {where} ORDER BY regulation_code, article LIMIT :limit OFFSET :offset """), params, ) rows = result.fetchall() count_result = db.execute( text(f"SELECT count(*) FROM crosswalk_matrix WHERE {where}"), params, ) total = count_result.fetchone()[0] crosswalk_rows = [ CrosswalkRow( regulation_code=r[0] or "", article=r[1], obligation_id=r[2], pattern_id=r[3], master_control_id=r[4], confidence=float(r[5] or 0), source=r[6] or "auto", ) for r in rows ] return CrosswalkQueryResponse(rows=crosswalk_rows, total=total) finally: db.close() @router.get("/crosswalk/stats", response_model=CrosswalkStatsResponse) async def crosswalk_stats(): """Get crosswalk coverage statistics.""" db = SessionLocal() try: row = db.execute(text(""" SELECT count(*) AS total, count(DISTINCT regulation_code) FILTER (WHERE regulation_code != '') AS regs, count(DISTINCT obligation_id) FILTER (WHERE obligation_id IS NOT NULL) AS obls, count(DISTINCT pattern_id) FILTER (WHERE pattern_id IS NOT NULL) AS pats, count(DISTINCT master_control_id) FILTER (WHERE master_control_id IS NOT NULL) AS ctrls FROM crosswalk_matrix """)).fetchone() # Coverage by regulation reg_rows = db.execute(text(""" SELECT regulation_code, count(*) AS cnt FROM crosswalk_matrix WHERE regulation_code != '' GROUP BY regulation_code ORDER BY cnt DESC """)).fetchall() coverage = {r[0]: r[1] for r in reg_rows} return CrosswalkStatsResponse( total_rows=row[0], regulations_covered=row[1], obligations_linked=row[2], patterns_used=row[3], controls_linked=row[4], coverage_by_regulation=coverage, ) finally: db.close() # ============================================================================= # MIGRATION ENDPOINTS # ============================================================================= @router.post("/migrate/decompose", response_model=MigrationResponse) async def migrate_decompose(req: MigrationRequest): """Pass 0a: Extract obligation candidates from rich controls. With use_anthropic=true, uses Anthropic API with prompt caching and content batching (multiple controls per API call). """ from compliance.services.decomposition_pass import DecompositionPass db = SessionLocal() try: decomp = DecompositionPass(db=db) stats = await decomp.run_pass0a( limit=req.limit, batch_size=req.batch_size, use_anthropic=req.use_anthropic, category_filter=req.category_filter, source_filter=req.source_filter, ) return MigrationResponse(status="completed", stats=stats) except Exception as e: logger.error("Decomposition pass 0a failed: %s", e) raise HTTPException(status_code=500, detail=str(e)) finally: db.close() @router.post("/migrate/compose-atomic", response_model=MigrationResponse) async def migrate_compose_atomic(req: MigrationRequest): """Pass 0b: Compose atomic controls from obligation candidates. With use_anthropic=true, uses Anthropic API with prompt caching and content batching (multiple obligations per API call). """ from compliance.services.decomposition_pass import DecompositionPass db = SessionLocal() try: decomp = DecompositionPass(db=db) stats = await decomp.run_pass0b( limit=req.limit, batch_size=req.batch_size, use_anthropic=req.use_anthropic, ) return MigrationResponse(status="completed", stats=stats) except Exception as e: logger.error("Decomposition pass 0b failed: %s", e) raise HTTPException(status_code=500, detail=str(e)) finally: db.close() @router.post("/migrate/batch-submit-0a", response_model=MigrationResponse) async def batch_submit_pass0a(req: BatchSubmitRequest): """Submit Pass 0a as Anthropic Batch API job (50% cost reduction). Returns a batch_id for polling. Results are processed asynchronously within 24 hours by Anthropic. """ from compliance.services.decomposition_pass import DecompositionPass db = SessionLocal() try: decomp = DecompositionPass(db=db) result = await decomp.submit_batch_pass0a( limit=req.limit, batch_size=req.batch_size, category_filter=req.category_filter, source_filter=req.source_filter, ) return MigrationResponse(status=result.pop("status", "submitted"), stats=result) except Exception as e: logger.error("Batch submit 0a failed: %s", e) raise HTTPException(status_code=500, detail=str(e)) finally: db.close() @router.post("/migrate/batch-submit-0b", response_model=MigrationResponse) async def batch_submit_pass0b(req: BatchSubmitRequest): """Submit Pass 0b as Anthropic Batch API job (50% cost reduction).""" from compliance.services.decomposition_pass import DecompositionPass db = SessionLocal() try: decomp = DecompositionPass(db=db) result = await decomp.submit_batch_pass0b( limit=req.limit, batch_size=req.batch_size, ) return MigrationResponse(status=result.pop("status", "submitted"), stats=result) except Exception as e: logger.error("Batch submit 0b failed: %s", e) raise HTTPException(status_code=500, detail=str(e)) finally: db.close() @router.get("/migrate/batch-status/{batch_id}") async def batch_check_status(batch_id: str): """Check processing status of an Anthropic batch job.""" from compliance.services.decomposition_pass import check_batch_status try: status = await check_batch_status(batch_id) return status except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @router.post("/migrate/batch-process", response_model=MigrationResponse) async def batch_process_results(req: BatchProcessRequest): """Fetch and process results from a completed Anthropic batch. Call this after batch-status shows processing_status='ended'. """ from compliance.services.decomposition_pass import DecompositionPass db = SessionLocal() try: decomp = DecompositionPass(db=db) stats = await decomp.process_batch_results( batch_id=req.batch_id, pass_type=req.pass_type, ) return MigrationResponse(status=stats.pop("status", "completed"), stats=stats) except Exception as e: logger.error("Batch process failed: %s", e) raise HTTPException(status_code=500, detail=str(e)) finally: db.close() @router.post("/migrate/link-obligations", response_model=MigrationResponse) async def migrate_link_obligations(req: MigrationRequest): """Pass 1: Link controls to obligations via source_citation article.""" from compliance.services.pipeline_adapter import MigrationPasses db = SessionLocal() try: migration = MigrationPasses(db=db) await migration.initialize() stats = await migration.run_pass1_obligation_linkage(limit=req.limit) return MigrationResponse(status="completed", stats=stats) except Exception as e: logger.error("Migration pass 1 failed: %s", e) raise HTTPException(status_code=500, detail=str(e)) finally: db.close() @router.post("/migrate/classify-patterns", response_model=MigrationResponse) async def migrate_classify_patterns(req: MigrationRequest): """Pass 2: Classify controls into patterns via keyword matching.""" from compliance.services.pipeline_adapter import MigrationPasses db = SessionLocal() try: migration = MigrationPasses(db=db) await migration.initialize() stats = await migration.run_pass2_pattern_classification(limit=req.limit) return MigrationResponse(status="completed", stats=stats) except Exception as e: logger.error("Migration pass 2 failed: %s", e) raise HTTPException(status_code=500, detail=str(e)) finally: db.close() @router.post("/migrate/triage", response_model=MigrationResponse) async def migrate_triage(): """Pass 3: Quality triage — categorize by linkage completeness.""" from compliance.services.pipeline_adapter import MigrationPasses db = SessionLocal() try: migration = MigrationPasses(db=db) stats = migration.run_pass3_quality_triage() return MigrationResponse(status="completed", stats=stats) except Exception as e: logger.error("Migration pass 3 failed: %s", e) raise HTTPException(status_code=500, detail=str(e)) finally: db.close() @router.post("/migrate/backfill-crosswalk", response_model=MigrationResponse) async def migrate_backfill_crosswalk(): """Pass 4: Create crosswalk rows for linked controls.""" from compliance.services.pipeline_adapter import MigrationPasses db = SessionLocal() try: migration = MigrationPasses(db=db) stats = migration.run_pass4_crosswalk_backfill() return MigrationResponse(status="completed", stats=stats) except Exception as e: logger.error("Migration pass 4 failed: %s", e) raise HTTPException(status_code=500, detail=str(e)) finally: db.close() @router.post("/migrate/deduplicate", response_model=MigrationResponse) async def migrate_deduplicate(): """Pass 5: Mark duplicate controls (same obligation + pattern).""" from compliance.services.pipeline_adapter import MigrationPasses db = SessionLocal() try: migration = MigrationPasses(db=db) stats = migration.run_pass5_deduplication() return MigrationResponse(status="completed", stats=stats) except Exception as e: logger.error("Migration pass 5 failed: %s", e) raise HTTPException(status_code=500, detail=str(e)) finally: db.close() @router.get("/migrate/status", response_model=MigrationStatusResponse) async def migration_status(): """Get overall migration progress.""" from compliance.services.pipeline_adapter import MigrationPasses db = SessionLocal() try: migration = MigrationPasses(db=db) status = migration.migration_status() return MigrationStatusResponse(**status) except Exception as e: logger.error("Migration status failed: %s", e) raise HTTPException(status_code=500, detail=str(e)) finally: db.close() @router.get("/migrate/decomposition-status", response_model=DecompositionStatusResponse) async def decomposition_status(): """Get decomposition progress (Pass 0a/0b).""" from compliance.services.decomposition_pass import DecompositionPass db = SessionLocal() try: decomp = DecompositionPass(db=db) status = decomp.decomposition_status() return DecompositionStatusResponse(**status) except Exception as e: logger.error("Decomposition status failed: %s", e) raise HTTPException(status_code=500, detail=str(e)) finally: db.close() # ============================================================================= # HELPERS # ============================================================================= def _get_pattern_control_counts() -> dict[str, int]: """Get count of controls per pattern_id from DB.""" db = SessionLocal() try: result = db.execute(text(""" SELECT pattern_id, count(*) AS cnt FROM canonical_controls WHERE pattern_id IS NOT NULL AND pattern_id != '' AND release_state NOT IN ('deprecated') GROUP BY pattern_id """)) return {row[0]: row[1] for row in result.fetchall()} except Exception: return {} finally: db.close()