Files
breakpilot-compliance/backend-compliance/compliance/api/crosswalk_routes.py
Benjamin Admin 770f0b5ab0
All checks were successful
CI/CD / go-lint (push) Has been skipped
CI/CD / python-lint (push) Has been skipped
CI/CD / nodejs-lint (push) Has been skipped
CI/CD / test-go-ai-compliance (push) Successful in 31s
CI/CD / test-python-backend-compliance (push) Successful in 31s
CI/CD / test-python-document-crawler (push) Successful in 21s
CI/CD / test-python-dsms-gateway (push) Successful in 19s
CI/CD / validate-canonical-controls (push) Successful in 10s
CI/CD / Deploy (push) Successful in 2s
fix: adapt batch dedup to NULL pattern_id — group by merge_group_hint
All Pass 0b controls have pattern_id=NULL. Rewritten to:
- Phase 1: Group by merge_group_hint (action:object:trigger), 52k groups
- Phase 2: Cross-group embedding search for semantically similar masters
- Qdrant search uses unfiltered cross-regulation endpoint
- API param changed: pattern_id → hint_filter

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-24 07:24:02 +01:00

857 lines
29 KiB
Python

"""
FastAPI routes for the Multi-Layer Control Architecture.
Pattern Library, Obligation Extraction, Crosswalk Matrix, and Migration endpoints.
Endpoints:
GET /v1/canonical/patterns — All patterns (with filters)
GET /v1/canonical/patterns/{pattern_id} — Single pattern
GET /v1/canonical/patterns/{pattern_id}/controls — Controls for a pattern
POST /v1/canonical/obligations/extract — Extract obligations from text
GET /v1/canonical/crosswalk — Query crosswalk matrix
GET /v1/canonical/crosswalk/stats — Coverage statistics
POST /v1/canonical/migrate/decompose — Pass 0a: Obligation extraction
POST /v1/canonical/migrate/merge-obligations — Merge implementation-level dupes
POST /v1/canonical/migrate/enrich-obligations — Add trigger_type, impl metadata
POST /v1/canonical/migrate/compose-atomic — Pass 0b: Atomic control composition
POST /v1/canonical/migrate/link-obligations — Pass 1: Obligation linkage
POST /v1/canonical/migrate/classify-patterns — Pass 2: Pattern classification
POST /v1/canonical/migrate/triage — Pass 3: Quality triage
POST /v1/canonical/migrate/backfill-crosswalk — Pass 4: Crosswalk backfill
POST /v1/canonical/migrate/deduplicate — Pass 5: Deduplication
GET /v1/canonical/migrate/status — Migration progress
GET /v1/canonical/migrate/decomposition-status — Decomposition progress
"""
import json
import logging
from typing import Optional, List
from fastapi import APIRouter, HTTPException, Query
from pydantic import BaseModel
from sqlalchemy import text
from database import SessionLocal
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/v1/canonical", tags=["crosswalk"])
# =============================================================================
# REQUEST / RESPONSE MODELS
# =============================================================================
class PatternResponse(BaseModel):
id: str
name: str
name_de: str
domain: str
category: str
description: str
objective_template: str
severity_default: str
implementation_effort_default: str = "m"
tags: list = []
composable_with: list = []
open_anchor_refs: list = []
controls_count: int = 0
class PatternListResponse(BaseModel):
patterns: List[PatternResponse]
total: int
class PatternDetailResponse(PatternResponse):
rationale_template: str = ""
requirements_template: list = []
test_procedure_template: list = []
evidence_template: list = []
obligation_match_keywords: list = []
class ObligationExtractRequest(BaseModel):
text: str
regulation_code: Optional[str] = None
article: Optional[str] = None
paragraph: Optional[str] = None
class ObligationExtractResponse(BaseModel):
obligation_id: Optional[str] = None
obligation_title: Optional[str] = None
obligation_text: Optional[str] = None
method: str = "none"
confidence: float = 0.0
regulation_id: Optional[str] = None
pattern_id: Optional[str] = None
pattern_confidence: float = 0.0
class CrosswalkRow(BaseModel):
regulation_code: str = ""
article: Optional[str] = None
obligation_id: Optional[str] = None
pattern_id: Optional[str] = None
master_control_id: Optional[str] = None
confidence: float = 0.0
source: str = "auto"
class CrosswalkQueryResponse(BaseModel):
rows: List[CrosswalkRow]
total: int
class CrosswalkStatsResponse(BaseModel):
total_rows: int = 0
regulations_covered: int = 0
obligations_linked: int = 0
patterns_used: int = 0
controls_linked: int = 0
coverage_by_regulation: dict = {}
class MigrationRequest(BaseModel):
limit: int = 0 # 0 = no limit
batch_size: int = 0 # 0 = auto (5 for Anthropic, 1 for Ollama)
use_anthropic: bool = False # Use Anthropic API instead of Ollama
category_filter: Optional[str] = None # Comma-separated categories
source_filter: Optional[str] = None # Comma-separated source regulations (ILIKE match)
class BatchSubmitRequest(BaseModel):
limit: int = 0
batch_size: int = 5
category_filter: Optional[str] = None
source_filter: Optional[str] = None
class BatchProcessRequest(BaseModel):
batch_id: str
pass_type: str = "0a" # "0a" or "0b"
class MigrationResponse(BaseModel):
status: str = "completed"
stats: dict = {}
class MigrationStatusResponse(BaseModel):
total_controls: int = 0
has_obligation: int = 0
has_pattern: int = 0
fully_linked: int = 0
deprecated: int = 0
coverage_obligation_pct: float = 0.0
coverage_pattern_pct: float = 0.0
coverage_full_pct: float = 0.0
class DecompositionStatusResponse(BaseModel):
rich_controls: int = 0
decomposed_controls: int = 0
total_candidates: int = 0
validated: int = 0
rejected: int = 0
composed: int = 0
atomic_controls: int = 0
merged: int = 0
enriched: int = 0
ready_for_pass0b: int = 0
decomposition_pct: float = 0.0
composition_pct: float = 0.0
# =============================================================================
# PATTERN LIBRARY ENDPOINTS
# =============================================================================
@router.get("/patterns", response_model=PatternListResponse)
async def list_patterns(
domain: Optional[str] = Query(None, description="Filter by domain (e.g. AUTH, CRYP)"),
category: Optional[str] = Query(None, description="Filter by category"),
tag: Optional[str] = Query(None, description="Filter by tag"),
):
"""List all control patterns with optional filters."""
from compliance.services.pattern_matcher import PatternMatcher
matcher = PatternMatcher()
matcher._load_patterns()
matcher._build_keyword_index()
patterns = matcher._patterns
if domain:
patterns = [p for p in patterns if p.domain == domain.upper()]
if category:
patterns = [p for p in patterns if p.category == category.lower()]
if tag:
patterns = [p for p in patterns if tag.lower() in [t.lower() for t in p.tags]]
# Count controls per pattern from DB
control_counts = _get_pattern_control_counts()
response_patterns = []
for p in patterns:
response_patterns.append(PatternResponse(
id=p.id,
name=p.name,
name_de=p.name_de,
domain=p.domain,
category=p.category,
description=p.description,
objective_template=p.objective_template,
severity_default=p.severity_default,
implementation_effort_default=p.implementation_effort_default,
tags=p.tags,
composable_with=p.composable_with,
open_anchor_refs=p.open_anchor_refs,
controls_count=control_counts.get(p.id, 0),
))
return PatternListResponse(patterns=response_patterns, total=len(response_patterns))
@router.get("/patterns/{pattern_id}", response_model=PatternDetailResponse)
async def get_pattern(pattern_id: str):
"""Get a single control pattern by ID."""
from compliance.services.pattern_matcher import PatternMatcher
matcher = PatternMatcher()
matcher._load_patterns()
pattern = matcher.get_pattern(pattern_id)
if not pattern:
raise HTTPException(status_code=404, detail=f"Pattern {pattern_id} not found")
control_counts = _get_pattern_control_counts()
return PatternDetailResponse(
id=pattern.id,
name=pattern.name,
name_de=pattern.name_de,
domain=pattern.domain,
category=pattern.category,
description=pattern.description,
objective_template=pattern.objective_template,
rationale_template=pattern.rationale_template,
requirements_template=pattern.requirements_template,
test_procedure_template=pattern.test_procedure_template,
evidence_template=pattern.evidence_template,
severity_default=pattern.severity_default,
implementation_effort_default=pattern.implementation_effort_default,
tags=pattern.tags,
composable_with=pattern.composable_with,
open_anchor_refs=pattern.open_anchor_refs,
obligation_match_keywords=pattern.obligation_match_keywords,
controls_count=control_counts.get(pattern.id, 0),
)
@router.get("/patterns/{pattern_id}/controls")
async def get_pattern_controls(
pattern_id: str,
limit: int = Query(50, ge=1, le=500),
offset: int = Query(0, ge=0),
):
"""Get controls generated from a specific pattern."""
db = SessionLocal()
try:
result = db.execute(
text("""
SELECT id, control_id, title, objective, severity,
release_state, category, obligation_ids
FROM canonical_controls
WHERE pattern_id = :pattern_id
AND release_state NOT IN ('deprecated')
ORDER BY control_id
LIMIT :limit OFFSET :offset
"""),
{"pattern_id": pattern_id.upper(), "limit": limit, "offset": offset},
)
rows = result.fetchall()
count_result = db.execute(
text("""
SELECT count(*) FROM canonical_controls
WHERE pattern_id = :pattern_id
AND release_state NOT IN ('deprecated')
"""),
{"pattern_id": pattern_id.upper()},
)
total = count_result.fetchone()[0]
controls = []
for row in rows:
obl_ids = row[7]
if isinstance(obl_ids, str):
try:
obl_ids = json.loads(obl_ids)
except (json.JSONDecodeError, TypeError):
obl_ids = []
controls.append({
"id": str(row[0]),
"control_id": row[1],
"title": row[2],
"objective": row[3],
"severity": row[4],
"release_state": row[5],
"category": row[6],
"obligation_ids": obl_ids or [],
})
return {"controls": controls, "total": total}
finally:
db.close()
# =============================================================================
# OBLIGATION EXTRACTION ENDPOINT
# =============================================================================
@router.post("/obligations/extract", response_model=ObligationExtractResponse)
async def extract_obligation(req: ObligationExtractRequest):
"""Extract obligation from text using 3-tier strategy, then match to pattern."""
from compliance.services.obligation_extractor import ObligationExtractor
from compliance.services.pattern_matcher import PatternMatcher
extractor = ObligationExtractor()
await extractor.initialize()
obligation = await extractor.extract(
chunk_text=req.text,
regulation_code=req.regulation_code or "",
article=req.article,
paragraph=req.paragraph,
)
# Also match to pattern
matcher = PatternMatcher()
matcher._load_patterns()
matcher._build_keyword_index()
pattern_text = obligation.obligation_text or obligation.obligation_title or req.text[:500]
pattern_result = matcher._tier1_keyword(pattern_text, obligation.regulation_id)
return ObligationExtractResponse(
obligation_id=obligation.obligation_id,
obligation_title=obligation.obligation_title,
obligation_text=obligation.obligation_text,
method=obligation.method,
confidence=obligation.confidence,
regulation_id=obligation.regulation_id,
pattern_id=pattern_result.pattern_id if pattern_result else None,
pattern_confidence=pattern_result.confidence if pattern_result else 0,
)
# =============================================================================
# CROSSWALK MATRIX ENDPOINTS
# =============================================================================
@router.get("/crosswalk", response_model=CrosswalkQueryResponse)
async def query_crosswalk(
regulation_code: Optional[str] = Query(None),
article: Optional[str] = Query(None),
obligation_id: Optional[str] = Query(None),
pattern_id: Optional[str] = Query(None),
limit: int = Query(100, ge=1, le=1000),
offset: int = Query(0, ge=0),
):
"""Query the crosswalk matrix with filters."""
db = SessionLocal()
try:
conditions = ["1=1"]
params = {"limit": limit, "offset": offset}
if regulation_code:
conditions.append("regulation_code = :reg")
params["reg"] = regulation_code
if article:
conditions.append("article = :art")
params["art"] = article
if obligation_id:
conditions.append("obligation_id = :obl")
params["obl"] = obligation_id
if pattern_id:
conditions.append("pattern_id = :pat")
params["pat"] = pattern_id
where = " AND ".join(conditions)
result = db.execute(
text(f"""
SELECT regulation_code, article, obligation_id,
pattern_id, master_control_id, confidence, source
FROM crosswalk_matrix
WHERE {where}
ORDER BY regulation_code, article
LIMIT :limit OFFSET :offset
"""),
params,
)
rows = result.fetchall()
count_result = db.execute(
text(f"SELECT count(*) FROM crosswalk_matrix WHERE {where}"),
params,
)
total = count_result.fetchone()[0]
crosswalk_rows = [
CrosswalkRow(
regulation_code=r[0] or "",
article=r[1],
obligation_id=r[2],
pattern_id=r[3],
master_control_id=r[4],
confidence=float(r[5] or 0),
source=r[6] or "auto",
)
for r in rows
]
return CrosswalkQueryResponse(rows=crosswalk_rows, total=total)
finally:
db.close()
@router.get("/crosswalk/stats", response_model=CrosswalkStatsResponse)
async def crosswalk_stats():
"""Get crosswalk coverage statistics."""
db = SessionLocal()
try:
row = db.execute(text("""
SELECT
count(*) AS total,
count(DISTINCT regulation_code) FILTER (WHERE regulation_code != '') AS regs,
count(DISTINCT obligation_id) FILTER (WHERE obligation_id IS NOT NULL) AS obls,
count(DISTINCT pattern_id) FILTER (WHERE pattern_id IS NOT NULL) AS pats,
count(DISTINCT master_control_id) FILTER (WHERE master_control_id IS NOT NULL) AS ctrls
FROM crosswalk_matrix
""")).fetchone()
# Coverage by regulation
reg_rows = db.execute(text("""
SELECT regulation_code, count(*) AS cnt
FROM crosswalk_matrix
WHERE regulation_code != ''
GROUP BY regulation_code
ORDER BY cnt DESC
""")).fetchall()
coverage = {r[0]: r[1] for r in reg_rows}
return CrosswalkStatsResponse(
total_rows=row[0],
regulations_covered=row[1],
obligations_linked=row[2],
patterns_used=row[3],
controls_linked=row[4],
coverage_by_regulation=coverage,
)
finally:
db.close()
# =============================================================================
# MIGRATION ENDPOINTS
# =============================================================================
@router.post("/migrate/decompose", response_model=MigrationResponse)
async def migrate_decompose(req: MigrationRequest):
"""Pass 0a: Extract obligation candidates from rich controls.
With use_anthropic=true, uses Anthropic API with prompt caching
and content batching (multiple controls per API call).
"""
from compliance.services.decomposition_pass import DecompositionPass
db = SessionLocal()
try:
decomp = DecompositionPass(db=db)
stats = await decomp.run_pass0a(
limit=req.limit,
batch_size=req.batch_size,
use_anthropic=req.use_anthropic,
category_filter=req.category_filter,
source_filter=req.source_filter,
)
return MigrationResponse(status="completed", stats=stats)
except Exception as e:
logger.error("Decomposition pass 0a failed: %s", e)
raise HTTPException(status_code=500, detail=str(e))
finally:
db.close()
@router.post("/migrate/merge-obligations", response_model=MigrationResponse)
async def migrate_merge_obligations():
"""Merge implementation-level duplicate obligations within each parent.
Run AFTER Pass 0a, BEFORE Pass 0b. No LLM calls — rule-based.
Merges obligations that share similar action+object into the more
abstract survivor, marking the concrete duplicate as 'merged'.
"""
from compliance.services.decomposition_pass import DecompositionPass
db = SessionLocal()
try:
decomp = DecompositionPass(db=db)
stats = decomp.run_merge_pass()
return MigrationResponse(status="completed", stats=stats)
except Exception as e:
logger.error("Merge pass failed: %s", e)
raise HTTPException(status_code=500, detail=str(e))
finally:
db.close()
@router.post("/migrate/enrich-obligations", response_model=MigrationResponse)
async def migrate_enrich_obligations():
"""Add trigger_type and is_implementation_specific metadata.
Run AFTER merge pass, BEFORE Pass 0b. No LLM calls — rule-based.
Classifies trigger_type (event/periodic/continuous) from obligation text
and detects implementation-specific obligations (concrete tools/protocols).
"""
from compliance.services.decomposition_pass import DecompositionPass
db = SessionLocal()
try:
decomp = DecompositionPass(db=db)
stats = decomp.enrich_obligations()
return MigrationResponse(status="completed", stats=stats)
except Exception as e:
logger.error("Enrich pass failed: %s", e)
raise HTTPException(status_code=500, detail=str(e))
finally:
db.close()
@router.post("/migrate/compose-atomic", response_model=MigrationResponse)
async def migrate_compose_atomic(req: MigrationRequest):
"""Pass 0b: Compose atomic controls from obligation candidates.
With use_anthropic=true, uses Anthropic API with prompt caching
and content batching (multiple obligations per API call).
"""
from compliance.services.decomposition_pass import DecompositionPass
db = SessionLocal()
try:
decomp = DecompositionPass(db=db)
stats = await decomp.run_pass0b(
limit=req.limit,
batch_size=req.batch_size,
use_anthropic=req.use_anthropic,
)
return MigrationResponse(status="completed", stats=stats)
except Exception as e:
logger.error("Decomposition pass 0b failed: %s", e)
raise HTTPException(status_code=500, detail=str(e))
finally:
db.close()
@router.post("/migrate/batch-submit-0a", response_model=MigrationResponse)
async def batch_submit_pass0a(req: BatchSubmitRequest):
"""Submit Pass 0a as Anthropic Batch API job (50% cost reduction).
Returns a batch_id for polling. Results are processed asynchronously
within 24 hours by Anthropic.
"""
from compliance.services.decomposition_pass import DecompositionPass
db = SessionLocal()
try:
decomp = DecompositionPass(db=db)
result = await decomp.submit_batch_pass0a(
limit=req.limit,
batch_size=req.batch_size,
category_filter=req.category_filter,
source_filter=req.source_filter,
)
return MigrationResponse(status=result.pop("status", "submitted"), stats=result)
except Exception as e:
logger.error("Batch submit 0a failed: %s", e)
raise HTTPException(status_code=500, detail=str(e))
finally:
db.close()
@router.post("/migrate/batch-submit-0b", response_model=MigrationResponse)
async def batch_submit_pass0b(req: BatchSubmitRequest):
"""Submit Pass 0b as Anthropic Batch API job (50% cost reduction)."""
from compliance.services.decomposition_pass import DecompositionPass
db = SessionLocal()
try:
decomp = DecompositionPass(db=db)
result = await decomp.submit_batch_pass0b(
limit=req.limit,
batch_size=req.batch_size,
)
return MigrationResponse(status=result.pop("status", "submitted"), stats=result)
except Exception as e:
logger.error("Batch submit 0b failed: %s", e)
raise HTTPException(status_code=500, detail=str(e))
finally:
db.close()
@router.get("/migrate/batch-status/{batch_id}")
async def batch_check_status(batch_id: str):
"""Check processing status of an Anthropic batch job."""
from compliance.services.decomposition_pass import check_batch_status
try:
status = await check_batch_status(batch_id)
return status
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.post("/migrate/batch-process", response_model=MigrationResponse)
async def batch_process_results(req: BatchProcessRequest):
"""Fetch and process results from a completed Anthropic batch.
Call this after batch-status shows processing_status='ended'.
"""
from compliance.services.decomposition_pass import DecompositionPass
db = SessionLocal()
try:
decomp = DecompositionPass(db=db)
stats = await decomp.process_batch_results(
batch_id=req.batch_id,
pass_type=req.pass_type,
)
return MigrationResponse(status=stats.pop("status", "completed"), stats=stats)
except Exception as e:
logger.error("Batch process failed: %s", e)
raise HTTPException(status_code=500, detail=str(e))
finally:
db.close()
@router.post("/migrate/link-obligations", response_model=MigrationResponse)
async def migrate_link_obligations(req: MigrationRequest):
"""Pass 1: Link controls to obligations via source_citation article."""
from compliance.services.pipeline_adapter import MigrationPasses
db = SessionLocal()
try:
migration = MigrationPasses(db=db)
await migration.initialize()
stats = await migration.run_pass1_obligation_linkage(limit=req.limit)
return MigrationResponse(status="completed", stats=stats)
except Exception as e:
logger.error("Migration pass 1 failed: %s", e)
raise HTTPException(status_code=500, detail=str(e))
finally:
db.close()
@router.post("/migrate/classify-patterns", response_model=MigrationResponse)
async def migrate_classify_patterns(req: MigrationRequest):
"""Pass 2: Classify controls into patterns via keyword matching."""
from compliance.services.pipeline_adapter import MigrationPasses
db = SessionLocal()
try:
migration = MigrationPasses(db=db)
await migration.initialize()
stats = await migration.run_pass2_pattern_classification(limit=req.limit)
return MigrationResponse(status="completed", stats=stats)
except Exception as e:
logger.error("Migration pass 2 failed: %s", e)
raise HTTPException(status_code=500, detail=str(e))
finally:
db.close()
@router.post("/migrate/triage", response_model=MigrationResponse)
async def migrate_triage():
"""Pass 3: Quality triage — categorize by linkage completeness."""
from compliance.services.pipeline_adapter import MigrationPasses
db = SessionLocal()
try:
migration = MigrationPasses(db=db)
stats = migration.run_pass3_quality_triage()
return MigrationResponse(status="completed", stats=stats)
except Exception as e:
logger.error("Migration pass 3 failed: %s", e)
raise HTTPException(status_code=500, detail=str(e))
finally:
db.close()
@router.post("/migrate/backfill-crosswalk", response_model=MigrationResponse)
async def migrate_backfill_crosswalk():
"""Pass 4: Create crosswalk rows for linked controls."""
from compliance.services.pipeline_adapter import MigrationPasses
db = SessionLocal()
try:
migration = MigrationPasses(db=db)
stats = migration.run_pass4_crosswalk_backfill()
return MigrationResponse(status="completed", stats=stats)
except Exception as e:
logger.error("Migration pass 4 failed: %s", e)
raise HTTPException(status_code=500, detail=str(e))
finally:
db.close()
@router.post("/migrate/deduplicate", response_model=MigrationResponse)
async def migrate_deduplicate():
"""Pass 5: Mark duplicate controls (same obligation + pattern)."""
from compliance.services.pipeline_adapter import MigrationPasses
db = SessionLocal()
try:
migration = MigrationPasses(db=db)
stats = migration.run_pass5_deduplication()
return MigrationResponse(status="completed", stats=stats)
except Exception as e:
logger.error("Migration pass 5 failed: %s", e)
raise HTTPException(status_code=500, detail=str(e))
finally:
db.close()
@router.get("/migrate/status", response_model=MigrationStatusResponse)
async def migration_status():
"""Get overall migration progress."""
from compliance.services.pipeline_adapter import MigrationPasses
db = SessionLocal()
try:
migration = MigrationPasses(db=db)
status = migration.migration_status()
return MigrationStatusResponse(**status)
except Exception as e:
logger.error("Migration status failed: %s", e)
raise HTTPException(status_code=500, detail=str(e))
finally:
db.close()
@router.get("/migrate/decomposition-status", response_model=DecompositionStatusResponse)
async def decomposition_status():
"""Get decomposition progress (Pass 0a/0b)."""
from compliance.services.decomposition_pass import DecompositionPass
db = SessionLocal()
try:
decomp = DecompositionPass(db=db)
status = decomp.decomposition_status()
return DecompositionStatusResponse(**status)
except Exception as e:
logger.error("Decomposition status failed: %s", e)
raise HTTPException(status_code=500, detail=str(e))
finally:
db.close()
# =============================================================================
# BATCH DEDUP ENDPOINTS
# =============================================================================
# Module-level runner reference for status polling
_batch_dedup_runner = None
@router.post("/migrate/batch-dedup", response_model=MigrationResponse)
async def migrate_batch_dedup(
dry_run: bool = Query(False, description="Preview mode — no DB changes"),
hint_filter: Optional[str] = Query(None, description="Only process hints matching this prefix"),
):
"""Batch dedup: reduce ~85k Pass 0b controls to ~18-25k masters.
Phase 1: Groups by merge_group_hint, picks best quality master, links rest.
Phase 2: Cross-group embedding search for semantically similar masters.
"""
global _batch_dedup_runner
from compliance.services.batch_dedup_runner import BatchDedupRunner
db = SessionLocal()
try:
runner = BatchDedupRunner(db=db)
_batch_dedup_runner = runner
stats = await runner.run(dry_run=dry_run, hint_filter=hint_filter)
return MigrationResponse(status="completed", stats=stats)
except Exception as e:
logger.error("Batch dedup failed: %s", e)
raise HTTPException(status_code=500, detail=str(e))
finally:
_batch_dedup_runner = None
db.close()
@router.get("/migrate/batch-dedup/status")
async def batch_dedup_status():
"""Get current batch dedup progress (while running)."""
if _batch_dedup_runner is not None:
return {"running": True, **_batch_dedup_runner.get_status()}
# Not running — show DB stats
db = SessionLocal()
try:
row = db.execute(text("""
SELECT
count(*) FILTER (WHERE decomposition_method = 'pass0b') AS total_pass0b,
count(*) FILTER (WHERE decomposition_method = 'pass0b'
AND release_state = 'duplicate') AS duplicates,
count(*) FILTER (WHERE decomposition_method = 'pass0b'
AND release_state != 'duplicate'
AND release_state != 'deprecated') AS masters
FROM canonical_controls
""")).fetchone()
review_count = db.execute(text(
"SELECT count(*) FROM control_dedup_reviews WHERE review_status = 'pending'"
)).fetchone()[0]
return {
"running": False,
"total_pass0b": row[0],
"duplicates": row[1],
"masters": row[2],
"pending_reviews": review_count,
}
finally:
db.close()
# =============================================================================
# HELPERS
# =============================================================================
def _get_pattern_control_counts() -> dict[str, int]:
"""Get count of controls per pattern_id from DB."""
db = SessionLocal()
try:
result = db.execute(text("""
SELECT pattern_id, count(*) AS cnt
FROM canonical_controls
WHERE pattern_id IS NOT NULL AND pattern_id != ''
AND release_state NOT IN ('deprecated')
GROUP BY pattern_id
"""))
return {row[0]: row[1] for row in result.fetchall()}
except Exception:
return {}
finally:
db.close()