From e3ab428b91348215f212787055b351aa3fe673b8 Mon Sep 17 00:00:00 2001 From: Benjamin Admin Date: Thu, 9 Apr 2026 14:40:47 +0200 Subject: [PATCH] feat: control-pipeline Service aus Compliance-Repo migriert Control-Pipeline (Pass 0a/0b, BatchDedup, Generator) als eigenstaendiger Service in Core, damit Compliance-Repo unabhaengig refakturiert werden kann. Schreibt weiterhin ins compliance-Schema der shared PostgreSQL. Co-Authored-By: Claude Opus 4.6 (1M context) --- control-pipeline/Dockerfile | 19 + control-pipeline/api/__init__.py | 8 + .../api/canonical_control_routes.py | 2012 +++++++++ .../api/control_generator_routes.py | 1100 +++++ control-pipeline/config.py | 67 + control-pipeline/data/__init__.py | 0 .../data/source_type_classification.py | 205 + control-pipeline/db/__init__.py | 0 control-pipeline/db/session.py | 37 + control-pipeline/main.py | 88 + control-pipeline/requirements.txt | 22 + control-pipeline/services/__init__.py | 0 control-pipeline/services/anchor_finder.py | 187 + .../services/batch_dedup_runner.py | 618 +++ .../services/citation_backfill.py | 438 ++ control-pipeline/services/control_composer.py | 546 +++ control-pipeline/services/control_dedup.py | 745 ++++ .../services/control_generator.py | 2249 ++++++++++ .../services/control_status_machine.py | 154 + .../services/decomposition_pass.py | 3877 +++++++++++++++++ .../services/framework_decomposition.py | 714 +++ control-pipeline/services/license_gate.py | 116 + control-pipeline/services/llm_provider.py | 624 +++ .../services/normative_patterns.py | 59 + .../services/obligation_extractor.py | 563 +++ control-pipeline/services/pattern_matcher.py | 532 +++ control-pipeline/services/pipeline_adapter.py | 670 +++ control-pipeline/services/rag_client.py | 213 + control-pipeline/services/reranker.py | 85 + .../services/similarity_detector.py | 223 + control-pipeline/services/v1_enrichment.py | 331 ++ control-pipeline/tests/__init__.py | 0 docker-compose.yml | 45 + nginx/conf.d/default.conf | 27 + 34 files changed, 16574 insertions(+) create mode 100644 control-pipeline/Dockerfile create mode 100644 control-pipeline/api/__init__.py create mode 100644 control-pipeline/api/canonical_control_routes.py create mode 100644 control-pipeline/api/control_generator_routes.py create mode 100644 control-pipeline/config.py create mode 100644 control-pipeline/data/__init__.py create mode 100644 control-pipeline/data/source_type_classification.py create mode 100644 control-pipeline/db/__init__.py create mode 100644 control-pipeline/db/session.py create mode 100644 control-pipeline/main.py create mode 100644 control-pipeline/requirements.txt create mode 100644 control-pipeline/services/__init__.py create mode 100644 control-pipeline/services/anchor_finder.py create mode 100644 control-pipeline/services/batch_dedup_runner.py create mode 100644 control-pipeline/services/citation_backfill.py create mode 100644 control-pipeline/services/control_composer.py create mode 100644 control-pipeline/services/control_dedup.py create mode 100644 control-pipeline/services/control_generator.py create mode 100644 control-pipeline/services/control_status_machine.py create mode 100644 control-pipeline/services/decomposition_pass.py create mode 100644 control-pipeline/services/framework_decomposition.py create mode 100644 control-pipeline/services/license_gate.py create mode 100644 control-pipeline/services/llm_provider.py create mode 100644 control-pipeline/services/normative_patterns.py create mode 100644 control-pipeline/services/obligation_extractor.py create mode 100644 control-pipeline/services/pattern_matcher.py create mode 100644 control-pipeline/services/pipeline_adapter.py create mode 100644 control-pipeline/services/rag_client.py create mode 100644 control-pipeline/services/reranker.py create mode 100644 control-pipeline/services/similarity_detector.py create mode 100644 control-pipeline/services/v1_enrichment.py create mode 100644 control-pipeline/tests/__init__.py diff --git a/control-pipeline/Dockerfile b/control-pipeline/Dockerfile new file mode 100644 index 0000000..ee16e97 --- /dev/null +++ b/control-pipeline/Dockerfile @@ -0,0 +1,19 @@ +FROM python:3.11-slim + +WORKDIR /app + +RUN apt-get update && apt-get install -y --no-install-recommends \ + curl \ + && rm -rf /var/lib/apt/lists/* + +COPY requirements.txt . +RUN pip install --no-cache-dir -r requirements.txt + +COPY . . + +EXPOSE 8098 + +HEALTHCHECK --interval=30s --timeout=10s --start-period=10s --retries=3 \ + CMD curl -f http://127.0.0.1:8098/health || exit 1 + +CMD ["python", "main.py"] diff --git a/control-pipeline/api/__init__.py b/control-pipeline/api/__init__.py new file mode 100644 index 0000000..decc0ad --- /dev/null +++ b/control-pipeline/api/__init__.py @@ -0,0 +1,8 @@ +from fastapi import APIRouter + +from api.control_generator_routes import router as generator_router +from api.canonical_control_routes import router as canonical_router + +router = APIRouter() +router.include_router(generator_router) +router.include_router(canonical_router) diff --git a/control-pipeline/api/canonical_control_routes.py b/control-pipeline/api/canonical_control_routes.py new file mode 100644 index 0000000..44ec0e5 --- /dev/null +++ b/control-pipeline/api/canonical_control_routes.py @@ -0,0 +1,2012 @@ +""" +FastAPI routes for the Canonical Control Library. + +Independently authored security controls anchored in open-source frameworks +(OWASP, NIST, ENISA). No proprietary nomenclature. + +Endpoints: + GET /v1/canonical/frameworks — All frameworks + GET /v1/canonical/frameworks/{framework_id} — Framework details + GET /v1/canonical/frameworks/{framework_id}/controls — Controls of a framework + GET /v1/canonical/controls — All controls (filterable) + GET /v1/canonical/controls/{control_id} — Single control + GET /v1/canonical/controls/{control_id}/traceability — Traceability chain + GET /v1/canonical/controls/{control_id}/similar — Find similar controls + POST /v1/canonical/controls — Create a control + PUT /v1/canonical/controls/{control_id} — Update a control + DELETE /v1/canonical/controls/{control_id} — Delete a control + GET /v1/canonical/categories — Category list + GET /v1/canonical/sources — Source registry + GET /v1/canonical/licenses — License matrix + POST /v1/canonical/controls/{control_id}/similarity-check — Too-close check +""" + +from __future__ import annotations + +import json +import logging +from typing import Any, Optional + +from fastapi import APIRouter, HTTPException, Query +from pydantic import BaseModel +from sqlalchemy import text + +from db.session import SessionLocal +from services.license_gate import get_license_matrix, get_source_permissions +from services.similarity_detector import check_similarity + +logger = logging.getLogger(__name__) +router = APIRouter(prefix="/v1/canonical", tags=["canonical-controls"]) + + +# ============================================================================= +# RESPONSE MODELS +# ============================================================================= + +class FrameworkResponse(BaseModel): + id: str + framework_id: str + name: str + version: str + description: Optional[str] = None + owner: Optional[str] = None + policy_version: Optional[str] = None + release_state: str + created_at: str + updated_at: str + + +class ControlResponse(BaseModel): + id: str + framework_id: str + control_id: str + title: str + objective: str + rationale: str + scope: dict + requirements: list + test_procedure: list + evidence: list + severity: str + risk_score: Optional[float] = None + implementation_effort: Optional[str] = None + evidence_confidence: Optional[float] = None + open_anchors: list + release_state: str + tags: list + license_rule: Optional[int] = None + source_original_text: Optional[str] = None + source_citation: Optional[dict] = None + customer_visible: Optional[bool] = None + verification_method: Optional[str] = None + category: Optional[str] = None + evidence_type: Optional[str] = None + target_audience: Optional[str] = None + generation_metadata: Optional[dict] = None + generation_strategy: Optional[str] = "ungrouped" + applicable_industries: Optional[list] = None + applicable_company_size: Optional[list] = None + scope_conditions: Optional[dict] = None + created_at: str + updated_at: str + + +class ControlCreateRequest(BaseModel): + framework_id: str # e.g. 'bp_security_v1' + control_id: str # e.g. 'AUTH-003' + title: str + objective: str + rationale: str + scope: dict = {} + requirements: list = [] + test_procedure: list = [] + evidence: list = [] + severity: str = "medium" + risk_score: Optional[float] = None + implementation_effort: Optional[str] = None + evidence_confidence: Optional[float] = None + open_anchors: list = [] + release_state: str = "draft" + tags: list = [] + license_rule: Optional[int] = None + source_original_text: Optional[str] = None + source_citation: Optional[dict] = None + customer_visible: Optional[bool] = True + verification_method: Optional[str] = None + category: Optional[str] = None + evidence_type: Optional[str] = None + target_audience: Optional[str] = None + generation_metadata: Optional[dict] = None + applicable_industries: Optional[list] = None + applicable_company_size: Optional[list] = None + scope_conditions: Optional[dict] = None + + +class ControlUpdateRequest(BaseModel): + title: Optional[str] = None + objective: Optional[str] = None + rationale: Optional[str] = None + scope: Optional[dict] = None + requirements: Optional[list] = None + test_procedure: Optional[list] = None + evidence: Optional[list] = None + severity: Optional[str] = None + risk_score: Optional[float] = None + implementation_effort: Optional[str] = None + evidence_confidence: Optional[float] = None + open_anchors: Optional[list] = None + release_state: Optional[str] = None + tags: Optional[list] = None + license_rule: Optional[int] = None + source_original_text: Optional[str] = None + source_citation: Optional[dict] = None + customer_visible: Optional[bool] = None + verification_method: Optional[str] = None + category: Optional[str] = None + evidence_type: Optional[str] = None + target_audience: Optional[str] = None + generation_metadata: Optional[dict] = None + applicable_industries: Optional[list] = None + applicable_company_size: Optional[list] = None + scope_conditions: Optional[dict] = None + + +class SimilarityCheckRequest(BaseModel): + source_text: str + candidate_text: str + + +class SimilarityCheckResponse(BaseModel): + max_exact_run: int + token_overlap: float + ngram_jaccard: float + embedding_cosine: float + lcs_ratio: float + status: str + details: dict + + +# ============================================================================= +# HELPERS +# ============================================================================= + +_CONTROL_COLS = """id, framework_id, control_id, title, objective, rationale, + scope, requirements, test_procedure, evidence, + severity, risk_score, implementation_effort, + evidence_confidence, open_anchors, release_state, tags, + license_rule, source_original_text, source_citation, + customer_visible, verification_method, category, evidence_type, + target_audience, generation_metadata, generation_strategy, + applicable_industries, applicable_company_size, scope_conditions, + parent_control_uuid, decomposition_method, pipeline_version, + (SELECT p.control_id FROM canonical_controls p WHERE p.id = canonical_controls.parent_control_uuid) AS parent_control_id, + (SELECT p.title FROM canonical_controls p WHERE p.id = canonical_controls.parent_control_uuid) AS parent_control_title, + created_at, updated_at""" + + +def _row_to_dict(row, columns: list[str]) -> dict[str, Any]: + """Generic row → dict converter.""" + return {col: (getattr(row, col).isoformat() if hasattr(getattr(row, col, None), 'isoformat') else getattr(row, col)) for col in columns} + + +# ============================================================================= +# FRAMEWORKS +# ============================================================================= + +@router.get("/frameworks") +async def list_frameworks(): + """List all registered control frameworks.""" + with SessionLocal() as db: + rows = db.execute( + text(""" + SELECT id, framework_id, name, version, description, + owner, policy_version, release_state, + created_at, updated_at + FROM canonical_control_frameworks + ORDER BY name + """) + ).fetchall() + + return [ + { + "id": str(r.id), + "framework_id": r.framework_id, + "name": r.name, + "version": r.version, + "description": r.description, + "owner": r.owner, + "policy_version": r.policy_version, + "release_state": r.release_state, + "created_at": r.created_at.isoformat() if r.created_at else None, + "updated_at": r.updated_at.isoformat() if r.updated_at else None, + } + for r in rows + ] + + +@router.get("/frameworks/{framework_id}") +async def get_framework(framework_id: str): + """Get a single framework by its framework_id.""" + with SessionLocal() as db: + row = db.execute( + text(""" + SELECT id, framework_id, name, version, description, + owner, policy_version, release_state, + created_at, updated_at + FROM canonical_control_frameworks + WHERE framework_id = :fid + """), + {"fid": framework_id}, + ).fetchone() + + if not row: + raise HTTPException(status_code=404, detail="Framework not found") + + return { + "id": str(row.id), + "framework_id": row.framework_id, + "name": row.name, + "version": row.version, + "description": row.description, + "owner": row.owner, + "policy_version": row.policy_version, + "release_state": row.release_state, + "created_at": row.created_at.isoformat() if row.created_at else None, + "updated_at": row.updated_at.isoformat() if row.updated_at else None, + } + + +@router.get("/frameworks/{framework_id}/controls") +async def list_framework_controls( + framework_id: str, + severity: Optional[str] = Query(None), + release_state: Optional[str] = Query(None), + verification_method: Optional[str] = Query(None), + category: Optional[str] = Query(None), + target_audience: Optional[str] = Query(None), +): + """List controls belonging to a framework.""" + with SessionLocal() as db: + # Resolve framework UUID + fw = db.execute( + text("SELECT id FROM canonical_control_frameworks WHERE framework_id = :fid"), + {"fid": framework_id}, + ).fetchone() + if not fw: + raise HTTPException(status_code=404, detail="Framework not found") + + query = f""" + SELECT {_CONTROL_COLS} + FROM canonical_controls + WHERE framework_id = :fw_id + """ + params: dict[str, Any] = {"fw_id": str(fw.id)} + + if severity: + query += " AND severity = :sev" + params["sev"] = severity + if release_state: + query += " AND release_state = :rs" + params["rs"] = release_state + if verification_method: + query += " AND verification_method = :vm" + params["vm"] = verification_method + if category: + query += " AND category = :cat" + params["cat"] = category + if target_audience: + query += " AND target_audience::jsonb @> (:ta)::jsonb" + params["ta"] = json.dumps([target_audience]) + + query += " ORDER BY control_id" + rows = db.execute(text(query), params).fetchall() + + return [_control_row(r) for r in rows] + + +# ============================================================================= +# CONTROLS +# ============================================================================= + +@router.get("/controls") +async def list_controls( + severity: Optional[str] = Query(None), + domain: Optional[str] = Query(None), + release_state: Optional[str] = Query(None), + verification_method: Optional[str] = Query(None), + category: Optional[str] = Query(None), + evidence_type: Optional[str] = Query(None, description="Filter: code, process, hybrid"), + target_audience: Optional[str] = Query(None), + source: Optional[str] = Query(None, description="Filter by source_citation->source"), + search: Optional[str] = Query(None, description="Full-text search in control_id, title, objective"), + control_type: Optional[str] = Query(None, description="Filter: atomic, rich, or all"), + exclude_duplicates: bool = Query(False, description="Exclude controls with release_state='duplicate'"), + sort: Optional[str] = Query("control_id", description="Sort field: control_id, created_at, severity"), + order: Optional[str] = Query("asc", description="Sort order: asc or desc"), + limit: Optional[int] = Query(None, ge=1, le=5000, description="Max results"), + offset: Optional[int] = Query(None, ge=0, description="Offset for pagination"), +): + """List canonical controls with filters, search, sorting and pagination.""" + query = f""" + SELECT {_CONTROL_COLS} + FROM canonical_controls + WHERE 1=1 + """ + params: dict[str, Any] = {} + + if exclude_duplicates: + query += " AND release_state != 'duplicate'" + + if severity: + query += " AND severity = :sev" + params["sev"] = severity + if domain: + query += " AND LEFT(control_id, LENGTH(:dom)) = :dom" + params["dom"] = domain.upper() + if release_state: + query += " AND release_state = :rs" + params["rs"] = release_state + if verification_method: + if verification_method == "__none__": + query += " AND verification_method IS NULL" + else: + query += " AND verification_method = :vm" + params["vm"] = verification_method + if category: + if category == "__none__": + query += " AND category IS NULL" + else: + query += " AND category = :cat" + params["cat"] = category + if evidence_type: + if evidence_type == "__none__": + query += " AND evidence_type IS NULL" + else: + query += " AND evidence_type = :et" + params["et"] = evidence_type + if target_audience: + query += " AND target_audience LIKE :ta_pattern" + params["ta_pattern"] = f'%"{target_audience}"%' + if source: + if source == "__none__": + query += " AND (source_citation IS NULL OR source_citation->>'source' IS NULL OR source_citation->>'source' = '')" + else: + query += " AND source_citation->>'source' = :src" + params["src"] = source + if control_type == "atomic": + query += " AND decomposition_method = 'pass0b'" + elif control_type == "rich": + query += " AND (decomposition_method IS NULL OR decomposition_method != 'pass0b')" + elif control_type == "eigenentwicklung": + query += """ AND generation_strategy = 'ungrouped' + AND (pipeline_version = '1' OR pipeline_version IS NULL) + AND source_citation IS NULL + AND parent_control_uuid IS NULL""" + if search: + query += " AND (control_id ILIKE :q OR title ILIKE :q OR objective ILIKE :q)" + params["q"] = f"%{search}%" + + # Sorting + sort_col = "control_id" + if sort in ("created_at", "updated_at", "severity", "control_id"): + sort_col = sort + elif sort == "source": + sort_col = "source_citation->>'source'" + sort_dir = "DESC" if order and order.lower() == "desc" else "ASC" + if sort == "source": + # Group by source first, then by control_id within each source + query += f" ORDER BY {sort_col} {sort_dir} NULLS LAST, control_id ASC" + else: + query += f" ORDER BY {sort_col} {sort_dir}" + + if limit is not None: + query += " LIMIT :lim" + params["lim"] = limit + if offset is not None: + query += " OFFSET :off" + params["off"] = offset + + with SessionLocal() as db: + rows = db.execute(text(query), params).fetchall() + + return [_control_row(r) for r in rows] + + +@router.get("/controls-count") +async def count_controls( + severity: Optional[str] = Query(None), + domain: Optional[str] = Query(None), + release_state: Optional[str] = Query(None), + verification_method: Optional[str] = Query(None), + category: Optional[str] = Query(None), + evidence_type: Optional[str] = Query(None), + target_audience: Optional[str] = Query(None), + source: Optional[str] = Query(None), + search: Optional[str] = Query(None), + control_type: Optional[str] = Query(None), + exclude_duplicates: bool = Query(False, description="Exclude controls with release_state='duplicate'"), +): + """Count controls matching filters (for pagination).""" + query = "SELECT count(*) FROM canonical_controls WHERE 1=1" + params: dict[str, Any] = {} + + if exclude_duplicates: + query += " AND release_state != 'duplicate'" + + if severity: + query += " AND severity = :sev" + params["sev"] = severity + if domain: + query += " AND LEFT(control_id, LENGTH(:dom)) = :dom" + params["dom"] = domain.upper() + if release_state: + query += " AND release_state = :rs" + params["rs"] = release_state + if verification_method: + if verification_method == "__none__": + query += " AND verification_method IS NULL" + else: + query += " AND verification_method = :vm" + params["vm"] = verification_method + if category: + if category == "__none__": + query += " AND category IS NULL" + else: + query += " AND category = :cat" + params["cat"] = category + if evidence_type: + if evidence_type == "__none__": + query += " AND evidence_type IS NULL" + else: + query += " AND evidence_type = :et" + params["et"] = evidence_type + if target_audience: + query += " AND target_audience LIKE :ta_pattern" + params["ta_pattern"] = f'%"{target_audience}"%' + if source: + if source == "__none__": + query += " AND (source_citation IS NULL OR source_citation->>'source' IS NULL OR source_citation->>'source' = '')" + else: + query += " AND source_citation->>'source' = :src" + params["src"] = source + if control_type == "atomic": + query += " AND decomposition_method = 'pass0b'" + elif control_type == "rich": + query += " AND (decomposition_method IS NULL OR decomposition_method != 'pass0b')" + elif control_type == "eigenentwicklung": + query += """ AND generation_strategy = 'ungrouped' + AND (pipeline_version = '1' OR pipeline_version IS NULL) + AND source_citation IS NULL + AND parent_control_uuid IS NULL""" + if search: + query += " AND (control_id ILIKE :q OR title ILIKE :q OR objective ILIKE :q)" + params["q"] = f"%{search}%" + + with SessionLocal() as db: + total = db.execute(text(query), params).scalar() + + return {"total": total} + + +@router.get("/controls-meta") +async def controls_meta( + severity: Optional[str] = Query(None), + domain: Optional[str] = Query(None), + release_state: Optional[str] = Query(None), + verification_method: Optional[str] = Query(None), + category: Optional[str] = Query(None), + evidence_type: Optional[str] = Query(None), + target_audience: Optional[str] = Query(None), + source: Optional[str] = Query(None), + search: Optional[str] = Query(None), + control_type: Optional[str] = Query(None), + exclude_duplicates: bool = Query(False), +): + """Return faceted metadata for filter dropdowns. + + Each facet's counts respect ALL active filters EXCEPT the facet's own, + so dropdowns always show how many items each option would yield. + """ + + def _build_where(skip: Optional[str] = None) -> tuple[str, dict[str, Any]]: + clauses = ["1=1"] + p: dict[str, Any] = {} + + if exclude_duplicates: + clauses.append("release_state != 'duplicate'") + if severity and skip != "severity": + clauses.append("severity = :sev") + p["sev"] = severity + if domain and skip != "domain": + clauses.append("LEFT(control_id, LENGTH(:dom)) = :dom") + p["dom"] = domain.upper() + if release_state and skip != "release_state": + clauses.append("release_state = :rs") + p["rs"] = release_state + if verification_method and skip != "verification_method": + if verification_method == "__none__": + clauses.append("verification_method IS NULL") + else: + clauses.append("verification_method = :vm") + p["vm"] = verification_method + if category and skip != "category": + if category == "__none__": + clauses.append("category IS NULL") + else: + clauses.append("category = :cat") + p["cat"] = category + if evidence_type and skip != "evidence_type": + if evidence_type == "__none__": + clauses.append("evidence_type IS NULL") + else: + clauses.append("evidence_type = :et") + p["et"] = evidence_type + if target_audience and skip != "target_audience": + clauses.append("target_audience LIKE :ta_pattern") + p["ta_pattern"] = f'%"{target_audience}"%' + if source and skip != "source": + if source == "__none__": + clauses.append("(source_citation IS NULL OR source_citation->>'source' IS NULL OR source_citation->>'source' = '')") + else: + clauses.append("source_citation->>'source' = :src") + p["src"] = source + if control_type and skip != "control_type": + if control_type == "atomic": + clauses.append("decomposition_method = 'pass0b'") + elif control_type == "rich": + clauses.append("(decomposition_method IS NULL OR decomposition_method != 'pass0b')") + elif control_type == "eigenentwicklung": + clauses.append("""generation_strategy = 'ungrouped' + AND (pipeline_version = '1' OR pipeline_version IS NULL) + AND source_citation IS NULL + AND parent_control_uuid IS NULL""") + if search and skip != "search": + clauses.append("(control_id ILIKE :q OR title ILIKE :q OR objective ILIKE :q)") + p["q"] = f"%{search}%" + + return " AND ".join(clauses), p + + with SessionLocal() as db: + # Total with ALL filters + w_all, p_all = _build_where() + total = db.execute(text(f"SELECT count(*) FROM canonical_controls WHERE {w_all}"), p_all).scalar() + + # Domain facet (skip domain filter so user sees all domains) + w_dom, p_dom = _build_where(skip="domain") + domains = db.execute(text(f""" + SELECT UPPER(SPLIT_PART(control_id, '-', 1)) as domain, count(*) as cnt + FROM canonical_controls WHERE {w_dom} + GROUP BY domain ORDER BY domain + """), p_dom).fetchall() + + # Source facet (skip source filter) + w_src, p_src = _build_where(skip="source") + sources = db.execute(text(f""" + SELECT source_citation->>'source' as src, count(*) as cnt + FROM canonical_controls + WHERE {w_src} + AND source_citation->>'source' IS NOT NULL AND source_citation->>'source' != '' + GROUP BY src ORDER BY cnt DESC + """), p_src).fetchall() + + no_source = db.execute(text(f""" + SELECT count(*) FROM canonical_controls + WHERE {w_src} + AND (source_citation IS NULL OR source_citation->>'source' IS NULL OR source_citation->>'source' = '') + """), p_src).scalar() + + # Type facet (skip control_type filter) + w_typ, p_typ = _build_where(skip="control_type") + atomic_count = db.execute(text(f""" + SELECT count(*) FROM canonical_controls + WHERE {w_typ} AND decomposition_method = 'pass0b' + """), p_typ).scalar() or 0 + + eigenentwicklung_count = db.execute(text(f""" + SELECT count(*) FROM canonical_controls + WHERE {w_typ} + AND generation_strategy = 'ungrouped' + AND (pipeline_version = '1' OR pipeline_version IS NULL) + AND source_citation IS NULL + AND parent_control_uuid IS NULL + """), p_typ).scalar() or 0 + + rich_count = db.execute(text(f""" + SELECT count(*) FROM canonical_controls + WHERE {w_typ} + AND (decomposition_method IS NULL OR decomposition_method != 'pass0b') + """), p_typ).scalar() or 0 + + # Severity facet (skip severity filter) + w_sev, p_sev = _build_where(skip="severity") + severity_counts = db.execute(text(f""" + SELECT severity, count(*) as cnt + FROM canonical_controls WHERE {w_sev} + GROUP BY severity ORDER BY severity + """), p_sev).fetchall() + + # Verification method facet (include NULLs as __none__) + w_vm, p_vm = _build_where(skip="verification_method") + vm_counts = db.execute(text(f""" + SELECT COALESCE(verification_method, '__none__') as vm, count(*) as cnt + FROM canonical_controls WHERE {w_vm} + GROUP BY vm ORDER BY vm + """), p_vm).fetchall() + + # Category facet (include NULLs as __none__) + w_cat, p_cat = _build_where(skip="category") + cat_counts = db.execute(text(f""" + SELECT COALESCE(category, '__none__') as cat, count(*) as cnt + FROM canonical_controls WHERE {w_cat} + GROUP BY cat ORDER BY cnt DESC + """), p_cat).fetchall() + + # Evidence type facet (include NULLs as __none__) + w_et, p_et = _build_where(skip="evidence_type") + et_counts = db.execute(text(f""" + SELECT COALESCE(evidence_type, '__none__') as et, count(*) as cnt + FROM canonical_controls WHERE {w_et} + GROUP BY et ORDER BY et + """), p_et).fetchall() + + # Release state facet + w_rs, p_rs = _build_where(skip="release_state") + rs_counts = db.execute(text(f""" + SELECT release_state, count(*) as cnt + FROM canonical_controls WHERE {w_rs} + GROUP BY release_state ORDER BY release_state + """), p_rs).fetchall() + + return { + "total": total, + "domains": [{"domain": r[0], "count": r[1]} for r in domains], + "sources": [{"source": r[0], "count": r[1]} for r in sources], + "no_source_count": no_source, + "type_counts": { + "rich": rich_count, + "atomic": atomic_count, + "eigenentwicklung": eigenentwicklung_count, + }, + "severity_counts": {r[0]: r[1] for r in severity_counts}, + "verification_method_counts": {r[0]: r[1] for r in vm_counts}, + "category_counts": {r[0]: r[1] for r in cat_counts}, + "evidence_type_counts": {r[0]: r[1] for r in et_counts}, + "release_state_counts": {r[0]: r[1] for r in rs_counts}, + } + + +@router.get("/controls/atomic-stats") +async def atomic_stats(): + """Return aggregated statistics for atomic controls (masters only).""" + with SessionLocal() as db: + total_active = db.execute(text(""" + SELECT count(*) FROM canonical_controls + WHERE decomposition_method = 'pass0b' + AND release_state NOT IN ('duplicate', 'deprecated', 'rejected') + """)).scalar() or 0 + + total_duplicate = db.execute(text(""" + SELECT count(*) FROM canonical_controls + WHERE decomposition_method = 'pass0b' + AND release_state = 'duplicate' + """)).scalar() or 0 + + by_domain = db.execute(text(""" + SELECT UPPER(SPLIT_PART(control_id, '-', 1)) AS domain, count(*) AS cnt + FROM canonical_controls + WHERE decomposition_method = 'pass0b' + AND release_state NOT IN ('duplicate', 'deprecated', 'rejected') + GROUP BY domain ORDER BY cnt DESC + """)).fetchall() + + by_regulation = db.execute(text(""" + SELECT cpl.source_regulation AS regulation, count(DISTINCT cc.id) AS cnt + FROM canonical_controls cc + JOIN control_parent_links cpl ON cpl.control_uuid = cc.id + WHERE cc.decomposition_method = 'pass0b' + AND cc.release_state NOT IN ('duplicate', 'deprecated', 'rejected') + AND cpl.source_regulation IS NOT NULL + GROUP BY cpl.source_regulation ORDER BY cnt DESC + """)).fetchall() + + avg_coverage = db.execute(text(""" + SELECT COALESCE(AVG(reg_count), 0) + FROM ( + SELECT cc.id, count(DISTINCT cpl.source_regulation) AS reg_count + FROM canonical_controls cc + LEFT JOIN control_parent_links cpl ON cpl.control_uuid = cc.id + WHERE cc.decomposition_method = 'pass0b' + AND cc.release_state NOT IN ('duplicate', 'deprecated', 'rejected') + GROUP BY cc.id + ) sub + """)).scalar() or 0 + + return { + "total_active": total_active, + "total_duplicate": total_duplicate, + "by_domain": [{"domain": r[0], "count": r[1]} for r in by_domain], + "by_regulation": [{"regulation": r[0], "count": r[1]} for r in by_regulation], + "avg_regulation_coverage": round(float(avg_coverage), 1), + } + + +@router.get("/controls/v1-enrichment-stats") +async def v1_enrichment_stats_endpoint(): + """ + Uebersicht: Wie viele v1 Controls haben regulatorische Abdeckung? + """ + from services.v1_enrichment import get_v1_enrichment_stats + return await get_v1_enrichment_stats() + + +@router.get("/controls/{control_id}") +async def get_control(control_id: str): + """Get a single canonical control by its control_id (e.g. AUTH-001).""" + with SessionLocal() as db: + row = db.execute( + text(f""" + SELECT {_CONTROL_COLS} + FROM canonical_controls + WHERE control_id = :cid + """), + {"cid": control_id.upper()}, + ).fetchone() + + if not row: + raise HTTPException(status_code=404, detail="Control not found") + + return _control_row(row) + + +@router.get("/controls/{control_id}/traceability") +async def get_control_traceability(control_id: str): + """Get the full traceability chain for a control. + + For atomic controls: shows all parent links with source regulations, + articles, and the obligation chain. + For rich controls: shows child atomic controls derived from them. + """ + with SessionLocal() as db: + # Get control UUID + ctrl = db.execute( + text(""" + SELECT id, control_id, title, parent_control_uuid, + decomposition_method, source_citation + FROM canonical_controls WHERE control_id = :cid + """), + {"cid": control_id.upper()}, + ).fetchone() + + if not ctrl: + raise HTTPException(status_code=404, detail="Control not found") + + result: dict[str, Any] = { + "control_id": ctrl.control_id, + "title": ctrl.title, + "is_atomic": ctrl.decomposition_method == "pass0b", + } + + ctrl_uuid = str(ctrl.id) + + # Parent links (M:N) — for atomic controls + parent_links = db.execute( + text(""" + SELECT cpl.parent_control_uuid, cpl.link_type, + cpl.confidence, cpl.source_regulation, + cpl.source_article, cpl.obligation_candidate_id, + cc.control_id AS parent_control_id, + cc.title AS parent_title, + cc.source_citation AS parent_citation, + oc.obligation_text, oc.action, oc.object, + oc.normative_strength + FROM control_parent_links cpl + JOIN canonical_controls cc ON cc.id = cpl.parent_control_uuid + LEFT JOIN obligation_candidates oc ON oc.id = cpl.obligation_candidate_id + WHERE cpl.control_uuid = CAST(:uid AS uuid) + ORDER BY cpl.source_regulation, cpl.source_article + """), + {"uid": ctrl_uuid}, + ).fetchall() + + result["parent_links"] = [ + { + "parent_control_id": pl.parent_control_id, + "parent_title": pl.parent_title, + "link_type": pl.link_type, + "confidence": float(pl.confidence) if pl.confidence else 1.0, + "source_regulation": pl.source_regulation, + "source_article": pl.source_article, + "parent_citation": pl.parent_citation, + "obligation": { + "text": pl.obligation_text, + "action": pl.action, + "object": pl.object, + "normative_strength": pl.normative_strength, + } if pl.obligation_text else None, + } + for pl in parent_links + ] + + # Also include the 1:1 parent (backwards compat) if not already in links + if ctrl.parent_control_uuid: + parent_uuids_in_links = { + str(pl.parent_control_uuid) for pl in parent_links + } + parent_uuid_str = str(ctrl.parent_control_uuid) + if parent_uuid_str not in parent_uuids_in_links: + legacy = db.execute( + text(""" + SELECT control_id, title, source_citation + FROM canonical_controls WHERE id = CAST(:uid AS uuid) + """), + {"uid": parent_uuid_str}, + ).fetchone() + if legacy: + result["parent_links"].insert(0, { + "parent_control_id": legacy.control_id, + "parent_title": legacy.title, + "link_type": "decomposition", + "confidence": 1.0, + "source_regulation": None, + "source_article": None, + "parent_citation": legacy.source_citation, + "obligation": None, + }) + + # Child controls — for rich controls + children = db.execute( + text(""" + SELECT control_id, title, category, severity, + decomposition_method + FROM canonical_controls + WHERE parent_control_uuid = CAST(:uid AS uuid) + ORDER BY control_id + """), + {"uid": ctrl_uuid}, + ).fetchall() + + result["children"] = [ + { + "control_id": ch.control_id, + "title": ch.title, + "category": ch.category, + "severity": ch.severity, + "decomposition_method": ch.decomposition_method, + } + for ch in children + ] + + # Unique source regulations count + regs = set() + for pl in result["parent_links"]: + if pl.get("source_regulation"): + regs.add(pl["source_regulation"]) + result["source_count"] = len(regs) + + return result + + +@router.get("/controls/{control_id}/provenance") +async def get_control_provenance(control_id: str): + """Get full provenance chain for a control — extends traceability with + obligations, document references, merged duplicates, and regulations summary. + """ + with SessionLocal() as db: + ctrl = db.execute( + text(""" + SELECT id, control_id, title, parent_control_uuid, + decomposition_method, source_citation + FROM canonical_controls WHERE control_id = :cid + """), + {"cid": control_id.upper()}, + ).fetchone() + + if not ctrl: + raise HTTPException(status_code=404, detail="Control not found") + + ctrl_uuid = str(ctrl.id) + is_atomic = ctrl.decomposition_method == "pass0b" + + result: dict[str, Any] = { + "control_id": ctrl.control_id, + "title": ctrl.title, + "is_atomic": is_atomic, + } + + # --- Parent links (same as traceability) --- + parent_links = db.execute( + text(""" + SELECT cpl.parent_control_uuid, cpl.link_type, + cpl.confidence, cpl.source_regulation, + cpl.source_article, cpl.obligation_candidate_id, + cc.control_id AS parent_control_id, + cc.title AS parent_title, + cc.source_citation AS parent_citation, + oc.obligation_text, oc.action, oc.object, + oc.normative_strength + FROM control_parent_links cpl + JOIN canonical_controls cc ON cc.id = cpl.parent_control_uuid + LEFT JOIN obligation_candidates oc ON oc.id = cpl.obligation_candidate_id + WHERE cpl.control_uuid = CAST(:uid AS uuid) + ORDER BY cpl.source_regulation, cpl.source_article + """), + {"uid": ctrl_uuid}, + ).fetchall() + + result["parent_links"] = [ + { + "parent_control_id": pl.parent_control_id, + "parent_title": pl.parent_title, + "link_type": pl.link_type, + "confidence": float(pl.confidence) if pl.confidence else 1.0, + "source_regulation": pl.source_regulation, + "source_article": pl.source_article, + "parent_citation": pl.parent_citation, + "obligation": { + "text": pl.obligation_text, + "action": pl.action, + "object": pl.object, + "normative_strength": pl.normative_strength, + } if pl.obligation_text else None, + } + for pl in parent_links + ] + + # Legacy 1:1 parent (backwards compat) + if ctrl.parent_control_uuid: + parent_uuids_in_links = { + str(pl.parent_control_uuid) for pl in parent_links + } + parent_uuid_str = str(ctrl.parent_control_uuid) + if parent_uuid_str not in parent_uuids_in_links: + legacy = db.execute( + text(""" + SELECT control_id, title, source_citation + FROM canonical_controls WHERE id = CAST(:uid AS uuid) + """), + {"uid": parent_uuid_str}, + ).fetchone() + if legacy: + result["parent_links"].insert(0, { + "parent_control_id": legacy.control_id, + "parent_title": legacy.title, + "link_type": "decomposition", + "confidence": 1.0, + "source_regulation": None, + "source_article": None, + "parent_citation": legacy.source_citation, + "obligation": None, + }) + + # --- Children --- + children = db.execute( + text(""" + SELECT control_id, title, category, severity, + decomposition_method + FROM canonical_controls + WHERE parent_control_uuid = CAST(:uid AS uuid) + ORDER BY control_id + """), + {"uid": ctrl_uuid}, + ).fetchall() + + result["children"] = [ + { + "control_id": ch.control_id, + "title": ch.title, + "category": ch.category, + "severity": ch.severity, + "decomposition_method": ch.decomposition_method, + } + for ch in children + ] + + # Source count + regs = set() + for pl in result["parent_links"]: + if pl.get("source_regulation"): + regs.add(pl["source_regulation"]) + result["source_count"] = len(regs) + + # --- Obligations (for Rich Controls) --- + obligations = db.execute( + text(""" + SELECT candidate_id, obligation_text, action, object, + normative_strength, release_state + FROM obligation_candidates + WHERE parent_control_uuid = CAST(:uid AS uuid) + AND release_state NOT IN ('rejected', 'merged', 'duplicate') + ORDER BY candidate_id + """), + {"uid": ctrl_uuid}, + ).fetchall() + + result["obligations"] = [ + { + "candidate_id": ob.candidate_id, + "obligation_text": ob.obligation_text, + "action": ob.action, + "object": ob.object, + "normative_strength": ob.normative_strength, + "release_state": ob.release_state, + } + for ob in obligations + ] + result["obligation_count"] = len(obligations) + + # --- Document References --- + doc_refs = db.execute( + text(""" + SELECT DISTINCT oe.regulation_code, oe.article, oe.paragraph, + oe.extraction_method, oe.confidence + FROM obligation_extractions oe + WHERE oe.control_uuid = CAST(:uid AS uuid) + OR oe.obligation_id IN ( + SELECT oc.candidate_id FROM obligation_candidates oc + JOIN control_parent_links cpl ON cpl.obligation_candidate_id = oc.id + WHERE cpl.control_uuid = CAST(:uid AS uuid) + ) + ORDER BY oe.regulation_code, oe.article + """), + {"uid": ctrl_uuid}, + ).fetchall() + + result["document_references"] = [ + { + "regulation_code": dr.regulation_code, + "article": dr.article, + "paragraph": dr.paragraph, + "extraction_method": dr.extraction_method, + "confidence": float(dr.confidence) if dr.confidence else None, + } + for dr in doc_refs + ] + + # --- Merged Duplicates --- + merged = db.execute( + text(""" + SELECT cc.control_id, cc.title, + (SELECT cpl.source_regulation FROM control_parent_links cpl + WHERE cpl.control_uuid = cc.id LIMIT 1) AS source_regulation + FROM canonical_controls cc + WHERE cc.merged_into_uuid = CAST(:uid AS uuid) + AND cc.release_state = 'duplicate' + ORDER BY cc.control_id + """), + {"uid": ctrl_uuid}, + ).fetchall() + + result["merged_duplicates"] = [ + { + "control_id": m.control_id, + "title": m.title, + "source_regulation": m.source_regulation, + } + for m in merged + ] + result["merged_duplicates_count"] = len(merged) + + # --- Regulations Summary (aggregated from parent_links + doc_refs) --- + reg_map: dict[str, dict[str, Any]] = {} + for pl in result["parent_links"]: + reg = pl.get("source_regulation") + if not reg: + continue + if reg not in reg_map: + reg_map[reg] = {"articles": set(), "link_types": set()} + if pl.get("source_article"): + reg_map[reg]["articles"].add(pl["source_article"]) + reg_map[reg]["link_types"].add(pl.get("link_type", "decomposition")) + + for dr in result["document_references"]: + reg = dr.get("regulation_code") + if not reg: + continue + if reg not in reg_map: + reg_map[reg] = {"articles": set(), "link_types": set()} + if dr.get("article"): + reg_map[reg]["articles"].add(dr["article"]) + + result["regulations_summary"] = [ + { + "regulation_code": reg, + "articles": sorted(info["articles"]), + "link_types": sorted(info["link_types"]), + } + for reg, info in sorted(reg_map.items()) + ] + + return result + + +# ============================================================================= +# NORMATIVE STRENGTH BACKFILL +# ============================================================================= + +@router.post("/controls/backfill-normative-strength") +async def backfill_normative_strength( + dry_run: bool = Query(True, description="Nur zaehlen, nicht aendern"), +): + """ + Korrigiert normative_strength auf obligation_candidates basierend auf + dem source_type der Quell-Regulierung. + + Dreistufiges Modell: + - law (Gesetz): normative_strength bleibt unveraendert + - guideline (Leitlinie): max 'should' + - framework (Framework): max 'can' + + Fuer Controls mit mehreren Parent-Links gilt der hoechste source_type. + """ + from data.source_type_classification import ( + classify_source_regulation, + cap_normative_strength, + ) + + with SessionLocal() as db: + # 1. Alle Obligations mit source_citation des Parent Controls laden + obligations = db.execute(text(""" + SELECT oc.id, oc.candidate_id, oc.normative_strength, + cc.source_citation->>'source' AS parent_source + FROM obligation_candidates oc + JOIN canonical_controls cc ON cc.id = oc.parent_control_uuid + WHERE oc.release_state NOT IN ('rejected', 'merged', 'duplicate') + AND oc.normative_strength IS NOT NULL + ORDER BY oc.candidate_id + """)).fetchall() + + # 2. Normative strength korrigieren basierend auf source_type + changes = [] + stats = {"total": len(obligations), "unchanged": 0, "capped_to_should": 0, "capped_to_may": 0, "no_source": 0} + + for obl in obligations: + if not obl.parent_source: + stats["no_source"] += 1 + continue + + source_type = classify_source_regulation(obl.parent_source) + new_strength = cap_normative_strength(obl.normative_strength, source_type) + + if new_strength != obl.normative_strength: + changes.append({ + "id": str(obl.id), + "candidate_id": obl.candidate_id, + "old_strength": obl.normative_strength, + "new_strength": new_strength, + "source_type": source_type, + "source_regulation": obl.parent_source, + }) + if new_strength == "should": + stats["capped_to_should"] += 1 + elif new_strength == "may": + stats["capped_to_may"] += 1 + else: + stats["unchanged"] += 1 + + # 4. Aenderungen anwenden (wenn kein dry_run) + if not dry_run and changes: + for change in changes: + db.execute(text(""" + UPDATE obligation_candidates + SET normative_strength = :new_strength + WHERE id = CAST(:oid AS uuid) + """), {"new_strength": change["new_strength"], "oid": change["id"]}) + db.commit() + + return { + "dry_run": dry_run, + "stats": stats, + "total_changes": len(changes), + "sample_changes": changes[:20], + } + + +# ============================================================================= +# OBLIGATION DEDUPLICATION +# ============================================================================= + +@router.post("/obligations/dedup") +async def dedup_obligations( + dry_run: bool = Query(True, description="Nur zaehlen, nicht aendern"), + batch_size: int = Query(0, description="0 = alle auf einmal"), + offset: int = Query(0, description="Offset fuer Batch-Verarbeitung"), +): + """ + Markiert doppelte obligation_candidates als 'duplicate'. + + Duplikate = mehrere Eintraege mit gleichem candidate_id. + Pro candidate_id wird der aelteste Eintrag (MIN(created_at)) behalten, + alle anderen erhalten release_state='duplicate' und merged_into_id + zeigt auf den behaltenen Eintrag. + """ + with SessionLocal() as db: + # 1. Finde alle candidate_ids mit mehr als einem Eintrag + # (nur noch nicht-deduplizierte beruecksichtigen) + dup_query = """ + SELECT candidate_id, count(*) as cnt + FROM obligation_candidates + WHERE release_state NOT IN ('rejected', 'merged', 'duplicate') + GROUP BY candidate_id + HAVING count(*) > 1 + ORDER BY candidate_id + """ + if batch_size > 0: + dup_query += f" LIMIT {batch_size} OFFSET {offset}" + + dup_groups = db.execute(text(dup_query)).fetchall() + + total_groups = db.execute(text(""" + SELECT count(*) FROM ( + SELECT candidate_id + FROM obligation_candidates + WHERE release_state NOT IN ('rejected', 'merged', 'duplicate') + GROUP BY candidate_id + HAVING count(*) > 1 + ) sub + """)).scalar() + + # 2. Pro Gruppe: aeltesten behalten, Rest als duplicate markieren + kept_count = 0 + duplicate_count = 0 + sample_changes: list[dict[str, Any]] = [] + + for grp in dup_groups: + cid = grp.candidate_id + + # Alle Eintraege fuer dieses candidate_id holen + entries = db.execute(text(""" + SELECT id, candidate_id, obligation_text, release_state, created_at + FROM obligation_candidates + WHERE candidate_id = :cid + AND release_state NOT IN ('rejected', 'merged', 'duplicate') + ORDER BY created_at ASC, id ASC + """), {"cid": cid}).fetchall() + + if len(entries) < 2: + continue + + keeper = entries[0] # aeltester Eintrag + duplicates = entries[1:] + kept_count += 1 + duplicate_count += len(duplicates) + + if len(sample_changes) < 20: + sample_changes.append({ + "candidate_id": cid, + "kept_id": str(keeper.id), + "kept_text": keeper.obligation_text[:100], + "duplicate_count": len(duplicates), + "duplicate_ids": [str(d.id) for d in duplicates], + }) + + if not dry_run: + for dup in duplicates: + db.execute(text(""" + UPDATE obligation_candidates + SET release_state = 'duplicate', + merged_into_id = CAST(:keeper_id AS uuid), + quality_flags = COALESCE(quality_flags, '{}'::jsonb) + || jsonb_build_object( + 'dedup_reason', 'duplicate of ' || :keeper_cid, + 'dedup_kept_id', :keeper_id_str, + 'dedup_at', NOW()::text + ) + WHERE id = CAST(:dup_id AS uuid) + """), { + "keeper_id": str(keeper.id), + "keeper_cid": cid, + "keeper_id_str": str(keeper.id), + "dup_id": str(dup.id), + }) + + if not dry_run and duplicate_count > 0: + db.commit() + + return { + "dry_run": dry_run, + "stats": { + "total_duplicate_groups": total_groups, + "processed_groups": len(dup_groups), + "kept": kept_count, + "marked_duplicate": duplicate_count, + }, + "sample_changes": sample_changes, + } + + +@router.get("/obligations/dedup-stats") +async def dedup_obligations_stats(): + """Statistiken ueber den aktuellen Dedup-Status der Obligations.""" + with SessionLocal() as db: + total = db.execute(text( + "SELECT count(*) FROM obligation_candidates" + )).scalar() + + by_state = db.execute(text(""" + SELECT release_state, count(*) as cnt + FROM obligation_candidates + GROUP BY release_state + ORDER BY release_state + """)).fetchall() + + dup_groups = db.execute(text(""" + SELECT count(*) FROM ( + SELECT candidate_id + FROM obligation_candidates + WHERE release_state NOT IN ('rejected', 'merged', 'duplicate') + GROUP BY candidate_id + HAVING count(*) > 1 + ) sub + """)).scalar() + + removable = db.execute(text(""" + SELECT COALESCE(sum(cnt - 1), 0) FROM ( + SELECT candidate_id, count(*) as cnt + FROM obligation_candidates + WHERE release_state NOT IN ('rejected', 'merged', 'duplicate') + GROUP BY candidate_id + HAVING count(*) > 1 + ) sub + """)).scalar() + + return { + "total_obligations": total, + "by_state": {r.release_state: r.cnt for r in by_state}, + "pending_duplicate_groups": dup_groups, + "pending_removable_duplicates": removable, + } + + +# ============================================================================= +# EVIDENCE TYPE BACKFILL +# ============================================================================= + +# Domains that are primarily technical (code-verifiable) +_CODE_DOMAINS = frozenset({ + "SEC", "AUTH", "CRYPT", "CRYP", "CRY", "NET", "LOG", "ACC", "APP", "SYS", + "CI", "CONT", "API", "CLOUD", "IAC", "SAST", "DAST", "DEP", "SBOM", + "WEB", "DEV", "SDL", "PKI", "HSM", "TEE", "TPM", "CRX", "CRF", + "FWU", "STO", "RUN", "VUL", "MAL", "PLT", "AUT", +}) + +# Domains that are primarily process-based (document-verifiable) +_PROCESS_DOMAINS = frozenset({ + "GOV", "ORG", "COMP", "LEGAL", "HR", "TRAIN", "AML", "FIN", + "RISK", "AUDIT", "AUD", "PROC", "DOC", "PHYS", "PHY", "PRIV", "DPO", + "BCDR", "BCP", "VENDOR", "SUPPLY", "SUP", "CERT", "POLICY", + "ENV", "HLT", "TRD", "LAB", "PER", "REL", "ISM", "COM", + "GAM", "RIS", "PCA", "GNT", "HCA", "RES", "ISS", +}) + +# Domains that are typically hybrid +_HYBRID_DOMAINS = frozenset({ + "DATA", "AI", "INC", "ID", "IAM", "IDF", "IDP", "IDA", "IDN", + "OPS", "MNT", "INT", "BCK", +}) + + +def _classify_evidence_type(control_id: str, category: str | None) -> str: + """Heuristic: classify a control as code/process/hybrid based on domain prefix.""" + domain = control_id.split("-")[0].upper() if control_id else "" + + if domain in _CODE_DOMAINS: + return "code" + if domain in _PROCESS_DOMAINS: + return "process" + if domain in _HYBRID_DOMAINS: + return "hybrid" + + # Fallback: use category if available + code_categories = {"encryption", "authentication", "network", "application", "system", "identity"} + process_categories = {"compliance", "personnel", "physical", "governance", "risk"} + if category in code_categories: + return "code" + if category in process_categories: + return "process" + + return "process" # Conservative default + + +@router.post("/controls/backfill-evidence-type") +async def backfill_evidence_type( + dry_run: bool = Query(True, description="Nur zaehlen, nicht aendern"), +): + """ + Klassifiziert Controls als code/process/hybrid basierend auf Domain-Prefix. + + Heuristik: + - SEC, AUTH, CRYPT, NET, LOG, ... → code + - GOV, ORG, COMP, LEGAL, HR, ... → process + - DATA, AI, INC → hybrid + """ + with SessionLocal() as db: + rows = db.execute(text(""" + SELECT id, control_id, category, evidence_type + FROM canonical_controls + WHERE release_state NOT IN ('rejected', 'merged') + ORDER BY control_id + """)).fetchall() + + changes = [] + stats = {"total": len(rows), "already_set": 0, "code": 0, "process": 0, "hybrid": 0} + + for row in rows: + if row.evidence_type is not None: + stats["already_set"] += 1 + continue + + new_type = _classify_evidence_type(row.control_id, row.category) + stats[new_type] += 1 + changes.append({ + "id": str(row.id), + "control_id": row.control_id, + "evidence_type": new_type, + }) + + if not dry_run and changes: + for change in changes: + db.execute(text(""" + UPDATE canonical_controls + SET evidence_type = :et + WHERE id = CAST(:cid AS uuid) + """), {"et": change["evidence_type"], "cid": change["id"]}) + db.commit() + + return { + "dry_run": dry_run, + "stats": stats, + "total_changes": len(changes), + "sample_changes": changes[:20], + } + + +# ============================================================================= +# RATIONALE BACKFILL (LLM) +# ============================================================================= + +@router.post("/controls/backfill-rationale") +async def backfill_rationale( + dry_run: bool = Query(True, description="Nur zaehlen, nicht aendern"), + batch_size: int = Query(50, description="Parent-Controls pro Durchlauf"), + offset: int = Query(0, description="Offset fuer Paginierung (Parent-Index)"), +): + """ + Generiert sinnvolle Begruendungen fuer atomare Controls per LLM. + + Optimierung: Gruppiert nach Parent-Control (~7k Parents statt ~86k Einzel-Calls). + Pro Parent-Gruppe wird EIN LLM-Aufruf gemacht, der eine gemeinsame + Begruendung fuer alle Kinder erzeugt. + + Workflow: + 1. dry_run=true → Statistiken anzeigen + 2. dry_run=false&batch_size=50&offset=0 → Erste 50 Parents verarbeiten + 3. Wiederholen mit offset=50, 100, ... bis fertig + """ + from services.llm_provider import get_llm_provider + + with SessionLocal() as db: + # 1. Parent-Controls mit Kindern laden (nur wo rationale = Placeholder) + parents = db.execute(text(""" + SELECT p.id AS parent_uuid, p.control_id, p.title, p.category, + p.source_citation->>'source' AS source_name, + COUNT(c.id) AS child_count + FROM canonical_controls p + JOIN canonical_controls c ON c.parent_control_uuid = p.id + WHERE c.rationale = 'Aus Obligation abgeleitet.' + AND c.release_state NOT IN ('rejected', 'merged') + GROUP BY p.id, p.control_id, p.title, p.category, + p.source_citation->>'source' + ORDER BY p.control_id + """)).fetchall() + + total_parents = len(parents) + total_children = sum(p.child_count for p in parents) + + if dry_run: + return { + "dry_run": True, + "total_parents": total_parents, + "total_children": total_children, + "estimated_llm_calls": total_parents, + "sample_parents": [ + { + "control_id": p.control_id, + "title": p.title, + "source": p.source_name, + "child_count": p.child_count, + } + for p in parents[:10] + ], + } + + # 2. Batch auswählen + batch = parents[offset : offset + batch_size] + if not batch: + return { + "dry_run": False, + "message": "Kein weiterer Batch — alle Parents verarbeitet.", + "total_parents": total_parents, + "offset": offset, + "processed": 0, + } + + provider = get_llm_provider() + processed = 0 + children_updated = 0 + errors = [] + sample_rationales = [] + + for parent in batch: + parent_uuid = str(parent.parent_uuid) + source = parent.source_name or "Regulierung" + + # LLM-Prompt + prompt = ( + f"Du bist Compliance-Experte. Erklaere in 1-2 Saetzen auf Deutsch, " + f"WARUM aus dem uebergeordneten Control atomare Teilmassnahmen " + f"abgeleitet wurden.\n\n" + f"Uebergeordnetes Control: {parent.control_id} — {parent.title}\n" + f"Regulierung: {source}\n" + f"Kategorie: {parent.category or 'k.A.'}\n" + f"Anzahl atomarer Controls: {parent.child_count}\n\n" + f"Schreibe NUR die Begruendung (1-2 Saetze). Kein Markdown, " + f"keine Aufzaehlung, kein Praefix. " + f"Erklaere den regulatorischen Hintergrund und warum die " + f"Zerlegung in atomare, testbare Massnahmen notwendig ist." + ) + + try: + response = await provider.complete( + prompt=prompt, + max_tokens=256, + temperature=0.3, + ) + rationale = response.content.strip() + + # Bereinigen: Anfuehrungszeichen, Markdown entfernen + rationale = rationale.strip('"').strip("'").strip() + if rationale.startswith("Begründung:") or rationale.startswith("Begruendung:"): + rationale = rationale.split(":", 1)[1].strip() + + # Laenge begrenzen (max 500 Zeichen) + if len(rationale) > 500: + rationale = rationale[:497] + "..." + + if not rationale or len(rationale) < 10: + errors.append({ + "control_id": parent.control_id, + "error": "LLM-Antwort zu kurz oder leer", + }) + continue + + # Alle Kinder dieses Parents updaten + result = db.execute( + text(""" + UPDATE canonical_controls + SET rationale = :rationale + WHERE parent_control_uuid = CAST(:pid AS uuid) + AND rationale = 'Aus Obligation abgeleitet.' + AND release_state NOT IN ('rejected', 'merged') + """), + {"rationale": rationale, "pid": parent_uuid}, + ) + children_updated += result.rowcount + processed += 1 + + if len(sample_rationales) < 5: + sample_rationales.append({ + "parent": parent.control_id, + "title": parent.title, + "rationale": rationale, + "children_updated": result.rowcount, + }) + + except Exception as e: + logger.error(f"LLM error for {parent.control_id}: {e}") + errors.append({ + "control_id": parent.control_id, + "error": str(e)[:200], + }) + # Rollback um DB-Session nach Fehler nutzbar zu halten + try: + db.rollback() + except Exception: + pass + + db.commit() + + return { + "dry_run": False, + "offset": offset, + "batch_size": batch_size, + "next_offset": offset + batch_size if offset + batch_size < total_parents else None, + "processed_parents": processed, + "children_updated": children_updated, + "total_parents": total_parents, + "total_children": total_children, + "errors": errors[:10], + "sample_rationales": sample_rationales, + } + + +# ============================================================================= +# CONTROL CRUD (CREATE / UPDATE / DELETE) +# ============================================================================= + +@router.post("/controls", status_code=201) +async def create_control(body: ControlCreateRequest): + """Create a new canonical control.""" + import json as _json + import re + # Validate control_id format + if not re.match(r"^[A-Z]{2,6}-[0-9]{3}$", body.control_id): + raise HTTPException(status_code=400, detail="control_id must match DOMAIN-NNN (e.g. AUTH-001)") + if body.severity not in ("low", "medium", "high", "critical"): + raise HTTPException(status_code=400, detail="severity must be low/medium/high/critical") + if body.risk_score is not None and not (0 <= body.risk_score <= 10): + raise HTTPException(status_code=400, detail="risk_score must be 0..10") + + with SessionLocal() as db: + # Resolve framework + fw = db.execute( + text("SELECT id FROM canonical_control_frameworks WHERE framework_id = :fid"), + {"fid": body.framework_id}, + ).fetchone() + if not fw: + raise HTTPException(status_code=404, detail=f"Framework '{body.framework_id}' not found") + + # Check duplicate + existing = db.execute( + text("SELECT id FROM canonical_controls WHERE framework_id = :fid AND control_id = :cid"), + {"fid": str(fw.id), "cid": body.control_id}, + ).fetchone() + if existing: + raise HTTPException(status_code=409, detail=f"Control '{body.control_id}' already exists") + + row = db.execute( + text(f""" + INSERT INTO canonical_controls ( + framework_id, control_id, title, objective, rationale, + scope, requirements, test_procedure, evidence, + severity, risk_score, implementation_effort, evidence_confidence, + open_anchors, release_state, tags, + license_rule, source_original_text, source_citation, + customer_visible, verification_method, category, evidence_type, + target_audience, generation_metadata, + applicable_industries, applicable_company_size, scope_conditions + ) VALUES ( + :fw_id, :cid, :title, :objective, :rationale, + CAST(:scope AS jsonb), CAST(:requirements AS jsonb), + CAST(:test_procedure AS jsonb), CAST(:evidence AS jsonb), + :severity, :risk_score, :effort, :confidence, + CAST(:anchors AS jsonb), :release_state, CAST(:tags AS jsonb), + :license_rule, :source_original_text, + CAST(:source_citation AS jsonb), + :customer_visible, :verification_method, :category, :evidence_type, + :target_audience, CAST(:generation_metadata AS jsonb), + CAST(:applicable_industries AS jsonb), + CAST(:applicable_company_size AS jsonb), + CAST(:scope_conditions AS jsonb) + ) + RETURNING {_CONTROL_COLS} + """), + { + "fw_id": str(fw.id), + "cid": body.control_id, + "title": body.title, + "objective": body.objective, + "rationale": body.rationale, + "scope": _json.dumps(body.scope), + "requirements": _json.dumps(body.requirements), + "test_procedure": _json.dumps(body.test_procedure), + "evidence": _json.dumps(body.evidence), + "severity": body.severity, + "risk_score": body.risk_score, + "effort": body.implementation_effort, + "confidence": body.evidence_confidence, + "anchors": _json.dumps(body.open_anchors), + "release_state": body.release_state, + "tags": _json.dumps(body.tags), + "license_rule": body.license_rule, + "source_original_text": body.source_original_text, + "source_citation": _json.dumps(body.source_citation) if body.source_citation else None, + "customer_visible": body.customer_visible, + "verification_method": body.verification_method, + "category": body.category, + "evidence_type": body.evidence_type, + "target_audience": body.target_audience, + "generation_metadata": _json.dumps(body.generation_metadata) if body.generation_metadata else None, + "applicable_industries": _json.dumps(body.applicable_industries) if body.applicable_industries else None, + "applicable_company_size": _json.dumps(body.applicable_company_size) if body.applicable_company_size else None, + "scope_conditions": _json.dumps(body.scope_conditions) if body.scope_conditions else None, + }, + ).fetchone() + db.commit() + + return _control_row(row) + + +@router.put("/controls/{control_id}") +async def update_control(control_id: str, body: ControlUpdateRequest): + """Update an existing canonical control (partial update).""" + import json as _json + + updates = body.dict(exclude_none=True) + if not updates: + raise HTTPException(status_code=400, detail="No fields to update") + + if "severity" in updates and updates["severity"] not in ("low", "medium", "high", "critical"): + raise HTTPException(status_code=400, detail="severity must be low/medium/high/critical") + if "risk_score" in updates and updates["risk_score"] is not None and not (0 <= updates["risk_score"] <= 10): + raise HTTPException(status_code=400, detail="risk_score must be 0..10") + + # Build dynamic SET clause + set_parts = [] + params: dict[str, Any] = {"cid": control_id.upper()} + json_fields = {"scope", "requirements", "test_procedure", "evidence", "open_anchors", "tags", + "source_citation", "generation_metadata"} + + for key, val in updates.items(): + col = key + if key in json_fields: + set_parts.append(f"{col} = CAST(:{key} AS jsonb)") + params[key] = _json.dumps(val) + else: + set_parts.append(f"{col} = :{key}") + params[key] = val + + set_parts.append("updated_at = NOW()") + + with SessionLocal() as db: + row = db.execute( + text(f""" + UPDATE canonical_controls + SET {', '.join(set_parts)} + WHERE control_id = :cid + RETURNING {_CONTROL_COLS} + """), + params, + ).fetchone() + if not row: + raise HTTPException(status_code=404, detail="Control not found") + db.commit() + + return _control_row(row) + + +@router.delete("/controls/{control_id}", status_code=204) +async def delete_control(control_id: str): + """Delete a canonical control.""" + with SessionLocal() as db: + result = db.execute( + text("DELETE FROM canonical_controls WHERE control_id = :cid"), + {"cid": control_id.upper()}, + ) + if result.rowcount == 0: + raise HTTPException(status_code=404, detail="Control not found") + db.commit() + + return None + + +# ============================================================================= +# SIMILARITY CHECK +# ============================================================================= + +@router.post("/controls/{control_id}/similarity-check") +async def similarity_check(control_id: str, body: SimilarityCheckRequest): + """Run the too-close detector against a source/candidate text pair.""" + report = await check_similarity(body.source_text, body.candidate_text) + return { + "control_id": control_id.upper(), + "max_exact_run": report.max_exact_run, + "token_overlap": report.token_overlap, + "ngram_jaccard": report.ngram_jaccard, + "embedding_cosine": report.embedding_cosine, + "lcs_ratio": report.lcs_ratio, + "status": report.status, + "details": report.details, + } + + +# ============================================================================= +# CATEGORIES +# ============================================================================= + +@router.get("/categories") +async def list_categories(): + """List all canonical control categories.""" + with SessionLocal() as db: + rows = db.execute( + text("SELECT category_id, label_de, label_en, sort_order FROM canonical_control_categories ORDER BY sort_order") + ).fetchall() + + return [ + { + "category_id": r.category_id, + "label_de": r.label_de, + "label_en": r.label_en, + "sort_order": r.sort_order, + } + for r in rows + ] + + +# ============================================================================= +# SIMILAR CONTROLS (Embedding-based dedup) +# ============================================================================= + +@router.get("/controls/{control_id}/similar") +async def find_similar_controls( + control_id: str, + threshold: float = Query(0.85, ge=0.5, le=1.0), + limit: int = Query(20, ge=1, le=100), +): + """Find controls similar to the given one using embedding cosine similarity.""" + with SessionLocal() as db: + # Get the target control's embedding + target = db.execute( + text(""" + SELECT id, control_id, title, objective + FROM canonical_controls + WHERE control_id = :cid + """), + {"cid": control_id.upper()}, + ).fetchone() + + if not target: + raise HTTPException(status_code=404, detail="Control not found") + + # Find similar controls using pg_vector cosine distance if available, + # otherwise fall back to text-based matching via objective similarity + try: + rows = db.execute( + text(""" + SELECT c.control_id, c.title, c.severity, c.release_state, + c.tags, c.license_rule, c.verification_method, c.category, + 1 - (c.embedding <=> t.embedding) AS similarity + FROM canonical_controls c, canonical_controls t + WHERE t.control_id = :cid + AND c.control_id != :cid + AND c.release_state != 'deprecated' + AND c.embedding IS NOT NULL + AND t.embedding IS NOT NULL + AND 1 - (c.embedding <=> t.embedding) >= :threshold + ORDER BY similarity DESC + LIMIT :lim + """), + {"cid": control_id.upper(), "threshold": threshold, "lim": limit}, + ).fetchall() + + return [ + { + "control_id": r.control_id, + "title": r.title, + "severity": r.severity, + "release_state": r.release_state, + "tags": r.tags or [], + "license_rule": r.license_rule, + "verification_method": r.verification_method, + "category": r.category, + "similarity": round(float(r.similarity), 4), + } + for r in rows + ] + except Exception as e: + logger.warning("Embedding similarity query failed (no embedding column?): %s", e) + return [] + + +# ============================================================================= +# SOURCES & LICENSES +# ============================================================================= + +@router.get("/sources") +async def list_sources(): + """List all registered sources with permission flags.""" + with SessionLocal() as db: + return get_source_permissions(db) + + +@router.get("/licenses") +async def list_licenses(): + """Return the license matrix.""" + with SessionLocal() as db: + return get_license_matrix(db) + + +# ============================================================================= +# V1 ENRICHMENT (Eigenentwicklung → Regulatorische Abdeckung) +# ============================================================================= + +@router.post("/controls/enrich-v1-matches") +async def enrich_v1_matches_endpoint( + dry_run: bool = Query(True, description="Nur zaehlen, nicht schreiben"), + batch_size: int = Query(100, description="Controls pro Durchlauf"), + offset: int = Query(0, description="Offset fuer Paginierung"), +): + """ + Findet regulatorische Abdeckung fuer v1 Eigenentwicklung Controls. + + Eigenentwicklung = generation_strategy='ungrouped', pipeline_version=1, + source_citation IS NULL, parent_control_uuid IS NULL. + + Workflow: + 1. dry_run=true → Statistiken anzeigen + 2. dry_run=false&batch_size=100&offset=0 → Erste 100 verarbeiten + 3. Wiederholen mit next_offset bis fertig + """ + from services.v1_enrichment import enrich_v1_matches + return await enrich_v1_matches( + dry_run=dry_run, + batch_size=batch_size, + offset=offset, + ) + + +@router.get("/controls/{control_id}/v1-matches") +async def get_v1_matches_endpoint(control_id: str): + """ + Gibt regulatorische Matches fuer ein v1 Control zurueck. + + Returns: + Liste von Matches mit Control-Details, Source, Score. + """ + from services.v1_enrichment import get_v1_matches + + # Resolve control_id to UUID + with SessionLocal() as db: + row = db.execute(text(""" + SELECT id FROM canonical_controls WHERE control_id = :cid + """), {"cid": control_id}).fetchone() + + if not row: + raise HTTPException(status_code=404, detail=f"Control {control_id} not found") + + return await get_v1_matches(str(row.id)) + + +# ============================================================================= +# INTERNAL HELPERS +# ============================================================================= + +def _control_row(r) -> dict: + return { + "id": str(r.id), + "framework_id": str(r.framework_id), + "control_id": r.control_id, + "title": r.title, + "objective": r.objective, + "rationale": r.rationale, + "scope": r.scope, + "requirements": r.requirements, + "test_procedure": r.test_procedure, + "evidence": r.evidence, + "severity": r.severity, + "risk_score": float(r.risk_score) if r.risk_score is not None else None, + "implementation_effort": r.implementation_effort, + "evidence_confidence": float(r.evidence_confidence) if r.evidence_confidence is not None else None, + "open_anchors": r.open_anchors, + "release_state": r.release_state, + "tags": r.tags or [], + "license_rule": r.license_rule, + "source_original_text": r.source_original_text, + "source_citation": r.source_citation, + "customer_visible": r.customer_visible, + "verification_method": r.verification_method, + "category": r.category, + "evidence_type": getattr(r, "evidence_type", None), + "target_audience": r.target_audience, + "generation_metadata": r.generation_metadata, + "generation_strategy": getattr(r, "generation_strategy", "ungrouped"), + "applicable_industries": getattr(r, "applicable_industries", None), + "applicable_company_size": getattr(r, "applicable_company_size", None), + "scope_conditions": getattr(r, "scope_conditions", None), + "parent_control_uuid": str(r.parent_control_uuid) if getattr(r, "parent_control_uuid", None) else None, + "parent_control_id": getattr(r, "parent_control_id", None), + "parent_control_title": getattr(r, "parent_control_title", None), + "decomposition_method": getattr(r, "decomposition_method", None), + "pipeline_version": getattr(r, "pipeline_version", None), + "created_at": r.created_at.isoformat() if r.created_at else None, + "updated_at": r.updated_at.isoformat() if r.updated_at else None, + } diff --git a/control-pipeline/api/control_generator_routes.py b/control-pipeline/api/control_generator_routes.py new file mode 100644 index 0000000..efbd311 --- /dev/null +++ b/control-pipeline/api/control_generator_routes.py @@ -0,0 +1,1100 @@ +""" +FastAPI routes for the Control Generator Pipeline. + +Endpoints: + POST /v1/canonical/generate — Start generation run + GET /v1/canonical/generate/status/{job_id} — Job status + GET /v1/canonical/generate/jobs — All jobs + GET /v1/canonical/generate/review-queue — Controls needing review + POST /v1/canonical/generate/review/{control_id} — Complete review + GET /v1/canonical/generate/processed-stats — Processing stats per collection + GET /v1/canonical/blocked-sources — Blocked sources list + POST /v1/canonical/blocked-sources/cleanup — Start cleanup workflow +""" + +import asyncio +import json +import logging +from typing import Optional, List + +from fastapi import APIRouter, HTTPException, Query +from pydantic import BaseModel +from sqlalchemy import text + +from db.session import SessionLocal +from services.control_generator import ( + ControlGeneratorPipeline, + GeneratorConfig, + ALL_COLLECTIONS, + VALID_CATEGORIES, + VALID_DOMAINS, + _classify_regulation, + _detect_category, + _detect_domain, + _llm_local, + _parse_llm_json, + CATEGORY_LIST_STR, +) +from services.citation_backfill import CitationBackfill, BackfillResult +from services.rag_client import get_rag_client + +logger = logging.getLogger(__name__) +router = APIRouter(prefix="/v1/canonical", tags=["control-generator"]) + + +# ============================================================================= +# REQUEST / RESPONSE MODELS +# ============================================================================= + +class GenerateRequest(BaseModel): + domain: Optional[str] = None + collections: Optional[List[str]] = None + max_controls: int = 50 + max_chunks: int = 1000 # Default: process max 1000 chunks per job (respects document boundaries) + batch_size: int = 5 + skip_web_search: bool = False + dry_run: bool = False + regulation_filter: Optional[List[str]] = None # Only process these regulation_code prefixes + skip_prefilter: bool = False # Skip local LLM pre-filter, send all chunks to API + + +class GenerateResponse(BaseModel): + job_id: str + status: str + message: str + total_chunks_scanned: int = 0 + controls_generated: int = 0 + controls_verified: int = 0 + controls_needs_review: int = 0 + controls_too_close: int = 0 + controls_duplicates_found: int = 0 + controls_qa_fixed: int = 0 + errors: list = [] + controls: list = [] + + +class ReviewRequest(BaseModel): + action: str # "approve", "reject", "needs_rework" + release_state: Optional[str] = None # Override release_state + notes: Optional[str] = None + + +class ProcessedStats(BaseModel): + collection: str + total_chunks_estimated: int + processed_chunks: int + pending_chunks: int + direct_adopted: int + llm_reformed: int + skipped: int + + +class BlockedSourceResponse(BaseModel): + id: str + regulation_code: str + document_title: str + reason: str + deletion_status: str + qdrant_collection: Optional[str] = None + marked_at: str + + +# ============================================================================= +# ENDPOINTS +# ============================================================================= + +async def _run_pipeline_background(config: GeneratorConfig, job_id: str): + """Run the pipeline in the background. Uses its own DB session.""" + db = SessionLocal() + try: + config.existing_job_id = job_id + pipeline = ControlGeneratorPipeline(db=db, rag_client=get_rag_client()) + result = await pipeline.run(config) + logger.info( + "Background generation job %s completed: %d controls from %d chunks", + job_id, result.controls_generated, result.total_chunks_scanned, + ) + except Exception as e: + logger.error("Background generation job %s failed: %s", job_id, e) + # Update job as failed + try: + db.execute( + text(""" + UPDATE canonical_generation_jobs + SET status = 'failed', errors = :errors, completed_at = NOW() + WHERE id = CAST(:job_id AS uuid) + """), + {"job_id": job_id, "errors": json.dumps([str(e)])}, + ) + db.commit() + except Exception: + pass + finally: + db.close() + + +@router.post("/generate", response_model=GenerateResponse) +async def start_generation(req: GenerateRequest): + """Start a control generation run (runs in background). + + Returns immediately with job_id. Use GET /generate/status/{job_id} to poll progress. + """ + config = GeneratorConfig( + collections=req.collections, + domain=req.domain, + batch_size=req.batch_size, + max_controls=req.max_controls, + max_chunks=req.max_chunks, + skip_web_search=req.skip_web_search, + dry_run=req.dry_run, + regulation_filter=req.regulation_filter, + skip_prefilter=req.skip_prefilter, + ) + + if req.dry_run: + # Dry run: execute synchronously and return controls + db = SessionLocal() + try: + pipeline = ControlGeneratorPipeline(db=db, rag_client=get_rag_client()) + result = await pipeline.run(config) + return GenerateResponse( + job_id=result.job_id, + status=result.status, + message=f"Dry run: {result.controls_generated} controls from {result.total_chunks_scanned} chunks", + total_chunks_scanned=result.total_chunks_scanned, + controls_generated=result.controls_generated, + controls_verified=result.controls_verified, + controls_needs_review=result.controls_needs_review, + controls_too_close=result.controls_too_close, + controls_duplicates_found=result.controls_duplicates_found, + errors=result.errors, + controls=result.controls, + ) + except Exception as e: + logger.error("Dry run failed: %s", e) + raise HTTPException(status_code=500, detail=str(e)) + finally: + db.close() + + # Create job record first so we can return the ID + db = SessionLocal() + try: + result = db.execute( + text(""" + INSERT INTO canonical_generation_jobs (status, config) + VALUES ('running', :config) + RETURNING id + """), + {"config": json.dumps(config.model_dump())}, + ) + db.commit() + row = result.fetchone() + job_id = str(row[0]) if row else None + except Exception as e: + logger.error("Failed to create job: %s", e) + raise HTTPException(status_code=500, detail=f"Failed to create job: {e}") + finally: + db.close() + + if not job_id: + raise HTTPException(status_code=500, detail="Failed to create job record") + + # Launch pipeline in background + asyncio.create_task(_run_pipeline_background(config, job_id)) + + return GenerateResponse( + job_id=job_id, + status="running", + message="Generation started in background. Poll /generate/status/{job_id} for progress.", + ) + + +@router.get("/generate/status/{job_id}") +async def get_job_status(job_id: str): + """Get status of a generation job.""" + db = SessionLocal() + try: + result = db.execute( + text("SELECT * FROM canonical_generation_jobs WHERE id = CAST(:id AS uuid)"), + {"id": job_id}, + ) + row = result.fetchone() + if not row: + raise HTTPException(status_code=404, detail="Job not found") + + cols = result.keys() + job = dict(zip(cols, row)) + # Serialize datetime fields + for key in ("started_at", "completed_at", "created_at"): + if job.get(key): + job[key] = str(job[key]) + job["id"] = str(job["id"]) + return job + finally: + db.close() + + +@router.get("/generate/jobs") +async def list_jobs( + limit: int = Query(20, ge=1, le=100), + offset: int = Query(0, ge=0), +): + """List all generation jobs.""" + db = SessionLocal() + try: + result = db.execute( + text(""" + SELECT id, status, total_chunks_scanned, controls_generated, + controls_verified, controls_needs_review, controls_too_close, + controls_duplicates_found, created_at, completed_at + FROM canonical_generation_jobs + ORDER BY created_at DESC + LIMIT :limit OFFSET :offset + """), + {"limit": limit, "offset": offset}, + ) + jobs = [] + cols = result.keys() + for row in result: + job = dict(zip(cols, row)) + job["id"] = str(job["id"]) + for key in ("created_at", "completed_at"): + if job.get(key): + job[key] = str(job[key]) + jobs.append(job) + return {"jobs": jobs, "total": len(jobs)} + finally: + db.close() + + +@router.get("/generate/review-queue") +async def get_review_queue( + release_state: str = Query("needs_review", pattern="^(needs_review|too_close|duplicate)$"), + limit: int = Query(50, ge=1, le=200), +): + """Get controls that need manual review.""" + db = SessionLocal() + try: + result = db.execute( + text(""" + SELECT c.id, c.control_id, c.title, c.objective, c.severity, + c.release_state, c.license_rule, c.customer_visible, + c.generation_metadata, c.open_anchors, c.tags, + c.created_at + FROM canonical_controls c + WHERE c.release_state = :state + ORDER BY c.created_at DESC + LIMIT :limit + """), + {"state": release_state, "limit": limit}, + ) + controls = [] + cols = result.keys() + for row in result: + ctrl = dict(zip(cols, row)) + ctrl["id"] = str(ctrl["id"]) + ctrl["created_at"] = str(ctrl["created_at"]) + # Parse JSON fields + for jf in ("generation_metadata", "open_anchors", "tags"): + if isinstance(ctrl.get(jf), str): + try: + ctrl[jf] = json.loads(ctrl[jf]) + except (json.JSONDecodeError, TypeError): + pass + controls.append(ctrl) + return {"controls": controls, "total": len(controls)} + finally: + db.close() + + +@router.post("/generate/review/{control_id}") +async def review_control(control_id: str, req: ReviewRequest): + """Complete review of a generated control.""" + db = SessionLocal() + try: + # Validate control exists and is in reviewable state + result = db.execute( + text("SELECT id, release_state FROM canonical_controls WHERE control_id = :cid"), + {"cid": control_id}, + ) + row = result.fetchone() + if not row: + raise HTTPException(status_code=404, detail="Control not found") + + current_state = row[1] + if current_state not in ("needs_review", "too_close", "duplicate"): + raise HTTPException(status_code=400, detail=f"Control is in state '{current_state}', not reviewable") + + # Determine new state + if req.action == "approve": + new_state = req.release_state or "draft" + elif req.action == "reject": + new_state = "deprecated" + elif req.action == "needs_rework": + new_state = "needs_review" + else: + raise HTTPException(status_code=400, detail=f"Unknown action: {req.action}") + + if new_state not in ("draft", "review", "approved", "deprecated", "needs_review", "too_close", "duplicate"): + raise HTTPException(status_code=400, detail=f"Invalid release_state: {new_state}") + + db.execute( + text(""" + UPDATE canonical_controls + SET release_state = :state, updated_at = NOW() + WHERE control_id = :cid + """), + {"state": new_state, "cid": control_id}, + ) + db.commit() + + return {"control_id": control_id, "release_state": new_state, "action": req.action} + finally: + db.close() + + +class BulkReviewRequest(BaseModel): + release_state: str # Filter: which controls to bulk-review + action: str # "approve" or "reject" + new_state: Optional[str] = None # Override target state + + +@router.post("/generate/bulk-review") +async def bulk_review(req: BulkReviewRequest): + """Bulk review all controls matching a release_state filter. + + Example: reject all needs_review → sets them to deprecated. + """ + if req.release_state not in ("needs_review", "too_close", "duplicate"): + raise HTTPException(status_code=400, detail=f"Invalid filter state: {req.release_state}") + + if req.action == "approve": + target = req.new_state or "draft" + elif req.action == "reject": + target = "deprecated" + else: + raise HTTPException(status_code=400, detail=f"Unknown action: {req.action}") + + if target not in ("draft", "review", "approved", "deprecated", "needs_review"): + raise HTTPException(status_code=400, detail=f"Invalid target state: {target}") + + db = SessionLocal() + try: + result = db.execute( + text(""" + UPDATE canonical_controls + SET release_state = :target, updated_at = NOW() + WHERE release_state = :source + RETURNING control_id + """), + {"source": req.release_state, "target": target}, + ) + affected = [row[0] for row in result] + db.commit() + + return { + "action": req.action, + "source_state": req.release_state, + "target_state": target, + "affected_count": len(affected), + } + finally: + db.close() + + +class QAReclassifyRequest(BaseModel): + limit: int = 100 # How many controls to reclassify per run + dry_run: bool = True # Preview only by default + filter_category: Optional[str] = None # Only reclassify controls of this category + filter_domain_prefix: Optional[str] = None # Only reclassify controls with this prefix + + +@router.post("/generate/qa-reclassify") +async def qa_reclassify(req: QAReclassifyRequest): + """Run QA reclassification on existing controls using local LLM. + + Finds controls where keyword-detection disagrees with current category/domain, + then uses Ollama to determine the correct classification. + """ + db = SessionLocal() + try: + # Load controls to check + where_clauses = ["release_state NOT IN ('deprecated')"] + params = {"limit": req.limit} + if req.filter_category: + where_clauses.append("category = :cat") + params["cat"] = req.filter_category + if req.filter_domain_prefix: + where_clauses.append("control_id LIKE :prefix") + params["prefix"] = f"{req.filter_domain_prefix}-%" + + where_sql = " AND ".join(where_clauses) + rows = db.execute( + text(f""" + SELECT id, control_id, title, objective, category, + COALESCE(requirements::text, '[]') as requirements, + COALESCE(source_original_text, '') as source_text + FROM canonical_controls + WHERE {where_sql} + ORDER BY created_at DESC + LIMIT :limit + """), + params, + ).fetchall() + + results = {"checked": 0, "mismatches": 0, "fixes": [], "errors": []} + + for row in rows: + results["checked"] += 1 + control_id = row[1] + title = row[2] + objective = row[3] or "" + current_category = row[4] + source_text = row[6] or objective + + # Keyword detection on source text + kw_category = _detect_category(source_text) or _detect_category(objective) + kw_domain = _detect_domain(source_text) + current_prefix = control_id.split("-")[0] if "-" in control_id else "" + + # Skip if keyword detection agrees with current classification + if kw_category == current_category and kw_domain == current_prefix: + continue + + results["mismatches"] += 1 + + # Ask Ollama to arbitrate + try: + reqs_text = "" + try: + reqs = json.loads(row[5]) + if isinstance(reqs, list): + reqs_text = ", ".join(str(r) for r in reqs[:3]) + except Exception: + pass + + prompt = f"""Pruefe dieses Compliance-Control auf korrekte Klassifizierung. + +Titel: {title[:100]} +Ziel: {objective[:200]} +Anforderungen: {reqs_text[:200]} + +Aktuelle Zuordnung: domain={current_prefix}, category={current_category} +Keyword-Erkennung: domain={kw_domain}, category={kw_category} + +Welche Zuordnung ist korrekt? Antworte NUR als JSON: +{{"domain": "KUERZEL", "category": "kategorie_name", "reason": "kurze Begruendung"}} + +Domains: AUTH=Authentifizierung, CRYP=Kryptographie, NET=Netzwerk, DATA=Datenschutz, LOG=Logging, ACC=Zugriffskontrolle, SEC=IT-Sicherheit, INC=Vorfallmanagement, AI=KI, COMP=Compliance, GOV=Behoerden, LAB=Arbeitsrecht, FIN=Finanzregulierung, TRD=Gewerbe, ENV=Umwelt, HLT=Gesundheit +Kategorien: {CATEGORY_LIST_STR}""" + + raw = await _llm_local(prompt) + data = _parse_llm_json(raw) + if not data: + continue + + qa_domain = data.get("domain", "").upper() + qa_category = data.get("category", "") + reason = data.get("reason", "") + + fix_entry = { + "control_id": control_id, + "title": title[:80], + "old_category": current_category, + "old_domain": current_prefix, + "new_category": qa_category if qa_category in VALID_CATEGORIES else current_category, + "new_domain": qa_domain if qa_domain in VALID_DOMAINS else current_prefix, + "reason": reason, + } + + category_changed = qa_category in VALID_CATEGORIES and qa_category != current_category + + if category_changed and not req.dry_run: + db.execute( + text(""" + UPDATE canonical_controls + SET category = :category, updated_at = NOW() + WHERE id = :id + """), + {"id": row[0], "category": qa_category}, + ) + fix_entry["applied"] = True + else: + fix_entry["applied"] = False + + results["fixes"].append(fix_entry) + + except Exception as e: + results["errors"].append({"control_id": control_id, "error": str(e)}) + + if not req.dry_run: + db.commit() + + return results + finally: + db.close() + + +@router.get("/generate/processed-stats") +async def get_processed_stats(): + """Get processing statistics per collection.""" + db = SessionLocal() + try: + result = db.execute( + text(""" + SELECT + collection, + COUNT(*) as processed_chunks, + COUNT(*) FILTER (WHERE processing_path = 'structured') as direct_adopted, + COUNT(*) FILTER (WHERE processing_path = 'llm_reform') as llm_reformed, + COUNT(*) FILTER (WHERE processing_path = 'skipped') as skipped + FROM canonical_processed_chunks + GROUP BY collection + ORDER BY collection + """) + ) + stats = [] + cols = result.keys() + for row in result: + stat = dict(zip(cols, row)) + stat["total_chunks_estimated"] = 0 # Would need Qdrant API to get total + stat["pending_chunks"] = 0 + stats.append(stat) + return {"stats": stats} + finally: + db.close() + + +# ============================================================================= +# BLOCKED SOURCES +# ============================================================================= + +@router.get("/blocked-sources") +async def list_blocked_sources(): + """List all blocked (Rule 3) sources.""" + db = SessionLocal() + try: + result = db.execute( + text(""" + SELECT id, regulation_code, document_title, reason, + deletion_status, qdrant_collection, marked_at + FROM canonical_blocked_sources + ORDER BY marked_at DESC + """) + ) + sources = [] + cols = result.keys() + for row in result: + src = dict(zip(cols, row)) + src["id"] = str(src["id"]) + src["marked_at"] = str(src["marked_at"]) + sources.append(src) + return {"sources": sources} + finally: + db.close() + + +@router.post("/blocked-sources/cleanup") +async def start_cleanup(): + """Start cleanup workflow for blocked sources. + + This marks all pending blocked sources for deletion. + Actual RAG chunk deletion and file removal is a separate manual step. + """ + db = SessionLocal() + try: + result = db.execute( + text(""" + UPDATE canonical_blocked_sources + SET deletion_status = 'marked_for_deletion' + WHERE deletion_status = 'pending' + RETURNING regulation_code + """) + ) + marked = [row[0] for row in result] + db.commit() + + return { + "status": "marked_for_deletion", + "marked_count": len(marked), + "regulation_codes": marked, + "message": "Sources marked for deletion. Run manual cleanup to remove RAG chunks and files.", + } + finally: + db.close() + + +# ============================================================================= +# CUSTOMER VIEW FILTER +# ============================================================================= + +@router.get("/controls-customer") +async def get_controls_customer_view( + severity: Optional[str] = Query(None), + domain: Optional[str] = Query(None), +): + """Get controls filtered for customer visibility. + + Rule 3 controls have source_citation and source_original_text hidden. + generation_metadata is NEVER shown to customers. + """ + db = SessionLocal() + try: + query = """ + SELECT c.id, c.control_id, c.title, c.objective, c.rationale, + c.scope, c.requirements, c.test_procedure, c.evidence, + c.severity, c.risk_score, c.implementation_effort, + c.open_anchors, c.release_state, c.tags, + c.license_rule, c.customer_visible, + c.source_original_text, c.source_citation, + c.created_at, c.updated_at + FROM canonical_controls c + WHERE c.release_state IN ('draft', 'approved') + """ + params: dict = {} + + if severity: + query += " AND c.severity = :severity" + params["severity"] = severity + if domain: + query += " AND c.control_id LIKE :domain" + params["domain"] = f"{domain.upper()}-%" + + query += " ORDER BY c.control_id" + + result = db.execute(text(query), params) + controls = [] + cols = result.keys() + for row in result: + ctrl = dict(zip(cols, row)) + ctrl["id"] = str(ctrl["id"]) + for key in ("created_at", "updated_at"): + if ctrl.get(key): + ctrl[key] = str(ctrl[key]) + + # Parse JSON fields + for jf in ("scope", "requirements", "test_procedure", "evidence", + "open_anchors", "tags", "source_citation"): + if isinstance(ctrl.get(jf), str): + try: + ctrl[jf] = json.loads(ctrl[jf]) + except (json.JSONDecodeError, TypeError): + pass + + # Customer visibility rules: + # - NEVER show generation_metadata + # - Rule 3: NEVER show source_citation or source_original_text + ctrl.pop("generation_metadata", None) + if not ctrl.get("customer_visible", True): + ctrl["source_citation"] = None + ctrl["source_original_text"] = None + + controls.append(ctrl) + + return {"controls": controls, "total": len(controls)} + finally: + db.close() + + +# ============================================================================= +# CITATION BACKFILL +# ============================================================================= + +class BackfillRequest(BaseModel): + dry_run: bool = True # Default to dry_run for safety + limit: int = 0 # 0 = all controls + + +class BackfillResponse(BaseModel): + status: str + total_controls: int = 0 + matched_hash: int = 0 + matched_regex: int = 0 + matched_llm: int = 0 + unmatched: int = 0 + updated: int = 0 + errors: list = [] + + +_backfill_status: dict = {} + + +async def _run_backfill_background(dry_run: bool, limit: int, backfill_id: str): + """Run backfill in background with own DB session.""" + db = SessionLocal() + try: + backfill = CitationBackfill(db=db, rag_client=get_rag_client()) + result = await backfill.run(dry_run=dry_run, limit=limit) + _backfill_status[backfill_id] = { + "status": "completed", + "total_controls": result.total_controls, + "matched_hash": result.matched_hash, + "matched_regex": result.matched_regex, + "matched_llm": result.matched_llm, + "unmatched": result.unmatched, + "updated": result.updated, + "errors": result.errors[:50], + } + logger.info("Backfill %s completed: %d updated", backfill_id, result.updated) + except Exception as e: + logger.error("Backfill %s failed: %s", backfill_id, e) + _backfill_status[backfill_id] = {"status": "failed", "errors": [str(e)]} + finally: + db.close() + + +@router.post("/generate/backfill-citations", response_model=BackfillResponse) +async def start_backfill(req: BackfillRequest): + """Backfill article/paragraph into existing control source_citations. + + Uses 3-tier matching: hash lookup → regex parse → Ollama LLM. + Default is dry_run=True (preview only, no DB changes). + """ + import uuid + backfill_id = str(uuid.uuid4())[:8] + _backfill_status[backfill_id] = {"status": "running"} + + # Always run in background (RAG index build takes minutes) + asyncio.create_task(_run_backfill_background(req.dry_run, req.limit, backfill_id)) + return BackfillResponse( + status=f"running (id={backfill_id})", + ) + + +@router.get("/generate/backfill-status/{backfill_id}") +async def get_backfill_status(backfill_id: str): + """Get status of a backfill job.""" + status = _backfill_status.get(backfill_id) + if not status: + raise HTTPException(status_code=404, detail="Backfill job not found") + return status + + +# ============================================================================= +# DOMAIN + TARGET AUDIENCE BACKFILL +# ============================================================================= + +class DomainBackfillRequest(BaseModel): + dry_run: bool = True + job_id: Optional[str] = None # Only backfill controls from this job + limit: int = 0 # 0 = all + +_domain_backfill_status: dict = {} + + +async def _run_domain_backfill(req: DomainBackfillRequest, backfill_id: str): + """Backfill domain, category, and target_audience for existing controls using Anthropic.""" + import os + import httpx + + ANTHROPIC_API_KEY = os.getenv("ANTHROPIC_API_KEY", "") + ANTHROPIC_MODEL = os.getenv("CONTROL_GEN_ANTHROPIC_MODEL", "claude-sonnet-4-6") + + if not ANTHROPIC_API_KEY: + _domain_backfill_status[backfill_id] = { + "status": "failed", "error": "ANTHROPIC_API_KEY not set" + } + return + + db = SessionLocal() + try: + # Find controls needing backfill + where_clauses = ["(target_audience IS NULL OR target_audience = '[]' OR target_audience = 'null')"] + params: dict = {} + if req.job_id: + where_clauses.append("generation_metadata->>'job_id' = :job_id") + params["job_id"] = req.job_id + + query = f""" + SELECT id, control_id, title, objective, category, source_original_text, tags + FROM canonical_controls + WHERE {' AND '.join(where_clauses)} + ORDER BY control_id + """ + if req.limit > 0: + query += f" LIMIT {req.limit}" + + result = db.execute(text(query), params) + controls = [dict(zip(result.keys(), row)) for row in result] + + total = len(controls) + updated = 0 + errors = [] + + _domain_backfill_status[backfill_id] = { + "status": "running", "total": total, "updated": 0, "errors": [] + } + + # Process in batches of 10 + BATCH_SIZE = 10 + for batch_start in range(0, total, BATCH_SIZE): + batch = controls[batch_start:batch_start + BATCH_SIZE] + + entries = [] + for idx, ctrl in enumerate(batch): + text_for_analysis = ctrl.get("objective") or ctrl.get("title") or "" + original = ctrl.get("source_original_text") or "" + if original: + text_for_analysis += f"\n\nQuelltext-Auszug: {original[:500]}" + entries.append( + f"--- CONTROL {idx + 1}: {ctrl['control_id']} ---\n" + f"Titel: {ctrl.get('title', '')}\n" + f"Objective: {text_for_analysis[:800]}\n" + f"Tags: {json.dumps(ctrl.get('tags', []))}" + ) + + prompt = f"""Analysiere die folgenden {len(batch)} Controls und bestimme fuer jedes: +1. domain: Das Fachgebiet (AUTH, CRYP, NET, DATA, LOG, ACC, SEC, INC, AI, COMP, GOV, LAB, FIN, TRD, ENV, HLT) +2. category: Die Kategorie (encryption, authentication, network, data_protection, logging, incident, continuity, compliance, supply_chain, physical, personnel, application, system, risk, governance, hardware, identity, public_administration, labor_law, finance, trade_regulation, environmental, health) +3. target_audience: Liste der Zielgruppen (moegliche Werte: "unternehmen", "behoerden", "entwickler", "datenschutzbeauftragte", "geschaeftsfuehrung", "it-abteilung", "rechtsabteilung", "compliance-officer", "personalwesen", "einkauf", "produktion", "vertrieb", "gesundheitswesen", "finanzwesen", "oeffentlicher_dienst") + +Antworte mit einem JSON-Array mit {len(batch)} Objekten. Jedes Objekt hat: +- control_index: 1-basierter Index +- domain: Fachgebiet-Kuerzel +- category: Kategorie +- target_audience: Liste der Zielgruppen + +{"".join(entries)}""" + + try: + headers = { + "x-api-key": ANTHROPIC_API_KEY, + "anthropic-version": "2023-06-01", + "content-type": "application/json", + } + payload = { + "model": ANTHROPIC_MODEL, + "max_tokens": 4096, + "system": "Du bist ein Compliance-Experte. Klassifiziere Controls nach Fachgebiet und Zielgruppe. Antworte NUR mit validem JSON.", + "messages": [{"role": "user", "content": prompt}], + } + + async with httpx.AsyncClient(timeout=60.0) as client: + resp = await client.post( + "https://api.anthropic.com/v1/messages", + headers=headers, + json=payload, + ) + if resp.status_code != 200: + errors.append(f"Anthropic API {resp.status_code} at batch {batch_start}") + continue + + raw = resp.json().get("content", [{}])[0].get("text", "") + + # Parse response + import re + bracket_match = re.search(r"\[.*\]", raw, re.DOTALL) + if not bracket_match: + errors.append(f"No JSON array in response at batch {batch_start}") + continue + + results_list = json.loads(bracket_match.group(0)) + + for item in results_list: + idx = item.get("control_index", 0) - 1 + if idx < 0 or idx >= len(batch): + continue + ctrl = batch[idx] + ctrl_id = str(ctrl["id"]) + + new_domain = item.get("domain", "") + new_category = item.get("category", "") + new_audience = item.get("target_audience", []) + + if not isinstance(new_audience, list): + new_audience = [] + + # Build new control_id from domain if domain changed + old_prefix = ctrl["control_id"].split("-")[0] if ctrl["control_id"] else "" + new_prefix = new_domain.upper()[:4] if new_domain else old_prefix + + if not req.dry_run: + update_parts = [] + update_params: dict = {"ctrl_id": ctrl_id} + + if new_category: + update_parts.append("category = :category") + update_params["category"] = new_category + + if new_audience: + update_parts.append("target_audience = :target_audience") + update_params["target_audience"] = json.dumps(new_audience) + + # Note: We do NOT rename control_ids here — that would + # break references and cause unique constraint violations. + + if update_parts: + update_parts.append("updated_at = NOW()") + db.execute( + text(f"UPDATE canonical_controls SET {', '.join(update_parts)} WHERE id = CAST(:ctrl_id AS uuid)"), + update_params, + ) + updated += 1 + + if not req.dry_run: + db.commit() + + except Exception as e: + errors.append(f"Batch {batch_start}: {str(e)}") + db.rollback() + + _domain_backfill_status[backfill_id] = { + "status": "running", "total": total, "updated": updated, + "progress": f"{min(batch_start + BATCH_SIZE, total)}/{total}", + "errors": errors[-10:], + } + + _domain_backfill_status[backfill_id] = { + "status": "completed", "total": total, "updated": updated, + "errors": errors[-50:], + } + logger.info("Domain backfill %s completed: %d/%d updated", backfill_id, updated, total) + + except Exception as e: + logger.error("Domain backfill %s failed: %s", backfill_id, e) + _domain_backfill_status[backfill_id] = {"status": "failed", "error": str(e)} + finally: + db.close() + + +@router.post("/generate/backfill-domain") +async def start_domain_backfill(req: DomainBackfillRequest): + """Backfill domain, category, and target_audience for controls using Anthropic API. + + Finds controls where target_audience is NULL and enriches them. + Default is dry_run=True (preview only). + """ + import uuid + backfill_id = str(uuid.uuid4())[:8] + _domain_backfill_status[backfill_id] = {"status": "starting"} + asyncio.create_task(_run_domain_backfill(req, backfill_id)) + return {"status": "running", "backfill_id": backfill_id, + "message": f"Domain backfill started. Poll /generate/backfill-status/{backfill_id}"} + + +@router.get("/generate/domain-backfill-status/{backfill_id}") +async def get_domain_backfill_status(backfill_id: str): + """Get status of a domain backfill job.""" + status = _domain_backfill_status.get(backfill_id) + if not status: + raise HTTPException(status_code=404, detail="Domain backfill job not found") + return status + + +# --------------------------------------------------------------------------- +# Source-Type Backfill — Classify law vs guideline vs standard vs restricted +# --------------------------------------------------------------------------- + +class SourceTypeBackfillRequest(BaseModel): + dry_run: bool = True + + +_source_type_backfill_status: dict = {} + + +async def _run_source_type_backfill(dry_run: bool, backfill_id: str): + """Backfill source_type into source_citation JSONB for all controls.""" + db = SessionLocal() + try: + # Find controls with source_citation that lack source_type + rows = db.execute(text(""" + SELECT control_id, source_citation, generation_metadata + FROM compliance.canonical_controls + WHERE source_citation IS NOT NULL + AND (source_citation->>'source_type' IS NULL + OR source_citation->>'source_type' = '') + """)).fetchall() + + total = len(rows) + updated = 0 + already_correct = 0 + errors = [] + + _source_type_backfill_status[backfill_id] = { + "status": "running", "total": total, "updated": 0, "dry_run": dry_run, + } + + for row in rows: + cid = row[0] + citation = row[1] if isinstance(row[1], dict) else json.loads(row[1] or "{}") + metadata = row[2] if isinstance(row[2], dict) else json.loads(row[2] or "{}") + + # Get regulation_code from metadata + reg_code = metadata.get("source_regulation", "") + if not reg_code: + # Try to infer from source name + errors.append(f"{cid}: no source_regulation in metadata") + continue + + # Classify + license_info = _classify_regulation(reg_code) + source_type = license_info.get("source_type", "restricted") + + # Update citation + citation["source_type"] = source_type + + if not dry_run: + db.execute(text(""" + UPDATE compliance.canonical_controls + SET source_citation = :citation + WHERE control_id = :cid + """), {"citation": json.dumps(citation), "cid": cid}) + if updated % 100 == 0: + db.commit() + updated += 1 + + if not dry_run: + db.commit() + + # Count distribution + dist_query = db.execute(text(""" + SELECT source_citation->>'source_type' as st, COUNT(*) + FROM compliance.canonical_controls + WHERE source_citation IS NOT NULL + AND source_citation->>'source_type' IS NOT NULL + GROUP BY st + """)).fetchall() if not dry_run else [] + + distribution = {r[0]: r[1] for r in dist_query} + + _source_type_backfill_status[backfill_id] = { + "status": "completed", "total": total, "updated": updated, + "dry_run": dry_run, "distribution": distribution, + "errors": errors[:50], + } + logger.info("Source-type backfill %s completed: %d/%d updated (dry_run=%s)", + backfill_id, updated, total, dry_run) + + except Exception as e: + logger.error("Source-type backfill %s failed: %s", backfill_id, e) + _source_type_backfill_status[backfill_id] = {"status": "failed", "error": str(e)} + finally: + db.close() + + +@router.post("/generate/backfill-source-type") +async def start_source_type_backfill(req: SourceTypeBackfillRequest): + """Backfill source_type (law/guideline/standard/restricted) into source_citation JSONB. + + Classifies each control's source as binding law, authority guideline, + voluntary standard, or restricted norm based on regulation_code. + Default is dry_run=True (preview only). + """ + import uuid + backfill_id = str(uuid.uuid4())[:8] + _source_type_backfill_status[backfill_id] = {"status": "starting"} + asyncio.create_task(_run_source_type_backfill(req.dry_run, backfill_id)) + return { + "status": "running", + "backfill_id": backfill_id, + "message": f"Source-type backfill started. Poll /generate/source-type-backfill-status/{backfill_id}", + } + + +@router.get("/generate/source-type-backfill-status/{backfill_id}") +async def get_source_type_backfill_status(backfill_id: str): + """Get status of a source-type backfill job.""" + status = _source_type_backfill_status.get(backfill_id) + if not status: + raise HTTPException(status_code=404, detail="Source-type backfill job not found") + return status diff --git a/control-pipeline/config.py b/control-pipeline/config.py new file mode 100644 index 0000000..5e4a2da --- /dev/null +++ b/control-pipeline/config.py @@ -0,0 +1,67 @@ +import os + + +class Settings: + """Environment-based configuration for control-pipeline.""" + + # Database (compliance schema) + DATABASE_URL: str = os.getenv( + "DATABASE_URL", + "postgresql://breakpilot:breakpilot123@localhost:5432/breakpilot_db", + ) + SCHEMA_SEARCH_PATH: str = os.getenv( + "SCHEMA_SEARCH_PATH", "compliance,core,public" + ) + + # Qdrant (vector search for dedup) + QDRANT_URL: str = os.getenv("QDRANT_URL", "http://localhost:6333") + QDRANT_API_KEY: str = os.getenv("QDRANT_API_KEY", "") + + # Embedding Service + EMBEDDING_SERVICE_URL: str = os.getenv( + "EMBEDDING_SERVICE_URL", "http://embedding-service:8087" + ) + + # LLM - Anthropic + ANTHROPIC_API_KEY: str = os.getenv("ANTHROPIC_API_KEY", "") + CONTROL_GEN_ANTHROPIC_MODEL: str = os.getenv( + "CONTROL_GEN_ANTHROPIC_MODEL", "claude-sonnet-4-6" + ) + DECOMPOSITION_LLM_MODEL: str = os.getenv( + "DECOMPOSITION_LLM_MODEL", "claude-haiku-4-5-20251001" + ) + CONTROL_GEN_LLM_TIMEOUT: int = int( + os.getenv("CONTROL_GEN_LLM_TIMEOUT", "180") + ) + + # LLM - Ollama (fallback) + OLLAMA_URL: str = os.getenv( + "OLLAMA_URL", "http://host.docker.internal:11434" + ) + CONTROL_GEN_OLLAMA_MODEL: str = os.getenv( + "CONTROL_GEN_OLLAMA_MODEL", "qwen3.5:35b-a3b" + ) + + # SDK Service (for RAG search proxy) + SDK_URL: str = os.getenv( + "SDK_URL", "http://ai-compliance-sdk:8090" + ) + + # Auth + JWT_SECRET: str = os.getenv("JWT_SECRET", "") + + # Server + PORT: int = int(os.getenv("PORT", "8098")) + LOG_LEVEL: str = os.getenv("LOG_LEVEL", "INFO") + ENVIRONMENT: str = os.getenv("ENVIRONMENT", "development") + + # Pipeline + DECOMPOSITION_BATCH_SIZE: int = int( + os.getenv("DECOMPOSITION_BATCH_SIZE", "5") + ) + DECOMPOSITION_LLM_TIMEOUT: int = int( + os.getenv("DECOMPOSITION_LLM_TIMEOUT", "120") + ) + + +settings = Settings() diff --git a/control-pipeline/data/__init__.py b/control-pipeline/data/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/control-pipeline/data/source_type_classification.py b/control-pipeline/data/source_type_classification.py new file mode 100644 index 0000000..fbfbe25 --- /dev/null +++ b/control-pipeline/data/source_type_classification.py @@ -0,0 +1,205 @@ +""" +Source-Type-Klassifikation fuer Regulierungen und Frameworks. + +Dreistufiges Modell der normativen Verbindlichkeit: + + Stufe 1 — GESETZ (law): + Rechtlich bindend. Bussgeld bei Verstoss. + Beispiele: DSGVO, NIS2, AI Act, CRA + + Stufe 2 — LEITLINIE (guideline): + Offizielle Auslegungshilfe von Aufsichtsbehoerden. + Beweislastumkehr: Wer abweicht, muss begruenden warum. + Beispiele: EDPB-Leitlinien, BSI-Standards, WP29-Dokumente + + Stufe 3 — FRAMEWORK (framework): + Freiwillige Best Practices, nicht rechtsverbindlich. + Aber: Koennen als "Stand der Technik" herangezogen werden. + Beispiele: ENISA, NIST, OWASP, OECD, CISA + +Mapping: source_regulation (aus control_parent_links) -> source_type +""" + +# --- Typ-Definitionen --- +SOURCE_TYPE_LAW = "law" # Gesetz/Verordnung/Richtlinie — normative_strength bleibt +SOURCE_TYPE_GUIDELINE = "guideline" # Leitlinie/Standard — max "should" +SOURCE_TYPE_FRAMEWORK = "framework" # Framework/Best Practice — max "may" + +# Max erlaubte normative_strength pro source_type +# DB-Constraint erlaubt: must, should, may (NICHT "can") +NORMATIVE_STRENGTH_CAP: dict[str, str] = { + SOURCE_TYPE_LAW: "must", # keine Begrenzung + SOURCE_TYPE_GUIDELINE: "should", # max "should" + SOURCE_TYPE_FRAMEWORK: "may", # max "may" (= "kann") +} + +# Reihenfolge fuer Vergleiche (hoeher = staerker) +STRENGTH_ORDER: dict[str, int] = { + "may": 1, # KANN (DB-Wert) + "can": 1, # Alias — wird in cap_normative_strength zu "may" normalisiert + "should": 2, + "must": 3, +} + + +def cap_normative_strength(original: str, source_type: str) -> str: + """ + Begrenzt die normative_strength basierend auf dem source_type. + + Beispiel: + cap_normative_strength("must", "framework") -> "may" + cap_normative_strength("should", "law") -> "should" + cap_normative_strength("must", "guideline") -> "should" + """ + cap = NORMATIVE_STRENGTH_CAP.get(source_type, "must") + cap_level = STRENGTH_ORDER.get(cap, 3) + original_level = STRENGTH_ORDER.get(original, 3) + if original_level > cap_level: + return cap + return original + + +def get_highest_source_type(source_types: list[str]) -> str: + """ + Bestimmt den hoechsten source_type aus einer Liste. + Ein Gesetz uebertrumpft alles. + + Beispiel: + get_highest_source_type(["framework", "law"]) -> "law" + get_highest_source_type(["framework", "guideline"]) -> "guideline" + """ + type_order = {SOURCE_TYPE_FRAMEWORK: 1, SOURCE_TYPE_GUIDELINE: 2, SOURCE_TYPE_LAW: 3} + if not source_types: + return SOURCE_TYPE_FRAMEWORK + return max(source_types, key=lambda t: type_order.get(t, 0)) + + +# ============================================================================ +# Klassifikation: source_regulation -> source_type +# +# Diese Map wird fuer den Backfill und zukuenftige Pipeline-Runs verwendet. +# Neue Regulierungen hier eintragen! +# ============================================================================ + +SOURCE_REGULATION_CLASSIFICATION: dict[str, str] = { + # --- EU-Verordnungen (unmittelbar bindend) --- + "DSGVO (EU) 2016/679": SOURCE_TYPE_LAW, + "KI-Verordnung (EU) 2024/1689": SOURCE_TYPE_LAW, + "Cyber Resilience Act (CRA)": SOURCE_TYPE_LAW, + "NIS2-Richtlinie (EU) 2022/2555": SOURCE_TYPE_LAW, + "Data Act": SOURCE_TYPE_LAW, + "Data Governance Act (DGA)": SOURCE_TYPE_LAW, + "Markets in Crypto-Assets (MiCA)": SOURCE_TYPE_LAW, + "Maschinenverordnung (EU) 2023/1230": SOURCE_TYPE_LAW, + "Batterieverordnung (EU) 2023/1542": SOURCE_TYPE_LAW, + "AML-Verordnung": SOURCE_TYPE_LAW, + + # --- EU-Richtlinien (nach nationaler Umsetzung bindend) --- + # Fuer Compliance-Zwecke wie Gesetze behandeln + + # --- Nationale Gesetze --- + "Bundesdatenschutzgesetz (BDSG)": SOURCE_TYPE_LAW, + "Telekommunikationsgesetz": SOURCE_TYPE_LAW, + "Telekommunikationsgesetz Oesterreich": SOURCE_TYPE_LAW, + "Gewerbeordnung (GewO)": SOURCE_TYPE_LAW, + "Handelsgesetzbuch (HGB)": SOURCE_TYPE_LAW, + "Abgabenordnung (AO)": SOURCE_TYPE_LAW, + "IFRS-Übernahmeverordnung": SOURCE_TYPE_LAW, + "Österreichisches Datenschutzgesetz (DSG)": SOURCE_TYPE_LAW, + "LOPDGDD - Ley Orgánica de Protección de Datos (Spanien)": SOURCE_TYPE_LAW, + "Loi Informatique et Libertés (Frankreich)": SOURCE_TYPE_LAW, + "Információs önrendelkezési jog törvény (Ungarn)": SOURCE_TYPE_LAW, + "EU Blue Guide 2022": SOURCE_TYPE_LAW, + + # --- EDPB/WP29 Leitlinien (offizielle Auslegungshilfe) --- + "EDPB Leitlinien 01/2019 (Zertifizierung)": SOURCE_TYPE_GUIDELINE, + "EDPB Leitlinien 01/2020 (Datentransfers)": SOURCE_TYPE_GUIDELINE, + "EDPB Leitlinien 01/2020 (Vernetzte Fahrzeuge)": SOURCE_TYPE_GUIDELINE, + "EDPB Leitlinien 01/2022 (BCR)": SOURCE_TYPE_GUIDELINE, + "EDPB Leitlinien 01/2024 (Berechtigtes Interesse)": SOURCE_TYPE_GUIDELINE, + "EDPB Leitlinien 04/2019 (Data Protection by Design)": SOURCE_TYPE_GUIDELINE, + "EDPB Leitlinien 05/2020 - Einwilligung": SOURCE_TYPE_GUIDELINE, + "EDPB Leitlinien 07/2020 (Datentransfers)": SOURCE_TYPE_GUIDELINE, + "EDPB Leitlinien 08/2020 (Social Media)": SOURCE_TYPE_GUIDELINE, + "EDPB Leitlinien 09/2022 (Data Breach)": SOURCE_TYPE_GUIDELINE, + "EDPB Leitlinien 09/2022 - Meldung von Datenschutzverletzungen": SOURCE_TYPE_GUIDELINE, + "EDPB Empfehlungen 01/2020 - Ergaenzende Massnahmen fuer Datentransfers": SOURCE_TYPE_GUIDELINE, + "EDPB Leitlinien - Berechtigtes Interesse (Art. 6(1)(f))": SOURCE_TYPE_GUIDELINE, + "WP244 Leitlinien (Profiling)": SOURCE_TYPE_GUIDELINE, + "WP251 Leitlinien (Profiling)": SOURCE_TYPE_GUIDELINE, + "WP260 Leitlinien (Transparenz)": SOURCE_TYPE_GUIDELINE, + + # --- BSI Standards (behoerdliche technische Richtlinien) --- + "BSI-TR-03161-1": SOURCE_TYPE_GUIDELINE, + "BSI-TR-03161-2": SOURCE_TYPE_GUIDELINE, + "BSI-TR-03161-3": SOURCE_TYPE_GUIDELINE, + + # --- ENISA (EU-Agentur, aber Empfehlungen nicht rechtsverbindlich) --- + "ENISA Cybersecurity State 2024": SOURCE_TYPE_FRAMEWORK, + "ENISA ICS/SCADA Dependencies": SOURCE_TYPE_FRAMEWORK, + "ENISA Supply Chain Good Practices": SOURCE_TYPE_FRAMEWORK, + "ENISA Threat Landscape Supply Chain": SOURCE_TYPE_FRAMEWORK, + + # --- NIST (US-Standards, international als Best Practice) --- + "NIST AI Risk Management Framework": SOURCE_TYPE_FRAMEWORK, + "NIST Cybersecurity Framework 2.0": SOURCE_TYPE_FRAMEWORK, + "NIST SP 800-207 (Zero Trust)": SOURCE_TYPE_FRAMEWORK, + "NIST SP 800-218 (SSDF)": SOURCE_TYPE_FRAMEWORK, + "NIST SP 800-53 Rev. 5": SOURCE_TYPE_FRAMEWORK, + "NIST SP 800-63-3": SOURCE_TYPE_FRAMEWORK, + + # --- OWASP (Community-Standards) --- + "OWASP API Security Top 10 (2023)": SOURCE_TYPE_FRAMEWORK, + "OWASP ASVS 4.0": SOURCE_TYPE_FRAMEWORK, + "OWASP MASVS 2.0": SOURCE_TYPE_FRAMEWORK, + "OWASP SAMM 2.0": SOURCE_TYPE_FRAMEWORK, + "OWASP Top 10 (2021)": SOURCE_TYPE_FRAMEWORK, + + # --- Sonstige Frameworks --- + "OECD KI-Empfehlung": SOURCE_TYPE_FRAMEWORK, + "CISA Secure by Design": SOURCE_TYPE_FRAMEWORK, +} + + +def classify_source_regulation(source_regulation: str) -> str: + """ + Klassifiziert eine source_regulation als law, guideline oder framework. + + Verwendet exaktes Matching gegen die Map. Bei unbekannten Quellen + wird anhand von Schluesselwoertern geraten, Fallback ist 'framework' + (konservativstes Ergebnis). + """ + if not source_regulation: + return SOURCE_TYPE_FRAMEWORK + + # Exaktes Match + if source_regulation in SOURCE_REGULATION_CLASSIFICATION: + return SOURCE_REGULATION_CLASSIFICATION[source_regulation] + + # Heuristik fuer unbekannte Quellen + lower = source_regulation.lower() + + # Gesetze erkennen + law_indicators = [ + "verordnung", "richtlinie", "gesetz", "directive", "regulation", + "(eu)", "(eg)", "act", "ley", "loi", "törvény", "código", + ] + if any(ind in lower for ind in law_indicators): + return SOURCE_TYPE_LAW + + # Leitlinien erkennen + guideline_indicators = [ + "edpb", "leitlinie", "guideline", "wp2", "bsi", "empfehlung", + ] + if any(ind in lower for ind in guideline_indicators): + return SOURCE_TYPE_GUIDELINE + + # Frameworks erkennen + framework_indicators = [ + "enisa", "nist", "owasp", "oecd", "cisa", "framework", "iso", + ] + if any(ind in lower for ind in framework_indicators): + return SOURCE_TYPE_FRAMEWORK + + # Konservativ: unbekannt = framework (geringste Verbindlichkeit) + return SOURCE_TYPE_FRAMEWORK diff --git a/control-pipeline/db/__init__.py b/control-pipeline/db/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/control-pipeline/db/session.py b/control-pipeline/db/session.py new file mode 100644 index 0000000..0004552 --- /dev/null +++ b/control-pipeline/db/session.py @@ -0,0 +1,37 @@ +"""Database session factory for control-pipeline. + +Connects to the shared PostgreSQL with search_path set to compliance schema. +""" + +from sqlalchemy import create_engine, event +from sqlalchemy.orm import sessionmaker + +from config import settings + +engine = create_engine( + settings.DATABASE_URL, + pool_pre_ping=True, + pool_size=5, + max_overflow=10, + echo=False, +) + + +@event.listens_for(engine, "connect") +def set_search_path(dbapi_connection, connection_record): + cursor = dbapi_connection.cursor() + cursor.execute(f"SET search_path TO {settings.SCHEMA_SEARCH_PATH}") + cursor.close() + dbapi_connection.commit() + + +SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) + + +def get_db(): + """FastAPI dependency for DB sessions.""" + db = SessionLocal() + try: + yield db + finally: + db.close() diff --git a/control-pipeline/main.py b/control-pipeline/main.py new file mode 100644 index 0000000..117b499 --- /dev/null +++ b/control-pipeline/main.py @@ -0,0 +1,88 @@ +import logging +from contextlib import asynccontextmanager + +import uvicorn +from fastapi import FastAPI +from fastapi.middleware.cors import CORSMiddleware + +from config import settings +from db.session import engine + +logging.basicConfig( + level=getattr(logging, settings.LOG_LEVEL, logging.INFO), + format="%(asctime)s [%(name)s] %(levelname)s: %(message)s", +) +logger = logging.getLogger("control-pipeline") + + +@asynccontextmanager +async def lifespan(app: FastAPI): + """Startup: verify DB and Qdrant connectivity.""" + logger.info("Control-Pipeline starting up ...") + + # Verify database connection + try: + with engine.connect() as conn: + conn.execute(__import__("sqlalchemy").text("SELECT 1")) + logger.info("Database connection OK") + except Exception as exc: + logger.error("Database connection failed: %s", exc) + + yield + + logger.info("Control-Pipeline shutting down ...") + + +app = FastAPI( + title="BreakPilot Control Pipeline", + description="Control generation, decomposition, and deduplication pipeline for the BreakPilot compliance platform.", + version="1.0.0", + lifespan=lifespan, +) + +# CORS +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + +# Routers +from api import router as api_router # noqa: E402 + +app.include_router(api_router) + + +# Health +@app.get("/health") +async def health(): + """Liveness probe.""" + db_ok = False + try: + with engine.connect() as conn: + conn.execute(__import__("sqlalchemy").text("SELECT 1")) + db_ok = True + except Exception: + pass + + status = "healthy" if db_ok else "degraded" + return { + "status": status, + "service": "control-pipeline", + "version": "1.0.0", + "dependencies": { + "postgres": "ok" if db_ok else "unavailable", + }, + } + + +if __name__ == "__main__": + uvicorn.run( + "main:app", + host="0.0.0.0", + port=settings.PORT, + reload=False, + log_level="info", + ) diff --git a/control-pipeline/requirements.txt b/control-pipeline/requirements.txt new file mode 100644 index 0000000..5533c45 --- /dev/null +++ b/control-pipeline/requirements.txt @@ -0,0 +1,22 @@ +# Web Framework +fastapi>=0.123.0 +uvicorn[standard]>=0.27.0 + +# Database +SQLAlchemy>=2.0.36 +psycopg2-binary>=2.9.10 + +# HTTP Client +httpx>=0.28.0 + +# Validation +pydantic>=2.5.0 + +# AI - Anthropic Claude +anthropic>=0.75.0 + +# Vector DB (dedup) +qdrant-client>=1.7.0 + +# Auth +python-jose[cryptography]>=3.3.0 diff --git a/control-pipeline/services/__init__.py b/control-pipeline/services/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/control-pipeline/services/anchor_finder.py b/control-pipeline/services/anchor_finder.py new file mode 100644 index 0000000..fe3ebde --- /dev/null +++ b/control-pipeline/services/anchor_finder.py @@ -0,0 +1,187 @@ +""" +Anchor Finder — finds open-source references (OWASP, NIST, ENISA) for controls. + +Two-stage search: + Stage A: RAG-internal search for open-source chunks matching the control topic + Stage B: Web search via DuckDuckGo Instant Answer API (no API key needed) + +Only open-source references (Rule 1+2) are accepted as anchors. +""" + +import logging +from dataclasses import dataclass +from typing import List, Optional + +import httpx + +from .rag_client import ComplianceRAGClient, get_rag_client +from .control_generator import ( + GeneratedControl, + REGULATION_LICENSE_MAP, + _RULE2_PREFIXES, + _RULE3_PREFIXES, + _classify_regulation, +) + +logger = logging.getLogger(__name__) + +# Regulation codes that are safe to reference as open anchors (Rule 1+2) +_OPEN_SOURCE_RULES = {1, 2} + + +@dataclass +class OpenAnchor: + framework: str + ref: str + url: str + + +class AnchorFinder: + """Finds open-source references to anchor generated controls.""" + + def __init__(self, rag_client: Optional[ComplianceRAGClient] = None): + self.rag = rag_client or get_rag_client() + + async def find_anchors( + self, + control: GeneratedControl, + skip_web: bool = False, + min_anchors: int = 2, + ) -> List[OpenAnchor]: + """Find open-source anchors for a control.""" + # Stage A: RAG-internal search + anchors = await self._search_rag_for_open_anchors(control) + + # Stage B: Web search if not enough anchors + if len(anchors) < min_anchors and not skip_web: + web_anchors = await self._search_web(control) + # Deduplicate by framework+ref + existing_keys = {(a.framework, a.ref) for a in anchors} + for wa in web_anchors: + if (wa.framework, wa.ref) not in existing_keys: + anchors.append(wa) + + return anchors + + async def _search_rag_for_open_anchors(self, control: GeneratedControl) -> List[OpenAnchor]: + """Search RAG for chunks from open sources matching the control topic.""" + # Build search query from control title + first 3 tags + tags_str = " ".join(control.tags[:3]) if control.tags else "" + query = f"{control.title} {tags_str}".strip() + + results = await self.rag.search_with_rerank( + query=query, + collection="bp_compliance_ce", + top_k=15, + ) + + anchors: List[OpenAnchor] = [] + seen: set[str] = set() + + for r in results: + if not r.regulation_code: + continue + + # Only accept open-source references + license_info = _classify_regulation(r.regulation_code) + if license_info.get("rule") not in _OPEN_SOURCE_RULES: + continue + + # Build reference key for dedup + ref = r.article or r.category or "" + key = f"{r.regulation_code}:{ref}" + if key in seen: + continue + seen.add(key) + + framework_name = license_info.get("name", r.regulation_name or r.regulation_short or r.regulation_code) + url = r.source_url or self._build_reference_url(r.regulation_code, ref) + + anchors.append(OpenAnchor( + framework=framework_name, + ref=ref, + url=url, + )) + + if len(anchors) >= 5: + break + + return anchors + + async def _search_web(self, control: GeneratedControl) -> List[OpenAnchor]: + """Search DuckDuckGo Instant Answer API for open references.""" + keywords = f"{control.title} security control OWASP NIST" + try: + async with httpx.AsyncClient(timeout=10.0) as client: + resp = await client.get( + "https://api.duckduckgo.com/", + params={ + "q": keywords, + "format": "json", + "no_html": "1", + "skip_disambig": "1", + }, + ) + if resp.status_code != 200: + return [] + + data = resp.json() + anchors: List[OpenAnchor] = [] + + # Parse RelatedTopics + for topic in data.get("RelatedTopics", [])[:10]: + url = topic.get("FirstURL", "") + text = topic.get("Text", "") + + if not url: + continue + + # Only accept known open-source domains + framework = self._identify_framework_from_url(url) + if framework: + anchors.append(OpenAnchor( + framework=framework, + ref=text[:100] if text else url, + url=url, + )) + + if len(anchors) >= 3: + break + + return anchors + + except Exception as e: + logger.warning("Web anchor search failed: %s", e) + return [] + + @staticmethod + def _identify_framework_from_url(url: str) -> Optional[str]: + """Identify if a URL belongs to a known open-source framework.""" + url_lower = url.lower() + if "owasp.org" in url_lower: + return "OWASP" + if "nist.gov" in url_lower or "csrc.nist.gov" in url_lower: + return "NIST" + if "enisa.europa.eu" in url_lower: + return "ENISA" + if "cisa.gov" in url_lower: + return "CISA" + if "eur-lex.europa.eu" in url_lower: + return "EU Law" + return None + + @staticmethod + def _build_reference_url(regulation_code: str, ref: str) -> str: + """Build a reference URL for known frameworks.""" + code = regulation_code.lower() + if code.startswith("owasp"): + return "https://owasp.org/www-project-application-security-verification-standard/" + if code.startswith("nist"): + return "https://csrc.nist.gov/publications" + if code.startswith("enisa"): + return "https://www.enisa.europa.eu/publications" + if code.startswith("eu_"): + return "https://eur-lex.europa.eu/" + if code == "cisa_secure_by_design": + return "https://www.cisa.gov/securebydesign" + return "" diff --git a/control-pipeline/services/batch_dedup_runner.py b/control-pipeline/services/batch_dedup_runner.py new file mode 100644 index 0000000..fa7b18b --- /dev/null +++ b/control-pipeline/services/batch_dedup_runner.py @@ -0,0 +1,618 @@ +"""Batch Dedup Runner — Orchestrates deduplication of ~85k atomare Controls. + +Reduces Pass 0b controls from ~85k to ~18-25k unique Master Controls via: + Phase 1: Intra-Group Dedup — same merge_group_hint → pick best, link rest + (85k → ~52k, mostly title-identical short-circuit, no embeddings) + Phase 2: Cross-Group Dedup — embed masters, search Qdrant for similar + masters with different hints (52k → ~18-25k) + +All Pass 0b controls have pattern_id=NULL. The primary grouping key is +merge_group_hint (format: "action_type:norm_obj:trigger_key"), which +encodes the normalized action, object, and trigger. + +Usage: + runner = BatchDedupRunner(db) + stats = await runner.run(dry_run=True) # preview + stats = await runner.run(dry_run=False) # execute + stats = await runner.run(hint_filter="implement:multi_factor_auth:none") +""" + +import json +import logging +import time +from collections import defaultdict + +from sqlalchemy import text + +from services.control_dedup import ( + canonicalize_text, + ensure_qdrant_collection, + get_embedding, + normalize_action, + normalize_object, + qdrant_search_cross_regulation, + qdrant_upsert, + LINK_THRESHOLD, + REVIEW_THRESHOLD, +) + +logger = logging.getLogger(__name__) + +DEDUP_COLLECTION = "atomic_controls_dedup" + + +# ── Quality Score ──────────────────────────────────────────────────────── + + +def quality_score(control: dict) -> float: + """Score a control by richness of requirements, tests, evidence, and objective. + + Higher score = better candidate for master control. + """ + score = 0.0 + + reqs = control.get("requirements") or "[]" + if isinstance(reqs, str): + try: + reqs = json.loads(reqs) + except (json.JSONDecodeError, TypeError): + reqs = [] + score += len(reqs) * 2.0 + + tests = control.get("test_procedure") or "[]" + if isinstance(tests, str): + try: + tests = json.loads(tests) + except (json.JSONDecodeError, TypeError): + tests = [] + score += len(tests) * 1.5 + + evidence = control.get("evidence") or "[]" + if isinstance(evidence, str): + try: + evidence = json.loads(evidence) + except (json.JSONDecodeError, TypeError): + evidence = [] + score += len(evidence) * 1.0 + + objective = control.get("objective") or "" + score += min(len(objective) / 200, 3.0) + + return score + + +# ── Batch Dedup Runner ─────────────────────────────────────────────────── + + +class BatchDedupRunner: + """Batch dedup orchestrator for existing Pass 0b atomic controls.""" + + def __init__(self, db, collection: str = DEDUP_COLLECTION): + self.db = db + self.collection = collection + self.stats = { + "total_controls": 0, + "unique_hints": 0, + "phase1_groups_processed": 0, + "masters": 0, + "linked": 0, + "review": 0, + "new_controls": 0, + "parent_links_transferred": 0, + "cross_group_linked": 0, + "cross_group_review": 0, + "errors": 0, + "skipped_title_identical": 0, + } + self._progress_phase = "" + self._progress_count = 0 + self._progress_total = 0 + + async def run( + self, + dry_run: bool = False, + hint_filter: str = None, + ) -> dict: + """Run the full batch dedup pipeline. + + Args: + dry_run: If True, compute stats but don't modify DB/Qdrant. + hint_filter: If set, only process groups matching this hint prefix. + + Returns: + Stats dict with counts. + """ + start = time.monotonic() + logger.info("BatchDedup starting (dry_run=%s, hint_filter=%s)", + dry_run, hint_filter) + + if not dry_run: + await ensure_qdrant_collection(collection=self.collection) + + # Phase 1: Intra-group dedup (same merge_group_hint) + self._progress_phase = "phase1" + groups = self._load_merge_groups(hint_filter) + self._progress_total = self.stats["total_controls"] + + for hint, controls in groups: + try: + await self._process_hint_group(hint, controls, dry_run) + self.stats["phase1_groups_processed"] += 1 + except Exception as e: + logger.error("BatchDedup Phase 1 error on hint %s: %s", hint, e) + self.stats["errors"] += 1 + try: + self.db.rollback() + except Exception: + pass + + logger.info( + "BatchDedup Phase 1 done: %d masters, %d linked, %d review", + self.stats["masters"], self.stats["linked"], self.stats["review"], + ) + + # Phase 2: Cross-group dedup via embeddings + if not dry_run: + self._progress_phase = "phase2" + await self._run_cross_group_pass() + + elapsed = time.monotonic() - start + self.stats["elapsed_seconds"] = round(elapsed, 1) + logger.info("BatchDedup completed in %.1fs: %s", elapsed, self.stats) + return self.stats + + def _load_merge_groups(self, hint_filter: str = None) -> list: + """Load all Pass 0b controls grouped by merge_group_hint, largest first.""" + conditions = [ + "decomposition_method = 'pass0b'", + "release_state != 'deprecated'", + "release_state != 'duplicate'", + ] + params = {} + + if hint_filter: + conditions.append("generation_metadata->>'merge_group_hint' LIKE :hf") + params["hf"] = f"{hint_filter}%" + + where = " AND ".join(conditions) + rows = self.db.execute(text(f""" + SELECT id::text, control_id, title, objective, + pattern_id, requirements::text, test_procedure::text, + evidence::text, release_state, + generation_metadata->>'merge_group_hint' as merge_group_hint, + generation_metadata->>'action_object_class' as action_object_class + FROM canonical_controls + WHERE {where} + ORDER BY control_id + """), params).fetchall() + + by_hint = defaultdict(list) + for r in rows: + by_hint[r[9] or ""].append({ + "uuid": r[0], + "control_id": r[1], + "title": r[2], + "objective": r[3], + "pattern_id": r[4], + "requirements": r[5], + "test_procedure": r[6], + "evidence": r[7], + "release_state": r[8], + "merge_group_hint": r[9] or "", + "action_object_class": r[10] or "", + }) + + self.stats["total_controls"] = len(rows) + self.stats["unique_hints"] = len(by_hint) + + sorted_groups = sorted(by_hint.items(), key=lambda x: len(x[1]), reverse=True) + logger.info("BatchDedup loaded %d controls in %d hint groups", + len(rows), len(sorted_groups)) + return sorted_groups + + def _sub_group_by_merge_hint(self, controls: list) -> dict: + """Group controls by merge_group_hint composite key.""" + groups = defaultdict(list) + for c in controls: + hint = c["merge_group_hint"] + if hint: + groups[hint].append(c) + else: + groups[f"__no_hint_{c['uuid']}"].append(c) + return dict(groups) + + async def _process_hint_group( + self, + hint: str, + controls: list, + dry_run: bool, + ): + """Process all controls sharing the same merge_group_hint. + + Within a hint group, all controls share action+object+trigger. + The best-quality control becomes master, rest are linked as duplicates. + """ + if len(controls) < 2: + # Singleton → always master + self.stats["masters"] += 1 + if not dry_run: + await self._embed_and_index(controls[0]) + self._progress_count += 1 + self._log_progress(hint) + return + + # Sort by quality score (best first) + sorted_group = sorted(controls, key=quality_score, reverse=True) + master = sorted_group[0] + self.stats["masters"] += 1 + + if not dry_run: + await self._embed_and_index(master) + + for candidate in sorted_group[1:]: + # All share the same hint → check title similarity + if candidate["title"].strip().lower() == master["title"].strip().lower(): + # Identical title → direct link (no embedding needed) + self.stats["linked"] += 1 + self.stats["skipped_title_identical"] += 1 + if not dry_run: + await self._mark_duplicate(master, candidate, confidence=1.0) + else: + # Different title within same hint → still likely duplicate + # Use embedding to verify + await self._check_and_link_within_group(master, candidate, dry_run) + + self._progress_count += 1 + self._log_progress(hint) + + async def _check_and_link_within_group( + self, + master: dict, + candidate: dict, + dry_run: bool, + ): + """Check if candidate (same hint group) is duplicate of master via embedding.""" + parts = candidate["merge_group_hint"].split(":", 2) + action = parts[0] if len(parts) > 0 else "" + obj = parts[1] if len(parts) > 1 else "" + + canonical = canonicalize_text(action, obj, candidate["title"]) + embedding = await get_embedding(canonical) + + if not embedding: + # Can't embed → link anyway (same hint = same action+object) + self.stats["linked"] += 1 + if not dry_run: + await self._mark_duplicate(master, candidate, confidence=0.90) + return + + # Search the dedup collection (unfiltered — pattern_id is NULL) + results = await qdrant_search_cross_regulation( + embedding, top_k=3, collection=self.collection, + ) + + if not results: + # No Qdrant matches yet (master might not be indexed yet) → link to master + self.stats["linked"] += 1 + if not dry_run: + await self._mark_duplicate(master, candidate, confidence=0.90) + return + + best = results[0] + best_score = best.get("score", 0.0) + best_payload = best.get("payload", {}) + best_uuid = best_payload.get("control_uuid", "") + + if best_score > LINK_THRESHOLD: + self.stats["linked"] += 1 + if not dry_run: + await self._mark_duplicate_to(best_uuid, candidate, confidence=best_score) + elif best_score > REVIEW_THRESHOLD: + self.stats["review"] += 1 + if not dry_run: + self._write_review(candidate, best_payload, best_score) + else: + # Very different despite same hint → new master + self.stats["new_controls"] += 1 + if not dry_run: + await self._index_with_embedding(candidate, embedding) + + async def _run_cross_group_pass(self): + """Phase 2: Find cross-group duplicates among surviving masters. + + After Phase 1, ~52k masters remain. Many have similar semantics + despite different merge_group_hints (e.g. different German spellings). + This pass embeds all masters and finds near-duplicates via Qdrant. + """ + logger.info("BatchDedup Phase 2: Cross-group pass starting...") + + rows = self.db.execute(text(""" + SELECT id::text, control_id, title, + generation_metadata->>'merge_group_hint' as merge_group_hint + FROM canonical_controls + WHERE decomposition_method = 'pass0b' + AND release_state != 'duplicate' + AND release_state != 'deprecated' + ORDER BY control_id + """)).fetchall() + + self._progress_total = len(rows) + self._progress_count = 0 + logger.info("BatchDedup Cross-group: %d masters to check", len(rows)) + cross_linked = 0 + cross_review = 0 + + for i, r in enumerate(rows): + uuid = r[0] + hint = r[3] or "" + parts = hint.split(":", 2) + action = parts[0] if len(parts) > 0 else "" + obj = parts[1] if len(parts) > 1 else "" + + canonical = canonicalize_text(action, obj, r[2]) + embedding = await get_embedding(canonical) + if not embedding: + continue + + results = await qdrant_search_cross_regulation( + embedding, top_k=5, collection=self.collection, + ) + if not results: + continue + + # Find best match from a DIFFERENT hint group + for match in results: + match_score = match.get("score", 0.0) + match_payload = match.get("payload", {}) + match_uuid = match_payload.get("control_uuid", "") + + # Skip self-match + if match_uuid == uuid: + continue + + # Must be a different hint group (otherwise already handled in Phase 1) + match_action = match_payload.get("action_normalized", "") + match_object = match_payload.get("object_normalized", "") + # Simple check: different control UUID is enough + if match_score > LINK_THRESHOLD: + # Mark the worse one as duplicate + try: + self.db.execute(text(""" + UPDATE canonical_controls + SET release_state = 'duplicate', merged_into_uuid = CAST(:master AS uuid) + WHERE id = CAST(:dup AS uuid) + AND release_state != 'duplicate' + """), {"master": match_uuid, "dup": uuid}) + + self.db.execute(text(""" + INSERT INTO control_parent_links + (control_uuid, parent_control_uuid, link_type, confidence) + VALUES (CAST(:cu AS uuid), CAST(:pu AS uuid), 'cross_regulation', :conf) + ON CONFLICT (control_uuid, parent_control_uuid) DO NOTHING + """), {"cu": match_uuid, "pu": uuid, "conf": match_score}) + + # Transfer parent links + transferred = self._transfer_parent_links(match_uuid, uuid) + self.stats["parent_links_transferred"] += transferred + + self.db.commit() + cross_linked += 1 + except Exception as e: + logger.error("BatchDedup cross-group link error %s→%s: %s", + uuid, match_uuid, e) + self.db.rollback() + self.stats["errors"] += 1 + break # Only one cross-link per control + elif match_score > REVIEW_THRESHOLD: + self._write_review( + {"control_id": r[1], "title": r[2], "objective": "", + "merge_group_hint": hint, "pattern_id": None}, + match_payload, match_score, + ) + cross_review += 1 + break + + self._progress_count = i + 1 + if (i + 1) % 500 == 0: + logger.info("BatchDedup Cross-group: %d/%d checked, %d linked, %d review", + i + 1, len(rows), cross_linked, cross_review) + + self.stats["cross_group_linked"] = cross_linked + self.stats["cross_group_review"] = cross_review + logger.info("BatchDedup Cross-group complete: %d linked, %d review", + cross_linked, cross_review) + + # ── Qdrant Helpers ─────────────────────────────────────────────────── + + async def _embed_and_index(self, control: dict): + """Compute embedding and index a control in the dedup Qdrant collection.""" + parts = control["merge_group_hint"].split(":", 2) + action = parts[0] if len(parts) > 0 else "" + obj = parts[1] if len(parts) > 1 else "" + + norm_action = normalize_action(action) + norm_object = normalize_object(obj) + canonical = canonicalize_text(action, obj, control["title"]) + embedding = await get_embedding(canonical) + + if not embedding: + return + + await qdrant_upsert( + point_id=control["uuid"], + embedding=embedding, + payload={ + "control_uuid": control["uuid"], + "control_id": control["control_id"], + "title": control["title"], + "pattern_id": control.get("pattern_id"), + "action_normalized": norm_action, + "object_normalized": norm_object, + "canonical_text": canonical, + "merge_group_hint": control["merge_group_hint"], + }, + collection=self.collection, + ) + + async def _index_with_embedding(self, control: dict, embedding: list): + """Index a control with a pre-computed embedding.""" + parts = control["merge_group_hint"].split(":", 2) + action = parts[0] if len(parts) > 0 else "" + obj = parts[1] if len(parts) > 1 else "" + + norm_action = normalize_action(action) + norm_object = normalize_object(obj) + canonical = canonicalize_text(action, obj, control["title"]) + + await qdrant_upsert( + point_id=control["uuid"], + embedding=embedding, + payload={ + "control_uuid": control["uuid"], + "control_id": control["control_id"], + "title": control["title"], + "pattern_id": control.get("pattern_id"), + "action_normalized": norm_action, + "object_normalized": norm_object, + "canonical_text": canonical, + "merge_group_hint": control["merge_group_hint"], + }, + collection=self.collection, + ) + + # ── DB Write Helpers ───────────────────────────────────────────────── + + async def _mark_duplicate(self, master: dict, candidate: dict, confidence: float): + """Mark candidate as duplicate of master, transfer parent links.""" + try: + self.db.execute(text(""" + UPDATE canonical_controls + SET release_state = 'duplicate', merged_into_uuid = CAST(:master AS uuid) + WHERE id = CAST(:cand AS uuid) + """), {"master": master["uuid"], "cand": candidate["uuid"]}) + + self.db.execute(text(""" + INSERT INTO control_parent_links + (control_uuid, parent_control_uuid, link_type, confidence) + VALUES (CAST(:master AS uuid), CAST(:cand_parent AS uuid), 'dedup_merge', :conf) + ON CONFLICT (control_uuid, parent_control_uuid) DO NOTHING + """), {"master": master["uuid"], "cand_parent": candidate["uuid"], "conf": confidence}) + + transferred = self._transfer_parent_links(master["uuid"], candidate["uuid"]) + self.stats["parent_links_transferred"] += transferred + + self.db.commit() + except Exception as e: + logger.error("BatchDedup _mark_duplicate error %s→%s: %s", + candidate["uuid"], master["uuid"], e) + self.db.rollback() + raise + + async def _mark_duplicate_to(self, master_uuid: str, candidate: dict, confidence: float): + """Mark candidate as duplicate of a Qdrant-matched master.""" + try: + self.db.execute(text(""" + UPDATE canonical_controls + SET release_state = 'duplicate', merged_into_uuid = CAST(:master AS uuid) + WHERE id = CAST(:cand AS uuid) + """), {"master": master_uuid, "cand": candidate["uuid"]}) + + self.db.execute(text(""" + INSERT INTO control_parent_links + (control_uuid, parent_control_uuid, link_type, confidence) + VALUES (CAST(:master AS uuid), CAST(:cand_parent AS uuid), 'dedup_merge', :conf) + ON CONFLICT (control_uuid, parent_control_uuid) DO NOTHING + """), {"master": master_uuid, "cand_parent": candidate["uuid"], "conf": confidence}) + + transferred = self._transfer_parent_links(master_uuid, candidate["uuid"]) + self.stats["parent_links_transferred"] += transferred + + self.db.commit() + except Exception as e: + logger.error("BatchDedup _mark_duplicate_to error %s→%s: %s", + candidate["uuid"], master_uuid, e) + self.db.rollback() + raise + + def _transfer_parent_links(self, master_uuid: str, duplicate_uuid: str) -> int: + """Move existing parent links from duplicate to master.""" + rows = self.db.execute(text(""" + SELECT parent_control_uuid::text, link_type, confidence, + source_regulation, source_article, obligation_candidate_id::text + FROM control_parent_links + WHERE control_uuid = CAST(:dup AS uuid) + AND link_type = 'decomposition' + """), {"dup": duplicate_uuid}).fetchall() + + transferred = 0 + for r in rows: + parent_uuid = r[0] + if parent_uuid == master_uuid: + continue + self.db.execute(text(""" + INSERT INTO control_parent_links + (control_uuid, parent_control_uuid, link_type, confidence, + source_regulation, source_article, obligation_candidate_id) + VALUES (CAST(:cu AS uuid), CAST(:pu AS uuid), :lt, :conf, + :sr, :sa, CAST(:oci AS uuid)) + ON CONFLICT (control_uuid, parent_control_uuid) DO NOTHING + """), { + "cu": master_uuid, + "pu": parent_uuid, + "lt": r[1], + "conf": float(r[2]) if r[2] else 1.0, + "sr": r[3], + "sa": r[4], + "oci": r[5], + }) + transferred += 1 + + return transferred + + def _write_review(self, candidate: dict, matched_payload: dict, score: float): + """Write a dedup review entry for borderline matches.""" + try: + self.db.execute(text(""" + INSERT INTO control_dedup_reviews + (candidate_control_id, candidate_title, candidate_objective, + matched_control_uuid, matched_control_id, + similarity_score, dedup_stage, dedup_details) + VALUES (:ccid, :ct, :co, CAST(:mcu AS uuid), :mci, + :ss, 'batch_dedup', CAST(:dd AS jsonb)) + """), { + "ccid": candidate["control_id"], + "ct": candidate["title"], + "co": candidate.get("objective", ""), + "mcu": matched_payload.get("control_uuid"), + "mci": matched_payload.get("control_id"), + "ss": score, + "dd": json.dumps({ + "merge_group_hint": candidate.get("merge_group_hint", ""), + "pattern_id": candidate.get("pattern_id"), + }), + }) + self.db.commit() + except Exception as e: + logger.error("BatchDedup _write_review error: %s", e) + self.db.rollback() + raise + + # ── Progress ───────────────────────────────────────────────────────── + + def _log_progress(self, hint: str): + """Log progress every 500 controls.""" + if self._progress_count > 0 and self._progress_count % 500 == 0: + logger.info( + "BatchDedup [%s] %d/%d — masters=%d, linked=%d, review=%d", + self._progress_phase, self._progress_count, self._progress_total, + self.stats["masters"], self.stats["linked"], self.stats["review"], + ) + + def get_status(self) -> dict: + """Return current progress stats (for status endpoint).""" + return { + "phase": self._progress_phase, + "progress": self._progress_count, + "total": self._progress_total, + **self.stats, + } diff --git a/control-pipeline/services/citation_backfill.py b/control-pipeline/services/citation_backfill.py new file mode 100644 index 0000000..9222445 --- /dev/null +++ b/control-pipeline/services/citation_backfill.py @@ -0,0 +1,438 @@ +""" +Citation Backfill Service — enrich existing controls with article/paragraph provenance. + +3-tier matching strategy: + Tier 1 — Hash match: sha256(source_original_text) → RAG chunk lookup + Tier 2 — Regex parse: split concatenated "DSGVO Art. 35" → regulation + article + Tier 3 — Ollama LLM: ask local LLM to identify article/paragraph from text +""" + +import hashlib +import json +import logging +import os +import re +from dataclasses import dataclass, field +from datetime import datetime, timezone +from typing import Optional + +import httpx +from sqlalchemy import text +from sqlalchemy.orm import Session + +from .rag_client import ComplianceRAGClient, RAGSearchResult + +logger = logging.getLogger(__name__) + +OLLAMA_URL = os.getenv("OLLAMA_URL", "http://host.docker.internal:11434") +OLLAMA_MODEL = os.getenv("CONTROL_GEN_OLLAMA_MODEL", "qwen3.5:35b-a3b") +LLM_TIMEOUT = float(os.getenv("CONTROL_GEN_LLM_TIMEOUT", "180")) + +ALL_COLLECTIONS = [ + "bp_compliance_ce", + "bp_compliance_gesetze", + "bp_compliance_datenschutz", + "bp_dsfa_corpus", + "bp_legal_templates", +] + +BACKFILL_SYSTEM_PROMPT = ( + "Du bist ein Rechtsexperte. Deine Aufgabe ist es, aus einem Gesetzestext " + "den genauen Artikel und Absatz zu bestimmen. Antworte NUR mit validem JSON." +) + +# Regex to split concatenated source like "DSGVO Art. 35" or "NIS2 Artikel 21 Abs. 2" +_SOURCE_ARTICLE_RE = re.compile( + r"^(.+?)\s+(Art(?:ikel)?\.?\s*\d+.*)$", re.IGNORECASE +) + + +@dataclass +class MatchResult: + article: str + paragraph: str + method: str # "hash", "regex", "llm" + + +@dataclass +class BackfillResult: + total_controls: int = 0 + matched_hash: int = 0 + matched_regex: int = 0 + matched_llm: int = 0 + unmatched: int = 0 + updated: int = 0 + errors: list = field(default_factory=list) + + +class CitationBackfill: + """Backfill article/paragraph into existing control source_citations.""" + + def __init__(self, db: Session, rag_client: ComplianceRAGClient): + self.db = db + self.rag = rag_client + self._rag_index: dict[str, RAGSearchResult] = {} + + async def run(self, dry_run: bool = True, limit: int = 0) -> BackfillResult: + """Main entry: iterate controls missing article/paragraph, match to RAG, update.""" + result = BackfillResult() + + # Load controls needing backfill + controls = self._load_controls_needing_backfill(limit) + result.total_controls = len(controls) + logger.info("Backfill: %d controls need article/paragraph enrichment", len(controls)) + + if not controls: + return result + + # Collect hashes we need to find — only build index for controls with source text + needed_hashes: set[str] = set() + for ctrl in controls: + src = ctrl.get("source_original_text") + if src: + needed_hashes.add(hashlib.sha256(src.encode()).hexdigest()) + + if needed_hashes: + # Build targeted RAG index — only scroll collections that our controls reference + logger.info("Building targeted RAG hash index for %d source texts...", len(needed_hashes)) + await self._build_rag_index_targeted(controls) + logger.info("RAG index built: %d chunks indexed, %d hashes needed", len(self._rag_index), len(needed_hashes)) + else: + logger.info("No source_original_text found — skipping RAG index build") + + # Process each control + for i, ctrl in enumerate(controls): + if i > 0 and i % 100 == 0: + logger.info("Backfill progress: %d/%d processed", i, result.total_controls) + + try: + match = await self._match_control(ctrl) + if match: + if match.method == "hash": + result.matched_hash += 1 + elif match.method == "regex": + result.matched_regex += 1 + elif match.method == "llm": + result.matched_llm += 1 + + if not dry_run: + self._update_control(ctrl, match) + result.updated += 1 + else: + logger.debug( + "DRY RUN: Would update %s with article=%s paragraph=%s (method=%s)", + ctrl["control_id"], match.article, match.paragraph, match.method, + ) + else: + result.unmatched += 1 + + except Exception as e: + error_msg = f"Error backfilling {ctrl.get('control_id', '?')}: {e}" + logger.error(error_msg) + result.errors.append(error_msg) + + if not dry_run: + try: + self.db.commit() + except Exception as e: + logger.error("Backfill commit failed: %s", e) + result.errors.append(f"Commit failed: {e}") + + logger.info( + "Backfill complete: %d total, hash=%d regex=%d llm=%d unmatched=%d updated=%d", + result.total_controls, result.matched_hash, result.matched_regex, + result.matched_llm, result.unmatched, result.updated, + ) + return result + + def _load_controls_needing_backfill(self, limit: int = 0) -> list[dict]: + """Load controls where source_citation exists but lacks separate 'article' key.""" + query = """ + SELECT id, control_id, source_citation, source_original_text, + generation_metadata, license_rule + FROM canonical_controls + WHERE license_rule IN (1, 2) + AND source_citation IS NOT NULL + AND ( + source_citation->>'article' IS NULL + OR source_citation->>'article' = '' + ) + ORDER BY control_id + """ + if limit > 0: + query += f" LIMIT {limit}" + + result = self.db.execute(text(query)) + cols = result.keys() + controls = [] + for row in result: + ctrl = dict(zip(cols, row)) + ctrl["id"] = str(ctrl["id"]) + # Parse JSON fields + for jf in ("source_citation", "generation_metadata"): + if isinstance(ctrl.get(jf), str): + try: + ctrl[jf] = json.loads(ctrl[jf]) + except (json.JSONDecodeError, TypeError): + ctrl[jf] = {} + controls.append(ctrl) + return controls + + async def _build_rag_index_targeted(self, controls: list[dict]): + """Build RAG index by scrolling only collections relevant to our controls. + + Uses regulation codes from generation_metadata to identify which collections + to search, falling back to all collections only if needed. + """ + # Determine which collections are relevant based on regulation codes + regulation_to_collection = self._map_regulations_to_collections(controls) + collections_to_search = set(regulation_to_collection.values()) or set(ALL_COLLECTIONS) + + logger.info("Targeted index: searching %d collections: %s", + len(collections_to_search), ", ".join(collections_to_search)) + + for collection in collections_to_search: + offset = None + page = 0 + seen_offsets: set[str] = set() + while True: + chunks, next_offset = await self.rag.scroll( + collection=collection, offset=offset, limit=200, + ) + if not chunks: + break + for chunk in chunks: + if chunk.text and len(chunk.text.strip()) >= 50: + h = hashlib.sha256(chunk.text.encode()).hexdigest() + self._rag_index[h] = chunk + page += 1 + if page % 50 == 0: + logger.info("Indexing %s: page %d (%d chunks so far)", + collection, page, len(self._rag_index)) + if not next_offset: + break + if next_offset in seen_offsets: + logger.warning("Scroll loop in %s at page %d — stopping", collection, page) + break + seen_offsets.add(next_offset) + offset = next_offset + + logger.info("Indexed collection %s: %d pages", collection, page) + + def _map_regulations_to_collections(self, controls: list[dict]) -> dict[str, str]: + """Map regulation codes from controls to likely Qdrant collections.""" + # Heuristic: regulation code prefix → collection + collection_map = { + "eu_": "bp_compliance_gesetze", + "dsgvo": "bp_compliance_datenschutz", + "bdsg": "bp_compliance_gesetze", + "ttdsg": "bp_compliance_gesetze", + "nist_": "bp_compliance_ce", + "owasp": "bp_compliance_ce", + "bsi_": "bp_compliance_ce", + "enisa": "bp_compliance_ce", + "at_": "bp_compliance_recht", + "fr_": "bp_compliance_recht", + "es_": "bp_compliance_recht", + } + result: dict[str, str] = {} + for ctrl in controls: + meta = ctrl.get("generation_metadata") or {} + reg = meta.get("source_regulation", "") + if not reg: + continue + for prefix, coll in collection_map.items(): + if reg.startswith(prefix): + result[reg] = coll + break + else: + # Unknown regulation — search all + for coll in ALL_COLLECTIONS: + result[f"_all_{coll}"] = coll + return result + + async def _match_control(self, ctrl: dict) -> Optional[MatchResult]: + """3-tier matching: hash → regex → LLM.""" + + # Tier 1: Hash match against RAG index + source_text = ctrl.get("source_original_text") + if source_text: + h = hashlib.sha256(source_text.encode()).hexdigest() + chunk = self._rag_index.get(h) + if chunk and (chunk.article or chunk.paragraph): + return MatchResult( + article=chunk.article or "", + paragraph=chunk.paragraph or "", + method="hash", + ) + + # Tier 2: Regex parse concatenated source + citation = ctrl.get("source_citation") or {} + source_str = citation.get("source", "") + parsed = _parse_concatenated_source(source_str) + if parsed and parsed["article"]: + return MatchResult( + article=parsed["article"], + paragraph="", # Regex can't extract paragraph from concatenated format + method="regex", + ) + + # Tier 3: Ollama LLM + if source_text: + return await self._llm_match(ctrl) + + return None + + async def _llm_match(self, ctrl: dict) -> Optional[MatchResult]: + """Use Ollama to identify article/paragraph from source text.""" + citation = ctrl.get("source_citation") or {} + regulation_name = citation.get("source", "") + metadata = ctrl.get("generation_metadata") or {} + regulation_code = metadata.get("source_regulation", "") + source_text = ctrl.get("source_original_text", "") + + prompt = f"""Analysiere den folgenden Gesetzestext und bestimme den genauen Artikel und Absatz. + +Gesetz: {regulation_name} (Code: {regulation_code}) + +Text: +--- +{source_text[:2000]} +--- + +Antworte NUR mit JSON: +{{"article": "Art. XX", "paragraph": "Abs. Y"}} + +Falls kein spezifischer Absatz erkennbar ist, setze paragraph auf "". +Falls kein Artikel erkennbar ist, setze article auf "". +Bei deutschen Gesetzen mit § verwende: "§ XX" statt "Art. XX".""" + + try: + raw = await _llm_ollama(prompt, BACKFILL_SYSTEM_PROMPT) + data = _parse_json(raw) + if data and (data.get("article") or data.get("paragraph")): + return MatchResult( + article=data.get("article", ""), + paragraph=data.get("paragraph", ""), + method="llm", + ) + except Exception as e: + logger.warning("LLM match failed for %s: %s", ctrl.get("control_id"), e) + + return None + + def _update_control(self, ctrl: dict, match: MatchResult): + """Update source_citation and generation_metadata in DB.""" + citation = ctrl.get("source_citation") or {} + + # Clean the source name: remove concatenated article if present + source_str = citation.get("source", "") + parsed = _parse_concatenated_source(source_str) + if parsed: + citation["source"] = parsed["name"] + + # Add separate article/paragraph fields + citation["article"] = match.article + citation["paragraph"] = match.paragraph + + # Update generation_metadata + metadata = ctrl.get("generation_metadata") or {} + if match.article: + metadata["source_article"] = match.article + metadata["source_paragraph"] = match.paragraph + metadata["backfill_method"] = match.method + metadata["backfill_at"] = datetime.now(timezone.utc).isoformat() + + self.db.execute( + text(""" + UPDATE canonical_controls + SET source_citation = :citation, + generation_metadata = :metadata, + updated_at = NOW() + WHERE id = CAST(:id AS uuid) + """), + { + "id": ctrl["id"], + "citation": json.dumps(citation), + "metadata": json.dumps(metadata), + }, + ) + + +def _parse_concatenated_source(source: str) -> Optional[dict]: + """Parse 'DSGVO Art. 35' → {name: 'DSGVO', article: 'Art. 35'}. + + Also handles '§' format: 'BDSG § 42' → {name: 'BDSG', article: '§ 42'}. + """ + if not source: + return None + + # Try Art./Artikel pattern + m = _SOURCE_ARTICLE_RE.match(source) + if m: + return {"name": m.group(1).strip(), "article": m.group(2).strip()} + + # Try § pattern + m2 = re.match(r"^(.+?)\s+(§\s*\d+.*)$", source) + if m2: + return {"name": m2.group(1).strip(), "article": m2.group(2).strip()} + + return None + + +async def _llm_ollama(prompt: str, system_prompt: Optional[str] = None) -> str: + """Call Ollama chat API for backfill matching.""" + messages = [] + if system_prompt: + messages.append({"role": "system", "content": system_prompt}) + messages.append({"role": "user", "content": prompt}) + + payload = { + "model": OLLAMA_MODEL, + "messages": messages, + "stream": False, + "format": "json", + "options": {"num_predict": 256}, + "think": False, + } + + try: + async with httpx.AsyncClient(timeout=LLM_TIMEOUT) as client: + resp = await client.post(f"{OLLAMA_URL}/api/chat", json=payload) + if resp.status_code != 200: + logger.error("Ollama backfill failed %d: %s", resp.status_code, resp.text[:300]) + return "" + data = resp.json() + msg = data.get("message", {}) + if isinstance(msg, dict): + return msg.get("content", "") + return data.get("response", str(msg)) + except Exception as e: + logger.error("Ollama backfill request failed: %s", e) + return "" + + +def _parse_json(raw: str) -> Optional[dict]: + """Extract JSON object from LLM output.""" + if not raw: + return None + # Try direct parse + try: + return json.loads(raw) + except json.JSONDecodeError: + pass + # Try extracting from markdown code block + m = re.search(r"```(?:json)?\s*(\{.*?\})\s*```", raw, re.DOTALL) + if m: + try: + return json.loads(m.group(1)) + except json.JSONDecodeError: + pass + # Try finding first { ... } + m = re.search(r"\{[^{}]*\}", raw) + if m: + try: + return json.loads(m.group(0)) + except json.JSONDecodeError: + pass + return None diff --git a/control-pipeline/services/control_composer.py b/control-pipeline/services/control_composer.py new file mode 100644 index 0000000..064528c --- /dev/null +++ b/control-pipeline/services/control_composer.py @@ -0,0 +1,546 @@ +"""Control Composer — Pattern + Obligation → Master Control. + +Takes an obligation (from ObligationExtractor) and a matched control pattern +(from PatternMatcher), then uses LLM to compose a structured, actionable +Master Control. Replaces the old Stage 3 (STRUCTURE/REFORM) with a +pattern-guided approach. + +Three composition modes based on license rules: + Rule 1: Obligation + Pattern + original text → full control + Rule 2: Obligation + Pattern + original text + citation → control + Rule 3: Obligation + Pattern (NO original text) → reformulated control + +Fallback: No pattern match → basic generation (tagged needs_pattern_assignment) + +Part of the Multi-Layer Control Architecture (Phase 6 of 8). +""" + +import json +import logging +import os +from dataclasses import dataclass, field +from typing import Optional + +from services.obligation_extractor import ( + ObligationMatch, + _llm_ollama, + _parse_json, +) +from services.pattern_matcher import ( + ControlPattern, + PatternMatchResult, +) + +logger = logging.getLogger(__name__) + +OLLAMA_MODEL = os.getenv("CONTROL_GEN_OLLAMA_MODEL", "qwen3.5:35b-a3b") + +# Valid values for generated control fields +VALID_SEVERITIES = {"low", "medium", "high", "critical"} +VALID_EFFORTS = {"s", "m", "l", "xl"} +VALID_VERIFICATION = {"code_review", "document", "tool", "hybrid"} + + +@dataclass +class ComposedControl: + """A Master Control composed from an obligation + pattern.""" + + # Core fields (match canonical_controls schema) + control_id: str = "" + title: str = "" + objective: str = "" + rationale: str = "" + scope: dict = field(default_factory=dict) + requirements: list = field(default_factory=list) + test_procedure: list = field(default_factory=list) + evidence: list = field(default_factory=list) + severity: str = "medium" + risk_score: float = 5.0 + implementation_effort: str = "m" + open_anchors: list = field(default_factory=list) + release_state: str = "draft" + tags: list = field(default_factory=list) + # 3-Rule License fields + license_rule: Optional[int] = None + source_original_text: Optional[str] = None + source_citation: Optional[dict] = None + customer_visible: bool = True + # Classification + verification_method: Optional[str] = None + category: Optional[str] = None + target_audience: Optional[list] = None + # Pattern + Obligation linkage + pattern_id: Optional[str] = None + obligation_ids: list = field(default_factory=list) + # Metadata + generation_metadata: dict = field(default_factory=dict) + composition_method: str = "pattern_guided" # pattern_guided | fallback + + def to_dict(self) -> dict: + """Serialize for DB storage or API response.""" + return { + "control_id": self.control_id, + "title": self.title, + "objective": self.objective, + "rationale": self.rationale, + "scope": self.scope, + "requirements": self.requirements, + "test_procedure": self.test_procedure, + "evidence": self.evidence, + "severity": self.severity, + "risk_score": self.risk_score, + "implementation_effort": self.implementation_effort, + "open_anchors": self.open_anchors, + "release_state": self.release_state, + "tags": self.tags, + "license_rule": self.license_rule, + "source_original_text": self.source_original_text, + "source_citation": self.source_citation, + "customer_visible": self.customer_visible, + "verification_method": self.verification_method, + "category": self.category, + "target_audience": self.target_audience, + "pattern_id": self.pattern_id, + "obligation_ids": self.obligation_ids, + "generation_metadata": self.generation_metadata, + "composition_method": self.composition_method, + } + + +class ControlComposer: + """Composes Master Controls from obligations + patterns. + + Usage:: + + composer = ControlComposer() + + control = await composer.compose( + obligation=obligation_match, + pattern_result=pattern_match_result, + chunk_text="...", + license_rule=1, + source_citation={...}, + ) + """ + + async def compose( + self, + obligation: ObligationMatch, + pattern_result: PatternMatchResult, + chunk_text: Optional[str] = None, + license_rule: int = 3, + source_citation: Optional[dict] = None, + regulation_code: Optional[str] = None, + ) -> ComposedControl: + """Compose a Master Control from obligation + pattern. + + Args: + obligation: The extracted obligation (from ObligationExtractor). + pattern_result: The matched pattern (from PatternMatcher). + chunk_text: Original RAG chunk text (only used for Rules 1-2). + license_rule: 1=free, 2=citation, 3=restricted. + source_citation: Citation metadata for Rule 2. + regulation_code: Source regulation code. + + Returns: + ComposedControl ready for storage. + """ + pattern = pattern_result.pattern if pattern_result else None + + if pattern: + control = await self._compose_with_pattern( + obligation, pattern, chunk_text, license_rule, source_citation, + ) + else: + control = await self._compose_fallback( + obligation, chunk_text, license_rule, source_citation, + ) + + # Set linkage fields + control.pattern_id = pattern.id if pattern else None + if obligation.obligation_id: + control.obligation_ids = [obligation.obligation_id] + + # Set license fields + control.license_rule = license_rule + if license_rule in (1, 2) and chunk_text: + control.source_original_text = chunk_text + if license_rule == 2 and source_citation: + control.source_citation = source_citation + if license_rule == 3: + control.customer_visible = False + control.source_original_text = None + control.source_citation = None + + # Build metadata + control.generation_metadata = { + "composition_method": control.composition_method, + "pattern_id": control.pattern_id, + "pattern_confidence": round(pattern_result.confidence, 3) if pattern_result else 0, + "pattern_method": pattern_result.method if pattern_result else "none", + "obligation_id": obligation.obligation_id, + "obligation_method": obligation.method, + "obligation_confidence": round(obligation.confidence, 3), + "license_rule": license_rule, + "regulation_code": regulation_code, + } + + # Validate and fix fields + _validate_control(control) + + return control + + async def compose_batch( + self, + items: list[dict], + ) -> list[ComposedControl]: + """Compose multiple controls. + + Args: + items: List of dicts with keys: obligation, pattern_result, + chunk_text, license_rule, source_citation, regulation_code. + + Returns: + List of ComposedControl instances. + """ + results = [] + for item in items: + control = await self.compose( + obligation=item["obligation"], + pattern_result=item.get("pattern_result", PatternMatchResult()), + chunk_text=item.get("chunk_text"), + license_rule=item.get("license_rule", 3), + source_citation=item.get("source_citation"), + regulation_code=item.get("regulation_code"), + ) + results.append(control) + return results + + # ----------------------------------------------------------------------- + # Pattern-guided composition + # ----------------------------------------------------------------------- + + async def _compose_with_pattern( + self, + obligation: ObligationMatch, + pattern: ControlPattern, + chunk_text: Optional[str], + license_rule: int, + source_citation: Optional[dict], + ) -> ComposedControl: + """Use LLM to fill the pattern template with obligation-specific details.""" + prompt = _build_compose_prompt(obligation, pattern, chunk_text, license_rule) + system_prompt = _compose_system_prompt(license_rule) + + llm_result = await _llm_ollama(prompt, system_prompt) + if not llm_result: + return self._compose_from_template(obligation, pattern) + + parsed = _parse_json(llm_result) + if not parsed: + return self._compose_from_template(obligation, pattern) + + control = ComposedControl( + title=parsed.get("title", pattern.name_de)[:255], + objective=parsed.get("objective", pattern.objective_template), + rationale=parsed.get("rationale", pattern.rationale_template), + requirements=_ensure_list(parsed.get("requirements", pattern.requirements_template)), + test_procedure=_ensure_list(parsed.get("test_procedure", pattern.test_procedure_template)), + evidence=_ensure_list(parsed.get("evidence", pattern.evidence_template)), + severity=parsed.get("severity", pattern.severity_default), + implementation_effort=parsed.get("implementation_effort", pattern.implementation_effort_default), + category=parsed.get("category", pattern.category), + tags=_ensure_list(parsed.get("tags", pattern.tags)), + target_audience=_ensure_list(parsed.get("target_audience", [])), + verification_method=parsed.get("verification_method"), + open_anchors=_anchors_from_pattern(pattern), + composition_method="pattern_guided", + ) + + return control + + def _compose_from_template( + self, + obligation: ObligationMatch, + pattern: ControlPattern, + ) -> ComposedControl: + """Fallback: fill template directly without LLM (when LLM fails).""" + obl_title = obligation.obligation_title or "" + obl_text = obligation.obligation_text or "" + + title = f"{pattern.name_de}" + if obl_title: + title = f"{pattern.name_de} — {obl_title}" + + objective = pattern.objective_template + if obl_text and len(obl_text) > 20: + objective = f"{pattern.objective_template} Bezug: {obl_text[:200]}" + + return ComposedControl( + title=title[:255], + objective=objective, + rationale=pattern.rationale_template, + requirements=list(pattern.requirements_template), + test_procedure=list(pattern.test_procedure_template), + evidence=list(pattern.evidence_template), + severity=pattern.severity_default, + implementation_effort=pattern.implementation_effort_default, + category=pattern.category, + tags=list(pattern.tags), + open_anchors=_anchors_from_pattern(pattern), + composition_method="template_only", + ) + + # ----------------------------------------------------------------------- + # Fallback (no pattern) + # ----------------------------------------------------------------------- + + async def _compose_fallback( + self, + obligation: ObligationMatch, + chunk_text: Optional[str], + license_rule: int, + source_citation: Optional[dict], + ) -> ComposedControl: + """Generate a control without a pattern template (old-style).""" + prompt = _build_fallback_prompt(obligation, chunk_text, license_rule) + system_prompt = _compose_system_prompt(license_rule) + + llm_result = await _llm_ollama(prompt, system_prompt) + parsed = _parse_json(llm_result) if llm_result else {} + + obl_text = obligation.obligation_text or "" + + control = ComposedControl( + title=parsed.get("title", obl_text[:100] if obl_text else "Untitled Control")[:255], + objective=parsed.get("objective", obl_text[:500]), + rationale=parsed.get("rationale", "Aus gesetzlicher Pflicht abgeleitet."), + requirements=_ensure_list(parsed.get("requirements", [])), + test_procedure=_ensure_list(parsed.get("test_procedure", [])), + evidence=_ensure_list(parsed.get("evidence", [])), + severity=parsed.get("severity", "medium"), + implementation_effort=parsed.get("implementation_effort", "m"), + category=parsed.get("category"), + tags=_ensure_list(parsed.get("tags", [])), + target_audience=_ensure_list(parsed.get("target_audience", [])), + verification_method=parsed.get("verification_method"), + composition_method="fallback", + release_state="needs_review", + ) + + return control + + +# --------------------------------------------------------------------------- +# Prompt builders +# --------------------------------------------------------------------------- + + +def _compose_system_prompt(license_rule: int) -> str: + """Build the system prompt based on license rule.""" + if license_rule == 3: + return ( + "Du bist ein Security-Compliance-Experte. Deine Aufgabe ist es, " + "eigenstaendige Security Controls zu formulieren. " + "Du formulierst IMMER in eigenen Worten. " + "KOPIERE KEINE Saetze aus dem Quelltext. " + "Verwende eigene Begriffe und Struktur. " + "NENNE NICHT die Quelle. Keine proprietaeren Bezeichner. " + "Antworte NUR mit validem JSON." + ) + return ( + "Du bist ein Security-Compliance-Experte. " + "Erstelle ein praxisorientiertes, umsetzbares Security Control. " + "Antworte NUR mit validem JSON." + ) + + +def _build_compose_prompt( + obligation: ObligationMatch, + pattern: ControlPattern, + chunk_text: Optional[str], + license_rule: int, +) -> str: + """Build the LLM prompt for pattern-guided composition.""" + obl_section = _obligation_section(obligation) + pattern_section = _pattern_section(pattern) + + if license_rule == 3: + context_section = "KONTEXT: Intern analysiert (keine Quellenangabe)." + elif chunk_text: + context_section = f"KONTEXT (Originaltext):\n{chunk_text[:2000]}" + else: + context_section = "KONTEXT: Kein Originaltext verfuegbar." + + return f"""Erstelle ein PRAXISORIENTIERTES Security Control. + +{obl_section} + +{pattern_section} + +{context_section} + +AUFGABE: +Fuelle das Muster mit pflicht-spezifischen Details. +Das Ergebnis muss UMSETZBAR sein — keine Gesetzesparaphrase. +Formuliere konkret und handlungsorientiert. + +Antworte als JSON: +{{ + "title": "Kurzer praegnanter Titel (max 100 Zeichen, deutsch)", + "objective": "Was soll erreicht werden? (1-3 Saetze)", + "rationale": "Warum ist das wichtig? (1-2 Saetze)", + "requirements": ["Konkrete Anforderung 1", "Anforderung 2", ...], + "test_procedure": ["Pruefschritt 1", "Pruefschritt 2", ...], + "evidence": ["Nachweis 1", "Nachweis 2", ...], + "severity": "low|medium|high|critical", + "implementation_effort": "s|m|l|xl", + "category": "{pattern.category}", + "tags": ["tag1", "tag2"], + "target_audience": ["unternehmen", "behoerden", "entwickler"], + "verification_method": "code_review|document|tool|hybrid" +}}""" + + +def _build_fallback_prompt( + obligation: ObligationMatch, + chunk_text: Optional[str], + license_rule: int, +) -> str: + """Build the LLM prompt for fallback composition (no pattern).""" + obl_section = _obligation_section(obligation) + + if license_rule == 3: + context_section = "KONTEXT: Intern analysiert (keine Quellenangabe)." + elif chunk_text: + context_section = f"KONTEXT (Originaltext):\n{chunk_text[:2000]}" + else: + context_section = "KONTEXT: Kein Originaltext verfuegbar." + + return f"""Erstelle ein Security Control aus der folgenden Pflicht. + +{obl_section} + +{context_section} + +AUFGABE: +Formuliere ein umsetzbares Security Control. +Keine Gesetzesparaphrase — konkrete Massnahmen beschreiben. + +Antworte als JSON: +{{ + "title": "Kurzer praegnanter Titel (max 100 Zeichen, deutsch)", + "objective": "Was soll erreicht werden? (1-3 Saetze)", + "rationale": "Warum ist das wichtig? (1-2 Saetze)", + "requirements": ["Konkrete Anforderung 1", "Anforderung 2", ...], + "test_procedure": ["Pruefschritt 1", "Pruefschritt 2", ...], + "evidence": ["Nachweis 1", "Nachweis 2", ...], + "severity": "low|medium|high|critical", + "implementation_effort": "s|m|l|xl", + "category": "one of: authentication, encryption, data_protection, etc.", + "tags": ["tag1", "tag2"], + "target_audience": ["unternehmen"], + "verification_method": "code_review|document|tool|hybrid" +}}""" + + +def _obligation_section(obligation: ObligationMatch) -> str: + """Format the obligation for the prompt.""" + parts = ["PFLICHT (was das Gesetz verlangt):"] + if obligation.obligation_title: + parts.append(f" Titel: {obligation.obligation_title}") + if obligation.obligation_text: + parts.append(f" Beschreibung: {obligation.obligation_text[:500]}") + if obligation.obligation_id: + parts.append(f" ID: {obligation.obligation_id}") + if obligation.regulation_id: + parts.append(f" Rechtsgrundlage: {obligation.regulation_id}") + if not obligation.obligation_text and not obligation.obligation_title: + parts.append(" (Keine spezifische Pflicht extrahiert)") + return "\n".join(parts) + + +def _pattern_section(pattern: ControlPattern) -> str: + """Format the pattern for the prompt.""" + reqs = "\n ".join(f"- {r}" for r in pattern.requirements_template[:5]) + tests = "\n ".join(f"- {t}" for t in pattern.test_procedure_template[:3]) + return f"""MUSTER (wie man es typischerweise umsetzt): + Pattern: {pattern.name_de} ({pattern.id}) + Domain: {pattern.domain} + Ziel-Template: {pattern.objective_template} + Anforderungs-Template: + {reqs} + Pruefverfahren-Template: + {tests}""" + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _ensure_list(value) -> list: + """Ensure a value is a list of strings.""" + if isinstance(value, list): + return [str(v) for v in value if v] + if isinstance(value, str): + return [value] + return [] + + +def _anchors_from_pattern(pattern: ControlPattern) -> list: + """Convert pattern's open_anchor_refs to control anchor format.""" + anchors = [] + for ref in pattern.open_anchor_refs: + anchors.append({ + "framework": ref.get("framework", ""), + "control_id": ref.get("ref", ""), + "title": "", + "alignment_score": 0.8, + }) + return anchors + + +def _validate_control(control: ComposedControl) -> None: + """Validate and fix control field values.""" + # Severity + if control.severity not in VALID_SEVERITIES: + control.severity = "medium" + + # Implementation effort + if control.implementation_effort not in VALID_EFFORTS: + control.implementation_effort = "m" + + # Verification method + if control.verification_method and control.verification_method not in VALID_VERIFICATION: + control.verification_method = None + + # Risk score + if not (0 <= control.risk_score <= 10): + control.risk_score = _severity_to_risk(control.severity) + + # Title length + if len(control.title) > 255: + control.title = control.title[:252] + "..." + + # Ensure minimum content + if not control.objective: + control.objective = control.title + if not control.rationale: + control.rationale = "Aus regulatorischer Anforderung abgeleitet." + if not control.requirements: + control.requirements = ["Anforderung gemaess Pflichtbeschreibung umsetzen"] + if not control.test_procedure: + control.test_procedure = ["Umsetzung der Anforderungen pruefen"] + if not control.evidence: + control.evidence = ["Dokumentation der Umsetzung"] + + +def _severity_to_risk(severity: str) -> float: + """Map severity to a default risk score.""" + return { + "critical": 9.0, + "high": 7.0, + "medium": 5.0, + "low": 3.0, + }.get(severity, 5.0) diff --git a/control-pipeline/services/control_dedup.py b/control-pipeline/services/control_dedup.py new file mode 100644 index 0000000..26a26f2 --- /dev/null +++ b/control-pipeline/services/control_dedup.py @@ -0,0 +1,745 @@ +"""Control Deduplication Engine — 4-Stage Matching Pipeline. + +Prevents duplicate atomic controls during Pass 0b by checking candidates +against existing controls before insertion. + +Stages: + 1. Pattern-Gate: pattern_id must match (hard gate) + 2. Action-Check: normalized action verb must match (hard gate) + 3. Object-Norm: normalized object must match (soft gate with high threshold) + 4. Embedding: cosine similarity with tiered thresholds (Qdrant) + +Verdicts: + - NEW: create a new atomic control + - LINK: add parent link to existing control (similarity > LINK_THRESHOLD) + - REVIEW: queue for human review (REVIEW_THRESHOLD < sim < LINK_THRESHOLD) +""" + +import logging +import os +import re +from dataclasses import dataclass, field +from typing import Optional, Callable, Awaitable + +import httpx + +logger = logging.getLogger(__name__) + +# ── Configuration ──────────────────────────────────────────────────── + +DEDUP_ENABLED = os.getenv("DEDUP_ENABLED", "true").lower() == "true" +LINK_THRESHOLD = float(os.getenv("DEDUP_LINK_THRESHOLD", "0.92")) +REVIEW_THRESHOLD = float(os.getenv("DEDUP_REVIEW_THRESHOLD", "0.85")) +LINK_THRESHOLD_DIFF_OBJECT = float(os.getenv("DEDUP_LINK_THRESHOLD_DIFF_OBJ", "0.95")) +CROSS_REG_LINK_THRESHOLD = float(os.getenv("DEDUP_CROSS_REG_THRESHOLD", "0.95")) +QDRANT_COLLECTION = os.getenv("DEDUP_QDRANT_COLLECTION", "atomic_controls") +QDRANT_URL = os.getenv("QDRANT_URL", "http://host.docker.internal:6333") +EMBEDDING_URL = os.getenv("EMBEDDING_URL", "http://embedding-service:8087") + + +# ── Result Dataclass ───────────────────────────────────────────────── + +@dataclass +class DedupResult: + """Outcome of the dedup check.""" + verdict: str # "new" | "link" | "review" + matched_control_uuid: Optional[str] = None + matched_control_id: Optional[str] = None + matched_title: Optional[str] = None + stage: str = "" # which stage decided + similarity_score: float = 0.0 + link_type: str = "dedup_merge" # "dedup_merge" | "cross_regulation" + details: dict = field(default_factory=dict) + + +# ── Action Normalization ───────────────────────────────────────────── + +_ACTION_SYNONYMS: dict[str, str] = { + # German → canonical English + "implementieren": "implement", + "umsetzen": "implement", + "einrichten": "implement", + "einführen": "implement", + "aufbauen": "implement", + "bereitstellen": "implement", + "aktivieren": "implement", + "konfigurieren": "configure", + "einstellen": "configure", + "parametrieren": "configure", + "testen": "test", + "prüfen": "test", + "überprüfen": "test", + "verifizieren": "test", + "validieren": "test", + "kontrollieren": "test", + "auditieren": "audit", + "dokumentieren": "document", + "protokollieren": "log", + "aufzeichnen": "log", + "loggen": "log", + "überwachen": "monitor", + "monitoring": "monitor", + "beobachten": "monitor", + "schulen": "train", + "trainieren": "train", + "sensibilisieren": "train", + "löschen": "delete", + "entfernen": "delete", + "verschlüsseln": "encrypt", + "sperren": "block", + "beschränken": "restrict", + "einschränken": "restrict", + "begrenzen": "restrict", + "autorisieren": "authorize", + "genehmigen": "authorize", + "freigeben": "authorize", + "authentifizieren": "authenticate", + "identifizieren": "identify", + "melden": "report", + "benachrichtigen": "notify", + "informieren": "notify", + "aktualisieren": "update", + "erneuern": "update", + "sichern": "backup", + "wiederherstellen": "restore", + # English passthrough + "implement": "implement", + "configure": "configure", + "test": "test", + "verify": "test", + "validate": "test", + "audit": "audit", + "document": "document", + "log": "log", + "monitor": "monitor", + "train": "train", + "delete": "delete", + "encrypt": "encrypt", + "restrict": "restrict", + "authorize": "authorize", + "authenticate": "authenticate", + "report": "report", + "update": "update", + "backup": "backup", + "restore": "restore", +} + + +def normalize_action(action: str) -> str: + """Normalize an action verb to a canonical English form.""" + if not action: + return "" + action = action.strip().lower() + # Strip German infinitive/conjugation suffixes for lookup + action_base = re.sub(r"(en|t|st|e|te|tet|end)$", "", action) + # Try exact match first, then base form + if action in _ACTION_SYNONYMS: + return _ACTION_SYNONYMS[action] + if action_base in _ACTION_SYNONYMS: + return _ACTION_SYNONYMS[action_base] + # Fuzzy: check if action starts with any known verb + for verb, canonical in _ACTION_SYNONYMS.items(): + if action.startswith(verb) or verb.startswith(action): + return canonical + return action # fallback: return as-is + + +# ── Object Normalization ───────────────────────────────────────────── + +_OBJECT_SYNONYMS: dict[str, str] = { + # Authentication / Access + "mfa": "multi_factor_auth", + "multi-faktor-authentifizierung": "multi_factor_auth", + "mehrfaktorauthentifizierung": "multi_factor_auth", + "multi-factor authentication": "multi_factor_auth", + "two-factor": "multi_factor_auth", + "2fa": "multi_factor_auth", + "passwort": "password_policy", + "kennwort": "password_policy", + "password": "password_policy", + "zugangsdaten": "credentials", + "credentials": "credentials", + "admin-konten": "privileged_access", + "admin accounts": "privileged_access", + "administratorkonten": "privileged_access", + "privilegierte zugriffe": "privileged_access", + "privileged accounts": "privileged_access", + "remote-zugriff": "remote_access", + "fernzugriff": "remote_access", + "remote access": "remote_access", + "session": "session_management", + "sitzung": "session_management", + "sitzungsverwaltung": "session_management", + # Encryption + "verschlüsselung": "encryption", + "encryption": "encryption", + "kryptografie": "encryption", + "kryptografische verfahren": "encryption", + "schlüssel": "key_management", + "key management": "key_management", + "schlüsselverwaltung": "key_management", + "zertifikat": "certificate_management", + "certificate": "certificate_management", + "tls": "transport_encryption", + "ssl": "transport_encryption", + "https": "transport_encryption", + # Network + "firewall": "firewall", + "netzwerk": "network_security", + "network": "network_security", + "vpn": "vpn", + "segmentierung": "network_segmentation", + "segmentation": "network_segmentation", + # Logging / Monitoring + "audit-log": "audit_logging", + "audit log": "audit_logging", + "protokoll": "audit_logging", + "logging": "audit_logging", + "monitoring": "monitoring", + "überwachung": "monitoring", + "alerting": "alerting", + "alarmierung": "alerting", + "siem": "siem", + # Data + "personenbezogene daten": "personal_data", + "personal data": "personal_data", + "sensible daten": "sensitive_data", + "sensitive data": "sensitive_data", + "datensicherung": "backup", + "backup": "backup", + "wiederherstellung": "disaster_recovery", + "disaster recovery": "disaster_recovery", + # Policy / Process + "richtlinie": "policy", + "policy": "policy", + "verfahrensanweisung": "procedure", + "procedure": "procedure", + "prozess": "process", + "schulung": "training", + "training": "training", + "awareness": "awareness", + "sensibilisierung": "awareness", + # Incident + "vorfall": "incident", + "incident": "incident", + "sicherheitsvorfall": "security_incident", + "security incident": "security_incident", + # Vulnerability + "schwachstelle": "vulnerability", + "vulnerability": "vulnerability", + "patch": "patch_management", + "update": "patch_management", + "patching": "patch_management", +} + +# Precompile for substring matching (longest first) +_OBJECT_KEYS_SORTED = sorted(_OBJECT_SYNONYMS.keys(), key=len, reverse=True) + + +def normalize_object(obj: str) -> str: + """Normalize a compliance object to a canonical token.""" + if not obj: + return "" + obj_lower = obj.strip().lower() + # Exact match + if obj_lower in _OBJECT_SYNONYMS: + return _OBJECT_SYNONYMS[obj_lower] + # Substring match (longest first) + for phrase in _OBJECT_KEYS_SORTED: + if phrase in obj_lower: + return _OBJECT_SYNONYMS[phrase] + # Fallback: strip articles/prepositions, join with underscore + cleaned = re.sub(r"\b(der|die|das|den|dem|des|ein|eine|eines|einem|einen" + r"|für|von|zu|auf|in|an|bei|mit|nach|über|unter|the|a|an" + r"|for|of|to|on|in|at|by|with)\b", "", obj_lower) + tokens = [t for t in cleaned.split() if len(t) > 2] + return "_".join(tokens[:4]) if tokens else obj_lower.replace(" ", "_") + + +# ── Canonicalization ───────────────────────────────────────────────── + +def canonicalize_text(action: str, obj: str, title: str = "") -> str: + """Build a canonical English text for embedding. + + Transforms German compliance text into normalized English tokens + for more stable embedding comparisons. + """ + norm_action = normalize_action(action) + norm_object = normalize_object(obj) + # Build canonical sentence + parts = [norm_action, norm_object] + if title: + # Add title keywords (stripped of common filler) + title_clean = re.sub( + r"\b(und|oder|für|von|zu|der|die|das|den|dem|des|ein|eine" + r"|bei|mit|nach|gemäß|gem\.|laut|entsprechend)\b", + "", title.lower() + ) + title_tokens = [t for t in title_clean.split() if len(t) > 3][:5] + if title_tokens: + parts.append("for") + parts.extend(title_tokens) + return " ".join(parts) + + +# ── Embedding Helper ───────────────────────────────────────────────── + +async def get_embedding(text: str) -> list[float]: + """Get embedding vector for a single text via embedding service.""" + try: + async with httpx.AsyncClient(timeout=10.0) as client: + resp = await client.post( + f"{EMBEDDING_URL}/embed", + json={"texts": [text]}, + ) + embeddings = resp.json().get("embeddings", []) + return embeddings[0] if embeddings else [] + except Exception as e: + logger.warning("Embedding failed: %s", e) + return [] + + +def cosine_similarity(a: list[float], b: list[float]) -> float: + """Compute cosine similarity between two vectors.""" + if not a or not b or len(a) != len(b): + return 0.0 + dot = sum(x * y for x, y in zip(a, b)) + norm_a = sum(x * x for x in a) ** 0.5 + norm_b = sum(x * x for x in b) ** 0.5 + if norm_a == 0 or norm_b == 0: + return 0.0 + return dot / (norm_a * norm_b) + + +# ── Qdrant Helpers ─────────────────────────────────────────────────── + +async def qdrant_search( + embedding: list[float], + pattern_id: str, + top_k: int = 10, + collection: Optional[str] = None, +) -> list[dict]: + """Search Qdrant for similar atomic controls, filtered by pattern_id.""" + if not embedding: + return [] + coll = collection or QDRANT_COLLECTION + body: dict = { + "vector": embedding, + "limit": top_k, + "with_payload": True, + "filter": { + "must": [ + {"key": "pattern_id", "match": {"value": pattern_id}} + ] + }, + } + try: + async with httpx.AsyncClient(timeout=10.0) as client: + resp = await client.post( + f"{QDRANT_URL}/collections/{coll}/points/search", + json=body, + ) + if resp.status_code != 200: + logger.warning("Qdrant search failed: %d", resp.status_code) + return [] + return resp.json().get("result", []) + except Exception as e: + logger.warning("Qdrant search error: %s", e) + return [] + + +async def qdrant_search_cross_regulation( + embedding: list[float], + top_k: int = 5, + collection: Optional[str] = None, +) -> list[dict]: + """Search Qdrant for similar controls across ALL regulations (no pattern_id filter). + + Used for cross-regulation linking (e.g. DSGVO Art. 25 ↔ NIS2 Art. 21). + """ + if not embedding: + return [] + coll = collection or QDRANT_COLLECTION + body: dict = { + "vector": embedding, + "limit": top_k, + "with_payload": True, + } + try: + async with httpx.AsyncClient(timeout=10.0) as client: + resp = await client.post( + f"{QDRANT_URL}/collections/{coll}/points/search", + json=body, + ) + if resp.status_code != 200: + logger.warning("Qdrant cross-reg search failed: %d", resp.status_code) + return [] + return resp.json().get("result", []) + except Exception as e: + logger.warning("Qdrant cross-reg search error: %s", e) + return [] + + +async def qdrant_upsert( + point_id: str, + embedding: list[float], + payload: dict, + collection: Optional[str] = None, +) -> bool: + """Upsert a single point into a Qdrant collection.""" + if not embedding: + return False + coll = collection or QDRANT_COLLECTION + body = { + "points": [{ + "id": point_id, + "vector": embedding, + "payload": payload, + }] + } + try: + async with httpx.AsyncClient(timeout=10.0) as client: + resp = await client.put( + f"{QDRANT_URL}/collections/{coll}/points", + json=body, + ) + return resp.status_code == 200 + except Exception as e: + logger.warning("Qdrant upsert error: %s", e) + return False + + +async def ensure_qdrant_collection( + vector_size: int = 1024, + collection: Optional[str] = None, +) -> bool: + """Create a Qdrant collection if it doesn't exist (idempotent).""" + coll = collection or QDRANT_COLLECTION + try: + async with httpx.AsyncClient(timeout=10.0) as client: + # Check if exists + resp = await client.get(f"{QDRANT_URL}/collections/{coll}") + if resp.status_code == 200: + return True + # Create + resp = await client.put( + f"{QDRANT_URL}/collections/{coll}", + json={ + "vectors": {"size": vector_size, "distance": "Cosine"}, + }, + ) + if resp.status_code == 200: + logger.info("Created Qdrant collection: %s", coll) + # Create payload indexes + for field_name in ["pattern_id", "action_normalized", "object_normalized", "control_id"]: + await client.put( + f"{QDRANT_URL}/collections/{coll}/index", + json={"field_name": field_name, "field_schema": "keyword"}, + ) + return True + logger.error("Failed to create Qdrant collection: %d", resp.status_code) + return False + except Exception as e: + logger.warning("Qdrant collection check error: %s", e) + return False + + +# ── Main Dedup Checker ─────────────────────────────────────────────── + +class ControlDedupChecker: + """4-stage dedup checker for atomic controls. + + Usage: + checker = ControlDedupChecker(db_session) + result = await checker.check_duplicate(candidate_action, candidate_object, candidate_title, pattern_id) + if result.verdict == "link": + checker.add_parent_link(result.matched_control_uuid, parent_uuid) + elif result.verdict == "review": + checker.write_review(candidate, result) + else: + # Insert new control + """ + + def __init__( + self, + db, + embed_fn: Optional[Callable[[str], Awaitable[list[float]]]] = None, + search_fn: Optional[Callable] = None, + ): + self.db = db + self._embed = embed_fn or get_embedding + self._search = search_fn or qdrant_search + self._cache: dict[str, list[dict]] = {} # pattern_id → existing controls + + def _load_existing(self, pattern_id: str) -> list[dict]: + """Load existing atomic controls with same pattern_id from DB.""" + if pattern_id in self._cache: + return self._cache[pattern_id] + from sqlalchemy import text + rows = self.db.execute(text(""" + SELECT id::text, control_id, title, objective, + pattern_id, + generation_metadata->>'obligation_type' as obligation_type + FROM canonical_controls + WHERE parent_control_uuid IS NOT NULL + AND release_state != 'deprecated' + AND pattern_id = :pid + """), {"pid": pattern_id}).fetchall() + result = [ + { + "uuid": r[0], "control_id": r[1], "title": r[2], + "objective": r[3], "pattern_id": r[4], + "obligation_type": r[5], + } + for r in rows + ] + self._cache[pattern_id] = result + return result + + async def check_duplicate( + self, + action: str, + obj: str, + title: str, + pattern_id: Optional[str], + ) -> DedupResult: + """Run the 4-stage dedup pipeline + cross-regulation linking. + + Returns DedupResult with verdict: new/link/review. + """ + # No pattern_id → can't dedup meaningfully + if not pattern_id: + return DedupResult(verdict="new", stage="no_pattern") + + # Stage 1: Pattern-Gate + existing = self._load_existing(pattern_id) + if not existing: + return DedupResult( + verdict="new", stage="pattern_gate", + details={"reason": "no existing controls with this pattern_id"}, + ) + + # Stage 2: Action-Check + norm_action = normalize_action(action) + # We don't have action stored on existing controls from DB directly, + # so we use embedding for controls that passed pattern gate. + # But we CAN check via generation_metadata if available. + + # Stage 3: Object-Normalization + norm_object = normalize_object(obj) + + # Stage 4: Embedding Similarity + canonical = canonicalize_text(action, obj, title) + embedding = await self._embed(canonical) + if not embedding: + # Can't compute embedding → default to new + return DedupResult( + verdict="new", stage="embedding_unavailable", + details={"canonical_text": canonical}, + ) + + # Search Qdrant + results = await self._search(embedding, pattern_id, top_k=5) + + if not results: + # No intra-pattern matches → try cross-regulation + return await self._check_cross_regulation(embedding, DedupResult( + verdict="new", stage="no_qdrant_matches", + details={"canonical_text": canonical, "action": norm_action, "object": norm_object}, + )) + + # Evaluate best match + best = results[0] + best_score = best.get("score", 0.0) + best_payload = best.get("payload", {}) + best_action = best_payload.get("action_normalized", "") + best_object = best_payload.get("object_normalized", "") + + # Action differs → NEW (even if embedding is high) + if best_action and norm_action and best_action != norm_action: + return await self._check_cross_regulation(embedding, DedupResult( + verdict="new", stage="action_mismatch", + similarity_score=best_score, + matched_control_id=best_payload.get("control_id"), + details={ + "candidate_action": norm_action, + "existing_action": best_action, + "similarity": best_score, + }, + )) + + # Object differs → use higher threshold + if best_object and norm_object and best_object != norm_object: + if best_score > LINK_THRESHOLD_DIFF_OBJECT: + return DedupResult( + verdict="link", stage="embedding_diff_object", + matched_control_uuid=best_payload.get("control_uuid"), + matched_control_id=best_payload.get("control_id"), + matched_title=best_payload.get("title"), + similarity_score=best_score, + details={"candidate_object": norm_object, "existing_object": best_object}, + ) + return await self._check_cross_regulation(embedding, DedupResult( + verdict="new", stage="object_mismatch_below_threshold", + similarity_score=best_score, + matched_control_id=best_payload.get("control_id"), + details={ + "candidate_object": norm_object, + "existing_object": best_object, + "threshold": LINK_THRESHOLD_DIFF_OBJECT, + }, + )) + + # Same action + same object → tiered thresholds + if best_score > LINK_THRESHOLD: + return DedupResult( + verdict="link", stage="embedding_match", + matched_control_uuid=best_payload.get("control_uuid"), + matched_control_id=best_payload.get("control_id"), + matched_title=best_payload.get("title"), + similarity_score=best_score, + ) + if best_score > REVIEW_THRESHOLD: + return DedupResult( + verdict="review", stage="embedding_review", + matched_control_uuid=best_payload.get("control_uuid"), + matched_control_id=best_payload.get("control_id"), + matched_title=best_payload.get("title"), + similarity_score=best_score, + ) + return await self._check_cross_regulation(embedding, DedupResult( + verdict="new", stage="embedding_below_threshold", + similarity_score=best_score, + details={"threshold": REVIEW_THRESHOLD}, + )) + + async def _check_cross_regulation( + self, + embedding: list[float], + intra_result: DedupResult, + ) -> DedupResult: + """Second pass: cross-regulation linking for controls deemed 'new'. + + Searches Qdrant WITHOUT pattern_id filter. Uses a higher threshold + (0.95) to avoid false positives across regulation boundaries. + """ + if intra_result.verdict != "new" or not embedding: + return intra_result + + cross_results = await qdrant_search_cross_regulation(embedding, top_k=5) + if not cross_results: + return intra_result + + best = cross_results[0] + best_score = best.get("score", 0.0) + if best_score > CROSS_REG_LINK_THRESHOLD: + best_payload = best.get("payload", {}) + return DedupResult( + verdict="link", + stage="cross_regulation", + matched_control_uuid=best_payload.get("control_uuid"), + matched_control_id=best_payload.get("control_id"), + matched_title=best_payload.get("title"), + similarity_score=best_score, + link_type="cross_regulation", + details={ + "cross_reg_score": best_score, + "cross_reg_threshold": CROSS_REG_LINK_THRESHOLD, + }, + ) + + return intra_result + + def add_parent_link( + self, + control_uuid: str, + parent_control_uuid: str, + link_type: str = "dedup_merge", + confidence: float = 0.0, + source_regulation: Optional[str] = None, + source_article: Optional[str] = None, + obligation_candidate_id: Optional[str] = None, + ) -> None: + """Add a parent link to an existing atomic control.""" + from sqlalchemy import text + self.db.execute(text(""" + INSERT INTO control_parent_links + (control_uuid, parent_control_uuid, link_type, confidence, + source_regulation, source_article, obligation_candidate_id) + VALUES (:cu, :pu, :lt, :conf, :sr, :sa, :oci::uuid) + ON CONFLICT (control_uuid, parent_control_uuid) DO NOTHING + """), { + "cu": control_uuid, + "pu": parent_control_uuid, + "lt": link_type, + "conf": confidence, + "sr": source_regulation, + "sa": source_article, + "oci": obligation_candidate_id, + }) + self.db.commit() + + def write_review( + self, + candidate_control_id: str, + candidate_title: str, + candidate_objective: str, + result: DedupResult, + parent_control_uuid: Optional[str] = None, + obligation_candidate_id: Optional[str] = None, + ) -> None: + """Write a dedup review queue entry.""" + from sqlalchemy import text + self.db.execute(text(""" + INSERT INTO control_dedup_reviews + (candidate_control_id, candidate_title, candidate_objective, + matched_control_uuid, matched_control_id, + similarity_score, dedup_stage, dedup_details, + parent_control_uuid, obligation_candidate_id) + VALUES (:ccid, :ct, :co, :mcu::uuid, :mci, :ss, :ds, + :dd::jsonb, :pcu::uuid, :oci) + """), { + "ccid": candidate_control_id, + "ct": candidate_title, + "co": candidate_objective, + "mcu": result.matched_control_uuid, + "mci": result.matched_control_id, + "ss": result.similarity_score, + "ds": result.stage, + "dd": __import__("json").dumps(result.details), + "pcu": parent_control_uuid, + "oci": obligation_candidate_id, + }) + self.db.commit() + + async def index_control( + self, + control_uuid: str, + control_id: str, + title: str, + action: str, + obj: str, + pattern_id: str, + collection: Optional[str] = None, + ) -> bool: + """Index a new atomic control in Qdrant for future dedup checks.""" + norm_action = normalize_action(action) + norm_object = normalize_object(obj) + canonical = canonicalize_text(action, obj, title) + embedding = await self._embed(canonical) + if not embedding: + return False + return await qdrant_upsert( + point_id=control_uuid, + embedding=embedding, + payload={ + "control_uuid": control_uuid, + "control_id": control_id, + "title": title, + "pattern_id": pattern_id, + "action_normalized": norm_action, + "object_normalized": norm_object, + "canonical_text": canonical, + }, + collection=collection, + ) diff --git a/control-pipeline/services/control_generator.py b/control-pipeline/services/control_generator.py new file mode 100644 index 0000000..b582cb3 --- /dev/null +++ b/control-pipeline/services/control_generator.py @@ -0,0 +1,2249 @@ +""" +Control Generator Pipeline — RAG → License → Structure/Reform → Harmonize → Anchor → Store. + +7-stage pipeline that generates canonical security controls from RAG chunks: + 1. RAG SCAN — Load unprocessed chunks (or new document versions) + 2. LICENSE CLASSIFY — Determine which of 3 license rules applies + 3a. STRUCTURE — Rule 1+2: Structure original text into control format + 3b. LLM REFORM — Rule 3: Fully reformulate (no original text, no source names) + 4. HARMONIZE — Check against existing controls for duplicates + 5. ANCHOR SEARCH — Find open-source references (OWASP, NIST, ENISA) + 6. STORE — Persist to DB with correct visibility flags + 7. MARK PROCESSED — Mark RAG chunks as processed (with version tracking) + +Three License Rules: + Rule 1 (free_use): Laws, Public Domain — original text allowed + Rule 2 (citation_required): CC-BY, CC-BY-SA — original text with citation + Rule 3 (restricted): BSI, ISO — full reformulation, no source names +""" + +import hashlib +import json +import logging +import os +import re +import uuid +from collections import defaultdict +from dataclasses import dataclass, field, asdict +from datetime import datetime, timezone +from typing import Dict, List, Optional, Set + +import httpx +from pydantic import BaseModel +from sqlalchemy import text +from sqlalchemy.orm import Session + +from .rag_client import ComplianceRAGClient, RAGSearchResult, get_rag_client +from .similarity_detector import check_similarity, SimilarityReport + +logger = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# Configuration +# --------------------------------------------------------------------------- + +SDK_URL = os.getenv("SDK_URL", "http://ai-compliance-sdk:8090") +EMBEDDING_URL = os.getenv("EMBEDDING_URL", "http://embedding-service:8087") +QDRANT_URL = os.getenv("QDRANT_URL", "http://host.docker.internal:6333") +ANTHROPIC_API_KEY = os.getenv("ANTHROPIC_API_KEY", "") +ANTHROPIC_MODEL = os.getenv("CONTROL_GEN_ANTHROPIC_MODEL", "claude-sonnet-4-6") +OLLAMA_URL = os.getenv("OLLAMA_URL", "http://host.docker.internal:11434") +OLLAMA_MODEL = os.getenv("CONTROL_GEN_OLLAMA_MODEL", "qwen3.5:35b-a3b") +LLM_TIMEOUT = float(os.getenv("CONTROL_GEN_LLM_TIMEOUT", "180")) + +HARMONIZATION_THRESHOLD = 0.85 # Cosine similarity above this = duplicate + +# Pipeline version — increment when generation rules change materially. +# v1: Original (local LLM prefilter, old prompt) +# v2: Anthropic decides relevance, null for non-requirement chunks, annexes protected +# v3: Scoped Control Applicability — applicable_industries, applicable_company_size, scope_conditions +PIPELINE_VERSION = 3 + +ALL_COLLECTIONS = [ + "bp_compliance_ce", + "bp_compliance_gesetze", + "bp_compliance_datenschutz", + "bp_dsfa_corpus", + "bp_legal_templates", +] + +# --------------------------------------------------------------------------- +# License Mapping (3-Rule System) +# --------------------------------------------------------------------------- + +REGULATION_LICENSE_MAP: dict[str, dict] = { + # RULE 1: FREE USE — Laws, Public Domain + # source_type: "law" = binding legislation, "guideline" = authority guidance (soft law), + # "standard" = voluntary framework/best practice, "restricted" = protected norm + # EU Regulations — names MUST match canonical DB source names + "eu_2016_679": {"license": "EU_LAW", "rule": 1, "source_type": "law", "name": "DSGVO (EU) 2016/679"}, + "eu_2024_1689": {"license": "EU_LAW", "rule": 1, "source_type": "law", "name": "KI-Verordnung (EU) 2024/1689"}, + "eu_2022_2555": {"license": "EU_LAW", "rule": 1, "source_type": "law", "name": "NIS2-Richtlinie (EU) 2022/2555"}, + "eu_2024_2847": {"license": "EU_LAW", "rule": 1, "source_type": "law", "name": "Cyber Resilience Act (CRA)"}, + "eu_2023_1230": {"license": "EU_LAW", "rule": 1, "source_type": "law", "name": "Maschinenverordnung (EU) 2023/1230"}, + "eu_2022_2065": {"license": "EU_LAW", "rule": 1, "source_type": "law", "name": "Digital Services Act (DSA)"}, + "eu_2022_1925": {"license": "EU_LAW", "rule": 1, "source_type": "law", "name": "Digital Markets Act (DMA)"}, + "eu_2022_868": {"license": "EU_LAW", "rule": 1, "source_type": "law", "name": "Data Governance Act (DGA)"}, + "eu_2019_770": {"license": "EU_LAW", "rule": 1, "source_type": "law", "name": "Digitale-Inhalte-Richtlinie"}, + "eu_2021_914": {"license": "EU_LAW", "rule": 1, "source_type": "law", "name": "Standardvertragsklauseln (SCC)"}, + "eu_2002_58": {"license": "EU_LAW", "rule": 1, "source_type": "law", "name": "ePrivacy-Richtlinie"}, + "eu_2000_31": {"license": "EU_LAW", "rule": 1, "source_type": "law", "name": "E-Commerce-Richtlinie"}, + "eu_2023_1803": {"license": "EU_LAW", "rule": 1, "source_type": "law", "name": "IFRS-Übernahmeverordnung"}, + "eucsa": {"license": "EU_LAW", "rule": 1, "source_type": "law", "name": "EU Cybersecurity Act"}, + "dataact": {"license": "EU_LAW", "rule": 1, "source_type": "law", "name": "Data Act"}, + "dora": {"license": "EU_LAW", "rule": 1, "source_type": "law", "name": "Digital Operational Resilience Act"}, + "ehds": {"license": "EU_LAW", "rule": 1, "source_type": "law", "name": "European Health Data Space"}, + "gpsr": {"license": "EU_LAW", "rule": 1, "source_type": "law", "name": "Allgemeine Produktsicherheitsverordnung"}, + "eu_2023_988": {"license": "EU_LAW", "rule": 1, "source_type": "law", "name": "Allgemeine Produktsicherheitsverordnung (GPSR)"}, + "eu_2023_1542": {"license": "EU_LAW", "rule": 1, "source_type": "law", "name": "Batterieverordnung (EU) 2023/1542"}, + "mica": {"license": "EU_LAW", "rule": 1, "source_type": "law", "name": "Markets in Crypto-Assets (MiCA)"}, + "psd2": {"license": "EU_LAW", "rule": 1, "source_type": "law", "name": "Zahlungsdiensterichtlinie 2"}, + "dpf": {"license": "EU_LAW", "rule": 1, "source_type": "law", "name": "EU-US Data Privacy Framework"}, + "dsm": {"license": "EU_LAW", "rule": 1, "source_type": "law", "name": "DSM-Urheberrechtsrichtlinie"}, + "amlr": {"license": "EU_LAW", "rule": 1, "source_type": "law", "name": "AML-Verordnung"}, + "eu_blue_guide_2022": {"license": "EU_PUBLIC", "rule": 1, "source_type": "guideline", "name": "EU Blue Guide 2022"}, + # NIST (Public Domain — NOT laws, voluntary standards) + "nist_sp_800_53": {"license": "NIST_PUBLIC_DOMAIN", "rule": 1, "source_type": "standard", "name": "NIST SP 800-53 Rev. 5"}, + "nist_sp800_53r5": {"license": "NIST_PUBLIC_DOMAIN", "rule": 1, "source_type": "standard", "name": "NIST SP 800-53 Rev. 5"}, + "nist_sp_800_63b": {"license": "NIST_PUBLIC_DOMAIN", "rule": 1, "source_type": "standard", "name": "NIST SP 800-63-3"}, + "nist_sp800_63_3": {"license": "NIST_PUBLIC_DOMAIN", "rule": 1, "source_type": "standard", "name": "NIST SP 800-63-3"}, + "nist_csf_2_0": {"license": "NIST_PUBLIC_DOMAIN", "rule": 1, "source_type": "standard", "name": "NIST Cybersecurity Framework 2.0"}, + "nist_sp_800_218": {"license": "NIST_PUBLIC_DOMAIN", "rule": 1, "source_type": "standard", "name": "NIST SP 800-218 (SSDF)"}, + "nist_sp800_218": {"license": "NIST_PUBLIC_DOMAIN", "rule": 1, "source_type": "standard", "name": "NIST SP 800-218 (SSDF)"}, + "nist_sp800_207": {"license": "NIST_PUBLIC_DOMAIN", "rule": 1, "source_type": "standard", "name": "NIST SP 800-207 (Zero Trust)"}, + "nist_ai_rmf": {"license": "NIST_PUBLIC_DOMAIN", "rule": 1, "source_type": "standard", "name": "NIST AI Risk Management Framework"}, + "nist_privacy_1_0": {"license": "NIST_PUBLIC_DOMAIN", "rule": 1, "source_type": "standard", "name": "NIST Privacy Framework 1.0"}, + "nistir_8259a": {"license": "NIST_PUBLIC_DOMAIN", "rule": 1, "source_type": "standard", "name": "NISTIR 8259A IoT Security"}, + "cisa_secure_by_design": {"license": "US_GOV_PUBLIC", "rule": 1, "source_type": "standard", "name": "CISA Secure by Design"}, + # German Laws + "bdsg": {"license": "DE_LAW", "rule": 1, "source_type": "law", "name": "Bundesdatenschutzgesetz (BDSG)"}, + "bdsg_2018_komplett": {"license": "DE_LAW", "rule": 1, "source_type": "law", "name": "Bundesdatenschutzgesetz (BDSG)"}, + "ttdsg": {"license": "DE_LAW", "rule": 1, "source_type": "law", "name": "TTDSG"}, + "tdddg_25": {"license": "DE_LAW", "rule": 1, "source_type": "law", "name": "TDDDG"}, + "tkg": {"license": "DE_LAW", "rule": 1, "source_type": "law", "name": "TKG"}, + "de_tkg": {"license": "DE_LAW", "rule": 1, "source_type": "law", "name": "TKG"}, + "bgb_komplett": {"license": "DE_LAW", "rule": 1, "source_type": "law", "name": "BGB"}, + "hgb": {"license": "DE_LAW", "rule": 1, "source_type": "law", "name": "Handelsgesetzbuch (HGB)"}, + "hgb_komplett": {"license": "DE_LAW", "rule": 1, "source_type": "law", "name": "Handelsgesetzbuch (HGB)"}, + "urhg_komplett": {"license": "DE_LAW", "rule": 1, "source_type": "law", "name": "UrhG"}, + "uwg": {"license": "DE_LAW", "rule": 1, "source_type": "law", "name": "UWG"}, + "tmg_komplett": {"license": "DE_LAW", "rule": 1, "source_type": "law", "name": "TMG"}, + "gewo": {"license": "DE_LAW", "rule": 1, "source_type": "law", "name": "Gewerbeordnung (GewO)"}, + "ao": {"license": "DE_LAW", "rule": 1, "source_type": "law", "name": "Abgabenordnung (AO)"}, + "ao_komplett": {"license": "DE_LAW", "rule": 1, "source_type": "law", "name": "Abgabenordnung (AO)"}, + "battdg": {"license": "DE_LAW", "rule": 1, "source_type": "law", "name": "Batteriegesetz"}, + # Austrian Laws + "at_dsg": {"license": "AT_LAW", "rule": 1, "source_type": "law", "name": "Österreichisches Datenschutzgesetz (DSG)"}, + "at_abgb": {"license": "AT_LAW", "rule": 1, "source_type": "law", "name": "AT ABGB"}, + "at_abgb_agb": {"license": "AT_LAW", "rule": 1, "source_type": "law", "name": "AT ABGB AGB-Recht"}, + "at_bao": {"license": "AT_LAW", "rule": 1, "source_type": "law", "name": "AT BAO"}, + "at_bao_ret": {"license": "AT_LAW", "rule": 1, "source_type": "law", "name": "AT BAO Retention"}, + "at_ecg": {"license": "AT_LAW", "rule": 1, "source_type": "law", "name": "AT E-Commerce-Gesetz"}, + "at_kschg": {"license": "AT_LAW", "rule": 1, "source_type": "law", "name": "AT Konsumentenschutzgesetz"}, + "at_medieng": {"license": "AT_LAW", "rule": 1, "source_type": "law", "name": "AT Mediengesetz"}, + "at_tkg": {"license": "AT_LAW", "rule": 1, "source_type": "law", "name": "Telekommunikationsgesetz Oesterreich"}, + "at_ugb": {"license": "AT_LAW", "rule": 1, "source_type": "law", "name": "AT UGB"}, + "at_ugb_ret": {"license": "AT_LAW", "rule": 1, "source_type": "law", "name": "AT UGB Retention"}, + "at_uwg": {"license": "AT_LAW", "rule": 1, "source_type": "law", "name": "AT UWG"}, + # Other EU Member State Laws + "fr_loi_informatique": {"license": "FR_LAW", "rule": 1, "source_type": "law", "name": "FR Loi Informatique"}, + "es_lopdgdd": {"license": "ES_LAW", "rule": 1, "source_type": "law", "name": "ES LOPDGDD"}, + "nl_uavg": {"license": "NL_LAW", "rule": 1, "source_type": "law", "name": "NL UAVG"}, + "it_codice_privacy": {"license": "IT_LAW", "rule": 1, "source_type": "law", "name": "IT Codice Privacy"}, + "hu_info_tv": {"license": "HU_LAW", "rule": 1, "source_type": "law", "name": "HU Információs törvény"}, + # EDPB Guidelines (EU Public Authority — soft law, not binding legislation) + "edpb_01_2020": {"license": "EU_PUBLIC", "rule": 1, "source_type": "guideline", "name": "EDPB 01/2020 Ergaenzende Massnahmen"}, + "edpb_02_2023": {"license": "EU_PUBLIC", "rule": 1, "source_type": "guideline", "name": "EDPB 02/2023 Technischer Anwendungsbereich"}, + "edpb_05_2020": {"license": "EU_PUBLIC", "rule": 1, "source_type": "guideline", "name": "EDPB 05/2020 Einwilligung"}, + "edpb_09_2022": {"license": "EU_PUBLIC", "rule": 1, "source_type": "guideline", "name": "EDPB 09/2022 Datenschutzverletzungen"}, + "edpb_bcr_01_2022": {"license": "EU_PUBLIC", "rule": 1, "source_type": "guideline", "name": "EDPB BCR Leitlinien"}, + "edpb_breach_09_2022": {"license": "EU_PUBLIC", "rule": 1, "source_type": "guideline", "name": "EDPB Breach Notification"}, + "edpb_connected_vehicles_01_2020": {"license": "EU_PUBLIC", "rule": 1, "source_type": "guideline", "name": "EDPB Connected Vehicles"}, + "edpb_dpbd_04_2019": {"license": "EU_PUBLIC", "rule": 1, "source_type": "guideline", "name": "EDPB Data Protection by Design"}, + "edpb_eprivacy_02_2023": {"license": "EU_PUBLIC", "rule": 1, "source_type": "guideline", "name": "EDPB ePrivacy"}, + "edpb_facial_recognition_05_2022": {"license": "EU_PUBLIC", "rule": 1, "source_type": "guideline", "name": "EDPB Facial Recognition"}, + "edpb_fines_04_2022": {"license": "EU_PUBLIC", "rule": 1, "source_type": "guideline", "name": "EDPB Fines Calculation"}, + "edpb_legitimate_interest": {"license": "EU_PUBLIC", "rule": 1, "source_type": "guideline", "name": "EDPB Legitimate Interest"}, + "edpb_legitimate_interest_01_2024": {"license": "EU_PUBLIC","rule": 1, "source_type": "guideline", "name": "EDPB Legitimate Interest 2024"}, + "edpb_social_media_08_2020": {"license": "EU_PUBLIC", "rule": 1, "source_type": "guideline", "name": "EDPB Social Media"}, + "edpb_transfers_01_2020":{"license": "EU_PUBLIC", "rule": 1, "source_type": "guideline", "name": "EDPB Transfers 01/2020"}, + "edpb_transfers_07_2020":{"license": "EU_PUBLIC", "rule": 1, "source_type": "guideline", "name": "EDPB Transfers 07/2020"}, + "edpb_video_03_2019": {"license": "EU_PUBLIC", "rule": 1, "source_type": "guideline", "name": "EDPB Video Surveillance"}, + "edps_dpia_list": {"license": "EU_PUBLIC", "rule": 1, "source_type": "guideline", "name": "EDPS DPIA Liste"}, + "edpb_certification_01_2018": {"license": "EU_PUBLIC", "rule": 1, "source_type": "guideline", "name": "EDPB Certification 01/2018"}, + "edpb_certification_01_2019": {"license": "EU_PUBLIC", "rule": 1, "source_type": "guideline", "name": "EDPB Certification 01/2019"}, + "eaa": {"license": "EU_LAW", "rule": 1, "source_type": "law", "name": "European Accessibility Act"}, + # WP29 (pre-EDPB) Guidelines — soft law + "wp244_profiling": {"license": "EU_PUBLIC", "rule": 1, "source_type": "guideline", "name": "WP29 Profiling"}, + "wp251_profiling": {"license": "EU_PUBLIC", "rule": 1, "source_type": "guideline", "name": "WP29 Data Portability"}, + "wp260_transparency": {"license": "EU_PUBLIC", "rule": 1, "source_type": "guideline", "name": "WP29 Transparency"}, + + # RULE 2: CITATION REQUIRED — CC-BY, CC-BY-SA (voluntary standards) + "owasp_asvs": {"license": "CC-BY-SA-4.0", "rule": 2, "source_type": "standard", "name": "OWASP ASVS 4.0", + "attribution": "OWASP Foundation, CC BY-SA 4.0"}, + "owasp_masvs": {"license": "CC-BY-SA-4.0", "rule": 2, "source_type": "standard", "name": "OWASP MASVS 2.0", + "attribution": "OWASP Foundation, CC BY-SA 4.0"}, + "owasp_top10": {"license": "CC-BY-SA-4.0", "rule": 2, "source_type": "standard", "name": "OWASP Top 10 (2021)", + "attribution": "OWASP Foundation, CC BY-SA 4.0"}, + "owasp_top10_2021": {"license": "CC-BY-SA-4.0", "rule": 2, "source_type": "standard", "name": "OWASP Top 10 (2021)", + "attribution": "OWASP Foundation, CC BY-SA 4.0"}, + "owasp_api_top10_2023": {"license": "CC-BY-SA-4.0", "rule": 2, "source_type": "standard", "name": "OWASP API Security Top 10 (2023)", + "attribution": "OWASP Foundation, CC BY-SA 4.0"}, + "owasp_samm": {"license": "CC-BY-SA-4.0", "rule": 2, "source_type": "standard", "name": "OWASP SAMM 2.0", + "attribution": "OWASP Foundation, CC BY-SA 4.0"}, + "owasp_mobile_top10": {"license": "CC-BY-SA-4.0", "rule": 2, "source_type": "standard", "name": "OWASP Mobile Top 10", + "attribution": "OWASP Foundation, CC BY-SA 4.0"}, + "oecd_ai_principles": {"license": "OECD_PUBLIC", "rule": 2, "source_type": "standard", "name": "OECD KI-Empfehlung", + "attribution": "OECD"}, + + # RULE 3: RESTRICTED — Full reformulation required + # Names stored as INTERNAL_ONLY — never exposed to customers +} + +# Prefix-based matching for wildcard entries +_RULE3_PREFIXES = ["bsi_", "iso_", "etsi_"] +_RULE2_PREFIXES = ["enisa_"] + + +def _classify_regulation(regulation_code: str) -> dict: + """Determine license rule for a regulation_code. + + Returns dict with keys: license, rule, name, source_type. + source_type is one of: law, guideline, standard, restricted. + """ + code = regulation_code.lower().strip() + + # Exact match first + if code in REGULATION_LICENSE_MAP: + return REGULATION_LICENSE_MAP[code] + + # Prefix match for Rule 2 (ENISA = standard) + for prefix in _RULE2_PREFIXES: + if code.startswith(prefix): + return {"license": "CC-BY-4.0", "rule": 2, "source_type": "standard", + "name": "ENISA", "attribution": "ENISA, CC BY 4.0"} + + # Prefix match for Rule 3 (BSI/ISO/ETSI = restricted) + for prefix in _RULE3_PREFIXES: + if code.startswith(prefix): + return {"license": f"{prefix.rstrip('_').upper()}_RESTRICTED", "rule": 3, + "source_type": "restricted", "name": "INTERNAL_ONLY"} + + # Unknown → treat as restricted (safe default) + logger.warning("Unknown regulation_code %r — defaulting to Rule 3 (restricted)", code) + return {"license": "UNKNOWN", "rule": 3, "source_type": "restricted", "name": "INTERNAL_ONLY"} + + +# --------------------------------------------------------------------------- +# Domain detection from content +# --------------------------------------------------------------------------- + +DOMAIN_KEYWORDS = { + "AUTH": ["authentication", "login", "password", "credential", "mfa", "2fa", + "session", "token", "oauth", "identity", "authentifizierung", "anmeldung"], + "CRYP": ["encryption", "cryptography", "tls", "ssl", "certificate", "hashing", + "aes", "rsa", "verschlüsselung", "kryptographie", "cipher", "schlüssel"], + "NET": ["network", "firewall", "dns", "vpn", "proxy", "segmentation", + "netzwerk", "routing", "port", "intrusion"], + "DATA": ["data protection", "privacy", "personal data", "datenschutz", + "personenbezogen", "dsgvo", "gdpr", "löschung", "verarbeitung"], + "LOG": ["logging", "monitoring", "audit trail", "siem", "alert", "anomaly", + "protokollierung", "überwachung"], + "ACC": ["access control", "authorization", "rbac", "permission", "privilege", + "zugriffskontrolle", "berechtigung", "autorisierung"], + "SEC": ["vulnerability", "patch", "update", "hardening", "configuration", + "schwachstelle", "härtung", "konfiguration"], + "INC": ["incident", "response", "breach", "recovery", "backup", + "vorfall", "wiederherstellung", "notfall"], + "AI": ["artificial intelligence", "machine learning", "model", "bias", + "ki", "künstliche intelligenz", "algorithmus", "training"], + "COMP": ["compliance", "audit", "regulation", "standard", "certification", + "konformität", "prüfung", "zertifizierung"], + "GOV": ["behörde", "verwaltung", "öffentlich", "register", "gewerberegister", + "handelsregister", "meldepflicht", "aufsicht", "genehmigung", "bescheid", + "verwaltungsakt", "ordnungswidrig", "bußgeld", "staat", "ministerium", + "bundesamt", "landesamt", "kommune", "gebietskörperschaft"], + "LAB": ["arbeitnehmer", "arbeitgeber", "arbeitsschutz", "arbeitszeit", "betriebsrat", + "kündigung", "beschäftigung", "mindestlohn", "arbeitsvertrag", "betriebsverfassung", + "arbeitsrecht", "arbeitsstätte", "gefährdungsbeurteilung", "unterweisung"], + "FIN": ["finanz", "bankwesen", "zahlungsverkehr", "geldwäsche", "bilanz", "rechnungslegung", + "buchführung", "jahresabschluss", "steuererklärung", "kapitalmarkt", "wertpapier", + "kreditinstitut", "finanzdienstleistung", "bankenaufsicht", "bafin"], + "TRD": ["handelsrecht", "gewerbeordnung", "gewerbe", "handwerk", "gewerbeuntersagung", + "gewerbebetrieb", "handelsgesetzbuch", "handelsregister", "kaufmann", + "unternehmer", "wettbewerb", "verbraucherschutz", "produktsicherheit"], + "ENV": ["umweltschutz", "emission", "abfall", "immission", "gewässerschutz", + "naturschutz", "umweltverträglichkeit", "klimaschutz", "nachhaltigkeit", + "entsorgung", "recycling", "umweltrecht"], + "HLT": ["gesundheit", "medizinprodukt", "arzneimittel", "patient", "krankenhaus", + "hygiene", "infektionsschutz", "medizin", "pflege", "therapie"], +} + + +CATEGORY_KEYWORDS = { + "encryption": ["encryption", "cryptography", "tls", "ssl", "certificate", "hashing", + "aes", "rsa", "verschlüsselung", "kryptographie", "cipher", "schlüssel"], + "authentication": ["authentication", "login", "password", "credential", "mfa", "2fa", + "session", "oauth", "authentifizierung", "anmeldung", "passwort"], + "network": ["network", "firewall", "dns", "vpn", "proxy", "segmentation", + "netzwerk", "routing", "port", "intrusion", "ids", "ips"], + "data_protection": ["data protection", "privacy", "personal data", "datenschutz", + "personenbezogen", "dsgvo", "gdpr", "löschung", "verarbeitung", "einwilligung"], + "logging": ["logging", "monitoring", "audit trail", "siem", "alert", "anomaly", + "protokollierung", "überwachung", "nachvollziehbar"], + "incident": ["incident", "response", "breach", "recovery", "vorfall", "sicherheitsvorfall"], + "continuity": ["backup", "disaster recovery", "notfall", "wiederherstellung", "notfallplan", + "business continuity", "ausfallsicherheit"], + "compliance": ["compliance", "audit", "regulation", "certification", "konformität", + "prüfung", "zertifizierung", "nachweis"], + "supply_chain": ["supplier", "vendor", "third party", "lieferant", "auftragnehmer", + "unterauftragnehmer", "supply chain", "dienstleister"], + "physical": ["physical", "building", "access zone", "physisch", "gebäude", "zutritt", + "schließsystem", "rechenzentrum"], + "personnel": ["training", "awareness", "employee", "schulung", "mitarbeiter", + "sensibilisierung", "personal", "unterweisung"], + "application": ["application", "software", "code review", "sdlc", "secure coding", + "anwendung", "entwicklung", "software-entwicklung", "api"], + "system": ["hardening", "patch", "configuration", "update", "härtung", "konfiguration", + "betriebssystem", "system", "server"], + "risk": ["risk assessment", "risk management", "risiko", "bewertung", "risikobewertung", + "risikoanalyse", "bedrohung", "threat"], + "governance": ["governance", "policy", "organization", "isms", "sicherheitsorganisation", + "richtlinie", "verantwortlichkeit", "rolle"], + "hardware": ["hardware", "platform", "firmware", "bios", "tpm", "chip", + "plattform", "geräte"], + "identity": ["identity", "iam", "directory", "ldap", "sso", "provisioning", + "identität", "identitätsmanagement", "benutzerverzeichnis"], + "public_administration": ["behörde", "verwaltung", "öffentlich", "register", "gewerberegister", + "handelsregister", "meldepflicht", "aufsicht", "genehmigung", "bescheid", + "verwaltungsakt", "ordnungswidrig", "bußgeld", "amt"], + "labor_law": ["arbeitnehmer", "arbeitgeber", "arbeitsschutz", "arbeitszeit", "betriebsrat", + "kündigung", "beschäftigung", "mindestlohn", "arbeitsvertrag", "betriebsverfassung"], + "finance": ["finanz", "bankwesen", "zahlungsverkehr", "geldwäsche", "bilanz", "rechnungslegung", + "buchführung", "jahresabschluss", "kapitalmarkt", "wertpapier", "bafin"], + "trade_regulation": ["gewerbeordnung", "gewerbe", "handwerk", "gewerbeuntersagung", + "gewerbebetrieb", "handelsrecht", "kaufmann", "wettbewerb", + "verbraucherschutz", "produktsicherheit"], + "environmental": ["umweltschutz", "emission", "abfall", "immission", "gewässerschutz", + "naturschutz", "klimaschutz", "nachhaltigkeit", "entsorgung"], + "health": ["gesundheit", "medizinprodukt", "arzneimittel", "patient", "krankenhaus", + "hygiene", "infektionsschutz", "pflege"], +} + +VALID_CATEGORIES = set(CATEGORY_KEYWORDS.keys()) +VALID_DOMAINS = {"AUTH", "CRYP", "NET", "DATA", "LOG", "ACC", "SEC", "INC", + "AI", "COMP", "GOV", "LAB", "FIN", "TRD", "ENV", "HLT"} + +# --------------------------------------------------------------------------- +# Recital (Erwägungsgrund) detection in source text +# --------------------------------------------------------------------------- + +# Pattern: standalone recital number like (125)\n or (126) at line start +_RECITAL_RE = re.compile(r'\((\d{1,3})\)\s*\n') + +# Recital-typical phrasing (German EU law Erwägungsgründe) +_RECITAL_PHRASES = [ + "in erwägung nachstehender gründe", + "erwägungsgrund", + "in anbetracht", + "daher sollte", + "aus diesem grund", + "es ist daher", + "folglich sollte", + "es sollte daher", + "in diesem zusammenhang", +] + + +def _detect_recital(text: str) -> Optional[dict]: + """Detect if source text is a recital (Erwägungsgrund) rather than an article. + + Returns a dict with detection details if recital markers are found, + or None if the text appears to be genuine article text. + + Detection criteria: + 1. Standalone recital numbers like (126)\\n in the text + 2. Recital-typical phrasing ("daher sollte", "erwägungsgrund", etc.) + """ + if not text: + return None + + # Check 1: Recital number markers + recital_matches = _RECITAL_RE.findall(text) + + # Check 2: Recital phrasing + text_lower = text.lower() + phrase_hits = [p for p in _RECITAL_PHRASES if p in text_lower] + + if not recital_matches and not phrase_hits: + return None + + # Require at least recital numbers OR >=2 phrase hits to be a suspect + if not recital_matches and len(phrase_hits) < 2: + return None + + return { + "recital_suspect": True, + "recital_numbers": recital_matches[:10], + "recital_phrases": phrase_hits[:5], + "detection_method": "regex+phrases" if recital_matches and phrase_hits + else "regex" if recital_matches else "phrases", + } + +CATEGORY_LIST_STR = ", ".join(sorted(VALID_CATEGORIES)) + +VERIFICATION_KEYWORDS = { + "code_review": ["source code", "code review", "static analysis", "sast", "dast", + "dependency check", "quellcode", "codeanalyse", "secure coding", + "software development", "api", "input validation", "output encoding"], + "document": ["policy", "procedure", "documentation", "training", "awareness", + "richtlinie", "dokumentation", "schulung", "nachweis", "vertrag", + "organizational", "process", "role", "responsibility"], + "tool": ["scanner", "monitoring", "siem", "ids", "ips", "firewall", "antivirus", + "vulnerability scan", "penetration test", "tool", "automated"], + "hybrid": [], # Assigned when multiple methods match equally +} + + +def _detect_category(text: str) -> Optional[str]: + """Detect the most likely category from text content.""" + text_lower = text.lower() + scores: dict[str, int] = {} + for cat, keywords in CATEGORY_KEYWORDS.items(): + scores[cat] = sum(1 for kw in keywords if kw in text_lower) + if not scores or max(scores.values()) == 0: + return None + return max(scores, key=scores.get) + + +def _detect_verification_method(text: str) -> Optional[str]: + """Detect verification method from text content.""" + text_lower = text.lower() + scores: dict[str, int] = {} + for method, keywords in VERIFICATION_KEYWORDS.items(): + if method == "hybrid": + continue + scores[method] = sum(1 for kw in keywords if kw in text_lower) + if not scores or max(scores.values()) == 0: + return None + top = sorted(scores.items(), key=lambda x: -x[1]) + # If top two are close, it's hybrid + if len(top) >= 2 and top[0][1] > 0 and top[1][1] > 0 and top[1][1] >= top[0][1] * 0.7: + return "hybrid" + return top[0][0] if top[0][1] > 0 else None + + +def _detect_domain(text: str) -> str: + """Detect the most likely domain from text content.""" + text_lower = text.lower() + scores: dict[str, int] = {} + for domain, keywords in DOMAIN_KEYWORDS.items(): + scores[domain] = sum(1 for kw in keywords if kw in text_lower) + if not scores or max(scores.values()) == 0: + return "SEC" # Default + return max(scores, key=scores.get) + + +# --------------------------------------------------------------------------- +# Data Models +# --------------------------------------------------------------------------- + +class GeneratorConfig(BaseModel): + collections: Optional[List[str]] = None + domain: Optional[str] = None + batch_size: int = 5 + max_controls: int = 0 # 0 = unlimited (process ALL chunks) + max_chunks: int = 0 # 0 = unlimited; >0 = stop after N chunks (respects document boundaries) + skip_processed: bool = True + skip_web_search: bool = False + dry_run: bool = False + existing_job_id: Optional[str] = None # If set, reuse this job instead of creating a new one + regulation_filter: Optional[List[str]] = None # Only process chunks matching these regulation_code prefixes + skip_prefilter: bool = False # If True, skip local LLM pre-filter (send all chunks to API) + + +@dataclass +class GeneratedControl: + control_id: str = "" + title: str = "" + objective: str = "" + rationale: str = "" + scope: dict = field(default_factory=dict) + requirements: list = field(default_factory=list) + test_procedure: list = field(default_factory=list) + evidence: list = field(default_factory=list) + severity: str = "medium" + risk_score: float = 5.0 + implementation_effort: str = "m" + open_anchors: list = field(default_factory=list) + release_state: str = "draft" + tags: list = field(default_factory=list) + # 3-rule fields + license_rule: Optional[int] = None + source_original_text: Optional[str] = None + source_citation: Optional[dict] = None + customer_visible: bool = True + generation_metadata: dict = field(default_factory=dict) + generation_strategy: str = "ungrouped" # ungrouped | document_grouped + # Classification fields + verification_method: Optional[str] = None # code_review, document, tool, hybrid + category: Optional[str] = None # one of 22 categories + target_audience: Optional[list] = None # e.g. ["unternehmen", "behoerden", "entwickler"] + # Scoped Control Applicability (v3) + applicable_industries: Optional[list] = None # e.g. ["all"] or ["Telekommunikation", "Energie"] + applicable_company_size: Optional[list] = None # e.g. ["all"] or ["medium", "large", "enterprise"] + scope_conditions: Optional[dict] = None # e.g. {"requires_any": ["uses_ai"], "description": "..."} + # Anti-Fake-Evidence: truth tracking for generated controls + truth_status: str = "generated" + may_be_used_as_evidence: bool = False + + +@dataclass +class GeneratorResult: + job_id: str = "" + status: str = "completed" + total_chunks_scanned: int = 0 + controls_generated: int = 0 + controls_verified: int = 0 + controls_needs_review: int = 0 + controls_too_close: int = 0 + controls_duplicates_found: int = 0 + controls_qa_fixed: int = 0 + chunks_skipped_prefilter: int = 0 + errors: list = field(default_factory=list) + controls: list = field(default_factory=list) + + +# --------------------------------------------------------------------------- +# LLM Client (via Go SDK) +# --------------------------------------------------------------------------- + +async def _llm_chat(prompt: str, system_prompt: Optional[str] = None) -> str: + """Call LLM — Anthropic Claude (primary) or Ollama (fallback).""" + if ANTHROPIC_API_KEY: + logger.info("Calling Anthropic API (model=%s)...", ANTHROPIC_MODEL) + result = await _llm_anthropic(prompt, system_prompt) + if result: + logger.info("Anthropic API success (%d chars)", len(result)) + return result + logger.warning("Anthropic failed, falling back to Ollama") + + logger.info("Calling Ollama (model=%s)...", OLLAMA_MODEL) + return await _llm_ollama(prompt, system_prompt) + + +async def _llm_local(prompt: str, system_prompt: Optional[str] = None) -> str: + """Call local Ollama LLM only (for pre-filtering and classification tasks).""" + return await _llm_ollama(prompt, system_prompt) + + +PREFILTER_SYSTEM_PROMPT = """Du bist ein Compliance-Analyst. Deine Aufgabe: Prüfe ob ein Textabschnitt eine konkrete Sicherheitsanforderung, Datenschutzpflicht, oder technische/organisatorische Maßnahme enthält. + +Antworte NUR mit einem JSON-Objekt: {"relevant": true/false, "reason": "kurze Begründung"} + +Relevant = true wenn der Text mindestens EINE der folgenden enthält: +- Konkrete Pflicht/Anforderung ("muss", "soll", "ist sicherzustellen") +- Technische Sicherheitsmaßnahme (Verschlüsselung, Zugriffskontrolle, Logging) +- Organisatorische Maßnahme (Schulung, Dokumentation, Audit) +- Datenschutz-Vorgabe (Löschpflicht, Einwilligung, Zweckbindung) +- Risikomanagement-Anforderung + +Relevant = false wenn der Text NUR enthält: +- Definitionen ohne Pflichten +- Inhaltsverzeichnisse oder Verweise +- Reine Begriffsbestimmungen +- Übergangsvorschriften ohne Substanz +- Adressaten/Geltungsbereich ohne Anforderung""" + + +async def _prefilter_chunk(chunk_text: str) -> tuple[bool, str]: + """Use local LLM to check if a chunk contains an actionable requirement. + + Returns (is_relevant, reason). + Much cheaper than sending every chunk to Anthropic. + """ + prompt = f"""Prüfe ob dieser Textabschnitt eine konkrete Sicherheitsanforderung oder Compliance-Pflicht enthält. + +Text: +--- +{chunk_text[:1500]} +--- + +Antworte NUR mit JSON: {{"relevant": true/false, "reason": "kurze Begründung"}}""" + + try: + raw = await _llm_local(prompt, PREFILTER_SYSTEM_PROMPT) + data = _parse_llm_json(raw) + if data: + return data.get("relevant", True), data.get("reason", "") + # If parsing fails, assume relevant (don't skip) + return True, "parse_failed" + except Exception as e: + logger.warning("Prefilter failed: %s — treating as relevant", e) + return True, f"error: {e}" + + +async def _llm_anthropic(prompt: str, system_prompt: Optional[str] = None) -> str: + """Call Anthropic Messages API.""" + headers = { + "x-api-key": ANTHROPIC_API_KEY, + "anthropic-version": "2023-06-01", + "content-type": "application/json", + } + payload = { + "model": ANTHROPIC_MODEL, + "max_tokens": 8192, + "messages": [{"role": "user", "content": prompt}], + } + if system_prompt: + payload["system"] = system_prompt + + try: + async with httpx.AsyncClient(timeout=LLM_TIMEOUT) as client: + resp = await client.post( + "https://api.anthropic.com/v1/messages", + headers=headers, + json=payload, + ) + if resp.status_code != 200: + logger.error("Anthropic API %d: %s", resp.status_code, resp.text[:300]) + return "" + data = resp.json() + content = data.get("content", []) + if content and isinstance(content, list): + return content[0].get("text", "") + return "" + except Exception as e: + logger.error("Anthropic request failed: %s (type: %s)", e, type(e).__name__) + return "" + + +async def _llm_ollama(prompt: str, system_prompt: Optional[str] = None) -> str: + """Call Ollama chat API (fallback).""" + messages = [] + if system_prompt: + messages.append({"role": "system", "content": system_prompt}) + messages.append({"role": "user", "content": prompt}) + + payload = { + "model": OLLAMA_MODEL, + "messages": messages, + "stream": False, + "format": "json", + "options": {"num_predict": 512}, # Limit response length for speed + "think": False, # Disable thinking for faster responses + } + + try: + async with httpx.AsyncClient(timeout=LLM_TIMEOUT) as client: + resp = await client.post(f"{OLLAMA_URL}/api/chat", json=payload) + if resp.status_code != 200: + logger.error("Ollama chat failed %d: %s", resp.status_code, resp.text[:300]) + return "" + data = resp.json() + msg = data.get("message", {}) + if isinstance(msg, dict): + return msg.get("content", "") + return data.get("response", str(msg)) + except Exception as e: + logger.error("Ollama request failed: %s", e) + return "" + + +async def _get_embedding(text: str) -> list[float]: + """Get embedding vector for text via embedding service.""" + try: + async with httpx.AsyncClient(timeout=10.0) as client: + resp = await client.post( + f"{EMBEDDING_URL}/embed", + json={"texts": [text]}, + ) + resp.raise_for_status() + embeddings = resp.json().get("embeddings", []) + return embeddings[0] if embeddings else [] + except Exception: + return [] + + +async def _get_embeddings_batch(texts: list[str], batch_size: int = 32) -> list[list[float]]: + """Get embedding vectors for multiple texts in batches.""" + all_embeddings: list[list[float]] = [] + for i in range(0, len(texts), batch_size): + batch = texts[i:i + batch_size] + try: + async with httpx.AsyncClient(timeout=30.0) as client: + resp = await client.post( + f"{EMBEDDING_URL}/embed", + json={"texts": batch}, + ) + resp.raise_for_status() + embeddings = resp.json().get("embeddings", []) + all_embeddings.extend(embeddings) + except Exception as e: + logger.warning("Batch embedding failed for %d texts: %s", len(batch), e) + all_embeddings.extend([[] for _ in batch]) + return all_embeddings + + +def _cosine_sim(a: list[float], b: list[float]) -> float: + """Compute cosine similarity between two vectors.""" + if not a or not b or len(a) != len(b): + return 0.0 + dot = sum(x * y for x, y in zip(a, b)) + norm_a = sum(x * x for x in a) ** 0.5 + norm_b = sum(x * x for x in b) ** 0.5 + if norm_a == 0 or norm_b == 0: + return 0.0 + return dot / (norm_a * norm_b) + + +# --------------------------------------------------------------------------- +# JSON Parsing Helper +# --------------------------------------------------------------------------- + +def _parse_llm_json(raw: str) -> dict: + """Extract JSON from LLM response (handles markdown fences).""" + # Try extracting from ```json ... ``` blocks + match = re.search(r"```(?:json)?\s*\n?(.*?)\n?```", raw, re.DOTALL) + text = match.group(1) if match else raw + + # Try parsing directly + try: + return json.loads(text) + except json.JSONDecodeError: + pass + + # Try finding first { ... } block + brace_match = re.search(r"\{.*\}", text, re.DOTALL) + if brace_match: + try: + return json.loads(brace_match.group(0)) + except json.JSONDecodeError: + pass + + logger.warning("Failed to parse LLM JSON response") + return {} + + +def _parse_llm_json_array(raw: str) -> list[dict]: + """Extract a JSON array from LLM response — returns list of dicts.""" + match = re.search(r"```(?:json)?\s*\n?(.*?)\n?```", raw, re.DOTALL) + text = match.group(1) if match else raw + + # Try parsing as array directly + try: + parsed = json.loads(text) + if isinstance(parsed, list): + return parsed + if isinstance(parsed, dict): + # Check if it wraps an array (e.g. {"controls": [...]}) + for key in ("controls", "results", "items", "data"): + if key in parsed and isinstance(parsed[key], list): + return parsed[key] + return [parsed] + except json.JSONDecodeError: + pass + + # Try finding [ ... ] block + bracket_match = re.search(r"\[.*\]", text, re.DOTALL) + if bracket_match: + try: + parsed = json.loads(bracket_match.group(0)) + if isinstance(parsed, list): + return parsed + except json.JSONDecodeError: + pass + + # Try finding multiple { ... } blocks (LLM sometimes returns separate objects) + objects = [] + for obj_match in re.finditer(r"\{[^{}]*(?:\{[^{}]*\}[^{}]*)*\}", text, re.DOTALL): + try: + obj = json.loads(obj_match.group(0)) + if isinstance(obj, dict) and obj.get("title"): + objects.append(obj) + except json.JSONDecodeError: + continue + if objects: + logger.info("Parsed %d individual JSON objects from batch response", len(objects)) + return objects + + # Fallback: try single object + single = _parse_llm_json(raw) + if single: + logger.info("Batch parse fallback: extracted single object") + else: + logger.warning("Batch parse failed — logging first 500 chars: %s", raw[:500]) + return [single] if single else [] + + +# --------------------------------------------------------------------------- +# Pipeline +# --------------------------------------------------------------------------- + +REFORM_SYSTEM_PROMPT = """Du bist ein Security-Compliance-Experte. Deine Aufgabe ist es, eigenständige +Security Controls zu formulieren. Du formulierst IMMER in eigenen Worten. +KOPIERE KEINE Sätze aus dem Quelltext. Verwende eigene Begriffe und Struktur. +NENNE NICHT die Quelle. Keine proprietären Bezeichner. + +WICHTIG — Truthfulness-Guardrail: +Deine Ausgabe ist ein ENTWURF. Formuliere NIEMALS Behauptungen über bereits erfolgte Umsetzung. +Verwende NICHT: "ist compliant", "erfüllt vollständig", "wurde geprüft", "wurde umgesetzt", +"ist auditiert", "vollständig implementiert", "nachweislich konform". +Verwende stattdessen: "soll umsetzen", "ist vorgesehen", "muss implementiert werden". + +Antworte NUR mit validem JSON. Bei mehreren Controls antworte mit einem JSON-Array.""" + +STRUCTURE_SYSTEM_PROMPT = """Du bist ein Security-Compliance-Experte. Strukturiere den gegebenen Text +als praxisorientiertes Security Control. Erstelle eine verständliche, umsetzbare Formulierung. + +WICHTIG — Truthfulness-Guardrail: +Deine Ausgabe ist ein ENTWURF. Formuliere NIEMALS Behauptungen über bereits erfolgte Umsetzung. +Verwende NICHT: "ist compliant", "erfüllt vollständig", "wurde geprüft", "wurde umgesetzt". +Verwende stattdessen: "soll umsetzen", "ist vorgesehen", "muss implementiert werden". + +Antworte NUR mit validem JSON. Bei mehreren Controls antworte mit einem JSON-Array.""" + +# Shared applicability prompt block — appended to all generation prompts (v3) +APPLICABILITY_PROMPT = """- applicable_industries: Liste der Branchen fuer die dieses Control relevant ist. + Verwende ["all"] wenn der Control branchenuebergreifend gilt. + Moegliche Werte: "all", "Technologie / IT", "IT Dienstleistungen", "E-Commerce / Handel", + "Finanzdienstleistungen", "Versicherungen", "Gesundheitswesen", "Pharma", "Bildung", + "Beratung / Consulting", "Marketing / Agentur", "Produktion / Industrie", + "Logistik / Transport", "Immobilien", "Bau", "Energie", "Automobil", + "Luft- / Raumfahrt", "Maschinenbau", "Anlagenbau", "Automatisierung", "Robotik", + "Messtechnik", "Agrar", "Chemie", "Minen / Bergbau", "Telekommunikation", + "Medien / Verlage", "Gastronomie / Hotellerie", "Recht / Kanzlei", + "Oeffentlicher Dienst", "Verteidigung / Ruestung", "Wasser- / Abwasserwirtschaft", + "Lebensmittel", "Digitale Infrastruktur", "Weltraum", "Post / Kurierdienste", + "Abfallwirtschaft", "Forschung" + Beispiel: TKG-Controls → ["Telekommunikation"] + Beispiel: DSGVO Art. 32 → ["all"] + Beispiel: NIS2 Art. 21 → ["Energie", "Gesundheitswesen", "Digitale Infrastruktur", "Logistik / Transport", ...] +- applicable_company_size: Ab welcher Unternehmensgroesse gilt dieses Control? + Verwende ["all"] wenn keine Groessenbeschraenkung. + Moegliche Werte: "all", "micro", "small", "medium", "large", "enterprise" + Groessen: micro (<10 MA), small (10-49), medium (50-249), large (250-999), enterprise (1000+) + Beispiel: NIS2 Art. 21 → ["medium", "large", "enterprise"] + Beispiel: DSGVO Art. 5 → ["all"] +- scope_conditions: Optionale Bedingungen aus dem Compliance-Scope des Unternehmens. + null wenn keine besonderen Bedingungen. Sonst JSON-Objekt: + {"requires_any": ["signal1", "signal2"], "description": "Kurze Erklaerung wann relevant"} + Moegliche Signale: "uses_ai", "third_country_transfer", "processes_health_data", + "processes_minors_data", "automated_decisions", "employee_monitoring", + "video_surveillance", "financial_data", "is_kritis_operator", "payment_services" + Beispiel AI Act: {"requires_any": ["uses_ai"], "description": "Nur bei KI-Einsatz relevant"} + Beispiel SCC: {"requires_any": ["third_country_transfer"], "description": "Nur bei Drittlandtransfer"} + Beispiel DSGVO Art. 32 (allgemein): null""" + + +class ControlGeneratorPipeline: + """Orchestrates the 7-stage control generation pipeline.""" + + def __init__(self, db: Session, rag_client: Optional[ComplianceRAGClient] = None): + self.db = db + self.rag = rag_client or get_rag_client() + self._existing_controls: Optional[List[dict]] = None + self._existing_embeddings: Dict[str, List[float]] = {} + + # ── Stage 1: RAG Scan ────────────────────────────────────────────── + + async def _scan_rag(self, config: GeneratorConfig) -> list[RAGSearchResult]: + """Scroll through ALL chunks in RAG collections. + + Uses DIRECT Qdrant scroll API (bypasses Go SDK which has offset cycling bugs). + Filters out already-processed chunks by hash. + """ + collections = config.collections or ALL_COLLECTIONS + all_results: list[RAGSearchResult] = [] + + # Pre-load all processed hashes for fast filtering + processed_hashes: set[str] = set() + if config.skip_processed: + try: + result = self.db.execute( + text("SELECT chunk_hash FROM canonical_processed_chunks") + ) + processed_hashes = {row[0] for row in result} + logger.info("Loaded %d processed chunk hashes", len(processed_hashes)) + except Exception as e: + logger.warning("Error loading processed hashes: %s", e) + + seen_hashes: set[str] = set() + + for collection in collections: + page = 0 + collection_total = 0 + collection_new = 0 + qdrant_offset = None # Qdrant uses point ID as offset + + while True: + # Direct Qdrant scroll API — bypasses Go SDK offset cycling bug + try: + scroll_body: dict = { + "limit": 250, + "with_payload": True, + "with_vector": False, + } + if qdrant_offset is not None: + scroll_body["offset"] = qdrant_offset + + async with httpx.AsyncClient(timeout=30.0) as client: + resp = await client.post( + f"{QDRANT_URL}/collections/{collection}/points/scroll", + json=scroll_body, + ) + if resp.status_code != 200: + logger.error("Qdrant scroll %s failed: %d %s", collection, resp.status_code, resp.text[:200]) + break + data = resp.json().get("result", {}) + points = data.get("points", []) + next_page_offset = data.get("next_page_offset") + except Exception as e: + logger.error("Qdrant scroll error for %s: %s", collection, e) + break + + if not points: + break + + collection_total += len(points) + + for point in points: + payload = point.get("payload", {}) + # Different collections use different field names for text + chunk_text = (payload.get("chunk_text", "") + or payload.get("content", "") + or payload.get("text", "") + or payload.get("page_content", "")) + if not chunk_text or len(chunk_text.strip()) < 50: + continue + + h = hashlib.sha256(chunk_text.encode()).hexdigest() + + if h in seen_hashes: + continue + seen_hashes.add(h) + + if h in processed_hashes: + continue + + # Convert Qdrant point to RAGSearchResult + # Handle varying payload schemas across collections + reg_code = (payload.get("regulation_id", "") + or payload.get("regulation_code", "") + or payload.get("source_id", "") + or payload.get("source_code", "")) + + # Filter by regulation_code if configured + if config.regulation_filter: + if not reg_code: + continue # Skip chunks without regulation code + code_lower = reg_code.lower() + if not any(code_lower.startswith(f.lower()) for f in config.regulation_filter): + continue + + reg_name = (payload.get("regulation_name_de", "") + or payload.get("regulation_name", "") + or payload.get("source_name", "") + or payload.get("guideline_name", "") + or payload.get("document_title", "") + or payload.get("filename", "")) + reg_short = (payload.get("regulation_short", "") + or reg_code) + chunk = RAGSearchResult( + text=chunk_text, + regulation_code=reg_code, + regulation_name=reg_name, + regulation_short=reg_short, + category=payload.get("category", "") or payload.get("data_type", ""), + article=payload.get("article", "") or payload.get("section_title", "") or payload.get("section", ""), + paragraph=payload.get("paragraph", ""), + source_url=payload.get("source_url", "") or payload.get("source", "") or payload.get("url", ""), + score=0.0, + collection=collection, + ) + all_results.append(chunk) + collection_new += 1 + + page += 1 + if page % 100 == 0: + logger.info( + "Scrolling %s (direct Qdrant): page %d, %d total chunks, %d new unprocessed", + collection, page, collection_total, collection_new, + ) + + # Stop conditions + if next_page_offset is None: + break # Qdrant returns null when no more pages + + qdrant_offset = next_page_offset + + logger.info( + "Collection %s: %d total chunks scrolled (direct Qdrant), %d new unprocessed", + collection, collection_total, collection_new, + ) + + if config.regulation_filter: + logger.info( + "RAG scroll complete: %d total unique seen, %d passed regulation_filter %s", + len(seen_hashes), len(all_results), config.regulation_filter, + ) + else: + logger.info( + "RAG scroll complete: %d total unique seen, %d new unprocessed to process", + len(seen_hashes), len(all_results), + ) + return all_results + + def _get_processed_hashes(self, hashes: list[str]) -> set[str]: + """Check which chunk hashes are already processed.""" + if not hashes: + return set() + try: + result = self.db.execute( + text("SELECT chunk_hash FROM canonical_processed_chunks WHERE chunk_hash = ANY(:hashes)"), + {"hashes": hashes}, + ) + return {row[0] for row in result} + except Exception as e: + logger.warning("Error checking processed chunks: %s", e) + return set() + + # ── Stage 2: License Classification ──────────────────────────────── + + def _classify_license(self, chunk: RAGSearchResult) -> dict: + """Determine which license rule applies to this chunk.""" + return _classify_regulation(chunk.regulation_code) + + # ── Stage 3a: Structure (Rule 1 — Free Use) ─────────────────────── + + async def _structure_free_use(self, chunk: RAGSearchResult, license_info: dict) -> GeneratedControl: + """Structure a freely usable text into control format.""" + source_name = license_info.get("name", chunk.regulation_name) + prompt = f"""Strukturiere den folgenden Gesetzestext als Security/Compliance Control. +Du DARFST den Originaltext verwenden (Quelle: {source_name}, {license_info.get('license', '')}). + +WICHTIG: Erstelle eine verständliche, praxisorientierte Formulierung. +Der Originaltext wird separat gespeichert — deine Formulierung soll klar und umsetzbar sein. + +Gib JSON zurück mit diesen Feldern: +- title: Kurzer prägnanter Titel (max 100 Zeichen) +- objective: Was soll erreicht werden? (1-3 Sätze) +- rationale: Warum ist das wichtig? (1-2 Sätze) +- requirements: Liste von konkreten Anforderungen (Strings) +- test_procedure: Liste von Prüfschritten (Strings) +- evidence: Liste von Nachweisdokumenten (Strings) +- severity: low/medium/high/critical +- tags: Liste von Tags +- domain: Fachgebiet als Kuerzel (AUTH=Authentifizierung, CRYP=Kryptographie, NET=Netzwerk, DATA=Datenschutz, LOG=Logging, ACC=Zugriffskontrolle, SEC=IT-Sicherheit, INC=Vorfallmanagement, AI=KI, COMP=Compliance, GOV=Behoerden/Verwaltung, LAB=Arbeitsrecht, FIN=Finanzregulierung, TRD=Gewerbe/Handelsrecht, ENV=Umwelt, HLT=Gesundheit) +- category: Inhaltliche Kategorie — MUSS zum domain passen. Moegliche Werte: {CATEGORY_LIST_STR} +- target_audience: Liste der Zielgruppen (z.B. "unternehmen", "behoerden", "entwickler", "datenschutzbeauftragte", "geschaeftsfuehrung", "it-abteilung", "rechtsabteilung", "compliance-officer", "personalwesen", "einkauf", "produktion", "gesundheitswesen", "finanzwesen", "oeffentlicher_dienst") +- source_article: Artikel-/Paragraphen-Referenz aus dem Text (z.B. "Artikel 10", "§ 42"). Leer lassen wenn nicht erkennbar. +- source_paragraph: Absatz-Referenz aus dem Text (z.B. "Absatz 5", "Nr. 2"). Leer lassen wenn nicht erkennbar. +{APPLICABILITY_PROMPT} + +Text: {chunk.text[:2000]} +Quelle: {chunk.regulation_name} ({chunk.regulation_code}), {chunk.article}""" + + raw = await _llm_chat(prompt, STRUCTURE_SYSTEM_PROMPT) + data = _parse_llm_json(raw) + if not data: + return self._fallback_control(chunk) + + domain = _detect_domain(chunk.text) + control = self._build_control_from_json(data, domain) + llm_article = str(data.get("source_article", "")).strip() + llm_paragraph = str(data.get("source_paragraph", "")).strip() + effective_article = llm_article or chunk.article or "" + effective_paragraph = llm_paragraph or chunk.paragraph or "" + control.license_rule = 1 + control.source_original_text = chunk.text + # Use canonical name from REGULATION_LICENSE_MAP, not Qdrant's regulation_name + canonical_source = license_info.get("name", chunk.regulation_name) + control.source_citation = { + "source": canonical_source, + "article": effective_article, + "paragraph": effective_paragraph, + "license": license_info.get("license", ""), + "source_type": license_info.get("source_type", "law"), + "url": chunk.source_url or "", + } + control.customer_visible = True + control.verification_method = _detect_verification_method(chunk.text) + if not control.category: + control.category = _detect_category(chunk.text) + control.generation_metadata = { + "processing_path": "structured", + "license_rule": 1, + "source_regulation": chunk.regulation_code, + "source_article": effective_article, + "source_paragraph": effective_paragraph, + } + return control + + # ── Stage 3b: Structure with Citation (Rule 2) ──────────────────── + + async def _structure_with_citation(self, chunk: RAGSearchResult, license_info: dict) -> GeneratedControl: + """Structure text that requires citation.""" + source_name = license_info.get("name", chunk.regulation_name) + attribution = license_info.get("attribution", "") + prompt = f"""Strukturiere den folgenden Text als Security Control. +Quelle: {source_name} ({license_info.get('license', '')}) — Zitation erforderlich. + +Du darfst den Text übernehmen oder verständlicher umformulieren. +Die Quelle wird automatisch zitiert — fokussiere dich auf Klarheit. + +Gib JSON zurück mit diesen Feldern: +- title: Kurzer prägnanter Titel (max 100 Zeichen) +- objective: Was soll erreicht werden? (1-3 Sätze) +- rationale: Warum ist das wichtig? (1-2 Sätze) +- requirements: Liste von konkreten Anforderungen (Strings) +- test_procedure: Liste von Prüfschritten (Strings) +- evidence: Liste von Nachweisdokumenten (Strings) +- severity: low/medium/high/critical +- tags: Liste von Tags +- domain: Fachgebiet als Kuerzel (AUTH=Authentifizierung, CRYP=Kryptographie, NET=Netzwerk, DATA=Datenschutz, LOG=Logging, ACC=Zugriffskontrolle, SEC=IT-Sicherheit, INC=Vorfallmanagement, AI=KI, COMP=Compliance, GOV=Behoerden/Verwaltung, LAB=Arbeitsrecht, FIN=Finanzregulierung, TRD=Gewerbe/Handelsrecht, ENV=Umwelt, HLT=Gesundheit) +- category: Inhaltliche Kategorie — MUSS zum domain passen. Moegliche Werte: {CATEGORY_LIST_STR} +- target_audience: Liste der Zielgruppen (z.B. "unternehmen", "behoerden", "entwickler", "datenschutzbeauftragte", "geschaeftsfuehrung", "it-abteilung", "rechtsabteilung", "compliance-officer", "personalwesen", "einkauf", "produktion", "gesundheitswesen", "finanzwesen", "oeffentlicher_dienst") +- source_article: Artikel-/Paragraphen-Referenz aus dem Text (z.B. "Artikel 10", "§ 42"). Leer lassen wenn nicht erkennbar. +- source_paragraph: Absatz-Referenz aus dem Text (z.B. "Absatz 5", "Nr. 2"). Leer lassen wenn nicht erkennbar. +{APPLICABILITY_PROMPT} + +Text: {chunk.text[:2000]} +Quelle: {chunk.regulation_name}, {chunk.article}""" + + raw = await _llm_chat(prompt, STRUCTURE_SYSTEM_PROMPT) + data = _parse_llm_json(raw) + if not data: + return self._fallback_control(chunk) + + domain = _detect_domain(chunk.text) + control = self._build_control_from_json(data, domain) + llm_article = str(data.get("source_article", "")).strip() + llm_paragraph = str(data.get("source_paragraph", "")).strip() + effective_article = llm_article or chunk.article or "" + effective_paragraph = llm_paragraph or chunk.paragraph or "" + control.license_rule = 2 + control.source_original_text = chunk.text + # Use canonical name from REGULATION_LICENSE_MAP, not Qdrant's regulation_name + canonical_source = license_info.get("name", chunk.regulation_name) + control.source_citation = { + "source": canonical_source, + "article": effective_article, + "paragraph": effective_paragraph, + "license": license_info.get("license", ""), + "license_notice": attribution, + "source_type": license_info.get("source_type", "standard"), + "url": chunk.source_url or "", + } + control.customer_visible = True + control.verification_method = _detect_verification_method(chunk.text) + if not control.category: + control.category = _detect_category(chunk.text) + control.generation_metadata = { + "processing_path": "structured", + "license_rule": 2, + "source_regulation": chunk.regulation_code, + "source_article": effective_article, + "source_paragraph": effective_paragraph, + } + return control + + # ── Stage 3c: LLM Reformulation (Rule 3 — Restricted) ───────────── + + async def _llm_reformulate(self, chunk: RAGSearchResult, config: GeneratorConfig) -> GeneratedControl: + """Fully reformulate — NO original text, NO source names.""" + domain = config.domain or _detect_domain(chunk.text) + prompt = f"""Analysiere den folgenden Prüfaspekt und formuliere ein EIGENSTÄNDIGES Security Control. +KOPIERE KEINE Sätze. Verwende eigene Begriffe und Struktur. +NENNE NICHT die Quelle. Keine proprietären Bezeichner (kein O.Auth_*, TR-03161, BSI-TR etc.). + +Aspekt (nur zur Analyse, NICHT kopieren, NICHT referenzieren): +--- +{chunk.text[:1500]} +--- + +Domain: {domain} + +Gib JSON zurück mit diesen Feldern: +- title: Kurzer eigenständiger Titel (max 100 Zeichen) +- objective: Eigenständige Formulierung des Ziels (1-3 Sätze) +- rationale: Eigenständige Begründung (1-2 Sätze) +- requirements: Liste von konkreten Anforderungen (Strings, eigene Worte) +- test_procedure: Liste von Prüfschritten (Strings) +- evidence: Liste von Nachweisdokumenten (Strings) +- severity: low/medium/high/critical +- tags: Liste von Tags (eigene Begriffe) +- domain: Fachgebiet als Kuerzel (AUTH=Authentifizierung, CRYP=Kryptographie, NET=Netzwerk, DATA=Datenschutz, LOG=Logging, ACC=Zugriffskontrolle, SEC=IT-Sicherheit, INC=Vorfallmanagement, AI=KI, COMP=Compliance, GOV=Behoerden/Verwaltung, LAB=Arbeitsrecht, FIN=Finanzregulierung, TRD=Gewerbe/Handelsrecht, ENV=Umwelt, HLT=Gesundheit) +- category: Inhaltliche Kategorie — MUSS zum domain passen. Moegliche Werte: {CATEGORY_LIST_STR} +- target_audience: Liste der Zielgruppen (z.B. "unternehmen", "behoerden", "entwickler", "datenschutzbeauftragte", "geschaeftsfuehrung", "it-abteilung", "rechtsabteilung", "compliance-officer", "personalwesen", "oeffentlicher_dienst") +{APPLICABILITY_PROMPT}""" + + raw = await _llm_chat(prompt, REFORM_SYSTEM_PROMPT) + data = _parse_llm_json(raw) + if not data: + return self._fallback_control(chunk) + + control = self._build_control_from_json(data, domain) + control.license_rule = 3 + control.source_original_text = None # NEVER store original + control.source_citation = None # NEVER cite source + control.customer_visible = False # Only our formulation + control.verification_method = _detect_verification_method(chunk.text) + if not control.category: + control.category = _detect_category(chunk.text) + # generation_metadata: NO source names, NO original texts + control.generation_metadata = { + "processing_path": "llm_reform", + "license_rule": 3, + } + return control + + # ── Stage 3 BATCH: Multiple chunks in one API call ───────────────── + + async def _structure_batch( + self, + chunks: list[RAGSearchResult], + license_infos: list[dict], + ) -> list[Optional[GeneratedControl]]: + """Structure multiple free-use/citation chunks in a single Anthropic call.""" + # Build document context header if chunks share a regulation + regulations_in_batch = set(c.regulation_name for c in chunks) + doc_context = "" + if len(regulations_in_batch) == 1: + reg_name = next(iter(regulations_in_batch)) + articles = sorted(set(c.article or "?" for c in chunks)) + doc_context = ( + f"\nDOKUMENTKONTEXT: Alle {len(chunks)} Chunks stammen aus demselben Gesetz: {reg_name}.\n" + f"Betroffene Artikel/Abschnitte: {', '.join(articles)}.\n" + f"Nutze diesen Zusammenhang fuer eine kohaerente, aufeinander abgestimmte Formulierung der Controls.\n" + f"Vermeide Redundanzen zwischen den Controls — jedes soll einen eigenen Aspekt abdecken.\n" + ) + elif len(regulations_in_batch) <= 3: + doc_context = ( + f"\nDOKUMENTKONTEXT: Die Chunks stammen aus {len(regulations_in_batch)} Gesetzen: " + f"{', '.join(regulations_in_batch)}.\n" + ) + + chunk_entries = [] + for idx, (chunk, lic) in enumerate(zip(chunks, license_infos)): + source_name = lic.get("name", chunk.regulation_name) + chunk_entries.append( + f"--- CHUNK {idx + 1} ---\n" + f"Text: {chunk.text[:2000]}\n" + f"Quelle: {chunk.regulation_name} ({chunk.regulation_code}), {chunk.article}\n" + f"Lizenz: {source_name} ({lic.get('license', '')})" + ) + joined = "\n\n".join(chunk_entries) + prompt = f"""Strukturiere die folgenden {len(chunks)} Gesetzestexte jeweils als eigenstaendiges Security/Compliance Control. +Du DARFST den Originaltext verwenden (Quellen sind jeweils angegeben). +{doc_context} +WICHTIG: +- Pruefe JEDEN Chunk: Enthaelt er eine konkrete Pflicht, Anforderung oder Massnahme? +- Wenn JA: Erstelle ein vollstaendiges, eigenstaendiges Control mit praxisorientierter Formulierung. +- Wenn NEIN (reines Inhaltsverzeichnis, Begriffsbestimmung ohne Pflicht, Geltungsbereich ohne Anforderung, reine Verweiskette): Gib null fuer diesen Chunk zurueck. +- BEACHTE: Anhaenge/Annexe enthalten oft KONKRETE technische Anforderungen — diese MUESSEN als Control erfasst werden! +- Jedes Control muss eigenstaendig und vollstaendig sein — nicht auf andere Controls verweisen. +- Qualitaet ist wichtiger als Geschwindigkeit. +- Antworte IMMER auf Deutsch. + +Gib ein JSON-Array zurueck mit GENAU {len(chunks)} Elementen. Fuer Chunks ohne Anforderung gib null zurueck. Fuer Chunks mit Anforderung ein Objekt mit diesen Feldern: +- chunk_index: 1-basierter Index des Chunks (1, 2, 3, ...) +- title: Kurzer praegnanter Titel auf Deutsch (max 100 Zeichen) +- objective: Was soll erreicht werden? (1-3 Saetze, Deutsch) +- rationale: Warum ist das wichtig? (1-2 Saetze, Deutsch) +- requirements: Liste von konkreten Anforderungen (Strings, Deutsch) +- test_procedure: Liste von Pruefschritten (Strings, Deutsch) +- evidence: Liste von Nachweisdokumenten (Strings, Deutsch) +- severity: low/medium/high/critical +- tags: Liste von Tags +- domain: Fachgebiet als Kuerzel (AUTH=Authentifizierung, CRYP=Kryptographie, NET=Netzwerk, DATA=Datenschutz, LOG=Logging, ACC=Zugriffskontrolle, SEC=IT-Sicherheit, INC=Vorfallmanagement, AI=KI, COMP=Compliance, GOV=Behoerden/Verwaltung, LAB=Arbeitsrecht, FIN=Finanzregulierung, TRD=Gewerbe/Handelsrecht, ENV=Umwelt, HLT=Gesundheit) +- category: Inhaltliche Kategorie — MUSS zum domain passen. Moegliche Werte: {CATEGORY_LIST_STR} +- target_audience: Liste der Zielgruppen fuer die dieses Control relevant ist. Moegliche Werte: "unternehmen", "behoerden", "entwickler", "datenschutzbeauftragte", "geschaeftsfuehrung", "it-abteilung", "rechtsabteilung", "compliance-officer", "personalwesen", "einkauf", "produktion", "vertrieb", "gesundheitswesen", "finanzwesen", "oeffentlicher_dienst" +- source_article: Artikel-/Paragraphen-Referenz aus dem Text extrahieren (z.B. "Artikel 10", "Art. 5", "§ 42", "Section 3"). Leer lassen wenn nicht erkennbar. +- source_paragraph: Absatz-Referenz aus dem Text extrahieren (z.B. "Absatz 5", "Abs. 3", "Nr. 2", "(1)"). Leer lassen wenn nicht erkennbar. +{APPLICABILITY_PROMPT} + +{joined}""" + + raw = await _llm_chat(prompt, STRUCTURE_SYSTEM_PROMPT) + results = _parse_llm_json_array(raw) + logger.info("Batch structure: parsed %d results from API response", len(results)) + + # Map results back to chunks by chunk_index (or by position if no index) + controls: list[Optional[GeneratedControl]] = [None] * len(chunks) + skipped_by_api = 0 + for pos, data in enumerate(results): + # API returns null for chunks without actionable requirements + if data is None: + skipped_by_api += 1 + continue + # Try chunk_index first, fall back to position + idx = data.get("chunk_index") + if idx is not None: + idx = int(idx) - 1 # Convert to 0-based + else: + idx = pos # Use position as fallback + if idx < 0 or idx >= len(chunks): + logger.warning("Batch: chunk_index %d out of range (0-%d), using position %d", idx, len(chunks)-1, pos) + idx = min(pos, len(chunks) - 1) + chunk = chunks[idx] + lic = license_infos[idx] + domain = _detect_domain(chunk.text) + control = self._build_control_from_json(data, domain) + control.license_rule = lic["rule"] + # Use LLM-extracted article/paragraph, fall back to chunk metadata + llm_article = str(data.get("source_article", "")).strip() + llm_paragraph = str(data.get("source_paragraph", "")).strip() + effective_article = llm_article or chunk.article or "" + effective_paragraph = llm_paragraph or chunk.paragraph or "" + if lic["rule"] in (1, 2): + control.source_original_text = chunk.text + # Use canonical name from REGULATION_LICENSE_MAP, not Qdrant's regulation_name + canonical_source = lic.get("name", chunk.regulation_name) + control.source_citation = { + "source": canonical_source, + "article": effective_article, + "paragraph": effective_paragraph, + "license": lic.get("license", ""), + "license_notice": lic.get("attribution", ""), + "source_type": lic.get("source_type", "law"), + "url": chunk.source_url or "", + } + control.customer_visible = True + control.verification_method = _detect_verification_method(chunk.text) + if not control.category: + control.category = _detect_category(chunk.text) + same_doc = len(set(c.regulation_code for c in chunks)) == 1 + control.generation_metadata = { + "processing_path": "structured_batch", + "license_rule": lic["rule"], + "source_regulation": chunk.regulation_code, + "source_article": effective_article, + "source_paragraph": effective_paragraph, + "batch_size": len(chunks), + "document_grouped": same_doc, + } + control.generation_strategy = "document_grouped" if same_doc else "ungrouped" + controls[idx] = control + + return controls + + async def _reformulate_batch( + self, + chunks: list[RAGSearchResult], + config: GeneratorConfig, + ) -> list[Optional[GeneratedControl]]: + """Reformulate multiple restricted chunks in a single Anthropic call.""" + chunk_entries = [] + for idx, chunk in enumerate(chunks): + domain = config.domain or _detect_domain(chunk.text) + chunk_entries.append( + f"--- ASPEKT {idx + 1} ---\n" + f"Domain: {domain}\n" + f"Text (nur zur Analyse, NICHT kopieren, NICHT referenzieren):\n{chunk.text[:1500]}" + ) + joined = "\n\n".join(chunk_entries) + prompt = f"""Analysiere die folgenden {len(chunks)} Pruefaspekte und formuliere fuer JEDEN mit konkreter Anforderung ein EIGENSTAENDIGES Security Control. +KOPIERE KEINE Saetze. Verwende eigene Begriffe und Struktur. +NENNE NICHT die Quellen. Keine proprietaeren Bezeichner (kein O.Auth_*, TR-03161, BSI-TR etc.). + +WICHTIG: +- Pruefe JEDEN Aspekt: Enthaelt er eine konkrete Pflicht, Anforderung oder Massnahme? +- Wenn JA: Erstelle ein vollstaendiges, eigenstaendiges Control. +- Wenn NEIN (reines Inhaltsverzeichnis, Begriffsbestimmung ohne Pflicht, Geltungsbereich ohne Anforderung): Gib null fuer diesen Aspekt zurueck. +- BEACHTE: Anhaenge/Annexe enthalten oft KONKRETE technische Anforderungen — diese MUESSEN erfasst werden! +- Jedes Control muss eigenstaendig und vollstaendig sein — nicht auf andere Controls verweisen. +- Qualitaet ist wichtiger als Geschwindigkeit. + +Gib ein JSON-Array zurueck mit GENAU {len(chunks)} Elementen. Fuer Aspekte ohne Anforderung gib null zurueck. Fuer Aspekte mit Anforderung ein Objekt mit diesen Feldern: +- chunk_index: 1-basierter Index des Aspekts (1, 2, 3, ...) +- title: Kurzer eigenstaendiger Titel (max 100 Zeichen) +- objective: Eigenstaendige Formulierung des Ziels (1-3 Saetze) +- rationale: Eigenstaendige Begruendung (1-2 Saetze) +- requirements: Liste von konkreten Anforderungen (Strings, eigene Worte) +- test_procedure: Liste von Pruefschritten (Strings) +- evidence: Liste von Nachweisdokumenten (Strings) +- severity: low/medium/high/critical +- tags: Liste von Tags (eigene Begriffe) +- domain: Fachgebiet als Kuerzel (AUTH=Authentifizierung, CRYP=Kryptographie, NET=Netzwerk, DATA=Datenschutz, LOG=Logging, ACC=Zugriffskontrolle, SEC=IT-Sicherheit, INC=Vorfallmanagement, AI=KI, COMP=Compliance, GOV=Behoerden/Verwaltung, LAB=Arbeitsrecht, FIN=Finanzregulierung, TRD=Gewerbe/Handelsrecht, ENV=Umwelt, HLT=Gesundheit) +- category: Inhaltliche Kategorie — MUSS zum domain passen. Moegliche Werte: {CATEGORY_LIST_STR} +- target_audience: Liste der Zielgruppen (z.B. "unternehmen", "behoerden", "entwickler", "datenschutzbeauftragte", "geschaeftsfuehrung", "it-abteilung", "rechtsabteilung", "compliance-officer", "personalwesen", "einkauf", "produktion", "gesundheitswesen", "finanzwesen", "oeffentlicher_dienst") +{APPLICABILITY_PROMPT} + +{joined}""" + + raw = await _llm_chat(prompt, REFORM_SYSTEM_PROMPT) + results = _parse_llm_json_array(raw) + logger.info("Batch reform: parsed %d results from API response", len(results)) + + controls: list[Optional[GeneratedControl]] = [None] * len(chunks) + for pos, data in enumerate(results): + if data is None: + continue + idx = data.get("chunk_index") + if idx is not None: + idx = int(idx) - 1 + else: + idx = pos + if idx < 0 or idx >= len(chunks): + logger.warning("Batch reform: chunk_index %d out of range, using position %d", idx, pos) + idx = min(pos, len(chunks) - 1) + chunk = chunks[idx] + domain = config.domain or _detect_domain(chunk.text) + control = self._build_control_from_json(data, domain) + control.license_rule = 3 + control.source_original_text = None + control.source_citation = None + control.customer_visible = False + control.verification_method = _detect_verification_method(chunk.text) + if not control.category: + control.category = _detect_category(chunk.text) + control.generation_metadata = { + "processing_path": "llm_reform_batch", + "license_rule": 3, + "batch_size": len(chunks), + } + controls[idx] = control + + return controls + + async def _process_batch( + self, + batch_items: list[tuple[RAGSearchResult, dict]], + config: GeneratorConfig, + job_id: str, + ) -> list[Optional[GeneratedControl]]: + """Process a batch of (chunk, license_info) through stages 3-5.""" + # Split by license rule: Rule 1+2 → structure, Rule 3 → reform + structure_items = [(c, l) for c, l in batch_items if l["rule"] in (1, 2)] + reform_items = [(c, l) for c, l in batch_items if l["rule"] == 3] + + all_controls: dict[int, Optional[GeneratedControl]] = {} + + if structure_items: + s_chunks = [c for c, _ in structure_items] + s_lics = [l for _, l in structure_items] + s_controls = await self._structure_batch(s_chunks, s_lics) + for (chunk, _), ctrl in zip(structure_items, s_controls): + orig_idx = next(i for i, (c, _) in enumerate(batch_items) if c is chunk) + all_controls[orig_idx] = ctrl + + if reform_items: + r_chunks = [c for c, _ in reform_items] + r_controls = await self._reformulate_batch(r_chunks, config) + for (chunk, _), ctrl in zip(reform_items, r_controls): + orig_idx = next(i for i, (c, _) in enumerate(batch_items) if c is chunk) + if ctrl: + # Too-Close-Check for Rule 3 + similarity = await check_similarity(chunk.text, f"{ctrl.objective} {ctrl.rationale}") + if similarity.status == "FAIL": + ctrl.release_state = "too_close" + ctrl.generation_metadata["similarity_status"] = "FAIL" + ctrl.generation_metadata["similarity_scores"] = { + "token_overlap": similarity.token_overlap, + "ngram_jaccard": similarity.ngram_jaccard, + "lcs_ratio": similarity.lcs_ratio, + } + all_controls[orig_idx] = ctrl + + # Post-process all controls: harmonization + anchor search + # NOTE: QA validation runs as a separate batch AFTER generation (qa-reclassify endpoint) + # to avoid competing with Ollama prefilter for resources. + qa_fixed_count = 0 + final: list[Optional[GeneratedControl]] = [] + for i in range(len(batch_items)): + control = all_controls.get(i) + if not control or (not control.title and not control.objective): + final.append(None) + continue + + if control.release_state == "too_close": + final.append(control) + continue + + # Harmonization + duplicates = await self._check_harmonization(control) + if duplicates: + control.release_state = "duplicate" + control.generation_metadata["similar_controls"] = duplicates + final.append(control) + continue + + # Anchor search + try: + from .anchor_finder import AnchorFinder + finder = AnchorFinder(self.rag) + anchors = await finder.find_anchors(control, skip_web=config.skip_web_search) + control.open_anchors = [asdict(a) if hasattr(a, '__dataclass_fields__') else a for a in anchors] + except Exception as e: + logger.warning("Anchor search failed: %s", e) + + # Release state + if control.license_rule in (1, 2): + control.release_state = "draft" + elif control.open_anchors: + control.release_state = "draft" + else: + control.release_state = "needs_review" + + # Control ID — prefer QA-corrected or LLM-assigned domain over keyword detection + domain = (control.generation_metadata.get("_effective_domain") + or config.domain + or _detect_domain(control.objective)) + control.control_id = self._generate_control_id(domain, self.db) + control.generation_metadata["job_id"] = job_id + + final.append(control) + + if qa_fixed_count: + logger.info("QA validation: fixed %d/%d controls in batch", qa_fixed_count, len(final)) + return final, qa_fixed_count + + # ── Stage 4: Harmonization ───────────────────────────────────────── + + async def _check_harmonization(self, new_control: GeneratedControl) -> Optional[list]: + """Check if a new control duplicates existing ones via embedding similarity.""" + existing = self._load_existing_controls() + if not existing: + return None + + # Pre-load all existing embeddings in batch (once per pipeline run) + if not self._existing_embeddings: + await self._preload_embeddings(existing) + + new_text = f"{new_control.title} {new_control.objective}" + new_emb = await _get_embedding(new_text) + if not new_emb: + return None + + similar = [] + for ex in existing: + ex_key = ex.get("control_id", "") + ex_emb = self._existing_embeddings.get(ex_key, []) + if not ex_emb: + continue + + cosine = _cosine_sim(new_emb, ex_emb) + if cosine > HARMONIZATION_THRESHOLD: + similar.append({ + "control_id": ex.get("control_id", ""), + "title": ex.get("title", ""), + "similarity": round(cosine, 3), + }) + + return similar if similar else None + + async def _preload_embeddings(self, existing: list[dict]): + """Pre-load embeddings for all existing controls in batches.""" + texts = [f"{ex.get('title', '')} {ex.get('objective', '')}" for ex in existing] + keys = [ex.get("control_id", "") for ex in existing] + + logger.info("Pre-loading embeddings for %d existing controls...", len(texts)) + embeddings = await _get_embeddings_batch(texts) + + for key, emb in zip(keys, embeddings): + self._existing_embeddings[key] = emb + + loaded = sum(1 for emb in embeddings if emb) + logger.info("Pre-loaded %d/%d embeddings", loaded, len(texts)) + + # Reset DB session after long-running embedding operation to avoid stale connections + try: + self.db.rollback() + except Exception: + pass + + def _load_existing_controls(self) -> list[dict]: + """Load existing controls from DB (cached per pipeline run).""" + if self._existing_controls is not None: + return self._existing_controls + + try: + result = self.db.execute( + text("SELECT control_id, title, objective FROM canonical_controls WHERE release_state != 'deprecated'") + ) + self._existing_controls = [ + {"control_id": r[0], "title": r[1], "objective": r[2]} + for r in result + ] + except Exception as e: + logger.warning("Error loading existing controls: %s", e) + self._existing_controls = [] + + return self._existing_controls + + # ── Helpers ──────────────────────────────────────────────────────── + + def _build_control_from_json(self, data: dict, domain: str) -> GeneratedControl: + """Build a GeneratedControl from parsed LLM JSON.""" + severity = data.get("severity", "medium") + if severity not in ("low", "medium", "high", "critical"): + severity = "medium" + + tags = data.get("tags", []) + if isinstance(tags, str): + tags = [t.strip() for t in tags.split(",")] + + # Use LLM-provided domain if available, fallback to keyword-detected domain + llm_domain = data.get("domain") + if llm_domain and llm_domain.upper() in VALID_DOMAINS: + domain = llm_domain.upper() + + # Use LLM-provided category if available + llm_category = data.get("category") + category = None + if llm_category and llm_category in VALID_CATEGORIES: + category = llm_category + + # Parse target_audience from LLM response + target_audience = data.get("target_audience") + if isinstance(target_audience, str): + target_audience = [t.strip() for t in target_audience.split(",")] + if not isinstance(target_audience, list): + target_audience = None + + # Parse applicability fields (v3) + applicable_industries = data.get("applicable_industries") + if isinstance(applicable_industries, str): + applicable_industries = [applicable_industries] + if not isinstance(applicable_industries, list): + applicable_industries = None + + applicable_company_size = data.get("applicable_company_size") + if isinstance(applicable_company_size, str): + applicable_company_size = [applicable_company_size] + if not isinstance(applicable_company_size, list): + applicable_company_size = None + # Validate size values + valid_sizes = {"all", "micro", "small", "medium", "large", "enterprise"} + if applicable_company_size: + applicable_company_size = [s for s in applicable_company_size if s in valid_sizes] + if not applicable_company_size: + applicable_company_size = None + + scope_conditions = data.get("scope_conditions") + if not isinstance(scope_conditions, dict): + scope_conditions = None + + control = GeneratedControl( + title=str(data.get("title", "Untitled Control"))[:255], + objective=str(data.get("objective", "")), + rationale=str(data.get("rationale", "")), + scope=data.get("scope", {}), + requirements=data.get("requirements", []) if isinstance(data.get("requirements"), list) else [], + test_procedure=data.get("test_procedure", []) if isinstance(data.get("test_procedure"), list) else [], + evidence=data.get("evidence", []) if isinstance(data.get("evidence"), list) else [], + severity=severity, + risk_score=min(10.0, max(0.0, float(data.get("risk_score", 5.0)))), + implementation_effort=data.get("implementation_effort", "m") if data.get("implementation_effort") in ("s", "m", "l", "xl") else "m", + tags=tags[:20], + target_audience=target_audience, + category=category, + applicable_industries=applicable_industries, + applicable_company_size=applicable_company_size, + scope_conditions=scope_conditions, + ) + # Store effective domain for later control_id generation + control.generation_metadata["_effective_domain"] = domain + return control + + def _fallback_control(self, chunk: RAGSearchResult) -> GeneratedControl: + """Create a minimal control when LLM parsing fails.""" + domain = _detect_domain(chunk.text) + return GeneratedControl( + title=f"Control from {chunk.regulation_code} {chunk.article or ''}".strip()[:255], + objective=chunk.text[:500] if chunk.text else "Needs manual review", + rationale="Auto-generated — LLM parsing failed, manual review required.", + severity="medium", + release_state="needs_review", + tags=[domain.lower()], + ) + + def _generate_control_id(self, domain: str, db: Session) -> str: + """Generate next sequential control ID like AUTH-011.""" + prefix = domain.upper()[:4] + try: + result = db.execute( + text("SELECT control_id FROM canonical_controls WHERE control_id LIKE :prefix ORDER BY control_id DESC LIMIT 1"), + {"prefix": f"{prefix}-%"}, + ) + row = result.fetchone() + if row: + last_num = int(row[0].split("-")[-1]) + return f"{prefix}-{last_num + 1:03d}" + except Exception: + pass + return f"{prefix}-001" + + # ── Stage QA: Automated Quality Validation ─────────────────────── + + async def _qa_validate_control( + self, control: GeneratedControl, chunk_text: str + ) -> tuple[GeneratedControl, bool]: + """Cross-validate category/domain using keyword detection + local LLM. + + Also checks for recital (Erwägungsgrund) contamination in source text. + Returns (control, was_fixed). Only triggers Ollama QA when the LLM + classification disagrees with keyword detection — keeps it fast. + """ + # ── Recital detection ────────────────────────────────────────── + source_text = control.source_original_text or "" + recital_info = _detect_recital(source_text) + if recital_info: + control.generation_metadata["recital_suspect"] = True + control.generation_metadata["recital_detection"] = recital_info + control.release_state = "needs_review" + logger.warning( + "Recital suspect: '%s' — recitals %s detected in source text", + control.title[:40], + recital_info.get("recital_numbers", []), + ) + + kw_category = _detect_category(chunk_text) or _detect_category(control.objective) + kw_domain = _detect_domain(chunk_text) + llm_domain = control.generation_metadata.get("_effective_domain", "") + + # If keyword and LLM agree → no QA needed + if control.category == kw_category and llm_domain == kw_domain: + return control, False + + # Disagreement detected → ask local LLM to arbitrate + title = control.title[:100] + objective = control.objective[:200] + reqs = ", ".join(control.requirements[:3]) if control.requirements else "keine" + prompt = f"""Pruefe dieses Compliance-Control auf korrekte Klassifizierung. + +Titel: {title} +Ziel: {objective} +Anforderungen: {reqs} + +Aktuelle Zuordnung: domain={llm_domain}, category={control.category} +Keyword-Erkennung: domain={kw_domain}, category={kw_category} + +Welche Zuordnung ist korrekt? Antworte NUR als JSON: +{{"domain": "KUERZEL", "category": "kategorie_name", "reason": "kurze Begruendung"}} + +Domains: AUTH=Authentifizierung, CRYP=Kryptographie, NET=Netzwerk, DATA=Datenschutz, LOG=Logging, ACC=Zugriffskontrolle, SEC=IT-Sicherheit, INC=Vorfallmanagement, AI=KI, COMP=Compliance, GOV=Behoerden, LAB=Arbeitsrecht, FIN=Finanzregulierung, TRD=Gewerbe, ENV=Umwelt, HLT=Gesundheit +Kategorien: {CATEGORY_LIST_STR}""" + + try: + raw = await _llm_local(prompt) + data = _parse_llm_json(raw) + if not data: + return control, False + + fixed = False + qa_domain = data.get("domain", "").upper() + qa_category = data.get("category", "") + reason = data.get("reason", "") + + if qa_category and qa_category in VALID_CATEGORIES and qa_category != control.category: + old_cat = control.category + control.category = qa_category + control.generation_metadata["qa_category_fix"] = { + "from": old_cat, "to": qa_category, "reason": reason, + } + logger.info("QA fix: '%s' category '%s' -> '%s' (%s)", + title[:40], old_cat, qa_category, reason) + fixed = True + + if qa_domain and qa_domain in VALID_DOMAINS and qa_domain != llm_domain: + control.generation_metadata["qa_domain_fix"] = { + "from": llm_domain, "to": qa_domain, "reason": reason, + } + control.generation_metadata["_effective_domain"] = qa_domain + logger.info("QA fix: '%s' domain '%s' -> '%s' (%s)", + title[:40], llm_domain, qa_domain, reason) + fixed = True + + return control, fixed + + except Exception as e: + logger.warning("QA validation failed for '%s': %s", title[:40], e) + return control, False + + # ── Pipeline Orchestration ───────────────────────────────────────── + + def _create_job(self, config: GeneratorConfig) -> str: + """Create a generation job record.""" + try: + result = self.db.execute( + text(""" + INSERT INTO canonical_generation_jobs (status, config) + VALUES ('running', :config) + RETURNING id + """), + {"config": json.dumps(config.model_dump())}, + ) + self.db.commit() + row = result.fetchone() + return str(row[0]) if row else str(uuid.uuid4()) + except Exception as e: + logger.error("Failed to create job: %s", e) + return str(uuid.uuid4()) + + def _update_job(self, job_id: str, result: GeneratorResult): + """Update job with current stats. Sets completed_at only when status is final.""" + is_final = result.status in ("completed", "failed") + try: + self.db.execute( + text(f""" + UPDATE canonical_generation_jobs + SET status = :status, + total_chunks_scanned = :scanned, + controls_generated = :generated, + controls_verified = :verified, + controls_needs_review = :needs_review, + controls_too_close = :too_close, + controls_duplicates_found = :duplicates, + errors = :errors + {"" if not is_final else ", completed_at = NOW()"} + WHERE id = CAST(:job_id AS uuid) + """), + { + "job_id": job_id, + "status": result.status, + "scanned": result.total_chunks_scanned, + "generated": result.controls_generated, + "verified": result.controls_verified, + "needs_review": result.controls_needs_review, + "too_close": result.controls_too_close, + "duplicates": result.controls_duplicates_found, + "errors": json.dumps(result.errors[-50:]), + }, + ) + self.db.commit() + except Exception as e: + logger.error("Failed to update job: %s", e) + + def _store_control(self, control: GeneratedControl, job_id: str) -> Optional[str]: + """Persist a generated control to DB. Returns the control UUID or None.""" + try: + # Get framework UUID + fw_result = self.db.execute( + text("SELECT id FROM canonical_control_frameworks WHERE framework_id = 'bp_security_v1' LIMIT 1") + ) + fw_row = fw_result.fetchone() + if not fw_row: + logger.error("Framework bp_security_v1 not found") + return None + framework_uuid = fw_row[0] + + # Generate control_id if not set + if not control.control_id: + domain = _detect_domain(control.objective) if control.objective else "SEC" + control.control_id = self._generate_control_id(domain, self.db) + + result = self.db.execute( + text(""" + INSERT INTO canonical_controls ( + framework_id, control_id, title, objective, rationale, + scope, requirements, test_procedure, evidence, + severity, risk_score, implementation_effort, + open_anchors, release_state, tags, + license_rule, source_original_text, source_citation, + customer_visible, generation_metadata, + verification_method, category, generation_strategy, + target_audience, pipeline_version, + applicable_industries, applicable_company_size, scope_conditions + ) VALUES ( + :framework_id, :control_id, :title, :objective, :rationale, + :scope, :requirements, :test_procedure, :evidence, + :severity, :risk_score, :implementation_effort, + :open_anchors, :release_state, :tags, + :license_rule, :source_original_text, :source_citation, + :customer_visible, :generation_metadata, + :verification_method, :category, :generation_strategy, + :target_audience, :pipeline_version, + :applicable_industries, :applicable_company_size, :scope_conditions + ) + ON CONFLICT (framework_id, control_id) DO NOTHING + RETURNING id + """), + { + "framework_id": framework_uuid, + "control_id": control.control_id, + "title": control.title, + "objective": control.objective, + "rationale": control.rationale, + "scope": json.dumps(control.scope), + "requirements": json.dumps(control.requirements), + "test_procedure": json.dumps(control.test_procedure), + "evidence": json.dumps(control.evidence), + "severity": control.severity, + "risk_score": control.risk_score, + "implementation_effort": control.implementation_effort, + "open_anchors": json.dumps(control.open_anchors), + "release_state": control.release_state, + "tags": json.dumps(control.tags), + "license_rule": control.license_rule, + "source_original_text": control.source_original_text, + "source_citation": json.dumps(control.source_citation) if control.source_citation else None, + "customer_visible": control.customer_visible, + "generation_metadata": json.dumps(control.generation_metadata) if control.generation_metadata else None, + "verification_method": control.verification_method, + "category": control.category, + "generation_strategy": control.generation_strategy, + "target_audience": json.dumps(control.target_audience) if control.target_audience else None, + "pipeline_version": PIPELINE_VERSION, + "applicable_industries": json.dumps(control.applicable_industries) if control.applicable_industries else None, + "applicable_company_size": json.dumps(control.applicable_company_size) if control.applicable_company_size else None, + "scope_conditions": json.dumps(control.scope_conditions) if control.scope_conditions else None, + }, + ) + self.db.commit() + row = result.fetchone() + control_uuid = str(row[0]) if row else None + + # Anti-Fake-Evidence: Record LLM audit trail for generated control + if control_uuid: + try: + self.db.execute( + text(""" + INSERT INTO compliance_llm_generation_audit ( + entity_type, entity_id, generation_mode, + truth_status, may_be_used_as_evidence, + llm_model, llm_provider, + input_summary, output_summary + ) VALUES ( + 'control', :entity_id, 'auto_generation', + 'generated', FALSE, + :llm_model, :llm_provider, + :input_summary, :output_summary + ) + """), + { + "entity_id": control_uuid, + "llm_model": ANTHROPIC_MODEL if ANTHROPIC_API_KEY else OLLAMA_MODEL, + "llm_provider": "anthropic" if ANTHROPIC_API_KEY else "ollama", + "input_summary": f"Control generation for {control.control_id}", + "output_summary": control.title[:500] if control.title else None, + }, + ) + self.db.commit() + except Exception as audit_err: + logger.warning("Failed to create LLM audit record: %s", audit_err) + + return control_uuid + except Exception as e: + logger.error("Failed to store control %s: %s", control.control_id, e) + self.db.rollback() + return None + + def _mark_chunk_processed( + self, + chunk: RAGSearchResult, + license_info: dict, + processing_path: str, + control_ids: list[str], + job_id: str, + ): + """Mark a RAG chunk as processed (Stage 7).""" + chunk_hash = hashlib.sha256(chunk.text.encode()).hexdigest() + try: + self.db.execute( + text(""" + INSERT INTO canonical_processed_chunks ( + chunk_hash, collection, regulation_code, + document_version, source_license, license_rule, + processing_path, generated_control_ids, job_id, + pipeline_version + ) VALUES ( + :hash, :collection, :regulation_code, + :doc_version, :license, :rule, + :path, :control_ids, CAST(:job_id AS uuid), + :pipeline_version + ) + ON CONFLICT (chunk_hash, collection, document_version) DO NOTHING + """), + { + "hash": chunk_hash, + "collection": chunk.collection or "bp_compliance_ce", + "regulation_code": chunk.regulation_code, + "doc_version": "1.0", + "license": license_info.get("license", ""), + "rule": license_info.get("rule", 3), + "path": processing_path, + "control_ids": json.dumps(control_ids), + "job_id": job_id, + "pipeline_version": PIPELINE_VERSION, + }, + ) + self.db.commit() + except Exception as e: + logger.warning("Failed to mark chunk processed: %s", e) + self.db.rollback() + + # ── Main Pipeline ────────────────────────────────────────────────── + + async def run(self, config: GeneratorConfig) -> GeneratorResult: + """Execute the full 7-stage pipeline.""" + result = GeneratorResult() + + # Create or reuse job + if config.existing_job_id: + job_id = config.existing_job_id + else: + job_id = self._create_job(config) + result.job_id = job_id + + try: + # Stage 1: RAG Scan + chunks = await self._scan_rag(config) + result.total_chunks_scanned = len(chunks) + + if not chunks: + result.status = "completed" + self._update_job(job_id, result) + return result + + # ── Group chunks by document (regulation_code) for coherent batching ── + doc_groups: dict[str, list[RAGSearchResult]] = defaultdict(list) + for chunk in chunks: + group_key = chunk.regulation_code or "unknown" + doc_groups[group_key].append(chunk) + + # Sort chunks within each group by article for sequential context + for key in doc_groups: + doc_groups[key].sort(key=lambda c: (c.article or "", c.paragraph or "")) + + logger.info( + "Grouped %d chunks into %d document groups for coherent batching", + len(chunks), len(doc_groups), + ) + + # ── Apply max_chunks limit respecting document boundaries ── + # Process complete documents until we exceed the limit. + # Never split a document across jobs. + chunks = [] + if config.max_chunks and config.max_chunks > 0: + for group_key, group_list in doc_groups.items(): + if chunks and len(chunks) + len(group_list) > config.max_chunks: + # Adding this document would exceed the limit — stop here + break + chunks.extend(group_list) + logger.info( + "max_chunks=%d: selected %d chunks from %d complete documents (of %d total groups)", + config.max_chunks, len(chunks), + len(set(c.regulation_code for c in chunks)), + len(doc_groups), + ) + else: + # No limit: flatten all groups + for group_list in doc_groups.values(): + chunks.extend(group_list) + + result.total_chunks_scanned = len(chunks) + + # Process chunks — batch mode (N chunks per Anthropic API call) + BATCH_SIZE = config.batch_size or 5 + controls_count = 0 + chunks_skipped_prefilter = 0 + pending_batch: list[tuple[RAGSearchResult, dict]] = [] # (chunk, license_info) + current_batch_regulation: Optional[str] = None # Track regulation for group-aware flushing + + async def _flush_batch(): + """Send pending batch to Anthropic and process results.""" + nonlocal controls_count, current_batch_regulation + if not pending_batch: + return + batch = pending_batch.copy() + pending_batch.clear() + current_batch_regulation = None + + # Log which document this batch belongs to + regs_in_batch = set(c.regulation_code for c, _ in batch) + logger.info( + "Processing batch of %d chunks (docs: %s) via single API call...", + len(batch), ", ".join(regs_in_batch), + ) + try: + batch_controls, batch_qa_fixes = await self._process_batch(batch, config, job_id) + result.controls_qa_fixed += batch_qa_fixes + except Exception as e: + logger.error("Batch processing failed: %s — falling back to single-chunk mode", e) + # Fallback: process each chunk individually + batch_controls = [] + for chunk, _lic in batch: + try: + ctrl = await self._process_single_chunk(chunk, config, job_id) + batch_controls.append(ctrl) + except Exception as e2: + logger.error("Single-chunk fallback also failed: %s", e2) + batch_controls.append(None) + + for (chunk, lic_info), control in zip(batch, batch_controls): + if control is None: + if not config.dry_run: + self._mark_chunk_processed(chunk, lic_info, "no_control", [], job_id) + continue + + # Mark as document_grouped strategy + control.generation_strategy = "document_grouped" + + # Count by state + if control.release_state == "too_close": + result.controls_too_close += 1 + elif control.release_state == "duplicate": + result.controls_duplicates_found += 1 + elif control.release_state == "needs_review": + result.controls_needs_review += 1 + else: + result.controls_verified += 1 + + # Store + if not config.dry_run: + ctrl_uuid = self._store_control(control, job_id) + if ctrl_uuid: + path = control.generation_metadata.get("processing_path", "structured_batch") + self._mark_chunk_processed(chunk, lic_info, path, [ctrl_uuid], job_id) + else: + self._mark_chunk_processed(chunk, lic_info, "store_failed", [], job_id) + + result.controls_generated += 1 + result.controls.append(asdict(control)) + controls_count += 1 + + if self._existing_controls is not None: + self._existing_controls.append({ + "control_id": control.control_id, + "title": control.title, + "objective": control.objective, + }) + + for i, chunk in enumerate(chunks): + try: + # Progress logging every 50 chunks + if i > 0 and i % 50 == 0: + logger.info( + "Progress: %d/%d chunks processed, %d controls generated, %d skipped by prefilter", + i, len(chunks), controls_count, chunks_skipped_prefilter, + ) + self._update_job(job_id, result) + + # Stage 1.5: Local LLM pre-filter — skip chunks without requirements + if not config.dry_run and not config.skip_prefilter: + is_relevant, prefilter_reason = await _prefilter_chunk(chunk.text) + if not is_relevant: + chunks_skipped_prefilter += 1 + license_info = self._classify_license(chunk) + self._mark_chunk_processed( + chunk, license_info, "prefilter_skip", [], job_id + ) + continue + + # Classify license and add to batch + license_info = self._classify_license(chunk) + chunk_regulation = chunk.regulation_code or "unknown" + + # Flush when: batch is full OR regulation changes (group boundary) + if pending_batch and ( + len(pending_batch) >= BATCH_SIZE + or chunk_regulation != current_batch_regulation + ): + await _flush_batch() + + pending_batch.append((chunk, license_info)) + current_batch_regulation = chunk_regulation + + except Exception as e: + error_msg = f"Error processing chunk {chunk.regulation_code}/{chunk.article}: {e}" + logger.error(error_msg) + result.errors.append(error_msg) + try: + if not config.dry_run: + license_info = self._classify_license(chunk) + self._mark_chunk_processed( + chunk, license_info, "error", [], job_id + ) + except Exception: + pass + + # Flush remaining chunks + await _flush_batch() + + result.chunks_skipped_prefilter = chunks_skipped_prefilter + logger.info( + "Pipeline complete: %d controls generated, %d chunks skipped by prefilter, %d total chunks", + controls_count, chunks_skipped_prefilter, len(chunks), + ) + + result.status = "completed" + + except Exception as e: + result.status = "failed" + result.errors.append(str(e)) + logger.error("Pipeline failed: %s", e) + + self._update_job(job_id, result) + return result + + async def _process_single_chunk( + self, + chunk: RAGSearchResult, + config: GeneratorConfig, + job_id: str, + ) -> Optional[GeneratedControl]: + """Process a single chunk through stages 2-5.""" + # Stage 2: License classification + license_info = self._classify_license(chunk) + + # Stage 3: Structure or Reform based on rule + if license_info["rule"] == 1: + control = await self._structure_free_use(chunk, license_info) + elif license_info["rule"] == 2: + control = await self._structure_with_citation(chunk, license_info) + else: + control = await self._llm_reformulate(chunk, config) + + # Too-Close-Check for Rule 3 + similarity = await check_similarity(chunk.text, f"{control.objective} {control.rationale}") + if similarity.status == "FAIL": + control.release_state = "too_close" + control.generation_metadata["similarity_status"] = "FAIL" + control.generation_metadata["similarity_scores"] = { + "token_overlap": similarity.token_overlap, + "ngram_jaccard": similarity.ngram_jaccard, + "lcs_ratio": similarity.lcs_ratio, + } + return control + + if not control.title or not control.objective: + return None + + # NOTE: QA validation runs as a separate batch AFTER generation (qa-reclassify endpoint) + + # Stage 4: Harmonization + duplicates = await self._check_harmonization(control) + if duplicates: + control.release_state = "duplicate" + control.generation_metadata["similar_controls"] = duplicates + return control + + # Stage 5: Anchor Search (imported from anchor_finder) + try: + from .anchor_finder import AnchorFinder + finder = AnchorFinder(self.rag) + anchors = await finder.find_anchors(control, skip_web=config.skip_web_search) + control.open_anchors = [asdict(a) if hasattr(a, '__dataclass_fields__') else a for a in anchors] + except Exception as e: + logger.warning("Anchor search failed: %s", e) + + # Determine release state + if control.license_rule in (1, 2): + control.release_state = "draft" + elif control.open_anchors: + control.release_state = "draft" + else: + control.release_state = "needs_review" + + # Generate control_id — prefer QA-corrected or LLM-assigned domain + domain = (control.generation_metadata.get("_effective_domain") + or config.domain + or _detect_domain(control.objective)) + control.control_id = self._generate_control_id(domain, self.db) + + # Store job_id in metadata + control.generation_metadata["job_id"] = job_id + + return control diff --git a/control-pipeline/services/control_status_machine.py b/control-pipeline/services/control_status_machine.py new file mode 100644 index 0000000..4bc3200 --- /dev/null +++ b/control-pipeline/services/control_status_machine.py @@ -0,0 +1,154 @@ +""" +Control Status Transition State Machine. + +Enforces that controls cannot be set to "pass" without sufficient evidence. +Prevents Compliance-Theater where controls claim compliance without real proof. + +Transition rules: + planned → in_progress : always allowed + in_progress → pass : requires ≥1 evidence with confidence ≥ E2 and + truth_status in (uploaded, observed, validated_internal) + in_progress → partial : requires ≥1 evidence (any level) + pass → fail : always allowed (degradation) + any → n/a : requires status_justification + any → planned : always allowed (reset) +""" + +from typing import Any, List, Optional, Tuple + +# EvidenceDB is an ORM model from compliance — we only need duck-typed objects +# with .confidence_level and .truth_status attributes. +EvidenceDB = Any + + +# Confidence level ordering for comparisons +CONFIDENCE_ORDER = {"E0": 0, "E1": 1, "E2": 2, "E3": 3, "E4": 4} + +# Truth statuses that qualify as "real" evidence for pass transitions +VALID_TRUTH_STATUSES = {"uploaded", "observed", "validated_internal", "accepted_by_auditor", "provided_to_auditor"} + + +def validate_transition( + current_status: str, + new_status: str, + evidence_list: Optional[List[EvidenceDB]] = None, + status_justification: Optional[str] = None, + bypass_for_auto_updater: bool = False, +) -> Tuple[bool, List[str]]: + """ + Validate whether a control status transition is allowed. + + Args: + current_status: Current control status value (e.g. "planned", "pass") + new_status: Requested new status + evidence_list: List of EvidenceDB objects linked to this control + status_justification: Text justification (required for n/a transitions) + bypass_for_auto_updater: If True, skip evidence checks (used by CI/CD auto-updater + which creates evidence atomically with status change) + + Returns: + Tuple of (allowed: bool, violations: list[str]) + """ + violations: List[str] = [] + evidence_list = evidence_list or [] + + # Same status → no-op, always allowed + if current_status == new_status: + return True, [] + + # Reset to planned is always allowed + if new_status == "planned": + return True, [] + + # n/a requires justification + if new_status == "n/a": + if not status_justification or not status_justification.strip(): + violations.append("Transition to 'n/a' requires a status_justification explaining why this control is not applicable.") + return len(violations) == 0, violations + + # Degradation: pass → fail is always allowed + if current_status == "pass" and new_status == "fail": + return True, [] + + # planned → in_progress: always allowed + if current_status == "planned" and new_status == "in_progress": + return True, [] + + # in_progress → partial: needs at least 1 evidence + if new_status == "partial": + if not bypass_for_auto_updater and len(evidence_list) == 0: + violations.append("Transition to 'partial' requires at least 1 evidence record.") + return len(violations) == 0, violations + + # in_progress → pass: strict requirements + if new_status == "pass": + if bypass_for_auto_updater: + return True, [] + + if len(evidence_list) == 0: + violations.append("Transition to 'pass' requires at least 1 evidence record.") + return False, violations + + # Check for at least one qualifying evidence + has_qualifying = False + for e in evidence_list: + conf = getattr(e, "confidence_level", None) + truth = getattr(e, "truth_status", None) + + # Get string values from enum or string + conf_val = conf.value if hasattr(conf, "value") else str(conf) if conf else "E1" + truth_val = truth.value if hasattr(truth, "value") else str(truth) if truth else "uploaded" + + if CONFIDENCE_ORDER.get(conf_val, 1) >= CONFIDENCE_ORDER["E2"] and truth_val in VALID_TRUTH_STATUSES: + has_qualifying = True + break + + if not has_qualifying: + violations.append( + "Transition to 'pass' requires at least 1 evidence with confidence >= E2 " + "and truth_status in (uploaded, observed, validated_internal, accepted_by_auditor). " + "Current evidence does not meet this threshold." + ) + + return len(violations) == 0, violations + + # in_progress → fail: always allowed + if new_status == "fail": + return True, [] + + # Any other transition from planned/fail to pass requires going through in_progress + if current_status in ("planned", "fail") and new_status == "pass": + if bypass_for_auto_updater: + return True, [] + violations.append( + f"Direct transition from '{current_status}' to 'pass' is not allowed. " + f"Move to 'in_progress' first, then to 'pass' with qualifying evidence." + ) + return False, violations + + # Default: allow other transitions (e.g. fail → partial, partial → pass) + # For partial → pass, apply the same evidence checks + if current_status == "partial" and new_status == "pass": + if bypass_for_auto_updater: + return True, [] + + has_qualifying = False + for e in evidence_list: + conf = getattr(e, "confidence_level", None) + truth = getattr(e, "truth_status", None) + conf_val = conf.value if hasattr(conf, "value") else str(conf) if conf else "E1" + truth_val = truth.value if hasattr(truth, "value") else str(truth) if truth else "uploaded" + + if CONFIDENCE_ORDER.get(conf_val, 1) >= CONFIDENCE_ORDER["E2"] and truth_val in VALID_TRUTH_STATUSES: + has_qualifying = True + break + + if not has_qualifying: + violations.append( + "Transition from 'partial' to 'pass' requires at least 1 evidence with confidence >= E2 " + "and truth_status in (uploaded, observed, validated_internal, accepted_by_auditor)." + ) + return len(violations) == 0, violations + + # All other transitions allowed + return True, [] diff --git a/control-pipeline/services/decomposition_pass.py b/control-pipeline/services/decomposition_pass.py new file mode 100644 index 0000000..7cad3e4 --- /dev/null +++ b/control-pipeline/services/decomposition_pass.py @@ -0,0 +1,3877 @@ +"""Decomposition Pass — Split Rich Controls into Atomic Controls. + +Pass 0 of the Multi-Layer Control Architecture migration. Runs BEFORE +Passes 1-5 (obligation linkage, pattern classification, etc.). + +Two sub-passes: + Pass 0a: Obligation Extraction — extract individual normative obligations + from a Rich Control using LLM with strict guardrails. + Pass 0b: Atomic Control Composition — turn each obligation candidate + into a standalone atomic control record. + +Plus a Quality Gate that validates extraction results. + +Guardrails (the 6 rules): + 1. Only normative statements (müssen, sicherzustellen, verpflichtet, ...) + 2. One main verb per obligation + 3. Test obligations separate from operational obligations + 4. Reporting obligations separate + 5. Don't split at evidence level + 6. Parent link always preserved +""" + +import json +import logging +import os +import re +import uuid +from dataclasses import dataclass, field +from typing import Optional + +import httpx +from sqlalchemy import text +from sqlalchemy.orm import Session + +logger = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# LLM Provider Config +# --------------------------------------------------------------------------- + +ANTHROPIC_API_KEY = os.getenv("ANTHROPIC_API_KEY", "") +ANTHROPIC_MODEL = os.getenv("DECOMPOSITION_LLM_MODEL", "claude-haiku-4-5-20251001") +DECOMPOSITION_BATCH_SIZE = int(os.getenv("DECOMPOSITION_BATCH_SIZE", "5")) +LLM_TIMEOUT = float(os.getenv("DECOMPOSITION_LLM_TIMEOUT", "120")) +ANTHROPIC_API_URL = "https://api.anthropic.com/v1" + + +# --------------------------------------------------------------------------- +# Normative signal detection — 3-Tier Classification +# --------------------------------------------------------------------------- +# Tier 1: Pflicht (mandatory) — strong normative signals +# Tier 2: Empfehlung (recommendation) — weaker normative signals +# Tier 3: Kann (optional/permissive) — permissive signals +# Nothing is rejected — everything is classified. +# +# Patterns are defined in normative_patterns.py and imported here +# with local aliases for backward compatibility. + +from .normative_patterns import ( + PFLICHT_RE as _PFLICHT_RE, + EMPFEHLUNG_RE as _EMPFEHLUNG_RE, + KANN_RE as _KANN_RE, + NORMATIVE_RE as _NORMATIVE_RE, + RATIONALE_RE as _RATIONALE_RE, +) + +_TEST_SIGNALS = [ + r"\btesten\b", r"\btest\b", r"\bprüfung\b", r"\bprüfen\b", + r"\bgetestet\b", r"\bwirksamkeit\b", r"\baudit\b", + r"\bregelmäßig\b.*\b(prüf|test|kontroll)", + r"\beffectiveness\b", r"\bverif", +] +_TEST_RE = re.compile("|".join(_TEST_SIGNALS), re.IGNORECASE) + +_REPORTING_SIGNALS = [ + r"\bmelden\b", r"\bmeldung\b", r"\bunterricht", + r"\binformieren\b", r"\bbenachricht", r"\bnotif", + r"\breport\b", r"\bbehörd", +] +_REPORTING_RE = re.compile("|".join(_REPORTING_SIGNALS), re.IGNORECASE) + + +# --------------------------------------------------------------------------- +# Merge & Enrichment helpers +# --------------------------------------------------------------------------- + +# Trigger-type detection patterns +_EVENT_TRIGGERS = re.compile( + r"\b(vorfall|incident|breach|verletzung|sicherheitsvorfall|meldung|entdeckung" + r"|feststellung|erkennung|ereignis|eintritt|bei\s+auftreten|im\s+falle" + r"|wenn\s+ein|sobald|unverzüglich|upon|in\s+case\s+of|when\s+a)\b", + re.IGNORECASE, +) +_PERIODIC_TRIGGERS = re.compile( + r"\b(jährlich|monatlich|quartalsweise|regelmäßig|periodisch|annually" + r"|monthly|quarterly|periodic|mindestens\s+(einmal|alle)|turnusmäßig" + r"|wiederkehrend|in\s+regelmäßigen\s+abständen)\b", + re.IGNORECASE, +) + +# Implementation-specific keywords (concrete tools/protocols/formats) +_IMPL_SPECIFIC_PATTERNS = re.compile( + r"\b(TLS|SSL|AES|RSA|SHA-\d|HTTPS|LDAP|SAML|OAuth|OIDC|MFA|2FA" + r"|SIEM|IDS|IPS|WAF|VPN|VLAN|DMZ|HSM|PKI|RBAC|ABAC" + r"|ISO\s*27\d{3}|SOC\s*2|PCI[\s-]DSS|NIST" + r"|Firewall|Antivirus|EDR|XDR|SOAR|DLP" + r"|SMS|E-Mail|Fax|Telefon" + r"|JSON|XML|CSV|PDF|YAML" + r"|PostgreSQL|MySQL|MongoDB|Redis|Kafka" + r"|Docker|Kubernetes|AWS|Azure|GCP" + r"|Active\s*Directory|RADIUS|Kerberos" + r"|RSyslog|Splunk|ELK|Grafana|Prometheus" + r"|Git|Jenkins|Terraform|Ansible)\b", + re.IGNORECASE, +) + + +def _classify_trigger_type(obligation_text: str, condition: str) -> str: + """Classify when an obligation is triggered: event/periodic/continuous.""" + combined = f"{obligation_text} {condition}" + if _EVENT_TRIGGERS.search(combined): + return "event" + if _PERIODIC_TRIGGERS.search(combined): + return "periodic" + return "continuous" + + +def _is_implementation_specific_text( + obligation_text: str, action: str, obj: str +) -> bool: + """Check if an obligation references concrete implementation details.""" + combined = f"{obligation_text} {action} {obj}" + matches = _IMPL_SPECIFIC_PATTERNS.findall(combined) + return len(matches) >= 1 + + +def _text_similar(a: str, b: str, threshold: float = 0.75) -> bool: + """Quick token-overlap similarity check (Jaccard on words).""" + if not a or not b: + return False + tokens_a = set(a.split()) + tokens_b = set(b.split()) + if not tokens_a or not tokens_b: + return False + intersection = tokens_a & tokens_b + union = tokens_a | tokens_b + return len(intersection) / len(union) >= threshold + + +def _is_more_implementation_specific(text_a: str, text_b: str) -> bool: + """Return True if text_a is more implementation-specific than text_b.""" + matches_a = len(_IMPL_SPECIFIC_PATTERNS.findall(text_a)) + matches_b = len(_IMPL_SPECIFIC_PATTERNS.findall(text_b)) + if matches_a != matches_b: + return matches_a > matches_b + # Tie-break: longer text is usually more specific + return len(text_a) > len(text_b) + + +# --------------------------------------------------------------------------- +# Data classes +# --------------------------------------------------------------------------- + + +@dataclass +class ObligationCandidate: + """A single normative obligation extracted from a Rich Control.""" + + candidate_id: str = "" + parent_control_uuid: str = "" + obligation_text: str = "" + action: str = "" + object_: str = "" + condition: Optional[str] = None + normative_strength: str = "must" + obligation_type: str = "pflicht" # pflicht | empfehlung | kann + is_test_obligation: bool = False + is_reporting_obligation: bool = False + extraction_confidence: float = 0.0 + quality_flags: dict = field(default_factory=dict) + release_state: str = "extracted" + + def to_dict(self) -> dict: + return { + "candidate_id": self.candidate_id, + "parent_control_uuid": self.parent_control_uuid, + "obligation_text": self.obligation_text, + "action": self.action, + "object": self.object_, + "condition": self.condition, + "normative_strength": self.normative_strength, + "obligation_type": self.obligation_type, + "is_test_obligation": self.is_test_obligation, + "is_reporting_obligation": self.is_reporting_obligation, + "extraction_confidence": self.extraction_confidence, + "quality_flags": self.quality_flags, + "release_state": self.release_state, + } + + +@dataclass +class AtomicControlCandidate: + """An atomic control composed from a single ObligationCandidate.""" + + candidate_id: str = "" + parent_control_uuid: str = "" + obligation_candidate_id: str = "" + title: str = "" + objective: str = "" + requirements: list = field(default_factory=list) + test_procedure: list = field(default_factory=list) + evidence: list = field(default_factory=list) + severity: str = "medium" + category: str = "" + domain: str = "" + source_regulation: str = "" + source_article: str = "" + + def to_dict(self) -> dict: + return { + "candidate_id": self.candidate_id, + "parent_control_uuid": self.parent_control_uuid, + "obligation_candidate_id": self.obligation_candidate_id, + "title": self.title, + "objective": self.objective, + "requirements": self.requirements, + "test_procedure": self.test_procedure, + "evidence": self.evidence, + "severity": self.severity, + "category": self.category, + "domain": self.domain, + } + + +# --------------------------------------------------------------------------- +# Quality Gate +# --------------------------------------------------------------------------- + + +def classify_obligation_type(txt: str) -> str: + """Classify obligation text into pflicht/empfehlung/kann. + + Priority: pflicht > empfehlung > kann > empfehlung (default). + Nothing is rejected — obligations without normative signal default + to 'empfehlung' (recommendation). + """ + if _PFLICHT_RE.search(txt): + return "pflicht" + if _EMPFEHLUNG_RE.search(txt): + return "empfehlung" + if _KANN_RE.search(txt): + return "kann" + # No signal at all — LLM thought it was an obligation, classify + # as recommendation (the user can still use it). + return "empfehlung" + + +def quality_gate(candidate: ObligationCandidate) -> dict: + """Validate an obligation candidate. Returns quality flags dict. + + Checks: + has_normative_signal: text contains normative language (informational) + obligation_type: pflicht | empfehlung | kann (classified, never rejected) + single_action: only one main action (heuristic) + not_rationale: not just a justification/reasoning + not_evidence_only: not just an evidence requirement + min_length: text is long enough to be meaningful + has_parent_link: references back to parent control + """ + txt = candidate.obligation_text + flags = {} + + # 1. Normative signal (informational — no longer used for rejection) + flags["has_normative_signal"] = bool(_NORMATIVE_RE.search(txt)) + + # 1b. Obligation type classification + flags["obligation_type"] = classify_obligation_type(txt) + + # 2. Single action heuristic — count "und" / "and" / "sowie" splits + # that connect different verbs (imperfect but useful) + multi_verb_re = re.compile( + r"\b(und|sowie|als auch)\b.*\b(müssen|sicherstellen|implementieren" + r"|dokumentieren|melden|testen|prüfen|überwachen|gewährleisten)\b", + re.IGNORECASE, + ) + flags["single_action"] = not bool(multi_verb_re.search(txt)) + + # 3. Not rationale + normative_count = len(_NORMATIVE_RE.findall(txt)) + rationale_count = len(_RATIONALE_RE.findall(txt)) + flags["not_rationale"] = normative_count >= rationale_count + + # 4. Not evidence-only (evidence fragments are typically short noun phrases) + evidence_only_re = re.compile( + r"^(Nachweis|Dokumentation|Screenshot|Protokoll|Bericht|Zertifikat)", + re.IGNORECASE, + ) + flags["not_evidence_only"] = not bool(evidence_only_re.match(txt.strip())) + + # 5. Min length + flags["min_length"] = len(txt.strip()) >= 20 + + # 6. Parent link + flags["has_parent_link"] = bool(candidate.parent_control_uuid) + + return flags + + +def passes_quality_gate(flags: dict) -> bool: + """Check if critical quality flags pass. + + Note: has_normative_signal is NO LONGER critical — obligations without + normative signal are classified as 'empfehlung' instead of being rejected. + """ + critical = ["not_evidence_only", "min_length", "has_parent_link"] + return all(flags.get(k, False) for k in critical) + + +# --------------------------------------------------------------------------- +# LLM Prompts +# --------------------------------------------------------------------------- + + +_PASS0A_SYSTEM_PROMPT = """\ +Du bist ein Rechts-Compliance-Experte. Du zerlegst Compliance-Controls \ +in einzelne atomare Pflichten. + +ANALYSE-SCHRITTE (intern durchfuehren, NICHT im Output!): +1. Identifiziere den Adressaten (Wer muss handeln?) +2. Identifiziere die Handlung (Was muss getan werden?) +3. Bestimme die normative Staerke (muss/soll/kann) +4. Pruefe ob Test- oder Meldepflicht vorliegt (separat erfassen!) +5. Formuliere jede Pflicht als eigenstaendiges JSON-Objekt + +REGELN (STRIKT EINHALTEN): +1. Nur normative Aussagen extrahieren — erkennbar an: müssen, haben \ +sicherzustellen, sind verpflichtet, ist zu dokumentieren, ist zu melden, \ +ist zu testen, shall, must, required. +2. Jede Pflicht hat genau EIN Hauptverb / eine Handlung. +3. Testpflichten SEPARAT von operativen Pflichten (is_test_obligation=true). +4. Meldepflichten SEPARAT (is_reporting_obligation=true). +5. NICHT auf Evidence-Ebene zerlegen (z.B. "DR-Plan vorhanden" ist KEIN \ +eigenes Control, sondern Evidence). +6. Begründungen, Erläuterungen und Erwägungsgründe sind KEINE Pflichten \ +— NICHT extrahieren. + +Antworte NUR mit einem JSON-Array. Keine Erklärungen.""" + + +def _build_pass0a_prompt( + title: str, objective: str, requirements: str, + test_procedure: str, source_ref: str +) -> str: + return f"""\ +Analysiere das folgende Control und extrahiere alle einzelnen normativen \ +Pflichten als JSON-Array. + +CONTROL: +Titel: {title} +Ziel: {objective} +Anforderungen: {requirements} +Prüfverfahren: {test_procedure} +Quellreferenz: {source_ref} + +Antworte als JSON-Array: +[ + {{ + "obligation_text": "Kurze, präzise Formulierung der Pflicht", + "action": "Hauptverb/Handlung", + "object": "Gegenstand der Pflicht", + "condition": "Auslöser/Bedingung oder null", + "normative_strength": "must", + "is_test_obligation": false, + "is_reporting_obligation": false + }} +]""" + + +_PASS0B_SYSTEM_PROMPT = """\ +Du bist ein Security-Compliance-Experte. Du erstellst aus einer einzelnen \ +normativen Pflicht ein praxisorientiertes, atomares Security Control. + +ANALYSE-SCHRITTE (intern durchfuehren, NICHT im Output!): +1. Identifiziere die konkrete Anforderung aus der Pflicht +2. Leite eine umsetzbare technische/organisatorische Massnahme ab +3. Definiere ein Pruefverfahren (wie wird Umsetzung verifiziert?) +4. Bestimme den Nachweis (welches Dokument/Artefakt belegt Compliance?) + +Das Control muss UMSETZBAR sein — keine Gesetzesparaphrase. +Antworte NUR als JSON. Keine Erklärungen.""" + + +# --------------------------------------------------------------------------- +# Deterministic Atomic Control Composition Engine v2 +# --------------------------------------------------------------------------- +# Transforms obligation candidates into atomic controls WITHOUT LLM. +# +# Pipeline: +# 1. split_compound_action() — split "erstellen und implementieren" → 2 +# 2. classify_action() — 18 fine-grained action types +# 3. classify_object() — policy / technical / process / register / … +# 4. trigger_qualifier() — periodic / event / continuous → timing text +# 5. template lookup — (action_type, object_class) → test + evidence +# 6. compose — assemble AtomicControlCandidate +# --------------------------------------------------------------------------- + +# ── 1. Compound Action Splitter ────────────────────────────────────────── + +_COMPOUND_SPLIT_RE = re.compile( + r"\s+(?:und|sowie|als\s+auch|,\s*(?:und|sowie))\s+", re.IGNORECASE +) + +# Phrases that should never be split (stylistic variants, not separate actions) +_NO_SPLIT_PHRASES: set[str] = { + "pflegen und aufrechterhalten", + "aufrechterhalten und pflegen", + "erkennen und verhindern", + "verhindern und erkennen", + "sichern und schützen", + "schützen und sichern", + "schützen und absichern", + "absichern und schützen", + "ermitteln und bewerten", + "bewerten und ermitteln", + "prüfen und überwachen", + "überwachen und prüfen", +} + + +def _split_compound_action(action: str) -> list[str]: + """Split compound actions into individual sub-actions. + + Only splits if: + - the parts map to *different* action types + - the phrase is not in the no-split list + + Keeps phrases like 'aufrechterhalten und pflegen' together + because both map to 'maintain'. + """ + if not action: + return [action] + + # Check no-split list first + if action.lower().strip() in _NO_SPLIT_PHRASES: + return [action] + + parts = _COMPOUND_SPLIT_RE.split(action.strip()) + if len(parts) <= 1: + return [action] + + # Classify each part — only split if types differ + types = [_classify_action(p.strip()) for p in parts] + if len(set(types)) > 1: + return [p.strip() for p in parts if p.strip()] + + return [action] + + +# ── 2. Action Type Classification (18 types) ──────────────────────────── + +_ACTION_PRIORITY = [ + "prevent", "exclude", "forbid", + "implement", "configure", "encrypt", "restrict_access", + "enforce", "invalidate", "issue", "rotate", + "monitor", "review", "assess", "audit", + "test", "verify", "validate", + "report", "notify", "train", + "delete", "retain", "ensure", + "define", "document", "maintain", + "approve", "remediate", + "perform", "obtain", +] + +_ACTION_KEYWORDS: list[tuple[str, str]] = [ + # ── Negative / prohibitive actions (highest priority) ──── + ("dürfen keine", "prevent"), + ("dürfen nicht", "prevent"), + ("darf keine", "prevent"), + ("darf nicht", "prevent"), + ("nicht zulässig", "forbid"), + ("nicht erlaubt", "forbid"), + ("nicht gestattet", "forbid"), + ("untersagt", "forbid"), + ("verboten", "forbid"), + ("nicht enthalten", "exclude"), + ("nicht übertragen", "prevent"), + ("nicht übermittelt", "prevent"), + ("nicht wiederverwendet", "prevent"), + ("nicht gespeichert", "prevent"), + ("verhindern", "prevent"), + ("unterbinden", "prevent"), + ("ausschließen", "exclude"), + ("vermeiden", "prevent"), + ("ablehnen", "exclude"), + ("zurückweisen", "exclude"), + # ── Session / lifecycle actions ────────────────────────── + ("ungültig machen", "invalidate"), + ("invalidieren", "invalidate"), + ("widerrufen", "invalidate"), + ("session beenden", "invalidate"), + ("vergeben", "issue"), + ("ausstellen", "issue"), + ("erzeugen", "issue"), + ("generieren", "issue"), + ("rotieren", "rotate"), + ("erneuern", "rotate"), + ("durchsetzen", "enforce"), + ("erzwingen", "enforce"), + # ── Multi-word patterns (longest match wins) ───────────── + ("aktuell halten", "maintain"), + ("aufrechterhalten", "maintain"), + ("sicherstellen", "ensure"), + ("gewährleisten", "ensure"), + ("benachrichtigen", "notify"), + ("sensibilisieren", "train"), + ("authentifizieren", "restrict_access"), + ("verschlüsseln", "encrypt"), + ("implementieren", "implement"), + ("konfigurieren", "configure"), + ("bereitstellen", "implement"), + ("protokollieren", "document"), + ("dokumentieren", "document"), + ("kontrollieren", "monitor"), + ("installieren", "implement"), + ("autorisieren", "restrict_access"), + ("beschränken", "restrict_access"), + ("berechtigen", "restrict_access"), + ("aufbewahren", "retain"), + ("archivieren", "retain"), + ("überwachen", "monitor"), + ("überprüfen", "review"), + ("auditieren", "audit"), + ("informieren", "notify"), + ("analysieren", "assess"), + ("verifizieren", "verify"), + ("validieren", "validate"), + ("evaluieren", "assess"), + ("integrieren", "implement"), + ("aktivieren", "configure"), + ("einrichten", "configure"), + ("einführen", "implement"), + ("unterweisen", "train"), + ("durchführen", "perform"), + ("verarbeiten", "perform"), + ("vernichten", "delete"), + ("entfernen", "delete"), + ("absichern", "implement"), + ("schützen", "implement"), + ("bewerten", "assess"), + ("umsetzen", "implement"), + ("aufbauen", "implement"), + ("erstellen", "document"), + ("definieren", "define"), + ("festlegen", "define"), + ("vorgeben", "define"), + ("verfassen", "document"), + ("einholen", "obtain"), + ("genehmigen", "approve"), + ("freigeben", "approve"), + ("zulassen", "approve"), + ("beheben", "remediate"), + ("korrigieren", "remediate"), + ("beseitigen", "remediate"), + ("nachbessern", "remediate"), + ("speichern", "retain"), + ("mitteilen", "notify"), + ("berichten", "report"), + ("schulen", "train"), + ("melden", "report"), + ("prüfen", "review"), + ("testen", "test"), + ("führen", "document"), + ("pflegen", "maintain"), + ("wahren", "maintain"), + ("löschen", "delete"), + ("angeben", "document"), + ("beifügen", "document"), + # English fallbacks + ("implement", "implement"), + ("configure", "configure"), + ("establish", "define"), + ("document", "document"), + ("maintain", "maintain"), + ("monitor", "monitor"), + ("review", "review"), + ("assess", "assess"), + ("audit", "audit"), + ("encrypt", "encrypt"), + ("restrict", "restrict_access"), + ("authorize", "restrict_access"), + ("verify", "verify"), + ("validate", "validate"), + ("report", "report"), + ("notify", "notify"), + ("train", "train"), + ("test", "test"), + ("delete", "delete"), + ("retain", "retain"), + ("ensure", "ensure"), + ("approve", "approve"), + ("remediate", "remediate"), + ("perform", "perform"), + ("obtain", "obtain"), + ("prevent", "prevent"), + ("forbid", "forbid"), + ("exclude", "exclude"), + ("invalidate", "invalidate"), + ("revoke", "invalidate"), + ("issue", "issue"), + ("generate", "issue"), + ("rotate", "rotate"), + ("enforce", "enforce"), +] + + +def _classify_action(action: str) -> str: + """Classify an obligation action string into one of 18 action types. + + For compound actions, returns the highest-priority matching type. + """ + if not action: + return "default" + action_lower = action.lower().strip() + + matches: set[str] = set() + for keyword, atype in _ACTION_KEYWORDS: + if keyword in action_lower: + matches.add(atype) + + if not matches: + return "default" + + for prio in _ACTION_PRIORITY: + if prio in matches: + return prio + + return next(iter(matches)) + + +# ── 3. Object Class Classification ────────────────────────────────────── + +_OBJECT_CLASS_KEYWORDS: dict[str, list[str]] = { + # ── Governance / Documentation ──────────────────────────── + "policy": [ + "richtlinie", "policy", "konzept", "strategie", "leitlinie", + "vorgabe", "regelung", "ordnung", "anweisung", "standard", + "rahmenwerk", "sicherheitskonzept", "datenschutzkonzept", + ], + "procedure": [ + "verfahren", "workflow", "ablauf", + "vorgehensweise", "methodik", "prozedur", "handlungsanweisung", + ], + "register": [ + "verzeichnis", "register", "inventar", "liste", "katalog", + "übersicht", "bestandsaufnahme", + ], + "record": [ + "protokoll", "log", "aufzeichnung", "nachweis", + "evidenz", "artefakt", "dokumentation", + ], + "report": [ + "meldung", "bericht", "report", "benachrichtigung", + "mitteilung", "anzeige", "statusbericht", + ], + # ── Technical / Security ────────────────────────────────── + "technical_control": [ + "mfa", "firewall", "verschlüsselung", "backup", "antivirus", + "ids", "ips", "waf", "vpn", "tls", "ssl", + "patch", "update", "härtung", "segmentierung", + "alarmierung", "monitoring", + ], + "access_control": [ + "authentifizierung", "autorisierung", "zugriff", + "berechtigung", "passwort", "kennwort", "anmeldung", + "sso", "rbac", + ], + "session": [ + "session", "sitzung", "sitzungsverwaltung", "session management", + "session-id", "session-token", "idle timeout", + "inaktivitäts-timeout", "inaktivitätszeitraum", + "logout", "abmeldung", + ], + "cookie": [ + "cookie", "session-cookie", "secure-flag", "httponly", + "samesite", "cookie-attribut", + ], + "jwt": [ + "jwt", "json web token", "bearer token", + "jwt-algorithmus", "jwt-signatur", + ], + "federated_assertion": [ + "assertion", "saml", "oidc", "openid", + "föderiert", "federated", "identity provider", + ], + "cryptographic_control": [ + "schlüssel", "zertifikat", "signatur", "kryptographi", + "cipher", "hash", "token", "entropie", + ], + "configuration": [ + "konfiguration", "einstellung", "parameter", + "baseline", "hardening", "härtungsprofil", + ], + "account": [ + "account", "konto", "benutzer", "privilegiert", + "admin", "root", "dienstkonto", "servicekonto", + ], + # ── Data / Systems ──────────────────────────────────────── + "system": [ + "system", "plattform", "dienst", "service", "anwendung", + "software", "komponente", "infrastruktur", "netzwerk", + "server", "datenbank", "produkt", "gerät", "endgerät", + ], + "data": [ + "daten", "information", "personenbezogen", "datensatz", + "datei", "inhalt", "verarbeitungstätigkeit", + ], + "interface": [ + "schnittstelle", "interface", "api", "integration", + "datenfluss", "datenübermittlung", + ], + # ── People / Org ────────────────────────────────────────── + "role": [ + "mitarbeiter", "personal", "rolle", "beauftragter", + "verantwortlicher", "team", "abteilung", "beschäftigte", + ], + "training": [ + "schulung", "training", "sensibilisierung", "awareness", + "unterweisung", "fortbildung", "qualifikation", + ], + # ── Incident / Risk ─────────────────────────────────────── + "incident": [ + "vorfall", "incident", "sicherheitsvorfall", "störung", + "notfall", "krise", "bedrohung", + ], + "risk_artifact": [ + "risiko", "schwachstelle", "vulnerability", "gefährdung", + "risikoanalyse", "risikobewertung", "schutzbedarfsfeststellung", + ], + # ── Process / Consent ──────────────────────────────────── + "process": [ + "prozess", "geschäftsprozess", "betriebsprozess", + "managementprozess", "steuerungsprozess", + ], + "consent": [ + "einwilligung", "consent", "einverständnis", + "zustimmung", "opt-in", "opt-out", + ], +} + + +def _classify_object(object_: str) -> str: + """Classify the obligation object into a domain class.""" + if not object_: + return "general" + obj_lower = object_.lower() + for obj_class, keywords in _OBJECT_CLASS_KEYWORDS.items(): + if any(k in obj_lower for k in keywords): + return obj_class + return "general" + + +# ── 4. Trigger / Timing Qualifier ─────────────────────────────────────── + +_FREQUENCY_PATTERNS: list[tuple[str, str]] = [ + (r"jährl", "jährlich"), + (r"quartal", "quartalsweise"), + (r"halbjähr", "halbjährlich"), + (r"monatl", "monatlich"), + (r"wöchentl", "wöchentlich"), + (r"regelmäßig", "regelmässig"), + (r"72\s*Stunden", "innerhalb von 72 Stunden"), + (r"unverzüglich", "unverzüglich"), + (r"ohne\s+Verzögerung", "ohne unangemessene Verzögerung"), + (r"vor\s+Inbetriebnahme", "vor Inbetriebnahme"), + (r"vor\s+Markteinführung", "vor Markteinführung"), + (r"vor\s+Freigabe", "vor Freigabe"), +] + + +def _extract_trigger_qualifier( + trigger_type: Optional[str], obligation_text: str, +) -> str: + """Extract timing/trigger context for test procedures.""" + # Try to find specific frequency in obligation text + for pattern, qualifier in _FREQUENCY_PATTERNS: + if re.search(pattern, obligation_text, re.IGNORECASE): + return qualifier + + if trigger_type == "event": + return "bei Eintreten des auslösenden Ereignisses" + if trigger_type == "periodic": + return "periodisch" + return "" # continuous → no special qualifier + + +# ── 4b. Structured Timing Extraction ──────────────────────────────────── + +_STRUCTURED_FREQUENCY_MAP: list[tuple[str, Optional[int], Optional[str]]] = [ + # (regex_pattern, deadline_hours, frequency) + (r"72\s*Stunden", 72, None), + (r"48\s*Stunden", 48, None), + (r"24\s*Stunden", 24, None), + (r"unverzüglich", 0, None), + (r"ohne\s+Verzögerung", 0, None), + (r"sofort", 0, None), + (r"jährl", None, "yearly"), + (r"halbjähr", None, "semi_annually"), + (r"quartal", None, "quarterly"), + (r"monatl", None, "monthly"), + (r"wöchentl", None, "weekly"), + (r"täglich", None, "daily"), + (r"regelmäßig", None, "periodic"), + (r"periodisch", None, "periodic"), + (r"vor\s+Inbetriebnahme", None, "before_deployment"), + (r"vor\s+Markteinführung", None, "before_launch"), + (r"vor\s+Freigabe", None, "before_release"), +] + + +def _extract_structured_timing( + obligation_text: str, +) -> tuple[Optional[int], Optional[str]]: + """Extract deadline_hours and frequency from obligation text. + + Returns (deadline_hours, frequency). Both may be None. + """ + for pattern, deadline, freq in _STRUCTURED_FREQUENCY_MAP: + if re.search(pattern, obligation_text, re.IGNORECASE): + return (deadline, freq) + return (None, None) + + +# ── 5. Template Matrix: (action_type, object_class) → templates ───────── +# +# Specific combos override base templates. Lookup order: +# 1. _SPECIFIC_TEMPLATES[(action_type, object_class)] +# 2. _ACTION_TEMPLATES[action_type] +# 3. _DEFAULT_ACTION_TEMPLATE + +_ACTION_TEMPLATES: dict[str, dict[str, list[str]]] = { + # ── Create / Define / Document ───────────────────────────── + "define": { + "test_procedure": [ + "Prüfung, ob {object} definiert und formal freigegeben ist", + "Review der Inhalte auf Vollständigkeit und Angemessenheit", + "Verifizierung, dass {object} den Betroffenen kommuniziert wurde", + ], + "evidence": [ + "Freigegebenes Dokument mit Geltungsbereich", + "Kommunikationsnachweis (E-Mail, Intranet, Schulung)", + ], + }, + "document": { + "test_procedure": [ + "Prüfung, ob {object} dokumentiert und aktuell ist", + "Sichtung der Dokumentation auf Vollständigkeit", + "Verifizierung der Versionierung und des Review-Zyklus", + ], + "evidence": [ + "Dokument mit Versionshistorie", + "Freigabenachweis (Unterschrift/Approval)", + ], + }, + "maintain": { + "test_procedure": [ + "Prüfung, ob {object} aktuell gehalten wird", + "Vergleich der letzten Aktualisierung mit dem Review-Zyklus", + "Stichprobe: Änderungen nach relevanten Ereignissen nachvollzogen", + ], + "evidence": [ + "Änderungshistorie mit Datum und Verantwortlichem", + "Nachweis des letzten Reviews", + ], + }, + # ── Implement / Configure ────────────────────────────────── + "implement": { + "test_procedure": [ + "Prüfung der technischen Umsetzung von {object}", + "Funktionstest der implementierten Massnahme", + "Review der Konfiguration gegen die Anforderungsspezifikation", + ], + "evidence": [ + "Konfigurationsnachweis (Screenshot/Export)", + "Implementierungsdokumentation", + ], + }, + "configure": { + "test_procedure": [ + "Prüfung der Konfiguration von {object} gegen Soll-Vorgaben", + "Vergleich mit Hardening-Baseline oder Best Practice", + "Automatisierter Konfigurationsscan (falls verfügbar)", + ], + "evidence": [ + "Konfigurationsexport mit Soll-/Ist-Vergleich", + "Scan-Bericht oder Compliance-Check-Ergebnis", + ], + }, + # ── Monitor / Review / Assess / Audit ────────────────────── + "monitor": { + "test_procedure": [ + "Prüfung der laufenden Überwachung von {object}", + "Stichprobe der Protokolle/Logs der letzten 3 Monate", + "Verifizierung der Alarmierungs- und Eskalationsprozesse", + ], + "evidence": [ + "Monitoring-Dashboard-Export oder Log-Auszüge", + "Alarmierungsregeln und Eskalationsmatrix", + ], + }, + "review": { + "test_procedure": [ + "Prüfung, ob {object} im vorgesehenen Zyklus überprüft wurde", + "Sichtung des Review-Protokolls auf Massnahmenableitung", + "Verifizierung der Umsetzung identifizierter Massnahmen", + ], + "evidence": [ + "Review-Protokoll mit Datum und Teilnehmern", + "Massnahmenplan mit Umsetzungsstatus", + ], + }, + "assess": { + "test_procedure": [ + "Prüfung der Bewertungsmethodik für {object}", + "Sichtung der letzten Bewertungsergebnisse", + "Verifizierung, dass Ergebnisse in Massnahmen überführt wurden", + ], + "evidence": [ + "Bewertungsbericht mit Methodik und Ergebnissen", + "Abgeleiteter Massnahmenplan", + ], + }, + "audit": { + "test_procedure": [ + "Prüfung des Audit-Plans und der Audit-Durchführung für {object}", + "Sichtung der Audit-Berichte und Findings", + "Verifizierung der Nachverfolgung offener Findings", + ], + "evidence": [ + "Audit-Bericht mit Findings und Empfehlungen", + "Finding-Tracker mit Umsetzungsstatus", + ], + }, + # ── Test / Verify / Validate ─────────────────────────────── + "test": { + "test_procedure": [ + "Review der Testpläne und -methodik für {object}", + "Stichprobe der Testergebnisse und Massnahmenableitung", + "Prüfung der Testabdeckung und -häufigkeit", + ], + "evidence": [ + "Testprotokoll mit Ergebnissen", + "Testplan und Abdeckungsbericht", + ], + }, + "verify": { + "test_procedure": [ + "Prüfung der Verifikationsmethodik für {object}", + "Nachvollzug der Verifikationsergebnisse gegen Spezifikation", + "Prüfung, ob alle Anforderungen abgedeckt sind", + ], + "evidence": [ + "Verifikationsbericht mit Soll-/Ist-Abgleich", + "Anforderungs-Traceability-Matrix", + ], + }, + "validate": { + "test_procedure": [ + "Prüfung der Validierungsmethodik für {object}", + "Bewertung, ob die Massnahme den Zweck im Praxiseinsatz erfüllt", + "Auswertung von Nutzerfeedback oder Betriebsdaten", + ], + "evidence": [ + "Validierungsbericht", + "Praxisnachweis (Betriebsdaten, Nutzerfeedback)", + ], + }, + # ── Report / Notify ──────────────────────────────────────── + "report": { + "test_procedure": [ + "Prüfung des Meldeprozesses für {object}", + "Stichprobe gemeldeter Vorfälle auf Vollständigkeit und Fristeneinhaltung", + "Verifizierung der Meldekanäle und Zuständigkeiten", + ], + "evidence": [ + "Meldeprozess-Dokumentation", + "Nachweise über erfolgte Meldungen mit Zeitstempeln", + ], + }, + "notify": { + "test_procedure": [ + "Prüfung des Benachrichtigungsprozesses für {object}", + "Verifizierung der Empfänger und Kommunikationskanäle", + "Stichprobe: Benachrichtigungen fristgerecht versendet", + ], + "evidence": [ + "Benachrichtigungsvorlagen und Verteiler", + "Versandnachweise mit Zeitstempeln", + ], + }, + # ── Train ────────────────────────────────────────────────── + "train": { + "test_procedure": [ + "Prüfung der Schulungsinhalte und -unterlagen zu {object}", + "Verifizierung der Teilnehmerlisten und Schulungsfrequenz", + "Stichprobe: Wissensstand durch Befragung oder Kurztest", + ], + "evidence": [ + "Schulungsunterlagen und Schulungsplan", + "Teilnehmerlisten mit Datum und Unterschrift", + "Ergebnisse von Wissenstests (falls durchgeführt)", + ], + }, + # ── Access / Encrypt ─────────────────────────────────────── + "restrict_access": { + "test_procedure": [ + "Review der Zugriffsberechtigungen für {object}", + "Prüfung der Berechtigungsmatrix auf Aktualität und Least-Privilege", + "Stichprobe: Entzug von Rechten bei Rollenwechsel/Austritt", + ], + "evidence": [ + "Aktuelle Berechtigungsmatrix", + "Zugriffsprotokolle der letzten 3 Monate", + "Nachweis des letzten Berechtigungs-Reviews", + ], + }, + "encrypt": { + "test_procedure": [ + "Prüfung der Verschlüsselungskonfiguration für {object}", + "Verifizierung der Algorithmen und Schlüssellängen gegen BSI-Empfehlungen", + "Prüfung des Schlüsselmanagement-Prozesses (Rotation, Speicherung)", + ], + "evidence": [ + "Kryptographie-Konzept", + "Zertifikats- und Schlüsselinventar", + "Schlüsselrotations-Nachweis", + ], + }, + # ── Delete / Retain ──────────────────────────────────────── + "delete": { + "test_procedure": [ + "Prüfung des Löschkonzepts für {object}", + "Verifizierung der Löschfristen und automatisierten Löschmechanismen", + "Stichprobe: Löschung nach Fristablauf tatsächlich durchgeführt", + ], + "evidence": [ + "Löschkonzept mit definierten Fristen", + "Löschprotokolle oder -nachweise", + ], + }, + "retain": { + "test_procedure": [ + "Prüfung der Aufbewahrungsfristen und -orte für {object}", + "Verifizierung der Zugriffskontrollen auf archivierte Daten", + "Prüfung der automatischen Löschung nach Ablauf der Aufbewahrungsfrist", + ], + "evidence": [ + "Aufbewahrungsrichtlinie mit Fristen", + "Speicherort-Dokumentation mit Zugriffskonzept", + ], + }, + # ── Ensure (catch-all for sicherstellen/gewährleisten) ───── + "ensure": { + "test_procedure": [ + "Prüfung, ob Massnahmen für {object} wirksam umgesetzt sind", + "Stichprobenprüfung der Einhaltung im operativen Betrieb", + "Review der zugehörigen Prozessdokumentation", + ], + "evidence": [ + "Nachweis der Umsetzung (Konfiguration/Prozess)", + "Prüfprotokoll der letzten Überprüfung", + ], + }, + # ── Perform / Obtain ─────────────────────────────────────── + "perform": { + "test_procedure": [ + "Prüfung der Durchführung von {object}", + "Verifizierung der Zuständigkeiten und Freigabeschritte", + "Stichprobe der Durchführung anhand aktueller Fälle", + ], + "evidence": [ + "Durchführungsnachweise (Tickets, Protokolle)", + "Prozessdokumentation mit Verantwortlichkeiten", + ], + }, + "obtain": { + "test_procedure": [ + "Prüfung des Einholungsprozesses für {object}", + "Verifizierung der Vollständigkeit und Gültigkeit", + "Stichprobe: Einholung vor Beginn der Verarbeitung nachgewiesen", + ], + "evidence": [ + "Nachweise der Einholung (Einwilligungen, Freigaben)", + "Gültigkeitsprüfung mit Zeitstempeln", + ], + }, + # ── Prevent / Exclude / Forbid (negative norms) ──────────── + "prevent": { + "test_procedure": [ + "Prüfung, dass {object} technisch verhindert wird", + "Stichprobe: Versuch der verbotenen Aktion schlägt fehl", + "Review der Konfiguration und Zugriffskontrollen", + ], + "evidence": [ + "Konfigurationsnachweis der Präventionsmassnahme", + "Testprotokoll der Negativtests", + ], + }, + "exclude": { + "test_procedure": [ + "Prüfung, dass {object} ausgeschlossen ist", + "Stichprobe: Verbotene Inhalte/Aktionen sind nicht vorhanden", + "Automatisierter Scan oder manuelle Prüfung", + ], + "evidence": [ + "Scan-Ergebnis oder Prüfprotokoll", + "Konfigurationsnachweis", + ], + }, + "forbid": { + "test_procedure": [ + "Prüfung, dass {object} untersagt und technisch blockiert ist", + "Verifizierung der Richtlinie und technischen Durchsetzung", + "Stichprobe: Versuch der untersagten Aktion wird abgelehnt", + ], + "evidence": [ + "Richtlinie mit explizitem Verbot", + "Technischer Nachweis der Blockierung", + ], + }, + # ── Enforce / Invalidate / Issue / Rotate ──────────────── + "enforce": { + "test_procedure": [ + "Prüfung der technischen Durchsetzung von {object}", + "Stichprobe: Nicht-konforme Konfigurationen werden automatisch korrigiert oder abgelehnt", + "Review der Enforcement-Regeln und Ausnahmen", + ], + "evidence": [ + "Enforcement-Policy mit technischer Umsetzung", + "Protokoll erzwungener Korrekturen oder Ablehnungen", + ], + }, + "invalidate": { + "test_procedure": [ + "Prüfung, dass {object} korrekt ungültig gemacht wird", + "Stichprobe: Nach Invalidierung kein Zugriff mehr möglich", + "Verifizierung der serverseitigen Bereinigung", + ], + "evidence": [ + "Protokoll der Invalidierungsaktionen", + "Testnachweis der Zugriffsverweigerung nach Invalidierung", + ], + }, + "issue": { + "test_procedure": [ + "Prüfung des Vergabeprozesses für {object}", + "Verifizierung der kryptographischen Sicherheit und Entropie", + "Stichprobe: Korrekte Vergabe unter definierten Bedingungen", + ], + "evidence": [ + "Prozessdokumentation der Vergabe", + "Nachweis der Entropie-/Sicherheitseigenschaften", + ], + }, + "rotate": { + "test_procedure": [ + "Prüfung des Rotationsprozesses für {object}", + "Verifizierung der Rotationsfrequenz und automatischen Auslöser", + "Stichprobe: Alte Artefakte nach Rotation ungültig", + ], + "evidence": [ + "Rotationsrichtlinie mit Frequenz", + "Rotationsprotokoll mit Zeitstempeln", + ], + }, + # ── Approve / Remediate ─────────────────────────────────── + "approve": { + "test_procedure": [ + "Prüfung des Genehmigungsprozesses für {object}", + "Verifizierung der Freigabeberechtigungen und Eskalationswege", + "Stichprobe: Genehmigung vor Umsetzung/Nutzung nachgewiesen", + ], + "evidence": [ + "Freigabenachweis (Signatur, Ticket, Workflow)", + "Genehmigungsmatrix mit Zuständigkeiten", + ], + }, + "remediate": { + "test_procedure": [ + "Prüfung des Behebungsprozesses für {object}", + "Verifizierung der Korrekturmassnahmen und Wirksamkeit", + "Stichprobe: Abweichungen fristgerecht behoben", + ], + "evidence": [ + "Korrekturmassnahmen-Dokumentation", + "Nachprüfprotokoll der Wirksamkeit", + ], + }, +} + +# ── Specific (action_type, object_class) overrides ─────────────────────── + +_SPECIFIC_TEMPLATES: dict[tuple[str, str], dict[str, list[str]]] = { + ("implement", "policy"): { + "test_procedure": [ + "Prüfung, ob {object} dokumentiert, freigegeben und in relevanten Prozessen umgesetzt ist", + "Interview mit Prozessverantwortlichen zur tatsächlichen Umsetzung", + "Stichprobe: Nachweis der Umsetzung in der Praxis", + ], + "evidence": [ + "Freigegebenes Richtliniendokument", + "Nachweis der Kommunikation an Betroffene", + "Stichproben der operativen Umsetzung", + ], + }, + ("implement", "technical_control"): { + "test_procedure": [ + "Prüfung der technischen Konfiguration von {object}", + "Funktionstest: Wirksamkeit der Massnahme verifizieren", + "Vulnerability-Scan oder Penetrationstest (falls anwendbar)", + ], + "evidence": [ + "Konfigurationsnachweis (Screenshot/Export)", + "Testprotokoll mit Ergebnissen", + "Scan-Bericht (falls durchgeführt)", + ], + }, + ("implement", "process"): { + "test_procedure": [ + "Prüfung der Prozessdokumentation für {object}", + "Verifizierung, dass der Prozess operativ gelebt wird", + "Stichprobe: Prozessdurchführung anhand aktueller Fälle", + ], + "evidence": [ + "Prozessdokumentation mit RACI-Matrix", + "Durchführungsnachweise der letzten 3 Monate", + ], + }, + ("define", "policy"): { + "test_procedure": [ + "Prüfung, ob {object} formal definiert und durch Management freigegeben ist", + "Review des Geltungsbereichs und der Adressaten", + "Verifizierung der regelmässigen Aktualisierung", + ], + "evidence": [ + "Freigegebene Policy mit Unterschrift der Geschäftsleitung", + "Geltungsbereich und Verteiler", + "Letzte Aktualisierung mit Änderungshistorie", + ], + }, + ("monitor", "system"): { + "test_procedure": [ + "Prüfung der Monitoring-Konfiguration für {object}", + "Stichprobe der System-Logs und Alerts der letzten 3 Monate", + "Verifizierung: Alerts führen zu dokumentierten Reaktionen", + ], + "evidence": [ + "Monitoring-Dashboard-Export", + "Alert-Konfiguration und Eskalationsregeln", + "Incident-Tickets aus Alert-Eskalation", + ], + }, + ("monitor", "incident"): { + "test_procedure": [ + "Prüfung des Incident-Monitoring-Prozesses für {object}", + "Stichprobe erkannter Vorfälle auf Reaktionszeit", + "Verifizierung der Eskalationswege", + ], + "evidence": [ + "Incident-Log mit Erkennungs- und Reaktionszeiten", + "Eskalationsmatrix", + ], + }, + ("review", "policy"): { + "test_procedure": [ + "Prüfung, ob {object} im vorgesehenen Zyklus durch Management reviewed wurde", + "Sichtung des Review-Protokolls auf Aktualisierungsbedarf", + "Verifizierung, dass Änderungen umgesetzt wurden", + ], + "evidence": [ + "Review-Protokoll mit Datum und Teilnehmern", + "Aktualisierte Version der Richtlinie (falls geändert)", + ], + }, + ("assess", "incident"): { + "test_procedure": [ + "Prüfung der Risikobewertung für {object}", + "Nachvollzug der Bewertungskriterien (Schwere, Auswirkung, Wahrscheinlichkeit)", + "Verifizierung der abgeleiteten Massnahmen", + ], + "evidence": [ + "Risikobewertungs-Matrix", + "Massnahmenplan mit Priorisierung", + ], + }, + ("report", "incident"): { + "test_procedure": [ + "Prüfung des Meldeprozesses für {object} an zuständige Behörden", + "Verifizierung der Meldefristen (z.B. 72h DSGVO, 24h NIS2)", + "Stichprobe: Meldungen fristgerecht und vollständig", + ], + "evidence": [ + "Meldeprozess mit Fristen und Zuständigkeiten", + "Kopien erfolgter Behördenmeldungen", + "Zeitstempel-Nachweis der Fristwahrung", + ], + }, + ("train", "role"): { + "test_procedure": [ + "Prüfung der Schulungspflicht für {object}", + "Verifizierung der Teilnahme aller betroffenen Personen", + "Stichprobe: Wissensstand durch Kurztest oder Befragung", + ], + "evidence": [ + "Schulungsplan mit Zielgruppen und Frequenz", + "Teilnehmerlisten mit Datum und Unterschrift", + "Testergebnisse oder Teilnahmebestätigungen", + ], + }, + ("restrict_access", "data"): { + "test_procedure": [ + "Prüfung der Zugriffskontrollen für {object}", + "Review der Berechtigungen nach Need-to-Know-Prinzip", + "Stichprobe: Keine überprivilegierten Zugänge", + ], + "evidence": [ + "Berechtigungsmatrix mit Rollen und Datenklassen", + "Zugriffsprotokolle", + "Ergebnis des letzten Access Reviews", + ], + }, + ("restrict_access", "system"): { + "test_procedure": [ + "Prüfung der Zugriffskontrollen für {object}", + "Review der Admin-/Privileged-Zugänge", + "Stichprobe: MFA aktiv, Passwort-Policy eingehalten", + ], + "evidence": [ + "Berechtigungsmatrix", + "Audit-Log privilegierter Zugriffe", + "MFA-Konfigurationsnachweis", + ], + }, + ("encrypt", "data"): { + "test_procedure": [ + "Prüfung der Verschlüsselung von {object} at Rest und in Transit", + "Verifizierung der Algorithmen gegen BSI TR-02102", + "Prüfung des Key-Management-Prozesses", + ], + "evidence": [ + "Kryptographie-Konzept mit Algorithmen und Schlüssellängen", + "TLS-Konfigurationsnachweis", + "Key-Rotation-Protokoll", + ], + }, + ("delete", "data"): { + "test_procedure": [ + "Prüfung des Löschkonzepts für {object}", + "Verifizierung der automatisierten Löschmechanismen", + "Stichprobe: Löschung personenbezogener Daten nach Fristablauf", + ], + "evidence": [ + "Löschkonzept mit Datenklassen und Fristen", + "Löschprotokolle", + "Nachweis der Vernichtung (bei physischen Medien)", + ], + }, + ("retain", "data"): { + "test_procedure": [ + "Prüfung der Aufbewahrungsfristen für {object}", + "Verifizierung der Speicherorte und Zugriffskontrollen", + "Prüfung: Keine Aufbewahrung über gesetzliche Frist hinaus", + ], + "evidence": [ + "Aufbewahrungsrichtlinie mit gesetzlichen Grundlagen", + "Speicherort-Inventar mit Zugriffskonzept", + ], + }, + ("obtain", "data"): { + "test_procedure": [ + "Prüfung des Einwilligungsprozesses für {object}", + "Verifizierung der Einwilligungstexte auf Rechtskonformität", + "Stichprobe: Einwilligung vor Verarbeitungsbeginn eingeholt", + ], + "evidence": [ + "Einwilligungsformulare/-dialoge", + "Consent-Log mit Zeitstempeln", + "Widerrufsprozess-Dokumentation", + ], + }, +} + +_DEFAULT_ACTION_TEMPLATE: dict[str, list[str]] = { + "test_procedure": [ + "Prüfung der Umsetzung von {object}", + "Verifizierung der zugehörigen Dokumentation und Nachweisführung", + ], + "evidence": [ + "Umsetzungsnachweis", + "Zugehörige Dokumentation", + ], +} + + +# ── 6. Title Suffix (action_type → past participle / state) ───────────── + +_ACTION_STATE_SUFFIX: dict[str, str] = { + "define": "definiert und freigegeben", + "document": "dokumentiert", + "maintain": "aktuell gehalten", + "implement": "umgesetzt", + "configure": "konfiguriert", + "monitor": "überwacht", + "review": "überprüft", + "assess": "bewertet", + "audit": "auditiert", + "test": "getestet", + "verify": "verifiziert", + "validate": "validiert", + "report": "gemeldet", + "notify": "benachrichtigt", + "train": "geschult", + "restrict_access": "zugriffsbeschränkt", + "encrypt": "verschlüsselt", + "delete": "gelöscht", + "retain": "aufbewahrt", + "ensure": "sichergestellt", + "approve": "genehmigt", + "remediate": "behoben", + "perform": "durchgeführt", + "obtain": "eingeholt", +} + + +# ── 6b. Pattern Candidates ────────────────────────────────────────────── + +_PATTERN_CANDIDATES_MAP: dict[tuple[str, str], list[str]] = { + ("define", "policy"): ["policy_documented", "policy_approved"], + ("document", "policy"): ["policy_documented"], + ("implement", "technical_control"): ["technical_safeguard_enabled", "security_control_tested"], + ("implement", "policy"): ["policy_implemented", "policy_communicated"], + ("implement", "process"): ["process_established", "process_operational"], + ("monitor", "system"): ["continuous_monitoring_active"], + ("monitor", "incident"): ["incident_detection_active"], + ("review", "policy"): ["policy_review_completed"], + ("review", "risk_artifact"): ["risk_review_completed"], + ("assess", "risk_artifact"): ["risk_assessment_completed"], + ("restrict_access", "data"): ["access_control_enforced"], + ("restrict_access", "system"): ["access_control_enforced", "privilege_management_active"], + ("restrict_access", "access_control"): ["access_control_enforced"], + ("encrypt", "data"): ["encryption_at_rest", "encryption_in_transit"], + ("encrypt", "cryptographic_control"): ["encryption_at_rest", "key_management_active"], + ("train", "role"): ["awareness_training_completed"], + ("train", "training"): ["awareness_training_completed"], + ("report", "incident"): ["incident_reported_timely"], + ("notify", "incident"): ["notification_sent_timely"], + ("delete", "data"): ["data_deletion_completed"], + ("retain", "data"): ["data_retention_enforced"], + ("audit", "system"): ["audit_completed"], + ("test", "technical_control"): ["security_control_tested"], + ("obtain", "consent"): ["consent_obtained"], + ("obtain", "data"): ["consent_obtained"], + ("approve", "policy"): ["policy_approved"], +} + +_PATTERN_CANDIDATES_BY_ACTION: dict[str, list[str]] = { + "define": ["policy_documented"], + "document": ["policy_documented"], + "implement": ["control_implemented"], + "monitor": ["continuous_monitoring_active"], + "review": ["review_completed"], + "assess": ["assessment_completed"], + "audit": ["audit_completed"], + "test": ["security_control_tested"], + "report": ["incident_reported_timely"], + "notify": ["notification_sent_timely"], + "train": ["awareness_training_completed"], + "restrict_access": ["access_control_enforced"], + "encrypt": ["encryption_at_rest"], + "delete": ["data_deletion_completed"], + "retain": ["data_retention_enforced"], + "ensure": ["control_implemented"], + "approve": ["policy_approved"], + "remediate": ["remediation_completed"], + "perform": ["activity_performed"], + "obtain": ["consent_obtained"], + "configure": ["technical_safeguard_enabled"], + "verify": ["verification_completed"], + "validate": ["validation_completed"], + "maintain": ["control_maintained"], +} + + +# ── 6c. Raw Infinitives (for validator Negativregeln) ──────────────────── + +_RAW_INFINITIVES: set[str] = { + "implementieren", "dokumentieren", "definieren", "konfigurieren", + "überwachen", "überprüfen", "auditieren", "testen", "verifizieren", + "validieren", "melden", "benachrichtigen", "schulen", "verschlüsseln", + "löschen", "aufbewahren", "sicherstellen", "gewährleisten", + "genehmigen", "beheben", "durchführen", "einholen", "erstellen", + "festlegen", "bereitstellen", "installieren", "einrichten", + "bewerten", "analysieren", "kontrollieren", "protokollieren", +} + + +# ── 7. Object Normalization (with synonym mapping) ────────────────────── + +_OBJECT_SYNONYMS: dict[str, str] = { + "verzeichnis": "register", + "inventar": "register", + "katalog": "register", + "bestandsaufnahme": "register", + "richtlinie": "policy", + "konzept": "policy", + "strategie": "policy", + "leitlinie": "policy", + "vorgabe": "policy", + "regelung": "policy", + "anweisung": "policy", + "rahmenwerk": "policy", + "sicherheitskonzept": "policy", + "datenschutzkonzept": "policy", + "verfahren": "procedure", + "ablauf": "procedure", + "vorgehensweise": "procedure", + "methodik": "procedure", + "prozedur": "procedure", + "handlungsanweisung": "procedure", + "protokoll": "record", + "aufzeichnung": "record", + "nachweis": "record", + "evidenz": "record", + "vorfall": "incident", + "störung": "incident", + "sicherheitsvorfall": "incident", + "notfall": "incident", + "krise": "incident", + "schwachstelle": "risk_artifact", + "gefährdung": "risk_artifact", + "risikoanalyse": "risk_artifact", + "risikobewertung": "risk_artifact", + "mitarbeiter": "role", + "personal": "role", + "beauftragter": "role", + "verantwortlicher": "role", + "schulung": "training", + "sensibilisierung": "training", + "unterweisung": "training", + "verschlüsselung": "technical_control", + "firewall": "technical_control", + "backup": "technical_control", + "meldung": "report", + "bericht": "report", + "benachrichtigung": "report", + "berechtigung": "access_control", + "authentifizierung": "access_control", + "zugriff": "access_control", + "einwilligung": "consent", + "zustimmung": "consent", + # Near-synonym expansions found via heavy-control analysis (2026-03-28) + "erkennung": "detection", + "früherkennung": "detection", + "frühzeitige erkennung": "detection", + "frühzeitigen erkennung": "detection", + "detektion": "detection", + "eskalation": "escalation", + "eskalationsprozess": "escalation", + "eskalationsverfahren": "escalation", + "benachrichtigungsprozess": "notification", + "benachrichtigungsverfahren": "notification", + "meldeprozess": "notification", + "meldeverfahren": "notification", + "meldesystem": "notification", + "benachrichtigungssystem": "notification", + "überwachung": "monitoring", + "monitoring": "monitoring", + "kontinuierliche überwachung": "monitoring", + "laufende überwachung": "monitoring", + "prüfung": "audit", + "überprüfung": "audit", + "kontrolle": "control_check", + "sicherheitskontrolle": "control_check", + "dokumentation": "documentation", + "aufzeichnungspflicht": "documentation", + "protokollierung": "logging", + "logführung": "logging", + "logmanagement": "logging", + "wiederherstellung": "recovery", + "notfallwiederherstellung": "recovery", + "disaster recovery": "recovery", + "notfallplan": "contingency_plan", + "notfallplanung": "contingency_plan", + "wiederanlaufplan": "contingency_plan", + "klassifizierung": "classification", + "kategorisierung": "classification", + "einstufung": "classification", + "segmentierung": "segmentation", + "netzwerksegmentierung": "segmentation", + "netzwerk-segmentierung": "segmentation", + "trennung": "segmentation", + "isolierung": "isolation", + "patch": "patch_mgmt", + "patchmanagement": "patch_mgmt", + "patch-management": "patch_mgmt", + "aktualisierung": "patch_mgmt", + "softwareaktualisierung": "patch_mgmt", + "härtung": "hardening", + "systemhärtung": "hardening", + "härtungsmaßnahme": "hardening", + "löschung": "deletion", + "datenlöschung": "deletion", + "löschkonzept": "deletion", + "anonymisierung": "anonymization", + "pseudonymisierung": "pseudonymization", + "zugangssteuerung": "access_control", + "zugangskontrolle": "access_control", + "zugriffssteuerung": "access_control", + "zugriffskontrolle": "access_control", + "schlüsselmanagement": "key_mgmt", + "schlüsselverwaltung": "key_mgmt", + "key management": "key_mgmt", + "zertifikatsverwaltung": "cert_mgmt", + "zertifikatsmanagement": "cert_mgmt", + "lieferant": "vendor", + "dienstleister": "vendor", + "auftragsverarbeiter": "vendor", + "drittanbieter": "vendor", + # Session management synonyms (2026-03-28) + "sitzung": "session", + "sitzungsverwaltung": "session_mgmt", + "session management": "session_mgmt", + "session-id": "session_token", + "sitzungstoken": "session_token", + "session-token": "session_token", + "idle timeout": "session_timeout", + "inaktivitäts-timeout": "session_timeout", + "inaktivitätszeitraum": "session_timeout", + "abmeldung": "logout", + "cookie-attribut": "cookie_security", + "secure-flag": "cookie_security", + "httponly": "cookie_security", + "samesite": "cookie_security", + "json web token": "jwt", + "bearer token": "jwt", + "föderierte assertion": "federated_assertion", + "saml assertion": "federated_assertion", +} + + +def _truncate_title(title: str, max_len: int = 80) -> str: + """Truncate title at word boundary to avoid mid-word cuts.""" + if len(title) <= max_len: + return title + truncated = title[:max_len] + # Cut at last space to avoid mid-word truncation + last_space = truncated.rfind(" ") + if last_space > max_len // 2: + return truncated[:last_space] + return truncated + + +def _normalize_object(object_raw: str) -> str: + """Normalize object text to a snake_case key for merge hints. + + Applies synonym mapping to collapse German terms to canonical forms + (e.g., 'Richtlinie' -> 'policy', 'Verzeichnis' -> 'register'). + Then strips qualifying prepositional phrases that would create + near-duplicate keys (e.g., 'bei Schwellenwertüberschreitung'). + Truncates to 40 chars to collapse overly specific variants. + """ + if not object_raw: + return "unknown" + + obj_lower = object_raw.strip().lower() + + # Strip qualifying prepositional phrases that don't change core identity. + # These create near-duplicate keys like "eskalationsprozess" vs + # "eskalationsprozess bei schwellenwertüberschreitung". + obj_lower = _QUALIFYING_PHRASE_RE.sub("", obj_lower).strip() + + # Synonym mapping — find the longest matching synonym + best_match = "" + best_canonical = "" + for synonym, canonical in _OBJECT_SYNONYMS.items(): + if synonym in obj_lower and len(synonym) > len(best_match): + best_match = synonym + best_canonical = canonical + + if best_canonical: + obj_lower = obj_lower.replace(best_match, best_canonical, 1) + + obj = re.sub(r"\s+", "_", obj_lower.strip()) + for src, dst in [("ä", "ae"), ("ö", "oe"), ("ü", "ue"), ("ß", "ss")]: + obj = obj.replace(src, dst) + obj = re.sub(r"[^a-z0-9_]", "", obj) + + # Strip trailing noise tokens (articles/prepositions stuck at the end) + obj = re.sub(r"(_(?:der|die|das|des|dem|den|fuer|bei|von|zur|zum|mit|auf|in|und|oder|aus|an|ueber|nach|gegen|unter|vor|zwischen|als|durch|ohne|wie))+$", "", obj) + + # Truncate at 40 chars (at underscore boundary) to collapse + # overly specific suffixes that create near-duplicate keys. + obj = _truncate_at_boundary(obj, 40) + + return obj or "unknown" + + +# Regex to strip German qualifying prepositional phrases from object text. +# Matches patterns like "bei schwellenwertüberschreitung", +# "für kritische systeme", "gemäß artikel 32" etc. +_QUALIFYING_PHRASE_RE = re.compile( + r"\s+(?:" + r"bei\s+\w+" + r"|für\s+(?:die\s+|den\s+|das\s+|kritische\s+)?\w+" + r"|gemäß\s+\w+" + r"|nach\s+\w+" + r"|von\s+\w+" + r"|im\s+(?:falle?\s+|rahmen\s+)?\w+" + r"|mit\s+(?:den\s+|der\s+|dem\s+)?\w+" + r"|auf\s+(?:basis|grundlage)\s+\w+" + r"|zur\s+(?:einhaltung|sicherstellung|gewährleistung|vermeidung|erfüllung)\s*\w*" + r"|durch\s+(?:den\s+|die\s+|das\s+)?\w+" + r"|über\s+(?:den\s+|die\s+|das\s+)?\w+" + r"|unter\s+\w+" + r"|zwischen\s+\w+" + r"|innerhalb\s+\w+" + r"|gegenüber\s+\w+" + r"|hinsichtlich\s+\w+" + r"|bezüglich\s+\w+" + r"|einschließlich\s+\w+" + r").*$", + re.IGNORECASE, +) + + +def _truncate_at_boundary(text: str, max_len: int) -> str: + """Truncate text at the last underscore boundary within max_len.""" + if len(text) <= max_len: + return text + truncated = text[:max_len] + last_sep = truncated.rfind("_") + if last_sep > max_len // 2: + return truncated[:last_sep] + return truncated + + +# ── 7b. Framework / Composite Detection ────────────────────────────────── + +_FRAMEWORK_KEYWORDS: list[str] = [ + "praktiken", "kontrollen gemäß", "maßnahmen gemäß", "anforderungen aus", + "anforderungen gemäß", "gemäß .+ umzusetzen", "framework", "standard", + "controls for", "practices for", "requirements from", +] + +_COMPOSITE_OBJECT_KEYWORDS: list[str] = [ + "ccm", "nist", "iso 27001", "iso 27002", "owasp", "bsi", + "cis controls", "cobit", "sox", "pci dss", "hitrust", + "soc 2", "soc2", "enisa", "kritis", +] + +# Container objects that are too broad for atomic controls. +# These produce titles like "Sichere Sitzungsverwaltung umgesetzt" which +# are not auditable — they encompass multiple sub-requirements. +_CONTAINER_OBJECT_KEYWORDS: list[str] = [ + "sitzungsverwaltung", "session management", "session-management", + "token-schutz", "tokenschutz", + "authentifizierungsmechanismen", "authentifizierungsmechanismus", + "sicherheitsmaßnahmen", "sicherheitsmassnahmen", + "schutzmaßnahmen", "schutzmassnahmen", + "zugriffskontrollmechanismen", + "sicherheitsarchitektur", + "sicherheitskontrollen", + "datenschutzmaßnahmen", "datenschutzmassnahmen", + "compliance-anforderungen", + "risikomanagementprozess", +] + +_COMPOSITE_RE = re.compile( + "|".join(_FRAMEWORK_KEYWORDS + _COMPOSITE_OBJECT_KEYWORDS), + re.IGNORECASE, +) + +_CONTAINER_RE = re.compile( + "|".join(_CONTAINER_OBJECT_KEYWORDS), + re.IGNORECASE, +) + + +def _is_composite_obligation(obligation_text: str, object_: str) -> bool: + """Detect framework-level / composite obligations that are NOT atomic. + + Returns True if the obligation references a framework domain, standard, + or set of practices rather than a single auditable requirement. + """ + combined = f"{obligation_text} {object_}" + return bool(_COMPOSITE_RE.search(combined)) + + +def _is_container_object(object_: str) -> bool: + """Detect overly broad container objects that should not be atomic. + + Objects like 'Sitzungsverwaltung' or 'Token-Schutz' encompass multiple + sub-requirements and produce non-auditable controls. + """ + if not object_: + return False + return bool(_CONTAINER_RE.search(object_)) + + +# ── 7c. Output Validator (Negativregeln) ───────────────────────────────── + +def _validate_atomic_control( + atomic: "AtomicControlCandidate", + action_type: str, + object_class: str, +) -> list[str]: + """Validate an atomic control against Pflichtfelder + Negativregeln. + + Returns a list of issue strings (ERROR: / WARN:). + Logs warnings but never rejects the control. + """ + issues: list[str] = [] + + # ── Pflichtfelder ────────────────────────────────────── + if not atomic.title.strip(): + issues.append("ERROR: title is empty") + if not atomic.objective.strip(): + issues.append("ERROR: objective is empty") + if not atomic.test_procedure: + issues.append("ERROR: test_procedure is empty") + if not atomic.evidence: + issues.append("ERROR: evidence is empty") + + # ── Negativregeln ────────────────────────────────────── + if len(atomic.title) > 80: + issues.append(f"ERROR: title exceeds 80 chars ({len(atomic.title)})") + + # Detect garbage pattern: "Prüfung der {raw_infinitive}" (leaked action) + for i, tp in enumerate(atomic.test_procedure): + for inf in _RAW_INFINITIVES: + if re.search( + rf"\b(?:der|des|die)\s+{re.escape(inf)}\b", tp, re.IGNORECASE, + ): + issues.append( + f"ERROR: test_procedure[{i}] contains raw infinitive '{inf}'" + ) + break + + for i, ev in enumerate(atomic.evidence): + if not ev.strip(): + issues.append(f"ERROR: evidence[{i}] is empty string") + + # ── Warnregeln ───────────────────────────────────────── + confidence = getattr(atomic, "_decomposition_confidence", None) + if confidence is not None and confidence < 0.5: + issues.append(f"WARN: low confidence ({confidence})") + + if object_class == "general": + issues.append("WARN: object_class is 'general' (unclassified)") + + if getattr(atomic, "_is_composite", False): + issues.append("WARN: composite/framework obligation — requires further decomposition") + + for issue in issues: + if issue.startswith("ERROR:"): + logger.warning("Validation: %s — title=%s", issue, atomic.title[:60]) + else: + logger.debug("Validation: %s — title=%s", issue, atomic.title[:60]) + + return issues + + +# ── 8. Confidence Scoring ─────────────────────────────────────────────── + +def _score_pass0b_confidence( + action_type: str, + object_class: str, + trigger_q: str, + has_specific_template: bool, +) -> float: + """Score decomposition confidence for a Pass 0b candidate.""" + score = 0.3 # base + if action_type != "default": + score += 0.25 + if object_class != "general": + score += 0.20 + if trigger_q: + score += 0.10 + if has_specific_template: + score += 0.15 + return round(min(score, 1.0), 2) + + +# ── 9. Compose Function ───────────────────────────────────────────────── + + +def _compose_deterministic( + obligation_text: str, + action: str, + object_: str, + parent_title: str, + parent_severity: str, + parent_category: str, + is_test: bool, + is_reporting: bool, + trigger_type: Optional[str] = None, + condition: Optional[str] = None, +) -> "AtomicControlCandidate": + """Compose an atomic control deterministically from obligation data. + + No LLM required. Uses action-type classification, object-class + matching, and trigger-aware templates. Generates: + - Title as '{Object} {state suffix}' + - Statement as '{condition_prefix} {object} ist {trigger} {action}' + - Evidence/test bundles from (action_type, object_class) matrix + - Pattern candidates for downstream categorization + - Merge hint for downstream dedup + - Structured timing (deadline_hours, frequency) + - Confidence score + - Validation issues (Negativregeln) + """ + # Override action type for flagged obligations + if is_test: + action_type = "test" + elif is_reporting: + action_type = "report" + else: + action_type = _classify_action(action) + + object_class = _classify_object(object_) + + # Template lookup: specific combo → action base → default + has_specific = (action_type, object_class) in _SPECIFIC_TEMPLATES + template = ( + _SPECIFIC_TEMPLATES.get((action_type, object_class)) + or _ACTION_TEMPLATES.get(action_type) + or _DEFAULT_ACTION_TEMPLATE + ) + + # Object for template substitution (fallback to parent title) + obj_display = object_.strip() if object_ else parent_title + + # ── Title: "{Object} {Zustand}" ─────────────────────────── + state = _ACTION_STATE_SUFFIX.get(action_type, "umgesetzt") + if object_: + title = _truncate_title(f"{object_.strip()} {state}") + elif action: + title = _truncate_title(f"{action.strip().capitalize()} {state}") + else: + title = _truncate_title(f"{parent_title} {state}") + + # ── Objective = obligation text (the normative statement) ─ + objective = obligation_text.strip()[:2000] + + # ── Requirements = obligation as concrete requirement ───── + requirements = [obligation_text.strip()] if obligation_text else [] + + # ── Test procedure from templates with object substitution + test_procedure = [ + tp.replace("{object}", obj_display) + for tp in template["test_procedure"] + ] + + # ── Trigger qualifier → add timing test step ────────────── + trigger_q = _extract_trigger_qualifier(trigger_type, obligation_text) + if trigger_q and test_procedure: + test_procedure.append( + f"Prüfung der Frist-/Trigger-Einhaltung: {trigger_q}" + ) + + # ── Evidence from templates ─────────────────────────────── + evidence = list(template["evidence"]) + + # ── Merge hint for downstream dedup ─────────────────────── + norm_obj = _normalize_object(object_) + trigger_key = trigger_type or "none" + merge_hint = f"{action_type}:{norm_obj}:{trigger_key}" + + # ── Statement: structured normative sentence ────────────── + condition_prefix = "" + if condition and condition.strip(): + condition_prefix = condition.strip().rstrip(",") + "," + trigger_clause = trigger_q if trigger_q else "" + obj_for_stmt = object_.strip() if object_ else parent_title + if obj_for_stmt: + parts = [p for p in [condition_prefix, obj_for_stmt, "ist", trigger_clause, state] if p] + statement = " ".join(parts) + else: + statement = "" + + # ── Pattern candidates ──────────────────────────────────── + pattern_candidates = _PATTERN_CANDIDATES_MAP.get( + (action_type, object_class), + _PATTERN_CANDIDATES_BY_ACTION.get(action_type, []), + ) + + # ── Structured timing ───────────────────────────────────── + deadline_hours, frequency = _extract_structured_timing(obligation_text) + + # ── Confidence score ────────────────────────────────────── + confidence = _score_pass0b_confidence( + action_type, object_class, trigger_q, has_specific, + ) + + atomic = AtomicControlCandidate( + title=title, + objective=objective, + requirements=requirements, + test_procedure=test_procedure, + evidence=evidence, + severity=_calibrate_severity(parent_severity, action_type), + category=parent_category or "governance", + ) + # Attach extra metadata (stored in generation_metadata) + atomic.domain = f"{action_type}:{object_class}" + atomic.source_regulation = merge_hint + atomic._decomposition_confidence = confidence # type: ignore[attr-defined] + atomic._statement = statement # type: ignore[attr-defined] + atomic._pattern_candidates = list(pattern_candidates) # type: ignore[attr-defined] + atomic._deadline_hours = deadline_hours # type: ignore[attr-defined] + atomic._frequency = frequency # type: ignore[attr-defined] + + # ── Composite / Framework / Container detection ──────────── + is_composite = _is_composite_obligation(obligation_text, object_) + is_container = _is_container_object(object_) + atomic._is_composite = is_composite or is_container # type: ignore[attr-defined] + if is_composite: + atomic._atomicity = "composite" # type: ignore[attr-defined] + elif is_container: + atomic._atomicity = "container" # type: ignore[attr-defined] + else: + atomic._atomicity = "atomic" # type: ignore[attr-defined] + atomic._requires_decomposition = is_composite or is_container # type: ignore[attr-defined] + + # ── Validate (log issues, never reject) ─────────────────── + validation_issues = _validate_atomic_control(atomic, action_type, object_class) + atomic._validation_issues = validation_issues # type: ignore[attr-defined] + + return atomic + + +def _build_pass0b_prompt( + obligation_text: str, action: str, object_: str, + parent_title: str, parent_category: str, source_ref: str, +) -> str: + return f"""\ +Erstelle aus der folgenden Pflicht ein atomares Control. + +PFLICHT: {obligation_text} +HANDLUNG: {action} +GEGENSTAND: {object_} + +KONTEXT (Ursprungs-Control): +Titel: {parent_title} +Kategorie: {parent_category} +Quellreferenz: {source_ref} + +Antworte als JSON: +{{ + "title": "Kurzer Titel (max 80 Zeichen, deutsch)", + "objective": "Was muss erreicht werden? (1-2 Sätze)", + "requirements": ["Konkrete Anforderung 1", "Anforderung 2"], + "test_procedure": ["Prüfschritt 1", "Prüfschritt 2"], + "evidence": ["Nachweis 1", "Nachweis 2"], + "severity": "critical|high|medium|low", + "category": "security|privacy|governance|operations|finance|reporting" +}}""" + + +# --------------------------------------------------------------------------- +# Batch Prompts (multiple controls/obligations per API call) +# --------------------------------------------------------------------------- + + +def _build_pass0a_batch_prompt(controls: list[dict]) -> str: + """Build a prompt for extracting obligations from multiple controls. + + Each control dict needs: control_id, title, objective, requirements, + test_procedure, source_ref. + """ + parts = [] + for i, ctrl in enumerate(controls, 1): + parts.append( + f"--- CONTROL {i} (ID: {ctrl['control_id']}) ---\n" + f"Titel: {ctrl['title']}\n" + f"Ziel: {ctrl['objective']}\n" + f"Anforderungen: {ctrl['requirements']}\n" + f"Prüfverfahren: {ctrl['test_procedure']}\n" + f"Quellreferenz: {ctrl['source_ref']}" + ) + + controls_text = "\n\n".join(parts) + ids_example = ", ".join(f'"{c["control_id"]}": [...]' for c in controls[:2]) + + return f"""\ +Analysiere die folgenden {len(controls)} Controls und extrahiere aus JEDEM \ +alle einzelnen normativen Pflichten. + +{controls_text} + +Antworte als JSON-Objekt. Fuer JEDES Control ein Key (die Control-ID) mit \ +einem Array von Pflichten: +{{ + {ids_example} +}} + +Jede Pflicht hat dieses Format: +{{ + "obligation_text": "Kurze, präzise Formulierung der Pflicht", + "action": "Hauptverb/Handlung", + "object": "Gegenstand der Pflicht", + "condition": null, + "normative_strength": "must", + "is_test_obligation": false, + "is_reporting_obligation": false +}}""" + + +def _build_pass0b_batch_prompt(obligations: list[dict]) -> str: + """Build a prompt for composing multiple atomic controls. + + Each obligation dict needs: candidate_id, obligation_text, action, + object, parent_title, parent_category, source_ref. + """ + parts = [] + for i, obl in enumerate(obligations, 1): + parts.append( + f"--- PFLICHT {i} (ID: {obl['candidate_id']}) ---\n" + f"PFLICHT: {obl['obligation_text']}\n" + f"HANDLUNG: {obl['action']}\n" + f"GEGENSTAND: {obl['object']}\n" + f"KONTEXT: {obl['parent_title']} | {obl['parent_category']}\n" + f"Quellreferenz: {obl['source_ref']}" + ) + + obligations_text = "\n\n".join(parts) + ids_example = ", ".join(f'"{o["candidate_id"]}": {{...}}' for o in obligations[:2]) + + return f"""\ +Erstelle aus den folgenden {len(obligations)} Pflichten je ein atomares Control. + +{obligations_text} + +Antworte als JSON-Objekt. Fuer JEDE Pflicht ein Key (die Pflicht-ID): +{{ + {ids_example} +}} + +Jedes Control hat dieses Format: +{{ + "title": "Kurzer Titel (max 80 Zeichen, deutsch)", + "objective": "Was muss erreicht werden? (1-2 Sätze)", + "requirements": ["Konkrete Anforderung 1", "Anforderung 2"], + "test_procedure": ["Prüfschritt 1", "Prüfschritt 2"], + "evidence": ["Nachweis 1", "Nachweis 2"], + "severity": "critical|high|medium|low", + "category": "security|privacy|governance|operations|finance|reporting" +}}""" + + +# --------------------------------------------------------------------------- +# Anthropic API (with prompt caching) +# --------------------------------------------------------------------------- + + +async def _llm_anthropic( + prompt: str, + system_prompt: str, + max_tokens: int = 8192, +) -> str: + """Call Anthropic Messages API with prompt caching for system prompt.""" + if not ANTHROPIC_API_KEY: + raise RuntimeError("ANTHROPIC_API_KEY not set") + + headers = { + "x-api-key": ANTHROPIC_API_KEY, + "anthropic-version": "2023-06-01", + "content-type": "application/json", + } + payload = { + "model": ANTHROPIC_MODEL, + "max_tokens": max_tokens, + "system": [ + { + "type": "text", + "text": system_prompt, + "cache_control": {"type": "ephemeral"}, + } + ], + "messages": [{"role": "user", "content": prompt}], + } + + try: + async with httpx.AsyncClient(timeout=LLM_TIMEOUT) as client: + resp = await client.post( + f"{ANTHROPIC_API_URL}/messages", + headers=headers, + json=payload, + ) + if resp.status_code != 200: + logger.error( + "Anthropic API %d: %s", resp.status_code, resp.text[:300] + ) + return "" + data = resp.json() + # Log cache performance + usage = data.get("usage", {}) + cached = usage.get("cache_read_input_tokens", 0) + if cached > 0: + logger.debug( + "Prompt cache hit: %d cached tokens", cached + ) + content = data.get("content", []) + if content and isinstance(content, list): + return content[0].get("text", "") + return "" + except Exception as e: + logger.error("Anthropic request failed: %s", e) + return "" + + +# --------------------------------------------------------------------------- +# Anthropic Batch API (50% cost reduction, async processing) +# --------------------------------------------------------------------------- + + +async def create_anthropic_batch( + requests: list[dict], +) -> dict: + """Submit a batch of requests to Anthropic Batch API. + + Each request: {"custom_id": "...", "params": {model, max_tokens, system, messages}} + Returns batch metadata including batch_id. + """ + if not ANTHROPIC_API_KEY: + raise RuntimeError("ANTHROPIC_API_KEY not set") + + headers = { + "x-api-key": ANTHROPIC_API_KEY, + "anthropic-version": "2023-06-01", + "content-type": "application/json", + } + + async with httpx.AsyncClient(timeout=60) as client: + resp = await client.post( + f"{ANTHROPIC_API_URL}/messages/batches", + headers=headers, + json={"requests": requests}, + ) + if resp.status_code not in (200, 201): + raise RuntimeError( + f"Batch API failed {resp.status_code}: {resp.text[:500]}" + ) + return resp.json() + + +async def check_batch_status(batch_id: str) -> dict: + """Check the processing status of a batch.""" + headers = { + "x-api-key": ANTHROPIC_API_KEY, + "anthropic-version": "2023-06-01", + } + + async with httpx.AsyncClient(timeout=30) as client: + resp = await client.get( + f"{ANTHROPIC_API_URL}/messages/batches/{batch_id}", + headers=headers, + ) + resp.raise_for_status() + return resp.json() + + +async def fetch_batch_results(batch_id: str) -> list[dict]: + """Fetch results of a completed batch. Returns list of result objects.""" + headers = { + "x-api-key": ANTHROPIC_API_KEY, + "anthropic-version": "2023-06-01", + } + + async with httpx.AsyncClient(timeout=120) as client: + resp = await client.get( + f"{ANTHROPIC_API_URL}/messages/batches/{batch_id}/results", + headers=headers, + ) + resp.raise_for_status() + # Response is JSONL (one JSON object per line) + results = [] + for line in resp.text.strip().split("\n"): + if line.strip(): + results.append(json.loads(line)) + return results + + +# --------------------------------------------------------------------------- +# Parse helpers +# --------------------------------------------------------------------------- + + +def _parse_json_array(text: str) -> list[dict]: + """Extract a JSON array from LLM response text.""" + # Try direct parse + try: + result = json.loads(text) + if isinstance(result, list): + return result + if isinstance(result, dict): + return [result] + except json.JSONDecodeError: + pass + + # Try extracting JSON array block + match = re.search(r"\[[\s\S]*\]", text) + if match: + try: + result = json.loads(match.group()) + if isinstance(result, list): + return result + except json.JSONDecodeError: + pass + + return [] + + +def _parse_json_object(text: str) -> dict: + """Extract a JSON object from LLM response text.""" + try: + result = json.loads(text) + if isinstance(result, dict): + return result + except json.JSONDecodeError: + pass + + match = re.search(r"\{[\s\S]*\}", text) + if match: + try: + result = json.loads(match.group()) + if isinstance(result, dict): + return result + except json.JSONDecodeError: + pass + + return {} + + +def _ensure_list(val) -> list: + """Ensure value is a list.""" + if isinstance(val, list): + return val + if isinstance(val, str): + return [val] if val else [] + return [] + + +# --------------------------------------------------------------------------- +# Decomposition Pass +# --------------------------------------------------------------------------- + + +class DecompositionPass: + """Pass 0: Decompose Rich Controls into atomic candidates. + + Usage:: + + decomp = DecompositionPass(db=session) + stats_0a = await decomp.run_pass0a(limit=100) + stats_0b = await decomp.run_pass0b(limit=100) + """ + + def __init__(self, db: Session, dedup_enabled: bool = False): + self.db = db + self._dedup = None + if dedup_enabled: + from services.control_dedup import ( + ControlDedupChecker, DEDUP_ENABLED, + ) + if DEDUP_ENABLED: + self._dedup = ControlDedupChecker(db) + + # ------------------------------------------------------------------- + # Pass 0a: Obligation Extraction + # ------------------------------------------------------------------- + + async def run_pass0a( + self, + limit: int = 0, + batch_size: int = 0, + use_anthropic: bool = False, + category_filter: Optional[str] = None, + source_filter: Optional[str] = None, + ) -> dict: + """Extract obligation candidates from rich controls. + + Args: + limit: Max controls to process (0 = no limit). + batch_size: Controls per LLM call (0 = use DECOMPOSITION_BATCH_SIZE + env var, or 1 for single mode). Only >1 with Anthropic. + use_anthropic: Use Anthropic API (True) or Ollama (False). + category_filter: Only process controls matching this category + (comma-separated, e.g. "security,privacy"). + source_filter: Only process controls from these source regulations + (comma-separated, e.g. "Maschinenverordnung,Cyber Resilience Act"). + Matches against source_citation->>'source' using ILIKE. + """ + if batch_size <= 0: + batch_size = DECOMPOSITION_BATCH_SIZE if use_anthropic else 1 + + # Find rich controls not yet decomposed + query = """ + SELECT cc.id, cc.control_id, cc.title, cc.objective, + cc.requirements, cc.test_procedure, + cc.source_citation, cc.category + FROM canonical_controls cc + WHERE cc.release_state NOT IN ('deprecated') + AND cc.parent_control_uuid IS NULL + AND NOT EXISTS ( + SELECT 1 FROM obligation_candidates oc + WHERE oc.parent_control_uuid = cc.id + ) + """ + params = {} + if category_filter: + cats = [c.strip() for c in category_filter.split(",") if c.strip()] + if cats: + query += " AND cc.category IN :cats" + params["cats"] = tuple(cats) + + if source_filter: + sources = [s.strip() for s in source_filter.split(",") if s.strip()] + if sources: + clauses = [] + for idx, src in enumerate(sources): + key = f"src_{idx}" + clauses.append(f"cc.source_citation::text ILIKE :{key}") + params[key] = f"%{src}%" + query += " AND (" + " OR ".join(clauses) + ")" + + query += " ORDER BY cc.created_at" + if limit > 0: + query += f" LIMIT {limit}" + + rows = self.db.execute(text(query), params).fetchall() + + stats = { + "controls_processed": 0, + "obligations_extracted": 0, + "obligations_validated": 0, + "obligations_rejected": 0, + "controls_skipped_empty": 0, + "llm_calls": 0, + "errors": 0, + "provider": "anthropic" if use_anthropic else "deterministic", + "batch_size": batch_size, + } + + # Prepare control data + prepared = [] + for row in rows: + title = row[2] or "" + objective = row[3] or "" + req_str = _format_field(row[4] or "") + test_str = _format_field(row[5] or "") + source_str = _format_citation(row[6] or "") + + if not title and not objective and not req_str: + stats["controls_skipped_empty"] += 1 + continue + + prepared.append({ + "uuid": str(row[0]), + "control_id": row[1] or "", + "title": title, + "objective": objective, + "requirements": req_str, + "test_procedure": test_str, + "source_ref": source_str, + "category": row[7] or "", + }) + + # Process in batches + for i in range(0, len(prepared), batch_size): + batch = prepared[i : i + batch_size] + try: + if use_anthropic and len(batch) > 1: + # Batched Anthropic call + prompt = _build_pass0a_batch_prompt(batch) + llm_response = await _llm_anthropic( + prompt=prompt, + system_prompt=_PASS0A_SYSTEM_PROMPT, + max_tokens=max(8192, len(batch) * 2000), + ) + stats["llm_calls"] += 1 + results_by_id = _parse_json_object(llm_response) + for ctrl in batch: + raw_obls = results_by_id.get(ctrl["control_id"], []) + if not isinstance(raw_obls, list): + raw_obls = [raw_obls] if raw_obls else [] + if not raw_obls: + raw_obls = [_fallback_obligation(ctrl)] + self._process_pass0a_obligations( + raw_obls, ctrl["control_id"], ctrl["uuid"], stats + ) + stats["controls_processed"] += 1 + elif use_anthropic: + # Single Anthropic call + ctrl = batch[0] + prompt = _build_pass0a_prompt( + title=ctrl["title"], objective=ctrl["objective"], + requirements=ctrl["requirements"], + test_procedure=ctrl["test_procedure"], + source_ref=ctrl["source_ref"], + ) + llm_response = await _llm_anthropic( + prompt=prompt, + system_prompt=_PASS0A_SYSTEM_PROMPT, + ) + stats["llm_calls"] += 1 + raw_obls = _parse_json_array(llm_response) + if not raw_obls: + raw_obls = [_fallback_obligation(ctrl)] + self._process_pass0a_obligations( + raw_obls, ctrl["control_id"], ctrl["uuid"], stats + ) + stats["controls_processed"] += 1 + else: + # Ollama (single only) + from services.obligation_extractor import _llm_ollama + ctrl = batch[0] + prompt = _build_pass0a_prompt( + title=ctrl["title"], objective=ctrl["objective"], + requirements=ctrl["requirements"], + test_procedure=ctrl["test_procedure"], + source_ref=ctrl["source_ref"], + ) + llm_response = await _llm_ollama( + prompt=prompt, + system_prompt=_PASS0A_SYSTEM_PROMPT, + ) + stats["llm_calls"] += 1 + raw_obls = _parse_json_array(llm_response) + if not raw_obls: + raw_obls = [_fallback_obligation(ctrl)] + self._process_pass0a_obligations( + raw_obls, ctrl["control_id"], ctrl["uuid"], stats + ) + stats["controls_processed"] += 1 + + # Commit after each successful sub-batch to avoid losing work + self.db.commit() + + except Exception as e: + ids = ", ".join(c["control_id"] for c in batch) + logger.error("Pass 0a failed for [%s]: %s", ids, e) + stats["errors"] += 1 + try: + self.db.rollback() + except Exception: + pass + logger.info("Pass 0a: %s", stats) + return stats + + _NORMATIVE_STRENGTH_MAP = { + "muss": "must", "must": "must", + "soll": "should", "should": "should", + "kann": "may", "may": "may", + } + + def _process_pass0a_obligations( + self, + raw_obligations: list[dict], + control_id: str, + control_uuid: str, + stats: dict, + ) -> None: + """Validate and write obligation candidates from LLM output.""" + for idx, raw in enumerate(raw_obligations): + raw_strength = raw.get("normative_strength", "must").lower().strip() + normative_strength = self._NORMATIVE_STRENGTH_MAP.get( + raw_strength, "must" + ) + cand = ObligationCandidate( + candidate_id=f"OC-{control_id}-{idx + 1:02d}", + parent_control_uuid=control_uuid, + obligation_text=raw.get("obligation_text", ""), + action=raw.get("action", ""), + object_=raw.get("object", ""), + condition=raw.get("condition"), + normative_strength=normative_strength, + is_test_obligation=bool(raw.get("is_test_obligation", False)), + is_reporting_obligation=bool(raw.get("is_reporting_obligation", False)), + ) + + # Auto-detect test/reporting if LLM missed it + if not cand.is_test_obligation and _TEST_RE.search(cand.obligation_text): + cand.is_test_obligation = True + if not cand.is_reporting_obligation and _REPORTING_RE.search(cand.obligation_text): + cand.is_reporting_obligation = True + + # Quality gate + obligation type classification + flags = quality_gate(cand) + cand.quality_flags = flags + cand.extraction_confidence = _compute_extraction_confidence(flags) + cand.obligation_type = flags.get("obligation_type", "empfehlung") + + if passes_quality_gate(flags): + cand.release_state = "validated" + stats["obligations_validated"] += 1 + else: + cand.release_state = "rejected" + stats["obligations_rejected"] += 1 + + self._write_obligation_candidate(cand) + stats["obligations_extracted"] += 1 + + # ------------------------------------------------------------------- + # Pass 0b: Atomic Control Composition + # ------------------------------------------------------------------- + + async def run_pass0b( + self, + limit: int = 0, + batch_size: int = 0, + use_anthropic: bool = False, + ) -> dict: + """Compose atomic controls from validated obligation candidates. + + Args: + limit: Max candidates to process (0 = no limit). + batch_size: Commit interval (0 = auto). For LLM: API batch size. + use_anthropic: Use Anthropic API (True) or deterministic engine (False). + """ + if batch_size <= 0: + batch_size = DECOMPOSITION_BATCH_SIZE if use_anthropic else 50 + + query = """ + SELECT oc.id, oc.candidate_id, oc.parent_control_uuid, + oc.obligation_text, oc.action, oc.object, + oc.condition, + oc.is_test_obligation, oc.is_reporting_obligation, + cc.title AS parent_title, + cc.category AS parent_category, + cc.source_citation AS parent_citation, + cc.severity AS parent_severity, + cc.control_id AS parent_control_id, + oc.trigger_type, + oc.is_implementation_specific + FROM obligation_candidates oc + JOIN canonical_controls cc ON cc.id = oc.parent_control_uuid + WHERE oc.release_state = 'validated' + AND oc.merged_into_id IS NULL + AND NOT EXISTS ( + SELECT 1 FROM canonical_controls ac + WHERE ac.parent_control_uuid = oc.parent_control_uuid + AND ac.decomposition_method = 'pass0b' + AND ac.release_state NOT IN ('deprecated', 'duplicate') + AND ac.title LIKE '%' || LEFT(oc.action, 20) || '%' + ) + """ + if limit > 0: + query += f" LIMIT {limit}" + + rows = self.db.execute(text(query)).fetchall() + + stats = { + "candidates_processed": 0, + "controls_created": 0, + "llm_failures": 0, + "llm_calls": 0, + "errors": 0, + "provider": "anthropic" if use_anthropic else "deterministic", + "batch_size": batch_size, + "dedup_enabled": self._dedup is not None, + "dedup_linked": 0, + "dedup_review": 0, + "skipped_merged": 0, + } + + # Prepare obligation data + prepared = [] + for row in rows: + prepared.append({ + "oc_id": str(row[0]), + "candidate_id": row[1] or "", + "parent_uuid": str(row[2]), + "obligation_text": row[3] or "", + "action": row[4] or "", + "object": row[5] or "", + "condition": row[6] or "", + "is_test": row[7], + "is_reporting": row[8], + "parent_title": row[9] or "", + "parent_category": row[10] or "", + "parent_citation": row[11] or "", + "parent_severity": row[12] or "medium", + "parent_control_id": row[13] or "", + "source_ref": _format_citation(row[11] or ""), + "trigger_type": row[14] or "continuous", + "is_implementation_specific": row[15] or False, + }) + + # Process in batches + for i in range(0, len(prepared), batch_size): + batch = prepared[i : i + batch_size] + try: + if use_anthropic and len(batch) > 1: + # Batched Anthropic call + prompt = _build_pass0b_batch_prompt(batch) + llm_response = await _llm_anthropic( + prompt=prompt, + system_prompt=_PASS0B_SYSTEM_PROMPT, + max_tokens=min(16384, max(4096, len(batch) * 500)), + ) + stats["llm_calls"] += 1 + results_by_id = _parse_json_object(llm_response) + for obl in batch: + parsed = results_by_id.get(obl["candidate_id"], {}) + await self._process_pass0b_control(obl, parsed, stats) + elif use_anthropic: + obl = batch[0] + prompt = _build_pass0b_prompt( + obligation_text=obl["obligation_text"], + action=obl["action"], object_=obl["object"], + parent_title=obl["parent_title"], + parent_category=obl["parent_category"], + source_ref=obl["source_ref"], + ) + llm_response = await _llm_anthropic( + prompt=prompt, + system_prompt=_PASS0B_SYSTEM_PROMPT, + ) + stats["llm_calls"] += 1 + parsed = _parse_json_object(llm_response) + await self._process_pass0b_control(obl, parsed, stats) + else: + # Deterministic engine — no LLM required + for obl in batch: + await self._route_and_compose(obl, stats) + + # Commit after each successful sub-batch + self.db.commit() + + except Exception as e: + ids = ", ".join(o["candidate_id"] for o in batch) + logger.error("Pass 0b failed for [%s]: %s", ids, e) + stats["errors"] += 1 + try: + self.db.rollback() + except Exception: + pass + logger.info("Pass 0b: %s", stats) + return stats + + async def _route_and_compose( + self, obl: dict, stats: dict, + ) -> None: + """Route an obligation through the framework detection layer, + then compose atomic controls. + + Routing types: + - atomic: compose directly via _compose_deterministic + - compound: split compound verbs, compose each + - framework_container: decompose via framework registry, + then compose each sub-obligation + """ + from services.framework_decomposition import ( + classify_routing, + decompose_framework_container, + ) + + routing = classify_routing( + obligation_text=obl["obligation_text"], + action_raw=obl["action"], + object_raw=obl["object"], + condition_raw=obl.get("condition"), + ) + + if routing.routing_type == "framework_container" and routing.framework_ref: + # Decompose framework container into sub-obligations + result = decompose_framework_container( + obligation_candidate_id=obl["candidate_id"], + parent_control_id=obl["parent_control_id"], + obligation_text=obl["obligation_text"], + framework_ref=routing.framework_ref, + framework_domain=routing.framework_domain, + ) + stats.setdefault("framework_decomposed", 0) + stats.setdefault("framework_sub_obligations", 0) + + if result.release_state == "decomposed" and result.decomposed_obligations: + stats["framework_decomposed"] += 1 + stats["framework_sub_obligations"] += len(result.decomposed_obligations) + logger.info( + "Framework decomposition: %s → %s/%s → %d sub-obligations", + obl["candidate_id"], routing.framework_ref, + routing.framework_domain, len(result.decomposed_obligations), + ) + # Compose each sub-obligation + for d_obl in result.decomposed_obligations: + sub_obl = { + **obl, + "obligation_text": d_obl.obligation_text, + "action": d_obl.action_raw, + "object": d_obl.object_raw, + } + sub_actions = _split_compound_action(sub_obl["action"]) + for sub_action in sub_actions: + atomic = _compose_deterministic( + obligation_text=sub_obl["obligation_text"], + action=sub_action, + object_=sub_obl["object"], + parent_title=obl["parent_title"], + parent_severity=obl["parent_severity"], + parent_category=obl["parent_category"], + is_test=obl["is_test"], + is_reporting=obl["is_reporting"], + trigger_type=obl.get("trigger_type"), + condition=obl.get("condition"), + ) + # Enrich gen_meta with framework info + atomic._framework_ref = routing.framework_ref # type: ignore[attr-defined] + atomic._framework_domain = routing.framework_domain # type: ignore[attr-defined] + atomic._framework_subcontrol_id = d_obl.subcontrol_id # type: ignore[attr-defined] + atomic._decomposition_source = "framework_decomposition" # type: ignore[attr-defined] + await self._process_pass0b_control( + obl, {}, stats, atomic=atomic, + ) + return + else: + # Unmatched framework — fall through to normal composition + logger.warning( + "Framework decomposition unmatched: %s — %s", + obl["candidate_id"], result.issues, + ) + + # Atomic or compound or unmatched framework: normal composition + sub_actions = _split_compound_action(obl["action"]) + for sub_action in sub_actions: + atomic = _compose_deterministic( + obligation_text=obl["obligation_text"], + action=sub_action, + object_=obl["object"], + parent_title=obl["parent_title"], + parent_severity=obl["parent_severity"], + parent_category=obl["parent_category"], + is_test=obl["is_test"], + is_reporting=obl["is_reporting"], + trigger_type=obl.get("trigger_type"), + condition=obl.get("condition"), + ) + await self._process_pass0b_control( + obl, {}, stats, atomic=atomic, + ) + + async def _process_pass0b_control( + self, obl: dict, parsed: dict, stats: dict, + atomic: Optional[AtomicControlCandidate] = None, + ) -> None: + """Create atomic control from deterministic engine, LLM output, or fallback. + + If dedup is enabled, checks for duplicates before insertion: + - LINK: adds parent link to existing control instead of creating new + - REVIEW: queues for human review, does not create control + - NEW: creates new control and indexes in Qdrant + """ + if atomic is not None: + # Deterministic engine — atomic already composed + pass + elif not parsed or not parsed.get("title"): + # LLM failed → use deterministic engine as fallback + atomic = _compose_deterministic( + obligation_text=obl["obligation_text"], + action=obl["action"], object_=obl["object"], + parent_title=obl["parent_title"], + parent_severity=obl["parent_severity"], + parent_category=obl["parent_category"], + is_test=obl["is_test"], + is_reporting=obl["is_reporting"], + condition=obl.get("condition"), + ) + stats["llm_failures"] += 1 + else: + atomic = AtomicControlCandidate( + title=parsed.get("title", "")[:200], + objective=parsed.get("objective", "")[:2000], + requirements=_ensure_list(parsed.get("requirements", [])), + test_procedure=_ensure_list(parsed.get("test_procedure", [])), + evidence=_ensure_list(parsed.get("evidence", [])), + severity=_normalize_severity( + parsed.get("severity", obl["parent_severity"]) + ), + category=parsed.get("category", obl["parent_category"]), + ) + + atomic.parent_control_uuid = obl["parent_uuid"] + atomic.obligation_candidate_id = obl["candidate_id"] + + # Cap severity for implementation-specific obligations + if obl.get("is_implementation_specific") and atomic.severity in ( + "critical", "high" + ): + atomic.severity = "medium" + + # Override category for test obligations + if obl.get("is_test"): + atomic.category = "testing" + + # ── Dedup check (if enabled) ──────────────────────────── + if self._dedup: + pattern_id = None + # Try to get pattern_id from parent control + pid_row = self.db.execute(text( + "SELECT pattern_id FROM canonical_controls WHERE id = CAST(:uid AS uuid)" + ), {"uid": obl["parent_uuid"]}).fetchone() + if pid_row: + pattern_id = pid_row[0] + + result = await self._dedup.check_duplicate( + action=obl.get("action", ""), + obj=obl.get("object", ""), + title=atomic.title, + pattern_id=pattern_id, + ) + + if result.verdict == "link": + self._dedup.add_parent_link( + control_uuid=result.matched_control_uuid, + parent_control_uuid=obl["parent_uuid"], + link_type="dedup_merge", + confidence=result.similarity_score, + ) + stats.setdefault("dedup_linked", 0) + stats["dedup_linked"] += 1 + stats["candidates_processed"] += 1 + logger.info("Dedup LINK: %s → %s (%.3f, %s)", + atomic.title[:60], result.matched_control_id, + result.similarity_score, result.stage) + return + + if result.verdict == "review": + self._dedup.write_review( + candidate_control_id=atomic.candidate_id or "", + candidate_title=atomic.title, + candidate_objective=atomic.objective, + result=result, + parent_control_uuid=obl["parent_uuid"], + obligation_candidate_id=obl.get("oc_id"), + ) + stats.setdefault("dedup_review", 0) + stats["dedup_review"] += 1 + stats["candidates_processed"] += 1 + logger.info("Dedup REVIEW: %s ↔ %s (%.3f, %s)", + atomic.title[:60], result.matched_control_id, + result.similarity_score, result.stage) + return + + # ── Create new atomic control ─────────────────────────── + seq = self._next_atomic_seq(obl["parent_control_id"]) + atomic.candidate_id = f"{obl['parent_control_id']}-A{seq:02d}" + + new_uuid = self._write_atomic_control(atomic, obl) + + self.db.execute( + text(""" + UPDATE obligation_candidates + SET release_state = 'composed' + WHERE id = CAST(:oc_id AS uuid) + """), + {"oc_id": obl["oc_id"]}, + ) + + # Index in Qdrant for future dedup checks + if self._dedup and new_uuid: + pattern_id_val = None + pid_row2 = self.db.execute(text( + "SELECT pattern_id FROM canonical_controls WHERE id = CAST(:uid AS uuid)" + ), {"uid": obl["parent_uuid"]}).fetchone() + if pid_row2: + pattern_id_val = pid_row2[0] + + if pattern_id_val: + await self._dedup.index_control( + control_uuid=new_uuid, + control_id=atomic.candidate_id, + title=atomic.title, + action=obl.get("action", ""), + obj=obl.get("object", ""), + pattern_id=pattern_id_val, + ) + + stats["controls_created"] += 1 + stats["candidates_processed"] += 1 + + # ------------------------------------------------------------------- + # Merge Pass: Deduplicate implementation-level obligations + # ------------------------------------------------------------------- + + def run_merge_pass(self) -> dict: + """Merge implementation-level duplicate obligations within each parent. + + When the same parent control has multiple obligations with nearly + identical action+object (e.g. "SMS-Verbot" + "Policy-as-Code" both + implementing a communication restriction), keep the more abstract one + and mark the concrete one as merged. + + No LLM calls — purely rule-based using text similarity. + """ + stats = { + "parents_checked": 0, + "obligations_merged": 0, + "obligations_kept": 0, + } + + # Get all parents that have >1 validated obligation + parents = self.db.execute(text(""" + SELECT parent_control_uuid, count(*) AS cnt + FROM obligation_candidates + WHERE release_state = 'validated' + AND merged_into_id IS NULL + GROUP BY parent_control_uuid + HAVING count(*) > 1 + """)).fetchall() + + for parent_uuid, cnt in parents: + stats["parents_checked"] += 1 + obligs = self.db.execute(text(""" + SELECT id, candidate_id, obligation_text, action, object + FROM obligation_candidates + WHERE parent_control_uuid = CAST(:pid AS uuid) + AND release_state = 'validated' + AND merged_into_id IS NULL + ORDER BY created_at + """), {"pid": str(parent_uuid)}).fetchall() + + merged_ids = set() + oblig_list = list(obligs) + + for i in range(len(oblig_list)): + if str(oblig_list[i][0]) in merged_ids: + continue + for j in range(i + 1, len(oblig_list)): + if str(oblig_list[j][0]) in merged_ids: + continue + + action_i = (oblig_list[i][3] or "").lower().strip() + action_j = (oblig_list[j][3] or "").lower().strip() + obj_i = (oblig_list[i][4] or "").lower().strip() + obj_j = (oblig_list[j][4] or "").lower().strip() + + # Check if actions are similar enough to be duplicates + if not _text_similar(action_i, action_j, threshold=0.75): + continue + if not _text_similar(obj_i, obj_j, threshold=0.60): + continue + + # Keep the more abstract one (shorter text = less specific) + text_i = oblig_list[i][2] or "" + text_j = oblig_list[j][2] or "" + if _is_more_implementation_specific(text_j, text_i): + survivor_id = str(oblig_list[i][0]) + merged_id = str(oblig_list[j][0]) + else: + survivor_id = str(oblig_list[j][0]) + merged_id = str(oblig_list[i][0]) + + self.db.execute(text(""" + UPDATE obligation_candidates + SET release_state = 'merged', + merged_into_id = CAST(:survivor AS uuid) + WHERE id = CAST(:merged AS uuid) + """), {"survivor": survivor_id, "merged": merged_id}) + + merged_ids.add(merged_id) + stats["obligations_merged"] += 1 + + # Commit per parent to avoid large transactions + self.db.commit() + + stats["obligations_kept"] = self.db.execute(text(""" + SELECT count(*) FROM obligation_candidates + WHERE release_state = 'validated' AND merged_into_id IS NULL + """)).fetchone()[0] + + logger.info("Merge pass: %s", stats) + return stats + + # ------------------------------------------------------------------- + # Enrich Pass: Add metadata to obligations + # ------------------------------------------------------------------- + + def enrich_obligations(self) -> dict: + """Add trigger_type and is_implementation_specific to obligations. + + Rule-based enrichment — no LLM calls. + """ + stats = { + "enriched": 0, + "trigger_event": 0, + "trigger_periodic": 0, + "trigger_continuous": 0, + "implementation_specific": 0, + } + + obligs = self.db.execute(text(""" + SELECT id, obligation_text, condition, action, object + FROM obligation_candidates + WHERE release_state = 'validated' + AND merged_into_id IS NULL + AND trigger_type IS NULL + """)).fetchall() + + for row in obligs: + oc_id = str(row[0]) + obl_text = row[1] or "" + condition = row[2] or "" + action = row[3] or "" + obj = row[4] or "" + + trigger = _classify_trigger_type(obl_text, condition) + impl = _is_implementation_specific_text(obl_text, action, obj) + + self.db.execute(text(""" + UPDATE obligation_candidates + SET trigger_type = :trigger, + is_implementation_specific = :impl + WHERE id = CAST(:oid AS uuid) + """), {"trigger": trigger, "impl": impl, "oid": oc_id}) + + stats["enriched"] += 1 + stats[f"trigger_{trigger}"] += 1 + if impl: + stats["implementation_specific"] += 1 + + self.db.commit() + logger.info("Enrich pass: %s", stats) + return stats + + # ------------------------------------------------------------------- + # Decomposition Status + # ------------------------------------------------------------------- + + def decomposition_status(self) -> dict: + """Return decomposition progress.""" + row = self.db.execute(text(""" + SELECT + (SELECT count(*) FROM canonical_controls + WHERE parent_control_uuid IS NULL + AND release_state NOT IN ('deprecated')) AS rich_controls, + (SELECT count(DISTINCT parent_control_uuid) FROM obligation_candidates) AS decomposed_controls, + (SELECT count(*) FROM obligation_candidates) AS total_candidates, + (SELECT count(*) FROM obligation_candidates WHERE release_state = 'validated') AS validated, + (SELECT count(*) FROM obligation_candidates WHERE release_state = 'rejected') AS rejected, + (SELECT count(*) FROM obligation_candidates WHERE release_state = 'composed') AS composed, + (SELECT count(*) FROM canonical_controls WHERE parent_control_uuid IS NOT NULL) AS atomic_controls, + (SELECT count(*) FROM obligation_candidates WHERE release_state = 'merged') AS merged, + (SELECT count(*) FROM obligation_candidates WHERE trigger_type IS NOT NULL) AS enriched + """)).fetchone() + + validated_for_0b = row[3] - (row[7] or 0) # validated minus merged + + return { + "rich_controls": row[0], + "decomposed_controls": row[1], + "total_candidates": row[2], + "validated": row[3], + "rejected": row[4], + "composed": row[5], + "atomic_controls": row[6], + "merged": row[7] or 0, + "enriched": row[8] or 0, + "ready_for_pass0b": validated_for_0b, + "decomposition_pct": round(row[1] / max(row[0], 1) * 100, 1), + "composition_pct": round(row[5] / max(validated_for_0b, 1) * 100, 1), + } + + # ------------------------------------------------------------------- + # DB Writers + # ------------------------------------------------------------------- + + def _write_obligation_candidate(self, cand: ObligationCandidate) -> None: + """Insert an obligation candidate into the DB.""" + self.db.execute( + text(""" + INSERT INTO obligation_candidates ( + parent_control_uuid, candidate_id, + obligation_text, action, object, condition, + normative_strength, is_test_obligation, + is_reporting_obligation, extraction_confidence, + quality_flags, release_state + ) VALUES ( + CAST(:parent_uuid AS uuid), :candidate_id, + :obligation_text, :action, :object, :condition, + :normative_strength, :is_test, :is_reporting, + :confidence, :quality_flags, :release_state + ) + """), + { + "parent_uuid": cand.parent_control_uuid, + "candidate_id": cand.candidate_id, + "obligation_text": cand.obligation_text, + "action": cand.action, + "object": cand.object_, + "condition": cand.condition, + "normative_strength": cand.normative_strength, + "is_test": cand.is_test_obligation, + "is_reporting": cand.is_reporting_obligation, + "confidence": cand.extraction_confidence, + "quality_flags": json.dumps(cand.quality_flags), + "release_state": cand.release_state, + }, + ) + + def _write_atomic_control( + self, atomic: AtomicControlCandidate, obl: dict, + ) -> Optional[str]: + """Insert an atomic control and create parent link. + + Returns the UUID of the newly created control, or None on failure. + Checks merge_hint to prevent duplicate controls under the same parent. + """ + parent_uuid = obl["parent_uuid"] + candidate_id = obl["candidate_id"] + + # ── Duplicate Guard: skip if same merge_hint already exists ── + merge_hint = getattr(atomic, "source_regulation", "") or "" + if merge_hint: + existing = self.db.execute( + text(""" + SELECT id::text FROM canonical_controls + WHERE parent_control_uuid = CAST(:parent AS uuid) + AND generation_metadata->>'merge_group_hint' = :hint + AND release_state NOT IN ('rejected', 'deprecated', 'duplicate') + LIMIT 1 + """), + {"parent": parent_uuid, "hint": merge_hint}, + ).fetchone() + if existing: + logger.debug( + "Duplicate guard: skipping %s — merge_hint %s already exists as %s", + candidate_id, merge_hint, existing[0], + ) + return existing[0] + + result = self.db.execute( + text(""" + INSERT INTO canonical_controls ( + control_id, title, objective, rationale, + scope, requirements, + test_procedure, evidence, severity, + open_anchors, category, + release_state, parent_control_uuid, + decomposition_method, + generation_metadata, + framework_id, + generation_strategy, pipeline_version + ) VALUES ( + :control_id, :title, :objective, :rationale, + :scope, :requirements, + :test_procedure, :evidence, + :severity, :open_anchors, :category, + 'draft', + CAST(:parent_uuid AS uuid), 'pass0b', + :gen_meta, + CAST(:framework_id AS uuid), + 'pass0b', 2 + ) + RETURNING id::text + """), + { + "control_id": atomic.candidate_id, + "title": atomic.title, + "objective": atomic.objective, + "rationale": getattr(atomic, "rationale", None) or "Aus Obligation abgeleitet.", + "scope": json.dumps({}), + "requirements": json.dumps(atomic.requirements), + "test_procedure": json.dumps(atomic.test_procedure), + "evidence": json.dumps(atomic.evidence), + "severity": atomic.severity, + "open_anchors": json.dumps([]), + "category": atomic.category, + "parent_uuid": parent_uuid, + "gen_meta": json.dumps({ + "decomposition_source": candidate_id, + "decomposition_method": "pass0b", + "engine_version": "v2", + "action_object_class": getattr(atomic, "domain", ""), + "merge_group_hint": atomic.source_regulation or "", + "decomposition_confidence": getattr( + atomic, "_decomposition_confidence", None + ), + "statement": getattr(atomic, "_statement", ""), + "pattern_candidates": getattr(atomic, "_pattern_candidates", []), + "deadline_hours": getattr(atomic, "_deadline_hours", None), + "frequency": getattr(atomic, "_frequency", None), + "validation_issues": getattr(atomic, "_validation_issues", []), + "is_composite": getattr(atomic, "_is_composite", False), + "atomicity": getattr(atomic, "_atomicity", "atomic"), + "requires_decomposition": getattr(atomic, "_requires_decomposition", False), + "framework_ref": getattr(atomic, "_framework_ref", None), + "framework_domain": getattr(atomic, "_framework_domain", None), + "framework_subcontrol_id": getattr(atomic, "_framework_subcontrol_id", None), + "decomposition_source": getattr(atomic, "_decomposition_source", "direct"), + }), + "framework_id": "14b1bdd2-abc7-4a43-adae-14471ee5c7cf", + }, + ) + + row = result.fetchone() + new_uuid = row[0] if row else None + + # Create M:N parent link (control_parent_links) + if new_uuid: + citation = _parse_citation(obl.get("parent_citation", "")) + self.db.execute( + text(""" + INSERT INTO control_parent_links + (control_uuid, parent_control_uuid, link_type, confidence, + source_regulation, source_article, obligation_candidate_id) + VALUES + (CAST(:cu AS uuid), CAST(:pu AS uuid), 'decomposition', 1.0, + :sr, :sa, CAST(:oci AS uuid)) + ON CONFLICT (control_uuid, parent_control_uuid) DO NOTHING + """), + { + "cu": new_uuid, + "pu": parent_uuid, + "sr": citation.get("source", ""), + "sa": citation.get("article", ""), + "oci": obl["oc_id"], + }, + ) + + return new_uuid + + def _next_atomic_seq(self, parent_control_id: str) -> int: + """Get the next sequence number for atomic controls under a parent.""" + result = self.db.execute( + text(""" + SELECT count(*) FROM canonical_controls + WHERE parent_control_uuid = ( + SELECT id FROM canonical_controls + WHERE control_id = :parent_id + LIMIT 1 + ) + """), + {"parent_id": parent_control_id}, + ).fetchone() + return (result[0] if result else 0) + 1 + + + # ------------------------------------------------------------------- + # Anthropic Batch API: Submit all controls as async batch (50% off) + # ------------------------------------------------------------------- + + async def submit_batch_pass0a( + self, + limit: int = 0, + batch_size: int = 5, + category_filter: Optional[str] = None, + source_filter: Optional[str] = None, + ) -> dict: + """Create an Anthropic Batch API request for Pass 0a. + + Groups controls into content-batches of `batch_size`, then submits + all batches as one Anthropic Batch (up to 10,000 requests). + Returns batch metadata for polling. + """ + query = """ + SELECT cc.id, cc.control_id, cc.title, cc.objective, + cc.requirements, cc.test_procedure, + cc.source_citation, cc.category + FROM canonical_controls cc + WHERE cc.release_state NOT IN ('deprecated') + AND cc.parent_control_uuid IS NULL + AND NOT EXISTS ( + SELECT 1 FROM obligation_candidates oc + WHERE oc.parent_control_uuid = cc.id + ) + """ + params = {} + if category_filter: + cats = [c.strip() for c in category_filter.split(",") if c.strip()] + if cats: + query += " AND cc.category IN :cats" + params["cats"] = tuple(cats) + + if source_filter: + sources = [s.strip() for s in source_filter.split(",") if s.strip()] + if sources: + clauses = [] + for idx, src in enumerate(sources): + key = f"src_{idx}" + clauses.append(f"cc.source_citation::text ILIKE :{key}") + params[key] = f"%{src}%" + query += " AND (" + " OR ".join(clauses) + ")" + + query += " ORDER BY cc.created_at" + if limit > 0: + query += f" LIMIT {limit}" + + rows = self.db.execute(text(query), params).fetchall() + + # Prepare control data (skip empty) + prepared = [] + for row in rows: + title = row[2] or "" + objective = row[3] or "" + req_str = _format_field(row[4] or "") + if not title and not objective and not req_str: + continue + prepared.append({ + "uuid": str(row[0]), + "control_id": row[1] or "", + "title": title, + "objective": objective, + "requirements": req_str, + "test_procedure": _format_field(row[5] or ""), + "source_ref": _format_citation(row[6] or ""), + "category": row[7] or "", + }) + + if not prepared: + return {"status": "empty", "total_controls": 0} + + # Build batch requests (each request = batch_size controls) + requests = [] + for i in range(0, len(prepared), batch_size): + batch = prepared[i : i + batch_size] + if len(batch) > 1: + prompt = _build_pass0a_batch_prompt(batch) + else: + ctrl = batch[0] + prompt = _build_pass0a_prompt( + title=ctrl["title"], objective=ctrl["objective"], + requirements=ctrl["requirements"], + test_procedure=ctrl["test_procedure"], + source_ref=ctrl["source_ref"], + ) + + # Control IDs in custom_id for result mapping + ids_str = "+".join(c["control_id"] for c in batch) + requests.append({ + "custom_id": f"p0a_{ids_str}", + "params": { + "model": ANTHROPIC_MODEL, + "max_tokens": max(8192, len(batch) * 2000), + "system": [ + { + "type": "text", + "text": _PASS0A_SYSTEM_PROMPT, + "cache_control": {"type": "ephemeral"}, + } + ], + "messages": [{"role": "user", "content": prompt}], + }, + }) + + batch_result = await create_anthropic_batch(requests) + batch_id = batch_result.get("id", "") + + logger.info( + "Batch API submitted: %s — %d requests (%d controls, batch_size=%d)", + batch_id, len(requests), len(prepared), batch_size, + ) + + return { + "status": "submitted", + "batch_id": batch_id, + "total_controls": len(prepared), + "total_requests": len(requests), + "batch_size": batch_size, + "category_filter": category_filter, + "source_filter": source_filter, + } + + async def submit_batch_pass0b( + self, + limit: int = 0, + batch_size: int = 5, + ) -> dict: + """Create an Anthropic Batch API request for Pass 0b.""" + query = """ + SELECT oc.id, oc.candidate_id, oc.parent_control_uuid, + oc.obligation_text, oc.action, oc.object, + oc.is_test_obligation, oc.is_reporting_obligation, + cc.title AS parent_title, + cc.category AS parent_category, + cc.source_citation AS parent_citation, + cc.severity AS parent_severity, + cc.control_id AS parent_control_id + FROM obligation_candidates oc + JOIN canonical_controls cc ON cc.id = oc.parent_control_uuid + WHERE oc.release_state = 'validated' + AND NOT EXISTS ( + SELECT 1 FROM canonical_controls ac + WHERE ac.parent_control_uuid = oc.parent_control_uuid + AND ac.decomposition_method = 'pass0b' + AND ac.release_state NOT IN ('deprecated', 'duplicate') + AND ac.title LIKE '%' || LEFT(oc.action, 20) || '%' + ) + """ + if limit > 0: + query += f" LIMIT {limit}" + + rows = self.db.execute(text(query)).fetchall() + + prepared = [] + for row in rows: + prepared.append({ + "oc_id": str(row[0]), + "candidate_id": row[1] or "", + "parent_uuid": str(row[2]), + "obligation_text": row[3] or "", + "action": row[4] or "", + "object": row[5] or "", + "is_test": row[6], + "is_reporting": row[7], + "parent_title": row[8] or "", + "parent_category": row[9] or "", + "parent_citation": row[10] or "", + "parent_severity": row[11] or "medium", + "parent_control_id": row[12] or "", + "source_ref": _format_citation(row[10] or ""), + }) + + if not prepared: + return {"status": "empty", "total_candidates": 0} + + requests = [] + for i in range(0, len(prepared), batch_size): + batch = prepared[i : i + batch_size] + if len(batch) > 1: + prompt = _build_pass0b_batch_prompt(batch) + else: + obl = batch[0] + prompt = _build_pass0b_prompt( + obligation_text=obl["obligation_text"], + action=obl["action"], object_=obl["object"], + parent_title=obl["parent_title"], + parent_category=obl["parent_category"], + source_ref=obl["source_ref"], + ) + + ids_str = "+".join(o["candidate_id"] for o in batch) + requests.append({ + "custom_id": f"p0b_{ids_str}", + "params": { + "model": ANTHROPIC_MODEL, + "max_tokens": max(8192, len(batch) * 1500), + "system": [ + { + "type": "text", + "text": _PASS0B_SYSTEM_PROMPT, + "cache_control": {"type": "ephemeral"}, + } + ], + "messages": [{"role": "user", "content": prompt}], + }, + }) + + batch_result = await create_anthropic_batch(requests) + batch_id = batch_result.get("id", "") + + logger.info( + "Batch API Pass 0b submitted: %s — %d requests (%d candidates)", + batch_id, len(requests), len(prepared), + ) + + return { + "status": "submitted", + "batch_id": batch_id, + "total_candidates": len(prepared), + "total_requests": len(requests), + "batch_size": batch_size, + } + + async def process_batch_results( + self, batch_id: str, pass_type: str = "0a", + ) -> dict: + """Fetch and process results from a completed Anthropic batch. + + Args: + batch_id: Anthropic batch ID. + pass_type: "0a" or "0b". + """ + # Check status first + status = await check_batch_status(batch_id) + if status.get("processing_status") != "ended": + return { + "status": "not_ready", + "processing_status": status.get("processing_status"), + "request_counts": status.get("request_counts", {}), + } + + results = await fetch_batch_results(batch_id) + stats = { + "results_processed": 0, + "results_succeeded": 0, + "results_failed": 0, + "errors": 0, + } + + if pass_type == "0a": + stats.update({ + "controls_processed": 0, + "obligations_extracted": 0, + "obligations_validated": 0, + "obligations_rejected": 0, + }) + else: + stats.update({ + "candidates_processed": 0, + "controls_created": 0, + "llm_failures": 0, + }) + + for result in results: + custom_id = result.get("custom_id", "") + result_data = result.get("result", {}) + stats["results_processed"] += 1 + + if result_data.get("type") != "succeeded": + stats["results_failed"] += 1 + logger.warning("Batch result failed: %s — %s", custom_id, result_data) + continue + + stats["results_succeeded"] += 1 + message = result_data.get("message", {}) + content = message.get("content", []) + text_content = content[0].get("text", "") if content else "" + + try: + if pass_type == "0a": + self._handle_batch_result_0a(custom_id, text_content, stats) + else: + await self._handle_batch_result_0b(custom_id, text_content, stats) + except Exception as e: + logger.error("Processing batch result %s: %s", custom_id, e) + stats["errors"] += 1 + + self.db.commit() + stats["status"] = "completed" + return stats + + def _handle_batch_result_0a( + self, custom_id: str, text_content: str, stats: dict, + ) -> None: + """Process a single Pass 0a batch result.""" + # custom_id format: p0a_CTRL-001+CTRL-002+... + prefix = "p0a_" + control_ids = custom_id[len(prefix):].split("+") if custom_id.startswith(prefix) else [] + + if len(control_ids) == 1: + raw_obls = _parse_json_array(text_content) + control_id = control_ids[0] + uuid_row = self.db.execute( + text("SELECT id FROM canonical_controls WHERE control_id = :cid LIMIT 1"), + {"cid": control_id}, + ).fetchone() + if not uuid_row: + return + control_uuid = str(uuid_row[0]) + if not raw_obls: + raw_obls = [{"obligation_text": control_id, "action": "sicherstellen", + "object": control_id}] + self._process_pass0a_obligations(raw_obls, control_id, control_uuid, stats) + stats["controls_processed"] += 1 + else: + results_by_id = _parse_json_object(text_content) + for control_id in control_ids: + uuid_row = self.db.execute( + text("SELECT id FROM canonical_controls WHERE control_id = :cid LIMIT 1"), + {"cid": control_id}, + ).fetchone() + if not uuid_row: + continue + control_uuid = str(uuid_row[0]) + raw_obls = results_by_id.get(control_id, []) + if not isinstance(raw_obls, list): + raw_obls = [raw_obls] if raw_obls else [] + if not raw_obls: + raw_obls = [{"obligation_text": control_id, "action": "sicherstellen", + "object": control_id}] + self._process_pass0a_obligations(raw_obls, control_id, control_uuid, stats) + stats["controls_processed"] += 1 + + async def _handle_batch_result_0b( + self, custom_id: str, text_content: str, stats: dict, + ) -> None: + """Process a single Pass 0b batch result.""" + prefix = "p0b_" + candidate_ids = custom_id[len(prefix):].split("+") if custom_id.startswith(prefix) else [] + + if len(candidate_ids) == 1: + parsed = _parse_json_object(text_content) + obl = self._load_obligation_for_0b(candidate_ids[0]) + if obl: + await self._process_pass0b_control(obl, parsed, stats) + else: + results_by_id = _parse_json_object(text_content) + for cand_id in candidate_ids: + parsed = results_by_id.get(cand_id, {}) + obl = self._load_obligation_for_0b(cand_id) + if obl: + await self._process_pass0b_control(obl, parsed, stats) + + def _load_obligation_for_0b(self, candidate_id: str) -> Optional[dict]: + """Load obligation data needed for Pass 0b processing.""" + row = self.db.execute( + text(""" + SELECT oc.id, oc.candidate_id, oc.parent_control_uuid, + oc.obligation_text, oc.action, oc.object, + oc.is_test_obligation, oc.is_reporting_obligation, + cc.title, cc.category, cc.source_citation, + cc.severity, cc.control_id + FROM obligation_candidates oc + JOIN canonical_controls cc ON cc.id = oc.parent_control_uuid + WHERE oc.candidate_id = :cid + """), + {"cid": candidate_id}, + ).fetchone() + if not row: + return None + return { + "oc_id": str(row[0]), + "candidate_id": row[1] or "", + "parent_uuid": str(row[2]), + "obligation_text": row[3] or "", + "action": row[4] or "", + "object": row[5] or "", + "is_test": row[6], + "is_reporting": row[7], + "parent_title": row[8] or "", + "parent_category": row[9] or "", + "parent_citation": row[10] or "", + "parent_severity": row[11] or "medium", + "parent_control_id": row[12] or "", + "source_ref": _format_citation(row[10] or ""), + } + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _fallback_obligation(ctrl: dict) -> dict: + """Create a single fallback obligation when LLM returns nothing.""" + return { + "obligation_text": ctrl.get("objective") or ctrl.get("title", ""), + "action": "sicherstellen", + "object": ctrl.get("title", ""), + "condition": None, + "normative_strength": "must", + "is_test_obligation": False, + "is_reporting_obligation": False, + } + + +def _format_field(value) -> str: + """Format a requirements/test_procedure field for the LLM prompt.""" + if not value: + return "" + if isinstance(value, str): + try: + parsed = json.loads(value) + if isinstance(parsed, list): + return "\n".join(f"- {item}" for item in parsed) + return value + except (json.JSONDecodeError, TypeError): + return value + if isinstance(value, list): + return "\n".join(f"- {item}" for item in value) + return str(value) + + +def _format_citation(citation) -> str: + """Format source_citation for display.""" + if not citation: + return "" + if isinstance(citation, str): + try: + c = json.loads(citation) + if isinstance(c, dict): + parts = [] + if c.get("source"): + parts.append(c["source"]) + if c.get("article"): + parts.append(c["article"]) + if c.get("paragraph"): + parts.append(c["paragraph"]) + return " ".join(parts) if parts else citation + except (json.JSONDecodeError, TypeError): + return citation + return str(citation) + + +def _parse_citation(citation) -> dict: + """Parse source_citation JSONB into a dict with source/article/paragraph.""" + if not citation: + return {} + if isinstance(citation, dict): + return citation + if isinstance(citation, str): + try: + c = json.loads(citation) + if isinstance(c, dict): + return c + except (json.JSONDecodeError, TypeError): + pass + return {} + + +def _compute_extraction_confidence(flags: dict) -> float: + """Compute confidence score from quality flags.""" + score = 0.0 + weights = { + "has_normative_signal": 0.30, + "single_action": 0.20, + "not_rationale": 0.20, + "not_evidence_only": 0.15, + "min_length": 0.10, + "has_parent_link": 0.05, + } + for flag, weight in weights.items(): + if flags.get(flag, False): + score += weight + return round(score, 2) + + +def _normalize_severity(val: str) -> str: + """Normalize severity value.""" + val = (val or "medium").lower().strip() + if val in ("critical", "high", "medium", "low"): + return val + return "medium" + + +# Action-type-based severity calibration: not every atomic control +# inherits the parent's severity. Definition and review controls are +# typically medium, while implementation controls stay high. +_ACTION_SEVERITY_CAP: dict[str, str] = { + "define": "medium", + "review": "medium", + "document": "medium", + "report": "medium", + "test": "medium", + "implement": "high", + "configure": "high", + "monitor": "high", + "enforce": "high", + "prevent": "high", + "exclude": "high", + "forbid": "high", + "invalidate": "high", + "issue": "high", + "rotate": "medium", +} + +# Severity ordering for cap comparison +_SEVERITY_ORDER = {"low": 0, "medium": 1, "high": 2, "critical": 3} + + +def _calibrate_severity(parent_severity: str, action_type: str) -> str: + """Calibrate severity based on action type. + + Implementation/enforcement inherits parent severity. + Definition/review/test/documentation caps at medium. + """ + parent = _normalize_severity(parent_severity) + cap = _ACTION_SEVERITY_CAP.get(action_type) + if not cap: + return parent + # Return the lower of parent severity and action-type cap + if _SEVERITY_ORDER.get(parent, 1) <= _SEVERITY_ORDER.get(cap, 1): + return parent + return cap + + +# _template_fallback removed — replaced by _compose_deterministic engine diff --git a/control-pipeline/services/framework_decomposition.py b/control-pipeline/services/framework_decomposition.py new file mode 100644 index 0000000..40010d2 --- /dev/null +++ b/control-pipeline/services/framework_decomposition.py @@ -0,0 +1,714 @@ +"""Framework Decomposition Engine — decomposes framework-container obligations. + +Sits between Pass 0a (obligation extraction) and Pass 0b (atomic control +composition). Detects obligations that reference a framework domain (e.g. +"CCM-Praktiken fuer AIS") and decomposes them into concrete sub-obligations +using an internal framework registry. + +Three routing types: + atomic → pass through to Pass 0b unchanged + compound → split compound verbs, then Pass 0b + framework_container → decompose via registry, then Pass 0b + +The registry is a set of JSON files under compliance/data/frameworks/. +""" + +import json +import logging +import os +import re +import uuid +from dataclasses import dataclass, field +from pathlib import Path +from typing import Optional + +logger = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# Registry loading +# --------------------------------------------------------------------------- + +_REGISTRY_DIR = Path(__file__).resolve().parent.parent / "data" / "frameworks" +_REGISTRY: dict[str, dict] = {} # framework_id → framework dict + + +def _load_registry() -> dict[str, dict]: + """Load all framework JSON files from the registry directory.""" + registry: dict[str, dict] = {} + if not _REGISTRY_DIR.is_dir(): + logger.warning("Framework registry dir not found: %s", _REGISTRY_DIR) + return registry + + for fpath in sorted(_REGISTRY_DIR.glob("*.json")): + try: + with open(fpath, encoding="utf-8") as f: + fw = json.load(f) + fw_id = fw.get("framework_id", fpath.stem) + registry[fw_id] = fw + logger.info( + "Loaded framework: %s (%d domains)", + fw_id, + len(fw.get("domains", [])), + ) + except Exception: + logger.exception("Failed to load framework file: %s", fpath) + return registry + + +def get_registry() -> dict[str, dict]: + """Return the global framework registry (lazy-loaded).""" + global _REGISTRY + if not _REGISTRY: + _REGISTRY = _load_registry() + return _REGISTRY + + +def reload_registry() -> dict[str, dict]: + """Force-reload the framework registry from disk.""" + global _REGISTRY + _REGISTRY = _load_registry() + return _REGISTRY + + +# --------------------------------------------------------------------------- +# Framework alias index (built from registry) +# --------------------------------------------------------------------------- + +def _build_alias_index(registry: dict[str, dict]) -> dict[str, str]: + """Build a lowercase alias → framework_id lookup.""" + idx: dict[str, str] = {} + for fw_id, fw in registry.items(): + # Framework-level aliases + idx[fw_id.lower()] = fw_id + name = fw.get("display_name", "") + if name: + idx[name.lower()] = fw_id + # Common short forms + for part in fw_id.lower().replace("_", " ").split(): + if len(part) >= 3: + idx[part] = fw_id + return idx + + +# --------------------------------------------------------------------------- +# Routing — classify obligation type +# --------------------------------------------------------------------------- + +# Extended patterns for framework detection (beyond the simple _COMPOSITE_RE +# in decomposition_pass.py — here we also capture the framework name) +_FRAMEWORK_PATTERN = re.compile( + r"(?:praktiken|kontrollen|ma(?:ss|ß)nahmen|anforderungen|vorgaben|controls|practices|measures|requirements)" + r"\s+(?:f(?:ue|ü)r|aus|gem(?:ae|ä)(?:ss|ß)|nach|from|of|for|per)\s+" + r"(.+?)(?:\s+(?:m(?:ue|ü)ssen|sollen|sind|werden|implementieren|umsetzen|einf(?:ue|ü)hren)|\.|,|$)", + re.IGNORECASE, +) + +# Direct framework name references +_DIRECT_FRAMEWORK_RE = re.compile( + r"\b(?:CSA\s*CCM|NIST\s*(?:SP\s*)?800-53|OWASP\s*(?:ASVS|SAMM|Top\s*10)" + r"|CIS\s*Controls|BSI\s*(?:IT-)?Grundschutz|ENISA|ISO\s*2700[12]" + r"|COBIT|SOX|PCI\s*DSS|HITRUST|SOC\s*2|KRITIS)\b", + re.IGNORECASE, +) + +# Compound verb patterns (multiple main verbs) +_COMPOUND_VERB_RE = re.compile( + r"\b(?:und|sowie|als\s+auch|or|and)\b", + re.IGNORECASE, +) + +# No-split phrases that look compound but aren't +_NO_SPLIT_PHRASES = [ + "pflegen und aufrechterhalten", + "dokumentieren und pflegen", + "definieren und dokumentieren", + "erstellen und freigeben", + "pruefen und genehmigen", + "identifizieren und bewerten", + "erkennen und melden", + "define and maintain", + "create and maintain", + "establish and maintain", + "monitor and review", + "detect and respond", +] + + +@dataclass +class RoutingResult: + """Result of obligation routing classification.""" + routing_type: str # atomic | compound | framework_container | unknown_review + framework_ref: Optional[str] = None + framework_domain: Optional[str] = None + domain_title: Optional[str] = None + confidence: float = 0.0 + reason: str = "" + + +def classify_routing( + obligation_text: str, + action_raw: str, + object_raw: str, + condition_raw: Optional[str] = None, +) -> RoutingResult: + """Classify an obligation into atomic / compound / framework_container.""" + combined = f"{obligation_text} {object_raw}".lower() + + # --- Step 1: Framework container detection --- + fw_result = _detect_framework(obligation_text, object_raw) + if fw_result.routing_type == "framework_container": + return fw_result + + # --- Step 2: Compound verb detection --- + if _is_compound_obligation(action_raw, obligation_text): + return RoutingResult( + routing_type="compound", + confidence=0.7, + reason="multiple_main_verbs", + ) + + # --- Step 3: Default = atomic --- + return RoutingResult( + routing_type="atomic", + confidence=0.9, + reason="single_action_single_object", + ) + + +def _detect_framework( + obligation_text: str, object_raw: str, +) -> RoutingResult: + """Detect if obligation references a framework domain.""" + combined = f"{obligation_text} {object_raw}" + registry = get_registry() + alias_idx = _build_alias_index(registry) + + # Strategy 1: direct framework name match + m = _DIRECT_FRAMEWORK_RE.search(combined) + if m: + fw_name = m.group(0).strip() + fw_id = _resolve_framework_id(fw_name, alias_idx, registry) + if fw_id: + domain_id, domain_title = _match_domain( + combined, registry[fw_id], + ) + return RoutingResult( + routing_type="framework_container", + framework_ref=fw_id, + framework_domain=domain_id, + domain_title=domain_title, + confidence=0.95 if domain_id else 0.75, + reason=f"direct_framework_match:{fw_name}", + ) + else: + # Framework name recognized but not in registry + return RoutingResult( + routing_type="framework_container", + framework_ref=None, + framework_domain=None, + confidence=0.6, + reason=f"direct_framework_match_no_registry:{fw_name}", + ) + + # Strategy 2: pattern match ("Praktiken fuer X") + m2 = _FRAMEWORK_PATTERN.search(combined) + if m2: + ref_text = m2.group(1).strip() + fw_id, domain_id, domain_title = _resolve_from_ref_text( + ref_text, registry, alias_idx, + ) + if fw_id: + return RoutingResult( + routing_type="framework_container", + framework_ref=fw_id, + framework_domain=domain_id, + domain_title=domain_title, + confidence=0.85 if domain_id else 0.65, + reason=f"pattern_match:{ref_text}", + ) + + # Strategy 3: keyword-heavy object + if _has_framework_keywords(object_raw): + return RoutingResult( + routing_type="framework_container", + framework_ref=None, + framework_domain=None, + confidence=0.5, + reason="framework_keywords_in_object", + ) + + return RoutingResult(routing_type="atomic", confidence=0.0) + + +def _resolve_framework_id( + name: str, + alias_idx: dict[str, str], + registry: dict[str, dict], +) -> Optional[str]: + """Resolve a framework name to its registry ID.""" + normalized = re.sub(r"\s+", " ", name.strip().lower()) + # Direct alias match + if normalized in alias_idx: + return alias_idx[normalized] + # Try compact form (strip spaces, hyphens, underscores) + compact = re.sub(r"[\s_\-]+", "", normalized) + for alias, fw_id in alias_idx.items(): + if re.sub(r"[\s_\-]+", "", alias) == compact: + return fw_id + # Substring match in display names + for fw_id, fw in registry.items(): + display = fw.get("display_name", "").lower() + if normalized in display or display in normalized: + return fw_id + # Partial match: check if normalized contains any alias (for multi-word refs) + for alias, fw_id in alias_idx.items(): + if len(alias) >= 4 and alias in normalized: + return fw_id + return None + + +def _match_domain( + text: str, framework: dict, +) -> tuple[Optional[str], Optional[str]]: + """Match a domain within a framework from text references.""" + text_lower = text.lower() + best_id: Optional[str] = None + best_title: Optional[str] = None + best_score = 0 + + for domain in framework.get("domains", []): + score = 0 + domain_id = domain["domain_id"] + title = domain.get("title", "") + + # Exact domain ID match (e.g. "AIS") + if re.search(rf"\b{re.escape(domain_id)}\b", text, re.IGNORECASE): + score += 10 + + # Full title match + if title.lower() in text_lower: + score += 8 + + # Alias match + for alias in domain.get("aliases", []): + if alias.lower() in text_lower: + score += 6 + break + + # Keyword overlap + kw_hits = sum( + 1 for kw in domain.get("keywords", []) + if kw.lower() in text_lower + ) + score += kw_hits + + if score > best_score: + best_score = score + best_id = domain_id + best_title = title + + if best_score >= 3: + return best_id, best_title + return None, None + + +def _resolve_from_ref_text( + ref_text: str, + registry: dict[str, dict], + alias_idx: dict[str, str], +) -> tuple[Optional[str], Optional[str], Optional[str]]: + """Resolve framework + domain from a reference text like 'AIS' or 'Application Security'.""" + ref_lower = ref_text.lower() + + for fw_id, fw in registry.items(): + for domain in fw.get("domains", []): + # Check domain ID + if domain["domain_id"].lower() in ref_lower: + return fw_id, domain["domain_id"], domain.get("title") + # Check title + if domain.get("title", "").lower() in ref_lower: + return fw_id, domain["domain_id"], domain.get("title") + # Check aliases + for alias in domain.get("aliases", []): + if alias.lower() in ref_lower or ref_lower in alias.lower(): + return fw_id, domain["domain_id"], domain.get("title") + + return None, None, None + + +_FRAMEWORK_KW_SET = { + "praktiken", "kontrollen", "massnahmen", "maßnahmen", + "anforderungen", "vorgaben", "framework", "standard", + "baseline", "katalog", "domain", "family", "category", + "practices", "controls", "measures", "requirements", +} + + +def _has_framework_keywords(text: str) -> bool: + """Check if text contains framework-indicator keywords.""" + words = set(re.findall(r"[a-zäöüß]+", text.lower())) + return len(words & _FRAMEWORK_KW_SET) >= 2 + + +def _is_compound_obligation(action_raw: str, obligation_text: str) -> bool: + """Detect if the obligation has multiple competing main verbs.""" + if not action_raw: + return False + + action_lower = action_raw.lower().strip() + + # Check no-split phrases first + for phrase in _NO_SPLIT_PHRASES: + if phrase in action_lower: + return False + + # Must have a conjunction + if not _COMPOUND_VERB_RE.search(action_lower): + return False + + # Split by conjunctions and check if we get 2+ meaningful verbs + parts = re.split(r"\b(?:und|sowie|als\s+auch|or|and)\b", action_lower) + meaningful = [p.strip() for p in parts if len(p.strip()) >= 3] + return len(meaningful) >= 2 + + +# --------------------------------------------------------------------------- +# Framework Decomposition +# --------------------------------------------------------------------------- + +@dataclass +class DecomposedObligation: + """A concrete obligation derived from a framework container.""" + obligation_candidate_id: str + parent_control_id: str + parent_framework_container_id: str + source_ref_law: str + source_ref_article: str + obligation_text: str + actor: str + action_raw: str + object_raw: str + condition_raw: Optional[str] = None + trigger_raw: Optional[str] = None + routing_type: str = "atomic" + release_state: str = "decomposed" + subcontrol_id: str = "" + # Metadata + action_hint: str = "" + object_hint: str = "" + object_class: str = "" + keywords: list[str] = field(default_factory=list) + + +@dataclass +class FrameworkDecompositionResult: + """Result of framework decomposition.""" + framework_container_id: str + source_obligation_candidate_id: str + framework_ref: Optional[str] + framework_domain: Optional[str] + domain_title: Optional[str] + matched_subcontrols: list[str] + decomposition_confidence: float + release_state: str # decomposed | unmatched | error + decomposed_obligations: list[DecomposedObligation] + issues: list[str] + + +def decompose_framework_container( + obligation_candidate_id: str, + parent_control_id: str, + obligation_text: str, + framework_ref: Optional[str], + framework_domain: Optional[str], + actor: str = "organization", +) -> FrameworkDecompositionResult: + """Decompose a framework-container obligation into concrete sub-obligations. + + Steps: + 1. Resolve framework from registry + 2. Resolve domain within framework + 3. Select relevant subcontrols (keyword filter or full domain) + 4. Generate decomposed obligations + """ + container_id = f"FWC-{uuid.uuid4().hex[:8]}" + registry = get_registry() + issues: list[str] = [] + + # Step 1: Resolve framework + fw = None + if framework_ref and framework_ref in registry: + fw = registry[framework_ref] + else: + # Try to find by name in text + fw, framework_ref = _find_framework_in_text(obligation_text, registry) + + if not fw: + issues.append("ERROR: framework_not_matched") + return FrameworkDecompositionResult( + framework_container_id=container_id, + source_obligation_candidate_id=obligation_candidate_id, + framework_ref=framework_ref, + framework_domain=framework_domain, + domain_title=None, + matched_subcontrols=[], + decomposition_confidence=0.0, + release_state="unmatched", + decomposed_obligations=[], + issues=issues, + ) + + # Step 2: Resolve domain + domain_data = None + domain_title = None + if framework_domain: + for d in fw.get("domains", []): + if d["domain_id"].lower() == framework_domain.lower(): + domain_data = d + domain_title = d.get("title") + break + if not domain_data: + # Try matching from text + domain_id, domain_title = _match_domain(obligation_text, fw) + if domain_id: + for d in fw.get("domains", []): + if d["domain_id"] == domain_id: + domain_data = d + framework_domain = domain_id + break + + if not domain_data: + issues.append("WARN: domain_not_matched — using all domains") + # Fall back to all subcontrols across all domains + all_subcontrols = [] + for d in fw.get("domains", []): + for sc in d.get("subcontrols", []): + sc["_domain_id"] = d["domain_id"] + all_subcontrols.append(sc) + subcontrols = _select_subcontrols(obligation_text, all_subcontrols) + if not subcontrols: + issues.append("ERROR: no_subcontrols_matched") + return FrameworkDecompositionResult( + framework_container_id=container_id, + source_obligation_candidate_id=obligation_candidate_id, + framework_ref=framework_ref, + framework_domain=framework_domain, + domain_title=None, + matched_subcontrols=[], + decomposition_confidence=0.0, + release_state="unmatched", + decomposed_obligations=[], + issues=issues, + ) + else: + # Step 3: Select subcontrols from domain + raw_subcontrols = domain_data.get("subcontrols", []) + subcontrols = _select_subcontrols(obligation_text, raw_subcontrols) + if not subcontrols: + # Full domain decomposition + subcontrols = raw_subcontrols + + # Quality check: too many subcontrols + if len(subcontrols) > 25: + issues.append(f"WARN: {len(subcontrols)} subcontrols — may be too broad") + + # Step 4: Generate decomposed obligations + display_name = fw.get("display_name", framework_ref or "Unknown") + decomposed: list[DecomposedObligation] = [] + matched_ids: list[str] = [] + + for sc in subcontrols: + sc_id = sc.get("subcontrol_id", "") + matched_ids.append(sc_id) + + action_hint = sc.get("action_hint", "") + object_hint = sc.get("object_hint", "") + + # Quality warnings + if not action_hint: + issues.append(f"WARN: {sc_id} missing action_hint") + if not object_hint: + issues.append(f"WARN: {sc_id} missing object_hint") + + obl_id = f"{obligation_candidate_id}-{sc_id}" + + decomposed.append(DecomposedObligation( + obligation_candidate_id=obl_id, + parent_control_id=parent_control_id, + parent_framework_container_id=container_id, + source_ref_law=display_name, + source_ref_article=sc_id, + obligation_text=sc.get("statement", ""), + actor=actor, + action_raw=action_hint or _infer_action(sc.get("statement", "")), + object_raw=object_hint or _infer_object(sc.get("statement", "")), + routing_type="atomic", + release_state="decomposed", + subcontrol_id=sc_id, + action_hint=action_hint, + object_hint=object_hint, + object_class=sc.get("object_class", ""), + keywords=sc.get("keywords", []), + )) + + # Check if decomposed are identical to container + for d in decomposed: + if d.obligation_text.strip() == obligation_text.strip(): + issues.append(f"WARN: {d.subcontrol_id} identical to container text") + + confidence = _compute_decomposition_confidence( + framework_ref, framework_domain, domain_data, len(subcontrols), issues, + ) + + return FrameworkDecompositionResult( + framework_container_id=container_id, + source_obligation_candidate_id=obligation_candidate_id, + framework_ref=framework_ref, + framework_domain=framework_domain, + domain_title=domain_title, + matched_subcontrols=matched_ids, + decomposition_confidence=confidence, + release_state="decomposed", + decomposed_obligations=decomposed, + issues=issues, + ) + + +def _find_framework_in_text( + text: str, registry: dict[str, dict], +) -> tuple[Optional[dict], Optional[str]]: + """Try to find a framework by searching text for known names.""" + alias_idx = _build_alias_index(registry) + m = _DIRECT_FRAMEWORK_RE.search(text) + if m: + fw_id = _resolve_framework_id(m.group(0), alias_idx, registry) + if fw_id and fw_id in registry: + return registry[fw_id], fw_id + return None, None + + +def _select_subcontrols( + obligation_text: str, subcontrols: list[dict], +) -> list[dict]: + """Select relevant subcontrols based on keyword matching. + + Returns empty list if no targeted match found (caller falls back to + full domain). + """ + text_lower = obligation_text.lower() + scored: list[tuple[int, dict]] = [] + + for sc in subcontrols: + score = 0 + for kw in sc.get("keywords", []): + if kw.lower() in text_lower: + score += 1 + # Title match + title = sc.get("title", "").lower() + if title and title in text_lower: + score += 3 + # Object hint in text + obj = sc.get("object_hint", "").lower() + if obj and obj in text_lower: + score += 2 + + if score > 0: + scored.append((score, sc)) + + if not scored: + return [] + + # Only return those with meaningful overlap (score >= 2) + scored.sort(key=lambda x: x[0], reverse=True) + return [sc for score, sc in scored if score >= 2] + + +def _infer_action(statement: str) -> str: + """Infer a basic action verb from a statement.""" + s = statement.lower() + if any(w in s for w in ["definiert", "definieren", "define"]): + return "definieren" + if any(w in s for w in ["implementiert", "implementieren", "implement"]): + return "implementieren" + if any(w in s for w in ["dokumentiert", "dokumentieren", "document"]): + return "dokumentieren" + if any(w in s for w in ["ueberwacht", "ueberwachen", "monitor"]): + return "ueberwachen" + if any(w in s for w in ["getestet", "testen", "test"]): + return "testen" + if any(w in s for w in ["geschuetzt", "schuetzen", "protect"]): + return "implementieren" + if any(w in s for w in ["verwaltet", "verwalten", "manage"]): + return "pflegen" + if any(w in s for w in ["gemeldet", "melden", "report"]): + return "melden" + return "implementieren" + + +def _infer_object(statement: str) -> str: + """Infer the primary object from a statement (first noun phrase).""" + # Simple heuristic: take the text after "muessen"/"muss" up to the verb + m = re.search( + r"(?:muessen|muss|m(?:ü|ue)ssen)\s+(.+?)(?:\s+werden|\s+sein|\.|,|$)", + statement, + re.IGNORECASE, + ) + if m: + return m.group(1).strip()[:80] + # Fallback: first 80 chars + return statement[:80] if statement else "" + + +def _compute_decomposition_confidence( + framework_ref: Optional[str], + domain: Optional[str], + domain_data: Optional[dict], + num_subcontrols: int, + issues: list[str], +) -> float: + """Compute confidence score for the decomposition.""" + score = 0.3 + if framework_ref: + score += 0.25 + if domain: + score += 0.20 + if domain_data: + score += 0.10 + if 1 <= num_subcontrols <= 15: + score += 0.10 + elif num_subcontrols > 15: + score += 0.05 # less confident with too many + + # Penalize errors + errors = sum(1 for i in issues if i.startswith("ERROR:")) + score -= errors * 0.15 + return round(max(min(score, 1.0), 0.0), 2) + + +# --------------------------------------------------------------------------- +# Registry statistics (for admin/debugging) +# --------------------------------------------------------------------------- + +def registry_stats() -> dict: + """Return summary statistics about the loaded registry.""" + reg = get_registry() + stats = { + "frameworks": len(reg), + "details": [], + } + total_domains = 0 + total_subcontrols = 0 + for fw_id, fw in reg.items(): + domains = fw.get("domains", []) + n_sc = sum(len(d.get("subcontrols", [])) for d in domains) + total_domains += len(domains) + total_subcontrols += n_sc + stats["details"].append({ + "framework_id": fw_id, + "display_name": fw.get("display_name", ""), + "domains": len(domains), + "subcontrols": n_sc, + }) + stats["total_domains"] = total_domains + stats["total_subcontrols"] = total_subcontrols + return stats diff --git a/control-pipeline/services/license_gate.py b/control-pipeline/services/license_gate.py new file mode 100644 index 0000000..b4f73bc --- /dev/null +++ b/control-pipeline/services/license_gate.py @@ -0,0 +1,116 @@ +""" +License Gate — checks whether a given source may be used for a specific purpose. + +Usage types: + - analysis: Read + analyse internally (TDM under UrhG 44b) + - store_excerpt: Store verbatim excerpt in vault + - ship_embeddings: Ship embeddings in product + - ship_in_product: Ship text/content in product + +Policy is driven by the canonical_control_sources table columns: + allowed_analysis, allowed_store_excerpt, allowed_ship_embeddings, allowed_ship_in_product +""" + +from __future__ import annotations + +import logging +from typing import Any + +from sqlalchemy import text +from sqlalchemy.orm import Session + +logger = logging.getLogger(__name__) + +USAGE_COLUMN_MAP = { + "analysis": "allowed_analysis", + "store_excerpt": "allowed_store_excerpt", + "ship_embeddings": "allowed_ship_embeddings", + "ship_in_product": "allowed_ship_in_product", +} + + +def check_source_allowed(db: Session, source_id: str, usage_type: str) -> bool: + """Check whether *source_id* may be used for *usage_type*. + + Returns False if the source is unknown or the usage is not allowed. + """ + col = USAGE_COLUMN_MAP.get(usage_type) + if col is None: + logger.warning("Unknown usage_type=%s", usage_type) + return False + + row = db.execute( + text(f"SELECT {col} FROM canonical_control_sources WHERE source_id = :sid"), + {"sid": source_id}, + ).fetchone() + + if row is None: + logger.warning("Source %s not found in registry", source_id) + return False + + return bool(row[0]) + + +def get_license_matrix(db: Session) -> list[dict[str, Any]]: + """Return the full license matrix with allowed usages per license.""" + rows = db.execute( + text(""" + SELECT license_id, name, terms_url, commercial_use, + ai_training_restriction, tdm_allowed_under_44b, + deletion_required, notes + FROM canonical_control_licenses + ORDER BY license_id + """) + ).fetchall() + + return [ + { + "license_id": r.license_id, + "name": r.name, + "terms_url": r.terms_url, + "commercial_use": r.commercial_use, + "ai_training_restriction": r.ai_training_restriction, + "tdm_allowed_under_44b": r.tdm_allowed_under_44b, + "deletion_required": r.deletion_required, + "notes": r.notes, + } + for r in rows + ] + + +def get_source_permissions(db: Session) -> list[dict[str, Any]]: + """Return all sources with their permission flags.""" + rows = db.execute( + text(""" + SELECT s.source_id, s.title, s.publisher, s.url, s.version_label, + s.language, s.license_id, + s.allowed_analysis, s.allowed_store_excerpt, + s.allowed_ship_embeddings, s.allowed_ship_in_product, + s.vault_retention_days, s.vault_access_tier, + l.name AS license_name, l.commercial_use + FROM canonical_control_sources s + JOIN canonical_control_licenses l ON l.license_id = s.license_id + ORDER BY s.source_id + """) + ).fetchall() + + return [ + { + "source_id": r.source_id, + "title": r.title, + "publisher": r.publisher, + "url": r.url, + "version_label": r.version_label, + "language": r.language, + "license_id": r.license_id, + "license_name": r.license_name, + "commercial_use": r.commercial_use, + "allowed_analysis": r.allowed_analysis, + "allowed_store_excerpt": r.allowed_store_excerpt, + "allowed_ship_embeddings": r.allowed_ship_embeddings, + "allowed_ship_in_product": r.allowed_ship_in_product, + "vault_retention_days": r.vault_retention_days, + "vault_access_tier": r.vault_access_tier, + } + for r in rows + ] diff --git a/control-pipeline/services/llm_provider.py b/control-pipeline/services/llm_provider.py new file mode 100644 index 0000000..926beb1 --- /dev/null +++ b/control-pipeline/services/llm_provider.py @@ -0,0 +1,624 @@ +""" +LLM Provider Abstraction for Compliance AI Features. + +Supports: +- Anthropic Claude API (default) +- Self-Hosted LLMs (Ollama, vLLM, LocalAI, etc.) +- HashiCorp Vault integration for secure API key storage + +Configuration via environment variables: +- COMPLIANCE_LLM_PROVIDER: "anthropic" or "self_hosted" +- ANTHROPIC_API_KEY: API key for Claude (or loaded from Vault) +- ANTHROPIC_MODEL: Model name (default: claude-sonnet-4-20250514) +- SELF_HOSTED_LLM_URL: Base URL for self-hosted LLM +- SELF_HOSTED_LLM_MODEL: Model name for self-hosted +- SELF_HOSTED_LLM_KEY: Optional API key for self-hosted + +Vault Configuration: +- VAULT_ADDR: Vault server address (e.g., http://vault:8200) +- VAULT_TOKEN: Vault authentication token +- USE_VAULT_SECRETS: Set to "true" to enable Vault integration +- VAULT_SECRET_PATH: Path to secrets (default: secret/breakpilot/api_keys) +""" + +import os +import asyncio +import logging +from abc import ABC, abstractmethod +from typing import List, Optional, Dict, Any +from dataclasses import dataclass +from enum import Enum + +import httpx + +logger = logging.getLogger(__name__) + + +# ============================================================================= +# Vault Integration +# ============================================================================= + +class VaultClient: + """ + HashiCorp Vault client for retrieving secrets. + + Supports KV v2 secrets engine. + """ + + def __init__( + self, + addr: Optional[str] = None, + token: Optional[str] = None + ): + self.addr = addr or os.getenv("VAULT_ADDR", "http://localhost:8200") + self.token = token or os.getenv("VAULT_TOKEN") + self._cache: Dict[str, Any] = {} + self._cache_ttl = 300 # 5 minutes cache + + def _get_headers(self) -> Dict[str, str]: + """Get request headers with Vault token.""" + headers = {"Content-Type": "application/json"} + if self.token: + headers["X-Vault-Token"] = self.token + return headers + + def get_secret(self, path: str, key: str = "value") -> Optional[str]: + """ + Get a secret from Vault KV v2. + + Args: + path: Secret path (e.g., "breakpilot/api_keys/anthropic") + key: Key within the secret data (default: "value") + + Returns: + Secret value or None if not found + """ + cache_key = f"{path}:{key}" + + # Check cache first + if cache_key in self._cache: + return self._cache[cache_key] + + try: + # KV v2 uses /data/ in the path + full_path = f"{self.addr}/v1/secret/data/{path}" + + response = httpx.get( + full_path, + headers=self._get_headers(), + timeout=10.0 + ) + + if response.status_code == 200: + data = response.json() + secret_data = data.get("data", {}).get("data", {}) + secret_value = secret_data.get(key) + + if secret_value: + self._cache[cache_key] = secret_value + logger.info(f"Successfully loaded secret from Vault: {path}") + return secret_value + + elif response.status_code == 404: + logger.warning(f"Secret not found in Vault: {path}") + else: + logger.error(f"Vault error {response.status_code}: {response.text}") + + except httpx.RequestError as e: + logger.error(f"Failed to connect to Vault at {self.addr}: {e}") + except Exception as e: + logger.error(f"Error retrieving secret from Vault: {e}") + + return None + + def get_anthropic_key(self) -> Optional[str]: + """Get Anthropic API key from Vault.""" + path = os.getenv("VAULT_ANTHROPIC_PATH", "breakpilot/api_keys/anthropic") + return self.get_secret(path, "value") + + def is_available(self) -> bool: + """Check if Vault is available and authenticated.""" + try: + response = httpx.get( + f"{self.addr}/v1/sys/health", + headers=self._get_headers(), + timeout=5.0 + ) + return response.status_code in (200, 429, 472, 473, 501, 503) + except Exception: + return False + + +# Singleton Vault client +_vault_client: Optional[VaultClient] = None + + +def get_vault_client() -> VaultClient: + """Get shared Vault client instance.""" + global _vault_client + if _vault_client is None: + _vault_client = VaultClient() + return _vault_client + + +def get_secret_from_vault_or_env( + vault_path: str, + env_var: str, + vault_key: str = "value" +) -> Optional[str]: + """ + Get a secret, trying Vault first, then falling back to environment variable. + + Args: + vault_path: Path in Vault (e.g., "breakpilot/api_keys/anthropic") + env_var: Environment variable name as fallback + vault_key: Key within Vault secret data + + Returns: + Secret value or None + """ + use_vault = os.getenv("USE_VAULT_SECRETS", "").lower() in ("true", "1", "yes") + + if use_vault: + vault = get_vault_client() + secret = vault.get_secret(vault_path, vault_key) + if secret: + return secret + logger.info(f"Vault secret not found, falling back to env: {env_var}") + + return os.getenv(env_var) + + +class LLMProviderType(str, Enum): + """Supported LLM provider types.""" + ANTHROPIC = "anthropic" + SELF_HOSTED = "self_hosted" + OLLAMA = "ollama" # Alias for self_hosted (Ollama-specific) + MOCK = "mock" # For testing + + +@dataclass +class LLMResponse: + """Standard response from LLM.""" + content: str + model: str + provider: str + usage: Optional[Dict[str, int]] = None + raw_response: Optional[Dict[str, Any]] = None + + +@dataclass +class LLMConfig: + """Configuration for LLM provider.""" + provider_type: LLMProviderType + api_key: Optional[str] = None + model: str = "claude-sonnet-4-20250514" + base_url: Optional[str] = None + max_tokens: int = 4096 + temperature: float = 0.3 + timeout: float = 60.0 + + +class LLMProvider(ABC): + """Abstract base class for LLM providers.""" + + def __init__(self, config: LLMConfig): + self.config = config + + @abstractmethod + async def complete( + self, + prompt: str, + system_prompt: Optional[str] = None, + max_tokens: Optional[int] = None, + temperature: Optional[float] = None + ) -> LLMResponse: + """Generate a completion for the given prompt.""" + pass + + @abstractmethod + async def batch_complete( + self, + prompts: List[str], + system_prompt: Optional[str] = None, + max_tokens: Optional[int] = None, + rate_limit: float = 1.0 + ) -> List[LLMResponse]: + """Generate completions for multiple prompts with rate limiting.""" + pass + + @property + @abstractmethod + def provider_name(self) -> str: + """Return the provider name.""" + pass + + +class AnthropicProvider(LLMProvider): + """Claude API Provider using Anthropic's official API.""" + + ANTHROPIC_API_URL = "https://api.anthropic.com/v1/messages" + + def __init__(self, config: LLMConfig): + super().__init__(config) + if not config.api_key: + raise ValueError("Anthropic API key is required") + self.api_key = config.api_key + self.model = config.model or "claude-sonnet-4-20250514" + + @property + def provider_name(self) -> str: + return "anthropic" + + async def complete( + self, + prompt: str, + system_prompt: Optional[str] = None, + max_tokens: Optional[int] = None, + temperature: Optional[float] = None + ) -> LLMResponse: + """Generate completion using Claude API.""" + + headers = { + "x-api-key": self.api_key, + "anthropic-version": "2023-06-01", + "content-type": "application/json" + } + + messages = [{"role": "user", "content": prompt}] + + payload = { + "model": self.model, + "max_tokens": max_tokens or self.config.max_tokens, + "messages": messages + } + + if system_prompt: + payload["system"] = system_prompt + + if temperature is not None: + payload["temperature"] = temperature + elif self.config.temperature is not None: + payload["temperature"] = self.config.temperature + + async with httpx.AsyncClient(timeout=self.config.timeout) as client: + try: + response = await client.post( + self.ANTHROPIC_API_URL, + headers=headers, + json=payload + ) + response.raise_for_status() + data = response.json() + + content = "" + if data.get("content"): + content = data["content"][0].get("text", "") + + return LLMResponse( + content=content, + model=self.model, + provider=self.provider_name, + usage=data.get("usage"), + raw_response=data + ) + + except httpx.HTTPStatusError as e: + logger.error(f"Anthropic API error: {e.response.status_code} - {e.response.text}") + raise + except Exception as e: + logger.error(f"Anthropic API request failed: {e}") + raise + + async def batch_complete( + self, + prompts: List[str], + system_prompt: Optional[str] = None, + max_tokens: Optional[int] = None, + rate_limit: float = 1.0 + ) -> List[LLMResponse]: + """Process multiple prompts with rate limiting.""" + results = [] + + for i, prompt in enumerate(prompts): + if i > 0: + await asyncio.sleep(rate_limit) + + try: + result = await self.complete( + prompt=prompt, + system_prompt=system_prompt, + max_tokens=max_tokens + ) + results.append(result) + except Exception as e: + logger.error(f"Failed to process prompt {i}: {e}") + # Append error response + results.append(LLMResponse( + content=f"Error: {str(e)}", + model=self.model, + provider=self.provider_name + )) + + return results + + +class SelfHostedProvider(LLMProvider): + """Self-Hosted LLM Provider supporting Ollama, vLLM, LocalAI, etc.""" + + def __init__(self, config: LLMConfig): + super().__init__(config) + if not config.base_url: + raise ValueError("Base URL is required for self-hosted provider") + self.base_url = config.base_url.rstrip("/") + self.model = config.model + self.api_key = config.api_key + + @property + def provider_name(self) -> str: + return "self_hosted" + + def _detect_api_format(self) -> str: + """Detect the API format based on URL patterns.""" + if "11434" in self.base_url or "ollama" in self.base_url.lower(): + return "ollama" + elif "openai" in self.base_url.lower() or "v1" in self.base_url: + return "openai" + else: + return "ollama" # Default to Ollama format + + async def complete( + self, + prompt: str, + system_prompt: Optional[str] = None, + max_tokens: Optional[int] = None, + temperature: Optional[float] = None + ) -> LLMResponse: + """Generate completion using self-hosted LLM.""" + + api_format = self._detect_api_format() + + headers = {"content-type": "application/json"} + if self.api_key: + headers["Authorization"] = f"Bearer {self.api_key}" + + if api_format == "ollama": + # Ollama API format + endpoint = f"{self.base_url}/api/generate" + full_prompt = prompt + if system_prompt: + full_prompt = f"{system_prompt}\n\n{prompt}" + + payload = { + "model": self.model, + "prompt": full_prompt, + "stream": False, + "think": False, # Disable thinking mode (qwen3.5 etc.) + "options": {} + } + + if max_tokens: + payload["options"]["num_predict"] = max_tokens + if temperature is not None: + payload["options"]["temperature"] = temperature + + else: + # OpenAI-compatible format (vLLM, LocalAI, etc.) + endpoint = f"{self.base_url}/v1/chat/completions" + + messages = [] + if system_prompt: + messages.append({"role": "system", "content": system_prompt}) + messages.append({"role": "user", "content": prompt}) + + payload = { + "model": self.model, + "messages": messages, + "max_tokens": max_tokens or self.config.max_tokens, + "temperature": temperature if temperature is not None else self.config.temperature + } + + async with httpx.AsyncClient(timeout=self.config.timeout) as client: + try: + response = await client.post(endpoint, headers=headers, json=payload) + response.raise_for_status() + data = response.json() + + # Parse response based on format + if api_format == "ollama": + content = data.get("response", "") + else: + # OpenAI format + content = data.get("choices", [{}])[0].get("message", {}).get("content", "") + + return LLMResponse( + content=content, + model=self.model, + provider=self.provider_name, + usage=data.get("usage"), + raw_response=data + ) + + except httpx.HTTPStatusError as e: + logger.error(f"Self-hosted LLM error: {e.response.status_code} - {e.response.text}") + raise + except Exception as e: + logger.error(f"Self-hosted LLM request failed: {e}") + raise + + async def batch_complete( + self, + prompts: List[str], + system_prompt: Optional[str] = None, + max_tokens: Optional[int] = None, + rate_limit: float = 0.5 # Self-hosted can be faster + ) -> List[LLMResponse]: + """Process multiple prompts with rate limiting.""" + results = [] + + for i, prompt in enumerate(prompts): + if i > 0: + await asyncio.sleep(rate_limit) + + try: + result = await self.complete( + prompt=prompt, + system_prompt=system_prompt, + max_tokens=max_tokens + ) + results.append(result) + except Exception as e: + logger.error(f"Failed to process prompt {i}: {e}") + results.append(LLMResponse( + content=f"Error: {str(e)}", + model=self.model, + provider=self.provider_name + )) + + return results + + +class MockProvider(LLMProvider): + """Mock provider for testing without actual API calls.""" + + def __init__(self, config: LLMConfig): + super().__init__(config) + self.responses: List[str] = [] + self.call_count = 0 + + @property + def provider_name(self) -> str: + return "mock" + + def set_responses(self, responses: List[str]): + """Set predetermined responses for testing.""" + self.responses = responses + self.call_count = 0 + + async def complete( + self, + prompt: str, + system_prompt: Optional[str] = None, + max_tokens: Optional[int] = None, + temperature: Optional[float] = None + ) -> LLMResponse: + """Return mock response.""" + if self.responses: + content = self.responses[self.call_count % len(self.responses)] + else: + content = f"Mock response for: {prompt[:50]}..." + + self.call_count += 1 + + return LLMResponse( + content=content, + model="mock-model", + provider=self.provider_name, + usage={"input_tokens": len(prompt), "output_tokens": len(content)} + ) + + async def batch_complete( + self, + prompts: List[str], + system_prompt: Optional[str] = None, + max_tokens: Optional[int] = None, + rate_limit: float = 0.0 + ) -> List[LLMResponse]: + """Return mock responses for batch.""" + return [await self.complete(p, system_prompt, max_tokens) for p in prompts] + + +def get_llm_config() -> LLMConfig: + """ + Create LLM config from environment variables or Vault. + + Priority for API key: + 1. Vault (if USE_VAULT_SECRETS=true and Vault is available) + 2. Environment variable (ANTHROPIC_API_KEY) + """ + provider_type_str = os.getenv("COMPLIANCE_LLM_PROVIDER", "anthropic") + + try: + provider_type = LLMProviderType(provider_type_str) + except ValueError: + logger.warning(f"Unknown LLM provider: {provider_type_str}, falling back to mock") + provider_type = LLMProviderType.MOCK + + # Get API key from Vault or environment + api_key = None + if provider_type == LLMProviderType.ANTHROPIC: + api_key = get_secret_from_vault_or_env( + vault_path="breakpilot/api_keys/anthropic", + env_var="ANTHROPIC_API_KEY" + ) + elif provider_type in (LLMProviderType.SELF_HOSTED, LLMProviderType.OLLAMA): + api_key = get_secret_from_vault_or_env( + vault_path="breakpilot/api_keys/self_hosted_llm", + env_var="SELF_HOSTED_LLM_KEY" + ) + + # Select model based on provider type + if provider_type == LLMProviderType.ANTHROPIC: + model = os.getenv("ANTHROPIC_MODEL", "claude-sonnet-4-20250514") + elif provider_type in (LLMProviderType.SELF_HOSTED, LLMProviderType.OLLAMA): + model = os.getenv("SELF_HOSTED_LLM_MODEL", "qwen2.5:14b") + else: + model = "mock-model" + + return LLMConfig( + provider_type=provider_type, + api_key=api_key, + model=model, + base_url=os.getenv("SELF_HOSTED_LLM_URL"), + max_tokens=int(os.getenv("COMPLIANCE_LLM_MAX_TOKENS", "4096")), + temperature=float(os.getenv("COMPLIANCE_LLM_TEMPERATURE", "0.3")), + timeout=float(os.getenv("COMPLIANCE_LLM_TIMEOUT", "60.0")) + ) + + +def get_llm_provider(config: Optional[LLMConfig] = None) -> LLMProvider: + """ + Factory function to get the appropriate LLM provider based on configuration. + + Usage: + provider = get_llm_provider() + response = await provider.complete("Analyze this requirement...") + """ + if config is None: + config = get_llm_config() + + if config.provider_type == LLMProviderType.ANTHROPIC: + if not config.api_key: + logger.warning("No Anthropic API key found, using mock provider") + return MockProvider(config) + return AnthropicProvider(config) + + elif config.provider_type in (LLMProviderType.SELF_HOSTED, LLMProviderType.OLLAMA): + if not config.base_url: + logger.warning("No self-hosted LLM URL found, using mock provider") + return MockProvider(config) + return SelfHostedProvider(config) + + elif config.provider_type == LLMProviderType.MOCK: + return MockProvider(config) + + else: + raise ValueError(f"Unsupported LLM provider type: {config.provider_type}") + + +# Singleton instance for reuse +_provider_instance: Optional[LLMProvider] = None + + +def get_shared_provider() -> LLMProvider: + """Get a shared LLM provider instance.""" + global _provider_instance + if _provider_instance is None: + _provider_instance = get_llm_provider() + return _provider_instance + + +def reset_shared_provider(): + """Reset the shared provider instance (useful for testing).""" + global _provider_instance + _provider_instance = None diff --git a/control-pipeline/services/normative_patterns.py b/control-pipeline/services/normative_patterns.py new file mode 100644 index 0000000..5adb895 --- /dev/null +++ b/control-pipeline/services/normative_patterns.py @@ -0,0 +1,59 @@ +"""Shared normative language patterns for assertion classification. + +Extracted from decomposition_pass.py for reuse in the assertion engine. +""" + +import re + +_PFLICHT_SIGNALS = [ + r"\bmüssen\b", r"\bmuss\b", r"\bhat\s+sicherzustellen\b", + r"\bhaben\s+sicherzustellen\b", r"\bsind\s+verpflichtet\b", + r"\bist\s+verpflichtet\b", + r"\bist\s+zu\s+\w+en\b", r"\bsind\s+zu\s+\w+en\b", + r"\bhat\s+zu\s+\w+en\b", r"\bhaben\s+zu\s+\w+en\b", + r"\bist\s+\w+zu\w+en\b", r"\bsind\s+\w+zu\w+en\b", + r"\bist\s+\w+\s+zu\s+\w+en\b", r"\bsind\s+\w+\s+zu\s+\w+en\b", + r"\bhat\s+\w+\s+zu\s+\w+en\b", r"\bhaben\s+\w+\s+zu\s+\w+en\b", + r"\bshall\b", r"\bmust\b", r"\brequired\b", + r"\b\w+zuteilen\b", r"\b\w+zuwenden\b", r"\b\w+zustellen\b", r"\b\w+zulegen\b", + r"\b\w+zunehmen\b", r"\b\w+zuführen\b", r"\b\w+zuhalten\b", r"\b\w+zusetzen\b", + r"\b\w+zuweisen\b", r"\b\w+zuordnen\b", r"\b\w+zufügen\b", r"\b\w+zugeben\b", + r"\bist\b.{1,80}\bzu\s+\w+en\b", r"\bsind\b.{1,80}\bzu\s+\w+en\b", +] +PFLICHT_RE = re.compile("|".join(_PFLICHT_SIGNALS), re.IGNORECASE) + +_EMPFEHLUNG_SIGNALS = [ + r"\bsoll\b", r"\bsollen\b", r"\bsollte\b", r"\bsollten\b", + r"\bgewährleisten\b", r"\bsicherstellen\b", + r"\bshould\b", r"\bensure\b", r"\brecommend\w*\b", + r"\bnachweisen\b", r"\beinhalten\b", r"\bunterlassen\b", r"\bwahren\b", + r"\bdokumentieren\b", r"\bimplementieren\b", r"\büberprüfen\b", r"\büberwachen\b", + r"\bprüfen,\s+ob\b", r"\bkontrollieren,\s+ob\b", +] +EMPFEHLUNG_RE = re.compile("|".join(_EMPFEHLUNG_SIGNALS), re.IGNORECASE) + +_KANN_SIGNALS = [ + r"\bkann\b", r"\bkönnen\b", r"\bdarf\b", r"\bdürfen\b", + r"\bmay\b", r"\boptional\b", +] +KANN_RE = re.compile("|".join(_KANN_SIGNALS), re.IGNORECASE) + +NORMATIVE_RE = re.compile( + "|".join(_PFLICHT_SIGNALS + _EMPFEHLUNG_SIGNALS + _KANN_SIGNALS), + re.IGNORECASE, +) + +_RATIONALE_SIGNALS = [ + r"\bda\s+", r"\bweil\b", r"\bgrund\b", r"\berwägung", + r"\bbecause\b", r"\breason\b", r"\brationale\b", + r"\bkönnen\s+.*\s+verursachen\b", r"\bführt\s+zu\b", +] +RATIONALE_RE = re.compile("|".join(_RATIONALE_SIGNALS), re.IGNORECASE) + +# Evidence-related keywords (for fact detection) +_EVIDENCE_KEYWORDS = [ + r"\bnachweis\b", r"\bzertifikat\b", r"\baudit.report\b", + r"\bprotokoll\b", r"\bdokumentation\b", r"\bbericht\b", + r"\bcertificate\b", r"\bevidence\b", r"\bproof\b", +] +EVIDENCE_RE = re.compile("|".join(_EVIDENCE_KEYWORDS), re.IGNORECASE) diff --git a/control-pipeline/services/obligation_extractor.py b/control-pipeline/services/obligation_extractor.py new file mode 100644 index 0000000..2eecf71 --- /dev/null +++ b/control-pipeline/services/obligation_extractor.py @@ -0,0 +1,563 @@ +"""Obligation Extractor — 3-Tier Chunk-to-Obligation Linking. + +Maps RAG chunks to obligations from the v2 obligation framework using +three tiers (fastest first): + + Tier 1: EXACT MATCH — regulation_code + article → obligation_id (~40%) + Tier 2: EMBEDDING — chunk text vs. obligation descriptions (~30%) + Tier 3: LLM EXTRACT — local Ollama extracts obligation text (~25%) + +Part of the Multi-Layer Control Architecture (Phase 4 of 8). +""" + +import json +import logging +import os +import re +from dataclasses import dataclass, field +from pathlib import Path +from typing import Optional + +import httpx + +logger = logging.getLogger(__name__) + +EMBEDDING_URL = os.getenv("EMBEDDING_URL", "http://embedding-service:8087") +OLLAMA_URL = os.getenv("OLLAMA_URL", "http://host.docker.internal:11434") +OLLAMA_MODEL = os.getenv("CONTROL_GEN_OLLAMA_MODEL", "qwen3.5:35b-a3b") +LLM_TIMEOUT = float(os.getenv("CONTROL_GEN_LLM_TIMEOUT", "180")) + +# Embedding similarity thresholds for Tier 2 +EMBEDDING_MATCH_THRESHOLD = 0.80 +EMBEDDING_CANDIDATE_THRESHOLD = 0.60 + +# --------------------------------------------------------------------------- +# Regulation code mapping: RAG chunk codes → obligation file regulation IDs +# --------------------------------------------------------------------------- + +_REGULATION_CODE_TO_ID = { + # DSGVO + "eu_2016_679": "dsgvo", + "dsgvo": "dsgvo", + "gdpr": "dsgvo", + # AI Act + "eu_2024_1689": "ai_act", + "ai_act": "ai_act", + "aiact": "ai_act", + # NIS2 + "eu_2022_2555": "nis2", + "nis2": "nis2", + "bsig": "nis2", + # BDSG + "bdsg": "bdsg", + # TTDSG + "ttdsg": "ttdsg", + # DSA + "eu_2022_2065": "dsa", + "dsa": "dsa", + # Data Act + "eu_2023_2854": "data_act", + "data_act": "data_act", + # EU Machinery + "eu_2023_1230": "eu_machinery", + "eu_machinery": "eu_machinery", + # DORA + "eu_2022_2554": "dora", + "dora": "dora", +} + + +@dataclass +class ObligationMatch: + """Result of obligation extraction.""" + + obligation_id: Optional[str] = None + obligation_title: Optional[str] = None + obligation_text: Optional[str] = None + method: str = "none" # exact_match | embedding_match | llm_extracted | inferred + confidence: float = 0.0 + regulation_id: Optional[str] = None # e.g. "dsgvo" + + def to_dict(self) -> dict: + return { + "obligation_id": self.obligation_id, + "obligation_title": self.obligation_title, + "obligation_text": self.obligation_text, + "method": self.method, + "confidence": self.confidence, + "regulation_id": self.regulation_id, + } + + +@dataclass +class _ObligationEntry: + """Internal representation of a loaded obligation.""" + + id: str + title: str + description: str + regulation_id: str + articles: list[str] = field(default_factory=list) # normalized: ["art. 30", "§ 38"] + embedding: list[float] = field(default_factory=list) + + +class ObligationExtractor: + """3-Tier obligation extraction from RAG chunks. + + Usage:: + + extractor = ObligationExtractor() + await extractor.initialize() # loads obligations + embeddings + + match = await extractor.extract( + chunk_text="...", + regulation_code="eu_2016_679", + article="Art. 30", + paragraph="Abs. 1", + ) + """ + + def __init__(self): + self._article_lookup: dict[str, list[str]] = {} # "dsgvo/art. 30" → ["DSGVO-OBL-001"] + self._obligations: dict[str, _ObligationEntry] = {} # id → entry + self._obligation_embeddings: list[list[float]] = [] + self._obligation_ids: list[str] = [] + self._initialized = False + + async def initialize(self) -> None: + """Load all obligations from v2 JSON files and compute embeddings.""" + if self._initialized: + return + + self._load_obligations() + await self._compute_embeddings() + self._initialized = True + logger.info( + "ObligationExtractor initialized: %d obligations, %d article lookups, %d embeddings", + len(self._obligations), + len(self._article_lookup), + sum(1 for e in self._obligation_embeddings if e), + ) + + async def extract( + self, + chunk_text: str, + regulation_code: str, + article: Optional[str] = None, + paragraph: Optional[str] = None, + ) -> ObligationMatch: + """Extract obligation from a chunk using 3-tier strategy.""" + if not self._initialized: + await self.initialize() + + reg_id = _normalize_regulation(regulation_code) + + # Tier 1: Exact match via article lookup + if article: + match = self._tier1_exact(reg_id, article) + if match: + return match + + # Tier 2: Embedding similarity + match = await self._tier2_embedding(chunk_text, reg_id) + if match: + return match + + # Tier 3: LLM extraction + match = await self._tier3_llm(chunk_text, regulation_code, article) + return match + + # ----------------------------------------------------------------------- + # Tier 1: Exact Match + # ----------------------------------------------------------------------- + + def _tier1_exact(self, reg_id: Optional[str], article: str) -> Optional[ObligationMatch]: + """Look up obligation by regulation + article.""" + if not reg_id: + return None + + norm_article = _normalize_article(article) + key = f"{reg_id}/{norm_article}" + + obl_ids = self._article_lookup.get(key) + if not obl_ids: + return None + + # Take the first match (highest priority) + obl_id = obl_ids[0] + entry = self._obligations.get(obl_id) + if not entry: + return None + + return ObligationMatch( + obligation_id=entry.id, + obligation_title=entry.title, + obligation_text=entry.description, + method="exact_match", + confidence=1.0, + regulation_id=reg_id, + ) + + # ----------------------------------------------------------------------- + # Tier 2: Embedding Match + # ----------------------------------------------------------------------- + + async def _tier2_embedding( + self, chunk_text: str, reg_id: Optional[str] + ) -> Optional[ObligationMatch]: + """Find nearest obligation by embedding similarity.""" + if not self._obligation_embeddings: + return None + + chunk_embedding = await _get_embedding(chunk_text[:2000]) + if not chunk_embedding: + return None + + best_idx = -1 + best_score = 0.0 + + for i, obl_emb in enumerate(self._obligation_embeddings): + if not obl_emb: + continue + # Prefer same-regulation matches + obl_id = self._obligation_ids[i] + entry = self._obligations.get(obl_id) + score = _cosine_sim(chunk_embedding, obl_emb) + + # Domain bonus: +0.05 if same regulation + if entry and reg_id and entry.regulation_id == reg_id: + score += 0.05 + + if score > best_score: + best_score = score + best_idx = i + + if best_idx < 0: + return None + + # Remove domain bonus for threshold comparison + raw_score = best_score + obl_id = self._obligation_ids[best_idx] + entry = self._obligations.get(obl_id) + if entry and reg_id and entry.regulation_id == reg_id: + raw_score -= 0.05 + + if raw_score >= EMBEDDING_MATCH_THRESHOLD: + return ObligationMatch( + obligation_id=entry.id if entry else obl_id, + obligation_title=entry.title if entry else None, + obligation_text=entry.description if entry else None, + method="embedding_match", + confidence=round(min(raw_score, 1.0), 3), + regulation_id=entry.regulation_id if entry else reg_id, + ) + + return None + + # ----------------------------------------------------------------------- + # Tier 3: LLM Extraction + # ----------------------------------------------------------------------- + + async def _tier3_llm( + self, chunk_text: str, regulation_code: str, article: Optional[str] + ) -> ObligationMatch: + """Use local LLM to extract the obligation from the chunk.""" + prompt = f"""Analysiere den folgenden Gesetzestext und extrahiere die zentrale rechtliche Pflicht. + +Text: +{chunk_text[:3000]} + +Quelle: {regulation_code} {article or ''} + +Antworte NUR als JSON: +{{ + "obligation_text": "Die zentrale Pflicht in einem Satz", + "actor": "Wer muss handeln (z.B. Verantwortlicher, Auftragsverarbeiter)", + "action": "Was muss getan werden", + "normative_strength": "muss|soll|kann" +}}""" + + system_prompt = ( + "Du bist ein Rechtsexperte fuer EU-Datenschutz- und Digitalrecht. " + "Extrahiere die zentrale rechtliche Pflicht aus Gesetzestexten. " + "Antworte ausschliesslich als JSON." + ) + + result_text = await _llm_ollama(prompt, system_prompt) + if not result_text: + return ObligationMatch( + method="llm_extracted", + confidence=0.0, + regulation_id=_normalize_regulation(regulation_code), + ) + + parsed = _parse_json(result_text) + obligation_text = parsed.get("obligation_text", result_text[:500]) + + return ObligationMatch( + obligation_id=None, + obligation_title=None, + obligation_text=obligation_text, + method="llm_extracted", + confidence=0.60, + regulation_id=_normalize_regulation(regulation_code), + ) + + # ----------------------------------------------------------------------- + # Initialization helpers + # ----------------------------------------------------------------------- + + def _load_obligations(self) -> None: + """Load all obligation files from v2 framework.""" + v2_dir = _find_obligations_dir() + if not v2_dir: + logger.warning("Obligations v2 directory not found — Tier 1 disabled") + return + + manifest_path = v2_dir / "_manifest.json" + if not manifest_path.exists(): + logger.warning("Manifest not found at %s", manifest_path) + return + + with open(manifest_path) as f: + manifest = json.load(f) + + for reg_info in manifest.get("regulations", []): + reg_id = reg_info["id"] + reg_file = v2_dir / reg_info["file"] + if not reg_file.exists(): + logger.warning("Regulation file not found: %s", reg_file) + continue + + with open(reg_file) as f: + data = json.load(f) + + for obl in data.get("obligations", []): + obl_id = obl["id"] + entry = _ObligationEntry( + id=obl_id, + title=obl.get("title", ""), + description=obl.get("description", ""), + regulation_id=reg_id, + ) + + # Build article lookup from legal_basis + for basis in obl.get("legal_basis", []): + article_raw = basis.get("article", "") + if article_raw: + norm_art = _normalize_article(article_raw) + key = f"{reg_id}/{norm_art}" + if key not in self._article_lookup: + self._article_lookup[key] = [] + self._article_lookup[key].append(obl_id) + entry.articles.append(norm_art) + + self._obligations[obl_id] = entry + + logger.info( + "Loaded %d obligations from %d regulations", + len(self._obligations), + len(manifest.get("regulations", [])), + ) + + async def _compute_embeddings(self) -> None: + """Compute embeddings for all obligation descriptions.""" + if not self._obligations: + return + + self._obligation_ids = list(self._obligations.keys()) + texts = [ + f"{self._obligations[oid].title}: {self._obligations[oid].description}" + for oid in self._obligation_ids + ] + + logger.info("Computing embeddings for %d obligations...", len(texts)) + self._obligation_embeddings = await _get_embeddings_batch(texts) + valid = sum(1 for e in self._obligation_embeddings if e) + logger.info("Got %d/%d valid embeddings", valid, len(texts)) + + # ----------------------------------------------------------------------- + # Stats + # ----------------------------------------------------------------------- + + def stats(self) -> dict: + """Return initialization statistics.""" + return { + "total_obligations": len(self._obligations), + "article_lookups": len(self._article_lookup), + "embeddings_valid": sum(1 for e in self._obligation_embeddings if e), + "regulations": list( + {e.regulation_id for e in self._obligations.values()} + ), + "initialized": self._initialized, + } + + +# --------------------------------------------------------------------------- +# Module-level helpers (reusable by other modules) +# --------------------------------------------------------------------------- + + +def _normalize_regulation(regulation_code: str) -> Optional[str]: + """Map a RAG regulation_code to obligation framework regulation ID.""" + if not regulation_code: + return None + code = regulation_code.lower().strip() + + # Direct lookup + if code in _REGULATION_CODE_TO_ID: + return _REGULATION_CODE_TO_ID[code] + + # Prefix matching for families + for prefix, reg_id in [ + ("eu_2016_679", "dsgvo"), + ("eu_2024_1689", "ai_act"), + ("eu_2022_2555", "nis2"), + ("eu_2022_2065", "dsa"), + ("eu_2023_2854", "data_act"), + ("eu_2023_1230", "eu_machinery"), + ("eu_2022_2554", "dora"), + ]: + if code.startswith(prefix): + return reg_id + + return None + + +def _normalize_article(article: str) -> str: + """Normalize article references for consistent lookup. + + Examples: + "Art. 30" → "art. 30" + "§ 38 BDSG" → "§ 38" + "Article 10" → "art. 10" + "Art. 30 Abs. 1" → "art. 30" + "Artikel 35" → "art. 35" + """ + if not article: + return "" + s = article.strip() + + # Remove trailing law name: "§ 38 BDSG" → "§ 38" + s = re.sub(r"\s+(DSGVO|BDSG|TTDSG|DSA|NIS2|DORA|AI.?Act)\s*$", "", s, flags=re.IGNORECASE) + + # Remove paragraph references: "Art. 30 Abs. 1" → "Art. 30" + s = re.sub(r"\s+(Abs|Absatz|para|paragraph|lit|Satz)\.?\s+.*$", "", s, flags=re.IGNORECASE) + + # Normalize "Article" / "Artikel" → "Art." + s = re.sub(r"^(Article|Artikel)\s+", "Art. ", s, flags=re.IGNORECASE) + + return s.lower().strip() + + +def _cosine_sim(a: list[float], b: list[float]) -> float: + """Compute cosine similarity between two vectors.""" + if not a or not b or len(a) != len(b): + return 0.0 + dot = sum(x * y for x, y in zip(a, b)) + norm_a = sum(x * x for x in a) ** 0.5 + norm_b = sum(x * x for x in b) ** 0.5 + if norm_a == 0 or norm_b == 0: + return 0.0 + return dot / (norm_a * norm_b) + + +def _find_obligations_dir() -> Optional[Path]: + """Locate the obligations v2 directory.""" + candidates = [ + Path(__file__).resolve().parent.parent.parent.parent + / "ai-compliance-sdk" / "policies" / "obligations" / "v2", + Path("/app/ai-compliance-sdk/policies/obligations/v2"), + Path("ai-compliance-sdk/policies/obligations/v2"), + ] + for p in candidates: + if p.is_dir() and (p / "_manifest.json").exists(): + return p + return None + + +async def _get_embedding(text: str) -> list[float]: + """Get embedding vector for a single text.""" + try: + async with httpx.AsyncClient(timeout=10.0) as client: + resp = await client.post( + f"{EMBEDDING_URL}/embed", + json={"texts": [text]}, + ) + resp.raise_for_status() + embeddings = resp.json().get("embeddings", []) + return embeddings[0] if embeddings else [] + except Exception: + return [] + + +async def _get_embeddings_batch( + texts: list[str], batch_size: int = 32 +) -> list[list[float]]: + """Get embeddings for multiple texts in batches.""" + all_embeddings: list[list[float]] = [] + for i in range(0, len(texts), batch_size): + batch = texts[i : i + batch_size] + try: + async with httpx.AsyncClient(timeout=30.0) as client: + resp = await client.post( + f"{EMBEDDING_URL}/embed", + json={"texts": batch}, + ) + resp.raise_for_status() + embeddings = resp.json().get("embeddings", []) + all_embeddings.extend(embeddings) + except Exception as e: + logger.warning("Batch embedding failed for %d texts: %s", len(batch), e) + all_embeddings.extend([[] for _ in batch]) + return all_embeddings + + +async def _llm_ollama(prompt: str, system_prompt: Optional[str] = None) -> str: + """Call local Ollama for LLM extraction.""" + messages = [] + if system_prompt: + messages.append({"role": "system", "content": system_prompt}) + messages.append({"role": "user", "content": prompt}) + + payload = { + "model": OLLAMA_MODEL, + "messages": messages, + "stream": False, + "format": "json", + "options": {"num_predict": 512}, + "think": False, + } + + try: + async with httpx.AsyncClient(timeout=LLM_TIMEOUT) as client: + resp = await client.post(f"{OLLAMA_URL}/api/chat", json=payload) + if resp.status_code != 200: + logger.error( + "Ollama chat failed %d: %s", resp.status_code, resp.text[:300] + ) + return "" + data = resp.json() + return data.get("message", {}).get("content", "") + except Exception as e: + logger.warning("Ollama call failed: %s", e) + return "" + + +def _parse_json(text: str) -> dict: + """Extract JSON from LLM response text.""" + # Try direct parse + try: + return json.loads(text) + except json.JSONDecodeError: + pass + + # Try extracting JSON block + match = re.search(r"\{[^{}]*\}", text, re.DOTALL) + if match: + try: + return json.loads(match.group()) + except json.JSONDecodeError: + pass + + return {} diff --git a/control-pipeline/services/pattern_matcher.py b/control-pipeline/services/pattern_matcher.py new file mode 100644 index 0000000..c302d35 --- /dev/null +++ b/control-pipeline/services/pattern_matcher.py @@ -0,0 +1,532 @@ +"""Pattern Matcher — Obligation-to-Control-Pattern Linking. + +Maps obligations (from the ObligationExtractor) to control patterns +using two tiers: + + Tier 1: KEYWORD MATCH — obligation_match_keywords from patterns (~70%) + Tier 2: EMBEDDING — cosine similarity with domain bonus (~25%) + +Part of the Multi-Layer Control Architecture (Phase 5 of 8). +""" + +import logging +import os +from dataclasses import dataclass, field +from pathlib import Path +from typing import Optional + +import yaml + +from services.obligation_extractor import ( + _cosine_sim, + _get_embedding, + _get_embeddings_batch, +) + +logger = logging.getLogger(__name__) + +# Minimum keyword score to accept a match (at least 2 keyword hits) +KEYWORD_MATCH_MIN_HITS = 2 +# Embedding threshold for Tier 2 +EMBEDDING_PATTERN_THRESHOLD = 0.75 +# Domain bonus when regulation maps to the pattern's domain +DOMAIN_BONUS = 0.10 + +# Map regulation IDs to pattern domains that are likely relevant +_REGULATION_DOMAIN_AFFINITY = { + "dsgvo": ["DATA", "COMP", "GOV"], + "bdsg": ["DATA", "COMP"], + "ttdsg": ["DATA"], + "ai_act": ["AI", "COMP", "DATA"], + "nis2": ["SEC", "INC", "NET", "LOG", "CRYP"], + "dsa": ["DATA", "COMP"], + "data_act": ["DATA", "COMP"], + "eu_machinery": ["SEC", "COMP"], + "dora": ["SEC", "INC", "FIN", "COMP"], +} + + +@dataclass +class ControlPattern: + """Python representation of a control pattern from YAML.""" + + id: str + name: str + name_de: str + domain: str + category: str + description: str + objective_template: str + rationale_template: str + requirements_template: list[str] = field(default_factory=list) + test_procedure_template: list[str] = field(default_factory=list) + evidence_template: list[str] = field(default_factory=list) + severity_default: str = "medium" + implementation_effort_default: str = "m" + obligation_match_keywords: list[str] = field(default_factory=list) + tags: list[str] = field(default_factory=list) + composable_with: list[str] = field(default_factory=list) + open_anchor_refs: list[dict] = field(default_factory=list) + + +@dataclass +class PatternMatchResult: + """Result of pattern matching.""" + + pattern: Optional[ControlPattern] = None + pattern_id: Optional[str] = None + method: str = "none" # keyword | embedding | combined | none + confidence: float = 0.0 + keyword_hits: int = 0 + total_keywords: int = 0 + embedding_score: float = 0.0 + domain_bonus_applied: bool = False + composable_patterns: list[str] = field(default_factory=list) + + def to_dict(self) -> dict: + return { + "pattern_id": self.pattern_id, + "method": self.method, + "confidence": round(self.confidence, 3), + "keyword_hits": self.keyword_hits, + "total_keywords": self.total_keywords, + "embedding_score": round(self.embedding_score, 3), + "domain_bonus_applied": self.domain_bonus_applied, + "composable_patterns": self.composable_patterns, + } + + +class PatternMatcher: + """Links obligations to control patterns using keyword + embedding matching. + + Usage:: + + matcher = PatternMatcher() + await matcher.initialize() + + result = await matcher.match( + obligation_text="Fuehrung eines Verarbeitungsverzeichnisses...", + regulation_id="dsgvo", + ) + print(result.pattern_id) # e.g. "CP-COMP-001" + print(result.confidence) # e.g. 0.85 + """ + + def __init__(self): + self._patterns: list[ControlPattern] = [] + self._by_id: dict[str, ControlPattern] = {} + self._by_domain: dict[str, list[ControlPattern]] = {} + self._keyword_index: dict[str, list[str]] = {} # keyword → [pattern_ids] + self._pattern_embeddings: list[list[float]] = [] + self._pattern_ids: list[str] = [] + self._initialized = False + + async def initialize(self) -> None: + """Load patterns from YAML and compute embeddings.""" + if self._initialized: + return + + self._load_patterns() + self._build_keyword_index() + await self._compute_embeddings() + self._initialized = True + logger.info( + "PatternMatcher initialized: %d patterns, %d keywords, %d embeddings", + len(self._patterns), + len(self._keyword_index), + sum(1 for e in self._pattern_embeddings if e), + ) + + async def match( + self, + obligation_text: str, + regulation_id: Optional[str] = None, + top_n: int = 1, + ) -> PatternMatchResult: + """Match obligation text to the best control pattern. + + Args: + obligation_text: The obligation description to match against. + regulation_id: Source regulation (for domain bonus). + top_n: Number of top results to consider for composability. + + Returns: + PatternMatchResult with the best match. + """ + if not self._initialized: + await self.initialize() + + if not obligation_text or not self._patterns: + return PatternMatchResult() + + # Tier 1: Keyword matching + keyword_result = self._tier1_keyword(obligation_text, regulation_id) + + # Tier 2: Embedding matching + embedding_result = await self._tier2_embedding(obligation_text, regulation_id) + + # Combine scores: prefer keyword match, boost with embedding if available + best = self._combine_results(keyword_result, embedding_result) + + # Attach composable patterns + if best.pattern: + best.composable_patterns = [ + pid for pid in best.pattern.composable_with + if pid in self._by_id + ] + + return best + + async def match_top_n( + self, + obligation_text: str, + regulation_id: Optional[str] = None, + n: int = 3, + ) -> list[PatternMatchResult]: + """Return top-N pattern matches sorted by confidence descending.""" + if not self._initialized: + await self.initialize() + + if not obligation_text or not self._patterns: + return [] + + keyword_scores = self._keyword_scores(obligation_text, regulation_id) + embedding_scores = await self._embedding_scores(obligation_text, regulation_id) + + # Merge scores + all_pattern_ids = set(keyword_scores.keys()) | set(embedding_scores.keys()) + results: list[PatternMatchResult] = [] + + for pid in all_pattern_ids: + pattern = self._by_id.get(pid) + if not pattern: + continue + + kw_score = keyword_scores.get(pid, (0, 0, 0.0)) # (hits, total, score) + emb_score = embedding_scores.get(pid, (0.0, False)) # (score, bonus_applied) + + kw_hits, kw_total, kw_confidence = kw_score + emb_confidence, bonus_applied = emb_score + + # Combined confidence: max of keyword and embedding, with boost if both + if kw_confidence > 0 and emb_confidence > 0: + combined = max(kw_confidence, emb_confidence) + 0.05 + method = "combined" + elif kw_confidence > 0: + combined = kw_confidence + method = "keyword" + else: + combined = emb_confidence + method = "embedding" + + results.append(PatternMatchResult( + pattern=pattern, + pattern_id=pid, + method=method, + confidence=min(combined, 1.0), + keyword_hits=kw_hits, + total_keywords=kw_total, + embedding_score=emb_confidence, + domain_bonus_applied=bonus_applied, + composable_patterns=[ + p for p in pattern.composable_with if p in self._by_id + ], + )) + + # Sort by confidence descending + results.sort(key=lambda r: r.confidence, reverse=True) + return results[:n] + + # ----------------------------------------------------------------------- + # Tier 1: Keyword Match + # ----------------------------------------------------------------------- + + def _tier1_keyword( + self, obligation_text: str, regulation_id: Optional[str] + ) -> Optional[PatternMatchResult]: + """Match by counting keyword hits in the obligation text.""" + scores = self._keyword_scores(obligation_text, regulation_id) + if not scores: + return None + + # Find best match + best_pid = max(scores, key=lambda pid: scores[pid][2]) + hits, total, confidence = scores[best_pid] + + if hits < KEYWORD_MATCH_MIN_HITS: + return None + + pattern = self._by_id.get(best_pid) + if not pattern: + return None + + # Check domain bonus + bonus_applied = False + if regulation_id and self._domain_matches(pattern.domain, regulation_id): + confidence = min(confidence + DOMAIN_BONUS, 1.0) + bonus_applied = True + + return PatternMatchResult( + pattern=pattern, + pattern_id=best_pid, + method="keyword", + confidence=confidence, + keyword_hits=hits, + total_keywords=total, + domain_bonus_applied=bonus_applied, + ) + + def _keyword_scores( + self, text: str, regulation_id: Optional[str] + ) -> dict[str, tuple[int, int, float]]: + """Compute keyword match scores for all patterns. + + Returns dict: pattern_id → (hits, total_keywords, confidence). + """ + text_lower = text.lower() + hits_by_pattern: dict[str, int] = {} + + for keyword, pattern_ids in self._keyword_index.items(): + if keyword in text_lower: + for pid in pattern_ids: + hits_by_pattern[pid] = hits_by_pattern.get(pid, 0) + 1 + + result: dict[str, tuple[int, int, float]] = {} + for pid, hits in hits_by_pattern.items(): + pattern = self._by_id.get(pid) + if not pattern: + continue + total = len(pattern.obligation_match_keywords) + confidence = hits / total if total > 0 else 0.0 + result[pid] = (hits, total, confidence) + + return result + + # ----------------------------------------------------------------------- + # Tier 2: Embedding Match + # ----------------------------------------------------------------------- + + async def _tier2_embedding( + self, obligation_text: str, regulation_id: Optional[str] + ) -> Optional[PatternMatchResult]: + """Match by embedding similarity against pattern objective_templates.""" + scores = await self._embedding_scores(obligation_text, regulation_id) + if not scores: + return None + + best_pid = max(scores, key=lambda pid: scores[pid][0]) + emb_score, bonus_applied = scores[best_pid] + + if emb_score < EMBEDDING_PATTERN_THRESHOLD: + return None + + pattern = self._by_id.get(best_pid) + if not pattern: + return None + + return PatternMatchResult( + pattern=pattern, + pattern_id=best_pid, + method="embedding", + confidence=min(emb_score, 1.0), + embedding_score=emb_score, + domain_bonus_applied=bonus_applied, + ) + + async def _embedding_scores( + self, obligation_text: str, regulation_id: Optional[str] + ) -> dict[str, tuple[float, bool]]: + """Compute embedding similarity scores for all patterns. + + Returns dict: pattern_id → (score, domain_bonus_applied). + """ + if not self._pattern_embeddings: + return {} + + chunk_embedding = await _get_embedding(obligation_text[:2000]) + if not chunk_embedding: + return {} + + result: dict[str, tuple[float, bool]] = {} + for i, pat_emb in enumerate(self._pattern_embeddings): + if not pat_emb: + continue + pid = self._pattern_ids[i] + pattern = self._by_id.get(pid) + if not pattern: + continue + + score = _cosine_sim(chunk_embedding, pat_emb) + + # Domain bonus + bonus_applied = False + if regulation_id and self._domain_matches(pattern.domain, regulation_id): + score += DOMAIN_BONUS + bonus_applied = True + + result[pid] = (score, bonus_applied) + + return result + + # ----------------------------------------------------------------------- + # Score combination + # ----------------------------------------------------------------------- + + def _combine_results( + self, + keyword_result: Optional[PatternMatchResult], + embedding_result: Optional[PatternMatchResult], + ) -> PatternMatchResult: + """Combine keyword and embedding results into the best match.""" + if not keyword_result and not embedding_result: + return PatternMatchResult() + + if not keyword_result: + return embedding_result + if not embedding_result: + return keyword_result + + # Both matched — check if they agree + if keyword_result.pattern_id == embedding_result.pattern_id: + # Same pattern: boost confidence + combined_confidence = min( + max(keyword_result.confidence, embedding_result.confidence) + 0.05, + 1.0, + ) + return PatternMatchResult( + pattern=keyword_result.pattern, + pattern_id=keyword_result.pattern_id, + method="combined", + confidence=combined_confidence, + keyword_hits=keyword_result.keyword_hits, + total_keywords=keyword_result.total_keywords, + embedding_score=embedding_result.embedding_score, + domain_bonus_applied=( + keyword_result.domain_bonus_applied + or embedding_result.domain_bonus_applied + ), + ) + + # Different patterns: pick the one with higher confidence + if keyword_result.confidence >= embedding_result.confidence: + return keyword_result + return embedding_result + + # ----------------------------------------------------------------------- + # Domain affinity + # ----------------------------------------------------------------------- + + @staticmethod + def _domain_matches(pattern_domain: str, regulation_id: str) -> bool: + """Check if a pattern's domain has affinity with a regulation.""" + affine_domains = _REGULATION_DOMAIN_AFFINITY.get(regulation_id, []) + return pattern_domain in affine_domains + + # ----------------------------------------------------------------------- + # Initialization helpers + # ----------------------------------------------------------------------- + + def _load_patterns(self) -> None: + """Load control patterns from YAML files.""" + patterns_dir = _find_patterns_dir() + if not patterns_dir: + logger.warning("Control patterns directory not found") + return + + for yaml_file in sorted(patterns_dir.glob("*.yaml")): + if yaml_file.name.startswith("_"): + continue + try: + with open(yaml_file) as f: + data = yaml.safe_load(f) + if not data or "patterns" not in data: + continue + for p in data["patterns"]: + pattern = ControlPattern( + id=p["id"], + name=p["name"], + name_de=p["name_de"], + domain=p["domain"], + category=p["category"], + description=p["description"], + objective_template=p["objective_template"], + rationale_template=p["rationale_template"], + requirements_template=p.get("requirements_template", []), + test_procedure_template=p.get("test_procedure_template", []), + evidence_template=p.get("evidence_template", []), + severity_default=p.get("severity_default", "medium"), + implementation_effort_default=p.get("implementation_effort_default", "m"), + obligation_match_keywords=p.get("obligation_match_keywords", []), + tags=p.get("tags", []), + composable_with=p.get("composable_with", []), + open_anchor_refs=p.get("open_anchor_refs", []), + ) + self._patterns.append(pattern) + self._by_id[pattern.id] = pattern + domain_list = self._by_domain.setdefault(pattern.domain, []) + domain_list.append(pattern) + except Exception as e: + logger.error("Failed to load %s: %s", yaml_file.name, e) + + logger.info("Loaded %d patterns from %s", len(self._patterns), patterns_dir) + + def _build_keyword_index(self) -> None: + """Build reverse index: keyword → [pattern_ids].""" + for pattern in self._patterns: + for kw in pattern.obligation_match_keywords: + lower_kw = kw.lower() + if lower_kw not in self._keyword_index: + self._keyword_index[lower_kw] = [] + self._keyword_index[lower_kw].append(pattern.id) + + async def _compute_embeddings(self) -> None: + """Compute embeddings for all pattern objective templates.""" + if not self._patterns: + return + + self._pattern_ids = [p.id for p in self._patterns] + texts = [ + f"{p.name_de}: {p.objective_template}" + for p in self._patterns + ] + + logger.info("Computing embeddings for %d patterns...", len(texts)) + self._pattern_embeddings = await _get_embeddings_batch(texts) + valid = sum(1 for e in self._pattern_embeddings if e) + logger.info("Got %d/%d valid pattern embeddings", valid, len(texts)) + + # ----------------------------------------------------------------------- + # Public helpers + # ----------------------------------------------------------------------- + + def get_pattern(self, pattern_id: str) -> Optional[ControlPattern]: + """Get a pattern by its ID.""" + return self._by_id.get(pattern_id.upper()) + + def get_patterns_by_domain(self, domain: str) -> list[ControlPattern]: + """Get all patterns for a domain.""" + return self._by_domain.get(domain.upper(), []) + + def stats(self) -> dict: + """Return matcher statistics.""" + return { + "total_patterns": len(self._patterns), + "domains": list(self._by_domain.keys()), + "keywords": len(self._keyword_index), + "embeddings_valid": sum(1 for e in self._pattern_embeddings if e), + "initialized": self._initialized, + } + + +def _find_patterns_dir() -> Optional[Path]: + """Locate the control_patterns directory.""" + candidates = [ + Path(__file__).resolve().parent.parent.parent.parent + / "ai-compliance-sdk" / "policies" / "control_patterns", + Path("/app/ai-compliance-sdk/policies/control_patterns"), + Path("ai-compliance-sdk/policies/control_patterns"), + ] + for p in candidates: + if p.is_dir(): + return p + return None diff --git a/control-pipeline/services/pipeline_adapter.py b/control-pipeline/services/pipeline_adapter.py new file mode 100644 index 0000000..ceb8c04 --- /dev/null +++ b/control-pipeline/services/pipeline_adapter.py @@ -0,0 +1,670 @@ +"""Pipeline Adapter — New 10-Stage Pipeline Integration. + +Bridges the existing 7-stage control_generator pipeline with the new +multi-layer components (ObligationExtractor, PatternMatcher, ControlComposer). + +New pipeline flow: + chunk → license_classify + → obligation_extract (Stage 4 — NEW) + → pattern_match (Stage 5 — NEW) + → control_compose (Stage 6 — replaces old Stage 3) + → harmonize → anchor → store + crosswalk → mark processed + +Can be used in two modes: + 1. INLINE: Called from _process_batch() to enrich the pipeline + 2. STANDALONE: Process chunks directly through new stages + +Part of the Multi-Layer Control Architecture (Phase 7 of 8). +""" + +import hashlib +import json +import logging +from dataclasses import dataclass, field +from typing import Optional + +from sqlalchemy import text +from sqlalchemy.orm import Session + +from services.control_composer import ComposedControl, ControlComposer +from services.obligation_extractor import ObligationExtractor, ObligationMatch +from services.pattern_matcher import PatternMatcher, PatternMatchResult + +logger = logging.getLogger(__name__) + + +@dataclass +class PipelineChunk: + """Input chunk for the new pipeline stages.""" + + text: str + collection: str = "" + regulation_code: str = "" + article: Optional[str] = None + paragraph: Optional[str] = None + license_rule: int = 3 + license_info: dict = field(default_factory=dict) + source_citation: Optional[dict] = None + chunk_hash: str = "" + + def compute_hash(self) -> str: + if not self.chunk_hash: + self.chunk_hash = hashlib.sha256(self.text.encode()).hexdigest() + return self.chunk_hash + + +@dataclass +class PipelineResult: + """Result of processing a chunk through the new pipeline.""" + + chunk: PipelineChunk + obligation: ObligationMatch = field(default_factory=ObligationMatch) + pattern_result: PatternMatchResult = field(default_factory=PatternMatchResult) + control: Optional[ComposedControl] = None + crosswalk_written: bool = False + error: Optional[str] = None + + def to_dict(self) -> dict: + return { + "chunk_hash": self.chunk.chunk_hash, + "obligation": self.obligation.to_dict() if self.obligation else None, + "pattern": self.pattern_result.to_dict() if self.pattern_result else None, + "control": self.control.to_dict() if self.control else None, + "crosswalk_written": self.crosswalk_written, + "error": self.error, + } + + +class PipelineAdapter: + """Integrates ObligationExtractor + PatternMatcher + ControlComposer. + + Usage:: + + adapter = PipelineAdapter(db) + await adapter.initialize() + + result = await adapter.process_chunk(PipelineChunk( + text="...", + regulation_code="eu_2016_679", + article="Art. 30", + license_rule=1, + )) + """ + + def __init__(self, db: Optional[Session] = None): + self.db = db + self._extractor = ObligationExtractor() + self._matcher = PatternMatcher() + self._composer = ControlComposer() + self._initialized = False + + async def initialize(self) -> None: + """Initialize all sub-components.""" + if self._initialized: + return + await self._extractor.initialize() + await self._matcher.initialize() + self._initialized = True + logger.info("PipelineAdapter initialized") + + async def process_chunk(self, chunk: PipelineChunk) -> PipelineResult: + """Process a single chunk through the new 3-stage pipeline. + + Stage 4: Obligation Extract + Stage 5: Pattern Match + Stage 6: Control Compose + """ + if not self._initialized: + await self.initialize() + + chunk.compute_hash() + result = PipelineResult(chunk=chunk) + + try: + # Stage 4: Obligation Extract + result.obligation = await self._extractor.extract( + chunk_text=chunk.text, + regulation_code=chunk.regulation_code, + article=chunk.article, + paragraph=chunk.paragraph, + ) + + # Stage 5: Pattern Match + obligation_text = ( + result.obligation.obligation_text + or result.obligation.obligation_title + or chunk.text[:500] + ) + result.pattern_result = await self._matcher.match( + obligation_text=obligation_text, + regulation_id=result.obligation.regulation_id, + ) + + # Stage 6: Control Compose + result.control = await self._composer.compose( + obligation=result.obligation, + pattern_result=result.pattern_result, + chunk_text=chunk.text if chunk.license_rule in (1, 2) else None, + license_rule=chunk.license_rule, + source_citation=chunk.source_citation, + regulation_code=chunk.regulation_code, + ) + + except Exception as e: + logger.error("Pipeline processing failed: %s", e) + result.error = str(e) + + return result + + async def process_batch(self, chunks: list[PipelineChunk]) -> list[PipelineResult]: + """Process multiple chunks through the pipeline.""" + results = [] + for chunk in chunks: + result = await self.process_chunk(chunk) + results.append(result) + return results + + def write_crosswalk(self, result: PipelineResult, control_uuid: str) -> bool: + """Write obligation_extraction + crosswalk_matrix rows for a processed chunk. + + Called AFTER the control is stored in canonical_controls. + """ + if not self.db or not result.control: + return False + + chunk = result.chunk + obligation = result.obligation + pattern = result.pattern_result + + try: + # 1. Write obligation_extraction row + self.db.execute( + text(""" + INSERT INTO obligation_extractions ( + chunk_hash, collection, regulation_code, + article, paragraph, obligation_id, + obligation_text, confidence, extraction_method, + pattern_id, pattern_match_score, control_uuid + ) VALUES ( + :chunk_hash, :collection, :regulation_code, + :article, :paragraph, :obligation_id, + :obligation_text, :confidence, :extraction_method, + :pattern_id, :pattern_match_score, + CAST(:control_uuid AS uuid) + ) + """), + { + "chunk_hash": chunk.chunk_hash, + "collection": chunk.collection, + "regulation_code": chunk.regulation_code, + "article": chunk.article, + "paragraph": chunk.paragraph, + "obligation_id": obligation.obligation_id if obligation else None, + "obligation_text": ( + obligation.obligation_text[:2000] + if obligation and obligation.obligation_text + else None + ), + "confidence": obligation.confidence if obligation else 0, + "extraction_method": obligation.method if obligation else "none", + "pattern_id": pattern.pattern_id if pattern else None, + "pattern_match_score": pattern.confidence if pattern else 0, + "control_uuid": control_uuid, + }, + ) + + # 2. Write crosswalk_matrix row + self.db.execute( + text(""" + INSERT INTO crosswalk_matrix ( + regulation_code, article, paragraph, + obligation_id, pattern_id, + master_control_id, master_control_uuid, + confidence, source + ) VALUES ( + :regulation_code, :article, :paragraph, + :obligation_id, :pattern_id, + :master_control_id, + CAST(:master_control_uuid AS uuid), + :confidence, :source + ) + """), + { + "regulation_code": chunk.regulation_code, + "article": chunk.article, + "paragraph": chunk.paragraph, + "obligation_id": obligation.obligation_id if obligation else None, + "pattern_id": pattern.pattern_id if pattern else None, + "master_control_id": result.control.control_id, + "master_control_uuid": control_uuid, + "confidence": min( + obligation.confidence if obligation else 0, + pattern.confidence if pattern else 0, + ), + "source": "auto", + }, + ) + + # 3. Update canonical_controls with pattern_id + obligation_ids + if result.control.pattern_id or result.control.obligation_ids: + self.db.execute( + text(""" + UPDATE canonical_controls + SET pattern_id = COALESCE(:pattern_id, pattern_id), + obligation_ids = COALESCE(:obligation_ids, obligation_ids) + WHERE id = CAST(:control_uuid AS uuid) + """), + { + "pattern_id": result.control.pattern_id, + "obligation_ids": json.dumps(result.control.obligation_ids), + "control_uuid": control_uuid, + }, + ) + + self.db.commit() + result.crosswalk_written = True + return True + + except Exception as e: + logger.error("Failed to write crosswalk: %s", e) + self.db.rollback() + return False + + def stats(self) -> dict: + """Return component statistics.""" + return { + "extractor": self._extractor.stats(), + "matcher": self._matcher.stats(), + "initialized": self._initialized, + } + + +# --------------------------------------------------------------------------- +# Migration Passes — Backfill existing 4,800+ controls +# --------------------------------------------------------------------------- + + +class MigrationPasses: + """Non-destructive migration passes for existing controls. + + Pass 1: Obligation Linkage (deterministic, article→obligation lookup) + Pass 2: Pattern Classification (keyword-based matching) + Pass 3: Quality Triage (categorize by linkage completeness) + Pass 4: Crosswalk Backfill (write crosswalk rows for linked controls) + Pass 5: Deduplication (mark duplicate controls) + + Usage:: + + migration = MigrationPasses(db) + await migration.initialize() + + result = await migration.run_pass1_obligation_linkage(limit=100) + result = await migration.run_pass2_pattern_classification(limit=100) + result = migration.run_pass3_quality_triage() + result = migration.run_pass4_crosswalk_backfill() + result = migration.run_pass5_deduplication() + """ + + def __init__(self, db: Session): + self.db = db + self._extractor = ObligationExtractor() + self._matcher = PatternMatcher() + self._initialized = False + + async def initialize(self) -> None: + """Initialize extractors (loads obligations + patterns).""" + if self._initialized: + return + self._extractor._load_obligations() + self._matcher._load_patterns() + self._matcher._build_keyword_index() + self._initialized = True + + # ------------------------------------------------------------------- + # Pass 1: Obligation Linkage (deterministic) + # ------------------------------------------------------------------- + + async def run_pass1_obligation_linkage(self, limit: int = 0) -> dict: + """Link existing controls to obligations via source_citation article. + + For each control with source_citation → extract regulation + article + → look up in obligation framework → set obligation_ids. + """ + if not self._initialized: + await self.initialize() + + query = """ + SELECT id, control_id, source_citation, generation_metadata + FROM canonical_controls + WHERE release_state NOT IN ('deprecated') + AND (obligation_ids IS NULL OR obligation_ids = '[]') + """ + if limit > 0: + query += f" LIMIT {limit}" + + rows = self.db.execute(text(query)).fetchall() + + stats = {"total": len(rows), "linked": 0, "no_match": 0, "no_citation": 0} + + for row in rows: + control_uuid = str(row[0]) + control_id = row[1] + citation = row[2] + metadata = row[3] + + # Extract regulation + article from citation or metadata + reg_code, article = _extract_regulation_article(citation, metadata) + if not reg_code: + stats["no_citation"] += 1 + continue + + # Tier 1: Exact match + match = self._extractor._tier1_exact(reg_code, article or "") + if match and match.obligation_id: + self.db.execute( + text(""" + UPDATE canonical_controls + SET obligation_ids = :obl_ids + WHERE id = CAST(:uuid AS uuid) + """), + { + "obl_ids": json.dumps([match.obligation_id]), + "uuid": control_uuid, + }, + ) + stats["linked"] += 1 + else: + stats["no_match"] += 1 + + self.db.commit() + logger.info("Pass 1: %s", stats) + return stats + + # ------------------------------------------------------------------- + # Pass 2: Pattern Classification (keyword-based) + # ------------------------------------------------------------------- + + async def run_pass2_pattern_classification(self, limit: int = 0) -> dict: + """Classify existing controls into patterns via keyword matching. + + For each control without pattern_id → keyword-match title+objective + against pattern library → assign best match. + """ + if not self._initialized: + await self.initialize() + + query = """ + SELECT id, control_id, title, objective + FROM canonical_controls + WHERE release_state NOT IN ('deprecated') + AND (pattern_id IS NULL OR pattern_id = '') + """ + if limit > 0: + query += f" LIMIT {limit}" + + rows = self.db.execute(text(query)).fetchall() + + stats = {"total": len(rows), "classified": 0, "no_match": 0} + + for row in rows: + control_uuid = str(row[0]) + title = row[2] or "" + objective = row[3] or "" + + # Keyword match + match_text = f"{title} {objective}" + result = self._matcher._tier1_keyword(match_text, None) + + if result and result.pattern_id and result.keyword_hits >= 2: + self.db.execute( + text(""" + UPDATE canonical_controls + SET pattern_id = :pattern_id + WHERE id = CAST(:uuid AS uuid) + """), + { + "pattern_id": result.pattern_id, + "uuid": control_uuid, + }, + ) + stats["classified"] += 1 + else: + stats["no_match"] += 1 + + self.db.commit() + logger.info("Pass 2: %s", stats) + return stats + + # ------------------------------------------------------------------- + # Pass 3: Quality Triage + # ------------------------------------------------------------------- + + def run_pass3_quality_triage(self) -> dict: + """Categorize controls by linkage completeness. + + Sets generation_metadata.triage_status: + - "review": has both obligation_id + pattern_id + - "needs_obligation": has pattern_id but no obligation_id + - "needs_pattern": has obligation_id but no pattern_id + - "legacy_unlinked": has neither + """ + categories = { + "review": """ + UPDATE canonical_controls + SET generation_metadata = jsonb_set( + COALESCE(generation_metadata::jsonb, '{}'::jsonb), + '{triage_status}', '"review"' + ) + WHERE release_state NOT IN ('deprecated') + AND obligation_ids IS NOT NULL AND obligation_ids != '[]' + AND pattern_id IS NOT NULL AND pattern_id != '' + """, + "needs_obligation": """ + UPDATE canonical_controls + SET generation_metadata = jsonb_set( + COALESCE(generation_metadata::jsonb, '{}'::jsonb), + '{triage_status}', '"needs_obligation"' + ) + WHERE release_state NOT IN ('deprecated') + AND (obligation_ids IS NULL OR obligation_ids = '[]') + AND pattern_id IS NOT NULL AND pattern_id != '' + """, + "needs_pattern": """ + UPDATE canonical_controls + SET generation_metadata = jsonb_set( + COALESCE(generation_metadata::jsonb, '{}'::jsonb), + '{triage_status}', '"needs_pattern"' + ) + WHERE release_state NOT IN ('deprecated') + AND obligation_ids IS NOT NULL AND obligation_ids != '[]' + AND (pattern_id IS NULL OR pattern_id = '') + """, + "legacy_unlinked": """ + UPDATE canonical_controls + SET generation_metadata = jsonb_set( + COALESCE(generation_metadata::jsonb, '{}'::jsonb), + '{triage_status}', '"legacy_unlinked"' + ) + WHERE release_state NOT IN ('deprecated') + AND (obligation_ids IS NULL OR obligation_ids = '[]') + AND (pattern_id IS NULL OR pattern_id = '') + """, + } + + stats = {} + for category, sql in categories.items(): + result = self.db.execute(text(sql)) + stats[category] = result.rowcount + + self.db.commit() + logger.info("Pass 3: %s", stats) + return stats + + # ------------------------------------------------------------------- + # Pass 4: Crosswalk Backfill + # ------------------------------------------------------------------- + + def run_pass4_crosswalk_backfill(self) -> dict: + """Create crosswalk_matrix rows for controls with obligation + pattern. + + Only creates rows that don't already exist. + """ + result = self.db.execute(text(""" + INSERT INTO crosswalk_matrix ( + regulation_code, obligation_id, pattern_id, + master_control_id, master_control_uuid, + confidence, source + ) + SELECT + COALESCE( + (generation_metadata::jsonb->>'source_regulation'), + '' + ) AS regulation_code, + obl.value::text AS obligation_id, + cc.pattern_id, + cc.control_id, + cc.id, + 0.80, + 'migrated' + FROM canonical_controls cc, + jsonb_array_elements_text( + COALESCE(cc.obligation_ids::jsonb, '[]'::jsonb) + ) AS obl(value) + WHERE cc.release_state NOT IN ('deprecated') + AND cc.pattern_id IS NOT NULL AND cc.pattern_id != '' + AND cc.obligation_ids IS NOT NULL AND cc.obligation_ids != '[]' + AND NOT EXISTS ( + SELECT 1 FROM crosswalk_matrix cw + WHERE cw.master_control_uuid = cc.id + AND cw.obligation_id = obl.value::text + ) + """)) + + rows_inserted = result.rowcount + self.db.commit() + logger.info("Pass 4: %d crosswalk rows inserted", rows_inserted) + return {"rows_inserted": rows_inserted} + + # ------------------------------------------------------------------- + # Pass 5: Deduplication + # ------------------------------------------------------------------- + + def run_pass5_deduplication(self) -> dict: + """Mark duplicate controls (same obligation + same pattern). + + Groups controls by (obligation_id, pattern_id), keeps the one with + highest evidence_confidence (or newest), marks rest as deprecated. + """ + # Find groups with duplicates + groups = self.db.execute(text(""" + SELECT cc.pattern_id, + obl.value::text AS obligation_id, + array_agg(cc.id ORDER BY cc.evidence_confidence DESC NULLS LAST, cc.created_at DESC) AS ids, + count(*) AS cnt + FROM canonical_controls cc, + jsonb_array_elements_text( + COALESCE(cc.obligation_ids::jsonb, '[]'::jsonb) + ) AS obl(value) + WHERE cc.release_state NOT IN ('deprecated') + AND cc.pattern_id IS NOT NULL AND cc.pattern_id != '' + GROUP BY cc.pattern_id, obl.value::text + HAVING count(*) > 1 + """)).fetchall() + + stats = {"groups_found": len(groups), "controls_deprecated": 0} + + for group in groups: + ids = group[2] # Array of UUIDs, first is the keeper + if len(ids) <= 1: + continue + + # Keep first (highest confidence), deprecate rest + deprecate_ids = ids[1:] + for dep_id in deprecate_ids: + self.db.execute( + text(""" + UPDATE canonical_controls + SET release_state = 'deprecated', + generation_metadata = jsonb_set( + COALESCE(generation_metadata::jsonb, '{}'::jsonb), + '{deprecated_reason}', '"duplicate_same_obligation_pattern"' + ) + WHERE id = CAST(:uuid AS uuid) + AND release_state != 'deprecated' + """), + {"uuid": str(dep_id)}, + ) + stats["controls_deprecated"] += 1 + + self.db.commit() + logger.info("Pass 5: %s", stats) + return stats + + def migration_status(self) -> dict: + """Return overall migration progress.""" + row = self.db.execute(text(""" + SELECT + count(*) AS total, + count(*) FILTER (WHERE obligation_ids IS NOT NULL AND obligation_ids != '[]') AS has_obligation, + count(*) FILTER (WHERE pattern_id IS NOT NULL AND pattern_id != '') AS has_pattern, + count(*) FILTER ( + WHERE obligation_ids IS NOT NULL AND obligation_ids != '[]' + AND pattern_id IS NOT NULL AND pattern_id != '' + ) AS fully_linked, + count(*) FILTER (WHERE release_state = 'deprecated') AS deprecated + FROM canonical_controls + """)).fetchone() + + return { + "total_controls": row[0], + "has_obligation": row[1], + "has_pattern": row[2], + "fully_linked": row[3], + "deprecated": row[4], + "coverage_obligation_pct": round(row[1] / max(row[0], 1) * 100, 1), + "coverage_pattern_pct": round(row[2] / max(row[0], 1) * 100, 1), + "coverage_full_pct": round(row[3] / max(row[0], 1) * 100, 1), + } + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _extract_regulation_article( + citation: Optional[str], metadata: Optional[str] +) -> tuple[Optional[str], Optional[str]]: + """Extract regulation_code and article from control's citation/metadata.""" + from services.obligation_extractor import _normalize_regulation + + reg_code = None + article = None + + # Try citation first (JSON string or dict) + if citation: + try: + c = json.loads(citation) if isinstance(citation, str) else citation + if isinstance(c, dict): + article = c.get("article") or c.get("source_article") + # Try to get regulation from source field + source = c.get("source", "") + if source: + reg_code = _normalize_regulation(source) + except (json.JSONDecodeError, TypeError): + pass + + # Try metadata + if metadata and not reg_code: + try: + m = json.loads(metadata) if isinstance(metadata, str) else metadata + if isinstance(m, dict): + src_reg = m.get("source_regulation", "") + if src_reg: + reg_code = _normalize_regulation(src_reg) + if not article: + article = m.get("source_article") + except (json.JSONDecodeError, TypeError): + pass + + return reg_code, article diff --git a/control-pipeline/services/rag_client.py b/control-pipeline/services/rag_client.py new file mode 100644 index 0000000..1000a9c --- /dev/null +++ b/control-pipeline/services/rag_client.py @@ -0,0 +1,213 @@ +""" +Compliance RAG Client — Proxy to Go SDK RAG Search. + +Lightweight HTTP client that queries the Go AI Compliance SDK's +POST /sdk/v1/rag/search endpoint. This avoids needing embedding +models or direct Qdrant access in Python. + +Error-tolerant: RAG failures never break the calling function. +""" + +import logging +import os +from dataclasses import dataclass +from typing import List, Optional + +import httpx + +logger = logging.getLogger(__name__) + +SDK_URL = os.getenv("SDK_URL", "http://ai-compliance-sdk:8090") +RAG_SEARCH_TIMEOUT = 15.0 # seconds + + +@dataclass +class RAGSearchResult: + """A single search result from the compliance corpus.""" + text: str + regulation_code: str + regulation_name: str + regulation_short: str + category: str + article: str + paragraph: str + source_url: str + score: float + collection: str = "" + + +class ComplianceRAGClient: + """ + RAG client that proxies search requests to the Go SDK. + + Usage: + client = get_rag_client() + results = await client.search("DSGVO Art. 35", collection="bp_compliance_recht") + context_str = client.format_for_prompt(results) + """ + + def __init__(self, base_url: str = SDK_URL): + self._search_url = f"{base_url}/sdk/v1/rag/search" + + async def search( + self, + query: str, + collection: str = "bp_compliance_ce", + regulations: Optional[List[str]] = None, + top_k: int = 5, + ) -> List[RAGSearchResult]: + """ + Search the RAG corpus via Go SDK. + + Returns an empty list on any error (never raises). + """ + payload = { + "query": query, + "collection": collection, + "top_k": top_k, + } + if regulations: + payload["regulations"] = regulations + + try: + async with httpx.AsyncClient(timeout=RAG_SEARCH_TIMEOUT) as client: + resp = await client.post(self._search_url, json=payload) + + if resp.status_code != 200: + logger.warning( + "RAG search returned %d: %s", resp.status_code, resp.text[:200] + ) + return [] + + data = resp.json() + results = [] + for r in data.get("results", []): + results.append(RAGSearchResult( + text=r.get("text", ""), + regulation_code=r.get("regulation_code", ""), + regulation_name=r.get("regulation_name", ""), + regulation_short=r.get("regulation_short", ""), + category=r.get("category", ""), + article=r.get("article", ""), + paragraph=r.get("paragraph", ""), + source_url=r.get("source_url", ""), + score=r.get("score", 0.0), + collection=collection, + )) + return results + + except Exception as e: + logger.warning("RAG search failed: %s", e) + return [] + + async def search_with_rerank( + self, + query: str, + collection: str = "bp_compliance_ce", + regulations: Optional[List[str]] = None, + top_k: int = 5, + ) -> List[RAGSearchResult]: + """ + Search with optional cross-encoder re-ranking. + + Fetches top_k*4 results from RAG, then re-ranks with cross-encoder + and returns top_k. Falls back to regular search if reranker is disabled. + """ + from .reranker import get_reranker + + reranker = get_reranker() + if reranker is None: + return await self.search(query, collection, regulations, top_k) + + # Fetch more candidates for re-ranking + candidates = await self.search( + query, collection, regulations, top_k=max(top_k * 4, 20) + ) + if not candidates: + return [] + + texts = [c.text for c in candidates] + try: + ranked_indices = reranker.rerank(query, texts, top_k=top_k) + return [candidates[i] for i in ranked_indices] + except Exception as e: + logger.warning("Reranking failed, returning unranked: %s", e) + return candidates[:top_k] + + async def scroll( + self, + collection: str, + offset: Optional[str] = None, + limit: int = 100, + ) -> tuple[List[RAGSearchResult], Optional[str]]: + """ + Scroll through ALL chunks in a collection (paginated). + + Returns (chunks, next_offset). next_offset is None when done. + """ + scroll_url = self._search_url.replace("/search", "/scroll") + params = {"collection": collection, "limit": str(limit)} + if offset: + params["offset"] = offset + + try: + async with httpx.AsyncClient(timeout=30.0) as client: + resp = await client.get(scroll_url, params=params) + + if resp.status_code != 200: + logger.warning( + "RAG scroll returned %d: %s", resp.status_code, resp.text[:200] + ) + return [], None + + data = resp.json() + results = [] + for r in data.get("chunks", []): + results.append(RAGSearchResult( + text=r.get("text", ""), + regulation_code=r.get("regulation_code", ""), + regulation_name=r.get("regulation_name", ""), + regulation_short=r.get("regulation_short", ""), + category=r.get("category", ""), + article=r.get("article", ""), + paragraph=r.get("paragraph", ""), + source_url=r.get("source_url", ""), + score=0.0, + collection=collection, + )) + next_offset = data.get("next_offset") or None + return results, next_offset + + except Exception as e: + logger.warning("RAG scroll failed: %s", e) + return [], None + + def format_for_prompt( + self, results: List[RAGSearchResult], max_results: int = 5 + ) -> str: + """Format search results as Markdown for inclusion in an LLM prompt.""" + if not results: + return "" + + lines = ["## Relevanter Rechtskontext\n"] + for i, r in enumerate(results[:max_results]): + header = f"{i + 1}. **{r.regulation_short}** ({r.regulation_code})" + if r.article: + header += f" — {r.article}" + lines.append(header) + text = r.text[:400] + "..." if len(r.text) > 400 else r.text + lines.append(f" > {text}\n") + + return "\n".join(lines) + + +# Singleton +_rag_client: Optional[ComplianceRAGClient] = None + + +def get_rag_client() -> ComplianceRAGClient: + """Get the shared RAG client instance.""" + global _rag_client + if _rag_client is None: + _rag_client = ComplianceRAGClient() + return _rag_client diff --git a/control-pipeline/services/reranker.py b/control-pipeline/services/reranker.py new file mode 100644 index 0000000..49e9a65 --- /dev/null +++ b/control-pipeline/services/reranker.py @@ -0,0 +1,85 @@ +""" +Cross-Encoder Re-Ranking for RAG Search Results. + +Uses BGE Reranker v2 (BAAI/bge-reranker-v2-m3, MIT license) to re-rank +search results from Qdrant for improved retrieval quality. + +Lazy-loads the model on first use. Disabled by default (RERANK_ENABLED=false). +""" + +import logging +import os +from typing import Optional + +logger = logging.getLogger(__name__) + +RERANK_ENABLED = os.getenv("RERANK_ENABLED", "false").lower() == "true" +RERANK_MODEL = os.getenv("RERANK_MODEL", "BAAI/bge-reranker-v2-m3") + + +class Reranker: + """Cross-encoder reranker using sentence-transformers.""" + + def __init__(self, model_name: str = RERANK_MODEL): + self._model = None # Lazy init + self._model_name = model_name + + def _ensure_model(self) -> None: + """Load model on first use.""" + if self._model is not None: + return + try: + from sentence_transformers import CrossEncoder + + logger.info("Loading reranker model: %s", self._model_name) + self._model = CrossEncoder(self._model_name) + logger.info("Reranker model loaded successfully") + except ImportError: + logger.error( + "sentence-transformers not installed. " + "Install with: pip install sentence-transformers" + ) + raise + except Exception as e: + logger.error("Failed to load reranker model: %s", e) + raise + + def rerank( + self, query: str, texts: list[str], top_k: int = 5 + ) -> list[int]: + """ + Return indices of top_k texts sorted by relevance (highest first). + + Args: + query: The search query. + texts: List of candidate texts to re-rank. + top_k: Number of top results to return. + + Returns: + List of indices into the original texts list, sorted by relevance. + """ + if not texts: + return [] + + self._ensure_model() + + pairs = [[query, text] for text in texts] + scores = self._model.predict(pairs) + + # Sort by score descending, return indices + ranked = sorted(range(len(scores)), key=lambda i: scores[i], reverse=True) + return ranked[:top_k] + + +# Module-level singleton +_reranker: Optional[Reranker] = None + + +def get_reranker() -> Optional[Reranker]: + """Get the shared reranker instance. Returns None if disabled.""" + global _reranker + if not RERANK_ENABLED: + return None + if _reranker is None: + _reranker = Reranker() + return _reranker diff --git a/control-pipeline/services/similarity_detector.py b/control-pipeline/services/similarity_detector.py new file mode 100644 index 0000000..b283114 --- /dev/null +++ b/control-pipeline/services/similarity_detector.py @@ -0,0 +1,223 @@ +""" +Too-Close Similarity Detector — checks whether a candidate text is too similar +to a protected source text (copyright / license compliance). + +Five metrics: + 1. Exact-phrase — longest identical token sequence + 2. Token overlap — Jaccard similarity of token sets + 3. 3-gram Jaccard — Jaccard similarity of character 3-grams + 4. Embedding cosine — via bge-m3 (Ollama or embedding-service) + 5. LCS ratio — Longest Common Subsequence / max(len_a, len_b) + +Decision: + PASS — no fail + max 1 warn + WARN — max 2 warn, no fail → human review + FAIL — any fail threshold → block, rewrite required +""" + +from __future__ import annotations + +import logging +import re +from dataclasses import dataclass +from typing import Optional + +import httpx + +logger = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# Thresholds +# --------------------------------------------------------------------------- + +THRESHOLDS = { + "max_exact_run": {"warn": 8, "fail": 12}, + "token_overlap": {"warn": 0.20, "fail": 0.30}, + "ngram_jaccard": {"warn": 0.10, "fail": 0.18}, + "embedding_cosine": {"warn": 0.86, "fail": 0.92}, + "lcs_ratio": {"warn": 0.35, "fail": 0.50}, +} + +# --------------------------------------------------------------------------- +# Tokenisation helpers +# --------------------------------------------------------------------------- + +_WORD_RE = re.compile(r"\w+", re.UNICODE) + + +def _tokenize(text: str) -> list[str]: + return [t.lower() for t in _WORD_RE.findall(text)] + + +def _char_ngrams(text: str, n: int = 3) -> set[str]: + text = text.lower() + return {text[i : i + n] for i in range(len(text) - n + 1)} if len(text) >= n else set() + + +# --------------------------------------------------------------------------- +# Metric implementations +# --------------------------------------------------------------------------- + +def max_exact_run(tokens_a: list[str], tokens_b: list[str]) -> int: + """Longest contiguous identical token sequence between a and b.""" + if not tokens_a or not tokens_b: + return 0 + + best = 0 + set_b = set(tokens_b) + + for i in range(len(tokens_a)): + if tokens_a[i] not in set_b: + continue + for j in range(len(tokens_b)): + if tokens_a[i] != tokens_b[j]: + continue + run = 0 + ii, jj = i, j + while ii < len(tokens_a) and jj < len(tokens_b) and tokens_a[ii] == tokens_b[jj]: + run += 1 + ii += 1 + jj += 1 + if run > best: + best = run + return best + + +def token_overlap_jaccard(tokens_a: list[str], tokens_b: list[str]) -> float: + """Jaccard similarity of token sets.""" + set_a, set_b = set(tokens_a), set(tokens_b) + if not set_a and not set_b: + return 0.0 + return len(set_a & set_b) / len(set_a | set_b) + + +def ngram_jaccard(text_a: str, text_b: str, n: int = 3) -> float: + """Jaccard similarity of character n-grams.""" + grams_a = _char_ngrams(text_a, n) + grams_b = _char_ngrams(text_b, n) + if not grams_a and not grams_b: + return 0.0 + return len(grams_a & grams_b) / len(grams_a | grams_b) + + +def lcs_ratio(tokens_a: list[str], tokens_b: list[str]) -> float: + """LCS length / max(len_a, len_b).""" + m, n = len(tokens_a), len(tokens_b) + if m == 0 or n == 0: + return 0.0 + + # Space-optimised LCS (two rows) + prev = [0] * (n + 1) + curr = [0] * (n + 1) + for i in range(1, m + 1): + for j in range(1, n + 1): + if tokens_a[i - 1] == tokens_b[j - 1]: + curr[j] = prev[j - 1] + 1 + else: + curr[j] = max(prev[j], curr[j - 1]) + prev, curr = curr, [0] * (n + 1) + + return prev[n] / max(m, n) + + +async def embedding_cosine(text_a: str, text_b: str, embedding_url: str | None = None) -> float: + """Cosine similarity via embedding service (bge-m3). + + Falls back to 0.0 if the service is unreachable. + """ + url = embedding_url or "http://embedding-service:8087" + try: + async with httpx.AsyncClient(timeout=10.0) as client: + resp = await client.post( + f"{url}/embed", + json={"texts": [text_a, text_b]}, + ) + resp.raise_for_status() + embeddings = resp.json().get("embeddings", []) + if len(embeddings) < 2: + return 0.0 + return _cosine(embeddings[0], embeddings[1]) + except Exception: + logger.warning("Embedding service unreachable, skipping cosine check") + return 0.0 + + +def _cosine(a: list[float], b: list[float]) -> float: + dot = sum(x * y for x, y in zip(a, b)) + norm_a = sum(x * x for x in a) ** 0.5 + norm_b = sum(x * x for x in b) ** 0.5 + if norm_a == 0 or norm_b == 0: + return 0.0 + return dot / (norm_a * norm_b) + + +# --------------------------------------------------------------------------- +# Decision engine +# --------------------------------------------------------------------------- + +@dataclass +class SimilarityReport: + max_exact_run: int + token_overlap: float + ngram_jaccard: float + embedding_cosine: float + lcs_ratio: float + status: str # PASS, WARN, FAIL + details: dict # per-metric status + + +def _classify(value: float | int, metric: str) -> str: + t = THRESHOLDS[metric] + if value >= t["fail"]: + return "FAIL" + if value >= t["warn"]: + return "WARN" + return "PASS" + + +async def check_similarity( + source_text: str, + candidate_text: str, + embedding_url: str | None = None, +) -> SimilarityReport: + """Run all 5 metrics and return an aggregate report.""" + tok_src = _tokenize(source_text) + tok_cand = _tokenize(candidate_text) + + m_exact = max_exact_run(tok_src, tok_cand) + m_token = token_overlap_jaccard(tok_src, tok_cand) + m_ngram = ngram_jaccard(source_text, candidate_text) + m_embed = await embedding_cosine(source_text, candidate_text, embedding_url) + m_lcs = lcs_ratio(tok_src, tok_cand) + + details = { + "max_exact_run": _classify(m_exact, "max_exact_run"), + "token_overlap": _classify(m_token, "token_overlap"), + "ngram_jaccard": _classify(m_ngram, "ngram_jaccard"), + "embedding_cosine": _classify(m_embed, "embedding_cosine"), + "lcs_ratio": _classify(m_lcs, "lcs_ratio"), + } + + fail_count = sum(1 for v in details.values() if v == "FAIL") + warn_count = sum(1 for v in details.values() if v == "WARN") + + if fail_count > 0: + status = "FAIL" + elif warn_count > 2: + status = "FAIL" + elif warn_count > 1: + status = "WARN" + elif warn_count == 1: + status = "PASS" + else: + status = "PASS" + + return SimilarityReport( + max_exact_run=m_exact, + token_overlap=round(m_token, 4), + ngram_jaccard=round(m_ngram, 4), + embedding_cosine=round(m_embed, 4), + lcs_ratio=round(m_lcs, 4), + status=status, + details=details, + ) diff --git a/control-pipeline/services/v1_enrichment.py b/control-pipeline/services/v1_enrichment.py new file mode 100644 index 0000000..b5ab409 --- /dev/null +++ b/control-pipeline/services/v1_enrichment.py @@ -0,0 +1,331 @@ +"""V1 Control Enrichment Service — Match Eigenentwicklung controls to regulations. + +Finds regulatory coverage for v1 controls (generation_strategy='ungrouped', +pipeline_version=1, no source_citation) by embedding similarity search. + +Reuses embedding + Qdrant helpers from control_dedup.py. +""" + +import logging +from typing import Optional + +from sqlalchemy import text + +from db.session import SessionLocal +from services.control_dedup import ( + get_embedding, + qdrant_search_cross_regulation, +) + +logger = logging.getLogger(__name__) + +# Similarity threshold — lower than dedup (0.85) since we want informational matches +# Typical top scores for v1 controls are 0.70-0.77 +V1_MATCH_THRESHOLD = 0.70 +V1_MAX_MATCHES = 5 + + +def _is_eigenentwicklung_query() -> str: + """SQL WHERE clause identifying v1 Eigenentwicklung controls.""" + return """ + generation_strategy = 'ungrouped' + AND (pipeline_version = '1' OR pipeline_version IS NULL) + AND source_citation IS NULL + AND parent_control_uuid IS NULL + AND release_state NOT IN ('rejected', 'merged', 'deprecated') + """ + + +async def count_v1_controls() -> int: + """Count how many v1 Eigenentwicklung controls exist.""" + with SessionLocal() as db: + row = db.execute(text(f""" + SELECT COUNT(*) AS cnt + FROM canonical_controls + WHERE {_is_eigenentwicklung_query()} + """)).fetchone() + return row.cnt if row else 0 + + +async def enrich_v1_matches( + dry_run: bool = True, + batch_size: int = 100, + offset: int = 0, +) -> dict: + """Find regulatory matches for v1 Eigenentwicklung controls. + + Args: + dry_run: If True, only count — don't write matches. + batch_size: Number of v1 controls to process per call. + offset: Pagination offset (v1 control index). + + Returns: + Stats dict with counts, sample matches, and pagination info. + """ + with SessionLocal() as db: + # 1. Load v1 controls (paginated) + v1_controls = db.execute(text(f""" + SELECT id, control_id, title, objective, category + FROM canonical_controls + WHERE {_is_eigenentwicklung_query()} + ORDER BY control_id + LIMIT :limit OFFSET :offset + """), {"limit": batch_size, "offset": offset}).fetchall() + + # Count total for pagination + total_row = db.execute(text(f""" + SELECT COUNT(*) AS cnt + FROM canonical_controls + WHERE {_is_eigenentwicklung_query()} + """)).fetchone() + total_v1 = total_row.cnt if total_row else 0 + + if not v1_controls: + return { + "dry_run": dry_run, + "processed": 0, + "total_v1": total_v1, + "message": "Kein weiterer Batch — alle v1 Controls verarbeitet.", + } + + if dry_run: + return { + "dry_run": True, + "total_v1": total_v1, + "offset": offset, + "batch_size": batch_size, + "sample_controls": [ + { + "control_id": r.control_id, + "title": r.title, + "category": r.category, + } + for r in v1_controls[:20] + ], + } + + # 2. Process each v1 control + processed = 0 + matches_inserted = 0 + errors = [] + sample_matches = [] + + for v1 in v1_controls: + try: + # Build search text + search_text = f"{v1.title} — {v1.objective}" + + # Get embedding + embedding = await get_embedding(search_text) + if not embedding: + errors.append({ + "control_id": v1.control_id, + "error": "Embedding fehlgeschlagen", + }) + continue + + # Search Qdrant (cross-regulation, no pattern filter) + # Collection is atomic_controls_dedup (contains ~51k atomare Controls) + results = await qdrant_search_cross_regulation( + embedding, top_k=20, + collection="atomic_controls_dedup", + ) + + # For each hit: resolve to a regulatory parent with source_citation. + # Atomic controls in Qdrant usually have parent_control_uuid → parent + # has the source_citation. We deduplicate by parent to avoid + # listing the same regulation multiple times. + rank = 0 + seen_parents: set[str] = set() + + for hit in results: + score = hit.get("score", 0) + if score < V1_MATCH_THRESHOLD: + continue + + payload = hit.get("payload", {}) + matched_uuid = payload.get("control_uuid") + if not matched_uuid or matched_uuid == str(v1.id): + continue + + # Try the matched control itself first, then its parent + matched_row = db.execute(text(""" + SELECT c.id, c.control_id, c.title, c.source_citation, + c.severity, c.category, c.parent_control_uuid + FROM canonical_controls c + WHERE c.id = CAST(:uuid AS uuid) + """), {"uuid": matched_uuid}).fetchone() + + if not matched_row: + continue + + # Resolve to regulatory control (one with source_citation) + reg_row = matched_row + if not reg_row.source_citation and reg_row.parent_control_uuid: + # Look up parent — the parent has the source_citation + parent_row = db.execute(text(""" + SELECT id, control_id, title, source_citation, + severity, category, parent_control_uuid + FROM canonical_controls + WHERE id = CAST(:uuid AS uuid) + AND source_citation IS NOT NULL + """), {"uuid": str(reg_row.parent_control_uuid)}).fetchone() + if parent_row: + reg_row = parent_row + + if not reg_row.source_citation: + continue + + # Deduplicate by parent UUID + parent_key = str(reg_row.id) + if parent_key in seen_parents: + continue + seen_parents.add(parent_key) + + rank += 1 + if rank > V1_MAX_MATCHES: + break + + # Extract source info + source_citation = reg_row.source_citation or {} + matched_source = source_citation.get("source") if isinstance(source_citation, dict) else None + matched_article = source_citation.get("article") if isinstance(source_citation, dict) else None + + # Insert match — link to the regulatory parent (not the atomic child) + db.execute(text(""" + INSERT INTO v1_control_matches + (v1_control_uuid, matched_control_uuid, similarity_score, + match_rank, matched_source, matched_article, match_method) + VALUES + (CAST(:v1_uuid AS uuid), CAST(:matched_uuid AS uuid), :score, + :rank, :source, :article, 'embedding') + ON CONFLICT (v1_control_uuid, matched_control_uuid) DO UPDATE + SET similarity_score = EXCLUDED.similarity_score, + match_rank = EXCLUDED.match_rank + """), { + "v1_uuid": str(v1.id), + "matched_uuid": str(reg_row.id), + "score": round(score, 3), + "rank": rank, + "source": matched_source, + "article": matched_article, + }) + matches_inserted += 1 + + # Collect sample + if len(sample_matches) < 20: + sample_matches.append({ + "v1_control_id": v1.control_id, + "v1_title": v1.title, + "matched_control_id": reg_row.control_id, + "matched_title": reg_row.title, + "matched_source": matched_source, + "matched_article": matched_article, + "similarity_score": round(score, 3), + "match_rank": rank, + }) + + processed += 1 + + except Exception as e: + logger.warning("V1 enrichment error for %s: %s", v1.control_id, e) + errors.append({ + "control_id": v1.control_id, + "error": str(e), + }) + + db.commit() + + # Pagination + next_offset = offset + batch_size if len(v1_controls) == batch_size else None + + return { + "dry_run": False, + "offset": offset, + "batch_size": batch_size, + "next_offset": next_offset, + "total_v1": total_v1, + "processed": processed, + "matches_inserted": matches_inserted, + "errors": errors[:10], + "sample_matches": sample_matches, + } + + +async def get_v1_matches(control_uuid: str) -> list[dict]: + """Get all regulatory matches for a specific v1 control. + + Args: + control_uuid: The UUID of the v1 control. + + Returns: + List of match dicts with control details. + """ + with SessionLocal() as db: + rows = db.execute(text(""" + SELECT + m.similarity_score, + m.match_rank, + m.matched_source, + m.matched_article, + m.match_method, + c.control_id AS matched_control_id, + c.title AS matched_title, + c.objective AS matched_objective, + c.severity AS matched_severity, + c.category AS matched_category, + c.source_citation AS matched_source_citation + FROM v1_control_matches m + JOIN canonical_controls c ON c.id = m.matched_control_uuid + WHERE m.v1_control_uuid = CAST(:uuid AS uuid) + ORDER BY m.match_rank + """), {"uuid": control_uuid}).fetchall() + + return [ + { + "matched_control_id": r.matched_control_id, + "matched_title": r.matched_title, + "matched_objective": r.matched_objective, + "matched_severity": r.matched_severity, + "matched_category": r.matched_category, + "matched_source": r.matched_source, + "matched_article": r.matched_article, + "matched_source_citation": r.matched_source_citation, + "similarity_score": float(r.similarity_score), + "match_rank": r.match_rank, + "match_method": r.match_method, + } + for r in rows + ] + + +async def get_v1_enrichment_stats() -> dict: + """Get overview stats for v1 enrichment.""" + with SessionLocal() as db: + total_v1 = db.execute(text(f""" + SELECT COUNT(*) AS cnt FROM canonical_controls + WHERE {_is_eigenentwicklung_query()} + """)).fetchone() + + matched_v1 = db.execute(text(f""" + SELECT COUNT(DISTINCT m.v1_control_uuid) AS cnt + FROM v1_control_matches m + JOIN canonical_controls c ON c.id = m.v1_control_uuid + WHERE {_is_eigenentwicklung_query().replace('release_state', 'c.release_state').replace('generation_strategy', 'c.generation_strategy').replace('pipeline_version', 'c.pipeline_version').replace('source_citation', 'c.source_citation').replace('parent_control_uuid', 'c.parent_control_uuid')} + """)).fetchone() + + total_matches = db.execute(text(""" + SELECT COUNT(*) AS cnt FROM v1_control_matches + """)).fetchone() + + avg_score = db.execute(text(""" + SELECT AVG(similarity_score) AS avg_score FROM v1_control_matches + """)).fetchone() + + return { + "total_v1_controls": total_v1.cnt if total_v1 else 0, + "v1_with_matches": matched_v1.cnt if matched_v1 else 0, + "v1_without_matches": (total_v1.cnt if total_v1 else 0) - (matched_v1.cnt if matched_v1 else 0), + "total_matches": total_matches.cnt if total_matches else 0, + "avg_similarity_score": round(float(avg_score.avg_score), 3) if avg_score and avg_score.avg_score else None, + } diff --git a/control-pipeline/tests/__init__.py b/control-pipeline/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/docker-compose.yml b/docker-compose.yml index fed4330..e934de1 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -56,6 +56,7 @@ services: - "8091:8091" # Voice Service (WSS) - "8093:8093" # AI Compliance SDK - "8097:8097" # RAG Service (NEU) + - "8098:8098" # Control Pipeline - "8443:8443" # Jitsi Meet - "3008:3008" # Admin Core - "3010:3010" # Portal Dashboard @@ -376,6 +377,50 @@ services: networks: - breakpilot-network + # ========================================================= + # CONTROL PIPELINE (Entwickler-only, nicht kundenrelevant) + # ========================================================= + control-pipeline: + build: + context: ./control-pipeline + dockerfile: Dockerfile + container_name: bp-core-control-pipeline + platform: linux/arm64 + expose: + - "8098" + environment: + PORT: 8098 + DATABASE_URL: postgresql://${POSTGRES_USER:-breakpilot}:${POSTGRES_PASSWORD:-breakpilot123}@postgres:5432/${POSTGRES_DB:-breakpilot_db} + SCHEMA_SEARCH_PATH: compliance,core,public + QDRANT_URL: http://qdrant:6333 + EMBEDDING_SERVICE_URL: http://embedding-service:8087 + ANTHROPIC_API_KEY: ${ANTHROPIC_API_KEY:-} + CONTROL_GEN_ANTHROPIC_MODEL: ${CONTROL_GEN_ANTHROPIC_MODEL:-claude-sonnet-4-6} + DECOMPOSITION_LLM_MODEL: ${DECOMPOSITION_LLM_MODEL:-claude-haiku-4-5-20251001} + OLLAMA_URL: ${OLLAMA_URL:-http://host.docker.internal:11434} + CONTROL_GEN_OLLAMA_MODEL: ${CONTROL_GEN_OLLAMA_MODEL:-qwen3.5:35b-a3b} + SDK_URL: http://ai-compliance-sdk:8090 + JWT_SECRET: ${JWT_SECRET:-your-super-secret-jwt-key-change-in-production} + ENVIRONMENT: ${ENVIRONMENT:-development} + extra_hosts: + - "host.docker.internal:host-gateway" + depends_on: + postgres: + condition: service_healthy + qdrant: + condition: service_healthy + embedding-service: + condition: service_healthy + healthcheck: + test: ["CMD", "curl", "-f", "http://127.0.0.1:8098/health"] + interval: 30s + timeout: 10s + retries: 3 + start_period: 10s + restart: unless-stopped + networks: + - breakpilot-network + embedding-service: build: context: ./embedding-service diff --git a/nginx/conf.d/default.conf b/nginx/conf.d/default.conf index 2742c6d..9c04981 100644 --- a/nginx/conf.d/default.conf +++ b/nginx/conf.d/default.conf @@ -578,6 +578,33 @@ server { } } +# ========================================================= +# CORE: Control Pipeline on port 8098 (Entwickler-only) +# ========================================================= +server { + listen 8098 ssl; + http2 on; + server_name macmini localhost; + + ssl_certificate /etc/nginx/certs/macmini.crt; + ssl_certificate_key /etc/nginx/certs/macmini.key; + ssl_protocols TLSv1.2 TLSv1.3; + ssl_ciphers ECDHE-ECDSA-AES128-GCM-SHA256:ECDHE-RSA-AES128-GCM-SHA256; + ssl_prefer_server_ciphers off; + + location / { + set $upstream_pipeline bp-core-control-pipeline:8098; + proxy_pass http://$upstream_pipeline; + proxy_http_version 1.1; + proxy_set_header Host $host; + proxy_set_header X-Real-IP $remote_addr; + proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; + proxy_set_header X-Forwarded-Proto https; + proxy_read_timeout 1800s; + proxy_send_timeout 1800s; + } +} + # ========================================================= # CORE: Edu-Search on port 8089 # =========================================================