feat: control-pipeline Service aus Compliance-Repo migriert

Control-Pipeline (Pass 0a/0b, BatchDedup, Generator) als eigenstaendiger
Service in Core, damit Compliance-Repo unabhaengig refakturiert werden kann.
Schreibt weiterhin ins compliance-Schema der shared PostgreSQL.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
Benjamin Admin
2026-04-09 14:40:47 +02:00
parent 68692ade4e
commit e3ab428b91
34 changed files with 16574 additions and 0 deletions

View File

@@ -0,0 +1,19 @@
FROM python:3.11-slim
WORKDIR /app
RUN apt-get update && apt-get install -y --no-install-recommends \
curl \
&& rm -rf /var/lib/apt/lists/*
COPY requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt
COPY . .
EXPOSE 8098
HEALTHCHECK --interval=30s --timeout=10s --start-period=10s --retries=3 \
CMD curl -f http://127.0.0.1:8098/health || exit 1
CMD ["python", "main.py"]

View File

@@ -0,0 +1,8 @@
from fastapi import APIRouter
from api.control_generator_routes import router as generator_router
from api.canonical_control_routes import router as canonical_router
router = APIRouter()
router.include_router(generator_router)
router.include_router(canonical_router)

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,67 @@
import os
class Settings:
"""Environment-based configuration for control-pipeline."""
# Database (compliance schema)
DATABASE_URL: str = os.getenv(
"DATABASE_URL",
"postgresql://breakpilot:breakpilot123@localhost:5432/breakpilot_db",
)
SCHEMA_SEARCH_PATH: str = os.getenv(
"SCHEMA_SEARCH_PATH", "compliance,core,public"
)
# Qdrant (vector search for dedup)
QDRANT_URL: str = os.getenv("QDRANT_URL", "http://localhost:6333")
QDRANT_API_KEY: str = os.getenv("QDRANT_API_KEY", "")
# Embedding Service
EMBEDDING_SERVICE_URL: str = os.getenv(
"EMBEDDING_SERVICE_URL", "http://embedding-service:8087"
)
# LLM - Anthropic
ANTHROPIC_API_KEY: str = os.getenv("ANTHROPIC_API_KEY", "")
CONTROL_GEN_ANTHROPIC_MODEL: str = os.getenv(
"CONTROL_GEN_ANTHROPIC_MODEL", "claude-sonnet-4-6"
)
DECOMPOSITION_LLM_MODEL: str = os.getenv(
"DECOMPOSITION_LLM_MODEL", "claude-haiku-4-5-20251001"
)
CONTROL_GEN_LLM_TIMEOUT: int = int(
os.getenv("CONTROL_GEN_LLM_TIMEOUT", "180")
)
# LLM - Ollama (fallback)
OLLAMA_URL: str = os.getenv(
"OLLAMA_URL", "http://host.docker.internal:11434"
)
CONTROL_GEN_OLLAMA_MODEL: str = os.getenv(
"CONTROL_GEN_OLLAMA_MODEL", "qwen3.5:35b-a3b"
)
# SDK Service (for RAG search proxy)
SDK_URL: str = os.getenv(
"SDK_URL", "http://ai-compliance-sdk:8090"
)
# Auth
JWT_SECRET: str = os.getenv("JWT_SECRET", "")
# Server
PORT: int = int(os.getenv("PORT", "8098"))
LOG_LEVEL: str = os.getenv("LOG_LEVEL", "INFO")
ENVIRONMENT: str = os.getenv("ENVIRONMENT", "development")
# Pipeline
DECOMPOSITION_BATCH_SIZE: int = int(
os.getenv("DECOMPOSITION_BATCH_SIZE", "5")
)
DECOMPOSITION_LLM_TIMEOUT: int = int(
os.getenv("DECOMPOSITION_LLM_TIMEOUT", "120")
)
settings = Settings()

View File

View File

@@ -0,0 +1,205 @@
"""
Source-Type-Klassifikation fuer Regulierungen und Frameworks.
Dreistufiges Modell der normativen Verbindlichkeit:
Stufe 1 — GESETZ (law):
Rechtlich bindend. Bussgeld bei Verstoss.
Beispiele: DSGVO, NIS2, AI Act, CRA
Stufe 2 — LEITLINIE (guideline):
Offizielle Auslegungshilfe von Aufsichtsbehoerden.
Beweislastumkehr: Wer abweicht, muss begruenden warum.
Beispiele: EDPB-Leitlinien, BSI-Standards, WP29-Dokumente
Stufe 3 — FRAMEWORK (framework):
Freiwillige Best Practices, nicht rechtsverbindlich.
Aber: Koennen als "Stand der Technik" herangezogen werden.
Beispiele: ENISA, NIST, OWASP, OECD, CISA
Mapping: source_regulation (aus control_parent_links) -> source_type
"""
# --- Typ-Definitionen ---
SOURCE_TYPE_LAW = "law" # Gesetz/Verordnung/Richtlinie — normative_strength bleibt
SOURCE_TYPE_GUIDELINE = "guideline" # Leitlinie/Standard — max "should"
SOURCE_TYPE_FRAMEWORK = "framework" # Framework/Best Practice — max "may"
# Max erlaubte normative_strength pro source_type
# DB-Constraint erlaubt: must, should, may (NICHT "can")
NORMATIVE_STRENGTH_CAP: dict[str, str] = {
SOURCE_TYPE_LAW: "must", # keine Begrenzung
SOURCE_TYPE_GUIDELINE: "should", # max "should"
SOURCE_TYPE_FRAMEWORK: "may", # max "may" (= "kann")
}
# Reihenfolge fuer Vergleiche (hoeher = staerker)
STRENGTH_ORDER: dict[str, int] = {
"may": 1, # KANN (DB-Wert)
"can": 1, # Alias — wird in cap_normative_strength zu "may" normalisiert
"should": 2,
"must": 3,
}
def cap_normative_strength(original: str, source_type: str) -> str:
"""
Begrenzt die normative_strength basierend auf dem source_type.
Beispiel:
cap_normative_strength("must", "framework") -> "may"
cap_normative_strength("should", "law") -> "should"
cap_normative_strength("must", "guideline") -> "should"
"""
cap = NORMATIVE_STRENGTH_CAP.get(source_type, "must")
cap_level = STRENGTH_ORDER.get(cap, 3)
original_level = STRENGTH_ORDER.get(original, 3)
if original_level > cap_level:
return cap
return original
def get_highest_source_type(source_types: list[str]) -> str:
"""
Bestimmt den hoechsten source_type aus einer Liste.
Ein Gesetz uebertrumpft alles.
Beispiel:
get_highest_source_type(["framework", "law"]) -> "law"
get_highest_source_type(["framework", "guideline"]) -> "guideline"
"""
type_order = {SOURCE_TYPE_FRAMEWORK: 1, SOURCE_TYPE_GUIDELINE: 2, SOURCE_TYPE_LAW: 3}
if not source_types:
return SOURCE_TYPE_FRAMEWORK
return max(source_types, key=lambda t: type_order.get(t, 0))
# ============================================================================
# Klassifikation: source_regulation -> source_type
#
# Diese Map wird fuer den Backfill und zukuenftige Pipeline-Runs verwendet.
# Neue Regulierungen hier eintragen!
# ============================================================================
SOURCE_REGULATION_CLASSIFICATION: dict[str, str] = {
# --- EU-Verordnungen (unmittelbar bindend) ---
"DSGVO (EU) 2016/679": SOURCE_TYPE_LAW,
"KI-Verordnung (EU) 2024/1689": SOURCE_TYPE_LAW,
"Cyber Resilience Act (CRA)": SOURCE_TYPE_LAW,
"NIS2-Richtlinie (EU) 2022/2555": SOURCE_TYPE_LAW,
"Data Act": SOURCE_TYPE_LAW,
"Data Governance Act (DGA)": SOURCE_TYPE_LAW,
"Markets in Crypto-Assets (MiCA)": SOURCE_TYPE_LAW,
"Maschinenverordnung (EU) 2023/1230": SOURCE_TYPE_LAW,
"Batterieverordnung (EU) 2023/1542": SOURCE_TYPE_LAW,
"AML-Verordnung": SOURCE_TYPE_LAW,
# --- EU-Richtlinien (nach nationaler Umsetzung bindend) ---
# Fuer Compliance-Zwecke wie Gesetze behandeln
# --- Nationale Gesetze ---
"Bundesdatenschutzgesetz (BDSG)": SOURCE_TYPE_LAW,
"Telekommunikationsgesetz": SOURCE_TYPE_LAW,
"Telekommunikationsgesetz Oesterreich": SOURCE_TYPE_LAW,
"Gewerbeordnung (GewO)": SOURCE_TYPE_LAW,
"Handelsgesetzbuch (HGB)": SOURCE_TYPE_LAW,
"Abgabenordnung (AO)": SOURCE_TYPE_LAW,
"IFRS-Übernahmeverordnung": SOURCE_TYPE_LAW,
"Österreichisches Datenschutzgesetz (DSG)": SOURCE_TYPE_LAW,
"LOPDGDD - Ley Orgánica de Protección de Datos (Spanien)": SOURCE_TYPE_LAW,
"Loi Informatique et Libertés (Frankreich)": SOURCE_TYPE_LAW,
"Információs önrendelkezési jog törvény (Ungarn)": SOURCE_TYPE_LAW,
"EU Blue Guide 2022": SOURCE_TYPE_LAW,
# --- EDPB/WP29 Leitlinien (offizielle Auslegungshilfe) ---
"EDPB Leitlinien 01/2019 (Zertifizierung)": SOURCE_TYPE_GUIDELINE,
"EDPB Leitlinien 01/2020 (Datentransfers)": SOURCE_TYPE_GUIDELINE,
"EDPB Leitlinien 01/2020 (Vernetzte Fahrzeuge)": SOURCE_TYPE_GUIDELINE,
"EDPB Leitlinien 01/2022 (BCR)": SOURCE_TYPE_GUIDELINE,
"EDPB Leitlinien 01/2024 (Berechtigtes Interesse)": SOURCE_TYPE_GUIDELINE,
"EDPB Leitlinien 04/2019 (Data Protection by Design)": SOURCE_TYPE_GUIDELINE,
"EDPB Leitlinien 05/2020 - Einwilligung": SOURCE_TYPE_GUIDELINE,
"EDPB Leitlinien 07/2020 (Datentransfers)": SOURCE_TYPE_GUIDELINE,
"EDPB Leitlinien 08/2020 (Social Media)": SOURCE_TYPE_GUIDELINE,
"EDPB Leitlinien 09/2022 (Data Breach)": SOURCE_TYPE_GUIDELINE,
"EDPB Leitlinien 09/2022 - Meldung von Datenschutzverletzungen": SOURCE_TYPE_GUIDELINE,
"EDPB Empfehlungen 01/2020 - Ergaenzende Massnahmen fuer Datentransfers": SOURCE_TYPE_GUIDELINE,
"EDPB Leitlinien - Berechtigtes Interesse (Art. 6(1)(f))": SOURCE_TYPE_GUIDELINE,
"WP244 Leitlinien (Profiling)": SOURCE_TYPE_GUIDELINE,
"WP251 Leitlinien (Profiling)": SOURCE_TYPE_GUIDELINE,
"WP260 Leitlinien (Transparenz)": SOURCE_TYPE_GUIDELINE,
# --- BSI Standards (behoerdliche technische Richtlinien) ---
"BSI-TR-03161-1": SOURCE_TYPE_GUIDELINE,
"BSI-TR-03161-2": SOURCE_TYPE_GUIDELINE,
"BSI-TR-03161-3": SOURCE_TYPE_GUIDELINE,
# --- ENISA (EU-Agentur, aber Empfehlungen nicht rechtsverbindlich) ---
"ENISA Cybersecurity State 2024": SOURCE_TYPE_FRAMEWORK,
"ENISA ICS/SCADA Dependencies": SOURCE_TYPE_FRAMEWORK,
"ENISA Supply Chain Good Practices": SOURCE_TYPE_FRAMEWORK,
"ENISA Threat Landscape Supply Chain": SOURCE_TYPE_FRAMEWORK,
# --- NIST (US-Standards, international als Best Practice) ---
"NIST AI Risk Management Framework": SOURCE_TYPE_FRAMEWORK,
"NIST Cybersecurity Framework 2.0": SOURCE_TYPE_FRAMEWORK,
"NIST SP 800-207 (Zero Trust)": SOURCE_TYPE_FRAMEWORK,
"NIST SP 800-218 (SSDF)": SOURCE_TYPE_FRAMEWORK,
"NIST SP 800-53 Rev. 5": SOURCE_TYPE_FRAMEWORK,
"NIST SP 800-63-3": SOURCE_TYPE_FRAMEWORK,
# --- OWASP (Community-Standards) ---
"OWASP API Security Top 10 (2023)": SOURCE_TYPE_FRAMEWORK,
"OWASP ASVS 4.0": SOURCE_TYPE_FRAMEWORK,
"OWASP MASVS 2.0": SOURCE_TYPE_FRAMEWORK,
"OWASP SAMM 2.0": SOURCE_TYPE_FRAMEWORK,
"OWASP Top 10 (2021)": SOURCE_TYPE_FRAMEWORK,
# --- Sonstige Frameworks ---
"OECD KI-Empfehlung": SOURCE_TYPE_FRAMEWORK,
"CISA Secure by Design": SOURCE_TYPE_FRAMEWORK,
}
def classify_source_regulation(source_regulation: str) -> str:
"""
Klassifiziert eine source_regulation als law, guideline oder framework.
Verwendet exaktes Matching gegen die Map. Bei unbekannten Quellen
wird anhand von Schluesselwoertern geraten, Fallback ist 'framework'
(konservativstes Ergebnis).
"""
if not source_regulation:
return SOURCE_TYPE_FRAMEWORK
# Exaktes Match
if source_regulation in SOURCE_REGULATION_CLASSIFICATION:
return SOURCE_REGULATION_CLASSIFICATION[source_regulation]
# Heuristik fuer unbekannte Quellen
lower = source_regulation.lower()
# Gesetze erkennen
law_indicators = [
"verordnung", "richtlinie", "gesetz", "directive", "regulation",
"(eu)", "(eg)", "act", "ley", "loi", "törvény", "código",
]
if any(ind in lower for ind in law_indicators):
return SOURCE_TYPE_LAW
# Leitlinien erkennen
guideline_indicators = [
"edpb", "leitlinie", "guideline", "wp2", "bsi", "empfehlung",
]
if any(ind in lower for ind in guideline_indicators):
return SOURCE_TYPE_GUIDELINE
# Frameworks erkennen
framework_indicators = [
"enisa", "nist", "owasp", "oecd", "cisa", "framework", "iso",
]
if any(ind in lower for ind in framework_indicators):
return SOURCE_TYPE_FRAMEWORK
# Konservativ: unbekannt = framework (geringste Verbindlichkeit)
return SOURCE_TYPE_FRAMEWORK

View File

View File

@@ -0,0 +1,37 @@
"""Database session factory for control-pipeline.
Connects to the shared PostgreSQL with search_path set to compliance schema.
"""
from sqlalchemy import create_engine, event
from sqlalchemy.orm import sessionmaker
from config import settings
engine = create_engine(
settings.DATABASE_URL,
pool_pre_ping=True,
pool_size=5,
max_overflow=10,
echo=False,
)
@event.listens_for(engine, "connect")
def set_search_path(dbapi_connection, connection_record):
cursor = dbapi_connection.cursor()
cursor.execute(f"SET search_path TO {settings.SCHEMA_SEARCH_PATH}")
cursor.close()
dbapi_connection.commit()
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
def get_db():
"""FastAPI dependency for DB sessions."""
db = SessionLocal()
try:
yield db
finally:
db.close()

88
control-pipeline/main.py Normal file
View File

@@ -0,0 +1,88 @@
import logging
from contextlib import asynccontextmanager
import uvicorn
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from config import settings
from db.session import engine
logging.basicConfig(
level=getattr(logging, settings.LOG_LEVEL, logging.INFO),
format="%(asctime)s [%(name)s] %(levelname)s: %(message)s",
)
logger = logging.getLogger("control-pipeline")
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Startup: verify DB and Qdrant connectivity."""
logger.info("Control-Pipeline starting up ...")
# Verify database connection
try:
with engine.connect() as conn:
conn.execute(__import__("sqlalchemy").text("SELECT 1"))
logger.info("Database connection OK")
except Exception as exc:
logger.error("Database connection failed: %s", exc)
yield
logger.info("Control-Pipeline shutting down ...")
app = FastAPI(
title="BreakPilot Control Pipeline",
description="Control generation, decomposition, and deduplication pipeline for the BreakPilot compliance platform.",
version="1.0.0",
lifespan=lifespan,
)
# CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Routers
from api import router as api_router # noqa: E402
app.include_router(api_router)
# Health
@app.get("/health")
async def health():
"""Liveness probe."""
db_ok = False
try:
with engine.connect() as conn:
conn.execute(__import__("sqlalchemy").text("SELECT 1"))
db_ok = True
except Exception:
pass
status = "healthy" if db_ok else "degraded"
return {
"status": status,
"service": "control-pipeline",
"version": "1.0.0",
"dependencies": {
"postgres": "ok" if db_ok else "unavailable",
},
}
if __name__ == "__main__":
uvicorn.run(
"main:app",
host="0.0.0.0",
port=settings.PORT,
reload=False,
log_level="info",
)

View File

@@ -0,0 +1,22 @@
# Web Framework
fastapi>=0.123.0
uvicorn[standard]>=0.27.0
# Database
SQLAlchemy>=2.0.36
psycopg2-binary>=2.9.10
# HTTP Client
httpx>=0.28.0
# Validation
pydantic>=2.5.0
# AI - Anthropic Claude
anthropic>=0.75.0
# Vector DB (dedup)
qdrant-client>=1.7.0
# Auth
python-jose[cryptography]>=3.3.0

View File

View File

@@ -0,0 +1,187 @@
"""
Anchor Finder — finds open-source references (OWASP, NIST, ENISA) for controls.
Two-stage search:
Stage A: RAG-internal search for open-source chunks matching the control topic
Stage B: Web search via DuckDuckGo Instant Answer API (no API key needed)
Only open-source references (Rule 1+2) are accepted as anchors.
"""
import logging
from dataclasses import dataclass
from typing import List, Optional
import httpx
from .rag_client import ComplianceRAGClient, get_rag_client
from .control_generator import (
GeneratedControl,
REGULATION_LICENSE_MAP,
_RULE2_PREFIXES,
_RULE3_PREFIXES,
_classify_regulation,
)
logger = logging.getLogger(__name__)
# Regulation codes that are safe to reference as open anchors (Rule 1+2)
_OPEN_SOURCE_RULES = {1, 2}
@dataclass
class OpenAnchor:
framework: str
ref: str
url: str
class AnchorFinder:
"""Finds open-source references to anchor generated controls."""
def __init__(self, rag_client: Optional[ComplianceRAGClient] = None):
self.rag = rag_client or get_rag_client()
async def find_anchors(
self,
control: GeneratedControl,
skip_web: bool = False,
min_anchors: int = 2,
) -> List[OpenAnchor]:
"""Find open-source anchors for a control."""
# Stage A: RAG-internal search
anchors = await self._search_rag_for_open_anchors(control)
# Stage B: Web search if not enough anchors
if len(anchors) < min_anchors and not skip_web:
web_anchors = await self._search_web(control)
# Deduplicate by framework+ref
existing_keys = {(a.framework, a.ref) for a in anchors}
for wa in web_anchors:
if (wa.framework, wa.ref) not in existing_keys:
anchors.append(wa)
return anchors
async def _search_rag_for_open_anchors(self, control: GeneratedControl) -> List[OpenAnchor]:
"""Search RAG for chunks from open sources matching the control topic."""
# Build search query from control title + first 3 tags
tags_str = " ".join(control.tags[:3]) if control.tags else ""
query = f"{control.title} {tags_str}".strip()
results = await self.rag.search_with_rerank(
query=query,
collection="bp_compliance_ce",
top_k=15,
)
anchors: List[OpenAnchor] = []
seen: set[str] = set()
for r in results:
if not r.regulation_code:
continue
# Only accept open-source references
license_info = _classify_regulation(r.regulation_code)
if license_info.get("rule") not in _OPEN_SOURCE_RULES:
continue
# Build reference key for dedup
ref = r.article or r.category or ""
key = f"{r.regulation_code}:{ref}"
if key in seen:
continue
seen.add(key)
framework_name = license_info.get("name", r.regulation_name or r.regulation_short or r.regulation_code)
url = r.source_url or self._build_reference_url(r.regulation_code, ref)
anchors.append(OpenAnchor(
framework=framework_name,
ref=ref,
url=url,
))
if len(anchors) >= 5:
break
return anchors
async def _search_web(self, control: GeneratedControl) -> List[OpenAnchor]:
"""Search DuckDuckGo Instant Answer API for open references."""
keywords = f"{control.title} security control OWASP NIST"
try:
async with httpx.AsyncClient(timeout=10.0) as client:
resp = await client.get(
"https://api.duckduckgo.com/",
params={
"q": keywords,
"format": "json",
"no_html": "1",
"skip_disambig": "1",
},
)
if resp.status_code != 200:
return []
data = resp.json()
anchors: List[OpenAnchor] = []
# Parse RelatedTopics
for topic in data.get("RelatedTopics", [])[:10]:
url = topic.get("FirstURL", "")
text = topic.get("Text", "")
if not url:
continue
# Only accept known open-source domains
framework = self._identify_framework_from_url(url)
if framework:
anchors.append(OpenAnchor(
framework=framework,
ref=text[:100] if text else url,
url=url,
))
if len(anchors) >= 3:
break
return anchors
except Exception as e:
logger.warning("Web anchor search failed: %s", e)
return []
@staticmethod
def _identify_framework_from_url(url: str) -> Optional[str]:
"""Identify if a URL belongs to a known open-source framework."""
url_lower = url.lower()
if "owasp.org" in url_lower:
return "OWASP"
if "nist.gov" in url_lower or "csrc.nist.gov" in url_lower:
return "NIST"
if "enisa.europa.eu" in url_lower:
return "ENISA"
if "cisa.gov" in url_lower:
return "CISA"
if "eur-lex.europa.eu" in url_lower:
return "EU Law"
return None
@staticmethod
def _build_reference_url(regulation_code: str, ref: str) -> str:
"""Build a reference URL for known frameworks."""
code = regulation_code.lower()
if code.startswith("owasp"):
return "https://owasp.org/www-project-application-security-verification-standard/"
if code.startswith("nist"):
return "https://csrc.nist.gov/publications"
if code.startswith("enisa"):
return "https://www.enisa.europa.eu/publications"
if code.startswith("eu_"):
return "https://eur-lex.europa.eu/"
if code == "cisa_secure_by_design":
return "https://www.cisa.gov/securebydesign"
return ""

View File

@@ -0,0 +1,618 @@
"""Batch Dedup Runner — Orchestrates deduplication of ~85k atomare Controls.
Reduces Pass 0b controls from ~85k to ~18-25k unique Master Controls via:
Phase 1: Intra-Group Dedup — same merge_group_hint → pick best, link rest
(85k → ~52k, mostly title-identical short-circuit, no embeddings)
Phase 2: Cross-Group Dedup — embed masters, search Qdrant for similar
masters with different hints (52k → ~18-25k)
All Pass 0b controls have pattern_id=NULL. The primary grouping key is
merge_group_hint (format: "action_type:norm_obj:trigger_key"), which
encodes the normalized action, object, and trigger.
Usage:
runner = BatchDedupRunner(db)
stats = await runner.run(dry_run=True) # preview
stats = await runner.run(dry_run=False) # execute
stats = await runner.run(hint_filter="implement:multi_factor_auth:none")
"""
import json
import logging
import time
from collections import defaultdict
from sqlalchemy import text
from services.control_dedup import (
canonicalize_text,
ensure_qdrant_collection,
get_embedding,
normalize_action,
normalize_object,
qdrant_search_cross_regulation,
qdrant_upsert,
LINK_THRESHOLD,
REVIEW_THRESHOLD,
)
logger = logging.getLogger(__name__)
DEDUP_COLLECTION = "atomic_controls_dedup"
# ── Quality Score ────────────────────────────────────────────────────────
def quality_score(control: dict) -> float:
"""Score a control by richness of requirements, tests, evidence, and objective.
Higher score = better candidate for master control.
"""
score = 0.0
reqs = control.get("requirements") or "[]"
if isinstance(reqs, str):
try:
reqs = json.loads(reqs)
except (json.JSONDecodeError, TypeError):
reqs = []
score += len(reqs) * 2.0
tests = control.get("test_procedure") or "[]"
if isinstance(tests, str):
try:
tests = json.loads(tests)
except (json.JSONDecodeError, TypeError):
tests = []
score += len(tests) * 1.5
evidence = control.get("evidence") or "[]"
if isinstance(evidence, str):
try:
evidence = json.loads(evidence)
except (json.JSONDecodeError, TypeError):
evidence = []
score += len(evidence) * 1.0
objective = control.get("objective") or ""
score += min(len(objective) / 200, 3.0)
return score
# ── Batch Dedup Runner ───────────────────────────────────────────────────
class BatchDedupRunner:
"""Batch dedup orchestrator for existing Pass 0b atomic controls."""
def __init__(self, db, collection: str = DEDUP_COLLECTION):
self.db = db
self.collection = collection
self.stats = {
"total_controls": 0,
"unique_hints": 0,
"phase1_groups_processed": 0,
"masters": 0,
"linked": 0,
"review": 0,
"new_controls": 0,
"parent_links_transferred": 0,
"cross_group_linked": 0,
"cross_group_review": 0,
"errors": 0,
"skipped_title_identical": 0,
}
self._progress_phase = ""
self._progress_count = 0
self._progress_total = 0
async def run(
self,
dry_run: bool = False,
hint_filter: str = None,
) -> dict:
"""Run the full batch dedup pipeline.
Args:
dry_run: If True, compute stats but don't modify DB/Qdrant.
hint_filter: If set, only process groups matching this hint prefix.
Returns:
Stats dict with counts.
"""
start = time.monotonic()
logger.info("BatchDedup starting (dry_run=%s, hint_filter=%s)",
dry_run, hint_filter)
if not dry_run:
await ensure_qdrant_collection(collection=self.collection)
# Phase 1: Intra-group dedup (same merge_group_hint)
self._progress_phase = "phase1"
groups = self._load_merge_groups(hint_filter)
self._progress_total = self.stats["total_controls"]
for hint, controls in groups:
try:
await self._process_hint_group(hint, controls, dry_run)
self.stats["phase1_groups_processed"] += 1
except Exception as e:
logger.error("BatchDedup Phase 1 error on hint %s: %s", hint, e)
self.stats["errors"] += 1
try:
self.db.rollback()
except Exception:
pass
logger.info(
"BatchDedup Phase 1 done: %d masters, %d linked, %d review",
self.stats["masters"], self.stats["linked"], self.stats["review"],
)
# Phase 2: Cross-group dedup via embeddings
if not dry_run:
self._progress_phase = "phase2"
await self._run_cross_group_pass()
elapsed = time.monotonic() - start
self.stats["elapsed_seconds"] = round(elapsed, 1)
logger.info("BatchDedup completed in %.1fs: %s", elapsed, self.stats)
return self.stats
def _load_merge_groups(self, hint_filter: str = None) -> list:
"""Load all Pass 0b controls grouped by merge_group_hint, largest first."""
conditions = [
"decomposition_method = 'pass0b'",
"release_state != 'deprecated'",
"release_state != 'duplicate'",
]
params = {}
if hint_filter:
conditions.append("generation_metadata->>'merge_group_hint' LIKE :hf")
params["hf"] = f"{hint_filter}%"
where = " AND ".join(conditions)
rows = self.db.execute(text(f"""
SELECT id::text, control_id, title, objective,
pattern_id, requirements::text, test_procedure::text,
evidence::text, release_state,
generation_metadata->>'merge_group_hint' as merge_group_hint,
generation_metadata->>'action_object_class' as action_object_class
FROM canonical_controls
WHERE {where}
ORDER BY control_id
"""), params).fetchall()
by_hint = defaultdict(list)
for r in rows:
by_hint[r[9] or ""].append({
"uuid": r[0],
"control_id": r[1],
"title": r[2],
"objective": r[3],
"pattern_id": r[4],
"requirements": r[5],
"test_procedure": r[6],
"evidence": r[7],
"release_state": r[8],
"merge_group_hint": r[9] or "",
"action_object_class": r[10] or "",
})
self.stats["total_controls"] = len(rows)
self.stats["unique_hints"] = len(by_hint)
sorted_groups = sorted(by_hint.items(), key=lambda x: len(x[1]), reverse=True)
logger.info("BatchDedup loaded %d controls in %d hint groups",
len(rows), len(sorted_groups))
return sorted_groups
def _sub_group_by_merge_hint(self, controls: list) -> dict:
"""Group controls by merge_group_hint composite key."""
groups = defaultdict(list)
for c in controls:
hint = c["merge_group_hint"]
if hint:
groups[hint].append(c)
else:
groups[f"__no_hint_{c['uuid']}"].append(c)
return dict(groups)
async def _process_hint_group(
self,
hint: str,
controls: list,
dry_run: bool,
):
"""Process all controls sharing the same merge_group_hint.
Within a hint group, all controls share action+object+trigger.
The best-quality control becomes master, rest are linked as duplicates.
"""
if len(controls) < 2:
# Singleton → always master
self.stats["masters"] += 1
if not dry_run:
await self._embed_and_index(controls[0])
self._progress_count += 1
self._log_progress(hint)
return
# Sort by quality score (best first)
sorted_group = sorted(controls, key=quality_score, reverse=True)
master = sorted_group[0]
self.stats["masters"] += 1
if not dry_run:
await self._embed_and_index(master)
for candidate in sorted_group[1:]:
# All share the same hint → check title similarity
if candidate["title"].strip().lower() == master["title"].strip().lower():
# Identical title → direct link (no embedding needed)
self.stats["linked"] += 1
self.stats["skipped_title_identical"] += 1
if not dry_run:
await self._mark_duplicate(master, candidate, confidence=1.0)
else:
# Different title within same hint → still likely duplicate
# Use embedding to verify
await self._check_and_link_within_group(master, candidate, dry_run)
self._progress_count += 1
self._log_progress(hint)
async def _check_and_link_within_group(
self,
master: dict,
candidate: dict,
dry_run: bool,
):
"""Check if candidate (same hint group) is duplicate of master via embedding."""
parts = candidate["merge_group_hint"].split(":", 2)
action = parts[0] if len(parts) > 0 else ""
obj = parts[1] if len(parts) > 1 else ""
canonical = canonicalize_text(action, obj, candidate["title"])
embedding = await get_embedding(canonical)
if not embedding:
# Can't embed → link anyway (same hint = same action+object)
self.stats["linked"] += 1
if not dry_run:
await self._mark_duplicate(master, candidate, confidence=0.90)
return
# Search the dedup collection (unfiltered — pattern_id is NULL)
results = await qdrant_search_cross_regulation(
embedding, top_k=3, collection=self.collection,
)
if not results:
# No Qdrant matches yet (master might not be indexed yet) → link to master
self.stats["linked"] += 1
if not dry_run:
await self._mark_duplicate(master, candidate, confidence=0.90)
return
best = results[0]
best_score = best.get("score", 0.0)
best_payload = best.get("payload", {})
best_uuid = best_payload.get("control_uuid", "")
if best_score > LINK_THRESHOLD:
self.stats["linked"] += 1
if not dry_run:
await self._mark_duplicate_to(best_uuid, candidate, confidence=best_score)
elif best_score > REVIEW_THRESHOLD:
self.stats["review"] += 1
if not dry_run:
self._write_review(candidate, best_payload, best_score)
else:
# Very different despite same hint → new master
self.stats["new_controls"] += 1
if not dry_run:
await self._index_with_embedding(candidate, embedding)
async def _run_cross_group_pass(self):
"""Phase 2: Find cross-group duplicates among surviving masters.
After Phase 1, ~52k masters remain. Many have similar semantics
despite different merge_group_hints (e.g. different German spellings).
This pass embeds all masters and finds near-duplicates via Qdrant.
"""
logger.info("BatchDedup Phase 2: Cross-group pass starting...")
rows = self.db.execute(text("""
SELECT id::text, control_id, title,
generation_metadata->>'merge_group_hint' as merge_group_hint
FROM canonical_controls
WHERE decomposition_method = 'pass0b'
AND release_state != 'duplicate'
AND release_state != 'deprecated'
ORDER BY control_id
""")).fetchall()
self._progress_total = len(rows)
self._progress_count = 0
logger.info("BatchDedup Cross-group: %d masters to check", len(rows))
cross_linked = 0
cross_review = 0
for i, r in enumerate(rows):
uuid = r[0]
hint = r[3] or ""
parts = hint.split(":", 2)
action = parts[0] if len(parts) > 0 else ""
obj = parts[1] if len(parts) > 1 else ""
canonical = canonicalize_text(action, obj, r[2])
embedding = await get_embedding(canonical)
if not embedding:
continue
results = await qdrant_search_cross_regulation(
embedding, top_k=5, collection=self.collection,
)
if not results:
continue
# Find best match from a DIFFERENT hint group
for match in results:
match_score = match.get("score", 0.0)
match_payload = match.get("payload", {})
match_uuid = match_payload.get("control_uuid", "")
# Skip self-match
if match_uuid == uuid:
continue
# Must be a different hint group (otherwise already handled in Phase 1)
match_action = match_payload.get("action_normalized", "")
match_object = match_payload.get("object_normalized", "")
# Simple check: different control UUID is enough
if match_score > LINK_THRESHOLD:
# Mark the worse one as duplicate
try:
self.db.execute(text("""
UPDATE canonical_controls
SET release_state = 'duplicate', merged_into_uuid = CAST(:master AS uuid)
WHERE id = CAST(:dup AS uuid)
AND release_state != 'duplicate'
"""), {"master": match_uuid, "dup": uuid})
self.db.execute(text("""
INSERT INTO control_parent_links
(control_uuid, parent_control_uuid, link_type, confidence)
VALUES (CAST(:cu AS uuid), CAST(:pu AS uuid), 'cross_regulation', :conf)
ON CONFLICT (control_uuid, parent_control_uuid) DO NOTHING
"""), {"cu": match_uuid, "pu": uuid, "conf": match_score})
# Transfer parent links
transferred = self._transfer_parent_links(match_uuid, uuid)
self.stats["parent_links_transferred"] += transferred
self.db.commit()
cross_linked += 1
except Exception as e:
logger.error("BatchDedup cross-group link error %s%s: %s",
uuid, match_uuid, e)
self.db.rollback()
self.stats["errors"] += 1
break # Only one cross-link per control
elif match_score > REVIEW_THRESHOLD:
self._write_review(
{"control_id": r[1], "title": r[2], "objective": "",
"merge_group_hint": hint, "pattern_id": None},
match_payload, match_score,
)
cross_review += 1
break
self._progress_count = i + 1
if (i + 1) % 500 == 0:
logger.info("BatchDedup Cross-group: %d/%d checked, %d linked, %d review",
i + 1, len(rows), cross_linked, cross_review)
self.stats["cross_group_linked"] = cross_linked
self.stats["cross_group_review"] = cross_review
logger.info("BatchDedup Cross-group complete: %d linked, %d review",
cross_linked, cross_review)
# ── Qdrant Helpers ───────────────────────────────────────────────────
async def _embed_and_index(self, control: dict):
"""Compute embedding and index a control in the dedup Qdrant collection."""
parts = control["merge_group_hint"].split(":", 2)
action = parts[0] if len(parts) > 0 else ""
obj = parts[1] if len(parts) > 1 else ""
norm_action = normalize_action(action)
norm_object = normalize_object(obj)
canonical = canonicalize_text(action, obj, control["title"])
embedding = await get_embedding(canonical)
if not embedding:
return
await qdrant_upsert(
point_id=control["uuid"],
embedding=embedding,
payload={
"control_uuid": control["uuid"],
"control_id": control["control_id"],
"title": control["title"],
"pattern_id": control.get("pattern_id"),
"action_normalized": norm_action,
"object_normalized": norm_object,
"canonical_text": canonical,
"merge_group_hint": control["merge_group_hint"],
},
collection=self.collection,
)
async def _index_with_embedding(self, control: dict, embedding: list):
"""Index a control with a pre-computed embedding."""
parts = control["merge_group_hint"].split(":", 2)
action = parts[0] if len(parts) > 0 else ""
obj = parts[1] if len(parts) > 1 else ""
norm_action = normalize_action(action)
norm_object = normalize_object(obj)
canonical = canonicalize_text(action, obj, control["title"])
await qdrant_upsert(
point_id=control["uuid"],
embedding=embedding,
payload={
"control_uuid": control["uuid"],
"control_id": control["control_id"],
"title": control["title"],
"pattern_id": control.get("pattern_id"),
"action_normalized": norm_action,
"object_normalized": norm_object,
"canonical_text": canonical,
"merge_group_hint": control["merge_group_hint"],
},
collection=self.collection,
)
# ── DB Write Helpers ─────────────────────────────────────────────────
async def _mark_duplicate(self, master: dict, candidate: dict, confidence: float):
"""Mark candidate as duplicate of master, transfer parent links."""
try:
self.db.execute(text("""
UPDATE canonical_controls
SET release_state = 'duplicate', merged_into_uuid = CAST(:master AS uuid)
WHERE id = CAST(:cand AS uuid)
"""), {"master": master["uuid"], "cand": candidate["uuid"]})
self.db.execute(text("""
INSERT INTO control_parent_links
(control_uuid, parent_control_uuid, link_type, confidence)
VALUES (CAST(:master AS uuid), CAST(:cand_parent AS uuid), 'dedup_merge', :conf)
ON CONFLICT (control_uuid, parent_control_uuid) DO NOTHING
"""), {"master": master["uuid"], "cand_parent": candidate["uuid"], "conf": confidence})
transferred = self._transfer_parent_links(master["uuid"], candidate["uuid"])
self.stats["parent_links_transferred"] += transferred
self.db.commit()
except Exception as e:
logger.error("BatchDedup _mark_duplicate error %s%s: %s",
candidate["uuid"], master["uuid"], e)
self.db.rollback()
raise
async def _mark_duplicate_to(self, master_uuid: str, candidate: dict, confidence: float):
"""Mark candidate as duplicate of a Qdrant-matched master."""
try:
self.db.execute(text("""
UPDATE canonical_controls
SET release_state = 'duplicate', merged_into_uuid = CAST(:master AS uuid)
WHERE id = CAST(:cand AS uuid)
"""), {"master": master_uuid, "cand": candidate["uuid"]})
self.db.execute(text("""
INSERT INTO control_parent_links
(control_uuid, parent_control_uuid, link_type, confidence)
VALUES (CAST(:master AS uuid), CAST(:cand_parent AS uuid), 'dedup_merge', :conf)
ON CONFLICT (control_uuid, parent_control_uuid) DO NOTHING
"""), {"master": master_uuid, "cand_parent": candidate["uuid"], "conf": confidence})
transferred = self._transfer_parent_links(master_uuid, candidate["uuid"])
self.stats["parent_links_transferred"] += transferred
self.db.commit()
except Exception as e:
logger.error("BatchDedup _mark_duplicate_to error %s%s: %s",
candidate["uuid"], master_uuid, e)
self.db.rollback()
raise
def _transfer_parent_links(self, master_uuid: str, duplicate_uuid: str) -> int:
"""Move existing parent links from duplicate to master."""
rows = self.db.execute(text("""
SELECT parent_control_uuid::text, link_type, confidence,
source_regulation, source_article, obligation_candidate_id::text
FROM control_parent_links
WHERE control_uuid = CAST(:dup AS uuid)
AND link_type = 'decomposition'
"""), {"dup": duplicate_uuid}).fetchall()
transferred = 0
for r in rows:
parent_uuid = r[0]
if parent_uuid == master_uuid:
continue
self.db.execute(text("""
INSERT INTO control_parent_links
(control_uuid, parent_control_uuid, link_type, confidence,
source_regulation, source_article, obligation_candidate_id)
VALUES (CAST(:cu AS uuid), CAST(:pu AS uuid), :lt, :conf,
:sr, :sa, CAST(:oci AS uuid))
ON CONFLICT (control_uuid, parent_control_uuid) DO NOTHING
"""), {
"cu": master_uuid,
"pu": parent_uuid,
"lt": r[1],
"conf": float(r[2]) if r[2] else 1.0,
"sr": r[3],
"sa": r[4],
"oci": r[5],
})
transferred += 1
return transferred
def _write_review(self, candidate: dict, matched_payload: dict, score: float):
"""Write a dedup review entry for borderline matches."""
try:
self.db.execute(text("""
INSERT INTO control_dedup_reviews
(candidate_control_id, candidate_title, candidate_objective,
matched_control_uuid, matched_control_id,
similarity_score, dedup_stage, dedup_details)
VALUES (:ccid, :ct, :co, CAST(:mcu AS uuid), :mci,
:ss, 'batch_dedup', CAST(:dd AS jsonb))
"""), {
"ccid": candidate["control_id"],
"ct": candidate["title"],
"co": candidate.get("objective", ""),
"mcu": matched_payload.get("control_uuid"),
"mci": matched_payload.get("control_id"),
"ss": score,
"dd": json.dumps({
"merge_group_hint": candidate.get("merge_group_hint", ""),
"pattern_id": candidate.get("pattern_id"),
}),
})
self.db.commit()
except Exception as e:
logger.error("BatchDedup _write_review error: %s", e)
self.db.rollback()
raise
# ── Progress ─────────────────────────────────────────────────────────
def _log_progress(self, hint: str):
"""Log progress every 500 controls."""
if self._progress_count > 0 and self._progress_count % 500 == 0:
logger.info(
"BatchDedup [%s] %d/%d — masters=%d, linked=%d, review=%d",
self._progress_phase, self._progress_count, self._progress_total,
self.stats["masters"], self.stats["linked"], self.stats["review"],
)
def get_status(self) -> dict:
"""Return current progress stats (for status endpoint)."""
return {
"phase": self._progress_phase,
"progress": self._progress_count,
"total": self._progress_total,
**self.stats,
}

View File

@@ -0,0 +1,438 @@
"""
Citation Backfill Service — enrich existing controls with article/paragraph provenance.
3-tier matching strategy:
Tier 1 — Hash match: sha256(source_original_text) → RAG chunk lookup
Tier 2 — Regex parse: split concatenated "DSGVO Art. 35" → regulation + article
Tier 3 — Ollama LLM: ask local LLM to identify article/paragraph from text
"""
import hashlib
import json
import logging
import os
import re
from dataclasses import dataclass, field
from datetime import datetime, timezone
from typing import Optional
import httpx
from sqlalchemy import text
from sqlalchemy.orm import Session
from .rag_client import ComplianceRAGClient, RAGSearchResult
logger = logging.getLogger(__name__)
OLLAMA_URL = os.getenv("OLLAMA_URL", "http://host.docker.internal:11434")
OLLAMA_MODEL = os.getenv("CONTROL_GEN_OLLAMA_MODEL", "qwen3.5:35b-a3b")
LLM_TIMEOUT = float(os.getenv("CONTROL_GEN_LLM_TIMEOUT", "180"))
ALL_COLLECTIONS = [
"bp_compliance_ce",
"bp_compliance_gesetze",
"bp_compliance_datenschutz",
"bp_dsfa_corpus",
"bp_legal_templates",
]
BACKFILL_SYSTEM_PROMPT = (
"Du bist ein Rechtsexperte. Deine Aufgabe ist es, aus einem Gesetzestext "
"den genauen Artikel und Absatz zu bestimmen. Antworte NUR mit validem JSON."
)
# Regex to split concatenated source like "DSGVO Art. 35" or "NIS2 Artikel 21 Abs. 2"
_SOURCE_ARTICLE_RE = re.compile(
r"^(.+?)\s+(Art(?:ikel)?\.?\s*\d+.*)$", re.IGNORECASE
)
@dataclass
class MatchResult:
article: str
paragraph: str
method: str # "hash", "regex", "llm"
@dataclass
class BackfillResult:
total_controls: int = 0
matched_hash: int = 0
matched_regex: int = 0
matched_llm: int = 0
unmatched: int = 0
updated: int = 0
errors: list = field(default_factory=list)
class CitationBackfill:
"""Backfill article/paragraph into existing control source_citations."""
def __init__(self, db: Session, rag_client: ComplianceRAGClient):
self.db = db
self.rag = rag_client
self._rag_index: dict[str, RAGSearchResult] = {}
async def run(self, dry_run: bool = True, limit: int = 0) -> BackfillResult:
"""Main entry: iterate controls missing article/paragraph, match to RAG, update."""
result = BackfillResult()
# Load controls needing backfill
controls = self._load_controls_needing_backfill(limit)
result.total_controls = len(controls)
logger.info("Backfill: %d controls need article/paragraph enrichment", len(controls))
if not controls:
return result
# Collect hashes we need to find — only build index for controls with source text
needed_hashes: set[str] = set()
for ctrl in controls:
src = ctrl.get("source_original_text")
if src:
needed_hashes.add(hashlib.sha256(src.encode()).hexdigest())
if needed_hashes:
# Build targeted RAG index — only scroll collections that our controls reference
logger.info("Building targeted RAG hash index for %d source texts...", len(needed_hashes))
await self._build_rag_index_targeted(controls)
logger.info("RAG index built: %d chunks indexed, %d hashes needed", len(self._rag_index), len(needed_hashes))
else:
logger.info("No source_original_text found — skipping RAG index build")
# Process each control
for i, ctrl in enumerate(controls):
if i > 0 and i % 100 == 0:
logger.info("Backfill progress: %d/%d processed", i, result.total_controls)
try:
match = await self._match_control(ctrl)
if match:
if match.method == "hash":
result.matched_hash += 1
elif match.method == "regex":
result.matched_regex += 1
elif match.method == "llm":
result.matched_llm += 1
if not dry_run:
self._update_control(ctrl, match)
result.updated += 1
else:
logger.debug(
"DRY RUN: Would update %s with article=%s paragraph=%s (method=%s)",
ctrl["control_id"], match.article, match.paragraph, match.method,
)
else:
result.unmatched += 1
except Exception as e:
error_msg = f"Error backfilling {ctrl.get('control_id', '?')}: {e}"
logger.error(error_msg)
result.errors.append(error_msg)
if not dry_run:
try:
self.db.commit()
except Exception as e:
logger.error("Backfill commit failed: %s", e)
result.errors.append(f"Commit failed: {e}")
logger.info(
"Backfill complete: %d total, hash=%d regex=%d llm=%d unmatched=%d updated=%d",
result.total_controls, result.matched_hash, result.matched_regex,
result.matched_llm, result.unmatched, result.updated,
)
return result
def _load_controls_needing_backfill(self, limit: int = 0) -> list[dict]:
"""Load controls where source_citation exists but lacks separate 'article' key."""
query = """
SELECT id, control_id, source_citation, source_original_text,
generation_metadata, license_rule
FROM canonical_controls
WHERE license_rule IN (1, 2)
AND source_citation IS NOT NULL
AND (
source_citation->>'article' IS NULL
OR source_citation->>'article' = ''
)
ORDER BY control_id
"""
if limit > 0:
query += f" LIMIT {limit}"
result = self.db.execute(text(query))
cols = result.keys()
controls = []
for row in result:
ctrl = dict(zip(cols, row))
ctrl["id"] = str(ctrl["id"])
# Parse JSON fields
for jf in ("source_citation", "generation_metadata"):
if isinstance(ctrl.get(jf), str):
try:
ctrl[jf] = json.loads(ctrl[jf])
except (json.JSONDecodeError, TypeError):
ctrl[jf] = {}
controls.append(ctrl)
return controls
async def _build_rag_index_targeted(self, controls: list[dict]):
"""Build RAG index by scrolling only collections relevant to our controls.
Uses regulation codes from generation_metadata to identify which collections
to search, falling back to all collections only if needed.
"""
# Determine which collections are relevant based on regulation codes
regulation_to_collection = self._map_regulations_to_collections(controls)
collections_to_search = set(regulation_to_collection.values()) or set(ALL_COLLECTIONS)
logger.info("Targeted index: searching %d collections: %s",
len(collections_to_search), ", ".join(collections_to_search))
for collection in collections_to_search:
offset = None
page = 0
seen_offsets: set[str] = set()
while True:
chunks, next_offset = await self.rag.scroll(
collection=collection, offset=offset, limit=200,
)
if not chunks:
break
for chunk in chunks:
if chunk.text and len(chunk.text.strip()) >= 50:
h = hashlib.sha256(chunk.text.encode()).hexdigest()
self._rag_index[h] = chunk
page += 1
if page % 50 == 0:
logger.info("Indexing %s: page %d (%d chunks so far)",
collection, page, len(self._rag_index))
if not next_offset:
break
if next_offset in seen_offsets:
logger.warning("Scroll loop in %s at page %d — stopping", collection, page)
break
seen_offsets.add(next_offset)
offset = next_offset
logger.info("Indexed collection %s: %d pages", collection, page)
def _map_regulations_to_collections(self, controls: list[dict]) -> dict[str, str]:
"""Map regulation codes from controls to likely Qdrant collections."""
# Heuristic: regulation code prefix → collection
collection_map = {
"eu_": "bp_compliance_gesetze",
"dsgvo": "bp_compliance_datenschutz",
"bdsg": "bp_compliance_gesetze",
"ttdsg": "bp_compliance_gesetze",
"nist_": "bp_compliance_ce",
"owasp": "bp_compliance_ce",
"bsi_": "bp_compliance_ce",
"enisa": "bp_compliance_ce",
"at_": "bp_compliance_recht",
"fr_": "bp_compliance_recht",
"es_": "bp_compliance_recht",
}
result: dict[str, str] = {}
for ctrl in controls:
meta = ctrl.get("generation_metadata") or {}
reg = meta.get("source_regulation", "")
if not reg:
continue
for prefix, coll in collection_map.items():
if reg.startswith(prefix):
result[reg] = coll
break
else:
# Unknown regulation — search all
for coll in ALL_COLLECTIONS:
result[f"_all_{coll}"] = coll
return result
async def _match_control(self, ctrl: dict) -> Optional[MatchResult]:
"""3-tier matching: hash → regex → LLM."""
# Tier 1: Hash match against RAG index
source_text = ctrl.get("source_original_text")
if source_text:
h = hashlib.sha256(source_text.encode()).hexdigest()
chunk = self._rag_index.get(h)
if chunk and (chunk.article or chunk.paragraph):
return MatchResult(
article=chunk.article or "",
paragraph=chunk.paragraph or "",
method="hash",
)
# Tier 2: Regex parse concatenated source
citation = ctrl.get("source_citation") or {}
source_str = citation.get("source", "")
parsed = _parse_concatenated_source(source_str)
if parsed and parsed["article"]:
return MatchResult(
article=parsed["article"],
paragraph="", # Regex can't extract paragraph from concatenated format
method="regex",
)
# Tier 3: Ollama LLM
if source_text:
return await self._llm_match(ctrl)
return None
async def _llm_match(self, ctrl: dict) -> Optional[MatchResult]:
"""Use Ollama to identify article/paragraph from source text."""
citation = ctrl.get("source_citation") or {}
regulation_name = citation.get("source", "")
metadata = ctrl.get("generation_metadata") or {}
regulation_code = metadata.get("source_regulation", "")
source_text = ctrl.get("source_original_text", "")
prompt = f"""Analysiere den folgenden Gesetzestext und bestimme den genauen Artikel und Absatz.
Gesetz: {regulation_name} (Code: {regulation_code})
Text:
---
{source_text[:2000]}
---
Antworte NUR mit JSON:
{{"article": "Art. XX", "paragraph": "Abs. Y"}}
Falls kein spezifischer Absatz erkennbar ist, setze paragraph auf "".
Falls kein Artikel erkennbar ist, setze article auf "".
Bei deutschen Gesetzen mit § verwende: "§ XX" statt "Art. XX"."""
try:
raw = await _llm_ollama(prompt, BACKFILL_SYSTEM_PROMPT)
data = _parse_json(raw)
if data and (data.get("article") or data.get("paragraph")):
return MatchResult(
article=data.get("article", ""),
paragraph=data.get("paragraph", ""),
method="llm",
)
except Exception as e:
logger.warning("LLM match failed for %s: %s", ctrl.get("control_id"), e)
return None
def _update_control(self, ctrl: dict, match: MatchResult):
"""Update source_citation and generation_metadata in DB."""
citation = ctrl.get("source_citation") or {}
# Clean the source name: remove concatenated article if present
source_str = citation.get("source", "")
parsed = _parse_concatenated_source(source_str)
if parsed:
citation["source"] = parsed["name"]
# Add separate article/paragraph fields
citation["article"] = match.article
citation["paragraph"] = match.paragraph
# Update generation_metadata
metadata = ctrl.get("generation_metadata") or {}
if match.article:
metadata["source_article"] = match.article
metadata["source_paragraph"] = match.paragraph
metadata["backfill_method"] = match.method
metadata["backfill_at"] = datetime.now(timezone.utc).isoformat()
self.db.execute(
text("""
UPDATE canonical_controls
SET source_citation = :citation,
generation_metadata = :metadata,
updated_at = NOW()
WHERE id = CAST(:id AS uuid)
"""),
{
"id": ctrl["id"],
"citation": json.dumps(citation),
"metadata": json.dumps(metadata),
},
)
def _parse_concatenated_source(source: str) -> Optional[dict]:
"""Parse 'DSGVO Art. 35'{name: 'DSGVO', article: 'Art. 35'}.
Also handles '§' format: 'BDSG § 42'{name: 'BDSG', article: '§ 42'}.
"""
if not source:
return None
# Try Art./Artikel pattern
m = _SOURCE_ARTICLE_RE.match(source)
if m:
return {"name": m.group(1).strip(), "article": m.group(2).strip()}
# Try § pattern
m2 = re.match(r"^(.+?)\s+(§\s*\d+.*)$", source)
if m2:
return {"name": m2.group(1).strip(), "article": m2.group(2).strip()}
return None
async def _llm_ollama(prompt: str, system_prompt: Optional[str] = None) -> str:
"""Call Ollama chat API for backfill matching."""
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
messages.append({"role": "user", "content": prompt})
payload = {
"model": OLLAMA_MODEL,
"messages": messages,
"stream": False,
"format": "json",
"options": {"num_predict": 256},
"think": False,
}
try:
async with httpx.AsyncClient(timeout=LLM_TIMEOUT) as client:
resp = await client.post(f"{OLLAMA_URL}/api/chat", json=payload)
if resp.status_code != 200:
logger.error("Ollama backfill failed %d: %s", resp.status_code, resp.text[:300])
return ""
data = resp.json()
msg = data.get("message", {})
if isinstance(msg, dict):
return msg.get("content", "")
return data.get("response", str(msg))
except Exception as e:
logger.error("Ollama backfill request failed: %s", e)
return ""
def _parse_json(raw: str) -> Optional[dict]:
"""Extract JSON object from LLM output."""
if not raw:
return None
# Try direct parse
try:
return json.loads(raw)
except json.JSONDecodeError:
pass
# Try extracting from markdown code block
m = re.search(r"```(?:json)?\s*(\{.*?\})\s*```", raw, re.DOTALL)
if m:
try:
return json.loads(m.group(1))
except json.JSONDecodeError:
pass
# Try finding first { ... }
m = re.search(r"\{[^{}]*\}", raw)
if m:
try:
return json.loads(m.group(0))
except json.JSONDecodeError:
pass
return None

View File

@@ -0,0 +1,546 @@
"""Control Composer — Pattern + Obligation → Master Control.
Takes an obligation (from ObligationExtractor) and a matched control pattern
(from PatternMatcher), then uses LLM to compose a structured, actionable
Master Control. Replaces the old Stage 3 (STRUCTURE/REFORM) with a
pattern-guided approach.
Three composition modes based on license rules:
Rule 1: Obligation + Pattern + original text → full control
Rule 2: Obligation + Pattern + original text + citation → control
Rule 3: Obligation + Pattern (NO original text) → reformulated control
Fallback: No pattern match → basic generation (tagged needs_pattern_assignment)
Part of the Multi-Layer Control Architecture (Phase 6 of 8).
"""
import json
import logging
import os
from dataclasses import dataclass, field
from typing import Optional
from services.obligation_extractor import (
ObligationMatch,
_llm_ollama,
_parse_json,
)
from services.pattern_matcher import (
ControlPattern,
PatternMatchResult,
)
logger = logging.getLogger(__name__)
OLLAMA_MODEL = os.getenv("CONTROL_GEN_OLLAMA_MODEL", "qwen3.5:35b-a3b")
# Valid values for generated control fields
VALID_SEVERITIES = {"low", "medium", "high", "critical"}
VALID_EFFORTS = {"s", "m", "l", "xl"}
VALID_VERIFICATION = {"code_review", "document", "tool", "hybrid"}
@dataclass
class ComposedControl:
"""A Master Control composed from an obligation + pattern."""
# Core fields (match canonical_controls schema)
control_id: str = ""
title: str = ""
objective: str = ""
rationale: str = ""
scope: dict = field(default_factory=dict)
requirements: list = field(default_factory=list)
test_procedure: list = field(default_factory=list)
evidence: list = field(default_factory=list)
severity: str = "medium"
risk_score: float = 5.0
implementation_effort: str = "m"
open_anchors: list = field(default_factory=list)
release_state: str = "draft"
tags: list = field(default_factory=list)
# 3-Rule License fields
license_rule: Optional[int] = None
source_original_text: Optional[str] = None
source_citation: Optional[dict] = None
customer_visible: bool = True
# Classification
verification_method: Optional[str] = None
category: Optional[str] = None
target_audience: Optional[list] = None
# Pattern + Obligation linkage
pattern_id: Optional[str] = None
obligation_ids: list = field(default_factory=list)
# Metadata
generation_metadata: dict = field(default_factory=dict)
composition_method: str = "pattern_guided" # pattern_guided | fallback
def to_dict(self) -> dict:
"""Serialize for DB storage or API response."""
return {
"control_id": self.control_id,
"title": self.title,
"objective": self.objective,
"rationale": self.rationale,
"scope": self.scope,
"requirements": self.requirements,
"test_procedure": self.test_procedure,
"evidence": self.evidence,
"severity": self.severity,
"risk_score": self.risk_score,
"implementation_effort": self.implementation_effort,
"open_anchors": self.open_anchors,
"release_state": self.release_state,
"tags": self.tags,
"license_rule": self.license_rule,
"source_original_text": self.source_original_text,
"source_citation": self.source_citation,
"customer_visible": self.customer_visible,
"verification_method": self.verification_method,
"category": self.category,
"target_audience": self.target_audience,
"pattern_id": self.pattern_id,
"obligation_ids": self.obligation_ids,
"generation_metadata": self.generation_metadata,
"composition_method": self.composition_method,
}
class ControlComposer:
"""Composes Master Controls from obligations + patterns.
Usage::
composer = ControlComposer()
control = await composer.compose(
obligation=obligation_match,
pattern_result=pattern_match_result,
chunk_text="...",
license_rule=1,
source_citation={...},
)
"""
async def compose(
self,
obligation: ObligationMatch,
pattern_result: PatternMatchResult,
chunk_text: Optional[str] = None,
license_rule: int = 3,
source_citation: Optional[dict] = None,
regulation_code: Optional[str] = None,
) -> ComposedControl:
"""Compose a Master Control from obligation + pattern.
Args:
obligation: The extracted obligation (from ObligationExtractor).
pattern_result: The matched pattern (from PatternMatcher).
chunk_text: Original RAG chunk text (only used for Rules 1-2).
license_rule: 1=free, 2=citation, 3=restricted.
source_citation: Citation metadata for Rule 2.
regulation_code: Source regulation code.
Returns:
ComposedControl ready for storage.
"""
pattern = pattern_result.pattern if pattern_result else None
if pattern:
control = await self._compose_with_pattern(
obligation, pattern, chunk_text, license_rule, source_citation,
)
else:
control = await self._compose_fallback(
obligation, chunk_text, license_rule, source_citation,
)
# Set linkage fields
control.pattern_id = pattern.id if pattern else None
if obligation.obligation_id:
control.obligation_ids = [obligation.obligation_id]
# Set license fields
control.license_rule = license_rule
if license_rule in (1, 2) and chunk_text:
control.source_original_text = chunk_text
if license_rule == 2 and source_citation:
control.source_citation = source_citation
if license_rule == 3:
control.customer_visible = False
control.source_original_text = None
control.source_citation = None
# Build metadata
control.generation_metadata = {
"composition_method": control.composition_method,
"pattern_id": control.pattern_id,
"pattern_confidence": round(pattern_result.confidence, 3) if pattern_result else 0,
"pattern_method": pattern_result.method if pattern_result else "none",
"obligation_id": obligation.obligation_id,
"obligation_method": obligation.method,
"obligation_confidence": round(obligation.confidence, 3),
"license_rule": license_rule,
"regulation_code": regulation_code,
}
# Validate and fix fields
_validate_control(control)
return control
async def compose_batch(
self,
items: list[dict],
) -> list[ComposedControl]:
"""Compose multiple controls.
Args:
items: List of dicts with keys: obligation, pattern_result,
chunk_text, license_rule, source_citation, regulation_code.
Returns:
List of ComposedControl instances.
"""
results = []
for item in items:
control = await self.compose(
obligation=item["obligation"],
pattern_result=item.get("pattern_result", PatternMatchResult()),
chunk_text=item.get("chunk_text"),
license_rule=item.get("license_rule", 3),
source_citation=item.get("source_citation"),
regulation_code=item.get("regulation_code"),
)
results.append(control)
return results
# -----------------------------------------------------------------------
# Pattern-guided composition
# -----------------------------------------------------------------------
async def _compose_with_pattern(
self,
obligation: ObligationMatch,
pattern: ControlPattern,
chunk_text: Optional[str],
license_rule: int,
source_citation: Optional[dict],
) -> ComposedControl:
"""Use LLM to fill the pattern template with obligation-specific details."""
prompt = _build_compose_prompt(obligation, pattern, chunk_text, license_rule)
system_prompt = _compose_system_prompt(license_rule)
llm_result = await _llm_ollama(prompt, system_prompt)
if not llm_result:
return self._compose_from_template(obligation, pattern)
parsed = _parse_json(llm_result)
if not parsed:
return self._compose_from_template(obligation, pattern)
control = ComposedControl(
title=parsed.get("title", pattern.name_de)[:255],
objective=parsed.get("objective", pattern.objective_template),
rationale=parsed.get("rationale", pattern.rationale_template),
requirements=_ensure_list(parsed.get("requirements", pattern.requirements_template)),
test_procedure=_ensure_list(parsed.get("test_procedure", pattern.test_procedure_template)),
evidence=_ensure_list(parsed.get("evidence", pattern.evidence_template)),
severity=parsed.get("severity", pattern.severity_default),
implementation_effort=parsed.get("implementation_effort", pattern.implementation_effort_default),
category=parsed.get("category", pattern.category),
tags=_ensure_list(parsed.get("tags", pattern.tags)),
target_audience=_ensure_list(parsed.get("target_audience", [])),
verification_method=parsed.get("verification_method"),
open_anchors=_anchors_from_pattern(pattern),
composition_method="pattern_guided",
)
return control
def _compose_from_template(
self,
obligation: ObligationMatch,
pattern: ControlPattern,
) -> ComposedControl:
"""Fallback: fill template directly without LLM (when LLM fails)."""
obl_title = obligation.obligation_title or ""
obl_text = obligation.obligation_text or ""
title = f"{pattern.name_de}"
if obl_title:
title = f"{pattern.name_de}{obl_title}"
objective = pattern.objective_template
if obl_text and len(obl_text) > 20:
objective = f"{pattern.objective_template} Bezug: {obl_text[:200]}"
return ComposedControl(
title=title[:255],
objective=objective,
rationale=pattern.rationale_template,
requirements=list(pattern.requirements_template),
test_procedure=list(pattern.test_procedure_template),
evidence=list(pattern.evidence_template),
severity=pattern.severity_default,
implementation_effort=pattern.implementation_effort_default,
category=pattern.category,
tags=list(pattern.tags),
open_anchors=_anchors_from_pattern(pattern),
composition_method="template_only",
)
# -----------------------------------------------------------------------
# Fallback (no pattern)
# -----------------------------------------------------------------------
async def _compose_fallback(
self,
obligation: ObligationMatch,
chunk_text: Optional[str],
license_rule: int,
source_citation: Optional[dict],
) -> ComposedControl:
"""Generate a control without a pattern template (old-style)."""
prompt = _build_fallback_prompt(obligation, chunk_text, license_rule)
system_prompt = _compose_system_prompt(license_rule)
llm_result = await _llm_ollama(prompt, system_prompt)
parsed = _parse_json(llm_result) if llm_result else {}
obl_text = obligation.obligation_text or ""
control = ComposedControl(
title=parsed.get("title", obl_text[:100] if obl_text else "Untitled Control")[:255],
objective=parsed.get("objective", obl_text[:500]),
rationale=parsed.get("rationale", "Aus gesetzlicher Pflicht abgeleitet."),
requirements=_ensure_list(parsed.get("requirements", [])),
test_procedure=_ensure_list(parsed.get("test_procedure", [])),
evidence=_ensure_list(parsed.get("evidence", [])),
severity=parsed.get("severity", "medium"),
implementation_effort=parsed.get("implementation_effort", "m"),
category=parsed.get("category"),
tags=_ensure_list(parsed.get("tags", [])),
target_audience=_ensure_list(parsed.get("target_audience", [])),
verification_method=parsed.get("verification_method"),
composition_method="fallback",
release_state="needs_review",
)
return control
# ---------------------------------------------------------------------------
# Prompt builders
# ---------------------------------------------------------------------------
def _compose_system_prompt(license_rule: int) -> str:
"""Build the system prompt based on license rule."""
if license_rule == 3:
return (
"Du bist ein Security-Compliance-Experte. Deine Aufgabe ist es, "
"eigenstaendige Security Controls zu formulieren. "
"Du formulierst IMMER in eigenen Worten. "
"KOPIERE KEINE Saetze aus dem Quelltext. "
"Verwende eigene Begriffe und Struktur. "
"NENNE NICHT die Quelle. Keine proprietaeren Bezeichner. "
"Antworte NUR mit validem JSON."
)
return (
"Du bist ein Security-Compliance-Experte. "
"Erstelle ein praxisorientiertes, umsetzbares Security Control. "
"Antworte NUR mit validem JSON."
)
def _build_compose_prompt(
obligation: ObligationMatch,
pattern: ControlPattern,
chunk_text: Optional[str],
license_rule: int,
) -> str:
"""Build the LLM prompt for pattern-guided composition."""
obl_section = _obligation_section(obligation)
pattern_section = _pattern_section(pattern)
if license_rule == 3:
context_section = "KONTEXT: Intern analysiert (keine Quellenangabe)."
elif chunk_text:
context_section = f"KONTEXT (Originaltext):\n{chunk_text[:2000]}"
else:
context_section = "KONTEXT: Kein Originaltext verfuegbar."
return f"""Erstelle ein PRAXISORIENTIERTES Security Control.
{obl_section}
{pattern_section}
{context_section}
AUFGABE:
Fuelle das Muster mit pflicht-spezifischen Details.
Das Ergebnis muss UMSETZBAR sein — keine Gesetzesparaphrase.
Formuliere konkret und handlungsorientiert.
Antworte als JSON:
{{
"title": "Kurzer praegnanter Titel (max 100 Zeichen, deutsch)",
"objective": "Was soll erreicht werden? (1-3 Saetze)",
"rationale": "Warum ist das wichtig? (1-2 Saetze)",
"requirements": ["Konkrete Anforderung 1", "Anforderung 2", ...],
"test_procedure": ["Pruefschritt 1", "Pruefschritt 2", ...],
"evidence": ["Nachweis 1", "Nachweis 2", ...],
"severity": "low|medium|high|critical",
"implementation_effort": "s|m|l|xl",
"category": "{pattern.category}",
"tags": ["tag1", "tag2"],
"target_audience": ["unternehmen", "behoerden", "entwickler"],
"verification_method": "code_review|document|tool|hybrid"
}}"""
def _build_fallback_prompt(
obligation: ObligationMatch,
chunk_text: Optional[str],
license_rule: int,
) -> str:
"""Build the LLM prompt for fallback composition (no pattern)."""
obl_section = _obligation_section(obligation)
if license_rule == 3:
context_section = "KONTEXT: Intern analysiert (keine Quellenangabe)."
elif chunk_text:
context_section = f"KONTEXT (Originaltext):\n{chunk_text[:2000]}"
else:
context_section = "KONTEXT: Kein Originaltext verfuegbar."
return f"""Erstelle ein Security Control aus der folgenden Pflicht.
{obl_section}
{context_section}
AUFGABE:
Formuliere ein umsetzbares Security Control.
Keine Gesetzesparaphrase — konkrete Massnahmen beschreiben.
Antworte als JSON:
{{
"title": "Kurzer praegnanter Titel (max 100 Zeichen, deutsch)",
"objective": "Was soll erreicht werden? (1-3 Saetze)",
"rationale": "Warum ist das wichtig? (1-2 Saetze)",
"requirements": ["Konkrete Anforderung 1", "Anforderung 2", ...],
"test_procedure": ["Pruefschritt 1", "Pruefschritt 2", ...],
"evidence": ["Nachweis 1", "Nachweis 2", ...],
"severity": "low|medium|high|critical",
"implementation_effort": "s|m|l|xl",
"category": "one of: authentication, encryption, data_protection, etc.",
"tags": ["tag1", "tag2"],
"target_audience": ["unternehmen"],
"verification_method": "code_review|document|tool|hybrid"
}}"""
def _obligation_section(obligation: ObligationMatch) -> str:
"""Format the obligation for the prompt."""
parts = ["PFLICHT (was das Gesetz verlangt):"]
if obligation.obligation_title:
parts.append(f" Titel: {obligation.obligation_title}")
if obligation.obligation_text:
parts.append(f" Beschreibung: {obligation.obligation_text[:500]}")
if obligation.obligation_id:
parts.append(f" ID: {obligation.obligation_id}")
if obligation.regulation_id:
parts.append(f" Rechtsgrundlage: {obligation.regulation_id}")
if not obligation.obligation_text and not obligation.obligation_title:
parts.append(" (Keine spezifische Pflicht extrahiert)")
return "\n".join(parts)
def _pattern_section(pattern: ControlPattern) -> str:
"""Format the pattern for the prompt."""
reqs = "\n ".join(f"- {r}" for r in pattern.requirements_template[:5])
tests = "\n ".join(f"- {t}" for t in pattern.test_procedure_template[:3])
return f"""MUSTER (wie man es typischerweise umsetzt):
Pattern: {pattern.name_de} ({pattern.id})
Domain: {pattern.domain}
Ziel-Template: {pattern.objective_template}
Anforderungs-Template:
{reqs}
Pruefverfahren-Template:
{tests}"""
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _ensure_list(value) -> list:
"""Ensure a value is a list of strings."""
if isinstance(value, list):
return [str(v) for v in value if v]
if isinstance(value, str):
return [value]
return []
def _anchors_from_pattern(pattern: ControlPattern) -> list:
"""Convert pattern's open_anchor_refs to control anchor format."""
anchors = []
for ref in pattern.open_anchor_refs:
anchors.append({
"framework": ref.get("framework", ""),
"control_id": ref.get("ref", ""),
"title": "",
"alignment_score": 0.8,
})
return anchors
def _validate_control(control: ComposedControl) -> None:
"""Validate and fix control field values."""
# Severity
if control.severity not in VALID_SEVERITIES:
control.severity = "medium"
# Implementation effort
if control.implementation_effort not in VALID_EFFORTS:
control.implementation_effort = "m"
# Verification method
if control.verification_method and control.verification_method not in VALID_VERIFICATION:
control.verification_method = None
# Risk score
if not (0 <= control.risk_score <= 10):
control.risk_score = _severity_to_risk(control.severity)
# Title length
if len(control.title) > 255:
control.title = control.title[:252] + "..."
# Ensure minimum content
if not control.objective:
control.objective = control.title
if not control.rationale:
control.rationale = "Aus regulatorischer Anforderung abgeleitet."
if not control.requirements:
control.requirements = ["Anforderung gemaess Pflichtbeschreibung umsetzen"]
if not control.test_procedure:
control.test_procedure = ["Umsetzung der Anforderungen pruefen"]
if not control.evidence:
control.evidence = ["Dokumentation der Umsetzung"]
def _severity_to_risk(severity: str) -> float:
"""Map severity to a default risk score."""
return {
"critical": 9.0,
"high": 7.0,
"medium": 5.0,
"low": 3.0,
}.get(severity, 5.0)

View File

@@ -0,0 +1,745 @@
"""Control Deduplication Engine — 4-Stage Matching Pipeline.
Prevents duplicate atomic controls during Pass 0b by checking candidates
against existing controls before insertion.
Stages:
1. Pattern-Gate: pattern_id must match (hard gate)
2. Action-Check: normalized action verb must match (hard gate)
3. Object-Norm: normalized object must match (soft gate with high threshold)
4. Embedding: cosine similarity with tiered thresholds (Qdrant)
Verdicts:
- NEW: create a new atomic control
- LINK: add parent link to existing control (similarity > LINK_THRESHOLD)
- REVIEW: queue for human review (REVIEW_THRESHOLD < sim < LINK_THRESHOLD)
"""
import logging
import os
import re
from dataclasses import dataclass, field
from typing import Optional, Callable, Awaitable
import httpx
logger = logging.getLogger(__name__)
# ── Configuration ────────────────────────────────────────────────────
DEDUP_ENABLED = os.getenv("DEDUP_ENABLED", "true").lower() == "true"
LINK_THRESHOLD = float(os.getenv("DEDUP_LINK_THRESHOLD", "0.92"))
REVIEW_THRESHOLD = float(os.getenv("DEDUP_REVIEW_THRESHOLD", "0.85"))
LINK_THRESHOLD_DIFF_OBJECT = float(os.getenv("DEDUP_LINK_THRESHOLD_DIFF_OBJ", "0.95"))
CROSS_REG_LINK_THRESHOLD = float(os.getenv("DEDUP_CROSS_REG_THRESHOLD", "0.95"))
QDRANT_COLLECTION = os.getenv("DEDUP_QDRANT_COLLECTION", "atomic_controls")
QDRANT_URL = os.getenv("QDRANT_URL", "http://host.docker.internal:6333")
EMBEDDING_URL = os.getenv("EMBEDDING_URL", "http://embedding-service:8087")
# ── Result Dataclass ─────────────────────────────────────────────────
@dataclass
class DedupResult:
"""Outcome of the dedup check."""
verdict: str # "new" | "link" | "review"
matched_control_uuid: Optional[str] = None
matched_control_id: Optional[str] = None
matched_title: Optional[str] = None
stage: str = "" # which stage decided
similarity_score: float = 0.0
link_type: str = "dedup_merge" # "dedup_merge" | "cross_regulation"
details: dict = field(default_factory=dict)
# ── Action Normalization ─────────────────────────────────────────────
_ACTION_SYNONYMS: dict[str, str] = {
# German → canonical English
"implementieren": "implement",
"umsetzen": "implement",
"einrichten": "implement",
"einführen": "implement",
"aufbauen": "implement",
"bereitstellen": "implement",
"aktivieren": "implement",
"konfigurieren": "configure",
"einstellen": "configure",
"parametrieren": "configure",
"testen": "test",
"prüfen": "test",
"überprüfen": "test",
"verifizieren": "test",
"validieren": "test",
"kontrollieren": "test",
"auditieren": "audit",
"dokumentieren": "document",
"protokollieren": "log",
"aufzeichnen": "log",
"loggen": "log",
"überwachen": "monitor",
"monitoring": "monitor",
"beobachten": "monitor",
"schulen": "train",
"trainieren": "train",
"sensibilisieren": "train",
"löschen": "delete",
"entfernen": "delete",
"verschlüsseln": "encrypt",
"sperren": "block",
"beschränken": "restrict",
"einschränken": "restrict",
"begrenzen": "restrict",
"autorisieren": "authorize",
"genehmigen": "authorize",
"freigeben": "authorize",
"authentifizieren": "authenticate",
"identifizieren": "identify",
"melden": "report",
"benachrichtigen": "notify",
"informieren": "notify",
"aktualisieren": "update",
"erneuern": "update",
"sichern": "backup",
"wiederherstellen": "restore",
# English passthrough
"implement": "implement",
"configure": "configure",
"test": "test",
"verify": "test",
"validate": "test",
"audit": "audit",
"document": "document",
"log": "log",
"monitor": "monitor",
"train": "train",
"delete": "delete",
"encrypt": "encrypt",
"restrict": "restrict",
"authorize": "authorize",
"authenticate": "authenticate",
"report": "report",
"update": "update",
"backup": "backup",
"restore": "restore",
}
def normalize_action(action: str) -> str:
"""Normalize an action verb to a canonical English form."""
if not action:
return ""
action = action.strip().lower()
# Strip German infinitive/conjugation suffixes for lookup
action_base = re.sub(r"(en|t|st|e|te|tet|end)$", "", action)
# Try exact match first, then base form
if action in _ACTION_SYNONYMS:
return _ACTION_SYNONYMS[action]
if action_base in _ACTION_SYNONYMS:
return _ACTION_SYNONYMS[action_base]
# Fuzzy: check if action starts with any known verb
for verb, canonical in _ACTION_SYNONYMS.items():
if action.startswith(verb) or verb.startswith(action):
return canonical
return action # fallback: return as-is
# ── Object Normalization ─────────────────────────────────────────────
_OBJECT_SYNONYMS: dict[str, str] = {
# Authentication / Access
"mfa": "multi_factor_auth",
"multi-faktor-authentifizierung": "multi_factor_auth",
"mehrfaktorauthentifizierung": "multi_factor_auth",
"multi-factor authentication": "multi_factor_auth",
"two-factor": "multi_factor_auth",
"2fa": "multi_factor_auth",
"passwort": "password_policy",
"kennwort": "password_policy",
"password": "password_policy",
"zugangsdaten": "credentials",
"credentials": "credentials",
"admin-konten": "privileged_access",
"admin accounts": "privileged_access",
"administratorkonten": "privileged_access",
"privilegierte zugriffe": "privileged_access",
"privileged accounts": "privileged_access",
"remote-zugriff": "remote_access",
"fernzugriff": "remote_access",
"remote access": "remote_access",
"session": "session_management",
"sitzung": "session_management",
"sitzungsverwaltung": "session_management",
# Encryption
"verschlüsselung": "encryption",
"encryption": "encryption",
"kryptografie": "encryption",
"kryptografische verfahren": "encryption",
"schlüssel": "key_management",
"key management": "key_management",
"schlüsselverwaltung": "key_management",
"zertifikat": "certificate_management",
"certificate": "certificate_management",
"tls": "transport_encryption",
"ssl": "transport_encryption",
"https": "transport_encryption",
# Network
"firewall": "firewall",
"netzwerk": "network_security",
"network": "network_security",
"vpn": "vpn",
"segmentierung": "network_segmentation",
"segmentation": "network_segmentation",
# Logging / Monitoring
"audit-log": "audit_logging",
"audit log": "audit_logging",
"protokoll": "audit_logging",
"logging": "audit_logging",
"monitoring": "monitoring",
"überwachung": "monitoring",
"alerting": "alerting",
"alarmierung": "alerting",
"siem": "siem",
# Data
"personenbezogene daten": "personal_data",
"personal data": "personal_data",
"sensible daten": "sensitive_data",
"sensitive data": "sensitive_data",
"datensicherung": "backup",
"backup": "backup",
"wiederherstellung": "disaster_recovery",
"disaster recovery": "disaster_recovery",
# Policy / Process
"richtlinie": "policy",
"policy": "policy",
"verfahrensanweisung": "procedure",
"procedure": "procedure",
"prozess": "process",
"schulung": "training",
"training": "training",
"awareness": "awareness",
"sensibilisierung": "awareness",
# Incident
"vorfall": "incident",
"incident": "incident",
"sicherheitsvorfall": "security_incident",
"security incident": "security_incident",
# Vulnerability
"schwachstelle": "vulnerability",
"vulnerability": "vulnerability",
"patch": "patch_management",
"update": "patch_management",
"patching": "patch_management",
}
# Precompile for substring matching (longest first)
_OBJECT_KEYS_SORTED = sorted(_OBJECT_SYNONYMS.keys(), key=len, reverse=True)
def normalize_object(obj: str) -> str:
"""Normalize a compliance object to a canonical token."""
if not obj:
return ""
obj_lower = obj.strip().lower()
# Exact match
if obj_lower in _OBJECT_SYNONYMS:
return _OBJECT_SYNONYMS[obj_lower]
# Substring match (longest first)
for phrase in _OBJECT_KEYS_SORTED:
if phrase in obj_lower:
return _OBJECT_SYNONYMS[phrase]
# Fallback: strip articles/prepositions, join with underscore
cleaned = re.sub(r"\b(der|die|das|den|dem|des|ein|eine|eines|einem|einen"
r"|für|von|zu|auf|in|an|bei|mit|nach|über|unter|the|a|an"
r"|for|of|to|on|in|at|by|with)\b", "", obj_lower)
tokens = [t for t in cleaned.split() if len(t) > 2]
return "_".join(tokens[:4]) if tokens else obj_lower.replace(" ", "_")
# ── Canonicalization ─────────────────────────────────────────────────
def canonicalize_text(action: str, obj: str, title: str = "") -> str:
"""Build a canonical English text for embedding.
Transforms German compliance text into normalized English tokens
for more stable embedding comparisons.
"""
norm_action = normalize_action(action)
norm_object = normalize_object(obj)
# Build canonical sentence
parts = [norm_action, norm_object]
if title:
# Add title keywords (stripped of common filler)
title_clean = re.sub(
r"\b(und|oder|für|von|zu|der|die|das|den|dem|des|ein|eine"
r"|bei|mit|nach|gemäß|gem\.|laut|entsprechend)\b",
"", title.lower()
)
title_tokens = [t for t in title_clean.split() if len(t) > 3][:5]
if title_tokens:
parts.append("for")
parts.extend(title_tokens)
return " ".join(parts)
# ── Embedding Helper ─────────────────────────────────────────────────
async def get_embedding(text: str) -> list[float]:
"""Get embedding vector for a single text via embedding service."""
try:
async with httpx.AsyncClient(timeout=10.0) as client:
resp = await client.post(
f"{EMBEDDING_URL}/embed",
json={"texts": [text]},
)
embeddings = resp.json().get("embeddings", [])
return embeddings[0] if embeddings else []
except Exception as e:
logger.warning("Embedding failed: %s", e)
return []
def cosine_similarity(a: list[float], b: list[float]) -> float:
"""Compute cosine similarity between two vectors."""
if not a or not b or len(a) != len(b):
return 0.0
dot = sum(x * y for x, y in zip(a, b))
norm_a = sum(x * x for x in a) ** 0.5
norm_b = sum(x * x for x in b) ** 0.5
if norm_a == 0 or norm_b == 0:
return 0.0
return dot / (norm_a * norm_b)
# ── Qdrant Helpers ───────────────────────────────────────────────────
async def qdrant_search(
embedding: list[float],
pattern_id: str,
top_k: int = 10,
collection: Optional[str] = None,
) -> list[dict]:
"""Search Qdrant for similar atomic controls, filtered by pattern_id."""
if not embedding:
return []
coll = collection or QDRANT_COLLECTION
body: dict = {
"vector": embedding,
"limit": top_k,
"with_payload": True,
"filter": {
"must": [
{"key": "pattern_id", "match": {"value": pattern_id}}
]
},
}
try:
async with httpx.AsyncClient(timeout=10.0) as client:
resp = await client.post(
f"{QDRANT_URL}/collections/{coll}/points/search",
json=body,
)
if resp.status_code != 200:
logger.warning("Qdrant search failed: %d", resp.status_code)
return []
return resp.json().get("result", [])
except Exception as e:
logger.warning("Qdrant search error: %s", e)
return []
async def qdrant_search_cross_regulation(
embedding: list[float],
top_k: int = 5,
collection: Optional[str] = None,
) -> list[dict]:
"""Search Qdrant for similar controls across ALL regulations (no pattern_id filter).
Used for cross-regulation linking (e.g. DSGVO Art. 25 ↔ NIS2 Art. 21).
"""
if not embedding:
return []
coll = collection or QDRANT_COLLECTION
body: dict = {
"vector": embedding,
"limit": top_k,
"with_payload": True,
}
try:
async with httpx.AsyncClient(timeout=10.0) as client:
resp = await client.post(
f"{QDRANT_URL}/collections/{coll}/points/search",
json=body,
)
if resp.status_code != 200:
logger.warning("Qdrant cross-reg search failed: %d", resp.status_code)
return []
return resp.json().get("result", [])
except Exception as e:
logger.warning("Qdrant cross-reg search error: %s", e)
return []
async def qdrant_upsert(
point_id: str,
embedding: list[float],
payload: dict,
collection: Optional[str] = None,
) -> bool:
"""Upsert a single point into a Qdrant collection."""
if not embedding:
return False
coll = collection or QDRANT_COLLECTION
body = {
"points": [{
"id": point_id,
"vector": embedding,
"payload": payload,
}]
}
try:
async with httpx.AsyncClient(timeout=10.0) as client:
resp = await client.put(
f"{QDRANT_URL}/collections/{coll}/points",
json=body,
)
return resp.status_code == 200
except Exception as e:
logger.warning("Qdrant upsert error: %s", e)
return False
async def ensure_qdrant_collection(
vector_size: int = 1024,
collection: Optional[str] = None,
) -> bool:
"""Create a Qdrant collection if it doesn't exist (idempotent)."""
coll = collection or QDRANT_COLLECTION
try:
async with httpx.AsyncClient(timeout=10.0) as client:
# Check if exists
resp = await client.get(f"{QDRANT_URL}/collections/{coll}")
if resp.status_code == 200:
return True
# Create
resp = await client.put(
f"{QDRANT_URL}/collections/{coll}",
json={
"vectors": {"size": vector_size, "distance": "Cosine"},
},
)
if resp.status_code == 200:
logger.info("Created Qdrant collection: %s", coll)
# Create payload indexes
for field_name in ["pattern_id", "action_normalized", "object_normalized", "control_id"]:
await client.put(
f"{QDRANT_URL}/collections/{coll}/index",
json={"field_name": field_name, "field_schema": "keyword"},
)
return True
logger.error("Failed to create Qdrant collection: %d", resp.status_code)
return False
except Exception as e:
logger.warning("Qdrant collection check error: %s", e)
return False
# ── Main Dedup Checker ───────────────────────────────────────────────
class ControlDedupChecker:
"""4-stage dedup checker for atomic controls.
Usage:
checker = ControlDedupChecker(db_session)
result = await checker.check_duplicate(candidate_action, candidate_object, candidate_title, pattern_id)
if result.verdict == "link":
checker.add_parent_link(result.matched_control_uuid, parent_uuid)
elif result.verdict == "review":
checker.write_review(candidate, result)
else:
# Insert new control
"""
def __init__(
self,
db,
embed_fn: Optional[Callable[[str], Awaitable[list[float]]]] = None,
search_fn: Optional[Callable] = None,
):
self.db = db
self._embed = embed_fn or get_embedding
self._search = search_fn or qdrant_search
self._cache: dict[str, list[dict]] = {} # pattern_id → existing controls
def _load_existing(self, pattern_id: str) -> list[dict]:
"""Load existing atomic controls with same pattern_id from DB."""
if pattern_id in self._cache:
return self._cache[pattern_id]
from sqlalchemy import text
rows = self.db.execute(text("""
SELECT id::text, control_id, title, objective,
pattern_id,
generation_metadata->>'obligation_type' as obligation_type
FROM canonical_controls
WHERE parent_control_uuid IS NOT NULL
AND release_state != 'deprecated'
AND pattern_id = :pid
"""), {"pid": pattern_id}).fetchall()
result = [
{
"uuid": r[0], "control_id": r[1], "title": r[2],
"objective": r[3], "pattern_id": r[4],
"obligation_type": r[5],
}
for r in rows
]
self._cache[pattern_id] = result
return result
async def check_duplicate(
self,
action: str,
obj: str,
title: str,
pattern_id: Optional[str],
) -> DedupResult:
"""Run the 4-stage dedup pipeline + cross-regulation linking.
Returns DedupResult with verdict: new/link/review.
"""
# No pattern_id → can't dedup meaningfully
if not pattern_id:
return DedupResult(verdict="new", stage="no_pattern")
# Stage 1: Pattern-Gate
existing = self._load_existing(pattern_id)
if not existing:
return DedupResult(
verdict="new", stage="pattern_gate",
details={"reason": "no existing controls with this pattern_id"},
)
# Stage 2: Action-Check
norm_action = normalize_action(action)
# We don't have action stored on existing controls from DB directly,
# so we use embedding for controls that passed pattern gate.
# But we CAN check via generation_metadata if available.
# Stage 3: Object-Normalization
norm_object = normalize_object(obj)
# Stage 4: Embedding Similarity
canonical = canonicalize_text(action, obj, title)
embedding = await self._embed(canonical)
if not embedding:
# Can't compute embedding → default to new
return DedupResult(
verdict="new", stage="embedding_unavailable",
details={"canonical_text": canonical},
)
# Search Qdrant
results = await self._search(embedding, pattern_id, top_k=5)
if not results:
# No intra-pattern matches → try cross-regulation
return await self._check_cross_regulation(embedding, DedupResult(
verdict="new", stage="no_qdrant_matches",
details={"canonical_text": canonical, "action": norm_action, "object": norm_object},
))
# Evaluate best match
best = results[0]
best_score = best.get("score", 0.0)
best_payload = best.get("payload", {})
best_action = best_payload.get("action_normalized", "")
best_object = best_payload.get("object_normalized", "")
# Action differs → NEW (even if embedding is high)
if best_action and norm_action and best_action != norm_action:
return await self._check_cross_regulation(embedding, DedupResult(
verdict="new", stage="action_mismatch",
similarity_score=best_score,
matched_control_id=best_payload.get("control_id"),
details={
"candidate_action": norm_action,
"existing_action": best_action,
"similarity": best_score,
},
))
# Object differs → use higher threshold
if best_object and norm_object and best_object != norm_object:
if best_score > LINK_THRESHOLD_DIFF_OBJECT:
return DedupResult(
verdict="link", stage="embedding_diff_object",
matched_control_uuid=best_payload.get("control_uuid"),
matched_control_id=best_payload.get("control_id"),
matched_title=best_payload.get("title"),
similarity_score=best_score,
details={"candidate_object": norm_object, "existing_object": best_object},
)
return await self._check_cross_regulation(embedding, DedupResult(
verdict="new", stage="object_mismatch_below_threshold",
similarity_score=best_score,
matched_control_id=best_payload.get("control_id"),
details={
"candidate_object": norm_object,
"existing_object": best_object,
"threshold": LINK_THRESHOLD_DIFF_OBJECT,
},
))
# Same action + same object → tiered thresholds
if best_score > LINK_THRESHOLD:
return DedupResult(
verdict="link", stage="embedding_match",
matched_control_uuid=best_payload.get("control_uuid"),
matched_control_id=best_payload.get("control_id"),
matched_title=best_payload.get("title"),
similarity_score=best_score,
)
if best_score > REVIEW_THRESHOLD:
return DedupResult(
verdict="review", stage="embedding_review",
matched_control_uuid=best_payload.get("control_uuid"),
matched_control_id=best_payload.get("control_id"),
matched_title=best_payload.get("title"),
similarity_score=best_score,
)
return await self._check_cross_regulation(embedding, DedupResult(
verdict="new", stage="embedding_below_threshold",
similarity_score=best_score,
details={"threshold": REVIEW_THRESHOLD},
))
async def _check_cross_regulation(
self,
embedding: list[float],
intra_result: DedupResult,
) -> DedupResult:
"""Second pass: cross-regulation linking for controls deemed 'new'.
Searches Qdrant WITHOUT pattern_id filter. Uses a higher threshold
(0.95) to avoid false positives across regulation boundaries.
"""
if intra_result.verdict != "new" or not embedding:
return intra_result
cross_results = await qdrant_search_cross_regulation(embedding, top_k=5)
if not cross_results:
return intra_result
best = cross_results[0]
best_score = best.get("score", 0.0)
if best_score > CROSS_REG_LINK_THRESHOLD:
best_payload = best.get("payload", {})
return DedupResult(
verdict="link",
stage="cross_regulation",
matched_control_uuid=best_payload.get("control_uuid"),
matched_control_id=best_payload.get("control_id"),
matched_title=best_payload.get("title"),
similarity_score=best_score,
link_type="cross_regulation",
details={
"cross_reg_score": best_score,
"cross_reg_threshold": CROSS_REG_LINK_THRESHOLD,
},
)
return intra_result
def add_parent_link(
self,
control_uuid: str,
parent_control_uuid: str,
link_type: str = "dedup_merge",
confidence: float = 0.0,
source_regulation: Optional[str] = None,
source_article: Optional[str] = None,
obligation_candidate_id: Optional[str] = None,
) -> None:
"""Add a parent link to an existing atomic control."""
from sqlalchemy import text
self.db.execute(text("""
INSERT INTO control_parent_links
(control_uuid, parent_control_uuid, link_type, confidence,
source_regulation, source_article, obligation_candidate_id)
VALUES (:cu, :pu, :lt, :conf, :sr, :sa, :oci::uuid)
ON CONFLICT (control_uuid, parent_control_uuid) DO NOTHING
"""), {
"cu": control_uuid,
"pu": parent_control_uuid,
"lt": link_type,
"conf": confidence,
"sr": source_regulation,
"sa": source_article,
"oci": obligation_candidate_id,
})
self.db.commit()
def write_review(
self,
candidate_control_id: str,
candidate_title: str,
candidate_objective: str,
result: DedupResult,
parent_control_uuid: Optional[str] = None,
obligation_candidate_id: Optional[str] = None,
) -> None:
"""Write a dedup review queue entry."""
from sqlalchemy import text
self.db.execute(text("""
INSERT INTO control_dedup_reviews
(candidate_control_id, candidate_title, candidate_objective,
matched_control_uuid, matched_control_id,
similarity_score, dedup_stage, dedup_details,
parent_control_uuid, obligation_candidate_id)
VALUES (:ccid, :ct, :co, :mcu::uuid, :mci, :ss, :ds,
:dd::jsonb, :pcu::uuid, :oci)
"""), {
"ccid": candidate_control_id,
"ct": candidate_title,
"co": candidate_objective,
"mcu": result.matched_control_uuid,
"mci": result.matched_control_id,
"ss": result.similarity_score,
"ds": result.stage,
"dd": __import__("json").dumps(result.details),
"pcu": parent_control_uuid,
"oci": obligation_candidate_id,
})
self.db.commit()
async def index_control(
self,
control_uuid: str,
control_id: str,
title: str,
action: str,
obj: str,
pattern_id: str,
collection: Optional[str] = None,
) -> bool:
"""Index a new atomic control in Qdrant for future dedup checks."""
norm_action = normalize_action(action)
norm_object = normalize_object(obj)
canonical = canonicalize_text(action, obj, title)
embedding = await self._embed(canonical)
if not embedding:
return False
return await qdrant_upsert(
point_id=control_uuid,
embedding=embedding,
payload={
"control_uuid": control_uuid,
"control_id": control_id,
"title": title,
"pattern_id": pattern_id,
"action_normalized": norm_action,
"object_normalized": norm_object,
"canonical_text": canonical,
},
collection=collection,
)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,154 @@
"""
Control Status Transition State Machine.
Enforces that controls cannot be set to "pass" without sufficient evidence.
Prevents Compliance-Theater where controls claim compliance without real proof.
Transition rules:
planned → in_progress : always allowed
in_progress → pass : requires ≥1 evidence with confidence ≥ E2 and
truth_status in (uploaded, observed, validated_internal)
in_progress → partial : requires ≥1 evidence (any level)
pass → fail : always allowed (degradation)
any → n/a : requires status_justification
any → planned : always allowed (reset)
"""
from typing import Any, List, Optional, Tuple
# EvidenceDB is an ORM model from compliance — we only need duck-typed objects
# with .confidence_level and .truth_status attributes.
EvidenceDB = Any
# Confidence level ordering for comparisons
CONFIDENCE_ORDER = {"E0": 0, "E1": 1, "E2": 2, "E3": 3, "E4": 4}
# Truth statuses that qualify as "real" evidence for pass transitions
VALID_TRUTH_STATUSES = {"uploaded", "observed", "validated_internal", "accepted_by_auditor", "provided_to_auditor"}
def validate_transition(
current_status: str,
new_status: str,
evidence_list: Optional[List[EvidenceDB]] = None,
status_justification: Optional[str] = None,
bypass_for_auto_updater: bool = False,
) -> Tuple[bool, List[str]]:
"""
Validate whether a control status transition is allowed.
Args:
current_status: Current control status value (e.g. "planned", "pass")
new_status: Requested new status
evidence_list: List of EvidenceDB objects linked to this control
status_justification: Text justification (required for n/a transitions)
bypass_for_auto_updater: If True, skip evidence checks (used by CI/CD auto-updater
which creates evidence atomically with status change)
Returns:
Tuple of (allowed: bool, violations: list[str])
"""
violations: List[str] = []
evidence_list = evidence_list or []
# Same status → no-op, always allowed
if current_status == new_status:
return True, []
# Reset to planned is always allowed
if new_status == "planned":
return True, []
# n/a requires justification
if new_status == "n/a":
if not status_justification or not status_justification.strip():
violations.append("Transition to 'n/a' requires a status_justification explaining why this control is not applicable.")
return len(violations) == 0, violations
# Degradation: pass → fail is always allowed
if current_status == "pass" and new_status == "fail":
return True, []
# planned → in_progress: always allowed
if current_status == "planned" and new_status == "in_progress":
return True, []
# in_progress → partial: needs at least 1 evidence
if new_status == "partial":
if not bypass_for_auto_updater and len(evidence_list) == 0:
violations.append("Transition to 'partial' requires at least 1 evidence record.")
return len(violations) == 0, violations
# in_progress → pass: strict requirements
if new_status == "pass":
if bypass_for_auto_updater:
return True, []
if len(evidence_list) == 0:
violations.append("Transition to 'pass' requires at least 1 evidence record.")
return False, violations
# Check for at least one qualifying evidence
has_qualifying = False
for e in evidence_list:
conf = getattr(e, "confidence_level", None)
truth = getattr(e, "truth_status", None)
# Get string values from enum or string
conf_val = conf.value if hasattr(conf, "value") else str(conf) if conf else "E1"
truth_val = truth.value if hasattr(truth, "value") else str(truth) if truth else "uploaded"
if CONFIDENCE_ORDER.get(conf_val, 1) >= CONFIDENCE_ORDER["E2"] and truth_val in VALID_TRUTH_STATUSES:
has_qualifying = True
break
if not has_qualifying:
violations.append(
"Transition to 'pass' requires at least 1 evidence with confidence >= E2 "
"and truth_status in (uploaded, observed, validated_internal, accepted_by_auditor). "
"Current evidence does not meet this threshold."
)
return len(violations) == 0, violations
# in_progress → fail: always allowed
if new_status == "fail":
return True, []
# Any other transition from planned/fail to pass requires going through in_progress
if current_status in ("planned", "fail") and new_status == "pass":
if bypass_for_auto_updater:
return True, []
violations.append(
f"Direct transition from '{current_status}' to 'pass' is not allowed. "
f"Move to 'in_progress' first, then to 'pass' with qualifying evidence."
)
return False, violations
# Default: allow other transitions (e.g. fail → partial, partial → pass)
# For partial → pass, apply the same evidence checks
if current_status == "partial" and new_status == "pass":
if bypass_for_auto_updater:
return True, []
has_qualifying = False
for e in evidence_list:
conf = getattr(e, "confidence_level", None)
truth = getattr(e, "truth_status", None)
conf_val = conf.value if hasattr(conf, "value") else str(conf) if conf else "E1"
truth_val = truth.value if hasattr(truth, "value") else str(truth) if truth else "uploaded"
if CONFIDENCE_ORDER.get(conf_val, 1) >= CONFIDENCE_ORDER["E2"] and truth_val in VALID_TRUTH_STATUSES:
has_qualifying = True
break
if not has_qualifying:
violations.append(
"Transition from 'partial' to 'pass' requires at least 1 evidence with confidence >= E2 "
"and truth_status in (uploaded, observed, validated_internal, accepted_by_auditor)."
)
return len(violations) == 0, violations
# All other transitions allowed
return True, []

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,714 @@
"""Framework Decomposition Engine — decomposes framework-container obligations.
Sits between Pass 0a (obligation extraction) and Pass 0b (atomic control
composition). Detects obligations that reference a framework domain (e.g.
"CCM-Praktiken fuer AIS") and decomposes them into concrete sub-obligations
using an internal framework registry.
Three routing types:
atomic → pass through to Pass 0b unchanged
compound → split compound verbs, then Pass 0b
framework_container → decompose via registry, then Pass 0b
The registry is a set of JSON files under compliance/data/frameworks/.
"""
import json
import logging
import os
import re
import uuid
from dataclasses import dataclass, field
from pathlib import Path
from typing import Optional
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Registry loading
# ---------------------------------------------------------------------------
_REGISTRY_DIR = Path(__file__).resolve().parent.parent / "data" / "frameworks"
_REGISTRY: dict[str, dict] = {} # framework_id → framework dict
def _load_registry() -> dict[str, dict]:
"""Load all framework JSON files from the registry directory."""
registry: dict[str, dict] = {}
if not _REGISTRY_DIR.is_dir():
logger.warning("Framework registry dir not found: %s", _REGISTRY_DIR)
return registry
for fpath in sorted(_REGISTRY_DIR.glob("*.json")):
try:
with open(fpath, encoding="utf-8") as f:
fw = json.load(f)
fw_id = fw.get("framework_id", fpath.stem)
registry[fw_id] = fw
logger.info(
"Loaded framework: %s (%d domains)",
fw_id,
len(fw.get("domains", [])),
)
except Exception:
logger.exception("Failed to load framework file: %s", fpath)
return registry
def get_registry() -> dict[str, dict]:
"""Return the global framework registry (lazy-loaded)."""
global _REGISTRY
if not _REGISTRY:
_REGISTRY = _load_registry()
return _REGISTRY
def reload_registry() -> dict[str, dict]:
"""Force-reload the framework registry from disk."""
global _REGISTRY
_REGISTRY = _load_registry()
return _REGISTRY
# ---------------------------------------------------------------------------
# Framework alias index (built from registry)
# ---------------------------------------------------------------------------
def _build_alias_index(registry: dict[str, dict]) -> dict[str, str]:
"""Build a lowercase alias → framework_id lookup."""
idx: dict[str, str] = {}
for fw_id, fw in registry.items():
# Framework-level aliases
idx[fw_id.lower()] = fw_id
name = fw.get("display_name", "")
if name:
idx[name.lower()] = fw_id
# Common short forms
for part in fw_id.lower().replace("_", " ").split():
if len(part) >= 3:
idx[part] = fw_id
return idx
# ---------------------------------------------------------------------------
# Routing — classify obligation type
# ---------------------------------------------------------------------------
# Extended patterns for framework detection (beyond the simple _COMPOSITE_RE
# in decomposition_pass.py — here we also capture the framework name)
_FRAMEWORK_PATTERN = re.compile(
r"(?:praktiken|kontrollen|ma(?:ss|ß)nahmen|anforderungen|vorgaben|controls|practices|measures|requirements)"
r"\s+(?:f(?:ue|ü)r|aus|gem(?:ae|ä)(?:ss|ß)|nach|from|of|for|per)\s+"
r"(.+?)(?:\s+(?:m(?:ue|ü)ssen|sollen|sind|werden|implementieren|umsetzen|einf(?:ue|ü)hren)|\.|,|$)",
re.IGNORECASE,
)
# Direct framework name references
_DIRECT_FRAMEWORK_RE = re.compile(
r"\b(?:CSA\s*CCM|NIST\s*(?:SP\s*)?800-53|OWASP\s*(?:ASVS|SAMM|Top\s*10)"
r"|CIS\s*Controls|BSI\s*(?:IT-)?Grundschutz|ENISA|ISO\s*2700[12]"
r"|COBIT|SOX|PCI\s*DSS|HITRUST|SOC\s*2|KRITIS)\b",
re.IGNORECASE,
)
# Compound verb patterns (multiple main verbs)
_COMPOUND_VERB_RE = re.compile(
r"\b(?:und|sowie|als\s+auch|or|and)\b",
re.IGNORECASE,
)
# No-split phrases that look compound but aren't
_NO_SPLIT_PHRASES = [
"pflegen und aufrechterhalten",
"dokumentieren und pflegen",
"definieren und dokumentieren",
"erstellen und freigeben",
"pruefen und genehmigen",
"identifizieren und bewerten",
"erkennen und melden",
"define and maintain",
"create and maintain",
"establish and maintain",
"monitor and review",
"detect and respond",
]
@dataclass
class RoutingResult:
"""Result of obligation routing classification."""
routing_type: str # atomic | compound | framework_container | unknown_review
framework_ref: Optional[str] = None
framework_domain: Optional[str] = None
domain_title: Optional[str] = None
confidence: float = 0.0
reason: str = ""
def classify_routing(
obligation_text: str,
action_raw: str,
object_raw: str,
condition_raw: Optional[str] = None,
) -> RoutingResult:
"""Classify an obligation into atomic / compound / framework_container."""
combined = f"{obligation_text} {object_raw}".lower()
# --- Step 1: Framework container detection ---
fw_result = _detect_framework(obligation_text, object_raw)
if fw_result.routing_type == "framework_container":
return fw_result
# --- Step 2: Compound verb detection ---
if _is_compound_obligation(action_raw, obligation_text):
return RoutingResult(
routing_type="compound",
confidence=0.7,
reason="multiple_main_verbs",
)
# --- Step 3: Default = atomic ---
return RoutingResult(
routing_type="atomic",
confidence=0.9,
reason="single_action_single_object",
)
def _detect_framework(
obligation_text: str, object_raw: str,
) -> RoutingResult:
"""Detect if obligation references a framework domain."""
combined = f"{obligation_text} {object_raw}"
registry = get_registry()
alias_idx = _build_alias_index(registry)
# Strategy 1: direct framework name match
m = _DIRECT_FRAMEWORK_RE.search(combined)
if m:
fw_name = m.group(0).strip()
fw_id = _resolve_framework_id(fw_name, alias_idx, registry)
if fw_id:
domain_id, domain_title = _match_domain(
combined, registry[fw_id],
)
return RoutingResult(
routing_type="framework_container",
framework_ref=fw_id,
framework_domain=domain_id,
domain_title=domain_title,
confidence=0.95 if domain_id else 0.75,
reason=f"direct_framework_match:{fw_name}",
)
else:
# Framework name recognized but not in registry
return RoutingResult(
routing_type="framework_container",
framework_ref=None,
framework_domain=None,
confidence=0.6,
reason=f"direct_framework_match_no_registry:{fw_name}",
)
# Strategy 2: pattern match ("Praktiken fuer X")
m2 = _FRAMEWORK_PATTERN.search(combined)
if m2:
ref_text = m2.group(1).strip()
fw_id, domain_id, domain_title = _resolve_from_ref_text(
ref_text, registry, alias_idx,
)
if fw_id:
return RoutingResult(
routing_type="framework_container",
framework_ref=fw_id,
framework_domain=domain_id,
domain_title=domain_title,
confidence=0.85 if domain_id else 0.65,
reason=f"pattern_match:{ref_text}",
)
# Strategy 3: keyword-heavy object
if _has_framework_keywords(object_raw):
return RoutingResult(
routing_type="framework_container",
framework_ref=None,
framework_domain=None,
confidence=0.5,
reason="framework_keywords_in_object",
)
return RoutingResult(routing_type="atomic", confidence=0.0)
def _resolve_framework_id(
name: str,
alias_idx: dict[str, str],
registry: dict[str, dict],
) -> Optional[str]:
"""Resolve a framework name to its registry ID."""
normalized = re.sub(r"\s+", " ", name.strip().lower())
# Direct alias match
if normalized in alias_idx:
return alias_idx[normalized]
# Try compact form (strip spaces, hyphens, underscores)
compact = re.sub(r"[\s_\-]+", "", normalized)
for alias, fw_id in alias_idx.items():
if re.sub(r"[\s_\-]+", "", alias) == compact:
return fw_id
# Substring match in display names
for fw_id, fw in registry.items():
display = fw.get("display_name", "").lower()
if normalized in display or display in normalized:
return fw_id
# Partial match: check if normalized contains any alias (for multi-word refs)
for alias, fw_id in alias_idx.items():
if len(alias) >= 4 and alias in normalized:
return fw_id
return None
def _match_domain(
text: str, framework: dict,
) -> tuple[Optional[str], Optional[str]]:
"""Match a domain within a framework from text references."""
text_lower = text.lower()
best_id: Optional[str] = None
best_title: Optional[str] = None
best_score = 0
for domain in framework.get("domains", []):
score = 0
domain_id = domain["domain_id"]
title = domain.get("title", "")
# Exact domain ID match (e.g. "AIS")
if re.search(rf"\b{re.escape(domain_id)}\b", text, re.IGNORECASE):
score += 10
# Full title match
if title.lower() in text_lower:
score += 8
# Alias match
for alias in domain.get("aliases", []):
if alias.lower() in text_lower:
score += 6
break
# Keyword overlap
kw_hits = sum(
1 for kw in domain.get("keywords", [])
if kw.lower() in text_lower
)
score += kw_hits
if score > best_score:
best_score = score
best_id = domain_id
best_title = title
if best_score >= 3:
return best_id, best_title
return None, None
def _resolve_from_ref_text(
ref_text: str,
registry: dict[str, dict],
alias_idx: dict[str, str],
) -> tuple[Optional[str], Optional[str], Optional[str]]:
"""Resolve framework + domain from a reference text like 'AIS' or 'Application Security'."""
ref_lower = ref_text.lower()
for fw_id, fw in registry.items():
for domain in fw.get("domains", []):
# Check domain ID
if domain["domain_id"].lower() in ref_lower:
return fw_id, domain["domain_id"], domain.get("title")
# Check title
if domain.get("title", "").lower() in ref_lower:
return fw_id, domain["domain_id"], domain.get("title")
# Check aliases
for alias in domain.get("aliases", []):
if alias.lower() in ref_lower or ref_lower in alias.lower():
return fw_id, domain["domain_id"], domain.get("title")
return None, None, None
_FRAMEWORK_KW_SET = {
"praktiken", "kontrollen", "massnahmen", "maßnahmen",
"anforderungen", "vorgaben", "framework", "standard",
"baseline", "katalog", "domain", "family", "category",
"practices", "controls", "measures", "requirements",
}
def _has_framework_keywords(text: str) -> bool:
"""Check if text contains framework-indicator keywords."""
words = set(re.findall(r"[a-zäöüß]+", text.lower()))
return len(words & _FRAMEWORK_KW_SET) >= 2
def _is_compound_obligation(action_raw: str, obligation_text: str) -> bool:
"""Detect if the obligation has multiple competing main verbs."""
if not action_raw:
return False
action_lower = action_raw.lower().strip()
# Check no-split phrases first
for phrase in _NO_SPLIT_PHRASES:
if phrase in action_lower:
return False
# Must have a conjunction
if not _COMPOUND_VERB_RE.search(action_lower):
return False
# Split by conjunctions and check if we get 2+ meaningful verbs
parts = re.split(r"\b(?:und|sowie|als\s+auch|or|and)\b", action_lower)
meaningful = [p.strip() for p in parts if len(p.strip()) >= 3]
return len(meaningful) >= 2
# ---------------------------------------------------------------------------
# Framework Decomposition
# ---------------------------------------------------------------------------
@dataclass
class DecomposedObligation:
"""A concrete obligation derived from a framework container."""
obligation_candidate_id: str
parent_control_id: str
parent_framework_container_id: str
source_ref_law: str
source_ref_article: str
obligation_text: str
actor: str
action_raw: str
object_raw: str
condition_raw: Optional[str] = None
trigger_raw: Optional[str] = None
routing_type: str = "atomic"
release_state: str = "decomposed"
subcontrol_id: str = ""
# Metadata
action_hint: str = ""
object_hint: str = ""
object_class: str = ""
keywords: list[str] = field(default_factory=list)
@dataclass
class FrameworkDecompositionResult:
"""Result of framework decomposition."""
framework_container_id: str
source_obligation_candidate_id: str
framework_ref: Optional[str]
framework_domain: Optional[str]
domain_title: Optional[str]
matched_subcontrols: list[str]
decomposition_confidence: float
release_state: str # decomposed | unmatched | error
decomposed_obligations: list[DecomposedObligation]
issues: list[str]
def decompose_framework_container(
obligation_candidate_id: str,
parent_control_id: str,
obligation_text: str,
framework_ref: Optional[str],
framework_domain: Optional[str],
actor: str = "organization",
) -> FrameworkDecompositionResult:
"""Decompose a framework-container obligation into concrete sub-obligations.
Steps:
1. Resolve framework from registry
2. Resolve domain within framework
3. Select relevant subcontrols (keyword filter or full domain)
4. Generate decomposed obligations
"""
container_id = f"FWC-{uuid.uuid4().hex[:8]}"
registry = get_registry()
issues: list[str] = []
# Step 1: Resolve framework
fw = None
if framework_ref and framework_ref in registry:
fw = registry[framework_ref]
else:
# Try to find by name in text
fw, framework_ref = _find_framework_in_text(obligation_text, registry)
if not fw:
issues.append("ERROR: framework_not_matched")
return FrameworkDecompositionResult(
framework_container_id=container_id,
source_obligation_candidate_id=obligation_candidate_id,
framework_ref=framework_ref,
framework_domain=framework_domain,
domain_title=None,
matched_subcontrols=[],
decomposition_confidence=0.0,
release_state="unmatched",
decomposed_obligations=[],
issues=issues,
)
# Step 2: Resolve domain
domain_data = None
domain_title = None
if framework_domain:
for d in fw.get("domains", []):
if d["domain_id"].lower() == framework_domain.lower():
domain_data = d
domain_title = d.get("title")
break
if not domain_data:
# Try matching from text
domain_id, domain_title = _match_domain(obligation_text, fw)
if domain_id:
for d in fw.get("domains", []):
if d["domain_id"] == domain_id:
domain_data = d
framework_domain = domain_id
break
if not domain_data:
issues.append("WARN: domain_not_matched — using all domains")
# Fall back to all subcontrols across all domains
all_subcontrols = []
for d in fw.get("domains", []):
for sc in d.get("subcontrols", []):
sc["_domain_id"] = d["domain_id"]
all_subcontrols.append(sc)
subcontrols = _select_subcontrols(obligation_text, all_subcontrols)
if not subcontrols:
issues.append("ERROR: no_subcontrols_matched")
return FrameworkDecompositionResult(
framework_container_id=container_id,
source_obligation_candidate_id=obligation_candidate_id,
framework_ref=framework_ref,
framework_domain=framework_domain,
domain_title=None,
matched_subcontrols=[],
decomposition_confidence=0.0,
release_state="unmatched",
decomposed_obligations=[],
issues=issues,
)
else:
# Step 3: Select subcontrols from domain
raw_subcontrols = domain_data.get("subcontrols", [])
subcontrols = _select_subcontrols(obligation_text, raw_subcontrols)
if not subcontrols:
# Full domain decomposition
subcontrols = raw_subcontrols
# Quality check: too many subcontrols
if len(subcontrols) > 25:
issues.append(f"WARN: {len(subcontrols)} subcontrols — may be too broad")
# Step 4: Generate decomposed obligations
display_name = fw.get("display_name", framework_ref or "Unknown")
decomposed: list[DecomposedObligation] = []
matched_ids: list[str] = []
for sc in subcontrols:
sc_id = sc.get("subcontrol_id", "")
matched_ids.append(sc_id)
action_hint = sc.get("action_hint", "")
object_hint = sc.get("object_hint", "")
# Quality warnings
if not action_hint:
issues.append(f"WARN: {sc_id} missing action_hint")
if not object_hint:
issues.append(f"WARN: {sc_id} missing object_hint")
obl_id = f"{obligation_candidate_id}-{sc_id}"
decomposed.append(DecomposedObligation(
obligation_candidate_id=obl_id,
parent_control_id=parent_control_id,
parent_framework_container_id=container_id,
source_ref_law=display_name,
source_ref_article=sc_id,
obligation_text=sc.get("statement", ""),
actor=actor,
action_raw=action_hint or _infer_action(sc.get("statement", "")),
object_raw=object_hint or _infer_object(sc.get("statement", "")),
routing_type="atomic",
release_state="decomposed",
subcontrol_id=sc_id,
action_hint=action_hint,
object_hint=object_hint,
object_class=sc.get("object_class", ""),
keywords=sc.get("keywords", []),
))
# Check if decomposed are identical to container
for d in decomposed:
if d.obligation_text.strip() == obligation_text.strip():
issues.append(f"WARN: {d.subcontrol_id} identical to container text")
confidence = _compute_decomposition_confidence(
framework_ref, framework_domain, domain_data, len(subcontrols), issues,
)
return FrameworkDecompositionResult(
framework_container_id=container_id,
source_obligation_candidate_id=obligation_candidate_id,
framework_ref=framework_ref,
framework_domain=framework_domain,
domain_title=domain_title,
matched_subcontrols=matched_ids,
decomposition_confidence=confidence,
release_state="decomposed",
decomposed_obligations=decomposed,
issues=issues,
)
def _find_framework_in_text(
text: str, registry: dict[str, dict],
) -> tuple[Optional[dict], Optional[str]]:
"""Try to find a framework by searching text for known names."""
alias_idx = _build_alias_index(registry)
m = _DIRECT_FRAMEWORK_RE.search(text)
if m:
fw_id = _resolve_framework_id(m.group(0), alias_idx, registry)
if fw_id and fw_id in registry:
return registry[fw_id], fw_id
return None, None
def _select_subcontrols(
obligation_text: str, subcontrols: list[dict],
) -> list[dict]:
"""Select relevant subcontrols based on keyword matching.
Returns empty list if no targeted match found (caller falls back to
full domain).
"""
text_lower = obligation_text.lower()
scored: list[tuple[int, dict]] = []
for sc in subcontrols:
score = 0
for kw in sc.get("keywords", []):
if kw.lower() in text_lower:
score += 1
# Title match
title = sc.get("title", "").lower()
if title and title in text_lower:
score += 3
# Object hint in text
obj = sc.get("object_hint", "").lower()
if obj and obj in text_lower:
score += 2
if score > 0:
scored.append((score, sc))
if not scored:
return []
# Only return those with meaningful overlap (score >= 2)
scored.sort(key=lambda x: x[0], reverse=True)
return [sc for score, sc in scored if score >= 2]
def _infer_action(statement: str) -> str:
"""Infer a basic action verb from a statement."""
s = statement.lower()
if any(w in s for w in ["definiert", "definieren", "define"]):
return "definieren"
if any(w in s for w in ["implementiert", "implementieren", "implement"]):
return "implementieren"
if any(w in s for w in ["dokumentiert", "dokumentieren", "document"]):
return "dokumentieren"
if any(w in s for w in ["ueberwacht", "ueberwachen", "monitor"]):
return "ueberwachen"
if any(w in s for w in ["getestet", "testen", "test"]):
return "testen"
if any(w in s for w in ["geschuetzt", "schuetzen", "protect"]):
return "implementieren"
if any(w in s for w in ["verwaltet", "verwalten", "manage"]):
return "pflegen"
if any(w in s for w in ["gemeldet", "melden", "report"]):
return "melden"
return "implementieren"
def _infer_object(statement: str) -> str:
"""Infer the primary object from a statement (first noun phrase)."""
# Simple heuristic: take the text after "muessen"/"muss" up to the verb
m = re.search(
r"(?:muessen|muss|m(?:ü|ue)ssen)\s+(.+?)(?:\s+werden|\s+sein|\.|,|$)",
statement,
re.IGNORECASE,
)
if m:
return m.group(1).strip()[:80]
# Fallback: first 80 chars
return statement[:80] if statement else ""
def _compute_decomposition_confidence(
framework_ref: Optional[str],
domain: Optional[str],
domain_data: Optional[dict],
num_subcontrols: int,
issues: list[str],
) -> float:
"""Compute confidence score for the decomposition."""
score = 0.3
if framework_ref:
score += 0.25
if domain:
score += 0.20
if domain_data:
score += 0.10
if 1 <= num_subcontrols <= 15:
score += 0.10
elif num_subcontrols > 15:
score += 0.05 # less confident with too many
# Penalize errors
errors = sum(1 for i in issues if i.startswith("ERROR:"))
score -= errors * 0.15
return round(max(min(score, 1.0), 0.0), 2)
# ---------------------------------------------------------------------------
# Registry statistics (for admin/debugging)
# ---------------------------------------------------------------------------
def registry_stats() -> dict:
"""Return summary statistics about the loaded registry."""
reg = get_registry()
stats = {
"frameworks": len(reg),
"details": [],
}
total_domains = 0
total_subcontrols = 0
for fw_id, fw in reg.items():
domains = fw.get("domains", [])
n_sc = sum(len(d.get("subcontrols", [])) for d in domains)
total_domains += len(domains)
total_subcontrols += n_sc
stats["details"].append({
"framework_id": fw_id,
"display_name": fw.get("display_name", ""),
"domains": len(domains),
"subcontrols": n_sc,
})
stats["total_domains"] = total_domains
stats["total_subcontrols"] = total_subcontrols
return stats

View File

@@ -0,0 +1,116 @@
"""
License Gate — checks whether a given source may be used for a specific purpose.
Usage types:
- analysis: Read + analyse internally (TDM under UrhG 44b)
- store_excerpt: Store verbatim excerpt in vault
- ship_embeddings: Ship embeddings in product
- ship_in_product: Ship text/content in product
Policy is driven by the canonical_control_sources table columns:
allowed_analysis, allowed_store_excerpt, allowed_ship_embeddings, allowed_ship_in_product
"""
from __future__ import annotations
import logging
from typing import Any
from sqlalchemy import text
from sqlalchemy.orm import Session
logger = logging.getLogger(__name__)
USAGE_COLUMN_MAP = {
"analysis": "allowed_analysis",
"store_excerpt": "allowed_store_excerpt",
"ship_embeddings": "allowed_ship_embeddings",
"ship_in_product": "allowed_ship_in_product",
}
def check_source_allowed(db: Session, source_id: str, usage_type: str) -> bool:
"""Check whether *source_id* may be used for *usage_type*.
Returns False if the source is unknown or the usage is not allowed.
"""
col = USAGE_COLUMN_MAP.get(usage_type)
if col is None:
logger.warning("Unknown usage_type=%s", usage_type)
return False
row = db.execute(
text(f"SELECT {col} FROM canonical_control_sources WHERE source_id = :sid"),
{"sid": source_id},
).fetchone()
if row is None:
logger.warning("Source %s not found in registry", source_id)
return False
return bool(row[0])
def get_license_matrix(db: Session) -> list[dict[str, Any]]:
"""Return the full license matrix with allowed usages per license."""
rows = db.execute(
text("""
SELECT license_id, name, terms_url, commercial_use,
ai_training_restriction, tdm_allowed_under_44b,
deletion_required, notes
FROM canonical_control_licenses
ORDER BY license_id
""")
).fetchall()
return [
{
"license_id": r.license_id,
"name": r.name,
"terms_url": r.terms_url,
"commercial_use": r.commercial_use,
"ai_training_restriction": r.ai_training_restriction,
"tdm_allowed_under_44b": r.tdm_allowed_under_44b,
"deletion_required": r.deletion_required,
"notes": r.notes,
}
for r in rows
]
def get_source_permissions(db: Session) -> list[dict[str, Any]]:
"""Return all sources with their permission flags."""
rows = db.execute(
text("""
SELECT s.source_id, s.title, s.publisher, s.url, s.version_label,
s.language, s.license_id,
s.allowed_analysis, s.allowed_store_excerpt,
s.allowed_ship_embeddings, s.allowed_ship_in_product,
s.vault_retention_days, s.vault_access_tier,
l.name AS license_name, l.commercial_use
FROM canonical_control_sources s
JOIN canonical_control_licenses l ON l.license_id = s.license_id
ORDER BY s.source_id
""")
).fetchall()
return [
{
"source_id": r.source_id,
"title": r.title,
"publisher": r.publisher,
"url": r.url,
"version_label": r.version_label,
"language": r.language,
"license_id": r.license_id,
"license_name": r.license_name,
"commercial_use": r.commercial_use,
"allowed_analysis": r.allowed_analysis,
"allowed_store_excerpt": r.allowed_store_excerpt,
"allowed_ship_embeddings": r.allowed_ship_embeddings,
"allowed_ship_in_product": r.allowed_ship_in_product,
"vault_retention_days": r.vault_retention_days,
"vault_access_tier": r.vault_access_tier,
}
for r in rows
]

View File

@@ -0,0 +1,624 @@
"""
LLM Provider Abstraction for Compliance AI Features.
Supports:
- Anthropic Claude API (default)
- Self-Hosted LLMs (Ollama, vLLM, LocalAI, etc.)
- HashiCorp Vault integration for secure API key storage
Configuration via environment variables:
- COMPLIANCE_LLM_PROVIDER: "anthropic" or "self_hosted"
- ANTHROPIC_API_KEY: API key for Claude (or loaded from Vault)
- ANTHROPIC_MODEL: Model name (default: claude-sonnet-4-20250514)
- SELF_HOSTED_LLM_URL: Base URL for self-hosted LLM
- SELF_HOSTED_LLM_MODEL: Model name for self-hosted
- SELF_HOSTED_LLM_KEY: Optional API key for self-hosted
Vault Configuration:
- VAULT_ADDR: Vault server address (e.g., http://vault:8200)
- VAULT_TOKEN: Vault authentication token
- USE_VAULT_SECRETS: Set to "true" to enable Vault integration
- VAULT_SECRET_PATH: Path to secrets (default: secret/breakpilot/api_keys)
"""
import os
import asyncio
import logging
from abc import ABC, abstractmethod
from typing import List, Optional, Dict, Any
from dataclasses import dataclass
from enum import Enum
import httpx
logger = logging.getLogger(__name__)
# =============================================================================
# Vault Integration
# =============================================================================
class VaultClient:
"""
HashiCorp Vault client for retrieving secrets.
Supports KV v2 secrets engine.
"""
def __init__(
self,
addr: Optional[str] = None,
token: Optional[str] = None
):
self.addr = addr or os.getenv("VAULT_ADDR", "http://localhost:8200")
self.token = token or os.getenv("VAULT_TOKEN")
self._cache: Dict[str, Any] = {}
self._cache_ttl = 300 # 5 minutes cache
def _get_headers(self) -> Dict[str, str]:
"""Get request headers with Vault token."""
headers = {"Content-Type": "application/json"}
if self.token:
headers["X-Vault-Token"] = self.token
return headers
def get_secret(self, path: str, key: str = "value") -> Optional[str]:
"""
Get a secret from Vault KV v2.
Args:
path: Secret path (e.g., "breakpilot/api_keys/anthropic")
key: Key within the secret data (default: "value")
Returns:
Secret value or None if not found
"""
cache_key = f"{path}:{key}"
# Check cache first
if cache_key in self._cache:
return self._cache[cache_key]
try:
# KV v2 uses /data/ in the path
full_path = f"{self.addr}/v1/secret/data/{path}"
response = httpx.get(
full_path,
headers=self._get_headers(),
timeout=10.0
)
if response.status_code == 200:
data = response.json()
secret_data = data.get("data", {}).get("data", {})
secret_value = secret_data.get(key)
if secret_value:
self._cache[cache_key] = secret_value
logger.info(f"Successfully loaded secret from Vault: {path}")
return secret_value
elif response.status_code == 404:
logger.warning(f"Secret not found in Vault: {path}")
else:
logger.error(f"Vault error {response.status_code}: {response.text}")
except httpx.RequestError as e:
logger.error(f"Failed to connect to Vault at {self.addr}: {e}")
except Exception as e:
logger.error(f"Error retrieving secret from Vault: {e}")
return None
def get_anthropic_key(self) -> Optional[str]:
"""Get Anthropic API key from Vault."""
path = os.getenv("VAULT_ANTHROPIC_PATH", "breakpilot/api_keys/anthropic")
return self.get_secret(path, "value")
def is_available(self) -> bool:
"""Check if Vault is available and authenticated."""
try:
response = httpx.get(
f"{self.addr}/v1/sys/health",
headers=self._get_headers(),
timeout=5.0
)
return response.status_code in (200, 429, 472, 473, 501, 503)
except Exception:
return False
# Singleton Vault client
_vault_client: Optional[VaultClient] = None
def get_vault_client() -> VaultClient:
"""Get shared Vault client instance."""
global _vault_client
if _vault_client is None:
_vault_client = VaultClient()
return _vault_client
def get_secret_from_vault_or_env(
vault_path: str,
env_var: str,
vault_key: str = "value"
) -> Optional[str]:
"""
Get a secret, trying Vault first, then falling back to environment variable.
Args:
vault_path: Path in Vault (e.g., "breakpilot/api_keys/anthropic")
env_var: Environment variable name as fallback
vault_key: Key within Vault secret data
Returns:
Secret value or None
"""
use_vault = os.getenv("USE_VAULT_SECRETS", "").lower() in ("true", "1", "yes")
if use_vault:
vault = get_vault_client()
secret = vault.get_secret(vault_path, vault_key)
if secret:
return secret
logger.info(f"Vault secret not found, falling back to env: {env_var}")
return os.getenv(env_var)
class LLMProviderType(str, Enum):
"""Supported LLM provider types."""
ANTHROPIC = "anthropic"
SELF_HOSTED = "self_hosted"
OLLAMA = "ollama" # Alias for self_hosted (Ollama-specific)
MOCK = "mock" # For testing
@dataclass
class LLMResponse:
"""Standard response from LLM."""
content: str
model: str
provider: str
usage: Optional[Dict[str, int]] = None
raw_response: Optional[Dict[str, Any]] = None
@dataclass
class LLMConfig:
"""Configuration for LLM provider."""
provider_type: LLMProviderType
api_key: Optional[str] = None
model: str = "claude-sonnet-4-20250514"
base_url: Optional[str] = None
max_tokens: int = 4096
temperature: float = 0.3
timeout: float = 60.0
class LLMProvider(ABC):
"""Abstract base class for LLM providers."""
def __init__(self, config: LLMConfig):
self.config = config
@abstractmethod
async def complete(
self,
prompt: str,
system_prompt: Optional[str] = None,
max_tokens: Optional[int] = None,
temperature: Optional[float] = None
) -> LLMResponse:
"""Generate a completion for the given prompt."""
pass
@abstractmethod
async def batch_complete(
self,
prompts: List[str],
system_prompt: Optional[str] = None,
max_tokens: Optional[int] = None,
rate_limit: float = 1.0
) -> List[LLMResponse]:
"""Generate completions for multiple prompts with rate limiting."""
pass
@property
@abstractmethod
def provider_name(self) -> str:
"""Return the provider name."""
pass
class AnthropicProvider(LLMProvider):
"""Claude API Provider using Anthropic's official API."""
ANTHROPIC_API_URL = "https://api.anthropic.com/v1/messages"
def __init__(self, config: LLMConfig):
super().__init__(config)
if not config.api_key:
raise ValueError("Anthropic API key is required")
self.api_key = config.api_key
self.model = config.model or "claude-sonnet-4-20250514"
@property
def provider_name(self) -> str:
return "anthropic"
async def complete(
self,
prompt: str,
system_prompt: Optional[str] = None,
max_tokens: Optional[int] = None,
temperature: Optional[float] = None
) -> LLMResponse:
"""Generate completion using Claude API."""
headers = {
"x-api-key": self.api_key,
"anthropic-version": "2023-06-01",
"content-type": "application/json"
}
messages = [{"role": "user", "content": prompt}]
payload = {
"model": self.model,
"max_tokens": max_tokens or self.config.max_tokens,
"messages": messages
}
if system_prompt:
payload["system"] = system_prompt
if temperature is not None:
payload["temperature"] = temperature
elif self.config.temperature is not None:
payload["temperature"] = self.config.temperature
async with httpx.AsyncClient(timeout=self.config.timeout) as client:
try:
response = await client.post(
self.ANTHROPIC_API_URL,
headers=headers,
json=payload
)
response.raise_for_status()
data = response.json()
content = ""
if data.get("content"):
content = data["content"][0].get("text", "")
return LLMResponse(
content=content,
model=self.model,
provider=self.provider_name,
usage=data.get("usage"),
raw_response=data
)
except httpx.HTTPStatusError as e:
logger.error(f"Anthropic API error: {e.response.status_code} - {e.response.text}")
raise
except Exception as e:
logger.error(f"Anthropic API request failed: {e}")
raise
async def batch_complete(
self,
prompts: List[str],
system_prompt: Optional[str] = None,
max_tokens: Optional[int] = None,
rate_limit: float = 1.0
) -> List[LLMResponse]:
"""Process multiple prompts with rate limiting."""
results = []
for i, prompt in enumerate(prompts):
if i > 0:
await asyncio.sleep(rate_limit)
try:
result = await self.complete(
prompt=prompt,
system_prompt=system_prompt,
max_tokens=max_tokens
)
results.append(result)
except Exception as e:
logger.error(f"Failed to process prompt {i}: {e}")
# Append error response
results.append(LLMResponse(
content=f"Error: {str(e)}",
model=self.model,
provider=self.provider_name
))
return results
class SelfHostedProvider(LLMProvider):
"""Self-Hosted LLM Provider supporting Ollama, vLLM, LocalAI, etc."""
def __init__(self, config: LLMConfig):
super().__init__(config)
if not config.base_url:
raise ValueError("Base URL is required for self-hosted provider")
self.base_url = config.base_url.rstrip("/")
self.model = config.model
self.api_key = config.api_key
@property
def provider_name(self) -> str:
return "self_hosted"
def _detect_api_format(self) -> str:
"""Detect the API format based on URL patterns."""
if "11434" in self.base_url or "ollama" in self.base_url.lower():
return "ollama"
elif "openai" in self.base_url.lower() or "v1" in self.base_url:
return "openai"
else:
return "ollama" # Default to Ollama format
async def complete(
self,
prompt: str,
system_prompt: Optional[str] = None,
max_tokens: Optional[int] = None,
temperature: Optional[float] = None
) -> LLMResponse:
"""Generate completion using self-hosted LLM."""
api_format = self._detect_api_format()
headers = {"content-type": "application/json"}
if self.api_key:
headers["Authorization"] = f"Bearer {self.api_key}"
if api_format == "ollama":
# Ollama API format
endpoint = f"{self.base_url}/api/generate"
full_prompt = prompt
if system_prompt:
full_prompt = f"{system_prompt}\n\n{prompt}"
payload = {
"model": self.model,
"prompt": full_prompt,
"stream": False,
"think": False, # Disable thinking mode (qwen3.5 etc.)
"options": {}
}
if max_tokens:
payload["options"]["num_predict"] = max_tokens
if temperature is not None:
payload["options"]["temperature"] = temperature
else:
# OpenAI-compatible format (vLLM, LocalAI, etc.)
endpoint = f"{self.base_url}/v1/chat/completions"
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
messages.append({"role": "user", "content": prompt})
payload = {
"model": self.model,
"messages": messages,
"max_tokens": max_tokens or self.config.max_tokens,
"temperature": temperature if temperature is not None else self.config.temperature
}
async with httpx.AsyncClient(timeout=self.config.timeout) as client:
try:
response = await client.post(endpoint, headers=headers, json=payload)
response.raise_for_status()
data = response.json()
# Parse response based on format
if api_format == "ollama":
content = data.get("response", "")
else:
# OpenAI format
content = data.get("choices", [{}])[0].get("message", {}).get("content", "")
return LLMResponse(
content=content,
model=self.model,
provider=self.provider_name,
usage=data.get("usage"),
raw_response=data
)
except httpx.HTTPStatusError as e:
logger.error(f"Self-hosted LLM error: {e.response.status_code} - {e.response.text}")
raise
except Exception as e:
logger.error(f"Self-hosted LLM request failed: {e}")
raise
async def batch_complete(
self,
prompts: List[str],
system_prompt: Optional[str] = None,
max_tokens: Optional[int] = None,
rate_limit: float = 0.5 # Self-hosted can be faster
) -> List[LLMResponse]:
"""Process multiple prompts with rate limiting."""
results = []
for i, prompt in enumerate(prompts):
if i > 0:
await asyncio.sleep(rate_limit)
try:
result = await self.complete(
prompt=prompt,
system_prompt=system_prompt,
max_tokens=max_tokens
)
results.append(result)
except Exception as e:
logger.error(f"Failed to process prompt {i}: {e}")
results.append(LLMResponse(
content=f"Error: {str(e)}",
model=self.model,
provider=self.provider_name
))
return results
class MockProvider(LLMProvider):
"""Mock provider for testing without actual API calls."""
def __init__(self, config: LLMConfig):
super().__init__(config)
self.responses: List[str] = []
self.call_count = 0
@property
def provider_name(self) -> str:
return "mock"
def set_responses(self, responses: List[str]):
"""Set predetermined responses for testing."""
self.responses = responses
self.call_count = 0
async def complete(
self,
prompt: str,
system_prompt: Optional[str] = None,
max_tokens: Optional[int] = None,
temperature: Optional[float] = None
) -> LLMResponse:
"""Return mock response."""
if self.responses:
content = self.responses[self.call_count % len(self.responses)]
else:
content = f"Mock response for: {prompt[:50]}..."
self.call_count += 1
return LLMResponse(
content=content,
model="mock-model",
provider=self.provider_name,
usage={"input_tokens": len(prompt), "output_tokens": len(content)}
)
async def batch_complete(
self,
prompts: List[str],
system_prompt: Optional[str] = None,
max_tokens: Optional[int] = None,
rate_limit: float = 0.0
) -> List[LLMResponse]:
"""Return mock responses for batch."""
return [await self.complete(p, system_prompt, max_tokens) for p in prompts]
def get_llm_config() -> LLMConfig:
"""
Create LLM config from environment variables or Vault.
Priority for API key:
1. Vault (if USE_VAULT_SECRETS=true and Vault is available)
2. Environment variable (ANTHROPIC_API_KEY)
"""
provider_type_str = os.getenv("COMPLIANCE_LLM_PROVIDER", "anthropic")
try:
provider_type = LLMProviderType(provider_type_str)
except ValueError:
logger.warning(f"Unknown LLM provider: {provider_type_str}, falling back to mock")
provider_type = LLMProviderType.MOCK
# Get API key from Vault or environment
api_key = None
if provider_type == LLMProviderType.ANTHROPIC:
api_key = get_secret_from_vault_or_env(
vault_path="breakpilot/api_keys/anthropic",
env_var="ANTHROPIC_API_KEY"
)
elif provider_type in (LLMProviderType.SELF_HOSTED, LLMProviderType.OLLAMA):
api_key = get_secret_from_vault_or_env(
vault_path="breakpilot/api_keys/self_hosted_llm",
env_var="SELF_HOSTED_LLM_KEY"
)
# Select model based on provider type
if provider_type == LLMProviderType.ANTHROPIC:
model = os.getenv("ANTHROPIC_MODEL", "claude-sonnet-4-20250514")
elif provider_type in (LLMProviderType.SELF_HOSTED, LLMProviderType.OLLAMA):
model = os.getenv("SELF_HOSTED_LLM_MODEL", "qwen2.5:14b")
else:
model = "mock-model"
return LLMConfig(
provider_type=provider_type,
api_key=api_key,
model=model,
base_url=os.getenv("SELF_HOSTED_LLM_URL"),
max_tokens=int(os.getenv("COMPLIANCE_LLM_MAX_TOKENS", "4096")),
temperature=float(os.getenv("COMPLIANCE_LLM_TEMPERATURE", "0.3")),
timeout=float(os.getenv("COMPLIANCE_LLM_TIMEOUT", "60.0"))
)
def get_llm_provider(config: Optional[LLMConfig] = None) -> LLMProvider:
"""
Factory function to get the appropriate LLM provider based on configuration.
Usage:
provider = get_llm_provider()
response = await provider.complete("Analyze this requirement...")
"""
if config is None:
config = get_llm_config()
if config.provider_type == LLMProviderType.ANTHROPIC:
if not config.api_key:
logger.warning("No Anthropic API key found, using mock provider")
return MockProvider(config)
return AnthropicProvider(config)
elif config.provider_type in (LLMProviderType.SELF_HOSTED, LLMProviderType.OLLAMA):
if not config.base_url:
logger.warning("No self-hosted LLM URL found, using mock provider")
return MockProvider(config)
return SelfHostedProvider(config)
elif config.provider_type == LLMProviderType.MOCK:
return MockProvider(config)
else:
raise ValueError(f"Unsupported LLM provider type: {config.provider_type}")
# Singleton instance for reuse
_provider_instance: Optional[LLMProvider] = None
def get_shared_provider() -> LLMProvider:
"""Get a shared LLM provider instance."""
global _provider_instance
if _provider_instance is None:
_provider_instance = get_llm_provider()
return _provider_instance
def reset_shared_provider():
"""Reset the shared provider instance (useful for testing)."""
global _provider_instance
_provider_instance = None

View File

@@ -0,0 +1,59 @@
"""Shared normative language patterns for assertion classification.
Extracted from decomposition_pass.py for reuse in the assertion engine.
"""
import re
_PFLICHT_SIGNALS = [
r"\bmüssen\b", r"\bmuss\b", r"\bhat\s+sicherzustellen\b",
r"\bhaben\s+sicherzustellen\b", r"\bsind\s+verpflichtet\b",
r"\bist\s+verpflichtet\b",
r"\bist\s+zu\s+\w+en\b", r"\bsind\s+zu\s+\w+en\b",
r"\bhat\s+zu\s+\w+en\b", r"\bhaben\s+zu\s+\w+en\b",
r"\bist\s+\w+zu\w+en\b", r"\bsind\s+\w+zu\w+en\b",
r"\bist\s+\w+\s+zu\s+\w+en\b", r"\bsind\s+\w+\s+zu\s+\w+en\b",
r"\bhat\s+\w+\s+zu\s+\w+en\b", r"\bhaben\s+\w+\s+zu\s+\w+en\b",
r"\bshall\b", r"\bmust\b", r"\brequired\b",
r"\b\w+zuteilen\b", r"\b\w+zuwenden\b", r"\b\w+zustellen\b", r"\b\w+zulegen\b",
r"\b\w+zunehmen\b", r"\b\w+zuführen\b", r"\b\w+zuhalten\b", r"\b\w+zusetzen\b",
r"\b\w+zuweisen\b", r"\b\w+zuordnen\b", r"\b\w+zufügen\b", r"\b\w+zugeben\b",
r"\bist\b.{1,80}\bzu\s+\w+en\b", r"\bsind\b.{1,80}\bzu\s+\w+en\b",
]
PFLICHT_RE = re.compile("|".join(_PFLICHT_SIGNALS), re.IGNORECASE)
_EMPFEHLUNG_SIGNALS = [
r"\bsoll\b", r"\bsollen\b", r"\bsollte\b", r"\bsollten\b",
r"\bgewährleisten\b", r"\bsicherstellen\b",
r"\bshould\b", r"\bensure\b", r"\brecommend\w*\b",
r"\bnachweisen\b", r"\beinhalten\b", r"\bunterlassen\b", r"\bwahren\b",
r"\bdokumentieren\b", r"\bimplementieren\b", r"\büberprüfen\b", r"\büberwachen\b",
r"\bprüfen,\s+ob\b", r"\bkontrollieren,\s+ob\b",
]
EMPFEHLUNG_RE = re.compile("|".join(_EMPFEHLUNG_SIGNALS), re.IGNORECASE)
_KANN_SIGNALS = [
r"\bkann\b", r"\bkönnen\b", r"\bdarf\b", r"\bdürfen\b",
r"\bmay\b", r"\boptional\b",
]
KANN_RE = re.compile("|".join(_KANN_SIGNALS), re.IGNORECASE)
NORMATIVE_RE = re.compile(
"|".join(_PFLICHT_SIGNALS + _EMPFEHLUNG_SIGNALS + _KANN_SIGNALS),
re.IGNORECASE,
)
_RATIONALE_SIGNALS = [
r"\bda\s+", r"\bweil\b", r"\bgrund\b", r"\berwägung",
r"\bbecause\b", r"\breason\b", r"\brationale\b",
r"\bkönnen\s+.*\s+verursachen\b", r"\bführt\s+zu\b",
]
RATIONALE_RE = re.compile("|".join(_RATIONALE_SIGNALS), re.IGNORECASE)
# Evidence-related keywords (for fact detection)
_EVIDENCE_KEYWORDS = [
r"\bnachweis\b", r"\bzertifikat\b", r"\baudit.report\b",
r"\bprotokoll\b", r"\bdokumentation\b", r"\bbericht\b",
r"\bcertificate\b", r"\bevidence\b", r"\bproof\b",
]
EVIDENCE_RE = re.compile("|".join(_EVIDENCE_KEYWORDS), re.IGNORECASE)

View File

@@ -0,0 +1,563 @@
"""Obligation Extractor — 3-Tier Chunk-to-Obligation Linking.
Maps RAG chunks to obligations from the v2 obligation framework using
three tiers (fastest first):
Tier 1: EXACT MATCH — regulation_code + article → obligation_id (~40%)
Tier 2: EMBEDDING — chunk text vs. obligation descriptions (~30%)
Tier 3: LLM EXTRACT — local Ollama extracts obligation text (~25%)
Part of the Multi-Layer Control Architecture (Phase 4 of 8).
"""
import json
import logging
import os
import re
from dataclasses import dataclass, field
from pathlib import Path
from typing import Optional
import httpx
logger = logging.getLogger(__name__)
EMBEDDING_URL = os.getenv("EMBEDDING_URL", "http://embedding-service:8087")
OLLAMA_URL = os.getenv("OLLAMA_URL", "http://host.docker.internal:11434")
OLLAMA_MODEL = os.getenv("CONTROL_GEN_OLLAMA_MODEL", "qwen3.5:35b-a3b")
LLM_TIMEOUT = float(os.getenv("CONTROL_GEN_LLM_TIMEOUT", "180"))
# Embedding similarity thresholds for Tier 2
EMBEDDING_MATCH_THRESHOLD = 0.80
EMBEDDING_CANDIDATE_THRESHOLD = 0.60
# ---------------------------------------------------------------------------
# Regulation code mapping: RAG chunk codes → obligation file regulation IDs
# ---------------------------------------------------------------------------
_REGULATION_CODE_TO_ID = {
# DSGVO
"eu_2016_679": "dsgvo",
"dsgvo": "dsgvo",
"gdpr": "dsgvo",
# AI Act
"eu_2024_1689": "ai_act",
"ai_act": "ai_act",
"aiact": "ai_act",
# NIS2
"eu_2022_2555": "nis2",
"nis2": "nis2",
"bsig": "nis2",
# BDSG
"bdsg": "bdsg",
# TTDSG
"ttdsg": "ttdsg",
# DSA
"eu_2022_2065": "dsa",
"dsa": "dsa",
# Data Act
"eu_2023_2854": "data_act",
"data_act": "data_act",
# EU Machinery
"eu_2023_1230": "eu_machinery",
"eu_machinery": "eu_machinery",
# DORA
"eu_2022_2554": "dora",
"dora": "dora",
}
@dataclass
class ObligationMatch:
"""Result of obligation extraction."""
obligation_id: Optional[str] = None
obligation_title: Optional[str] = None
obligation_text: Optional[str] = None
method: str = "none" # exact_match | embedding_match | llm_extracted | inferred
confidence: float = 0.0
regulation_id: Optional[str] = None # e.g. "dsgvo"
def to_dict(self) -> dict:
return {
"obligation_id": self.obligation_id,
"obligation_title": self.obligation_title,
"obligation_text": self.obligation_text,
"method": self.method,
"confidence": self.confidence,
"regulation_id": self.regulation_id,
}
@dataclass
class _ObligationEntry:
"""Internal representation of a loaded obligation."""
id: str
title: str
description: str
regulation_id: str
articles: list[str] = field(default_factory=list) # normalized: ["art. 30", "§ 38"]
embedding: list[float] = field(default_factory=list)
class ObligationExtractor:
"""3-Tier obligation extraction from RAG chunks.
Usage::
extractor = ObligationExtractor()
await extractor.initialize() # loads obligations + embeddings
match = await extractor.extract(
chunk_text="...",
regulation_code="eu_2016_679",
article="Art. 30",
paragraph="Abs. 1",
)
"""
def __init__(self):
self._article_lookup: dict[str, list[str]] = {} # "dsgvo/art. 30" → ["DSGVO-OBL-001"]
self._obligations: dict[str, _ObligationEntry] = {} # id → entry
self._obligation_embeddings: list[list[float]] = []
self._obligation_ids: list[str] = []
self._initialized = False
async def initialize(self) -> None:
"""Load all obligations from v2 JSON files and compute embeddings."""
if self._initialized:
return
self._load_obligations()
await self._compute_embeddings()
self._initialized = True
logger.info(
"ObligationExtractor initialized: %d obligations, %d article lookups, %d embeddings",
len(self._obligations),
len(self._article_lookup),
sum(1 for e in self._obligation_embeddings if e),
)
async def extract(
self,
chunk_text: str,
regulation_code: str,
article: Optional[str] = None,
paragraph: Optional[str] = None,
) -> ObligationMatch:
"""Extract obligation from a chunk using 3-tier strategy."""
if not self._initialized:
await self.initialize()
reg_id = _normalize_regulation(regulation_code)
# Tier 1: Exact match via article lookup
if article:
match = self._tier1_exact(reg_id, article)
if match:
return match
# Tier 2: Embedding similarity
match = await self._tier2_embedding(chunk_text, reg_id)
if match:
return match
# Tier 3: LLM extraction
match = await self._tier3_llm(chunk_text, regulation_code, article)
return match
# -----------------------------------------------------------------------
# Tier 1: Exact Match
# -----------------------------------------------------------------------
def _tier1_exact(self, reg_id: Optional[str], article: str) -> Optional[ObligationMatch]:
"""Look up obligation by regulation + article."""
if not reg_id:
return None
norm_article = _normalize_article(article)
key = f"{reg_id}/{norm_article}"
obl_ids = self._article_lookup.get(key)
if not obl_ids:
return None
# Take the first match (highest priority)
obl_id = obl_ids[0]
entry = self._obligations.get(obl_id)
if not entry:
return None
return ObligationMatch(
obligation_id=entry.id,
obligation_title=entry.title,
obligation_text=entry.description,
method="exact_match",
confidence=1.0,
regulation_id=reg_id,
)
# -----------------------------------------------------------------------
# Tier 2: Embedding Match
# -----------------------------------------------------------------------
async def _tier2_embedding(
self, chunk_text: str, reg_id: Optional[str]
) -> Optional[ObligationMatch]:
"""Find nearest obligation by embedding similarity."""
if not self._obligation_embeddings:
return None
chunk_embedding = await _get_embedding(chunk_text[:2000])
if not chunk_embedding:
return None
best_idx = -1
best_score = 0.0
for i, obl_emb in enumerate(self._obligation_embeddings):
if not obl_emb:
continue
# Prefer same-regulation matches
obl_id = self._obligation_ids[i]
entry = self._obligations.get(obl_id)
score = _cosine_sim(chunk_embedding, obl_emb)
# Domain bonus: +0.05 if same regulation
if entry and reg_id and entry.regulation_id == reg_id:
score += 0.05
if score > best_score:
best_score = score
best_idx = i
if best_idx < 0:
return None
# Remove domain bonus for threshold comparison
raw_score = best_score
obl_id = self._obligation_ids[best_idx]
entry = self._obligations.get(obl_id)
if entry and reg_id and entry.regulation_id == reg_id:
raw_score -= 0.05
if raw_score >= EMBEDDING_MATCH_THRESHOLD:
return ObligationMatch(
obligation_id=entry.id if entry else obl_id,
obligation_title=entry.title if entry else None,
obligation_text=entry.description if entry else None,
method="embedding_match",
confidence=round(min(raw_score, 1.0), 3),
regulation_id=entry.regulation_id if entry else reg_id,
)
return None
# -----------------------------------------------------------------------
# Tier 3: LLM Extraction
# -----------------------------------------------------------------------
async def _tier3_llm(
self, chunk_text: str, regulation_code: str, article: Optional[str]
) -> ObligationMatch:
"""Use local LLM to extract the obligation from the chunk."""
prompt = f"""Analysiere den folgenden Gesetzestext und extrahiere die zentrale rechtliche Pflicht.
Text:
{chunk_text[:3000]}
Quelle: {regulation_code} {article or ''}
Antworte NUR als JSON:
{{
"obligation_text": "Die zentrale Pflicht in einem Satz",
"actor": "Wer muss handeln (z.B. Verantwortlicher, Auftragsverarbeiter)",
"action": "Was muss getan werden",
"normative_strength": "muss|soll|kann"
}}"""
system_prompt = (
"Du bist ein Rechtsexperte fuer EU-Datenschutz- und Digitalrecht. "
"Extrahiere die zentrale rechtliche Pflicht aus Gesetzestexten. "
"Antworte ausschliesslich als JSON."
)
result_text = await _llm_ollama(prompt, system_prompt)
if not result_text:
return ObligationMatch(
method="llm_extracted",
confidence=0.0,
regulation_id=_normalize_regulation(regulation_code),
)
parsed = _parse_json(result_text)
obligation_text = parsed.get("obligation_text", result_text[:500])
return ObligationMatch(
obligation_id=None,
obligation_title=None,
obligation_text=obligation_text,
method="llm_extracted",
confidence=0.60,
regulation_id=_normalize_regulation(regulation_code),
)
# -----------------------------------------------------------------------
# Initialization helpers
# -----------------------------------------------------------------------
def _load_obligations(self) -> None:
"""Load all obligation files from v2 framework."""
v2_dir = _find_obligations_dir()
if not v2_dir:
logger.warning("Obligations v2 directory not found — Tier 1 disabled")
return
manifest_path = v2_dir / "_manifest.json"
if not manifest_path.exists():
logger.warning("Manifest not found at %s", manifest_path)
return
with open(manifest_path) as f:
manifest = json.load(f)
for reg_info in manifest.get("regulations", []):
reg_id = reg_info["id"]
reg_file = v2_dir / reg_info["file"]
if not reg_file.exists():
logger.warning("Regulation file not found: %s", reg_file)
continue
with open(reg_file) as f:
data = json.load(f)
for obl in data.get("obligations", []):
obl_id = obl["id"]
entry = _ObligationEntry(
id=obl_id,
title=obl.get("title", ""),
description=obl.get("description", ""),
regulation_id=reg_id,
)
# Build article lookup from legal_basis
for basis in obl.get("legal_basis", []):
article_raw = basis.get("article", "")
if article_raw:
norm_art = _normalize_article(article_raw)
key = f"{reg_id}/{norm_art}"
if key not in self._article_lookup:
self._article_lookup[key] = []
self._article_lookup[key].append(obl_id)
entry.articles.append(norm_art)
self._obligations[obl_id] = entry
logger.info(
"Loaded %d obligations from %d regulations",
len(self._obligations),
len(manifest.get("regulations", [])),
)
async def _compute_embeddings(self) -> None:
"""Compute embeddings for all obligation descriptions."""
if not self._obligations:
return
self._obligation_ids = list(self._obligations.keys())
texts = [
f"{self._obligations[oid].title}: {self._obligations[oid].description}"
for oid in self._obligation_ids
]
logger.info("Computing embeddings for %d obligations...", len(texts))
self._obligation_embeddings = await _get_embeddings_batch(texts)
valid = sum(1 for e in self._obligation_embeddings if e)
logger.info("Got %d/%d valid embeddings", valid, len(texts))
# -----------------------------------------------------------------------
# Stats
# -----------------------------------------------------------------------
def stats(self) -> dict:
"""Return initialization statistics."""
return {
"total_obligations": len(self._obligations),
"article_lookups": len(self._article_lookup),
"embeddings_valid": sum(1 for e in self._obligation_embeddings if e),
"regulations": list(
{e.regulation_id for e in self._obligations.values()}
),
"initialized": self._initialized,
}
# ---------------------------------------------------------------------------
# Module-level helpers (reusable by other modules)
# ---------------------------------------------------------------------------
def _normalize_regulation(regulation_code: str) -> Optional[str]:
"""Map a RAG regulation_code to obligation framework regulation ID."""
if not regulation_code:
return None
code = regulation_code.lower().strip()
# Direct lookup
if code in _REGULATION_CODE_TO_ID:
return _REGULATION_CODE_TO_ID[code]
# Prefix matching for families
for prefix, reg_id in [
("eu_2016_679", "dsgvo"),
("eu_2024_1689", "ai_act"),
("eu_2022_2555", "nis2"),
("eu_2022_2065", "dsa"),
("eu_2023_2854", "data_act"),
("eu_2023_1230", "eu_machinery"),
("eu_2022_2554", "dora"),
]:
if code.startswith(prefix):
return reg_id
return None
def _normalize_article(article: str) -> str:
"""Normalize article references for consistent lookup.
Examples:
"Art. 30""art. 30"
"§ 38 BDSG""§ 38"
"Article 10""art. 10"
"Art. 30 Abs. 1""art. 30"
"Artikel 35""art. 35"
"""
if not article:
return ""
s = article.strip()
# Remove trailing law name: "§ 38 BDSG" → "§ 38"
s = re.sub(r"\s+(DSGVO|BDSG|TTDSG|DSA|NIS2|DORA|AI.?Act)\s*$", "", s, flags=re.IGNORECASE)
# Remove paragraph references: "Art. 30 Abs. 1" → "Art. 30"
s = re.sub(r"\s+(Abs|Absatz|para|paragraph|lit|Satz)\.?\s+.*$", "", s, flags=re.IGNORECASE)
# Normalize "Article" / "Artikel" → "Art."
s = re.sub(r"^(Article|Artikel)\s+", "Art. ", s, flags=re.IGNORECASE)
return s.lower().strip()
def _cosine_sim(a: list[float], b: list[float]) -> float:
"""Compute cosine similarity between two vectors."""
if not a or not b or len(a) != len(b):
return 0.0
dot = sum(x * y for x, y in zip(a, b))
norm_a = sum(x * x for x in a) ** 0.5
norm_b = sum(x * x for x in b) ** 0.5
if norm_a == 0 or norm_b == 0:
return 0.0
return dot / (norm_a * norm_b)
def _find_obligations_dir() -> Optional[Path]:
"""Locate the obligations v2 directory."""
candidates = [
Path(__file__).resolve().parent.parent.parent.parent
/ "ai-compliance-sdk" / "policies" / "obligations" / "v2",
Path("/app/ai-compliance-sdk/policies/obligations/v2"),
Path("ai-compliance-sdk/policies/obligations/v2"),
]
for p in candidates:
if p.is_dir() and (p / "_manifest.json").exists():
return p
return None
async def _get_embedding(text: str) -> list[float]:
"""Get embedding vector for a single text."""
try:
async with httpx.AsyncClient(timeout=10.0) as client:
resp = await client.post(
f"{EMBEDDING_URL}/embed",
json={"texts": [text]},
)
resp.raise_for_status()
embeddings = resp.json().get("embeddings", [])
return embeddings[0] if embeddings else []
except Exception:
return []
async def _get_embeddings_batch(
texts: list[str], batch_size: int = 32
) -> list[list[float]]:
"""Get embeddings for multiple texts in batches."""
all_embeddings: list[list[float]] = []
for i in range(0, len(texts), batch_size):
batch = texts[i : i + batch_size]
try:
async with httpx.AsyncClient(timeout=30.0) as client:
resp = await client.post(
f"{EMBEDDING_URL}/embed",
json={"texts": batch},
)
resp.raise_for_status()
embeddings = resp.json().get("embeddings", [])
all_embeddings.extend(embeddings)
except Exception as e:
logger.warning("Batch embedding failed for %d texts: %s", len(batch), e)
all_embeddings.extend([[] for _ in batch])
return all_embeddings
async def _llm_ollama(prompt: str, system_prompt: Optional[str] = None) -> str:
"""Call local Ollama for LLM extraction."""
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
messages.append({"role": "user", "content": prompt})
payload = {
"model": OLLAMA_MODEL,
"messages": messages,
"stream": False,
"format": "json",
"options": {"num_predict": 512},
"think": False,
}
try:
async with httpx.AsyncClient(timeout=LLM_TIMEOUT) as client:
resp = await client.post(f"{OLLAMA_URL}/api/chat", json=payload)
if resp.status_code != 200:
logger.error(
"Ollama chat failed %d: %s", resp.status_code, resp.text[:300]
)
return ""
data = resp.json()
return data.get("message", {}).get("content", "")
except Exception as e:
logger.warning("Ollama call failed: %s", e)
return ""
def _parse_json(text: str) -> dict:
"""Extract JSON from LLM response text."""
# Try direct parse
try:
return json.loads(text)
except json.JSONDecodeError:
pass
# Try extracting JSON block
match = re.search(r"\{[^{}]*\}", text, re.DOTALL)
if match:
try:
return json.loads(match.group())
except json.JSONDecodeError:
pass
return {}

View File

@@ -0,0 +1,532 @@
"""Pattern Matcher — Obligation-to-Control-Pattern Linking.
Maps obligations (from the ObligationExtractor) to control patterns
using two tiers:
Tier 1: KEYWORD MATCH — obligation_match_keywords from patterns (~70%)
Tier 2: EMBEDDING — cosine similarity with domain bonus (~25%)
Part of the Multi-Layer Control Architecture (Phase 5 of 8).
"""
import logging
import os
from dataclasses import dataclass, field
from pathlib import Path
from typing import Optional
import yaml
from services.obligation_extractor import (
_cosine_sim,
_get_embedding,
_get_embeddings_batch,
)
logger = logging.getLogger(__name__)
# Minimum keyword score to accept a match (at least 2 keyword hits)
KEYWORD_MATCH_MIN_HITS = 2
# Embedding threshold for Tier 2
EMBEDDING_PATTERN_THRESHOLD = 0.75
# Domain bonus when regulation maps to the pattern's domain
DOMAIN_BONUS = 0.10
# Map regulation IDs to pattern domains that are likely relevant
_REGULATION_DOMAIN_AFFINITY = {
"dsgvo": ["DATA", "COMP", "GOV"],
"bdsg": ["DATA", "COMP"],
"ttdsg": ["DATA"],
"ai_act": ["AI", "COMP", "DATA"],
"nis2": ["SEC", "INC", "NET", "LOG", "CRYP"],
"dsa": ["DATA", "COMP"],
"data_act": ["DATA", "COMP"],
"eu_machinery": ["SEC", "COMP"],
"dora": ["SEC", "INC", "FIN", "COMP"],
}
@dataclass
class ControlPattern:
"""Python representation of a control pattern from YAML."""
id: str
name: str
name_de: str
domain: str
category: str
description: str
objective_template: str
rationale_template: str
requirements_template: list[str] = field(default_factory=list)
test_procedure_template: list[str] = field(default_factory=list)
evidence_template: list[str] = field(default_factory=list)
severity_default: str = "medium"
implementation_effort_default: str = "m"
obligation_match_keywords: list[str] = field(default_factory=list)
tags: list[str] = field(default_factory=list)
composable_with: list[str] = field(default_factory=list)
open_anchor_refs: list[dict] = field(default_factory=list)
@dataclass
class PatternMatchResult:
"""Result of pattern matching."""
pattern: Optional[ControlPattern] = None
pattern_id: Optional[str] = None
method: str = "none" # keyword | embedding | combined | none
confidence: float = 0.0
keyword_hits: int = 0
total_keywords: int = 0
embedding_score: float = 0.0
domain_bonus_applied: bool = False
composable_patterns: list[str] = field(default_factory=list)
def to_dict(self) -> dict:
return {
"pattern_id": self.pattern_id,
"method": self.method,
"confidence": round(self.confidence, 3),
"keyword_hits": self.keyword_hits,
"total_keywords": self.total_keywords,
"embedding_score": round(self.embedding_score, 3),
"domain_bonus_applied": self.domain_bonus_applied,
"composable_patterns": self.composable_patterns,
}
class PatternMatcher:
"""Links obligations to control patterns using keyword + embedding matching.
Usage::
matcher = PatternMatcher()
await matcher.initialize()
result = await matcher.match(
obligation_text="Fuehrung eines Verarbeitungsverzeichnisses...",
regulation_id="dsgvo",
)
print(result.pattern_id) # e.g. "CP-COMP-001"
print(result.confidence) # e.g. 0.85
"""
def __init__(self):
self._patterns: list[ControlPattern] = []
self._by_id: dict[str, ControlPattern] = {}
self._by_domain: dict[str, list[ControlPattern]] = {}
self._keyword_index: dict[str, list[str]] = {} # keyword → [pattern_ids]
self._pattern_embeddings: list[list[float]] = []
self._pattern_ids: list[str] = []
self._initialized = False
async def initialize(self) -> None:
"""Load patterns from YAML and compute embeddings."""
if self._initialized:
return
self._load_patterns()
self._build_keyword_index()
await self._compute_embeddings()
self._initialized = True
logger.info(
"PatternMatcher initialized: %d patterns, %d keywords, %d embeddings",
len(self._patterns),
len(self._keyword_index),
sum(1 for e in self._pattern_embeddings if e),
)
async def match(
self,
obligation_text: str,
regulation_id: Optional[str] = None,
top_n: int = 1,
) -> PatternMatchResult:
"""Match obligation text to the best control pattern.
Args:
obligation_text: The obligation description to match against.
regulation_id: Source regulation (for domain bonus).
top_n: Number of top results to consider for composability.
Returns:
PatternMatchResult with the best match.
"""
if not self._initialized:
await self.initialize()
if not obligation_text or not self._patterns:
return PatternMatchResult()
# Tier 1: Keyword matching
keyword_result = self._tier1_keyword(obligation_text, regulation_id)
# Tier 2: Embedding matching
embedding_result = await self._tier2_embedding(obligation_text, regulation_id)
# Combine scores: prefer keyword match, boost with embedding if available
best = self._combine_results(keyword_result, embedding_result)
# Attach composable patterns
if best.pattern:
best.composable_patterns = [
pid for pid in best.pattern.composable_with
if pid in self._by_id
]
return best
async def match_top_n(
self,
obligation_text: str,
regulation_id: Optional[str] = None,
n: int = 3,
) -> list[PatternMatchResult]:
"""Return top-N pattern matches sorted by confidence descending."""
if not self._initialized:
await self.initialize()
if not obligation_text or not self._patterns:
return []
keyword_scores = self._keyword_scores(obligation_text, regulation_id)
embedding_scores = await self._embedding_scores(obligation_text, regulation_id)
# Merge scores
all_pattern_ids = set(keyword_scores.keys()) | set(embedding_scores.keys())
results: list[PatternMatchResult] = []
for pid in all_pattern_ids:
pattern = self._by_id.get(pid)
if not pattern:
continue
kw_score = keyword_scores.get(pid, (0, 0, 0.0)) # (hits, total, score)
emb_score = embedding_scores.get(pid, (0.0, False)) # (score, bonus_applied)
kw_hits, kw_total, kw_confidence = kw_score
emb_confidence, bonus_applied = emb_score
# Combined confidence: max of keyword and embedding, with boost if both
if kw_confidence > 0 and emb_confidence > 0:
combined = max(kw_confidence, emb_confidence) + 0.05
method = "combined"
elif kw_confidence > 0:
combined = kw_confidence
method = "keyword"
else:
combined = emb_confidence
method = "embedding"
results.append(PatternMatchResult(
pattern=pattern,
pattern_id=pid,
method=method,
confidence=min(combined, 1.0),
keyword_hits=kw_hits,
total_keywords=kw_total,
embedding_score=emb_confidence,
domain_bonus_applied=bonus_applied,
composable_patterns=[
p for p in pattern.composable_with if p in self._by_id
],
))
# Sort by confidence descending
results.sort(key=lambda r: r.confidence, reverse=True)
return results[:n]
# -----------------------------------------------------------------------
# Tier 1: Keyword Match
# -----------------------------------------------------------------------
def _tier1_keyword(
self, obligation_text: str, regulation_id: Optional[str]
) -> Optional[PatternMatchResult]:
"""Match by counting keyword hits in the obligation text."""
scores = self._keyword_scores(obligation_text, regulation_id)
if not scores:
return None
# Find best match
best_pid = max(scores, key=lambda pid: scores[pid][2])
hits, total, confidence = scores[best_pid]
if hits < KEYWORD_MATCH_MIN_HITS:
return None
pattern = self._by_id.get(best_pid)
if not pattern:
return None
# Check domain bonus
bonus_applied = False
if regulation_id and self._domain_matches(pattern.domain, regulation_id):
confidence = min(confidence + DOMAIN_BONUS, 1.0)
bonus_applied = True
return PatternMatchResult(
pattern=pattern,
pattern_id=best_pid,
method="keyword",
confidence=confidence,
keyword_hits=hits,
total_keywords=total,
domain_bonus_applied=bonus_applied,
)
def _keyword_scores(
self, text: str, regulation_id: Optional[str]
) -> dict[str, tuple[int, int, float]]:
"""Compute keyword match scores for all patterns.
Returns dict: pattern_id → (hits, total_keywords, confidence).
"""
text_lower = text.lower()
hits_by_pattern: dict[str, int] = {}
for keyword, pattern_ids in self._keyword_index.items():
if keyword in text_lower:
for pid in pattern_ids:
hits_by_pattern[pid] = hits_by_pattern.get(pid, 0) + 1
result: dict[str, tuple[int, int, float]] = {}
for pid, hits in hits_by_pattern.items():
pattern = self._by_id.get(pid)
if not pattern:
continue
total = len(pattern.obligation_match_keywords)
confidence = hits / total if total > 0 else 0.0
result[pid] = (hits, total, confidence)
return result
# -----------------------------------------------------------------------
# Tier 2: Embedding Match
# -----------------------------------------------------------------------
async def _tier2_embedding(
self, obligation_text: str, regulation_id: Optional[str]
) -> Optional[PatternMatchResult]:
"""Match by embedding similarity against pattern objective_templates."""
scores = await self._embedding_scores(obligation_text, regulation_id)
if not scores:
return None
best_pid = max(scores, key=lambda pid: scores[pid][0])
emb_score, bonus_applied = scores[best_pid]
if emb_score < EMBEDDING_PATTERN_THRESHOLD:
return None
pattern = self._by_id.get(best_pid)
if not pattern:
return None
return PatternMatchResult(
pattern=pattern,
pattern_id=best_pid,
method="embedding",
confidence=min(emb_score, 1.0),
embedding_score=emb_score,
domain_bonus_applied=bonus_applied,
)
async def _embedding_scores(
self, obligation_text: str, regulation_id: Optional[str]
) -> dict[str, tuple[float, bool]]:
"""Compute embedding similarity scores for all patterns.
Returns dict: pattern_id → (score, domain_bonus_applied).
"""
if not self._pattern_embeddings:
return {}
chunk_embedding = await _get_embedding(obligation_text[:2000])
if not chunk_embedding:
return {}
result: dict[str, tuple[float, bool]] = {}
for i, pat_emb in enumerate(self._pattern_embeddings):
if not pat_emb:
continue
pid = self._pattern_ids[i]
pattern = self._by_id.get(pid)
if not pattern:
continue
score = _cosine_sim(chunk_embedding, pat_emb)
# Domain bonus
bonus_applied = False
if regulation_id and self._domain_matches(pattern.domain, regulation_id):
score += DOMAIN_BONUS
bonus_applied = True
result[pid] = (score, bonus_applied)
return result
# -----------------------------------------------------------------------
# Score combination
# -----------------------------------------------------------------------
def _combine_results(
self,
keyword_result: Optional[PatternMatchResult],
embedding_result: Optional[PatternMatchResult],
) -> PatternMatchResult:
"""Combine keyword and embedding results into the best match."""
if not keyword_result and not embedding_result:
return PatternMatchResult()
if not keyword_result:
return embedding_result
if not embedding_result:
return keyword_result
# Both matched — check if they agree
if keyword_result.pattern_id == embedding_result.pattern_id:
# Same pattern: boost confidence
combined_confidence = min(
max(keyword_result.confidence, embedding_result.confidence) + 0.05,
1.0,
)
return PatternMatchResult(
pattern=keyword_result.pattern,
pattern_id=keyword_result.pattern_id,
method="combined",
confidence=combined_confidence,
keyword_hits=keyword_result.keyword_hits,
total_keywords=keyword_result.total_keywords,
embedding_score=embedding_result.embedding_score,
domain_bonus_applied=(
keyword_result.domain_bonus_applied
or embedding_result.domain_bonus_applied
),
)
# Different patterns: pick the one with higher confidence
if keyword_result.confidence >= embedding_result.confidence:
return keyword_result
return embedding_result
# -----------------------------------------------------------------------
# Domain affinity
# -----------------------------------------------------------------------
@staticmethod
def _domain_matches(pattern_domain: str, regulation_id: str) -> bool:
"""Check if a pattern's domain has affinity with a regulation."""
affine_domains = _REGULATION_DOMAIN_AFFINITY.get(regulation_id, [])
return pattern_domain in affine_domains
# -----------------------------------------------------------------------
# Initialization helpers
# -----------------------------------------------------------------------
def _load_patterns(self) -> None:
"""Load control patterns from YAML files."""
patterns_dir = _find_patterns_dir()
if not patterns_dir:
logger.warning("Control patterns directory not found")
return
for yaml_file in sorted(patterns_dir.glob("*.yaml")):
if yaml_file.name.startswith("_"):
continue
try:
with open(yaml_file) as f:
data = yaml.safe_load(f)
if not data or "patterns" not in data:
continue
for p in data["patterns"]:
pattern = ControlPattern(
id=p["id"],
name=p["name"],
name_de=p["name_de"],
domain=p["domain"],
category=p["category"],
description=p["description"],
objective_template=p["objective_template"],
rationale_template=p["rationale_template"],
requirements_template=p.get("requirements_template", []),
test_procedure_template=p.get("test_procedure_template", []),
evidence_template=p.get("evidence_template", []),
severity_default=p.get("severity_default", "medium"),
implementation_effort_default=p.get("implementation_effort_default", "m"),
obligation_match_keywords=p.get("obligation_match_keywords", []),
tags=p.get("tags", []),
composable_with=p.get("composable_with", []),
open_anchor_refs=p.get("open_anchor_refs", []),
)
self._patterns.append(pattern)
self._by_id[pattern.id] = pattern
domain_list = self._by_domain.setdefault(pattern.domain, [])
domain_list.append(pattern)
except Exception as e:
logger.error("Failed to load %s: %s", yaml_file.name, e)
logger.info("Loaded %d patterns from %s", len(self._patterns), patterns_dir)
def _build_keyword_index(self) -> None:
"""Build reverse index: keyword → [pattern_ids]."""
for pattern in self._patterns:
for kw in pattern.obligation_match_keywords:
lower_kw = kw.lower()
if lower_kw not in self._keyword_index:
self._keyword_index[lower_kw] = []
self._keyword_index[lower_kw].append(pattern.id)
async def _compute_embeddings(self) -> None:
"""Compute embeddings for all pattern objective templates."""
if not self._patterns:
return
self._pattern_ids = [p.id for p in self._patterns]
texts = [
f"{p.name_de}: {p.objective_template}"
for p in self._patterns
]
logger.info("Computing embeddings for %d patterns...", len(texts))
self._pattern_embeddings = await _get_embeddings_batch(texts)
valid = sum(1 for e in self._pattern_embeddings if e)
logger.info("Got %d/%d valid pattern embeddings", valid, len(texts))
# -----------------------------------------------------------------------
# Public helpers
# -----------------------------------------------------------------------
def get_pattern(self, pattern_id: str) -> Optional[ControlPattern]:
"""Get a pattern by its ID."""
return self._by_id.get(pattern_id.upper())
def get_patterns_by_domain(self, domain: str) -> list[ControlPattern]:
"""Get all patterns for a domain."""
return self._by_domain.get(domain.upper(), [])
def stats(self) -> dict:
"""Return matcher statistics."""
return {
"total_patterns": len(self._patterns),
"domains": list(self._by_domain.keys()),
"keywords": len(self._keyword_index),
"embeddings_valid": sum(1 for e in self._pattern_embeddings if e),
"initialized": self._initialized,
}
def _find_patterns_dir() -> Optional[Path]:
"""Locate the control_patterns directory."""
candidates = [
Path(__file__).resolve().parent.parent.parent.parent
/ "ai-compliance-sdk" / "policies" / "control_patterns",
Path("/app/ai-compliance-sdk/policies/control_patterns"),
Path("ai-compliance-sdk/policies/control_patterns"),
]
for p in candidates:
if p.is_dir():
return p
return None

View File

@@ -0,0 +1,670 @@
"""Pipeline Adapter — New 10-Stage Pipeline Integration.
Bridges the existing 7-stage control_generator pipeline with the new
multi-layer components (ObligationExtractor, PatternMatcher, ControlComposer).
New pipeline flow:
chunk → license_classify
→ obligation_extract (Stage 4 — NEW)
→ pattern_match (Stage 5 — NEW)
→ control_compose (Stage 6 — replaces old Stage 3)
→ harmonize → anchor → store + crosswalk → mark processed
Can be used in two modes:
1. INLINE: Called from _process_batch() to enrich the pipeline
2. STANDALONE: Process chunks directly through new stages
Part of the Multi-Layer Control Architecture (Phase 7 of 8).
"""
import hashlib
import json
import logging
from dataclasses import dataclass, field
from typing import Optional
from sqlalchemy import text
from sqlalchemy.orm import Session
from services.control_composer import ComposedControl, ControlComposer
from services.obligation_extractor import ObligationExtractor, ObligationMatch
from services.pattern_matcher import PatternMatcher, PatternMatchResult
logger = logging.getLogger(__name__)
@dataclass
class PipelineChunk:
"""Input chunk for the new pipeline stages."""
text: str
collection: str = ""
regulation_code: str = ""
article: Optional[str] = None
paragraph: Optional[str] = None
license_rule: int = 3
license_info: dict = field(default_factory=dict)
source_citation: Optional[dict] = None
chunk_hash: str = ""
def compute_hash(self) -> str:
if not self.chunk_hash:
self.chunk_hash = hashlib.sha256(self.text.encode()).hexdigest()
return self.chunk_hash
@dataclass
class PipelineResult:
"""Result of processing a chunk through the new pipeline."""
chunk: PipelineChunk
obligation: ObligationMatch = field(default_factory=ObligationMatch)
pattern_result: PatternMatchResult = field(default_factory=PatternMatchResult)
control: Optional[ComposedControl] = None
crosswalk_written: bool = False
error: Optional[str] = None
def to_dict(self) -> dict:
return {
"chunk_hash": self.chunk.chunk_hash,
"obligation": self.obligation.to_dict() if self.obligation else None,
"pattern": self.pattern_result.to_dict() if self.pattern_result else None,
"control": self.control.to_dict() if self.control else None,
"crosswalk_written": self.crosswalk_written,
"error": self.error,
}
class PipelineAdapter:
"""Integrates ObligationExtractor + PatternMatcher + ControlComposer.
Usage::
adapter = PipelineAdapter(db)
await adapter.initialize()
result = await adapter.process_chunk(PipelineChunk(
text="...",
regulation_code="eu_2016_679",
article="Art. 30",
license_rule=1,
))
"""
def __init__(self, db: Optional[Session] = None):
self.db = db
self._extractor = ObligationExtractor()
self._matcher = PatternMatcher()
self._composer = ControlComposer()
self._initialized = False
async def initialize(self) -> None:
"""Initialize all sub-components."""
if self._initialized:
return
await self._extractor.initialize()
await self._matcher.initialize()
self._initialized = True
logger.info("PipelineAdapter initialized")
async def process_chunk(self, chunk: PipelineChunk) -> PipelineResult:
"""Process a single chunk through the new 3-stage pipeline.
Stage 4: Obligation Extract
Stage 5: Pattern Match
Stage 6: Control Compose
"""
if not self._initialized:
await self.initialize()
chunk.compute_hash()
result = PipelineResult(chunk=chunk)
try:
# Stage 4: Obligation Extract
result.obligation = await self._extractor.extract(
chunk_text=chunk.text,
regulation_code=chunk.regulation_code,
article=chunk.article,
paragraph=chunk.paragraph,
)
# Stage 5: Pattern Match
obligation_text = (
result.obligation.obligation_text
or result.obligation.obligation_title
or chunk.text[:500]
)
result.pattern_result = await self._matcher.match(
obligation_text=obligation_text,
regulation_id=result.obligation.regulation_id,
)
# Stage 6: Control Compose
result.control = await self._composer.compose(
obligation=result.obligation,
pattern_result=result.pattern_result,
chunk_text=chunk.text if chunk.license_rule in (1, 2) else None,
license_rule=chunk.license_rule,
source_citation=chunk.source_citation,
regulation_code=chunk.regulation_code,
)
except Exception as e:
logger.error("Pipeline processing failed: %s", e)
result.error = str(e)
return result
async def process_batch(self, chunks: list[PipelineChunk]) -> list[PipelineResult]:
"""Process multiple chunks through the pipeline."""
results = []
for chunk in chunks:
result = await self.process_chunk(chunk)
results.append(result)
return results
def write_crosswalk(self, result: PipelineResult, control_uuid: str) -> bool:
"""Write obligation_extraction + crosswalk_matrix rows for a processed chunk.
Called AFTER the control is stored in canonical_controls.
"""
if not self.db or not result.control:
return False
chunk = result.chunk
obligation = result.obligation
pattern = result.pattern_result
try:
# 1. Write obligation_extraction row
self.db.execute(
text("""
INSERT INTO obligation_extractions (
chunk_hash, collection, regulation_code,
article, paragraph, obligation_id,
obligation_text, confidence, extraction_method,
pattern_id, pattern_match_score, control_uuid
) VALUES (
:chunk_hash, :collection, :regulation_code,
:article, :paragraph, :obligation_id,
:obligation_text, :confidence, :extraction_method,
:pattern_id, :pattern_match_score,
CAST(:control_uuid AS uuid)
)
"""),
{
"chunk_hash": chunk.chunk_hash,
"collection": chunk.collection,
"regulation_code": chunk.regulation_code,
"article": chunk.article,
"paragraph": chunk.paragraph,
"obligation_id": obligation.obligation_id if obligation else None,
"obligation_text": (
obligation.obligation_text[:2000]
if obligation and obligation.obligation_text
else None
),
"confidence": obligation.confidence if obligation else 0,
"extraction_method": obligation.method if obligation else "none",
"pattern_id": pattern.pattern_id if pattern else None,
"pattern_match_score": pattern.confidence if pattern else 0,
"control_uuid": control_uuid,
},
)
# 2. Write crosswalk_matrix row
self.db.execute(
text("""
INSERT INTO crosswalk_matrix (
regulation_code, article, paragraph,
obligation_id, pattern_id,
master_control_id, master_control_uuid,
confidence, source
) VALUES (
:regulation_code, :article, :paragraph,
:obligation_id, :pattern_id,
:master_control_id,
CAST(:master_control_uuid AS uuid),
:confidence, :source
)
"""),
{
"regulation_code": chunk.regulation_code,
"article": chunk.article,
"paragraph": chunk.paragraph,
"obligation_id": obligation.obligation_id if obligation else None,
"pattern_id": pattern.pattern_id if pattern else None,
"master_control_id": result.control.control_id,
"master_control_uuid": control_uuid,
"confidence": min(
obligation.confidence if obligation else 0,
pattern.confidence if pattern else 0,
),
"source": "auto",
},
)
# 3. Update canonical_controls with pattern_id + obligation_ids
if result.control.pattern_id or result.control.obligation_ids:
self.db.execute(
text("""
UPDATE canonical_controls
SET pattern_id = COALESCE(:pattern_id, pattern_id),
obligation_ids = COALESCE(:obligation_ids, obligation_ids)
WHERE id = CAST(:control_uuid AS uuid)
"""),
{
"pattern_id": result.control.pattern_id,
"obligation_ids": json.dumps(result.control.obligation_ids),
"control_uuid": control_uuid,
},
)
self.db.commit()
result.crosswalk_written = True
return True
except Exception as e:
logger.error("Failed to write crosswalk: %s", e)
self.db.rollback()
return False
def stats(self) -> dict:
"""Return component statistics."""
return {
"extractor": self._extractor.stats(),
"matcher": self._matcher.stats(),
"initialized": self._initialized,
}
# ---------------------------------------------------------------------------
# Migration Passes — Backfill existing 4,800+ controls
# ---------------------------------------------------------------------------
class MigrationPasses:
"""Non-destructive migration passes for existing controls.
Pass 1: Obligation Linkage (deterministic, article→obligation lookup)
Pass 2: Pattern Classification (keyword-based matching)
Pass 3: Quality Triage (categorize by linkage completeness)
Pass 4: Crosswalk Backfill (write crosswalk rows for linked controls)
Pass 5: Deduplication (mark duplicate controls)
Usage::
migration = MigrationPasses(db)
await migration.initialize()
result = await migration.run_pass1_obligation_linkage(limit=100)
result = await migration.run_pass2_pattern_classification(limit=100)
result = migration.run_pass3_quality_triage()
result = migration.run_pass4_crosswalk_backfill()
result = migration.run_pass5_deduplication()
"""
def __init__(self, db: Session):
self.db = db
self._extractor = ObligationExtractor()
self._matcher = PatternMatcher()
self._initialized = False
async def initialize(self) -> None:
"""Initialize extractors (loads obligations + patterns)."""
if self._initialized:
return
self._extractor._load_obligations()
self._matcher._load_patterns()
self._matcher._build_keyword_index()
self._initialized = True
# -------------------------------------------------------------------
# Pass 1: Obligation Linkage (deterministic)
# -------------------------------------------------------------------
async def run_pass1_obligation_linkage(self, limit: int = 0) -> dict:
"""Link existing controls to obligations via source_citation article.
For each control with source_citation → extract regulation + article
→ look up in obligation framework → set obligation_ids.
"""
if not self._initialized:
await self.initialize()
query = """
SELECT id, control_id, source_citation, generation_metadata
FROM canonical_controls
WHERE release_state NOT IN ('deprecated')
AND (obligation_ids IS NULL OR obligation_ids = '[]')
"""
if limit > 0:
query += f" LIMIT {limit}"
rows = self.db.execute(text(query)).fetchall()
stats = {"total": len(rows), "linked": 0, "no_match": 0, "no_citation": 0}
for row in rows:
control_uuid = str(row[0])
control_id = row[1]
citation = row[2]
metadata = row[3]
# Extract regulation + article from citation or metadata
reg_code, article = _extract_regulation_article(citation, metadata)
if not reg_code:
stats["no_citation"] += 1
continue
# Tier 1: Exact match
match = self._extractor._tier1_exact(reg_code, article or "")
if match and match.obligation_id:
self.db.execute(
text("""
UPDATE canonical_controls
SET obligation_ids = :obl_ids
WHERE id = CAST(:uuid AS uuid)
"""),
{
"obl_ids": json.dumps([match.obligation_id]),
"uuid": control_uuid,
},
)
stats["linked"] += 1
else:
stats["no_match"] += 1
self.db.commit()
logger.info("Pass 1: %s", stats)
return stats
# -------------------------------------------------------------------
# Pass 2: Pattern Classification (keyword-based)
# -------------------------------------------------------------------
async def run_pass2_pattern_classification(self, limit: int = 0) -> dict:
"""Classify existing controls into patterns via keyword matching.
For each control without pattern_id → keyword-match title+objective
against pattern library → assign best match.
"""
if not self._initialized:
await self.initialize()
query = """
SELECT id, control_id, title, objective
FROM canonical_controls
WHERE release_state NOT IN ('deprecated')
AND (pattern_id IS NULL OR pattern_id = '')
"""
if limit > 0:
query += f" LIMIT {limit}"
rows = self.db.execute(text(query)).fetchall()
stats = {"total": len(rows), "classified": 0, "no_match": 0}
for row in rows:
control_uuid = str(row[0])
title = row[2] or ""
objective = row[3] or ""
# Keyword match
match_text = f"{title} {objective}"
result = self._matcher._tier1_keyword(match_text, None)
if result and result.pattern_id and result.keyword_hits >= 2:
self.db.execute(
text("""
UPDATE canonical_controls
SET pattern_id = :pattern_id
WHERE id = CAST(:uuid AS uuid)
"""),
{
"pattern_id": result.pattern_id,
"uuid": control_uuid,
},
)
stats["classified"] += 1
else:
stats["no_match"] += 1
self.db.commit()
logger.info("Pass 2: %s", stats)
return stats
# -------------------------------------------------------------------
# Pass 3: Quality Triage
# -------------------------------------------------------------------
def run_pass3_quality_triage(self) -> dict:
"""Categorize controls by linkage completeness.
Sets generation_metadata.triage_status:
- "review": has both obligation_id + pattern_id
- "needs_obligation": has pattern_id but no obligation_id
- "needs_pattern": has obligation_id but no pattern_id
- "legacy_unlinked": has neither
"""
categories = {
"review": """
UPDATE canonical_controls
SET generation_metadata = jsonb_set(
COALESCE(generation_metadata::jsonb, '{}'::jsonb),
'{triage_status}', '"review"'
)
WHERE release_state NOT IN ('deprecated')
AND obligation_ids IS NOT NULL AND obligation_ids != '[]'
AND pattern_id IS NOT NULL AND pattern_id != ''
""",
"needs_obligation": """
UPDATE canonical_controls
SET generation_metadata = jsonb_set(
COALESCE(generation_metadata::jsonb, '{}'::jsonb),
'{triage_status}', '"needs_obligation"'
)
WHERE release_state NOT IN ('deprecated')
AND (obligation_ids IS NULL OR obligation_ids = '[]')
AND pattern_id IS NOT NULL AND pattern_id != ''
""",
"needs_pattern": """
UPDATE canonical_controls
SET generation_metadata = jsonb_set(
COALESCE(generation_metadata::jsonb, '{}'::jsonb),
'{triage_status}', '"needs_pattern"'
)
WHERE release_state NOT IN ('deprecated')
AND obligation_ids IS NOT NULL AND obligation_ids != '[]'
AND (pattern_id IS NULL OR pattern_id = '')
""",
"legacy_unlinked": """
UPDATE canonical_controls
SET generation_metadata = jsonb_set(
COALESCE(generation_metadata::jsonb, '{}'::jsonb),
'{triage_status}', '"legacy_unlinked"'
)
WHERE release_state NOT IN ('deprecated')
AND (obligation_ids IS NULL OR obligation_ids = '[]')
AND (pattern_id IS NULL OR pattern_id = '')
""",
}
stats = {}
for category, sql in categories.items():
result = self.db.execute(text(sql))
stats[category] = result.rowcount
self.db.commit()
logger.info("Pass 3: %s", stats)
return stats
# -------------------------------------------------------------------
# Pass 4: Crosswalk Backfill
# -------------------------------------------------------------------
def run_pass4_crosswalk_backfill(self) -> dict:
"""Create crosswalk_matrix rows for controls with obligation + pattern.
Only creates rows that don't already exist.
"""
result = self.db.execute(text("""
INSERT INTO crosswalk_matrix (
regulation_code, obligation_id, pattern_id,
master_control_id, master_control_uuid,
confidence, source
)
SELECT
COALESCE(
(generation_metadata::jsonb->>'source_regulation'),
''
) AS regulation_code,
obl.value::text AS obligation_id,
cc.pattern_id,
cc.control_id,
cc.id,
0.80,
'migrated'
FROM canonical_controls cc,
jsonb_array_elements_text(
COALESCE(cc.obligation_ids::jsonb, '[]'::jsonb)
) AS obl(value)
WHERE cc.release_state NOT IN ('deprecated')
AND cc.pattern_id IS NOT NULL AND cc.pattern_id != ''
AND cc.obligation_ids IS NOT NULL AND cc.obligation_ids != '[]'
AND NOT EXISTS (
SELECT 1 FROM crosswalk_matrix cw
WHERE cw.master_control_uuid = cc.id
AND cw.obligation_id = obl.value::text
)
"""))
rows_inserted = result.rowcount
self.db.commit()
logger.info("Pass 4: %d crosswalk rows inserted", rows_inserted)
return {"rows_inserted": rows_inserted}
# -------------------------------------------------------------------
# Pass 5: Deduplication
# -------------------------------------------------------------------
def run_pass5_deduplication(self) -> dict:
"""Mark duplicate controls (same obligation + same pattern).
Groups controls by (obligation_id, pattern_id), keeps the one with
highest evidence_confidence (or newest), marks rest as deprecated.
"""
# Find groups with duplicates
groups = self.db.execute(text("""
SELECT cc.pattern_id,
obl.value::text AS obligation_id,
array_agg(cc.id ORDER BY cc.evidence_confidence DESC NULLS LAST, cc.created_at DESC) AS ids,
count(*) AS cnt
FROM canonical_controls cc,
jsonb_array_elements_text(
COALESCE(cc.obligation_ids::jsonb, '[]'::jsonb)
) AS obl(value)
WHERE cc.release_state NOT IN ('deprecated')
AND cc.pattern_id IS NOT NULL AND cc.pattern_id != ''
GROUP BY cc.pattern_id, obl.value::text
HAVING count(*) > 1
""")).fetchall()
stats = {"groups_found": len(groups), "controls_deprecated": 0}
for group in groups:
ids = group[2] # Array of UUIDs, first is the keeper
if len(ids) <= 1:
continue
# Keep first (highest confidence), deprecate rest
deprecate_ids = ids[1:]
for dep_id in deprecate_ids:
self.db.execute(
text("""
UPDATE canonical_controls
SET release_state = 'deprecated',
generation_metadata = jsonb_set(
COALESCE(generation_metadata::jsonb, '{}'::jsonb),
'{deprecated_reason}', '"duplicate_same_obligation_pattern"'
)
WHERE id = CAST(:uuid AS uuid)
AND release_state != 'deprecated'
"""),
{"uuid": str(dep_id)},
)
stats["controls_deprecated"] += 1
self.db.commit()
logger.info("Pass 5: %s", stats)
return stats
def migration_status(self) -> dict:
"""Return overall migration progress."""
row = self.db.execute(text("""
SELECT
count(*) AS total,
count(*) FILTER (WHERE obligation_ids IS NOT NULL AND obligation_ids != '[]') AS has_obligation,
count(*) FILTER (WHERE pattern_id IS NOT NULL AND pattern_id != '') AS has_pattern,
count(*) FILTER (
WHERE obligation_ids IS NOT NULL AND obligation_ids != '[]'
AND pattern_id IS NOT NULL AND pattern_id != ''
) AS fully_linked,
count(*) FILTER (WHERE release_state = 'deprecated') AS deprecated
FROM canonical_controls
""")).fetchone()
return {
"total_controls": row[0],
"has_obligation": row[1],
"has_pattern": row[2],
"fully_linked": row[3],
"deprecated": row[4],
"coverage_obligation_pct": round(row[1] / max(row[0], 1) * 100, 1),
"coverage_pattern_pct": round(row[2] / max(row[0], 1) * 100, 1),
"coverage_full_pct": round(row[3] / max(row[0], 1) * 100, 1),
}
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _extract_regulation_article(
citation: Optional[str], metadata: Optional[str]
) -> tuple[Optional[str], Optional[str]]:
"""Extract regulation_code and article from control's citation/metadata."""
from services.obligation_extractor import _normalize_regulation
reg_code = None
article = None
# Try citation first (JSON string or dict)
if citation:
try:
c = json.loads(citation) if isinstance(citation, str) else citation
if isinstance(c, dict):
article = c.get("article") or c.get("source_article")
# Try to get regulation from source field
source = c.get("source", "")
if source:
reg_code = _normalize_regulation(source)
except (json.JSONDecodeError, TypeError):
pass
# Try metadata
if metadata and not reg_code:
try:
m = json.loads(metadata) if isinstance(metadata, str) else metadata
if isinstance(m, dict):
src_reg = m.get("source_regulation", "")
if src_reg:
reg_code = _normalize_regulation(src_reg)
if not article:
article = m.get("source_article")
except (json.JSONDecodeError, TypeError):
pass
return reg_code, article

View File

@@ -0,0 +1,213 @@
"""
Compliance RAG Client — Proxy to Go SDK RAG Search.
Lightweight HTTP client that queries the Go AI Compliance SDK's
POST /sdk/v1/rag/search endpoint. This avoids needing embedding
models or direct Qdrant access in Python.
Error-tolerant: RAG failures never break the calling function.
"""
import logging
import os
from dataclasses import dataclass
from typing import List, Optional
import httpx
logger = logging.getLogger(__name__)
SDK_URL = os.getenv("SDK_URL", "http://ai-compliance-sdk:8090")
RAG_SEARCH_TIMEOUT = 15.0 # seconds
@dataclass
class RAGSearchResult:
"""A single search result from the compliance corpus."""
text: str
regulation_code: str
regulation_name: str
regulation_short: str
category: str
article: str
paragraph: str
source_url: str
score: float
collection: str = ""
class ComplianceRAGClient:
"""
RAG client that proxies search requests to the Go SDK.
Usage:
client = get_rag_client()
results = await client.search("DSGVO Art. 35", collection="bp_compliance_recht")
context_str = client.format_for_prompt(results)
"""
def __init__(self, base_url: str = SDK_URL):
self._search_url = f"{base_url}/sdk/v1/rag/search"
async def search(
self,
query: str,
collection: str = "bp_compliance_ce",
regulations: Optional[List[str]] = None,
top_k: int = 5,
) -> List[RAGSearchResult]:
"""
Search the RAG corpus via Go SDK.
Returns an empty list on any error (never raises).
"""
payload = {
"query": query,
"collection": collection,
"top_k": top_k,
}
if regulations:
payload["regulations"] = regulations
try:
async with httpx.AsyncClient(timeout=RAG_SEARCH_TIMEOUT) as client:
resp = await client.post(self._search_url, json=payload)
if resp.status_code != 200:
logger.warning(
"RAG search returned %d: %s", resp.status_code, resp.text[:200]
)
return []
data = resp.json()
results = []
for r in data.get("results", []):
results.append(RAGSearchResult(
text=r.get("text", ""),
regulation_code=r.get("regulation_code", ""),
regulation_name=r.get("regulation_name", ""),
regulation_short=r.get("regulation_short", ""),
category=r.get("category", ""),
article=r.get("article", ""),
paragraph=r.get("paragraph", ""),
source_url=r.get("source_url", ""),
score=r.get("score", 0.0),
collection=collection,
))
return results
except Exception as e:
logger.warning("RAG search failed: %s", e)
return []
async def search_with_rerank(
self,
query: str,
collection: str = "bp_compliance_ce",
regulations: Optional[List[str]] = None,
top_k: int = 5,
) -> List[RAGSearchResult]:
"""
Search with optional cross-encoder re-ranking.
Fetches top_k*4 results from RAG, then re-ranks with cross-encoder
and returns top_k. Falls back to regular search if reranker is disabled.
"""
from .reranker import get_reranker
reranker = get_reranker()
if reranker is None:
return await self.search(query, collection, regulations, top_k)
# Fetch more candidates for re-ranking
candidates = await self.search(
query, collection, regulations, top_k=max(top_k * 4, 20)
)
if not candidates:
return []
texts = [c.text for c in candidates]
try:
ranked_indices = reranker.rerank(query, texts, top_k=top_k)
return [candidates[i] for i in ranked_indices]
except Exception as e:
logger.warning("Reranking failed, returning unranked: %s", e)
return candidates[:top_k]
async def scroll(
self,
collection: str,
offset: Optional[str] = None,
limit: int = 100,
) -> tuple[List[RAGSearchResult], Optional[str]]:
"""
Scroll through ALL chunks in a collection (paginated).
Returns (chunks, next_offset). next_offset is None when done.
"""
scroll_url = self._search_url.replace("/search", "/scroll")
params = {"collection": collection, "limit": str(limit)}
if offset:
params["offset"] = offset
try:
async with httpx.AsyncClient(timeout=30.0) as client:
resp = await client.get(scroll_url, params=params)
if resp.status_code != 200:
logger.warning(
"RAG scroll returned %d: %s", resp.status_code, resp.text[:200]
)
return [], None
data = resp.json()
results = []
for r in data.get("chunks", []):
results.append(RAGSearchResult(
text=r.get("text", ""),
regulation_code=r.get("regulation_code", ""),
regulation_name=r.get("regulation_name", ""),
regulation_short=r.get("regulation_short", ""),
category=r.get("category", ""),
article=r.get("article", ""),
paragraph=r.get("paragraph", ""),
source_url=r.get("source_url", ""),
score=0.0,
collection=collection,
))
next_offset = data.get("next_offset") or None
return results, next_offset
except Exception as e:
logger.warning("RAG scroll failed: %s", e)
return [], None
def format_for_prompt(
self, results: List[RAGSearchResult], max_results: int = 5
) -> str:
"""Format search results as Markdown for inclusion in an LLM prompt."""
if not results:
return ""
lines = ["## Relevanter Rechtskontext\n"]
for i, r in enumerate(results[:max_results]):
header = f"{i + 1}. **{r.regulation_short}** ({r.regulation_code})"
if r.article:
header += f"{r.article}"
lines.append(header)
text = r.text[:400] + "..." if len(r.text) > 400 else r.text
lines.append(f" > {text}\n")
return "\n".join(lines)
# Singleton
_rag_client: Optional[ComplianceRAGClient] = None
def get_rag_client() -> ComplianceRAGClient:
"""Get the shared RAG client instance."""
global _rag_client
if _rag_client is None:
_rag_client = ComplianceRAGClient()
return _rag_client

View File

@@ -0,0 +1,85 @@
"""
Cross-Encoder Re-Ranking for RAG Search Results.
Uses BGE Reranker v2 (BAAI/bge-reranker-v2-m3, MIT license) to re-rank
search results from Qdrant for improved retrieval quality.
Lazy-loads the model on first use. Disabled by default (RERANK_ENABLED=false).
"""
import logging
import os
from typing import Optional
logger = logging.getLogger(__name__)
RERANK_ENABLED = os.getenv("RERANK_ENABLED", "false").lower() == "true"
RERANK_MODEL = os.getenv("RERANK_MODEL", "BAAI/bge-reranker-v2-m3")
class Reranker:
"""Cross-encoder reranker using sentence-transformers."""
def __init__(self, model_name: str = RERANK_MODEL):
self._model = None # Lazy init
self._model_name = model_name
def _ensure_model(self) -> None:
"""Load model on first use."""
if self._model is not None:
return
try:
from sentence_transformers import CrossEncoder
logger.info("Loading reranker model: %s", self._model_name)
self._model = CrossEncoder(self._model_name)
logger.info("Reranker model loaded successfully")
except ImportError:
logger.error(
"sentence-transformers not installed. "
"Install with: pip install sentence-transformers"
)
raise
except Exception as e:
logger.error("Failed to load reranker model: %s", e)
raise
def rerank(
self, query: str, texts: list[str], top_k: int = 5
) -> list[int]:
"""
Return indices of top_k texts sorted by relevance (highest first).
Args:
query: The search query.
texts: List of candidate texts to re-rank.
top_k: Number of top results to return.
Returns:
List of indices into the original texts list, sorted by relevance.
"""
if not texts:
return []
self._ensure_model()
pairs = [[query, text] for text in texts]
scores = self._model.predict(pairs)
# Sort by score descending, return indices
ranked = sorted(range(len(scores)), key=lambda i: scores[i], reverse=True)
return ranked[:top_k]
# Module-level singleton
_reranker: Optional[Reranker] = None
def get_reranker() -> Optional[Reranker]:
"""Get the shared reranker instance. Returns None if disabled."""
global _reranker
if not RERANK_ENABLED:
return None
if _reranker is None:
_reranker = Reranker()
return _reranker

View File

@@ -0,0 +1,223 @@
"""
Too-Close Similarity Detector — checks whether a candidate text is too similar
to a protected source text (copyright / license compliance).
Five metrics:
1. Exact-phrase — longest identical token sequence
2. Token overlap — Jaccard similarity of token sets
3. 3-gram Jaccard — Jaccard similarity of character 3-grams
4. Embedding cosine — via bge-m3 (Ollama or embedding-service)
5. LCS ratio — Longest Common Subsequence / max(len_a, len_b)
Decision:
PASS — no fail + max 1 warn
WARN — max 2 warn, no fail → human review
FAIL — any fail threshold → block, rewrite required
"""
from __future__ import annotations
import logging
import re
from dataclasses import dataclass
from typing import Optional
import httpx
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Thresholds
# ---------------------------------------------------------------------------
THRESHOLDS = {
"max_exact_run": {"warn": 8, "fail": 12},
"token_overlap": {"warn": 0.20, "fail": 0.30},
"ngram_jaccard": {"warn": 0.10, "fail": 0.18},
"embedding_cosine": {"warn": 0.86, "fail": 0.92},
"lcs_ratio": {"warn": 0.35, "fail": 0.50},
}
# ---------------------------------------------------------------------------
# Tokenisation helpers
# ---------------------------------------------------------------------------
_WORD_RE = re.compile(r"\w+", re.UNICODE)
def _tokenize(text: str) -> list[str]:
return [t.lower() for t in _WORD_RE.findall(text)]
def _char_ngrams(text: str, n: int = 3) -> set[str]:
text = text.lower()
return {text[i : i + n] for i in range(len(text) - n + 1)} if len(text) >= n else set()
# ---------------------------------------------------------------------------
# Metric implementations
# ---------------------------------------------------------------------------
def max_exact_run(tokens_a: list[str], tokens_b: list[str]) -> int:
"""Longest contiguous identical token sequence between a and b."""
if not tokens_a or not tokens_b:
return 0
best = 0
set_b = set(tokens_b)
for i in range(len(tokens_a)):
if tokens_a[i] not in set_b:
continue
for j in range(len(tokens_b)):
if tokens_a[i] != tokens_b[j]:
continue
run = 0
ii, jj = i, j
while ii < len(tokens_a) and jj < len(tokens_b) and tokens_a[ii] == tokens_b[jj]:
run += 1
ii += 1
jj += 1
if run > best:
best = run
return best
def token_overlap_jaccard(tokens_a: list[str], tokens_b: list[str]) -> float:
"""Jaccard similarity of token sets."""
set_a, set_b = set(tokens_a), set(tokens_b)
if not set_a and not set_b:
return 0.0
return len(set_a & set_b) / len(set_a | set_b)
def ngram_jaccard(text_a: str, text_b: str, n: int = 3) -> float:
"""Jaccard similarity of character n-grams."""
grams_a = _char_ngrams(text_a, n)
grams_b = _char_ngrams(text_b, n)
if not grams_a and not grams_b:
return 0.0
return len(grams_a & grams_b) / len(grams_a | grams_b)
def lcs_ratio(tokens_a: list[str], tokens_b: list[str]) -> float:
"""LCS length / max(len_a, len_b)."""
m, n = len(tokens_a), len(tokens_b)
if m == 0 or n == 0:
return 0.0
# Space-optimised LCS (two rows)
prev = [0] * (n + 1)
curr = [0] * (n + 1)
for i in range(1, m + 1):
for j in range(1, n + 1):
if tokens_a[i - 1] == tokens_b[j - 1]:
curr[j] = prev[j - 1] + 1
else:
curr[j] = max(prev[j], curr[j - 1])
prev, curr = curr, [0] * (n + 1)
return prev[n] / max(m, n)
async def embedding_cosine(text_a: str, text_b: str, embedding_url: str | None = None) -> float:
"""Cosine similarity via embedding service (bge-m3).
Falls back to 0.0 if the service is unreachable.
"""
url = embedding_url or "http://embedding-service:8087"
try:
async with httpx.AsyncClient(timeout=10.0) as client:
resp = await client.post(
f"{url}/embed",
json={"texts": [text_a, text_b]},
)
resp.raise_for_status()
embeddings = resp.json().get("embeddings", [])
if len(embeddings) < 2:
return 0.0
return _cosine(embeddings[0], embeddings[1])
except Exception:
logger.warning("Embedding service unreachable, skipping cosine check")
return 0.0
def _cosine(a: list[float], b: list[float]) -> float:
dot = sum(x * y for x, y in zip(a, b))
norm_a = sum(x * x for x in a) ** 0.5
norm_b = sum(x * x for x in b) ** 0.5
if norm_a == 0 or norm_b == 0:
return 0.0
return dot / (norm_a * norm_b)
# ---------------------------------------------------------------------------
# Decision engine
# ---------------------------------------------------------------------------
@dataclass
class SimilarityReport:
max_exact_run: int
token_overlap: float
ngram_jaccard: float
embedding_cosine: float
lcs_ratio: float
status: str # PASS, WARN, FAIL
details: dict # per-metric status
def _classify(value: float | int, metric: str) -> str:
t = THRESHOLDS[metric]
if value >= t["fail"]:
return "FAIL"
if value >= t["warn"]:
return "WARN"
return "PASS"
async def check_similarity(
source_text: str,
candidate_text: str,
embedding_url: str | None = None,
) -> SimilarityReport:
"""Run all 5 metrics and return an aggregate report."""
tok_src = _tokenize(source_text)
tok_cand = _tokenize(candidate_text)
m_exact = max_exact_run(tok_src, tok_cand)
m_token = token_overlap_jaccard(tok_src, tok_cand)
m_ngram = ngram_jaccard(source_text, candidate_text)
m_embed = await embedding_cosine(source_text, candidate_text, embedding_url)
m_lcs = lcs_ratio(tok_src, tok_cand)
details = {
"max_exact_run": _classify(m_exact, "max_exact_run"),
"token_overlap": _classify(m_token, "token_overlap"),
"ngram_jaccard": _classify(m_ngram, "ngram_jaccard"),
"embedding_cosine": _classify(m_embed, "embedding_cosine"),
"lcs_ratio": _classify(m_lcs, "lcs_ratio"),
}
fail_count = sum(1 for v in details.values() if v == "FAIL")
warn_count = sum(1 for v in details.values() if v == "WARN")
if fail_count > 0:
status = "FAIL"
elif warn_count > 2:
status = "FAIL"
elif warn_count > 1:
status = "WARN"
elif warn_count == 1:
status = "PASS"
else:
status = "PASS"
return SimilarityReport(
max_exact_run=m_exact,
token_overlap=round(m_token, 4),
ngram_jaccard=round(m_ngram, 4),
embedding_cosine=round(m_embed, 4),
lcs_ratio=round(m_lcs, 4),
status=status,
details=details,
)

View File

@@ -0,0 +1,331 @@
"""V1 Control Enrichment Service — Match Eigenentwicklung controls to regulations.
Finds regulatory coverage for v1 controls (generation_strategy='ungrouped',
pipeline_version=1, no source_citation) by embedding similarity search.
Reuses embedding + Qdrant helpers from control_dedup.py.
"""
import logging
from typing import Optional
from sqlalchemy import text
from db.session import SessionLocal
from services.control_dedup import (
get_embedding,
qdrant_search_cross_regulation,
)
logger = logging.getLogger(__name__)
# Similarity threshold — lower than dedup (0.85) since we want informational matches
# Typical top scores for v1 controls are 0.70-0.77
V1_MATCH_THRESHOLD = 0.70
V1_MAX_MATCHES = 5
def _is_eigenentwicklung_query() -> str:
"""SQL WHERE clause identifying v1 Eigenentwicklung controls."""
return """
generation_strategy = 'ungrouped'
AND (pipeline_version = '1' OR pipeline_version IS NULL)
AND source_citation IS NULL
AND parent_control_uuid IS NULL
AND release_state NOT IN ('rejected', 'merged', 'deprecated')
"""
async def count_v1_controls() -> int:
"""Count how many v1 Eigenentwicklung controls exist."""
with SessionLocal() as db:
row = db.execute(text(f"""
SELECT COUNT(*) AS cnt
FROM canonical_controls
WHERE {_is_eigenentwicklung_query()}
""")).fetchone()
return row.cnt if row else 0
async def enrich_v1_matches(
dry_run: bool = True,
batch_size: int = 100,
offset: int = 0,
) -> dict:
"""Find regulatory matches for v1 Eigenentwicklung controls.
Args:
dry_run: If True, only count — don't write matches.
batch_size: Number of v1 controls to process per call.
offset: Pagination offset (v1 control index).
Returns:
Stats dict with counts, sample matches, and pagination info.
"""
with SessionLocal() as db:
# 1. Load v1 controls (paginated)
v1_controls = db.execute(text(f"""
SELECT id, control_id, title, objective, category
FROM canonical_controls
WHERE {_is_eigenentwicklung_query()}
ORDER BY control_id
LIMIT :limit OFFSET :offset
"""), {"limit": batch_size, "offset": offset}).fetchall()
# Count total for pagination
total_row = db.execute(text(f"""
SELECT COUNT(*) AS cnt
FROM canonical_controls
WHERE {_is_eigenentwicklung_query()}
""")).fetchone()
total_v1 = total_row.cnt if total_row else 0
if not v1_controls:
return {
"dry_run": dry_run,
"processed": 0,
"total_v1": total_v1,
"message": "Kein weiterer Batch — alle v1 Controls verarbeitet.",
}
if dry_run:
return {
"dry_run": True,
"total_v1": total_v1,
"offset": offset,
"batch_size": batch_size,
"sample_controls": [
{
"control_id": r.control_id,
"title": r.title,
"category": r.category,
}
for r in v1_controls[:20]
],
}
# 2. Process each v1 control
processed = 0
matches_inserted = 0
errors = []
sample_matches = []
for v1 in v1_controls:
try:
# Build search text
search_text = f"{v1.title}{v1.objective}"
# Get embedding
embedding = await get_embedding(search_text)
if not embedding:
errors.append({
"control_id": v1.control_id,
"error": "Embedding fehlgeschlagen",
})
continue
# Search Qdrant (cross-regulation, no pattern filter)
# Collection is atomic_controls_dedup (contains ~51k atomare Controls)
results = await qdrant_search_cross_regulation(
embedding, top_k=20,
collection="atomic_controls_dedup",
)
# For each hit: resolve to a regulatory parent with source_citation.
# Atomic controls in Qdrant usually have parent_control_uuid → parent
# has the source_citation. We deduplicate by parent to avoid
# listing the same regulation multiple times.
rank = 0
seen_parents: set[str] = set()
for hit in results:
score = hit.get("score", 0)
if score < V1_MATCH_THRESHOLD:
continue
payload = hit.get("payload", {})
matched_uuid = payload.get("control_uuid")
if not matched_uuid or matched_uuid == str(v1.id):
continue
# Try the matched control itself first, then its parent
matched_row = db.execute(text("""
SELECT c.id, c.control_id, c.title, c.source_citation,
c.severity, c.category, c.parent_control_uuid
FROM canonical_controls c
WHERE c.id = CAST(:uuid AS uuid)
"""), {"uuid": matched_uuid}).fetchone()
if not matched_row:
continue
# Resolve to regulatory control (one with source_citation)
reg_row = matched_row
if not reg_row.source_citation and reg_row.parent_control_uuid:
# Look up parent — the parent has the source_citation
parent_row = db.execute(text("""
SELECT id, control_id, title, source_citation,
severity, category, parent_control_uuid
FROM canonical_controls
WHERE id = CAST(:uuid AS uuid)
AND source_citation IS NOT NULL
"""), {"uuid": str(reg_row.parent_control_uuid)}).fetchone()
if parent_row:
reg_row = parent_row
if not reg_row.source_citation:
continue
# Deduplicate by parent UUID
parent_key = str(reg_row.id)
if parent_key in seen_parents:
continue
seen_parents.add(parent_key)
rank += 1
if rank > V1_MAX_MATCHES:
break
# Extract source info
source_citation = reg_row.source_citation or {}
matched_source = source_citation.get("source") if isinstance(source_citation, dict) else None
matched_article = source_citation.get("article") if isinstance(source_citation, dict) else None
# Insert match — link to the regulatory parent (not the atomic child)
db.execute(text("""
INSERT INTO v1_control_matches
(v1_control_uuid, matched_control_uuid, similarity_score,
match_rank, matched_source, matched_article, match_method)
VALUES
(CAST(:v1_uuid AS uuid), CAST(:matched_uuid AS uuid), :score,
:rank, :source, :article, 'embedding')
ON CONFLICT (v1_control_uuid, matched_control_uuid) DO UPDATE
SET similarity_score = EXCLUDED.similarity_score,
match_rank = EXCLUDED.match_rank
"""), {
"v1_uuid": str(v1.id),
"matched_uuid": str(reg_row.id),
"score": round(score, 3),
"rank": rank,
"source": matched_source,
"article": matched_article,
})
matches_inserted += 1
# Collect sample
if len(sample_matches) < 20:
sample_matches.append({
"v1_control_id": v1.control_id,
"v1_title": v1.title,
"matched_control_id": reg_row.control_id,
"matched_title": reg_row.title,
"matched_source": matched_source,
"matched_article": matched_article,
"similarity_score": round(score, 3),
"match_rank": rank,
})
processed += 1
except Exception as e:
logger.warning("V1 enrichment error for %s: %s", v1.control_id, e)
errors.append({
"control_id": v1.control_id,
"error": str(e),
})
db.commit()
# Pagination
next_offset = offset + batch_size if len(v1_controls) == batch_size else None
return {
"dry_run": False,
"offset": offset,
"batch_size": batch_size,
"next_offset": next_offset,
"total_v1": total_v1,
"processed": processed,
"matches_inserted": matches_inserted,
"errors": errors[:10],
"sample_matches": sample_matches,
}
async def get_v1_matches(control_uuid: str) -> list[dict]:
"""Get all regulatory matches for a specific v1 control.
Args:
control_uuid: The UUID of the v1 control.
Returns:
List of match dicts with control details.
"""
with SessionLocal() as db:
rows = db.execute(text("""
SELECT
m.similarity_score,
m.match_rank,
m.matched_source,
m.matched_article,
m.match_method,
c.control_id AS matched_control_id,
c.title AS matched_title,
c.objective AS matched_objective,
c.severity AS matched_severity,
c.category AS matched_category,
c.source_citation AS matched_source_citation
FROM v1_control_matches m
JOIN canonical_controls c ON c.id = m.matched_control_uuid
WHERE m.v1_control_uuid = CAST(:uuid AS uuid)
ORDER BY m.match_rank
"""), {"uuid": control_uuid}).fetchall()
return [
{
"matched_control_id": r.matched_control_id,
"matched_title": r.matched_title,
"matched_objective": r.matched_objective,
"matched_severity": r.matched_severity,
"matched_category": r.matched_category,
"matched_source": r.matched_source,
"matched_article": r.matched_article,
"matched_source_citation": r.matched_source_citation,
"similarity_score": float(r.similarity_score),
"match_rank": r.match_rank,
"match_method": r.match_method,
}
for r in rows
]
async def get_v1_enrichment_stats() -> dict:
"""Get overview stats for v1 enrichment."""
with SessionLocal() as db:
total_v1 = db.execute(text(f"""
SELECT COUNT(*) AS cnt FROM canonical_controls
WHERE {_is_eigenentwicklung_query()}
""")).fetchone()
matched_v1 = db.execute(text(f"""
SELECT COUNT(DISTINCT m.v1_control_uuid) AS cnt
FROM v1_control_matches m
JOIN canonical_controls c ON c.id = m.v1_control_uuid
WHERE {_is_eigenentwicklung_query().replace('release_state', 'c.release_state').replace('generation_strategy', 'c.generation_strategy').replace('pipeline_version', 'c.pipeline_version').replace('source_citation', 'c.source_citation').replace('parent_control_uuid', 'c.parent_control_uuid')}
""")).fetchone()
total_matches = db.execute(text("""
SELECT COUNT(*) AS cnt FROM v1_control_matches
""")).fetchone()
avg_score = db.execute(text("""
SELECT AVG(similarity_score) AS avg_score FROM v1_control_matches
""")).fetchone()
return {
"total_v1_controls": total_v1.cnt if total_v1 else 0,
"v1_with_matches": matched_v1.cnt if matched_v1 else 0,
"v1_without_matches": (total_v1.cnt if total_v1 else 0) - (matched_v1.cnt if matched_v1 else 0),
"total_matches": total_matches.cnt if total_matches else 0,
"avg_similarity_score": round(float(avg_score.avg_score), 3) if avg_score and avg_score.avg_score else None,
}

View File

View File

@@ -56,6 +56,7 @@ services:
- "8091:8091" # Voice Service (WSS)
- "8093:8093" # AI Compliance SDK
- "8097:8097" # RAG Service (NEU)
- "8098:8098" # Control Pipeline
- "8443:8443" # Jitsi Meet
- "3008:3008" # Admin Core
- "3010:3010" # Portal Dashboard
@@ -376,6 +377,50 @@ services:
networks:
- breakpilot-network
# =========================================================
# CONTROL PIPELINE (Entwickler-only, nicht kundenrelevant)
# =========================================================
control-pipeline:
build:
context: ./control-pipeline
dockerfile: Dockerfile
container_name: bp-core-control-pipeline
platform: linux/arm64
expose:
- "8098"
environment:
PORT: 8098
DATABASE_URL: postgresql://${POSTGRES_USER:-breakpilot}:${POSTGRES_PASSWORD:-breakpilot123}@postgres:5432/${POSTGRES_DB:-breakpilot_db}
SCHEMA_SEARCH_PATH: compliance,core,public
QDRANT_URL: http://qdrant:6333
EMBEDDING_SERVICE_URL: http://embedding-service:8087
ANTHROPIC_API_KEY: ${ANTHROPIC_API_KEY:-}
CONTROL_GEN_ANTHROPIC_MODEL: ${CONTROL_GEN_ANTHROPIC_MODEL:-claude-sonnet-4-6}
DECOMPOSITION_LLM_MODEL: ${DECOMPOSITION_LLM_MODEL:-claude-haiku-4-5-20251001}
OLLAMA_URL: ${OLLAMA_URL:-http://host.docker.internal:11434}
CONTROL_GEN_OLLAMA_MODEL: ${CONTROL_GEN_OLLAMA_MODEL:-qwen3.5:35b-a3b}
SDK_URL: http://ai-compliance-sdk:8090
JWT_SECRET: ${JWT_SECRET:-your-super-secret-jwt-key-change-in-production}
ENVIRONMENT: ${ENVIRONMENT:-development}
extra_hosts:
- "host.docker.internal:host-gateway"
depends_on:
postgres:
condition: service_healthy
qdrant:
condition: service_healthy
embedding-service:
condition: service_healthy
healthcheck:
test: ["CMD", "curl", "-f", "http://127.0.0.1:8098/health"]
interval: 30s
timeout: 10s
retries: 3
start_period: 10s
restart: unless-stopped
networks:
- breakpilot-network
embedding-service:
build:
context: ./embedding-service

View File

@@ -578,6 +578,33 @@ server {
}
}
# =========================================================
# CORE: Control Pipeline on port 8098 (Entwickler-only)
# =========================================================
server {
listen 8098 ssl;
http2 on;
server_name macmini localhost;
ssl_certificate /etc/nginx/certs/macmini.crt;
ssl_certificate_key /etc/nginx/certs/macmini.key;
ssl_protocols TLSv1.2 TLSv1.3;
ssl_ciphers ECDHE-ECDSA-AES128-GCM-SHA256:ECDHE-RSA-AES128-GCM-SHA256;
ssl_prefer_server_ciphers off;
location / {
set $upstream_pipeline bp-core-control-pipeline:8098;
proxy_pass http://$upstream_pipeline;
proxy_http_version 1.1;
proxy_set_header Host $host;
proxy_set_header X-Real-IP $remote_addr;
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
proxy_set_header X-Forwarded-Proto https;
proxy_read_timeout 1800s;
proxy_send_timeout 1800s;
}
}
# =========================================================
# CORE: Edu-Search on port 8089
# =========================================================