feat(multi-layer): complete Multi-Layer Control Architecture (Phases 1-8 + Pass 0)
Some checks failed
CI/CD / go-lint (push) Has been skipped
CI/CD / python-lint (push) Has been skipped
CI/CD / nodejs-lint (push) Has been skipped
CI/CD / test-go-ai-compliance (push) Failing after 47s
CI/CD / test-python-backend-compliance (push) Successful in 33s
CI/CD / test-python-document-crawler (push) Successful in 24s
CI/CD / test-python-dsms-gateway (push) Successful in 18s
CI/CD / validate-canonical-controls (push) Successful in 11s
CI/CD / Deploy (push) Has been skipped
Some checks failed
CI/CD / go-lint (push) Has been skipped
CI/CD / python-lint (push) Has been skipped
CI/CD / nodejs-lint (push) Has been skipped
CI/CD / test-go-ai-compliance (push) Failing after 47s
CI/CD / test-python-backend-compliance (push) Successful in 33s
CI/CD / test-python-document-crawler (push) Successful in 24s
CI/CD / test-python-dsms-gateway (push) Successful in 18s
CI/CD / validate-canonical-controls (push) Successful in 11s
CI/CD / Deploy (push) Has been skipped
Implements the full Multi-Layer Control Architecture for migrating ~25,000 Rich Controls into atomic, deduplicated Master Controls with full traceability. Architecture: Legal Source → Obligation → Control Pattern → Master Control → Customer Instance New services: - ObligationExtractor: 3-tier extraction (exact → embedding → LLM) - PatternMatcher: 2-tier matching (keyword + embedding + domain-bonus) - ControlComposer: Pattern + Obligation → Master Control - PipelineAdapter: Pipeline integration + Migration Passes 1-5 - DecompositionPass: Pass 0a/0b — Rich Control → atomic Controls - CrosswalkRoutes: 15 API endpoints under /v1/canonical/ New DB schema: - Migration 060: obligation_extractions, control_patterns, crosswalk_matrix - Migration 061: obligation_candidates, parent_control_uuid tracking Pattern Library: 50 YAML patterns (30 core + 20 IT-security) Go SDK: Pattern loader with YAML validation and indexing Documentation: MkDocs updated with full architecture overview 500 Python tests passing across all components. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -53,6 +53,7 @@ _ROUTER_MODULES = [
|
||||
"wiki_routes",
|
||||
"canonical_control_routes",
|
||||
"control_generator_routes",
|
||||
"crosswalk_routes",
|
||||
"process_task_routes",
|
||||
"evidence_check_routes",
|
||||
]
|
||||
|
||||
623
backend-compliance/compliance/api/crosswalk_routes.py
Normal file
623
backend-compliance/compliance/api/crosswalk_routes.py
Normal file
@@ -0,0 +1,623 @@
|
||||
"""
|
||||
FastAPI routes for the Multi-Layer Control Architecture.
|
||||
|
||||
Pattern Library, Obligation Extraction, Crosswalk Matrix, and Migration endpoints.
|
||||
|
||||
Endpoints:
|
||||
GET /v1/canonical/patterns — All patterns (with filters)
|
||||
GET /v1/canonical/patterns/{pattern_id} — Single pattern
|
||||
GET /v1/canonical/patterns/{pattern_id}/controls — Controls for a pattern
|
||||
|
||||
POST /v1/canonical/obligations/extract — Extract obligations from text
|
||||
GET /v1/canonical/crosswalk — Query crosswalk matrix
|
||||
GET /v1/canonical/crosswalk/stats — Coverage statistics
|
||||
|
||||
POST /v1/canonical/migrate/decompose — Pass 0a: Obligation extraction
|
||||
POST /v1/canonical/migrate/compose-atomic — Pass 0b: Atomic control composition
|
||||
POST /v1/canonical/migrate/link-obligations — Pass 1: Obligation linkage
|
||||
POST /v1/canonical/migrate/classify-patterns — Pass 2: Pattern classification
|
||||
POST /v1/canonical/migrate/triage — Pass 3: Quality triage
|
||||
POST /v1/canonical/migrate/backfill-crosswalk — Pass 4: Crosswalk backfill
|
||||
POST /v1/canonical/migrate/deduplicate — Pass 5: Deduplication
|
||||
GET /v1/canonical/migrate/status — Migration progress
|
||||
GET /v1/canonical/migrate/decomposition-status — Decomposition progress
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Optional, List
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Query
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import text
|
||||
|
||||
from database import SessionLocal
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter(prefix="/v1/canonical", tags=["crosswalk"])
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# REQUEST / RESPONSE MODELS
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class PatternResponse(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
name_de: str
|
||||
domain: str
|
||||
category: str
|
||||
description: str
|
||||
objective_template: str
|
||||
severity_default: str
|
||||
implementation_effort_default: str = "m"
|
||||
tags: list = []
|
||||
composable_with: list = []
|
||||
open_anchor_refs: list = []
|
||||
controls_count: int = 0
|
||||
|
||||
|
||||
class PatternListResponse(BaseModel):
|
||||
patterns: List[PatternResponse]
|
||||
total: int
|
||||
|
||||
|
||||
class PatternDetailResponse(PatternResponse):
|
||||
rationale_template: str = ""
|
||||
requirements_template: list = []
|
||||
test_procedure_template: list = []
|
||||
evidence_template: list = []
|
||||
obligation_match_keywords: list = []
|
||||
|
||||
|
||||
class ObligationExtractRequest(BaseModel):
|
||||
text: str
|
||||
regulation_code: Optional[str] = None
|
||||
article: Optional[str] = None
|
||||
paragraph: Optional[str] = None
|
||||
|
||||
|
||||
class ObligationExtractResponse(BaseModel):
|
||||
obligation_id: Optional[str] = None
|
||||
obligation_title: Optional[str] = None
|
||||
obligation_text: Optional[str] = None
|
||||
method: str = "none"
|
||||
confidence: float = 0.0
|
||||
regulation_id: Optional[str] = None
|
||||
pattern_id: Optional[str] = None
|
||||
pattern_confidence: float = 0.0
|
||||
|
||||
|
||||
class CrosswalkRow(BaseModel):
|
||||
regulation_code: str = ""
|
||||
article: Optional[str] = None
|
||||
obligation_id: Optional[str] = None
|
||||
pattern_id: Optional[str] = None
|
||||
master_control_id: Optional[str] = None
|
||||
confidence: float = 0.0
|
||||
source: str = "auto"
|
||||
|
||||
|
||||
class CrosswalkQueryResponse(BaseModel):
|
||||
rows: List[CrosswalkRow]
|
||||
total: int
|
||||
|
||||
|
||||
class CrosswalkStatsResponse(BaseModel):
|
||||
total_rows: int = 0
|
||||
regulations_covered: int = 0
|
||||
obligations_linked: int = 0
|
||||
patterns_used: int = 0
|
||||
controls_linked: int = 0
|
||||
coverage_by_regulation: dict = {}
|
||||
|
||||
|
||||
class MigrationRequest(BaseModel):
|
||||
limit: int = 0 # 0 = no limit
|
||||
|
||||
|
||||
class MigrationResponse(BaseModel):
|
||||
status: str = "completed"
|
||||
stats: dict = {}
|
||||
|
||||
|
||||
class MigrationStatusResponse(BaseModel):
|
||||
total_controls: int = 0
|
||||
has_obligation: int = 0
|
||||
has_pattern: int = 0
|
||||
fully_linked: int = 0
|
||||
deprecated: int = 0
|
||||
coverage_obligation_pct: float = 0.0
|
||||
coverage_pattern_pct: float = 0.0
|
||||
coverage_full_pct: float = 0.0
|
||||
|
||||
|
||||
class DecompositionStatusResponse(BaseModel):
|
||||
rich_controls: int = 0
|
||||
decomposed_controls: int = 0
|
||||
total_candidates: int = 0
|
||||
validated: int = 0
|
||||
rejected: int = 0
|
||||
composed: int = 0
|
||||
atomic_controls: int = 0
|
||||
decomposition_pct: float = 0.0
|
||||
composition_pct: float = 0.0
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# PATTERN LIBRARY ENDPOINTS
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@router.get("/patterns", response_model=PatternListResponse)
|
||||
async def list_patterns(
|
||||
domain: Optional[str] = Query(None, description="Filter by domain (e.g. AUTH, CRYP)"),
|
||||
category: Optional[str] = Query(None, description="Filter by category"),
|
||||
tag: Optional[str] = Query(None, description="Filter by tag"),
|
||||
):
|
||||
"""List all control patterns with optional filters."""
|
||||
from compliance.services.pattern_matcher import PatternMatcher
|
||||
|
||||
matcher = PatternMatcher()
|
||||
matcher._load_patterns()
|
||||
matcher._build_keyword_index()
|
||||
|
||||
patterns = matcher._patterns
|
||||
|
||||
if domain:
|
||||
patterns = [p for p in patterns if p.domain == domain.upper()]
|
||||
if category:
|
||||
patterns = [p for p in patterns if p.category == category.lower()]
|
||||
if tag:
|
||||
patterns = [p for p in patterns if tag.lower() in [t.lower() for t in p.tags]]
|
||||
|
||||
# Count controls per pattern from DB
|
||||
control_counts = _get_pattern_control_counts()
|
||||
|
||||
response_patterns = []
|
||||
for p in patterns:
|
||||
response_patterns.append(PatternResponse(
|
||||
id=p.id,
|
||||
name=p.name,
|
||||
name_de=p.name_de,
|
||||
domain=p.domain,
|
||||
category=p.category,
|
||||
description=p.description,
|
||||
objective_template=p.objective_template,
|
||||
severity_default=p.severity_default,
|
||||
implementation_effort_default=p.implementation_effort_default,
|
||||
tags=p.tags,
|
||||
composable_with=p.composable_with,
|
||||
open_anchor_refs=p.open_anchor_refs,
|
||||
controls_count=control_counts.get(p.id, 0),
|
||||
))
|
||||
|
||||
return PatternListResponse(patterns=response_patterns, total=len(response_patterns))
|
||||
|
||||
|
||||
@router.get("/patterns/{pattern_id}", response_model=PatternDetailResponse)
|
||||
async def get_pattern(pattern_id: str):
|
||||
"""Get a single control pattern by ID."""
|
||||
from compliance.services.pattern_matcher import PatternMatcher
|
||||
|
||||
matcher = PatternMatcher()
|
||||
matcher._load_patterns()
|
||||
|
||||
pattern = matcher.get_pattern(pattern_id)
|
||||
if not pattern:
|
||||
raise HTTPException(status_code=404, detail=f"Pattern {pattern_id} not found")
|
||||
|
||||
control_counts = _get_pattern_control_counts()
|
||||
|
||||
return PatternDetailResponse(
|
||||
id=pattern.id,
|
||||
name=pattern.name,
|
||||
name_de=pattern.name_de,
|
||||
domain=pattern.domain,
|
||||
category=pattern.category,
|
||||
description=pattern.description,
|
||||
objective_template=pattern.objective_template,
|
||||
rationale_template=pattern.rationale_template,
|
||||
requirements_template=pattern.requirements_template,
|
||||
test_procedure_template=pattern.test_procedure_template,
|
||||
evidence_template=pattern.evidence_template,
|
||||
severity_default=pattern.severity_default,
|
||||
implementation_effort_default=pattern.implementation_effort_default,
|
||||
tags=pattern.tags,
|
||||
composable_with=pattern.composable_with,
|
||||
open_anchor_refs=pattern.open_anchor_refs,
|
||||
obligation_match_keywords=pattern.obligation_match_keywords,
|
||||
controls_count=control_counts.get(pattern.id, 0),
|
||||
)
|
||||
|
||||
|
||||
@router.get("/patterns/{pattern_id}/controls")
|
||||
async def get_pattern_controls(
|
||||
pattern_id: str,
|
||||
limit: int = Query(50, ge=1, le=500),
|
||||
offset: int = Query(0, ge=0),
|
||||
):
|
||||
"""Get controls generated from a specific pattern."""
|
||||
db = SessionLocal()
|
||||
try:
|
||||
result = db.execute(
|
||||
text("""
|
||||
SELECT id, control_id, title, objective, severity,
|
||||
release_state, category, obligation_ids
|
||||
FROM canonical_controls
|
||||
WHERE pattern_id = :pattern_id
|
||||
AND release_state NOT IN ('deprecated')
|
||||
ORDER BY control_id
|
||||
LIMIT :limit OFFSET :offset
|
||||
"""),
|
||||
{"pattern_id": pattern_id.upper(), "limit": limit, "offset": offset},
|
||||
)
|
||||
rows = result.fetchall()
|
||||
|
||||
count_result = db.execute(
|
||||
text("""
|
||||
SELECT count(*) FROM canonical_controls
|
||||
WHERE pattern_id = :pattern_id
|
||||
AND release_state NOT IN ('deprecated')
|
||||
"""),
|
||||
{"pattern_id": pattern_id.upper()},
|
||||
)
|
||||
total = count_result.fetchone()[0]
|
||||
|
||||
controls = []
|
||||
for row in rows:
|
||||
obl_ids = row[7]
|
||||
if isinstance(obl_ids, str):
|
||||
try:
|
||||
obl_ids = json.loads(obl_ids)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
obl_ids = []
|
||||
controls.append({
|
||||
"id": str(row[0]),
|
||||
"control_id": row[1],
|
||||
"title": row[2],
|
||||
"objective": row[3],
|
||||
"severity": row[4],
|
||||
"release_state": row[5],
|
||||
"category": row[6],
|
||||
"obligation_ids": obl_ids or [],
|
||||
})
|
||||
|
||||
return {"controls": controls, "total": total}
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# OBLIGATION EXTRACTION ENDPOINT
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@router.post("/obligations/extract", response_model=ObligationExtractResponse)
|
||||
async def extract_obligation(req: ObligationExtractRequest):
|
||||
"""Extract obligation from text using 3-tier strategy, then match to pattern."""
|
||||
from compliance.services.obligation_extractor import ObligationExtractor
|
||||
from compliance.services.pattern_matcher import PatternMatcher
|
||||
|
||||
extractor = ObligationExtractor()
|
||||
await extractor.initialize()
|
||||
|
||||
obligation = await extractor.extract(
|
||||
chunk_text=req.text,
|
||||
regulation_code=req.regulation_code or "",
|
||||
article=req.article,
|
||||
paragraph=req.paragraph,
|
||||
)
|
||||
|
||||
# Also match to pattern
|
||||
matcher = PatternMatcher()
|
||||
matcher._load_patterns()
|
||||
matcher._build_keyword_index()
|
||||
|
||||
pattern_text = obligation.obligation_text or obligation.obligation_title or req.text[:500]
|
||||
pattern_result = matcher._tier1_keyword(pattern_text, obligation.regulation_id)
|
||||
|
||||
return ObligationExtractResponse(
|
||||
obligation_id=obligation.obligation_id,
|
||||
obligation_title=obligation.obligation_title,
|
||||
obligation_text=obligation.obligation_text,
|
||||
method=obligation.method,
|
||||
confidence=obligation.confidence,
|
||||
regulation_id=obligation.regulation_id,
|
||||
pattern_id=pattern_result.pattern_id if pattern_result else None,
|
||||
pattern_confidence=pattern_result.confidence if pattern_result else 0,
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# CROSSWALK MATRIX ENDPOINTS
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@router.get("/crosswalk", response_model=CrosswalkQueryResponse)
|
||||
async def query_crosswalk(
|
||||
regulation_code: Optional[str] = Query(None),
|
||||
article: Optional[str] = Query(None),
|
||||
obligation_id: Optional[str] = Query(None),
|
||||
pattern_id: Optional[str] = Query(None),
|
||||
limit: int = Query(100, ge=1, le=1000),
|
||||
offset: int = Query(0, ge=0),
|
||||
):
|
||||
"""Query the crosswalk matrix with filters."""
|
||||
db = SessionLocal()
|
||||
try:
|
||||
conditions = ["1=1"]
|
||||
params = {"limit": limit, "offset": offset}
|
||||
|
||||
if regulation_code:
|
||||
conditions.append("regulation_code = :reg")
|
||||
params["reg"] = regulation_code
|
||||
if article:
|
||||
conditions.append("article = :art")
|
||||
params["art"] = article
|
||||
if obligation_id:
|
||||
conditions.append("obligation_id = :obl")
|
||||
params["obl"] = obligation_id
|
||||
if pattern_id:
|
||||
conditions.append("pattern_id = :pat")
|
||||
params["pat"] = pattern_id
|
||||
|
||||
where = " AND ".join(conditions)
|
||||
|
||||
result = db.execute(
|
||||
text(f"""
|
||||
SELECT regulation_code, article, obligation_id,
|
||||
pattern_id, master_control_id, confidence, source
|
||||
FROM crosswalk_matrix
|
||||
WHERE {where}
|
||||
ORDER BY regulation_code, article
|
||||
LIMIT :limit OFFSET :offset
|
||||
"""),
|
||||
params,
|
||||
)
|
||||
rows = result.fetchall()
|
||||
|
||||
count_result = db.execute(
|
||||
text(f"SELECT count(*) FROM crosswalk_matrix WHERE {where}"),
|
||||
params,
|
||||
)
|
||||
total = count_result.fetchone()[0]
|
||||
|
||||
crosswalk_rows = [
|
||||
CrosswalkRow(
|
||||
regulation_code=r[0] or "",
|
||||
article=r[1],
|
||||
obligation_id=r[2],
|
||||
pattern_id=r[3],
|
||||
master_control_id=r[4],
|
||||
confidence=float(r[5] or 0),
|
||||
source=r[6] or "auto",
|
||||
)
|
||||
for r in rows
|
||||
]
|
||||
|
||||
return CrosswalkQueryResponse(rows=crosswalk_rows, total=total)
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
@router.get("/crosswalk/stats", response_model=CrosswalkStatsResponse)
|
||||
async def crosswalk_stats():
|
||||
"""Get crosswalk coverage statistics."""
|
||||
db = SessionLocal()
|
||||
try:
|
||||
row = db.execute(text("""
|
||||
SELECT
|
||||
count(*) AS total,
|
||||
count(DISTINCT regulation_code) FILTER (WHERE regulation_code != '') AS regs,
|
||||
count(DISTINCT obligation_id) FILTER (WHERE obligation_id IS NOT NULL) AS obls,
|
||||
count(DISTINCT pattern_id) FILTER (WHERE pattern_id IS NOT NULL) AS pats,
|
||||
count(DISTINCT master_control_id) FILTER (WHERE master_control_id IS NOT NULL) AS ctrls
|
||||
FROM crosswalk_matrix
|
||||
""")).fetchone()
|
||||
|
||||
# Coverage by regulation
|
||||
reg_rows = db.execute(text("""
|
||||
SELECT regulation_code, count(*) AS cnt
|
||||
FROM crosswalk_matrix
|
||||
WHERE regulation_code != ''
|
||||
GROUP BY regulation_code
|
||||
ORDER BY cnt DESC
|
||||
""")).fetchall()
|
||||
|
||||
coverage = {r[0]: r[1] for r in reg_rows}
|
||||
|
||||
return CrosswalkStatsResponse(
|
||||
total_rows=row[0],
|
||||
regulations_covered=row[1],
|
||||
obligations_linked=row[2],
|
||||
patterns_used=row[3],
|
||||
controls_linked=row[4],
|
||||
coverage_by_regulation=coverage,
|
||||
)
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# MIGRATION ENDPOINTS
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@router.post("/migrate/decompose", response_model=MigrationResponse)
|
||||
async def migrate_decompose(req: MigrationRequest):
|
||||
"""Pass 0a: Extract obligation candidates from rich controls."""
|
||||
from compliance.services.decomposition_pass import DecompositionPass
|
||||
|
||||
db = SessionLocal()
|
||||
try:
|
||||
decomp = DecompositionPass(db=db)
|
||||
stats = await decomp.run_pass0a(limit=req.limit)
|
||||
return MigrationResponse(status="completed", stats=stats)
|
||||
except Exception as e:
|
||||
logger.error("Decomposition pass 0a failed: %s", e)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
@router.post("/migrate/compose-atomic", response_model=MigrationResponse)
|
||||
async def migrate_compose_atomic(req: MigrationRequest):
|
||||
"""Pass 0b: Compose atomic controls from obligation candidates."""
|
||||
from compliance.services.decomposition_pass import DecompositionPass
|
||||
|
||||
db = SessionLocal()
|
||||
try:
|
||||
decomp = DecompositionPass(db=db)
|
||||
stats = await decomp.run_pass0b(limit=req.limit)
|
||||
return MigrationResponse(status="completed", stats=stats)
|
||||
except Exception as e:
|
||||
logger.error("Decomposition pass 0b failed: %s", e)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
@router.post("/migrate/link-obligations", response_model=MigrationResponse)
|
||||
async def migrate_link_obligations(req: MigrationRequest):
|
||||
"""Pass 1: Link controls to obligations via source_citation article."""
|
||||
from compliance.services.pipeline_adapter import MigrationPasses
|
||||
|
||||
db = SessionLocal()
|
||||
try:
|
||||
migration = MigrationPasses(db=db)
|
||||
await migration.initialize()
|
||||
stats = await migration.run_pass1_obligation_linkage(limit=req.limit)
|
||||
return MigrationResponse(status="completed", stats=stats)
|
||||
except Exception as e:
|
||||
logger.error("Migration pass 1 failed: %s", e)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
@router.post("/migrate/classify-patterns", response_model=MigrationResponse)
|
||||
async def migrate_classify_patterns(req: MigrationRequest):
|
||||
"""Pass 2: Classify controls into patterns via keyword matching."""
|
||||
from compliance.services.pipeline_adapter import MigrationPasses
|
||||
|
||||
db = SessionLocal()
|
||||
try:
|
||||
migration = MigrationPasses(db=db)
|
||||
await migration.initialize()
|
||||
stats = await migration.run_pass2_pattern_classification(limit=req.limit)
|
||||
return MigrationResponse(status="completed", stats=stats)
|
||||
except Exception as e:
|
||||
logger.error("Migration pass 2 failed: %s", e)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
@router.post("/migrate/triage", response_model=MigrationResponse)
|
||||
async def migrate_triage():
|
||||
"""Pass 3: Quality triage — categorize by linkage completeness."""
|
||||
from compliance.services.pipeline_adapter import MigrationPasses
|
||||
|
||||
db = SessionLocal()
|
||||
try:
|
||||
migration = MigrationPasses(db=db)
|
||||
stats = migration.run_pass3_quality_triage()
|
||||
return MigrationResponse(status="completed", stats=stats)
|
||||
except Exception as e:
|
||||
logger.error("Migration pass 3 failed: %s", e)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
@router.post("/migrate/backfill-crosswalk", response_model=MigrationResponse)
|
||||
async def migrate_backfill_crosswalk():
|
||||
"""Pass 4: Create crosswalk rows for linked controls."""
|
||||
from compliance.services.pipeline_adapter import MigrationPasses
|
||||
|
||||
db = SessionLocal()
|
||||
try:
|
||||
migration = MigrationPasses(db=db)
|
||||
stats = migration.run_pass4_crosswalk_backfill()
|
||||
return MigrationResponse(status="completed", stats=stats)
|
||||
except Exception as e:
|
||||
logger.error("Migration pass 4 failed: %s", e)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
@router.post("/migrate/deduplicate", response_model=MigrationResponse)
|
||||
async def migrate_deduplicate():
|
||||
"""Pass 5: Mark duplicate controls (same obligation + pattern)."""
|
||||
from compliance.services.pipeline_adapter import MigrationPasses
|
||||
|
||||
db = SessionLocal()
|
||||
try:
|
||||
migration = MigrationPasses(db=db)
|
||||
stats = migration.run_pass5_deduplication()
|
||||
return MigrationResponse(status="completed", stats=stats)
|
||||
except Exception as e:
|
||||
logger.error("Migration pass 5 failed: %s", e)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
@router.get("/migrate/status", response_model=MigrationStatusResponse)
|
||||
async def migration_status():
|
||||
"""Get overall migration progress."""
|
||||
from compliance.services.pipeline_adapter import MigrationPasses
|
||||
|
||||
db = SessionLocal()
|
||||
try:
|
||||
migration = MigrationPasses(db=db)
|
||||
status = migration.migration_status()
|
||||
return MigrationStatusResponse(**status)
|
||||
except Exception as e:
|
||||
logger.error("Migration status failed: %s", e)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
@router.get("/migrate/decomposition-status", response_model=DecompositionStatusResponse)
|
||||
async def decomposition_status():
|
||||
"""Get decomposition progress (Pass 0a/0b)."""
|
||||
from compliance.services.decomposition_pass import DecompositionPass
|
||||
|
||||
db = SessionLocal()
|
||||
try:
|
||||
decomp = DecompositionPass(db=db)
|
||||
status = decomp.decomposition_status()
|
||||
return DecompositionStatusResponse(**status)
|
||||
except Exception as e:
|
||||
logger.error("Decomposition status failed: %s", e)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# HELPERS
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def _get_pattern_control_counts() -> dict[str, int]:
|
||||
"""Get count of controls per pattern_id from DB."""
|
||||
db = SessionLocal()
|
||||
try:
|
||||
result = db.execute(text("""
|
||||
SELECT pattern_id, count(*) AS cnt
|
||||
FROM canonical_controls
|
||||
WHERE pattern_id IS NOT NULL AND pattern_id != ''
|
||||
AND release_state NOT IN ('deprecated')
|
||||
GROUP BY pattern_id
|
||||
"""))
|
||||
return {row[0]: row[1] for row in result.fetchall()}
|
||||
except Exception:
|
||||
return {}
|
||||
finally:
|
||||
db.close()
|
||||
546
backend-compliance/compliance/services/control_composer.py
Normal file
546
backend-compliance/compliance/services/control_composer.py
Normal file
@@ -0,0 +1,546 @@
|
||||
"""Control Composer — Pattern + Obligation → Master Control.
|
||||
|
||||
Takes an obligation (from ObligationExtractor) and a matched control pattern
|
||||
(from PatternMatcher), then uses LLM to compose a structured, actionable
|
||||
Master Control. Replaces the old Stage 3 (STRUCTURE/REFORM) with a
|
||||
pattern-guided approach.
|
||||
|
||||
Three composition modes based on license rules:
|
||||
Rule 1: Obligation + Pattern + original text → full control
|
||||
Rule 2: Obligation + Pattern + original text + citation → control
|
||||
Rule 3: Obligation + Pattern (NO original text) → reformulated control
|
||||
|
||||
Fallback: No pattern match → basic generation (tagged needs_pattern_assignment)
|
||||
|
||||
Part of the Multi-Layer Control Architecture (Phase 6 of 8).
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
from compliance.services.obligation_extractor import (
|
||||
ObligationMatch,
|
||||
_llm_ollama,
|
||||
_parse_json,
|
||||
)
|
||||
from compliance.services.pattern_matcher import (
|
||||
ControlPattern,
|
||||
PatternMatchResult,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
OLLAMA_MODEL = os.getenv("CONTROL_GEN_OLLAMA_MODEL", "qwen3.5:35b-a3b")
|
||||
|
||||
# Valid values for generated control fields
|
||||
VALID_SEVERITIES = {"low", "medium", "high", "critical"}
|
||||
VALID_EFFORTS = {"s", "m", "l", "xl"}
|
||||
VALID_VERIFICATION = {"code_review", "document", "tool", "hybrid"}
|
||||
|
||||
|
||||
@dataclass
|
||||
class ComposedControl:
|
||||
"""A Master Control composed from an obligation + pattern."""
|
||||
|
||||
# Core fields (match canonical_controls schema)
|
||||
control_id: str = ""
|
||||
title: str = ""
|
||||
objective: str = ""
|
||||
rationale: str = ""
|
||||
scope: dict = field(default_factory=dict)
|
||||
requirements: list = field(default_factory=list)
|
||||
test_procedure: list = field(default_factory=list)
|
||||
evidence: list = field(default_factory=list)
|
||||
severity: str = "medium"
|
||||
risk_score: float = 5.0
|
||||
implementation_effort: str = "m"
|
||||
open_anchors: list = field(default_factory=list)
|
||||
release_state: str = "draft"
|
||||
tags: list = field(default_factory=list)
|
||||
# 3-Rule License fields
|
||||
license_rule: Optional[int] = None
|
||||
source_original_text: Optional[str] = None
|
||||
source_citation: Optional[dict] = None
|
||||
customer_visible: bool = True
|
||||
# Classification
|
||||
verification_method: Optional[str] = None
|
||||
category: Optional[str] = None
|
||||
target_audience: Optional[list] = None
|
||||
# Pattern + Obligation linkage
|
||||
pattern_id: Optional[str] = None
|
||||
obligation_ids: list = field(default_factory=list)
|
||||
# Metadata
|
||||
generation_metadata: dict = field(default_factory=dict)
|
||||
composition_method: str = "pattern_guided" # pattern_guided | fallback
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""Serialize for DB storage or API response."""
|
||||
return {
|
||||
"control_id": self.control_id,
|
||||
"title": self.title,
|
||||
"objective": self.objective,
|
||||
"rationale": self.rationale,
|
||||
"scope": self.scope,
|
||||
"requirements": self.requirements,
|
||||
"test_procedure": self.test_procedure,
|
||||
"evidence": self.evidence,
|
||||
"severity": self.severity,
|
||||
"risk_score": self.risk_score,
|
||||
"implementation_effort": self.implementation_effort,
|
||||
"open_anchors": self.open_anchors,
|
||||
"release_state": self.release_state,
|
||||
"tags": self.tags,
|
||||
"license_rule": self.license_rule,
|
||||
"source_original_text": self.source_original_text,
|
||||
"source_citation": self.source_citation,
|
||||
"customer_visible": self.customer_visible,
|
||||
"verification_method": self.verification_method,
|
||||
"category": self.category,
|
||||
"target_audience": self.target_audience,
|
||||
"pattern_id": self.pattern_id,
|
||||
"obligation_ids": self.obligation_ids,
|
||||
"generation_metadata": self.generation_metadata,
|
||||
"composition_method": self.composition_method,
|
||||
}
|
||||
|
||||
|
||||
class ControlComposer:
|
||||
"""Composes Master Controls from obligations + patterns.
|
||||
|
||||
Usage::
|
||||
|
||||
composer = ControlComposer()
|
||||
|
||||
control = await composer.compose(
|
||||
obligation=obligation_match,
|
||||
pattern_result=pattern_match_result,
|
||||
chunk_text="...",
|
||||
license_rule=1,
|
||||
source_citation={...},
|
||||
)
|
||||
"""
|
||||
|
||||
async def compose(
|
||||
self,
|
||||
obligation: ObligationMatch,
|
||||
pattern_result: PatternMatchResult,
|
||||
chunk_text: Optional[str] = None,
|
||||
license_rule: int = 3,
|
||||
source_citation: Optional[dict] = None,
|
||||
regulation_code: Optional[str] = None,
|
||||
) -> ComposedControl:
|
||||
"""Compose a Master Control from obligation + pattern.
|
||||
|
||||
Args:
|
||||
obligation: The extracted obligation (from ObligationExtractor).
|
||||
pattern_result: The matched pattern (from PatternMatcher).
|
||||
chunk_text: Original RAG chunk text (only used for Rules 1-2).
|
||||
license_rule: 1=free, 2=citation, 3=restricted.
|
||||
source_citation: Citation metadata for Rule 2.
|
||||
regulation_code: Source regulation code.
|
||||
|
||||
Returns:
|
||||
ComposedControl ready for storage.
|
||||
"""
|
||||
pattern = pattern_result.pattern if pattern_result else None
|
||||
|
||||
if pattern:
|
||||
control = await self._compose_with_pattern(
|
||||
obligation, pattern, chunk_text, license_rule, source_citation,
|
||||
)
|
||||
else:
|
||||
control = await self._compose_fallback(
|
||||
obligation, chunk_text, license_rule, source_citation,
|
||||
)
|
||||
|
||||
# Set linkage fields
|
||||
control.pattern_id = pattern.id if pattern else None
|
||||
if obligation.obligation_id:
|
||||
control.obligation_ids = [obligation.obligation_id]
|
||||
|
||||
# Set license fields
|
||||
control.license_rule = license_rule
|
||||
if license_rule in (1, 2) and chunk_text:
|
||||
control.source_original_text = chunk_text
|
||||
if license_rule == 2 and source_citation:
|
||||
control.source_citation = source_citation
|
||||
if license_rule == 3:
|
||||
control.customer_visible = False
|
||||
control.source_original_text = None
|
||||
control.source_citation = None
|
||||
|
||||
# Build metadata
|
||||
control.generation_metadata = {
|
||||
"composition_method": control.composition_method,
|
||||
"pattern_id": control.pattern_id,
|
||||
"pattern_confidence": round(pattern_result.confidence, 3) if pattern_result else 0,
|
||||
"pattern_method": pattern_result.method if pattern_result else "none",
|
||||
"obligation_id": obligation.obligation_id,
|
||||
"obligation_method": obligation.method,
|
||||
"obligation_confidence": round(obligation.confidence, 3),
|
||||
"license_rule": license_rule,
|
||||
"regulation_code": regulation_code,
|
||||
}
|
||||
|
||||
# Validate and fix fields
|
||||
_validate_control(control)
|
||||
|
||||
return control
|
||||
|
||||
async def compose_batch(
|
||||
self,
|
||||
items: list[dict],
|
||||
) -> list[ComposedControl]:
|
||||
"""Compose multiple controls.
|
||||
|
||||
Args:
|
||||
items: List of dicts with keys: obligation, pattern_result,
|
||||
chunk_text, license_rule, source_citation, regulation_code.
|
||||
|
||||
Returns:
|
||||
List of ComposedControl instances.
|
||||
"""
|
||||
results = []
|
||||
for item in items:
|
||||
control = await self.compose(
|
||||
obligation=item["obligation"],
|
||||
pattern_result=item.get("pattern_result", PatternMatchResult()),
|
||||
chunk_text=item.get("chunk_text"),
|
||||
license_rule=item.get("license_rule", 3),
|
||||
source_citation=item.get("source_citation"),
|
||||
regulation_code=item.get("regulation_code"),
|
||||
)
|
||||
results.append(control)
|
||||
return results
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# Pattern-guided composition
|
||||
# -----------------------------------------------------------------------
|
||||
|
||||
async def _compose_with_pattern(
|
||||
self,
|
||||
obligation: ObligationMatch,
|
||||
pattern: ControlPattern,
|
||||
chunk_text: Optional[str],
|
||||
license_rule: int,
|
||||
source_citation: Optional[dict],
|
||||
) -> ComposedControl:
|
||||
"""Use LLM to fill the pattern template with obligation-specific details."""
|
||||
prompt = _build_compose_prompt(obligation, pattern, chunk_text, license_rule)
|
||||
system_prompt = _compose_system_prompt(license_rule)
|
||||
|
||||
llm_result = await _llm_ollama(prompt, system_prompt)
|
||||
if not llm_result:
|
||||
return self._compose_from_template(obligation, pattern)
|
||||
|
||||
parsed = _parse_json(llm_result)
|
||||
if not parsed:
|
||||
return self._compose_from_template(obligation, pattern)
|
||||
|
||||
control = ComposedControl(
|
||||
title=parsed.get("title", pattern.name_de)[:255],
|
||||
objective=parsed.get("objective", pattern.objective_template),
|
||||
rationale=parsed.get("rationale", pattern.rationale_template),
|
||||
requirements=_ensure_list(parsed.get("requirements", pattern.requirements_template)),
|
||||
test_procedure=_ensure_list(parsed.get("test_procedure", pattern.test_procedure_template)),
|
||||
evidence=_ensure_list(parsed.get("evidence", pattern.evidence_template)),
|
||||
severity=parsed.get("severity", pattern.severity_default),
|
||||
implementation_effort=parsed.get("implementation_effort", pattern.implementation_effort_default),
|
||||
category=parsed.get("category", pattern.category),
|
||||
tags=_ensure_list(parsed.get("tags", pattern.tags)),
|
||||
target_audience=_ensure_list(parsed.get("target_audience", [])),
|
||||
verification_method=parsed.get("verification_method"),
|
||||
open_anchors=_anchors_from_pattern(pattern),
|
||||
composition_method="pattern_guided",
|
||||
)
|
||||
|
||||
return control
|
||||
|
||||
def _compose_from_template(
|
||||
self,
|
||||
obligation: ObligationMatch,
|
||||
pattern: ControlPattern,
|
||||
) -> ComposedControl:
|
||||
"""Fallback: fill template directly without LLM (when LLM fails)."""
|
||||
obl_title = obligation.obligation_title or ""
|
||||
obl_text = obligation.obligation_text or ""
|
||||
|
||||
title = f"{pattern.name_de}"
|
||||
if obl_title:
|
||||
title = f"{pattern.name_de} — {obl_title}"
|
||||
|
||||
objective = pattern.objective_template
|
||||
if obl_text and len(obl_text) > 20:
|
||||
objective = f"{pattern.objective_template} Bezug: {obl_text[:200]}"
|
||||
|
||||
return ComposedControl(
|
||||
title=title[:255],
|
||||
objective=objective,
|
||||
rationale=pattern.rationale_template,
|
||||
requirements=list(pattern.requirements_template),
|
||||
test_procedure=list(pattern.test_procedure_template),
|
||||
evidence=list(pattern.evidence_template),
|
||||
severity=pattern.severity_default,
|
||||
implementation_effort=pattern.implementation_effort_default,
|
||||
category=pattern.category,
|
||||
tags=list(pattern.tags),
|
||||
open_anchors=_anchors_from_pattern(pattern),
|
||||
composition_method="template_only",
|
||||
)
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# Fallback (no pattern)
|
||||
# -----------------------------------------------------------------------
|
||||
|
||||
async def _compose_fallback(
|
||||
self,
|
||||
obligation: ObligationMatch,
|
||||
chunk_text: Optional[str],
|
||||
license_rule: int,
|
||||
source_citation: Optional[dict],
|
||||
) -> ComposedControl:
|
||||
"""Generate a control without a pattern template (old-style)."""
|
||||
prompt = _build_fallback_prompt(obligation, chunk_text, license_rule)
|
||||
system_prompt = _compose_system_prompt(license_rule)
|
||||
|
||||
llm_result = await _llm_ollama(prompt, system_prompt)
|
||||
parsed = _parse_json(llm_result) if llm_result else {}
|
||||
|
||||
obl_text = obligation.obligation_text or ""
|
||||
|
||||
control = ComposedControl(
|
||||
title=parsed.get("title", obl_text[:100] if obl_text else "Untitled Control")[:255],
|
||||
objective=parsed.get("objective", obl_text[:500]),
|
||||
rationale=parsed.get("rationale", "Aus gesetzlicher Pflicht abgeleitet."),
|
||||
requirements=_ensure_list(parsed.get("requirements", [])),
|
||||
test_procedure=_ensure_list(parsed.get("test_procedure", [])),
|
||||
evidence=_ensure_list(parsed.get("evidence", [])),
|
||||
severity=parsed.get("severity", "medium"),
|
||||
implementation_effort=parsed.get("implementation_effort", "m"),
|
||||
category=parsed.get("category"),
|
||||
tags=_ensure_list(parsed.get("tags", [])),
|
||||
target_audience=_ensure_list(parsed.get("target_audience", [])),
|
||||
verification_method=parsed.get("verification_method"),
|
||||
composition_method="fallback",
|
||||
release_state="needs_review",
|
||||
)
|
||||
|
||||
return control
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Prompt builders
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _compose_system_prompt(license_rule: int) -> str:
|
||||
"""Build the system prompt based on license rule."""
|
||||
if license_rule == 3:
|
||||
return (
|
||||
"Du bist ein Security-Compliance-Experte. Deine Aufgabe ist es, "
|
||||
"eigenstaendige Security Controls zu formulieren. "
|
||||
"Du formulierst IMMER in eigenen Worten. "
|
||||
"KOPIERE KEINE Saetze aus dem Quelltext. "
|
||||
"Verwende eigene Begriffe und Struktur. "
|
||||
"NENNE NICHT die Quelle. Keine proprietaeren Bezeichner. "
|
||||
"Antworte NUR mit validem JSON."
|
||||
)
|
||||
return (
|
||||
"Du bist ein Security-Compliance-Experte. "
|
||||
"Erstelle ein praxisorientiertes, umsetzbares Security Control. "
|
||||
"Antworte NUR mit validem JSON."
|
||||
)
|
||||
|
||||
|
||||
def _build_compose_prompt(
|
||||
obligation: ObligationMatch,
|
||||
pattern: ControlPattern,
|
||||
chunk_text: Optional[str],
|
||||
license_rule: int,
|
||||
) -> str:
|
||||
"""Build the LLM prompt for pattern-guided composition."""
|
||||
obl_section = _obligation_section(obligation)
|
||||
pattern_section = _pattern_section(pattern)
|
||||
|
||||
if license_rule == 3:
|
||||
context_section = "KONTEXT: Intern analysiert (keine Quellenangabe)."
|
||||
elif chunk_text:
|
||||
context_section = f"KONTEXT (Originaltext):\n{chunk_text[:2000]}"
|
||||
else:
|
||||
context_section = "KONTEXT: Kein Originaltext verfuegbar."
|
||||
|
||||
return f"""Erstelle ein PRAXISORIENTIERTES Security Control.
|
||||
|
||||
{obl_section}
|
||||
|
||||
{pattern_section}
|
||||
|
||||
{context_section}
|
||||
|
||||
AUFGABE:
|
||||
Fuelle das Muster mit pflicht-spezifischen Details.
|
||||
Das Ergebnis muss UMSETZBAR sein — keine Gesetzesparaphrase.
|
||||
Formuliere konkret und handlungsorientiert.
|
||||
|
||||
Antworte als JSON:
|
||||
{{
|
||||
"title": "Kurzer praegnanter Titel (max 100 Zeichen, deutsch)",
|
||||
"objective": "Was soll erreicht werden? (1-3 Saetze)",
|
||||
"rationale": "Warum ist das wichtig? (1-2 Saetze)",
|
||||
"requirements": ["Konkrete Anforderung 1", "Anforderung 2", ...],
|
||||
"test_procedure": ["Pruefschritt 1", "Pruefschritt 2", ...],
|
||||
"evidence": ["Nachweis 1", "Nachweis 2", ...],
|
||||
"severity": "low|medium|high|critical",
|
||||
"implementation_effort": "s|m|l|xl",
|
||||
"category": "{pattern.category}",
|
||||
"tags": ["tag1", "tag2"],
|
||||
"target_audience": ["unternehmen", "behoerden", "entwickler"],
|
||||
"verification_method": "code_review|document|tool|hybrid"
|
||||
}}"""
|
||||
|
||||
|
||||
def _build_fallback_prompt(
|
||||
obligation: ObligationMatch,
|
||||
chunk_text: Optional[str],
|
||||
license_rule: int,
|
||||
) -> str:
|
||||
"""Build the LLM prompt for fallback composition (no pattern)."""
|
||||
obl_section = _obligation_section(obligation)
|
||||
|
||||
if license_rule == 3:
|
||||
context_section = "KONTEXT: Intern analysiert (keine Quellenangabe)."
|
||||
elif chunk_text:
|
||||
context_section = f"KONTEXT (Originaltext):\n{chunk_text[:2000]}"
|
||||
else:
|
||||
context_section = "KONTEXT: Kein Originaltext verfuegbar."
|
||||
|
||||
return f"""Erstelle ein Security Control aus der folgenden Pflicht.
|
||||
|
||||
{obl_section}
|
||||
|
||||
{context_section}
|
||||
|
||||
AUFGABE:
|
||||
Formuliere ein umsetzbares Security Control.
|
||||
Keine Gesetzesparaphrase — konkrete Massnahmen beschreiben.
|
||||
|
||||
Antworte als JSON:
|
||||
{{
|
||||
"title": "Kurzer praegnanter Titel (max 100 Zeichen, deutsch)",
|
||||
"objective": "Was soll erreicht werden? (1-3 Saetze)",
|
||||
"rationale": "Warum ist das wichtig? (1-2 Saetze)",
|
||||
"requirements": ["Konkrete Anforderung 1", "Anforderung 2", ...],
|
||||
"test_procedure": ["Pruefschritt 1", "Pruefschritt 2", ...],
|
||||
"evidence": ["Nachweis 1", "Nachweis 2", ...],
|
||||
"severity": "low|medium|high|critical",
|
||||
"implementation_effort": "s|m|l|xl",
|
||||
"category": "one of: authentication, encryption, data_protection, etc.",
|
||||
"tags": ["tag1", "tag2"],
|
||||
"target_audience": ["unternehmen"],
|
||||
"verification_method": "code_review|document|tool|hybrid"
|
||||
}}"""
|
||||
|
||||
|
||||
def _obligation_section(obligation: ObligationMatch) -> str:
|
||||
"""Format the obligation for the prompt."""
|
||||
parts = ["PFLICHT (was das Gesetz verlangt):"]
|
||||
if obligation.obligation_title:
|
||||
parts.append(f" Titel: {obligation.obligation_title}")
|
||||
if obligation.obligation_text:
|
||||
parts.append(f" Beschreibung: {obligation.obligation_text[:500]}")
|
||||
if obligation.obligation_id:
|
||||
parts.append(f" ID: {obligation.obligation_id}")
|
||||
if obligation.regulation_id:
|
||||
parts.append(f" Rechtsgrundlage: {obligation.regulation_id}")
|
||||
if not obligation.obligation_text and not obligation.obligation_title:
|
||||
parts.append(" (Keine spezifische Pflicht extrahiert)")
|
||||
return "\n".join(parts)
|
||||
|
||||
|
||||
def _pattern_section(pattern: ControlPattern) -> str:
|
||||
"""Format the pattern for the prompt."""
|
||||
reqs = "\n ".join(f"- {r}" for r in pattern.requirements_template[:5])
|
||||
tests = "\n ".join(f"- {t}" for t in pattern.test_procedure_template[:3])
|
||||
return f"""MUSTER (wie man es typischerweise umsetzt):
|
||||
Pattern: {pattern.name_de} ({pattern.id})
|
||||
Domain: {pattern.domain}
|
||||
Ziel-Template: {pattern.objective_template}
|
||||
Anforderungs-Template:
|
||||
{reqs}
|
||||
Pruefverfahren-Template:
|
||||
{tests}"""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _ensure_list(value) -> list:
|
||||
"""Ensure a value is a list of strings."""
|
||||
if isinstance(value, list):
|
||||
return [str(v) for v in value if v]
|
||||
if isinstance(value, str):
|
||||
return [value]
|
||||
return []
|
||||
|
||||
|
||||
def _anchors_from_pattern(pattern: ControlPattern) -> list:
|
||||
"""Convert pattern's open_anchor_refs to control anchor format."""
|
||||
anchors = []
|
||||
for ref in pattern.open_anchor_refs:
|
||||
anchors.append({
|
||||
"framework": ref.get("framework", ""),
|
||||
"control_id": ref.get("ref", ""),
|
||||
"title": "",
|
||||
"alignment_score": 0.8,
|
||||
})
|
||||
return anchors
|
||||
|
||||
|
||||
def _validate_control(control: ComposedControl) -> None:
|
||||
"""Validate and fix control field values."""
|
||||
# Severity
|
||||
if control.severity not in VALID_SEVERITIES:
|
||||
control.severity = "medium"
|
||||
|
||||
# Implementation effort
|
||||
if control.implementation_effort not in VALID_EFFORTS:
|
||||
control.implementation_effort = "m"
|
||||
|
||||
# Verification method
|
||||
if control.verification_method and control.verification_method not in VALID_VERIFICATION:
|
||||
control.verification_method = None
|
||||
|
||||
# Risk score
|
||||
if not (0 <= control.risk_score <= 10):
|
||||
control.risk_score = _severity_to_risk(control.severity)
|
||||
|
||||
# Title length
|
||||
if len(control.title) > 255:
|
||||
control.title = control.title[:252] + "..."
|
||||
|
||||
# Ensure minimum content
|
||||
if not control.objective:
|
||||
control.objective = control.title
|
||||
if not control.rationale:
|
||||
control.rationale = "Aus regulatorischer Anforderung abgeleitet."
|
||||
if not control.requirements:
|
||||
control.requirements = ["Anforderung gemaess Pflichtbeschreibung umsetzen"]
|
||||
if not control.test_procedure:
|
||||
control.test_procedure = ["Umsetzung der Anforderungen pruefen"]
|
||||
if not control.evidence:
|
||||
control.evidence = ["Dokumentation der Umsetzung"]
|
||||
|
||||
|
||||
def _severity_to_risk(severity: str) -> float:
|
||||
"""Map severity to a default risk score."""
|
||||
return {
|
||||
"critical": 9.0,
|
||||
"high": 7.0,
|
||||
"medium": 5.0,
|
||||
"low": 3.0,
|
||||
}.get(severity, 5.0)
|
||||
854
backend-compliance/compliance/services/decomposition_pass.py
Normal file
854
backend-compliance/compliance/services/decomposition_pass.py
Normal file
@@ -0,0 +1,854 @@
|
||||
"""Decomposition Pass — Split Rich Controls into Atomic Controls.
|
||||
|
||||
Pass 0 of the Multi-Layer Control Architecture migration. Runs BEFORE
|
||||
Passes 1-5 (obligation linkage, pattern classification, etc.).
|
||||
|
||||
Two sub-passes:
|
||||
Pass 0a: Obligation Extraction — extract individual normative obligations
|
||||
from a Rich Control using LLM with strict guardrails.
|
||||
Pass 0b: Atomic Control Composition — turn each obligation candidate
|
||||
into a standalone atomic control record.
|
||||
|
||||
Plus a Quality Gate that validates extraction results.
|
||||
|
||||
Guardrails (the 6 rules):
|
||||
1. Only normative statements (müssen, sicherzustellen, verpflichtet, ...)
|
||||
2. One main verb per obligation
|
||||
3. Test obligations separate from operational obligations
|
||||
4. Reporting obligations separate
|
||||
5. Don't split at evidence level
|
||||
6. Parent link always preserved
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Normative signal detection (Rule 1)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_NORMATIVE_SIGNALS = [
|
||||
r"\bmüssen\b", r"\bmuss\b", r"\bhat\s+sicherzustellen\b",
|
||||
r"\bhaben\s+sicherzustellen\b", r"\bsind\s+verpflichtet\b",
|
||||
r"\bist\s+verpflichtet\b", r"\bist\s+zu\s+\w+en\b",
|
||||
r"\bsind\s+zu\s+\w+en\b", r"\bhat\s+zu\s+\w+en\b",
|
||||
r"\bhaben\s+zu\s+\w+en\b", r"\bsoll\b", r"\bsollen\b",
|
||||
r"\bgewährleisten\b", r"\bsicherstellen\b",
|
||||
r"\bshall\b", r"\bmust\b", r"\brequired\b",
|
||||
r"\bshould\b", r"\bensure\b",
|
||||
]
|
||||
_NORMATIVE_RE = re.compile("|".join(_NORMATIVE_SIGNALS), re.IGNORECASE)
|
||||
|
||||
_RATIONALE_SIGNALS = [
|
||||
r"\bda\s+", r"\bweil\b", r"\bgrund\b", r"\berwägung",
|
||||
r"\bbecause\b", r"\breason\b", r"\brationale\b",
|
||||
r"\bkönnen\s+.*\s+verursachen\b", r"\bführt\s+zu\b",
|
||||
]
|
||||
_RATIONALE_RE = re.compile("|".join(_RATIONALE_SIGNALS), re.IGNORECASE)
|
||||
|
||||
_TEST_SIGNALS = [
|
||||
r"\btesten\b", r"\btest\b", r"\bprüfung\b", r"\bprüfen\b",
|
||||
r"\bgetestet\b", r"\bwirksamkeit\b", r"\baudit\b",
|
||||
r"\bregelmäßig\b.*\b(prüf|test|kontroll)",
|
||||
r"\beffectiveness\b", r"\bverif",
|
||||
]
|
||||
_TEST_RE = re.compile("|".join(_TEST_SIGNALS), re.IGNORECASE)
|
||||
|
||||
_REPORTING_SIGNALS = [
|
||||
r"\bmelden\b", r"\bmeldung\b", r"\bunterricht",
|
||||
r"\binformieren\b", r"\bbenachricht", r"\bnotif",
|
||||
r"\breport\b", r"\bbehörd",
|
||||
]
|
||||
_REPORTING_RE = re.compile("|".join(_REPORTING_SIGNALS), re.IGNORECASE)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Data classes
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass
|
||||
class ObligationCandidate:
|
||||
"""A single normative obligation extracted from a Rich Control."""
|
||||
|
||||
candidate_id: str = ""
|
||||
parent_control_uuid: str = ""
|
||||
obligation_text: str = ""
|
||||
action: str = ""
|
||||
object_: str = ""
|
||||
condition: Optional[str] = None
|
||||
normative_strength: str = "must"
|
||||
is_test_obligation: bool = False
|
||||
is_reporting_obligation: bool = False
|
||||
extraction_confidence: float = 0.0
|
||||
quality_flags: dict = field(default_factory=dict)
|
||||
release_state: str = "extracted"
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
"candidate_id": self.candidate_id,
|
||||
"parent_control_uuid": self.parent_control_uuid,
|
||||
"obligation_text": self.obligation_text,
|
||||
"action": self.action,
|
||||
"object": self.object_,
|
||||
"condition": self.condition,
|
||||
"normative_strength": self.normative_strength,
|
||||
"is_test_obligation": self.is_test_obligation,
|
||||
"is_reporting_obligation": self.is_reporting_obligation,
|
||||
"extraction_confidence": self.extraction_confidence,
|
||||
"quality_flags": self.quality_flags,
|
||||
"release_state": self.release_state,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class AtomicControlCandidate:
|
||||
"""An atomic control composed from a single ObligationCandidate."""
|
||||
|
||||
candidate_id: str = ""
|
||||
parent_control_uuid: str = ""
|
||||
obligation_candidate_id: str = ""
|
||||
title: str = ""
|
||||
objective: str = ""
|
||||
requirements: list = field(default_factory=list)
|
||||
test_procedure: list = field(default_factory=list)
|
||||
evidence: list = field(default_factory=list)
|
||||
severity: str = "medium"
|
||||
category: str = ""
|
||||
domain: str = ""
|
||||
source_regulation: str = ""
|
||||
source_article: str = ""
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
"candidate_id": self.candidate_id,
|
||||
"parent_control_uuid": self.parent_control_uuid,
|
||||
"obligation_candidate_id": self.obligation_candidate_id,
|
||||
"title": self.title,
|
||||
"objective": self.objective,
|
||||
"requirements": self.requirements,
|
||||
"test_procedure": self.test_procedure,
|
||||
"evidence": self.evidence,
|
||||
"severity": self.severity,
|
||||
"category": self.category,
|
||||
"domain": self.domain,
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Quality Gate
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def quality_gate(candidate: ObligationCandidate) -> dict:
|
||||
"""Validate an obligation candidate. Returns quality flags dict.
|
||||
|
||||
Checks:
|
||||
has_normative_signal: text contains normative language
|
||||
single_action: only one main action (heuristic)
|
||||
not_rationale: not just a justification/reasoning
|
||||
not_evidence_only: not just an evidence requirement
|
||||
min_length: text is long enough to be meaningful
|
||||
has_parent_link: references back to parent control
|
||||
"""
|
||||
txt = candidate.obligation_text
|
||||
flags = {}
|
||||
|
||||
# 1. Normative signal
|
||||
flags["has_normative_signal"] = bool(_NORMATIVE_RE.search(txt))
|
||||
|
||||
# 2. Single action heuristic — count "und" / "and" / "sowie" splits
|
||||
# that connect different verbs (imperfect but useful)
|
||||
multi_verb_re = re.compile(
|
||||
r"\b(und|sowie|als auch)\b.*\b(müssen|sicherstellen|implementieren"
|
||||
r"|dokumentieren|melden|testen|prüfen|überwachen|gewährleisten)\b",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
flags["single_action"] = not bool(multi_verb_re.search(txt))
|
||||
|
||||
# 3. Not rationale
|
||||
normative_count = len(_NORMATIVE_RE.findall(txt))
|
||||
rationale_count = len(_RATIONALE_RE.findall(txt))
|
||||
flags["not_rationale"] = normative_count >= rationale_count
|
||||
|
||||
# 4. Not evidence-only (evidence fragments are typically short noun phrases)
|
||||
evidence_only_re = re.compile(
|
||||
r"^(Nachweis|Dokumentation|Screenshot|Protokoll|Bericht|Zertifikat)",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
flags["not_evidence_only"] = not bool(evidence_only_re.match(txt.strip()))
|
||||
|
||||
# 5. Min length
|
||||
flags["min_length"] = len(txt.strip()) >= 20
|
||||
|
||||
# 6. Parent link
|
||||
flags["has_parent_link"] = bool(candidate.parent_control_uuid)
|
||||
|
||||
return flags
|
||||
|
||||
|
||||
def passes_quality_gate(flags: dict) -> bool:
|
||||
"""Check if all critical quality flags pass."""
|
||||
critical = ["has_normative_signal", "not_evidence_only", "min_length", "has_parent_link"]
|
||||
return all(flags.get(k, False) for k in critical)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# LLM Prompts
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
_PASS0A_SYSTEM_PROMPT = """\
|
||||
Du bist ein Rechts-Compliance-Experte. Du zerlegst Compliance-Controls \
|
||||
in einzelne atomare Pflichten.
|
||||
|
||||
REGELN (STRIKT EINHALTEN):
|
||||
1. Nur normative Aussagen extrahieren — erkennbar an: müssen, haben \
|
||||
sicherzustellen, sind verpflichtet, ist zu dokumentieren, ist zu melden, \
|
||||
ist zu testen, shall, must, required.
|
||||
2. Jede Pflicht hat genau EIN Hauptverb / eine Handlung.
|
||||
3. Testpflichten SEPARAT von operativen Pflichten (is_test_obligation=true).
|
||||
4. Meldepflichten SEPARAT (is_reporting_obligation=true).
|
||||
5. NICHT auf Evidence-Ebene zerlegen (z.B. "DR-Plan vorhanden" ist KEIN \
|
||||
eigenes Control, sondern Evidence).
|
||||
6. Begründungen, Erläuterungen und Erwägungsgründe sind KEINE Pflichten \
|
||||
— NICHT extrahieren.
|
||||
|
||||
Antworte NUR mit einem JSON-Array. Keine Erklärungen."""
|
||||
|
||||
|
||||
def _build_pass0a_prompt(
|
||||
title: str, objective: str, requirements: str,
|
||||
test_procedure: str, source_ref: str
|
||||
) -> str:
|
||||
return f"""\
|
||||
Analysiere das folgende Control und extrahiere alle einzelnen normativen \
|
||||
Pflichten als JSON-Array.
|
||||
|
||||
CONTROL:
|
||||
Titel: {title}
|
||||
Ziel: {objective}
|
||||
Anforderungen: {requirements}
|
||||
Prüfverfahren: {test_procedure}
|
||||
Quellreferenz: {source_ref}
|
||||
|
||||
Antworte als JSON-Array:
|
||||
[
|
||||
{{
|
||||
"obligation_text": "Kurze, präzise Formulierung der Pflicht",
|
||||
"action": "Hauptverb/Handlung",
|
||||
"object": "Gegenstand der Pflicht",
|
||||
"condition": "Auslöser/Bedingung oder null",
|
||||
"normative_strength": "must",
|
||||
"is_test_obligation": false,
|
||||
"is_reporting_obligation": false
|
||||
}}
|
||||
]"""
|
||||
|
||||
|
||||
_PASS0B_SYSTEM_PROMPT = """\
|
||||
Du bist ein Security-Compliance-Experte. Du erstellst aus einer einzelnen \
|
||||
normativen Pflicht ein praxisorientiertes, atomares Security Control.
|
||||
|
||||
Das Control muss UMSETZBAR sein — keine Gesetzesparaphrase.
|
||||
Antworte NUR als JSON. Keine Erklärungen."""
|
||||
|
||||
|
||||
def _build_pass0b_prompt(
|
||||
obligation_text: str, action: str, object_: str,
|
||||
parent_title: str, parent_category: str, source_ref: str,
|
||||
) -> str:
|
||||
return f"""\
|
||||
Erstelle aus der folgenden Pflicht ein atomares Control.
|
||||
|
||||
PFLICHT: {obligation_text}
|
||||
HANDLUNG: {action}
|
||||
GEGENSTAND: {object_}
|
||||
|
||||
KONTEXT (Ursprungs-Control):
|
||||
Titel: {parent_title}
|
||||
Kategorie: {parent_category}
|
||||
Quellreferenz: {source_ref}
|
||||
|
||||
Antworte als JSON:
|
||||
{{
|
||||
"title": "Kurzer Titel (max 80 Zeichen, deutsch)",
|
||||
"objective": "Was muss erreicht werden? (1-2 Sätze)",
|
||||
"requirements": ["Konkrete Anforderung 1", "Anforderung 2"],
|
||||
"test_procedure": ["Prüfschritt 1", "Prüfschritt 2"],
|
||||
"evidence": ["Nachweis 1", "Nachweis 2"],
|
||||
"severity": "critical|high|medium|low",
|
||||
"category": "security|privacy|governance|operations|finance|reporting"
|
||||
}}"""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Parse helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _parse_json_array(text: str) -> list[dict]:
|
||||
"""Extract a JSON array from LLM response text."""
|
||||
# Try direct parse
|
||||
try:
|
||||
result = json.loads(text)
|
||||
if isinstance(result, list):
|
||||
return result
|
||||
if isinstance(result, dict):
|
||||
return [result]
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# Try extracting JSON array block
|
||||
match = re.search(r"\[[\s\S]*\]", text)
|
||||
if match:
|
||||
try:
|
||||
result = json.loads(match.group())
|
||||
if isinstance(result, list):
|
||||
return result
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
return []
|
||||
|
||||
|
||||
def _parse_json_object(text: str) -> dict:
|
||||
"""Extract a JSON object from LLM response text."""
|
||||
try:
|
||||
result = json.loads(text)
|
||||
if isinstance(result, dict):
|
||||
return result
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
match = re.search(r"\{[\s\S]*\}", text)
|
||||
if match:
|
||||
try:
|
||||
result = json.loads(match.group())
|
||||
if isinstance(result, dict):
|
||||
return result
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
return {}
|
||||
|
||||
|
||||
def _ensure_list(val) -> list:
|
||||
"""Ensure value is a list."""
|
||||
if isinstance(val, list):
|
||||
return val
|
||||
if isinstance(val, str):
|
||||
return [val] if val else []
|
||||
return []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Decomposition Pass
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class DecompositionPass:
|
||||
"""Pass 0: Decompose Rich Controls into atomic candidates.
|
||||
|
||||
Usage::
|
||||
|
||||
decomp = DecompositionPass(db=session)
|
||||
stats_0a = await decomp.run_pass0a(limit=100)
|
||||
stats_0b = await decomp.run_pass0b(limit=100)
|
||||
"""
|
||||
|
||||
def __init__(self, db: Session):
|
||||
self.db = db
|
||||
|
||||
# -------------------------------------------------------------------
|
||||
# Pass 0a: Obligation Extraction
|
||||
# -------------------------------------------------------------------
|
||||
|
||||
async def run_pass0a(self, limit: int = 0) -> dict:
|
||||
"""Extract obligation candidates from rich controls.
|
||||
|
||||
Processes controls that have NOT been decomposed yet
|
||||
(no rows in obligation_candidates for that control).
|
||||
"""
|
||||
from compliance.services.obligation_extractor import _llm_ollama
|
||||
|
||||
# Find rich controls not yet decomposed
|
||||
query = """
|
||||
SELECT cc.id, cc.control_id, cc.title, cc.objective,
|
||||
cc.requirements, cc.test_procedure,
|
||||
cc.source_citation, cc.category
|
||||
FROM canonical_controls cc
|
||||
WHERE cc.release_state NOT IN ('deprecated')
|
||||
AND cc.parent_control_uuid IS NULL
|
||||
AND NOT EXISTS (
|
||||
SELECT 1 FROM obligation_candidates oc
|
||||
WHERE oc.parent_control_uuid = cc.id
|
||||
)
|
||||
ORDER BY cc.created_at
|
||||
"""
|
||||
if limit > 0:
|
||||
query += f" LIMIT {limit}"
|
||||
|
||||
rows = self.db.execute(text(query)).fetchall()
|
||||
|
||||
stats = {
|
||||
"controls_processed": 0,
|
||||
"obligations_extracted": 0,
|
||||
"obligations_validated": 0,
|
||||
"obligations_rejected": 0,
|
||||
"controls_skipped_empty": 0,
|
||||
"errors": 0,
|
||||
}
|
||||
|
||||
for row in rows:
|
||||
control_uuid = str(row[0])
|
||||
control_id = row[1] or ""
|
||||
title = row[2] or ""
|
||||
objective = row[3] or ""
|
||||
requirements = row[4] or ""
|
||||
test_procedure = row[5] or ""
|
||||
source_citation = row[6] or ""
|
||||
category = row[7] or ""
|
||||
|
||||
# Format requirements/test_procedure if JSON
|
||||
req_str = _format_field(requirements)
|
||||
test_str = _format_field(test_procedure)
|
||||
source_str = _format_citation(source_citation)
|
||||
|
||||
if not title and not objective and not req_str:
|
||||
stats["controls_skipped_empty"] += 1
|
||||
continue
|
||||
|
||||
try:
|
||||
prompt = _build_pass0a_prompt(
|
||||
title=title,
|
||||
objective=objective,
|
||||
requirements=req_str,
|
||||
test_procedure=test_str,
|
||||
source_ref=source_str,
|
||||
)
|
||||
|
||||
llm_response = await _llm_ollama(
|
||||
prompt=prompt,
|
||||
system_prompt=_PASS0A_SYSTEM_PROMPT,
|
||||
)
|
||||
|
||||
raw_obligations = _parse_json_array(llm_response)
|
||||
|
||||
if not raw_obligations:
|
||||
# Fallback: treat the whole control as one obligation
|
||||
raw_obligations = [{
|
||||
"obligation_text": objective or title,
|
||||
"action": "sicherstellen",
|
||||
"object": title,
|
||||
"condition": None,
|
||||
"normative_strength": "must",
|
||||
"is_test_obligation": False,
|
||||
"is_reporting_obligation": False,
|
||||
}]
|
||||
|
||||
for idx, raw in enumerate(raw_obligations):
|
||||
cand = ObligationCandidate(
|
||||
candidate_id=f"OC-{control_id}-{idx + 1:02d}",
|
||||
parent_control_uuid=control_uuid,
|
||||
obligation_text=raw.get("obligation_text", ""),
|
||||
action=raw.get("action", ""),
|
||||
object_=raw.get("object", ""),
|
||||
condition=raw.get("condition"),
|
||||
normative_strength=raw.get("normative_strength", "must"),
|
||||
is_test_obligation=bool(raw.get("is_test_obligation", False)),
|
||||
is_reporting_obligation=bool(raw.get("is_reporting_obligation", False)),
|
||||
)
|
||||
|
||||
# Auto-detect test/reporting if LLM missed it
|
||||
if not cand.is_test_obligation and _TEST_RE.search(cand.obligation_text):
|
||||
cand.is_test_obligation = True
|
||||
if not cand.is_reporting_obligation and _REPORTING_RE.search(cand.obligation_text):
|
||||
cand.is_reporting_obligation = True
|
||||
|
||||
# Quality gate
|
||||
flags = quality_gate(cand)
|
||||
cand.quality_flags = flags
|
||||
cand.extraction_confidence = _compute_extraction_confidence(flags)
|
||||
|
||||
if passes_quality_gate(flags):
|
||||
cand.release_state = "validated"
|
||||
stats["obligations_validated"] += 1
|
||||
else:
|
||||
cand.release_state = "rejected"
|
||||
stats["obligations_rejected"] += 1
|
||||
|
||||
# Write to DB
|
||||
self._write_obligation_candidate(cand)
|
||||
stats["obligations_extracted"] += 1
|
||||
|
||||
stats["controls_processed"] += 1
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Pass 0a failed for %s: %s", control_id, e)
|
||||
stats["errors"] += 1
|
||||
|
||||
self.db.commit()
|
||||
logger.info("Pass 0a: %s", stats)
|
||||
return stats
|
||||
|
||||
# -------------------------------------------------------------------
|
||||
# Pass 0b: Atomic Control Composition
|
||||
# -------------------------------------------------------------------
|
||||
|
||||
async def run_pass0b(self, limit: int = 0) -> dict:
|
||||
"""Compose atomic controls from validated obligation candidates.
|
||||
|
||||
Processes obligation_candidates with release_state='validated'
|
||||
that don't have a corresponding atomic control yet.
|
||||
"""
|
||||
from compliance.services.obligation_extractor import _llm_ollama
|
||||
|
||||
query = """
|
||||
SELECT oc.id, oc.candidate_id, oc.parent_control_uuid,
|
||||
oc.obligation_text, oc.action, oc.object,
|
||||
oc.is_test_obligation, oc.is_reporting_obligation,
|
||||
cc.title AS parent_title,
|
||||
cc.category AS parent_category,
|
||||
cc.source_citation AS parent_citation,
|
||||
cc.severity AS parent_severity,
|
||||
cc.control_id AS parent_control_id
|
||||
FROM obligation_candidates oc
|
||||
JOIN canonical_controls cc ON cc.id = oc.parent_control_uuid
|
||||
WHERE oc.release_state = 'validated'
|
||||
AND NOT EXISTS (
|
||||
SELECT 1 FROM canonical_controls ac
|
||||
WHERE ac.parent_control_uuid = oc.parent_control_uuid
|
||||
AND ac.decomposition_method = 'pass0b'
|
||||
AND ac.title LIKE '%' || LEFT(oc.action, 20) || '%'
|
||||
)
|
||||
"""
|
||||
if limit > 0:
|
||||
query += f" LIMIT {limit}"
|
||||
|
||||
rows = self.db.execute(text(query)).fetchall()
|
||||
|
||||
stats = {
|
||||
"candidates_processed": 0,
|
||||
"controls_created": 0,
|
||||
"llm_failures": 0,
|
||||
"errors": 0,
|
||||
}
|
||||
|
||||
for row in rows:
|
||||
oc_id = str(row[0])
|
||||
candidate_id = row[1] or ""
|
||||
parent_uuid = str(row[2])
|
||||
obligation_text = row[3] or ""
|
||||
action = row[4] or ""
|
||||
object_ = row[5] or ""
|
||||
is_test = row[6]
|
||||
is_reporting = row[7]
|
||||
parent_title = row[8] or ""
|
||||
parent_category = row[9] or ""
|
||||
parent_citation = row[10] or ""
|
||||
parent_severity = row[11] or "medium"
|
||||
parent_control_id = row[12] or ""
|
||||
|
||||
source_str = _format_citation(parent_citation)
|
||||
|
||||
try:
|
||||
prompt = _build_pass0b_prompt(
|
||||
obligation_text=obligation_text,
|
||||
action=action,
|
||||
object_=object_,
|
||||
parent_title=parent_title,
|
||||
parent_category=parent_category,
|
||||
source_ref=source_str,
|
||||
)
|
||||
|
||||
llm_response = await _llm_ollama(
|
||||
prompt=prompt,
|
||||
system_prompt=_PASS0B_SYSTEM_PROMPT,
|
||||
)
|
||||
|
||||
parsed = _parse_json_object(llm_response)
|
||||
|
||||
if not parsed or not parsed.get("title"):
|
||||
# Template fallback — no LLM needed
|
||||
atomic = _template_fallback(
|
||||
obligation_text=obligation_text,
|
||||
action=action,
|
||||
object_=object_,
|
||||
parent_title=parent_title,
|
||||
parent_severity=parent_severity,
|
||||
parent_category=parent_category,
|
||||
is_test=is_test,
|
||||
is_reporting=is_reporting,
|
||||
)
|
||||
stats["llm_failures"] += 1
|
||||
else:
|
||||
atomic = AtomicControlCandidate(
|
||||
title=parsed.get("title", "")[:200],
|
||||
objective=parsed.get("objective", "")[:2000],
|
||||
requirements=_ensure_list(parsed.get("requirements", [])),
|
||||
test_procedure=_ensure_list(parsed.get("test_procedure", [])),
|
||||
evidence=_ensure_list(parsed.get("evidence", [])),
|
||||
severity=_normalize_severity(parsed.get("severity", parent_severity)),
|
||||
category=parsed.get("category", parent_category),
|
||||
)
|
||||
|
||||
atomic.parent_control_uuid = parent_uuid
|
||||
atomic.obligation_candidate_id = candidate_id
|
||||
|
||||
# Generate control_id from parent
|
||||
seq = self._next_atomic_seq(parent_control_id)
|
||||
atomic.candidate_id = f"{parent_control_id}-A{seq:02d}"
|
||||
|
||||
# Write to canonical_controls
|
||||
self._write_atomic_control(atomic, parent_uuid, candidate_id)
|
||||
|
||||
# Mark obligation candidate as composed
|
||||
self.db.execute(
|
||||
text("""
|
||||
UPDATE obligation_candidates
|
||||
SET release_state = 'composed'
|
||||
WHERE id = CAST(:oc_id AS uuid)
|
||||
"""),
|
||||
{"oc_id": oc_id},
|
||||
)
|
||||
|
||||
stats["controls_created"] += 1
|
||||
stats["candidates_processed"] += 1
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Pass 0b failed for %s: %s", candidate_id, e)
|
||||
stats["errors"] += 1
|
||||
|
||||
self.db.commit()
|
||||
logger.info("Pass 0b: %s", stats)
|
||||
return stats
|
||||
|
||||
# -------------------------------------------------------------------
|
||||
# Decomposition Status
|
||||
# -------------------------------------------------------------------
|
||||
|
||||
def decomposition_status(self) -> dict:
|
||||
"""Return decomposition progress."""
|
||||
row = self.db.execute(text("""
|
||||
SELECT
|
||||
(SELECT count(*) FROM canonical_controls
|
||||
WHERE parent_control_uuid IS NULL
|
||||
AND release_state NOT IN ('deprecated')) AS rich_controls,
|
||||
(SELECT count(DISTINCT parent_control_uuid) FROM obligation_candidates) AS decomposed_controls,
|
||||
(SELECT count(*) FROM obligation_candidates) AS total_candidates,
|
||||
(SELECT count(*) FROM obligation_candidates WHERE release_state = 'validated') AS validated,
|
||||
(SELECT count(*) FROM obligation_candidates WHERE release_state = 'rejected') AS rejected,
|
||||
(SELECT count(*) FROM obligation_candidates WHERE release_state = 'composed') AS composed,
|
||||
(SELECT count(*) FROM canonical_controls WHERE parent_control_uuid IS NOT NULL) AS atomic_controls
|
||||
""")).fetchone()
|
||||
|
||||
return {
|
||||
"rich_controls": row[0],
|
||||
"decomposed_controls": row[1],
|
||||
"total_candidates": row[2],
|
||||
"validated": row[3],
|
||||
"rejected": row[4],
|
||||
"composed": row[5],
|
||||
"atomic_controls": row[6],
|
||||
"decomposition_pct": round(row[1] / max(row[0], 1) * 100, 1),
|
||||
"composition_pct": round(row[5] / max(row[3], 1) * 100, 1),
|
||||
}
|
||||
|
||||
# -------------------------------------------------------------------
|
||||
# DB Writers
|
||||
# -------------------------------------------------------------------
|
||||
|
||||
def _write_obligation_candidate(self, cand: ObligationCandidate) -> None:
|
||||
"""Insert an obligation candidate into the DB."""
|
||||
self.db.execute(
|
||||
text("""
|
||||
INSERT INTO obligation_candidates (
|
||||
parent_control_uuid, candidate_id,
|
||||
obligation_text, action, object, condition,
|
||||
normative_strength, is_test_obligation,
|
||||
is_reporting_obligation, extraction_confidence,
|
||||
quality_flags, release_state
|
||||
) VALUES (
|
||||
CAST(:parent_uuid AS uuid), :candidate_id,
|
||||
:obligation_text, :action, :object, :condition,
|
||||
:normative_strength, :is_test, :is_reporting,
|
||||
:confidence, :quality_flags, :release_state
|
||||
)
|
||||
"""),
|
||||
{
|
||||
"parent_uuid": cand.parent_control_uuid,
|
||||
"candidate_id": cand.candidate_id,
|
||||
"obligation_text": cand.obligation_text,
|
||||
"action": cand.action,
|
||||
"object": cand.object_,
|
||||
"condition": cand.condition,
|
||||
"normative_strength": cand.normative_strength,
|
||||
"is_test": cand.is_test_obligation,
|
||||
"is_reporting": cand.is_reporting_obligation,
|
||||
"confidence": cand.extraction_confidence,
|
||||
"quality_flags": json.dumps(cand.quality_flags),
|
||||
"release_state": cand.release_state,
|
||||
},
|
||||
)
|
||||
|
||||
def _write_atomic_control(
|
||||
self, atomic: AtomicControlCandidate,
|
||||
parent_uuid: str, candidate_id: str,
|
||||
) -> None:
|
||||
"""Insert an atomic control into canonical_controls."""
|
||||
self.db.execute(
|
||||
text("""
|
||||
INSERT INTO canonical_controls (
|
||||
control_id, title, objective, requirements,
|
||||
test_procedure, evidence, severity, category,
|
||||
release_state, parent_control_uuid,
|
||||
decomposition_method,
|
||||
generation_metadata
|
||||
) VALUES (
|
||||
:control_id, :title, :objective,
|
||||
:requirements, :test_procedure, :evidence,
|
||||
:severity, :category, 'draft',
|
||||
CAST(:parent_uuid AS uuid), 'pass0b',
|
||||
:gen_meta
|
||||
)
|
||||
"""),
|
||||
{
|
||||
"control_id": atomic.candidate_id,
|
||||
"title": atomic.title,
|
||||
"objective": atomic.objective,
|
||||
"requirements": json.dumps(atomic.requirements),
|
||||
"test_procedure": json.dumps(atomic.test_procedure),
|
||||
"evidence": json.dumps(atomic.evidence),
|
||||
"severity": atomic.severity,
|
||||
"category": atomic.category,
|
||||
"parent_uuid": parent_uuid,
|
||||
"gen_meta": json.dumps({
|
||||
"decomposition_source": candidate_id,
|
||||
"decomposition_method": "pass0b",
|
||||
}),
|
||||
},
|
||||
)
|
||||
|
||||
def _next_atomic_seq(self, parent_control_id: str) -> int:
|
||||
"""Get the next sequence number for atomic controls under a parent."""
|
||||
result = self.db.execute(
|
||||
text("""
|
||||
SELECT count(*) FROM canonical_controls
|
||||
WHERE parent_control_uuid = (
|
||||
SELECT id FROM canonical_controls
|
||||
WHERE control_id = :parent_id
|
||||
LIMIT 1
|
||||
)
|
||||
"""),
|
||||
{"parent_id": parent_control_id},
|
||||
).fetchone()
|
||||
return (result[0] if result else 0) + 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _format_field(value) -> str:
|
||||
"""Format a requirements/test_procedure field for the LLM prompt."""
|
||||
if not value:
|
||||
return ""
|
||||
if isinstance(value, str):
|
||||
try:
|
||||
parsed = json.loads(value)
|
||||
if isinstance(parsed, list):
|
||||
return "\n".join(f"- {item}" for item in parsed)
|
||||
return value
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
return value
|
||||
if isinstance(value, list):
|
||||
return "\n".join(f"- {item}" for item in value)
|
||||
return str(value)
|
||||
|
||||
|
||||
def _format_citation(citation) -> str:
|
||||
"""Format source_citation for display."""
|
||||
if not citation:
|
||||
return ""
|
||||
if isinstance(citation, str):
|
||||
try:
|
||||
c = json.loads(citation)
|
||||
if isinstance(c, dict):
|
||||
parts = []
|
||||
if c.get("source"):
|
||||
parts.append(c["source"])
|
||||
if c.get("article"):
|
||||
parts.append(c["article"])
|
||||
if c.get("paragraph"):
|
||||
parts.append(c["paragraph"])
|
||||
return " ".join(parts) if parts else citation
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
return citation
|
||||
return str(citation)
|
||||
|
||||
|
||||
def _compute_extraction_confidence(flags: dict) -> float:
|
||||
"""Compute confidence score from quality flags."""
|
||||
score = 0.0
|
||||
weights = {
|
||||
"has_normative_signal": 0.30,
|
||||
"single_action": 0.20,
|
||||
"not_rationale": 0.20,
|
||||
"not_evidence_only": 0.15,
|
||||
"min_length": 0.10,
|
||||
"has_parent_link": 0.05,
|
||||
}
|
||||
for flag, weight in weights.items():
|
||||
if flags.get(flag, False):
|
||||
score += weight
|
||||
return round(score, 2)
|
||||
|
||||
|
||||
def _normalize_severity(val: str) -> str:
|
||||
"""Normalize severity value."""
|
||||
val = (val or "medium").lower().strip()
|
||||
if val in ("critical", "high", "medium", "low"):
|
||||
return val
|
||||
return "medium"
|
||||
|
||||
|
||||
def _template_fallback(
|
||||
obligation_text: str, action: str, object_: str,
|
||||
parent_title: str, parent_severity: str, parent_category: str,
|
||||
is_test: bool, is_reporting: bool,
|
||||
) -> AtomicControlCandidate:
|
||||
"""Create an atomic control candidate from template when LLM fails."""
|
||||
if is_test:
|
||||
title = f"Test: {object_[:60]}" if object_ else f"Test: {action[:60]}"
|
||||
test_proc = [f"Prüfung der {object_ or action}"]
|
||||
evidence = ["Testprotokoll", "Prüfbericht"]
|
||||
elif is_reporting:
|
||||
title = f"Meldepflicht: {object_[:60]}" if object_ else f"Meldung: {action[:60]}"
|
||||
test_proc = ["Prüfung des Meldeprozesses", "Stichprobe gemeldeter Vorfälle"]
|
||||
evidence = ["Meldeprozess-Dokumentation", "Meldeformulare"]
|
||||
else:
|
||||
title = f"{action.capitalize()}: {object_[:60]}" if object_ else parent_title[:80]
|
||||
test_proc = [f"Prüfung der {action}"]
|
||||
evidence = ["Dokumentation", "Konfigurationsnachweis"]
|
||||
|
||||
return AtomicControlCandidate(
|
||||
title=title[:200],
|
||||
objective=obligation_text[:2000],
|
||||
requirements=[obligation_text] if obligation_text else [],
|
||||
test_procedure=test_proc,
|
||||
evidence=evidence,
|
||||
severity=_normalize_severity(parent_severity),
|
||||
category=parent_category,
|
||||
)
|
||||
562
backend-compliance/compliance/services/obligation_extractor.py
Normal file
562
backend-compliance/compliance/services/obligation_extractor.py
Normal file
@@ -0,0 +1,562 @@
|
||||
"""Obligation Extractor — 3-Tier Chunk-to-Obligation Linking.
|
||||
|
||||
Maps RAG chunks to obligations from the v2 obligation framework using
|
||||
three tiers (fastest first):
|
||||
|
||||
Tier 1: EXACT MATCH — regulation_code + article → obligation_id (~40%)
|
||||
Tier 2: EMBEDDING — chunk text vs. obligation descriptions (~30%)
|
||||
Tier 3: LLM EXTRACT — local Ollama extracts obligation text (~25%)
|
||||
|
||||
Part of the Multi-Layer Control Architecture (Phase 4 of 8).
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import httpx
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
EMBEDDING_URL = os.getenv("EMBEDDING_URL", "http://embedding-service:8087")
|
||||
OLLAMA_URL = os.getenv("OLLAMA_URL", "http://host.docker.internal:11434")
|
||||
OLLAMA_MODEL = os.getenv("CONTROL_GEN_OLLAMA_MODEL", "qwen3.5:35b-a3b")
|
||||
LLM_TIMEOUT = float(os.getenv("CONTROL_GEN_LLM_TIMEOUT", "180"))
|
||||
|
||||
# Embedding similarity thresholds for Tier 2
|
||||
EMBEDDING_MATCH_THRESHOLD = 0.80
|
||||
EMBEDDING_CANDIDATE_THRESHOLD = 0.60
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Regulation code mapping: RAG chunk codes → obligation file regulation IDs
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_REGULATION_CODE_TO_ID = {
|
||||
# DSGVO
|
||||
"eu_2016_679": "dsgvo",
|
||||
"dsgvo": "dsgvo",
|
||||
"gdpr": "dsgvo",
|
||||
# AI Act
|
||||
"eu_2024_1689": "ai_act",
|
||||
"ai_act": "ai_act",
|
||||
"aiact": "ai_act",
|
||||
# NIS2
|
||||
"eu_2022_2555": "nis2",
|
||||
"nis2": "nis2",
|
||||
"bsig": "nis2",
|
||||
# BDSG
|
||||
"bdsg": "bdsg",
|
||||
# TTDSG
|
||||
"ttdsg": "ttdsg",
|
||||
# DSA
|
||||
"eu_2022_2065": "dsa",
|
||||
"dsa": "dsa",
|
||||
# Data Act
|
||||
"eu_2023_2854": "data_act",
|
||||
"data_act": "data_act",
|
||||
# EU Machinery
|
||||
"eu_2023_1230": "eu_machinery",
|
||||
"eu_machinery": "eu_machinery",
|
||||
# DORA
|
||||
"eu_2022_2554": "dora",
|
||||
"dora": "dora",
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class ObligationMatch:
|
||||
"""Result of obligation extraction."""
|
||||
|
||||
obligation_id: Optional[str] = None
|
||||
obligation_title: Optional[str] = None
|
||||
obligation_text: Optional[str] = None
|
||||
method: str = "none" # exact_match | embedding_match | llm_extracted | inferred
|
||||
confidence: float = 0.0
|
||||
regulation_id: Optional[str] = None # e.g. "dsgvo"
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
"obligation_id": self.obligation_id,
|
||||
"obligation_title": self.obligation_title,
|
||||
"obligation_text": self.obligation_text,
|
||||
"method": self.method,
|
||||
"confidence": self.confidence,
|
||||
"regulation_id": self.regulation_id,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class _ObligationEntry:
|
||||
"""Internal representation of a loaded obligation."""
|
||||
|
||||
id: str
|
||||
title: str
|
||||
description: str
|
||||
regulation_id: str
|
||||
articles: list[str] = field(default_factory=list) # normalized: ["art. 30", "§ 38"]
|
||||
embedding: list[float] = field(default_factory=list)
|
||||
|
||||
|
||||
class ObligationExtractor:
|
||||
"""3-Tier obligation extraction from RAG chunks.
|
||||
|
||||
Usage::
|
||||
|
||||
extractor = ObligationExtractor()
|
||||
await extractor.initialize() # loads obligations + embeddings
|
||||
|
||||
match = await extractor.extract(
|
||||
chunk_text="...",
|
||||
regulation_code="eu_2016_679",
|
||||
article="Art. 30",
|
||||
paragraph="Abs. 1",
|
||||
)
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._article_lookup: dict[str, list[str]] = {} # "dsgvo/art. 30" → ["DSGVO-OBL-001"]
|
||||
self._obligations: dict[str, _ObligationEntry] = {} # id → entry
|
||||
self._obligation_embeddings: list[list[float]] = []
|
||||
self._obligation_ids: list[str] = []
|
||||
self._initialized = False
|
||||
|
||||
async def initialize(self) -> None:
|
||||
"""Load all obligations from v2 JSON files and compute embeddings."""
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
self._load_obligations()
|
||||
await self._compute_embeddings()
|
||||
self._initialized = True
|
||||
logger.info(
|
||||
"ObligationExtractor initialized: %d obligations, %d article lookups, %d embeddings",
|
||||
len(self._obligations),
|
||||
len(self._article_lookup),
|
||||
sum(1 for e in self._obligation_embeddings if e),
|
||||
)
|
||||
|
||||
async def extract(
|
||||
self,
|
||||
chunk_text: str,
|
||||
regulation_code: str,
|
||||
article: Optional[str] = None,
|
||||
paragraph: Optional[str] = None,
|
||||
) -> ObligationMatch:
|
||||
"""Extract obligation from a chunk using 3-tier strategy."""
|
||||
if not self._initialized:
|
||||
await self.initialize()
|
||||
|
||||
reg_id = _normalize_regulation(regulation_code)
|
||||
|
||||
# Tier 1: Exact match via article lookup
|
||||
if article:
|
||||
match = self._tier1_exact(reg_id, article)
|
||||
if match:
|
||||
return match
|
||||
|
||||
# Tier 2: Embedding similarity
|
||||
match = await self._tier2_embedding(chunk_text, reg_id)
|
||||
if match:
|
||||
return match
|
||||
|
||||
# Tier 3: LLM extraction
|
||||
match = await self._tier3_llm(chunk_text, regulation_code, article)
|
||||
return match
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# Tier 1: Exact Match
|
||||
# -----------------------------------------------------------------------
|
||||
|
||||
def _tier1_exact(self, reg_id: Optional[str], article: str) -> Optional[ObligationMatch]:
|
||||
"""Look up obligation by regulation + article."""
|
||||
if not reg_id:
|
||||
return None
|
||||
|
||||
norm_article = _normalize_article(article)
|
||||
key = f"{reg_id}/{norm_article}"
|
||||
|
||||
obl_ids = self._article_lookup.get(key)
|
||||
if not obl_ids:
|
||||
return None
|
||||
|
||||
# Take the first match (highest priority)
|
||||
obl_id = obl_ids[0]
|
||||
entry = self._obligations.get(obl_id)
|
||||
if not entry:
|
||||
return None
|
||||
|
||||
return ObligationMatch(
|
||||
obligation_id=entry.id,
|
||||
obligation_title=entry.title,
|
||||
obligation_text=entry.description,
|
||||
method="exact_match",
|
||||
confidence=1.0,
|
||||
regulation_id=reg_id,
|
||||
)
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# Tier 2: Embedding Match
|
||||
# -----------------------------------------------------------------------
|
||||
|
||||
async def _tier2_embedding(
|
||||
self, chunk_text: str, reg_id: Optional[str]
|
||||
) -> Optional[ObligationMatch]:
|
||||
"""Find nearest obligation by embedding similarity."""
|
||||
if not self._obligation_embeddings:
|
||||
return None
|
||||
|
||||
chunk_embedding = await _get_embedding(chunk_text[:2000])
|
||||
if not chunk_embedding:
|
||||
return None
|
||||
|
||||
best_idx = -1
|
||||
best_score = 0.0
|
||||
|
||||
for i, obl_emb in enumerate(self._obligation_embeddings):
|
||||
if not obl_emb:
|
||||
continue
|
||||
# Prefer same-regulation matches
|
||||
obl_id = self._obligation_ids[i]
|
||||
entry = self._obligations.get(obl_id)
|
||||
score = _cosine_sim(chunk_embedding, obl_emb)
|
||||
|
||||
# Domain bonus: +0.05 if same regulation
|
||||
if entry and reg_id and entry.regulation_id == reg_id:
|
||||
score += 0.05
|
||||
|
||||
if score > best_score:
|
||||
best_score = score
|
||||
best_idx = i
|
||||
|
||||
if best_idx < 0:
|
||||
return None
|
||||
|
||||
# Remove domain bonus for threshold comparison
|
||||
raw_score = best_score
|
||||
obl_id = self._obligation_ids[best_idx]
|
||||
entry = self._obligations.get(obl_id)
|
||||
if entry and reg_id and entry.regulation_id == reg_id:
|
||||
raw_score -= 0.05
|
||||
|
||||
if raw_score >= EMBEDDING_MATCH_THRESHOLD:
|
||||
return ObligationMatch(
|
||||
obligation_id=entry.id if entry else obl_id,
|
||||
obligation_title=entry.title if entry else None,
|
||||
obligation_text=entry.description if entry else None,
|
||||
method="embedding_match",
|
||||
confidence=round(min(raw_score, 1.0), 3),
|
||||
regulation_id=entry.regulation_id if entry else reg_id,
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# Tier 3: LLM Extraction
|
||||
# -----------------------------------------------------------------------
|
||||
|
||||
async def _tier3_llm(
|
||||
self, chunk_text: str, regulation_code: str, article: Optional[str]
|
||||
) -> ObligationMatch:
|
||||
"""Use local LLM to extract the obligation from the chunk."""
|
||||
prompt = f"""Analysiere den folgenden Gesetzestext und extrahiere die zentrale rechtliche Pflicht.
|
||||
|
||||
Text:
|
||||
{chunk_text[:3000]}
|
||||
|
||||
Quelle: {regulation_code} {article or ''}
|
||||
|
||||
Antworte NUR als JSON:
|
||||
{{
|
||||
"obligation_text": "Die zentrale Pflicht in einem Satz",
|
||||
"actor": "Wer muss handeln (z.B. Verantwortlicher, Auftragsverarbeiter)",
|
||||
"action": "Was muss getan werden",
|
||||
"normative_strength": "muss|soll|kann"
|
||||
}}"""
|
||||
|
||||
system_prompt = (
|
||||
"Du bist ein Rechtsexperte fuer EU-Datenschutz- und Digitalrecht. "
|
||||
"Extrahiere die zentrale rechtliche Pflicht aus Gesetzestexten. "
|
||||
"Antworte ausschliesslich als JSON."
|
||||
)
|
||||
|
||||
result_text = await _llm_ollama(prompt, system_prompt)
|
||||
if not result_text:
|
||||
return ObligationMatch(
|
||||
method="llm_extracted",
|
||||
confidence=0.0,
|
||||
regulation_id=_normalize_regulation(regulation_code),
|
||||
)
|
||||
|
||||
parsed = _parse_json(result_text)
|
||||
obligation_text = parsed.get("obligation_text", result_text[:500])
|
||||
|
||||
return ObligationMatch(
|
||||
obligation_id=None,
|
||||
obligation_title=None,
|
||||
obligation_text=obligation_text,
|
||||
method="llm_extracted",
|
||||
confidence=0.60,
|
||||
regulation_id=_normalize_regulation(regulation_code),
|
||||
)
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# Initialization helpers
|
||||
# -----------------------------------------------------------------------
|
||||
|
||||
def _load_obligations(self) -> None:
|
||||
"""Load all obligation files from v2 framework."""
|
||||
v2_dir = _find_obligations_dir()
|
||||
if not v2_dir:
|
||||
logger.warning("Obligations v2 directory not found — Tier 1 disabled")
|
||||
return
|
||||
|
||||
manifest_path = v2_dir / "_manifest.json"
|
||||
if not manifest_path.exists():
|
||||
logger.warning("Manifest not found at %s", manifest_path)
|
||||
return
|
||||
|
||||
with open(manifest_path) as f:
|
||||
manifest = json.load(f)
|
||||
|
||||
for reg_info in manifest.get("regulations", []):
|
||||
reg_id = reg_info["id"]
|
||||
reg_file = v2_dir / reg_info["file"]
|
||||
if not reg_file.exists():
|
||||
logger.warning("Regulation file not found: %s", reg_file)
|
||||
continue
|
||||
|
||||
with open(reg_file) as f:
|
||||
data = json.load(f)
|
||||
|
||||
for obl in data.get("obligations", []):
|
||||
obl_id = obl["id"]
|
||||
entry = _ObligationEntry(
|
||||
id=obl_id,
|
||||
title=obl.get("title", ""),
|
||||
description=obl.get("description", ""),
|
||||
regulation_id=reg_id,
|
||||
)
|
||||
|
||||
# Build article lookup from legal_basis
|
||||
for basis in obl.get("legal_basis", []):
|
||||
article_raw = basis.get("article", "")
|
||||
if article_raw:
|
||||
norm_art = _normalize_article(article_raw)
|
||||
key = f"{reg_id}/{norm_art}"
|
||||
if key not in self._article_lookup:
|
||||
self._article_lookup[key] = []
|
||||
self._article_lookup[key].append(obl_id)
|
||||
entry.articles.append(norm_art)
|
||||
|
||||
self._obligations[obl_id] = entry
|
||||
|
||||
logger.info(
|
||||
"Loaded %d obligations from %d regulations",
|
||||
len(self._obligations),
|
||||
len(manifest.get("regulations", [])),
|
||||
)
|
||||
|
||||
async def _compute_embeddings(self) -> None:
|
||||
"""Compute embeddings for all obligation descriptions."""
|
||||
if not self._obligations:
|
||||
return
|
||||
|
||||
self._obligation_ids = list(self._obligations.keys())
|
||||
texts = [
|
||||
f"{self._obligations[oid].title}: {self._obligations[oid].description}"
|
||||
for oid in self._obligation_ids
|
||||
]
|
||||
|
||||
logger.info("Computing embeddings for %d obligations...", len(texts))
|
||||
self._obligation_embeddings = await _get_embeddings_batch(texts)
|
||||
valid = sum(1 for e in self._obligation_embeddings if e)
|
||||
logger.info("Got %d/%d valid embeddings", valid, len(texts))
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# Stats
|
||||
# -----------------------------------------------------------------------
|
||||
|
||||
def stats(self) -> dict:
|
||||
"""Return initialization statistics."""
|
||||
return {
|
||||
"total_obligations": len(self._obligations),
|
||||
"article_lookups": len(self._article_lookup),
|
||||
"embeddings_valid": sum(1 for e in self._obligation_embeddings if e),
|
||||
"regulations": list(
|
||||
{e.regulation_id for e in self._obligations.values()}
|
||||
),
|
||||
"initialized": self._initialized,
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Module-level helpers (reusable by other modules)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _normalize_regulation(regulation_code: str) -> Optional[str]:
|
||||
"""Map a RAG regulation_code to obligation framework regulation ID."""
|
||||
if not regulation_code:
|
||||
return None
|
||||
code = regulation_code.lower().strip()
|
||||
|
||||
# Direct lookup
|
||||
if code in _REGULATION_CODE_TO_ID:
|
||||
return _REGULATION_CODE_TO_ID[code]
|
||||
|
||||
# Prefix matching for families
|
||||
for prefix, reg_id in [
|
||||
("eu_2016_679", "dsgvo"),
|
||||
("eu_2024_1689", "ai_act"),
|
||||
("eu_2022_2555", "nis2"),
|
||||
("eu_2022_2065", "dsa"),
|
||||
("eu_2023_2854", "data_act"),
|
||||
("eu_2023_1230", "eu_machinery"),
|
||||
("eu_2022_2554", "dora"),
|
||||
]:
|
||||
if code.startswith(prefix):
|
||||
return reg_id
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _normalize_article(article: str) -> str:
|
||||
"""Normalize article references for consistent lookup.
|
||||
|
||||
Examples:
|
||||
"Art. 30" → "art. 30"
|
||||
"§ 38 BDSG" → "§ 38"
|
||||
"Article 10" → "art. 10"
|
||||
"Art. 30 Abs. 1" → "art. 30"
|
||||
"Artikel 35" → "art. 35"
|
||||
"""
|
||||
if not article:
|
||||
return ""
|
||||
s = article.strip()
|
||||
|
||||
# Remove trailing law name: "§ 38 BDSG" → "§ 38"
|
||||
s = re.sub(r"\s+(DSGVO|BDSG|TTDSG|DSA|NIS2|DORA|AI.?Act)\s*$", "", s, flags=re.IGNORECASE)
|
||||
|
||||
# Remove paragraph references: "Art. 30 Abs. 1" → "Art. 30"
|
||||
s = re.sub(r"\s+(Abs|Absatz|para|paragraph|lit|Satz)\.?\s+.*$", "", s, flags=re.IGNORECASE)
|
||||
|
||||
# Normalize "Article" / "Artikel" → "Art."
|
||||
s = re.sub(r"^(Article|Artikel)\s+", "Art. ", s, flags=re.IGNORECASE)
|
||||
|
||||
return s.lower().strip()
|
||||
|
||||
|
||||
def _cosine_sim(a: list[float], b: list[float]) -> float:
|
||||
"""Compute cosine similarity between two vectors."""
|
||||
if not a or not b or len(a) != len(b):
|
||||
return 0.0
|
||||
dot = sum(x * y for x, y in zip(a, b))
|
||||
norm_a = sum(x * x for x in a) ** 0.5
|
||||
norm_b = sum(x * x for x in b) ** 0.5
|
||||
if norm_a == 0 or norm_b == 0:
|
||||
return 0.0
|
||||
return dot / (norm_a * norm_b)
|
||||
|
||||
|
||||
def _find_obligations_dir() -> Optional[Path]:
|
||||
"""Locate the obligations v2 directory."""
|
||||
candidates = [
|
||||
Path(__file__).resolve().parent.parent.parent.parent
|
||||
/ "ai-compliance-sdk" / "policies" / "obligations" / "v2",
|
||||
Path("/app/ai-compliance-sdk/policies/obligations/v2"),
|
||||
Path("ai-compliance-sdk/policies/obligations/v2"),
|
||||
]
|
||||
for p in candidates:
|
||||
if p.is_dir() and (p / "_manifest.json").exists():
|
||||
return p
|
||||
return None
|
||||
|
||||
|
||||
async def _get_embedding(text: str) -> list[float]:
|
||||
"""Get embedding vector for a single text."""
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
resp = await client.post(
|
||||
f"{EMBEDDING_URL}/embed",
|
||||
json={"texts": [text]},
|
||||
)
|
||||
resp.raise_for_status()
|
||||
embeddings = resp.json().get("embeddings", [])
|
||||
return embeddings[0] if embeddings else []
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
|
||||
async def _get_embeddings_batch(
|
||||
texts: list[str], batch_size: int = 32
|
||||
) -> list[list[float]]:
|
||||
"""Get embeddings for multiple texts in batches."""
|
||||
all_embeddings: list[list[float]] = []
|
||||
for i in range(0, len(texts), batch_size):
|
||||
batch = texts[i : i + batch_size]
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
resp = await client.post(
|
||||
f"{EMBEDDING_URL}/embed",
|
||||
json={"texts": batch},
|
||||
)
|
||||
resp.raise_for_status()
|
||||
embeddings = resp.json().get("embeddings", [])
|
||||
all_embeddings.extend(embeddings)
|
||||
except Exception as e:
|
||||
logger.warning("Batch embedding failed for %d texts: %s", len(batch), e)
|
||||
all_embeddings.extend([[] for _ in batch])
|
||||
return all_embeddings
|
||||
|
||||
|
||||
async def _llm_ollama(prompt: str, system_prompt: Optional[str] = None) -> str:
|
||||
"""Call local Ollama for LLM extraction."""
|
||||
messages = []
|
||||
if system_prompt:
|
||||
messages.append({"role": "system", "content": system_prompt})
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
|
||||
payload = {
|
||||
"model": OLLAMA_MODEL,
|
||||
"messages": messages,
|
||||
"stream": False,
|
||||
"options": {"num_predict": 512},
|
||||
"think": False,
|
||||
}
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=LLM_TIMEOUT) as client:
|
||||
resp = await client.post(f"{OLLAMA_URL}/api/chat", json=payload)
|
||||
if resp.status_code != 200:
|
||||
logger.error(
|
||||
"Ollama chat failed %d: %s", resp.status_code, resp.text[:300]
|
||||
)
|
||||
return ""
|
||||
data = resp.json()
|
||||
return data.get("message", {}).get("content", "")
|
||||
except Exception as e:
|
||||
logger.warning("Ollama call failed: %s", e)
|
||||
return ""
|
||||
|
||||
|
||||
def _parse_json(text: str) -> dict:
|
||||
"""Extract JSON from LLM response text."""
|
||||
# Try direct parse
|
||||
try:
|
||||
return json.loads(text)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# Try extracting JSON block
|
||||
match = re.search(r"\{[^{}]*\}", text, re.DOTALL)
|
||||
if match:
|
||||
try:
|
||||
return json.loads(match.group())
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
return {}
|
||||
532
backend-compliance/compliance/services/pattern_matcher.py
Normal file
532
backend-compliance/compliance/services/pattern_matcher.py
Normal file
@@ -0,0 +1,532 @@
|
||||
"""Pattern Matcher — Obligation-to-Control-Pattern Linking.
|
||||
|
||||
Maps obligations (from the ObligationExtractor) to control patterns
|
||||
using two tiers:
|
||||
|
||||
Tier 1: KEYWORD MATCH — obligation_match_keywords from patterns (~70%)
|
||||
Tier 2: EMBEDDING — cosine similarity with domain bonus (~25%)
|
||||
|
||||
Part of the Multi-Layer Control Architecture (Phase 5 of 8).
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import yaml
|
||||
|
||||
from compliance.services.obligation_extractor import (
|
||||
_cosine_sim,
|
||||
_get_embedding,
|
||||
_get_embeddings_batch,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Minimum keyword score to accept a match (at least 2 keyword hits)
|
||||
KEYWORD_MATCH_MIN_HITS = 2
|
||||
# Embedding threshold for Tier 2
|
||||
EMBEDDING_PATTERN_THRESHOLD = 0.75
|
||||
# Domain bonus when regulation maps to the pattern's domain
|
||||
DOMAIN_BONUS = 0.10
|
||||
|
||||
# Map regulation IDs to pattern domains that are likely relevant
|
||||
_REGULATION_DOMAIN_AFFINITY = {
|
||||
"dsgvo": ["DATA", "COMP", "GOV"],
|
||||
"bdsg": ["DATA", "COMP"],
|
||||
"ttdsg": ["DATA"],
|
||||
"ai_act": ["AI", "COMP", "DATA"],
|
||||
"nis2": ["SEC", "INC", "NET", "LOG", "CRYP"],
|
||||
"dsa": ["DATA", "COMP"],
|
||||
"data_act": ["DATA", "COMP"],
|
||||
"eu_machinery": ["SEC", "COMP"],
|
||||
"dora": ["SEC", "INC", "FIN", "COMP"],
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class ControlPattern:
|
||||
"""Python representation of a control pattern from YAML."""
|
||||
|
||||
id: str
|
||||
name: str
|
||||
name_de: str
|
||||
domain: str
|
||||
category: str
|
||||
description: str
|
||||
objective_template: str
|
||||
rationale_template: str
|
||||
requirements_template: list[str] = field(default_factory=list)
|
||||
test_procedure_template: list[str] = field(default_factory=list)
|
||||
evidence_template: list[str] = field(default_factory=list)
|
||||
severity_default: str = "medium"
|
||||
implementation_effort_default: str = "m"
|
||||
obligation_match_keywords: list[str] = field(default_factory=list)
|
||||
tags: list[str] = field(default_factory=list)
|
||||
composable_with: list[str] = field(default_factory=list)
|
||||
open_anchor_refs: list[dict] = field(default_factory=list)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PatternMatchResult:
|
||||
"""Result of pattern matching."""
|
||||
|
||||
pattern: Optional[ControlPattern] = None
|
||||
pattern_id: Optional[str] = None
|
||||
method: str = "none" # keyword | embedding | combined | none
|
||||
confidence: float = 0.0
|
||||
keyword_hits: int = 0
|
||||
total_keywords: int = 0
|
||||
embedding_score: float = 0.0
|
||||
domain_bonus_applied: bool = False
|
||||
composable_patterns: list[str] = field(default_factory=list)
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
"pattern_id": self.pattern_id,
|
||||
"method": self.method,
|
||||
"confidence": round(self.confidence, 3),
|
||||
"keyword_hits": self.keyword_hits,
|
||||
"total_keywords": self.total_keywords,
|
||||
"embedding_score": round(self.embedding_score, 3),
|
||||
"domain_bonus_applied": self.domain_bonus_applied,
|
||||
"composable_patterns": self.composable_patterns,
|
||||
}
|
||||
|
||||
|
||||
class PatternMatcher:
|
||||
"""Links obligations to control patterns using keyword + embedding matching.
|
||||
|
||||
Usage::
|
||||
|
||||
matcher = PatternMatcher()
|
||||
await matcher.initialize()
|
||||
|
||||
result = await matcher.match(
|
||||
obligation_text="Fuehrung eines Verarbeitungsverzeichnisses...",
|
||||
regulation_id="dsgvo",
|
||||
)
|
||||
print(result.pattern_id) # e.g. "CP-COMP-001"
|
||||
print(result.confidence) # e.g. 0.85
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._patterns: list[ControlPattern] = []
|
||||
self._by_id: dict[str, ControlPattern] = {}
|
||||
self._by_domain: dict[str, list[ControlPattern]] = {}
|
||||
self._keyword_index: dict[str, list[str]] = {} # keyword → [pattern_ids]
|
||||
self._pattern_embeddings: list[list[float]] = []
|
||||
self._pattern_ids: list[str] = []
|
||||
self._initialized = False
|
||||
|
||||
async def initialize(self) -> None:
|
||||
"""Load patterns from YAML and compute embeddings."""
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
self._load_patterns()
|
||||
self._build_keyword_index()
|
||||
await self._compute_embeddings()
|
||||
self._initialized = True
|
||||
logger.info(
|
||||
"PatternMatcher initialized: %d patterns, %d keywords, %d embeddings",
|
||||
len(self._patterns),
|
||||
len(self._keyword_index),
|
||||
sum(1 for e in self._pattern_embeddings if e),
|
||||
)
|
||||
|
||||
async def match(
|
||||
self,
|
||||
obligation_text: str,
|
||||
regulation_id: Optional[str] = None,
|
||||
top_n: int = 1,
|
||||
) -> PatternMatchResult:
|
||||
"""Match obligation text to the best control pattern.
|
||||
|
||||
Args:
|
||||
obligation_text: The obligation description to match against.
|
||||
regulation_id: Source regulation (for domain bonus).
|
||||
top_n: Number of top results to consider for composability.
|
||||
|
||||
Returns:
|
||||
PatternMatchResult with the best match.
|
||||
"""
|
||||
if not self._initialized:
|
||||
await self.initialize()
|
||||
|
||||
if not obligation_text or not self._patterns:
|
||||
return PatternMatchResult()
|
||||
|
||||
# Tier 1: Keyword matching
|
||||
keyword_result = self._tier1_keyword(obligation_text, regulation_id)
|
||||
|
||||
# Tier 2: Embedding matching
|
||||
embedding_result = await self._tier2_embedding(obligation_text, regulation_id)
|
||||
|
||||
# Combine scores: prefer keyword match, boost with embedding if available
|
||||
best = self._combine_results(keyword_result, embedding_result)
|
||||
|
||||
# Attach composable patterns
|
||||
if best.pattern:
|
||||
best.composable_patterns = [
|
||||
pid for pid in best.pattern.composable_with
|
||||
if pid in self._by_id
|
||||
]
|
||||
|
||||
return best
|
||||
|
||||
async def match_top_n(
|
||||
self,
|
||||
obligation_text: str,
|
||||
regulation_id: Optional[str] = None,
|
||||
n: int = 3,
|
||||
) -> list[PatternMatchResult]:
|
||||
"""Return top-N pattern matches sorted by confidence descending."""
|
||||
if not self._initialized:
|
||||
await self.initialize()
|
||||
|
||||
if not obligation_text or not self._patterns:
|
||||
return []
|
||||
|
||||
keyword_scores = self._keyword_scores(obligation_text, regulation_id)
|
||||
embedding_scores = await self._embedding_scores(obligation_text, regulation_id)
|
||||
|
||||
# Merge scores
|
||||
all_pattern_ids = set(keyword_scores.keys()) | set(embedding_scores.keys())
|
||||
results: list[PatternMatchResult] = []
|
||||
|
||||
for pid in all_pattern_ids:
|
||||
pattern = self._by_id.get(pid)
|
||||
if not pattern:
|
||||
continue
|
||||
|
||||
kw_score = keyword_scores.get(pid, (0, 0, 0.0)) # (hits, total, score)
|
||||
emb_score = embedding_scores.get(pid, (0.0, False)) # (score, bonus_applied)
|
||||
|
||||
kw_hits, kw_total, kw_confidence = kw_score
|
||||
emb_confidence, bonus_applied = emb_score
|
||||
|
||||
# Combined confidence: max of keyword and embedding, with boost if both
|
||||
if kw_confidence > 0 and emb_confidence > 0:
|
||||
combined = max(kw_confidence, emb_confidence) + 0.05
|
||||
method = "combined"
|
||||
elif kw_confidence > 0:
|
||||
combined = kw_confidence
|
||||
method = "keyword"
|
||||
else:
|
||||
combined = emb_confidence
|
||||
method = "embedding"
|
||||
|
||||
results.append(PatternMatchResult(
|
||||
pattern=pattern,
|
||||
pattern_id=pid,
|
||||
method=method,
|
||||
confidence=min(combined, 1.0),
|
||||
keyword_hits=kw_hits,
|
||||
total_keywords=kw_total,
|
||||
embedding_score=emb_confidence,
|
||||
domain_bonus_applied=bonus_applied,
|
||||
composable_patterns=[
|
||||
p for p in pattern.composable_with if p in self._by_id
|
||||
],
|
||||
))
|
||||
|
||||
# Sort by confidence descending
|
||||
results.sort(key=lambda r: r.confidence, reverse=True)
|
||||
return results[:n]
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# Tier 1: Keyword Match
|
||||
# -----------------------------------------------------------------------
|
||||
|
||||
def _tier1_keyword(
|
||||
self, obligation_text: str, regulation_id: Optional[str]
|
||||
) -> Optional[PatternMatchResult]:
|
||||
"""Match by counting keyword hits in the obligation text."""
|
||||
scores = self._keyword_scores(obligation_text, regulation_id)
|
||||
if not scores:
|
||||
return None
|
||||
|
||||
# Find best match
|
||||
best_pid = max(scores, key=lambda pid: scores[pid][2])
|
||||
hits, total, confidence = scores[best_pid]
|
||||
|
||||
if hits < KEYWORD_MATCH_MIN_HITS:
|
||||
return None
|
||||
|
||||
pattern = self._by_id.get(best_pid)
|
||||
if not pattern:
|
||||
return None
|
||||
|
||||
# Check domain bonus
|
||||
bonus_applied = False
|
||||
if regulation_id and self._domain_matches(pattern.domain, regulation_id):
|
||||
confidence = min(confidence + DOMAIN_BONUS, 1.0)
|
||||
bonus_applied = True
|
||||
|
||||
return PatternMatchResult(
|
||||
pattern=pattern,
|
||||
pattern_id=best_pid,
|
||||
method="keyword",
|
||||
confidence=confidence,
|
||||
keyword_hits=hits,
|
||||
total_keywords=total,
|
||||
domain_bonus_applied=bonus_applied,
|
||||
)
|
||||
|
||||
def _keyword_scores(
|
||||
self, text: str, regulation_id: Optional[str]
|
||||
) -> dict[str, tuple[int, int, float]]:
|
||||
"""Compute keyword match scores for all patterns.
|
||||
|
||||
Returns dict: pattern_id → (hits, total_keywords, confidence).
|
||||
"""
|
||||
text_lower = text.lower()
|
||||
hits_by_pattern: dict[str, int] = {}
|
||||
|
||||
for keyword, pattern_ids in self._keyword_index.items():
|
||||
if keyword in text_lower:
|
||||
for pid in pattern_ids:
|
||||
hits_by_pattern[pid] = hits_by_pattern.get(pid, 0) + 1
|
||||
|
||||
result: dict[str, tuple[int, int, float]] = {}
|
||||
for pid, hits in hits_by_pattern.items():
|
||||
pattern = self._by_id.get(pid)
|
||||
if not pattern:
|
||||
continue
|
||||
total = len(pattern.obligation_match_keywords)
|
||||
confidence = hits / total if total > 0 else 0.0
|
||||
result[pid] = (hits, total, confidence)
|
||||
|
||||
return result
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# Tier 2: Embedding Match
|
||||
# -----------------------------------------------------------------------
|
||||
|
||||
async def _tier2_embedding(
|
||||
self, obligation_text: str, regulation_id: Optional[str]
|
||||
) -> Optional[PatternMatchResult]:
|
||||
"""Match by embedding similarity against pattern objective_templates."""
|
||||
scores = await self._embedding_scores(obligation_text, regulation_id)
|
||||
if not scores:
|
||||
return None
|
||||
|
||||
best_pid = max(scores, key=lambda pid: scores[pid][0])
|
||||
emb_score, bonus_applied = scores[best_pid]
|
||||
|
||||
if emb_score < EMBEDDING_PATTERN_THRESHOLD:
|
||||
return None
|
||||
|
||||
pattern = self._by_id.get(best_pid)
|
||||
if not pattern:
|
||||
return None
|
||||
|
||||
return PatternMatchResult(
|
||||
pattern=pattern,
|
||||
pattern_id=best_pid,
|
||||
method="embedding",
|
||||
confidence=min(emb_score, 1.0),
|
||||
embedding_score=emb_score,
|
||||
domain_bonus_applied=bonus_applied,
|
||||
)
|
||||
|
||||
async def _embedding_scores(
|
||||
self, obligation_text: str, regulation_id: Optional[str]
|
||||
) -> dict[str, tuple[float, bool]]:
|
||||
"""Compute embedding similarity scores for all patterns.
|
||||
|
||||
Returns dict: pattern_id → (score, domain_bonus_applied).
|
||||
"""
|
||||
if not self._pattern_embeddings:
|
||||
return {}
|
||||
|
||||
chunk_embedding = await _get_embedding(obligation_text[:2000])
|
||||
if not chunk_embedding:
|
||||
return {}
|
||||
|
||||
result: dict[str, tuple[float, bool]] = {}
|
||||
for i, pat_emb in enumerate(self._pattern_embeddings):
|
||||
if not pat_emb:
|
||||
continue
|
||||
pid = self._pattern_ids[i]
|
||||
pattern = self._by_id.get(pid)
|
||||
if not pattern:
|
||||
continue
|
||||
|
||||
score = _cosine_sim(chunk_embedding, pat_emb)
|
||||
|
||||
# Domain bonus
|
||||
bonus_applied = False
|
||||
if regulation_id and self._domain_matches(pattern.domain, regulation_id):
|
||||
score += DOMAIN_BONUS
|
||||
bonus_applied = True
|
||||
|
||||
result[pid] = (score, bonus_applied)
|
||||
|
||||
return result
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# Score combination
|
||||
# -----------------------------------------------------------------------
|
||||
|
||||
def _combine_results(
|
||||
self,
|
||||
keyword_result: Optional[PatternMatchResult],
|
||||
embedding_result: Optional[PatternMatchResult],
|
||||
) -> PatternMatchResult:
|
||||
"""Combine keyword and embedding results into the best match."""
|
||||
if not keyword_result and not embedding_result:
|
||||
return PatternMatchResult()
|
||||
|
||||
if not keyword_result:
|
||||
return embedding_result
|
||||
if not embedding_result:
|
||||
return keyword_result
|
||||
|
||||
# Both matched — check if they agree
|
||||
if keyword_result.pattern_id == embedding_result.pattern_id:
|
||||
# Same pattern: boost confidence
|
||||
combined_confidence = min(
|
||||
max(keyword_result.confidence, embedding_result.confidence) + 0.05,
|
||||
1.0,
|
||||
)
|
||||
return PatternMatchResult(
|
||||
pattern=keyword_result.pattern,
|
||||
pattern_id=keyword_result.pattern_id,
|
||||
method="combined",
|
||||
confidence=combined_confidence,
|
||||
keyword_hits=keyword_result.keyword_hits,
|
||||
total_keywords=keyword_result.total_keywords,
|
||||
embedding_score=embedding_result.embedding_score,
|
||||
domain_bonus_applied=(
|
||||
keyword_result.domain_bonus_applied
|
||||
or embedding_result.domain_bonus_applied
|
||||
),
|
||||
)
|
||||
|
||||
# Different patterns: pick the one with higher confidence
|
||||
if keyword_result.confidence >= embedding_result.confidence:
|
||||
return keyword_result
|
||||
return embedding_result
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# Domain affinity
|
||||
# -----------------------------------------------------------------------
|
||||
|
||||
@staticmethod
|
||||
def _domain_matches(pattern_domain: str, regulation_id: str) -> bool:
|
||||
"""Check if a pattern's domain has affinity with a regulation."""
|
||||
affine_domains = _REGULATION_DOMAIN_AFFINITY.get(regulation_id, [])
|
||||
return pattern_domain in affine_domains
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# Initialization helpers
|
||||
# -----------------------------------------------------------------------
|
||||
|
||||
def _load_patterns(self) -> None:
|
||||
"""Load control patterns from YAML files."""
|
||||
patterns_dir = _find_patterns_dir()
|
||||
if not patterns_dir:
|
||||
logger.warning("Control patterns directory not found")
|
||||
return
|
||||
|
||||
for yaml_file in sorted(patterns_dir.glob("*.yaml")):
|
||||
if yaml_file.name.startswith("_"):
|
||||
continue
|
||||
try:
|
||||
with open(yaml_file) as f:
|
||||
data = yaml.safe_load(f)
|
||||
if not data or "patterns" not in data:
|
||||
continue
|
||||
for p in data["patterns"]:
|
||||
pattern = ControlPattern(
|
||||
id=p["id"],
|
||||
name=p["name"],
|
||||
name_de=p["name_de"],
|
||||
domain=p["domain"],
|
||||
category=p["category"],
|
||||
description=p["description"],
|
||||
objective_template=p["objective_template"],
|
||||
rationale_template=p["rationale_template"],
|
||||
requirements_template=p.get("requirements_template", []),
|
||||
test_procedure_template=p.get("test_procedure_template", []),
|
||||
evidence_template=p.get("evidence_template", []),
|
||||
severity_default=p.get("severity_default", "medium"),
|
||||
implementation_effort_default=p.get("implementation_effort_default", "m"),
|
||||
obligation_match_keywords=p.get("obligation_match_keywords", []),
|
||||
tags=p.get("tags", []),
|
||||
composable_with=p.get("composable_with", []),
|
||||
open_anchor_refs=p.get("open_anchor_refs", []),
|
||||
)
|
||||
self._patterns.append(pattern)
|
||||
self._by_id[pattern.id] = pattern
|
||||
domain_list = self._by_domain.setdefault(pattern.domain, [])
|
||||
domain_list.append(pattern)
|
||||
except Exception as e:
|
||||
logger.error("Failed to load %s: %s", yaml_file.name, e)
|
||||
|
||||
logger.info("Loaded %d patterns from %s", len(self._patterns), patterns_dir)
|
||||
|
||||
def _build_keyword_index(self) -> None:
|
||||
"""Build reverse index: keyword → [pattern_ids]."""
|
||||
for pattern in self._patterns:
|
||||
for kw in pattern.obligation_match_keywords:
|
||||
lower_kw = kw.lower()
|
||||
if lower_kw not in self._keyword_index:
|
||||
self._keyword_index[lower_kw] = []
|
||||
self._keyword_index[lower_kw].append(pattern.id)
|
||||
|
||||
async def _compute_embeddings(self) -> None:
|
||||
"""Compute embeddings for all pattern objective templates."""
|
||||
if not self._patterns:
|
||||
return
|
||||
|
||||
self._pattern_ids = [p.id for p in self._patterns]
|
||||
texts = [
|
||||
f"{p.name_de}: {p.objective_template}"
|
||||
for p in self._patterns
|
||||
]
|
||||
|
||||
logger.info("Computing embeddings for %d patterns...", len(texts))
|
||||
self._pattern_embeddings = await _get_embeddings_batch(texts)
|
||||
valid = sum(1 for e in self._pattern_embeddings if e)
|
||||
logger.info("Got %d/%d valid pattern embeddings", valid, len(texts))
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# Public helpers
|
||||
# -----------------------------------------------------------------------
|
||||
|
||||
def get_pattern(self, pattern_id: str) -> Optional[ControlPattern]:
|
||||
"""Get a pattern by its ID."""
|
||||
return self._by_id.get(pattern_id.upper())
|
||||
|
||||
def get_patterns_by_domain(self, domain: str) -> list[ControlPattern]:
|
||||
"""Get all patterns for a domain."""
|
||||
return self._by_domain.get(domain.upper(), [])
|
||||
|
||||
def stats(self) -> dict:
|
||||
"""Return matcher statistics."""
|
||||
return {
|
||||
"total_patterns": len(self._patterns),
|
||||
"domains": list(self._by_domain.keys()),
|
||||
"keywords": len(self._keyword_index),
|
||||
"embeddings_valid": sum(1 for e in self._pattern_embeddings if e),
|
||||
"initialized": self._initialized,
|
||||
}
|
||||
|
||||
|
||||
def _find_patterns_dir() -> Optional[Path]:
|
||||
"""Locate the control_patterns directory."""
|
||||
candidates = [
|
||||
Path(__file__).resolve().parent.parent.parent.parent
|
||||
/ "ai-compliance-sdk" / "policies" / "control_patterns",
|
||||
Path("/app/ai-compliance-sdk/policies/control_patterns"),
|
||||
Path("ai-compliance-sdk/policies/control_patterns"),
|
||||
]
|
||||
for p in candidates:
|
||||
if p.is_dir():
|
||||
return p
|
||||
return None
|
||||
670
backend-compliance/compliance/services/pipeline_adapter.py
Normal file
670
backend-compliance/compliance/services/pipeline_adapter.py
Normal file
@@ -0,0 +1,670 @@
|
||||
"""Pipeline Adapter — New 10-Stage Pipeline Integration.
|
||||
|
||||
Bridges the existing 7-stage control_generator pipeline with the new
|
||||
multi-layer components (ObligationExtractor, PatternMatcher, ControlComposer).
|
||||
|
||||
New pipeline flow:
|
||||
chunk → license_classify
|
||||
→ obligation_extract (Stage 4 — NEW)
|
||||
→ pattern_match (Stage 5 — NEW)
|
||||
→ control_compose (Stage 6 — replaces old Stage 3)
|
||||
→ harmonize → anchor → store + crosswalk → mark processed
|
||||
|
||||
Can be used in two modes:
|
||||
1. INLINE: Called from _process_batch() to enrich the pipeline
|
||||
2. STANDALONE: Process chunks directly through new stages
|
||||
|
||||
Part of the Multi-Layer Control Architecture (Phase 7 of 8).
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from compliance.services.control_composer import ComposedControl, ControlComposer
|
||||
from compliance.services.obligation_extractor import ObligationExtractor, ObligationMatch
|
||||
from compliance.services.pattern_matcher import PatternMatcher, PatternMatchResult
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PipelineChunk:
|
||||
"""Input chunk for the new pipeline stages."""
|
||||
|
||||
text: str
|
||||
collection: str = ""
|
||||
regulation_code: str = ""
|
||||
article: Optional[str] = None
|
||||
paragraph: Optional[str] = None
|
||||
license_rule: int = 3
|
||||
license_info: dict = field(default_factory=dict)
|
||||
source_citation: Optional[dict] = None
|
||||
chunk_hash: str = ""
|
||||
|
||||
def compute_hash(self) -> str:
|
||||
if not self.chunk_hash:
|
||||
self.chunk_hash = hashlib.sha256(self.text.encode()).hexdigest()
|
||||
return self.chunk_hash
|
||||
|
||||
|
||||
@dataclass
|
||||
class PipelineResult:
|
||||
"""Result of processing a chunk through the new pipeline."""
|
||||
|
||||
chunk: PipelineChunk
|
||||
obligation: ObligationMatch = field(default_factory=ObligationMatch)
|
||||
pattern_result: PatternMatchResult = field(default_factory=PatternMatchResult)
|
||||
control: Optional[ComposedControl] = None
|
||||
crosswalk_written: bool = False
|
||||
error: Optional[str] = None
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
"chunk_hash": self.chunk.chunk_hash,
|
||||
"obligation": self.obligation.to_dict() if self.obligation else None,
|
||||
"pattern": self.pattern_result.to_dict() if self.pattern_result else None,
|
||||
"control": self.control.to_dict() if self.control else None,
|
||||
"crosswalk_written": self.crosswalk_written,
|
||||
"error": self.error,
|
||||
}
|
||||
|
||||
|
||||
class PipelineAdapter:
|
||||
"""Integrates ObligationExtractor + PatternMatcher + ControlComposer.
|
||||
|
||||
Usage::
|
||||
|
||||
adapter = PipelineAdapter(db)
|
||||
await adapter.initialize()
|
||||
|
||||
result = await adapter.process_chunk(PipelineChunk(
|
||||
text="...",
|
||||
regulation_code="eu_2016_679",
|
||||
article="Art. 30",
|
||||
license_rule=1,
|
||||
))
|
||||
"""
|
||||
|
||||
def __init__(self, db: Optional[Session] = None):
|
||||
self.db = db
|
||||
self._extractor = ObligationExtractor()
|
||||
self._matcher = PatternMatcher()
|
||||
self._composer = ControlComposer()
|
||||
self._initialized = False
|
||||
|
||||
async def initialize(self) -> None:
|
||||
"""Initialize all sub-components."""
|
||||
if self._initialized:
|
||||
return
|
||||
await self._extractor.initialize()
|
||||
await self._matcher.initialize()
|
||||
self._initialized = True
|
||||
logger.info("PipelineAdapter initialized")
|
||||
|
||||
async def process_chunk(self, chunk: PipelineChunk) -> PipelineResult:
|
||||
"""Process a single chunk through the new 3-stage pipeline.
|
||||
|
||||
Stage 4: Obligation Extract
|
||||
Stage 5: Pattern Match
|
||||
Stage 6: Control Compose
|
||||
"""
|
||||
if not self._initialized:
|
||||
await self.initialize()
|
||||
|
||||
chunk.compute_hash()
|
||||
result = PipelineResult(chunk=chunk)
|
||||
|
||||
try:
|
||||
# Stage 4: Obligation Extract
|
||||
result.obligation = await self._extractor.extract(
|
||||
chunk_text=chunk.text,
|
||||
regulation_code=chunk.regulation_code,
|
||||
article=chunk.article,
|
||||
paragraph=chunk.paragraph,
|
||||
)
|
||||
|
||||
# Stage 5: Pattern Match
|
||||
obligation_text = (
|
||||
result.obligation.obligation_text
|
||||
or result.obligation.obligation_title
|
||||
or chunk.text[:500]
|
||||
)
|
||||
result.pattern_result = await self._matcher.match(
|
||||
obligation_text=obligation_text,
|
||||
regulation_id=result.obligation.regulation_id,
|
||||
)
|
||||
|
||||
# Stage 6: Control Compose
|
||||
result.control = await self._composer.compose(
|
||||
obligation=result.obligation,
|
||||
pattern_result=result.pattern_result,
|
||||
chunk_text=chunk.text if chunk.license_rule in (1, 2) else None,
|
||||
license_rule=chunk.license_rule,
|
||||
source_citation=chunk.source_citation,
|
||||
regulation_code=chunk.regulation_code,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Pipeline processing failed: %s", e)
|
||||
result.error = str(e)
|
||||
|
||||
return result
|
||||
|
||||
async def process_batch(self, chunks: list[PipelineChunk]) -> list[PipelineResult]:
|
||||
"""Process multiple chunks through the pipeline."""
|
||||
results = []
|
||||
for chunk in chunks:
|
||||
result = await self.process_chunk(chunk)
|
||||
results.append(result)
|
||||
return results
|
||||
|
||||
def write_crosswalk(self, result: PipelineResult, control_uuid: str) -> bool:
|
||||
"""Write obligation_extraction + crosswalk_matrix rows for a processed chunk.
|
||||
|
||||
Called AFTER the control is stored in canonical_controls.
|
||||
"""
|
||||
if not self.db or not result.control:
|
||||
return False
|
||||
|
||||
chunk = result.chunk
|
||||
obligation = result.obligation
|
||||
pattern = result.pattern_result
|
||||
|
||||
try:
|
||||
# 1. Write obligation_extraction row
|
||||
self.db.execute(
|
||||
text("""
|
||||
INSERT INTO obligation_extractions (
|
||||
chunk_hash, collection, regulation_code,
|
||||
article, paragraph, obligation_id,
|
||||
obligation_text, confidence, extraction_method,
|
||||
pattern_id, pattern_match_score, control_uuid
|
||||
) VALUES (
|
||||
:chunk_hash, :collection, :regulation_code,
|
||||
:article, :paragraph, :obligation_id,
|
||||
:obligation_text, :confidence, :extraction_method,
|
||||
:pattern_id, :pattern_match_score,
|
||||
CAST(:control_uuid AS uuid)
|
||||
)
|
||||
"""),
|
||||
{
|
||||
"chunk_hash": chunk.chunk_hash,
|
||||
"collection": chunk.collection,
|
||||
"regulation_code": chunk.regulation_code,
|
||||
"article": chunk.article,
|
||||
"paragraph": chunk.paragraph,
|
||||
"obligation_id": obligation.obligation_id if obligation else None,
|
||||
"obligation_text": (
|
||||
obligation.obligation_text[:2000]
|
||||
if obligation and obligation.obligation_text
|
||||
else None
|
||||
),
|
||||
"confidence": obligation.confidence if obligation else 0,
|
||||
"extraction_method": obligation.method if obligation else "none",
|
||||
"pattern_id": pattern.pattern_id if pattern else None,
|
||||
"pattern_match_score": pattern.confidence if pattern else 0,
|
||||
"control_uuid": control_uuid,
|
||||
},
|
||||
)
|
||||
|
||||
# 2. Write crosswalk_matrix row
|
||||
self.db.execute(
|
||||
text("""
|
||||
INSERT INTO crosswalk_matrix (
|
||||
regulation_code, article, paragraph,
|
||||
obligation_id, pattern_id,
|
||||
master_control_id, master_control_uuid,
|
||||
confidence, source
|
||||
) VALUES (
|
||||
:regulation_code, :article, :paragraph,
|
||||
:obligation_id, :pattern_id,
|
||||
:master_control_id,
|
||||
CAST(:master_control_uuid AS uuid),
|
||||
:confidence, :source
|
||||
)
|
||||
"""),
|
||||
{
|
||||
"regulation_code": chunk.regulation_code,
|
||||
"article": chunk.article,
|
||||
"paragraph": chunk.paragraph,
|
||||
"obligation_id": obligation.obligation_id if obligation else None,
|
||||
"pattern_id": pattern.pattern_id if pattern else None,
|
||||
"master_control_id": result.control.control_id,
|
||||
"master_control_uuid": control_uuid,
|
||||
"confidence": min(
|
||||
obligation.confidence if obligation else 0,
|
||||
pattern.confidence if pattern else 0,
|
||||
),
|
||||
"source": "auto",
|
||||
},
|
||||
)
|
||||
|
||||
# 3. Update canonical_controls with pattern_id + obligation_ids
|
||||
if result.control.pattern_id or result.control.obligation_ids:
|
||||
self.db.execute(
|
||||
text("""
|
||||
UPDATE canonical_controls
|
||||
SET pattern_id = COALESCE(:pattern_id, pattern_id),
|
||||
obligation_ids = COALESCE(:obligation_ids, obligation_ids)
|
||||
WHERE id = CAST(:control_uuid AS uuid)
|
||||
"""),
|
||||
{
|
||||
"pattern_id": result.control.pattern_id,
|
||||
"obligation_ids": json.dumps(result.control.obligation_ids),
|
||||
"control_uuid": control_uuid,
|
||||
},
|
||||
)
|
||||
|
||||
self.db.commit()
|
||||
result.crosswalk_written = True
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to write crosswalk: %s", e)
|
||||
self.db.rollback()
|
||||
return False
|
||||
|
||||
def stats(self) -> dict:
|
||||
"""Return component statistics."""
|
||||
return {
|
||||
"extractor": self._extractor.stats(),
|
||||
"matcher": self._matcher.stats(),
|
||||
"initialized": self._initialized,
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Migration Passes — Backfill existing 4,800+ controls
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class MigrationPasses:
|
||||
"""Non-destructive migration passes for existing controls.
|
||||
|
||||
Pass 1: Obligation Linkage (deterministic, article→obligation lookup)
|
||||
Pass 2: Pattern Classification (keyword-based matching)
|
||||
Pass 3: Quality Triage (categorize by linkage completeness)
|
||||
Pass 4: Crosswalk Backfill (write crosswalk rows for linked controls)
|
||||
Pass 5: Deduplication (mark duplicate controls)
|
||||
|
||||
Usage::
|
||||
|
||||
migration = MigrationPasses(db)
|
||||
await migration.initialize()
|
||||
|
||||
result = await migration.run_pass1_obligation_linkage(limit=100)
|
||||
result = await migration.run_pass2_pattern_classification(limit=100)
|
||||
result = migration.run_pass3_quality_triage()
|
||||
result = migration.run_pass4_crosswalk_backfill()
|
||||
result = migration.run_pass5_deduplication()
|
||||
"""
|
||||
|
||||
def __init__(self, db: Session):
|
||||
self.db = db
|
||||
self._extractor = ObligationExtractor()
|
||||
self._matcher = PatternMatcher()
|
||||
self._initialized = False
|
||||
|
||||
async def initialize(self) -> None:
|
||||
"""Initialize extractors (loads obligations + patterns)."""
|
||||
if self._initialized:
|
||||
return
|
||||
self._extractor._load_obligations()
|
||||
self._matcher._load_patterns()
|
||||
self._matcher._build_keyword_index()
|
||||
self._initialized = True
|
||||
|
||||
# -------------------------------------------------------------------
|
||||
# Pass 1: Obligation Linkage (deterministic)
|
||||
# -------------------------------------------------------------------
|
||||
|
||||
async def run_pass1_obligation_linkage(self, limit: int = 0) -> dict:
|
||||
"""Link existing controls to obligations via source_citation article.
|
||||
|
||||
For each control with source_citation → extract regulation + article
|
||||
→ look up in obligation framework → set obligation_ids.
|
||||
"""
|
||||
if not self._initialized:
|
||||
await self.initialize()
|
||||
|
||||
query = """
|
||||
SELECT id, control_id, source_citation, generation_metadata
|
||||
FROM canonical_controls
|
||||
WHERE release_state NOT IN ('deprecated')
|
||||
AND (obligation_ids IS NULL OR obligation_ids = '[]')
|
||||
"""
|
||||
if limit > 0:
|
||||
query += f" LIMIT {limit}"
|
||||
|
||||
rows = self.db.execute(text(query)).fetchall()
|
||||
|
||||
stats = {"total": len(rows), "linked": 0, "no_match": 0, "no_citation": 0}
|
||||
|
||||
for row in rows:
|
||||
control_uuid = str(row[0])
|
||||
control_id = row[1]
|
||||
citation = row[2]
|
||||
metadata = row[3]
|
||||
|
||||
# Extract regulation + article from citation or metadata
|
||||
reg_code, article = _extract_regulation_article(citation, metadata)
|
||||
if not reg_code:
|
||||
stats["no_citation"] += 1
|
||||
continue
|
||||
|
||||
# Tier 1: Exact match
|
||||
match = self._extractor._tier1_exact(reg_code, article or "")
|
||||
if match and match.obligation_id:
|
||||
self.db.execute(
|
||||
text("""
|
||||
UPDATE canonical_controls
|
||||
SET obligation_ids = :obl_ids
|
||||
WHERE id = CAST(:uuid AS uuid)
|
||||
"""),
|
||||
{
|
||||
"obl_ids": json.dumps([match.obligation_id]),
|
||||
"uuid": control_uuid,
|
||||
},
|
||||
)
|
||||
stats["linked"] += 1
|
||||
else:
|
||||
stats["no_match"] += 1
|
||||
|
||||
self.db.commit()
|
||||
logger.info("Pass 1: %s", stats)
|
||||
return stats
|
||||
|
||||
# -------------------------------------------------------------------
|
||||
# Pass 2: Pattern Classification (keyword-based)
|
||||
# -------------------------------------------------------------------
|
||||
|
||||
async def run_pass2_pattern_classification(self, limit: int = 0) -> dict:
|
||||
"""Classify existing controls into patterns via keyword matching.
|
||||
|
||||
For each control without pattern_id → keyword-match title+objective
|
||||
against pattern library → assign best match.
|
||||
"""
|
||||
if not self._initialized:
|
||||
await self.initialize()
|
||||
|
||||
query = """
|
||||
SELECT id, control_id, title, objective
|
||||
FROM canonical_controls
|
||||
WHERE release_state NOT IN ('deprecated')
|
||||
AND (pattern_id IS NULL OR pattern_id = '')
|
||||
"""
|
||||
if limit > 0:
|
||||
query += f" LIMIT {limit}"
|
||||
|
||||
rows = self.db.execute(text(query)).fetchall()
|
||||
|
||||
stats = {"total": len(rows), "classified": 0, "no_match": 0}
|
||||
|
||||
for row in rows:
|
||||
control_uuid = str(row[0])
|
||||
title = row[2] or ""
|
||||
objective = row[3] or ""
|
||||
|
||||
# Keyword match
|
||||
match_text = f"{title} {objective}"
|
||||
result = self._matcher._tier1_keyword(match_text, None)
|
||||
|
||||
if result and result.pattern_id and result.keyword_hits >= 2:
|
||||
self.db.execute(
|
||||
text("""
|
||||
UPDATE canonical_controls
|
||||
SET pattern_id = :pattern_id
|
||||
WHERE id = CAST(:uuid AS uuid)
|
||||
"""),
|
||||
{
|
||||
"pattern_id": result.pattern_id,
|
||||
"uuid": control_uuid,
|
||||
},
|
||||
)
|
||||
stats["classified"] += 1
|
||||
else:
|
||||
stats["no_match"] += 1
|
||||
|
||||
self.db.commit()
|
||||
logger.info("Pass 2: %s", stats)
|
||||
return stats
|
||||
|
||||
# -------------------------------------------------------------------
|
||||
# Pass 3: Quality Triage
|
||||
# -------------------------------------------------------------------
|
||||
|
||||
def run_pass3_quality_triage(self) -> dict:
|
||||
"""Categorize controls by linkage completeness.
|
||||
|
||||
Sets generation_metadata.triage_status:
|
||||
- "review": has both obligation_id + pattern_id
|
||||
- "needs_obligation": has pattern_id but no obligation_id
|
||||
- "needs_pattern": has obligation_id but no pattern_id
|
||||
- "legacy_unlinked": has neither
|
||||
"""
|
||||
categories = {
|
||||
"review": """
|
||||
UPDATE canonical_controls
|
||||
SET generation_metadata = jsonb_set(
|
||||
COALESCE(generation_metadata::jsonb, '{}'::jsonb),
|
||||
'{triage_status}', '"review"'
|
||||
)
|
||||
WHERE release_state NOT IN ('deprecated')
|
||||
AND obligation_ids IS NOT NULL AND obligation_ids != '[]'
|
||||
AND pattern_id IS NOT NULL AND pattern_id != ''
|
||||
""",
|
||||
"needs_obligation": """
|
||||
UPDATE canonical_controls
|
||||
SET generation_metadata = jsonb_set(
|
||||
COALESCE(generation_metadata::jsonb, '{}'::jsonb),
|
||||
'{triage_status}', '"needs_obligation"'
|
||||
)
|
||||
WHERE release_state NOT IN ('deprecated')
|
||||
AND (obligation_ids IS NULL OR obligation_ids = '[]')
|
||||
AND pattern_id IS NOT NULL AND pattern_id != ''
|
||||
""",
|
||||
"needs_pattern": """
|
||||
UPDATE canonical_controls
|
||||
SET generation_metadata = jsonb_set(
|
||||
COALESCE(generation_metadata::jsonb, '{}'::jsonb),
|
||||
'{triage_status}', '"needs_pattern"'
|
||||
)
|
||||
WHERE release_state NOT IN ('deprecated')
|
||||
AND obligation_ids IS NOT NULL AND obligation_ids != '[]'
|
||||
AND (pattern_id IS NULL OR pattern_id = '')
|
||||
""",
|
||||
"legacy_unlinked": """
|
||||
UPDATE canonical_controls
|
||||
SET generation_metadata = jsonb_set(
|
||||
COALESCE(generation_metadata::jsonb, '{}'::jsonb),
|
||||
'{triage_status}', '"legacy_unlinked"'
|
||||
)
|
||||
WHERE release_state NOT IN ('deprecated')
|
||||
AND (obligation_ids IS NULL OR obligation_ids = '[]')
|
||||
AND (pattern_id IS NULL OR pattern_id = '')
|
||||
""",
|
||||
}
|
||||
|
||||
stats = {}
|
||||
for category, sql in categories.items():
|
||||
result = self.db.execute(text(sql))
|
||||
stats[category] = result.rowcount
|
||||
|
||||
self.db.commit()
|
||||
logger.info("Pass 3: %s", stats)
|
||||
return stats
|
||||
|
||||
# -------------------------------------------------------------------
|
||||
# Pass 4: Crosswalk Backfill
|
||||
# -------------------------------------------------------------------
|
||||
|
||||
def run_pass4_crosswalk_backfill(self) -> dict:
|
||||
"""Create crosswalk_matrix rows for controls with obligation + pattern.
|
||||
|
||||
Only creates rows that don't already exist.
|
||||
"""
|
||||
result = self.db.execute(text("""
|
||||
INSERT INTO crosswalk_matrix (
|
||||
regulation_code, obligation_id, pattern_id,
|
||||
master_control_id, master_control_uuid,
|
||||
confidence, source
|
||||
)
|
||||
SELECT
|
||||
COALESCE(
|
||||
(generation_metadata::jsonb->>'source_regulation'),
|
||||
''
|
||||
) AS regulation_code,
|
||||
obl.value::text AS obligation_id,
|
||||
cc.pattern_id,
|
||||
cc.control_id,
|
||||
cc.id,
|
||||
0.80,
|
||||
'migrated'
|
||||
FROM canonical_controls cc,
|
||||
jsonb_array_elements_text(
|
||||
COALESCE(cc.obligation_ids::jsonb, '[]'::jsonb)
|
||||
) AS obl(value)
|
||||
WHERE cc.release_state NOT IN ('deprecated')
|
||||
AND cc.pattern_id IS NOT NULL AND cc.pattern_id != ''
|
||||
AND cc.obligation_ids IS NOT NULL AND cc.obligation_ids != '[]'
|
||||
AND NOT EXISTS (
|
||||
SELECT 1 FROM crosswalk_matrix cw
|
||||
WHERE cw.master_control_uuid = cc.id
|
||||
AND cw.obligation_id = obl.value::text
|
||||
)
|
||||
"""))
|
||||
|
||||
rows_inserted = result.rowcount
|
||||
self.db.commit()
|
||||
logger.info("Pass 4: %d crosswalk rows inserted", rows_inserted)
|
||||
return {"rows_inserted": rows_inserted}
|
||||
|
||||
# -------------------------------------------------------------------
|
||||
# Pass 5: Deduplication
|
||||
# -------------------------------------------------------------------
|
||||
|
||||
def run_pass5_deduplication(self) -> dict:
|
||||
"""Mark duplicate controls (same obligation + same pattern).
|
||||
|
||||
Groups controls by (obligation_id, pattern_id), keeps the one with
|
||||
highest evidence_confidence (or newest), marks rest as deprecated.
|
||||
"""
|
||||
# Find groups with duplicates
|
||||
groups = self.db.execute(text("""
|
||||
SELECT cc.pattern_id,
|
||||
obl.value::text AS obligation_id,
|
||||
array_agg(cc.id ORDER BY cc.evidence_confidence DESC NULLS LAST, cc.created_at DESC) AS ids,
|
||||
count(*) AS cnt
|
||||
FROM canonical_controls cc,
|
||||
jsonb_array_elements_text(
|
||||
COALESCE(cc.obligation_ids::jsonb, '[]'::jsonb)
|
||||
) AS obl(value)
|
||||
WHERE cc.release_state NOT IN ('deprecated')
|
||||
AND cc.pattern_id IS NOT NULL AND cc.pattern_id != ''
|
||||
GROUP BY cc.pattern_id, obl.value::text
|
||||
HAVING count(*) > 1
|
||||
""")).fetchall()
|
||||
|
||||
stats = {"groups_found": len(groups), "controls_deprecated": 0}
|
||||
|
||||
for group in groups:
|
||||
ids = group[2] # Array of UUIDs, first is the keeper
|
||||
if len(ids) <= 1:
|
||||
continue
|
||||
|
||||
# Keep first (highest confidence), deprecate rest
|
||||
deprecate_ids = ids[1:]
|
||||
for dep_id in deprecate_ids:
|
||||
self.db.execute(
|
||||
text("""
|
||||
UPDATE canonical_controls
|
||||
SET release_state = 'deprecated',
|
||||
generation_metadata = jsonb_set(
|
||||
COALESCE(generation_metadata::jsonb, '{}'::jsonb),
|
||||
'{deprecated_reason}', '"duplicate_same_obligation_pattern"'
|
||||
)
|
||||
WHERE id = CAST(:uuid AS uuid)
|
||||
AND release_state != 'deprecated'
|
||||
"""),
|
||||
{"uuid": str(dep_id)},
|
||||
)
|
||||
stats["controls_deprecated"] += 1
|
||||
|
||||
self.db.commit()
|
||||
logger.info("Pass 5: %s", stats)
|
||||
return stats
|
||||
|
||||
def migration_status(self) -> dict:
|
||||
"""Return overall migration progress."""
|
||||
row = self.db.execute(text("""
|
||||
SELECT
|
||||
count(*) AS total,
|
||||
count(*) FILTER (WHERE obligation_ids IS NOT NULL AND obligation_ids != '[]') AS has_obligation,
|
||||
count(*) FILTER (WHERE pattern_id IS NOT NULL AND pattern_id != '') AS has_pattern,
|
||||
count(*) FILTER (
|
||||
WHERE obligation_ids IS NOT NULL AND obligation_ids != '[]'
|
||||
AND pattern_id IS NOT NULL AND pattern_id != ''
|
||||
) AS fully_linked,
|
||||
count(*) FILTER (WHERE release_state = 'deprecated') AS deprecated
|
||||
FROM canonical_controls
|
||||
""")).fetchone()
|
||||
|
||||
return {
|
||||
"total_controls": row[0],
|
||||
"has_obligation": row[1],
|
||||
"has_pattern": row[2],
|
||||
"fully_linked": row[3],
|
||||
"deprecated": row[4],
|
||||
"coverage_obligation_pct": round(row[1] / max(row[0], 1) * 100, 1),
|
||||
"coverage_pattern_pct": round(row[2] / max(row[0], 1) * 100, 1),
|
||||
"coverage_full_pct": round(row[3] / max(row[0], 1) * 100, 1),
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _extract_regulation_article(
|
||||
citation: Optional[str], metadata: Optional[str]
|
||||
) -> tuple[Optional[str], Optional[str]]:
|
||||
"""Extract regulation_code and article from control's citation/metadata."""
|
||||
from compliance.services.obligation_extractor import _normalize_regulation
|
||||
|
||||
reg_code = None
|
||||
article = None
|
||||
|
||||
# Try citation first (JSON string or dict)
|
||||
if citation:
|
||||
try:
|
||||
c = json.loads(citation) if isinstance(citation, str) else citation
|
||||
if isinstance(c, dict):
|
||||
article = c.get("article") or c.get("source_article")
|
||||
# Try to get regulation from source field
|
||||
source = c.get("source", "")
|
||||
if source:
|
||||
reg_code = _normalize_regulation(source)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
pass
|
||||
|
||||
# Try metadata
|
||||
if metadata and not reg_code:
|
||||
try:
|
||||
m = json.loads(metadata) if isinstance(metadata, str) else metadata
|
||||
if isinstance(m, dict):
|
||||
src_reg = m.get("source_regulation", "")
|
||||
if src_reg:
|
||||
reg_code = _normalize_regulation(src_reg)
|
||||
if not article:
|
||||
article = m.get("source_article")
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
pass
|
||||
|
||||
return reg_code, article
|
||||
Reference in New Issue
Block a user