feat: control-pipeline Service aus Compliance-Repo migriert
Control-Pipeline (Pass 0a/0b, BatchDedup, Generator) als eigenstaendiger Service in Core, damit Compliance-Repo unabhaengig refakturiert werden kann. Schreibt weiterhin ins compliance-Schema der shared PostgreSQL. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
19
control-pipeline/Dockerfile
Normal file
19
control-pipeline/Dockerfile
Normal file
@@ -0,0 +1,19 @@
|
||||
FROM python:3.11-slim
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
curl \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
COPY requirements.txt .
|
||||
RUN pip install --no-cache-dir -r requirements.txt
|
||||
|
||||
COPY . .
|
||||
|
||||
EXPOSE 8098
|
||||
|
||||
HEALTHCHECK --interval=30s --timeout=10s --start-period=10s --retries=3 \
|
||||
CMD curl -f http://127.0.0.1:8098/health || exit 1
|
||||
|
||||
CMD ["python", "main.py"]
|
||||
8
control-pipeline/api/__init__.py
Normal file
8
control-pipeline/api/__init__.py
Normal file
@@ -0,0 +1,8 @@
|
||||
from fastapi import APIRouter
|
||||
|
||||
from api.control_generator_routes import router as generator_router
|
||||
from api.canonical_control_routes import router as canonical_router
|
||||
|
||||
router = APIRouter()
|
||||
router.include_router(generator_router)
|
||||
router.include_router(canonical_router)
|
||||
2012
control-pipeline/api/canonical_control_routes.py
Normal file
2012
control-pipeline/api/canonical_control_routes.py
Normal file
File diff suppressed because it is too large
Load Diff
1100
control-pipeline/api/control_generator_routes.py
Normal file
1100
control-pipeline/api/control_generator_routes.py
Normal file
File diff suppressed because it is too large
Load Diff
67
control-pipeline/config.py
Normal file
67
control-pipeline/config.py
Normal file
@@ -0,0 +1,67 @@
|
||||
import os
|
||||
|
||||
|
||||
class Settings:
|
||||
"""Environment-based configuration for control-pipeline."""
|
||||
|
||||
# Database (compliance schema)
|
||||
DATABASE_URL: str = os.getenv(
|
||||
"DATABASE_URL",
|
||||
"postgresql://breakpilot:breakpilot123@localhost:5432/breakpilot_db",
|
||||
)
|
||||
SCHEMA_SEARCH_PATH: str = os.getenv(
|
||||
"SCHEMA_SEARCH_PATH", "compliance,core,public"
|
||||
)
|
||||
|
||||
# Qdrant (vector search for dedup)
|
||||
QDRANT_URL: str = os.getenv("QDRANT_URL", "http://localhost:6333")
|
||||
QDRANT_API_KEY: str = os.getenv("QDRANT_API_KEY", "")
|
||||
|
||||
# Embedding Service
|
||||
EMBEDDING_SERVICE_URL: str = os.getenv(
|
||||
"EMBEDDING_SERVICE_URL", "http://embedding-service:8087"
|
||||
)
|
||||
|
||||
# LLM - Anthropic
|
||||
ANTHROPIC_API_KEY: str = os.getenv("ANTHROPIC_API_KEY", "")
|
||||
CONTROL_GEN_ANTHROPIC_MODEL: str = os.getenv(
|
||||
"CONTROL_GEN_ANTHROPIC_MODEL", "claude-sonnet-4-6"
|
||||
)
|
||||
DECOMPOSITION_LLM_MODEL: str = os.getenv(
|
||||
"DECOMPOSITION_LLM_MODEL", "claude-haiku-4-5-20251001"
|
||||
)
|
||||
CONTROL_GEN_LLM_TIMEOUT: int = int(
|
||||
os.getenv("CONTROL_GEN_LLM_TIMEOUT", "180")
|
||||
)
|
||||
|
||||
# LLM - Ollama (fallback)
|
||||
OLLAMA_URL: str = os.getenv(
|
||||
"OLLAMA_URL", "http://host.docker.internal:11434"
|
||||
)
|
||||
CONTROL_GEN_OLLAMA_MODEL: str = os.getenv(
|
||||
"CONTROL_GEN_OLLAMA_MODEL", "qwen3.5:35b-a3b"
|
||||
)
|
||||
|
||||
# SDK Service (for RAG search proxy)
|
||||
SDK_URL: str = os.getenv(
|
||||
"SDK_URL", "http://ai-compliance-sdk:8090"
|
||||
)
|
||||
|
||||
# Auth
|
||||
JWT_SECRET: str = os.getenv("JWT_SECRET", "")
|
||||
|
||||
# Server
|
||||
PORT: int = int(os.getenv("PORT", "8098"))
|
||||
LOG_LEVEL: str = os.getenv("LOG_LEVEL", "INFO")
|
||||
ENVIRONMENT: str = os.getenv("ENVIRONMENT", "development")
|
||||
|
||||
# Pipeline
|
||||
DECOMPOSITION_BATCH_SIZE: int = int(
|
||||
os.getenv("DECOMPOSITION_BATCH_SIZE", "5")
|
||||
)
|
||||
DECOMPOSITION_LLM_TIMEOUT: int = int(
|
||||
os.getenv("DECOMPOSITION_LLM_TIMEOUT", "120")
|
||||
)
|
||||
|
||||
|
||||
settings = Settings()
|
||||
0
control-pipeline/data/__init__.py
Normal file
0
control-pipeline/data/__init__.py
Normal file
205
control-pipeline/data/source_type_classification.py
Normal file
205
control-pipeline/data/source_type_classification.py
Normal file
@@ -0,0 +1,205 @@
|
||||
"""
|
||||
Source-Type-Klassifikation fuer Regulierungen und Frameworks.
|
||||
|
||||
Dreistufiges Modell der normativen Verbindlichkeit:
|
||||
|
||||
Stufe 1 — GESETZ (law):
|
||||
Rechtlich bindend. Bussgeld bei Verstoss.
|
||||
Beispiele: DSGVO, NIS2, AI Act, CRA
|
||||
|
||||
Stufe 2 — LEITLINIE (guideline):
|
||||
Offizielle Auslegungshilfe von Aufsichtsbehoerden.
|
||||
Beweislastumkehr: Wer abweicht, muss begruenden warum.
|
||||
Beispiele: EDPB-Leitlinien, BSI-Standards, WP29-Dokumente
|
||||
|
||||
Stufe 3 — FRAMEWORK (framework):
|
||||
Freiwillige Best Practices, nicht rechtsverbindlich.
|
||||
Aber: Koennen als "Stand der Technik" herangezogen werden.
|
||||
Beispiele: ENISA, NIST, OWASP, OECD, CISA
|
||||
|
||||
Mapping: source_regulation (aus control_parent_links) -> source_type
|
||||
"""
|
||||
|
||||
# --- Typ-Definitionen ---
|
||||
SOURCE_TYPE_LAW = "law" # Gesetz/Verordnung/Richtlinie — normative_strength bleibt
|
||||
SOURCE_TYPE_GUIDELINE = "guideline" # Leitlinie/Standard — max "should"
|
||||
SOURCE_TYPE_FRAMEWORK = "framework" # Framework/Best Practice — max "may"
|
||||
|
||||
# Max erlaubte normative_strength pro source_type
|
||||
# DB-Constraint erlaubt: must, should, may (NICHT "can")
|
||||
NORMATIVE_STRENGTH_CAP: dict[str, str] = {
|
||||
SOURCE_TYPE_LAW: "must", # keine Begrenzung
|
||||
SOURCE_TYPE_GUIDELINE: "should", # max "should"
|
||||
SOURCE_TYPE_FRAMEWORK: "may", # max "may" (= "kann")
|
||||
}
|
||||
|
||||
# Reihenfolge fuer Vergleiche (hoeher = staerker)
|
||||
STRENGTH_ORDER: dict[str, int] = {
|
||||
"may": 1, # KANN (DB-Wert)
|
||||
"can": 1, # Alias — wird in cap_normative_strength zu "may" normalisiert
|
||||
"should": 2,
|
||||
"must": 3,
|
||||
}
|
||||
|
||||
|
||||
def cap_normative_strength(original: str, source_type: str) -> str:
|
||||
"""
|
||||
Begrenzt die normative_strength basierend auf dem source_type.
|
||||
|
||||
Beispiel:
|
||||
cap_normative_strength("must", "framework") -> "may"
|
||||
cap_normative_strength("should", "law") -> "should"
|
||||
cap_normative_strength("must", "guideline") -> "should"
|
||||
"""
|
||||
cap = NORMATIVE_STRENGTH_CAP.get(source_type, "must")
|
||||
cap_level = STRENGTH_ORDER.get(cap, 3)
|
||||
original_level = STRENGTH_ORDER.get(original, 3)
|
||||
if original_level > cap_level:
|
||||
return cap
|
||||
return original
|
||||
|
||||
|
||||
def get_highest_source_type(source_types: list[str]) -> str:
|
||||
"""
|
||||
Bestimmt den hoechsten source_type aus einer Liste.
|
||||
Ein Gesetz uebertrumpft alles.
|
||||
|
||||
Beispiel:
|
||||
get_highest_source_type(["framework", "law"]) -> "law"
|
||||
get_highest_source_type(["framework", "guideline"]) -> "guideline"
|
||||
"""
|
||||
type_order = {SOURCE_TYPE_FRAMEWORK: 1, SOURCE_TYPE_GUIDELINE: 2, SOURCE_TYPE_LAW: 3}
|
||||
if not source_types:
|
||||
return SOURCE_TYPE_FRAMEWORK
|
||||
return max(source_types, key=lambda t: type_order.get(t, 0))
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Klassifikation: source_regulation -> source_type
|
||||
#
|
||||
# Diese Map wird fuer den Backfill und zukuenftige Pipeline-Runs verwendet.
|
||||
# Neue Regulierungen hier eintragen!
|
||||
# ============================================================================
|
||||
|
||||
SOURCE_REGULATION_CLASSIFICATION: dict[str, str] = {
|
||||
# --- EU-Verordnungen (unmittelbar bindend) ---
|
||||
"DSGVO (EU) 2016/679": SOURCE_TYPE_LAW,
|
||||
"KI-Verordnung (EU) 2024/1689": SOURCE_TYPE_LAW,
|
||||
"Cyber Resilience Act (CRA)": SOURCE_TYPE_LAW,
|
||||
"NIS2-Richtlinie (EU) 2022/2555": SOURCE_TYPE_LAW,
|
||||
"Data Act": SOURCE_TYPE_LAW,
|
||||
"Data Governance Act (DGA)": SOURCE_TYPE_LAW,
|
||||
"Markets in Crypto-Assets (MiCA)": SOURCE_TYPE_LAW,
|
||||
"Maschinenverordnung (EU) 2023/1230": SOURCE_TYPE_LAW,
|
||||
"Batterieverordnung (EU) 2023/1542": SOURCE_TYPE_LAW,
|
||||
"AML-Verordnung": SOURCE_TYPE_LAW,
|
||||
|
||||
# --- EU-Richtlinien (nach nationaler Umsetzung bindend) ---
|
||||
# Fuer Compliance-Zwecke wie Gesetze behandeln
|
||||
|
||||
# --- Nationale Gesetze ---
|
||||
"Bundesdatenschutzgesetz (BDSG)": SOURCE_TYPE_LAW,
|
||||
"Telekommunikationsgesetz": SOURCE_TYPE_LAW,
|
||||
"Telekommunikationsgesetz Oesterreich": SOURCE_TYPE_LAW,
|
||||
"Gewerbeordnung (GewO)": SOURCE_TYPE_LAW,
|
||||
"Handelsgesetzbuch (HGB)": SOURCE_TYPE_LAW,
|
||||
"Abgabenordnung (AO)": SOURCE_TYPE_LAW,
|
||||
"IFRS-Übernahmeverordnung": SOURCE_TYPE_LAW,
|
||||
"Österreichisches Datenschutzgesetz (DSG)": SOURCE_TYPE_LAW,
|
||||
"LOPDGDD - Ley Orgánica de Protección de Datos (Spanien)": SOURCE_TYPE_LAW,
|
||||
"Loi Informatique et Libertés (Frankreich)": SOURCE_TYPE_LAW,
|
||||
"Információs önrendelkezési jog törvény (Ungarn)": SOURCE_TYPE_LAW,
|
||||
"EU Blue Guide 2022": SOURCE_TYPE_LAW,
|
||||
|
||||
# --- EDPB/WP29 Leitlinien (offizielle Auslegungshilfe) ---
|
||||
"EDPB Leitlinien 01/2019 (Zertifizierung)": SOURCE_TYPE_GUIDELINE,
|
||||
"EDPB Leitlinien 01/2020 (Datentransfers)": SOURCE_TYPE_GUIDELINE,
|
||||
"EDPB Leitlinien 01/2020 (Vernetzte Fahrzeuge)": SOURCE_TYPE_GUIDELINE,
|
||||
"EDPB Leitlinien 01/2022 (BCR)": SOURCE_TYPE_GUIDELINE,
|
||||
"EDPB Leitlinien 01/2024 (Berechtigtes Interesse)": SOURCE_TYPE_GUIDELINE,
|
||||
"EDPB Leitlinien 04/2019 (Data Protection by Design)": SOURCE_TYPE_GUIDELINE,
|
||||
"EDPB Leitlinien 05/2020 - Einwilligung": SOURCE_TYPE_GUIDELINE,
|
||||
"EDPB Leitlinien 07/2020 (Datentransfers)": SOURCE_TYPE_GUIDELINE,
|
||||
"EDPB Leitlinien 08/2020 (Social Media)": SOURCE_TYPE_GUIDELINE,
|
||||
"EDPB Leitlinien 09/2022 (Data Breach)": SOURCE_TYPE_GUIDELINE,
|
||||
"EDPB Leitlinien 09/2022 - Meldung von Datenschutzverletzungen": SOURCE_TYPE_GUIDELINE,
|
||||
"EDPB Empfehlungen 01/2020 - Ergaenzende Massnahmen fuer Datentransfers": SOURCE_TYPE_GUIDELINE,
|
||||
"EDPB Leitlinien - Berechtigtes Interesse (Art. 6(1)(f))": SOURCE_TYPE_GUIDELINE,
|
||||
"WP244 Leitlinien (Profiling)": SOURCE_TYPE_GUIDELINE,
|
||||
"WP251 Leitlinien (Profiling)": SOURCE_TYPE_GUIDELINE,
|
||||
"WP260 Leitlinien (Transparenz)": SOURCE_TYPE_GUIDELINE,
|
||||
|
||||
# --- BSI Standards (behoerdliche technische Richtlinien) ---
|
||||
"BSI-TR-03161-1": SOURCE_TYPE_GUIDELINE,
|
||||
"BSI-TR-03161-2": SOURCE_TYPE_GUIDELINE,
|
||||
"BSI-TR-03161-3": SOURCE_TYPE_GUIDELINE,
|
||||
|
||||
# --- ENISA (EU-Agentur, aber Empfehlungen nicht rechtsverbindlich) ---
|
||||
"ENISA Cybersecurity State 2024": SOURCE_TYPE_FRAMEWORK,
|
||||
"ENISA ICS/SCADA Dependencies": SOURCE_TYPE_FRAMEWORK,
|
||||
"ENISA Supply Chain Good Practices": SOURCE_TYPE_FRAMEWORK,
|
||||
"ENISA Threat Landscape Supply Chain": SOURCE_TYPE_FRAMEWORK,
|
||||
|
||||
# --- NIST (US-Standards, international als Best Practice) ---
|
||||
"NIST AI Risk Management Framework": SOURCE_TYPE_FRAMEWORK,
|
||||
"NIST Cybersecurity Framework 2.0": SOURCE_TYPE_FRAMEWORK,
|
||||
"NIST SP 800-207 (Zero Trust)": SOURCE_TYPE_FRAMEWORK,
|
||||
"NIST SP 800-218 (SSDF)": SOURCE_TYPE_FRAMEWORK,
|
||||
"NIST SP 800-53 Rev. 5": SOURCE_TYPE_FRAMEWORK,
|
||||
"NIST SP 800-63-3": SOURCE_TYPE_FRAMEWORK,
|
||||
|
||||
# --- OWASP (Community-Standards) ---
|
||||
"OWASP API Security Top 10 (2023)": SOURCE_TYPE_FRAMEWORK,
|
||||
"OWASP ASVS 4.0": SOURCE_TYPE_FRAMEWORK,
|
||||
"OWASP MASVS 2.0": SOURCE_TYPE_FRAMEWORK,
|
||||
"OWASP SAMM 2.0": SOURCE_TYPE_FRAMEWORK,
|
||||
"OWASP Top 10 (2021)": SOURCE_TYPE_FRAMEWORK,
|
||||
|
||||
# --- Sonstige Frameworks ---
|
||||
"OECD KI-Empfehlung": SOURCE_TYPE_FRAMEWORK,
|
||||
"CISA Secure by Design": SOURCE_TYPE_FRAMEWORK,
|
||||
}
|
||||
|
||||
|
||||
def classify_source_regulation(source_regulation: str) -> str:
|
||||
"""
|
||||
Klassifiziert eine source_regulation als law, guideline oder framework.
|
||||
|
||||
Verwendet exaktes Matching gegen die Map. Bei unbekannten Quellen
|
||||
wird anhand von Schluesselwoertern geraten, Fallback ist 'framework'
|
||||
(konservativstes Ergebnis).
|
||||
"""
|
||||
if not source_regulation:
|
||||
return SOURCE_TYPE_FRAMEWORK
|
||||
|
||||
# Exaktes Match
|
||||
if source_regulation in SOURCE_REGULATION_CLASSIFICATION:
|
||||
return SOURCE_REGULATION_CLASSIFICATION[source_regulation]
|
||||
|
||||
# Heuristik fuer unbekannte Quellen
|
||||
lower = source_regulation.lower()
|
||||
|
||||
# Gesetze erkennen
|
||||
law_indicators = [
|
||||
"verordnung", "richtlinie", "gesetz", "directive", "regulation",
|
||||
"(eu)", "(eg)", "act", "ley", "loi", "törvény", "código",
|
||||
]
|
||||
if any(ind in lower for ind in law_indicators):
|
||||
return SOURCE_TYPE_LAW
|
||||
|
||||
# Leitlinien erkennen
|
||||
guideline_indicators = [
|
||||
"edpb", "leitlinie", "guideline", "wp2", "bsi", "empfehlung",
|
||||
]
|
||||
if any(ind in lower for ind in guideline_indicators):
|
||||
return SOURCE_TYPE_GUIDELINE
|
||||
|
||||
# Frameworks erkennen
|
||||
framework_indicators = [
|
||||
"enisa", "nist", "owasp", "oecd", "cisa", "framework", "iso",
|
||||
]
|
||||
if any(ind in lower for ind in framework_indicators):
|
||||
return SOURCE_TYPE_FRAMEWORK
|
||||
|
||||
# Konservativ: unbekannt = framework (geringste Verbindlichkeit)
|
||||
return SOURCE_TYPE_FRAMEWORK
|
||||
0
control-pipeline/db/__init__.py
Normal file
0
control-pipeline/db/__init__.py
Normal file
37
control-pipeline/db/session.py
Normal file
37
control-pipeline/db/session.py
Normal file
@@ -0,0 +1,37 @@
|
||||
"""Database session factory for control-pipeline.
|
||||
|
||||
Connects to the shared PostgreSQL with search_path set to compliance schema.
|
||||
"""
|
||||
|
||||
from sqlalchemy import create_engine, event
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from config import settings
|
||||
|
||||
engine = create_engine(
|
||||
settings.DATABASE_URL,
|
||||
pool_pre_ping=True,
|
||||
pool_size=5,
|
||||
max_overflow=10,
|
||||
echo=False,
|
||||
)
|
||||
|
||||
|
||||
@event.listens_for(engine, "connect")
|
||||
def set_search_path(dbapi_connection, connection_record):
|
||||
cursor = dbapi_connection.cursor()
|
||||
cursor.execute(f"SET search_path TO {settings.SCHEMA_SEARCH_PATH}")
|
||||
cursor.close()
|
||||
dbapi_connection.commit()
|
||||
|
||||
|
||||
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
|
||||
|
||||
def get_db():
|
||||
"""FastAPI dependency for DB sessions."""
|
||||
db = SessionLocal()
|
||||
try:
|
||||
yield db
|
||||
finally:
|
||||
db.close()
|
||||
88
control-pipeline/main.py
Normal file
88
control-pipeline/main.py
Normal file
@@ -0,0 +1,88 @@
|
||||
import logging
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
import uvicorn
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
from config import settings
|
||||
from db.session import engine
|
||||
|
||||
logging.basicConfig(
|
||||
level=getattr(logging, settings.LOG_LEVEL, logging.INFO),
|
||||
format="%(asctime)s [%(name)s] %(levelname)s: %(message)s",
|
||||
)
|
||||
logger = logging.getLogger("control-pipeline")
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""Startup: verify DB and Qdrant connectivity."""
|
||||
logger.info("Control-Pipeline starting up ...")
|
||||
|
||||
# Verify database connection
|
||||
try:
|
||||
with engine.connect() as conn:
|
||||
conn.execute(__import__("sqlalchemy").text("SELECT 1"))
|
||||
logger.info("Database connection OK")
|
||||
except Exception as exc:
|
||||
logger.error("Database connection failed: %s", exc)
|
||||
|
||||
yield
|
||||
|
||||
logger.info("Control-Pipeline shutting down ...")
|
||||
|
||||
|
||||
app = FastAPI(
|
||||
title="BreakPilot Control Pipeline",
|
||||
description="Control generation, decomposition, and deduplication pipeline for the BreakPilot compliance platform.",
|
||||
version="1.0.0",
|
||||
lifespan=lifespan,
|
||||
)
|
||||
|
||||
# CORS
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# Routers
|
||||
from api import router as api_router # noqa: E402
|
||||
|
||||
app.include_router(api_router)
|
||||
|
||||
|
||||
# Health
|
||||
@app.get("/health")
|
||||
async def health():
|
||||
"""Liveness probe."""
|
||||
db_ok = False
|
||||
try:
|
||||
with engine.connect() as conn:
|
||||
conn.execute(__import__("sqlalchemy").text("SELECT 1"))
|
||||
db_ok = True
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
status = "healthy" if db_ok else "degraded"
|
||||
return {
|
||||
"status": status,
|
||||
"service": "control-pipeline",
|
||||
"version": "1.0.0",
|
||||
"dependencies": {
|
||||
"postgres": "ok" if db_ok else "unavailable",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
uvicorn.run(
|
||||
"main:app",
|
||||
host="0.0.0.0",
|
||||
port=settings.PORT,
|
||||
reload=False,
|
||||
log_level="info",
|
||||
)
|
||||
22
control-pipeline/requirements.txt
Normal file
22
control-pipeline/requirements.txt
Normal file
@@ -0,0 +1,22 @@
|
||||
# Web Framework
|
||||
fastapi>=0.123.0
|
||||
uvicorn[standard]>=0.27.0
|
||||
|
||||
# Database
|
||||
SQLAlchemy>=2.0.36
|
||||
psycopg2-binary>=2.9.10
|
||||
|
||||
# HTTP Client
|
||||
httpx>=0.28.0
|
||||
|
||||
# Validation
|
||||
pydantic>=2.5.0
|
||||
|
||||
# AI - Anthropic Claude
|
||||
anthropic>=0.75.0
|
||||
|
||||
# Vector DB (dedup)
|
||||
qdrant-client>=1.7.0
|
||||
|
||||
# Auth
|
||||
python-jose[cryptography]>=3.3.0
|
||||
0
control-pipeline/services/__init__.py
Normal file
0
control-pipeline/services/__init__.py
Normal file
187
control-pipeline/services/anchor_finder.py
Normal file
187
control-pipeline/services/anchor_finder.py
Normal file
@@ -0,0 +1,187 @@
|
||||
"""
|
||||
Anchor Finder — finds open-source references (OWASP, NIST, ENISA) for controls.
|
||||
|
||||
Two-stage search:
|
||||
Stage A: RAG-internal search for open-source chunks matching the control topic
|
||||
Stage B: Web search via DuckDuckGo Instant Answer API (no API key needed)
|
||||
|
||||
Only open-source references (Rule 1+2) are accepted as anchors.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional
|
||||
|
||||
import httpx
|
||||
|
||||
from .rag_client import ComplianceRAGClient, get_rag_client
|
||||
from .control_generator import (
|
||||
GeneratedControl,
|
||||
REGULATION_LICENSE_MAP,
|
||||
_RULE2_PREFIXES,
|
||||
_RULE3_PREFIXES,
|
||||
_classify_regulation,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Regulation codes that are safe to reference as open anchors (Rule 1+2)
|
||||
_OPEN_SOURCE_RULES = {1, 2}
|
||||
|
||||
|
||||
@dataclass
|
||||
class OpenAnchor:
|
||||
framework: str
|
||||
ref: str
|
||||
url: str
|
||||
|
||||
|
||||
class AnchorFinder:
|
||||
"""Finds open-source references to anchor generated controls."""
|
||||
|
||||
def __init__(self, rag_client: Optional[ComplianceRAGClient] = None):
|
||||
self.rag = rag_client or get_rag_client()
|
||||
|
||||
async def find_anchors(
|
||||
self,
|
||||
control: GeneratedControl,
|
||||
skip_web: bool = False,
|
||||
min_anchors: int = 2,
|
||||
) -> List[OpenAnchor]:
|
||||
"""Find open-source anchors for a control."""
|
||||
# Stage A: RAG-internal search
|
||||
anchors = await self._search_rag_for_open_anchors(control)
|
||||
|
||||
# Stage B: Web search if not enough anchors
|
||||
if len(anchors) < min_anchors and not skip_web:
|
||||
web_anchors = await self._search_web(control)
|
||||
# Deduplicate by framework+ref
|
||||
existing_keys = {(a.framework, a.ref) for a in anchors}
|
||||
for wa in web_anchors:
|
||||
if (wa.framework, wa.ref) not in existing_keys:
|
||||
anchors.append(wa)
|
||||
|
||||
return anchors
|
||||
|
||||
async def _search_rag_for_open_anchors(self, control: GeneratedControl) -> List[OpenAnchor]:
|
||||
"""Search RAG for chunks from open sources matching the control topic."""
|
||||
# Build search query from control title + first 3 tags
|
||||
tags_str = " ".join(control.tags[:3]) if control.tags else ""
|
||||
query = f"{control.title} {tags_str}".strip()
|
||||
|
||||
results = await self.rag.search_with_rerank(
|
||||
query=query,
|
||||
collection="bp_compliance_ce",
|
||||
top_k=15,
|
||||
)
|
||||
|
||||
anchors: List[OpenAnchor] = []
|
||||
seen: set[str] = set()
|
||||
|
||||
for r in results:
|
||||
if not r.regulation_code:
|
||||
continue
|
||||
|
||||
# Only accept open-source references
|
||||
license_info = _classify_regulation(r.regulation_code)
|
||||
if license_info.get("rule") not in _OPEN_SOURCE_RULES:
|
||||
continue
|
||||
|
||||
# Build reference key for dedup
|
||||
ref = r.article or r.category or ""
|
||||
key = f"{r.regulation_code}:{ref}"
|
||||
if key in seen:
|
||||
continue
|
||||
seen.add(key)
|
||||
|
||||
framework_name = license_info.get("name", r.regulation_name or r.regulation_short or r.regulation_code)
|
||||
url = r.source_url or self._build_reference_url(r.regulation_code, ref)
|
||||
|
||||
anchors.append(OpenAnchor(
|
||||
framework=framework_name,
|
||||
ref=ref,
|
||||
url=url,
|
||||
))
|
||||
|
||||
if len(anchors) >= 5:
|
||||
break
|
||||
|
||||
return anchors
|
||||
|
||||
async def _search_web(self, control: GeneratedControl) -> List[OpenAnchor]:
|
||||
"""Search DuckDuckGo Instant Answer API for open references."""
|
||||
keywords = f"{control.title} security control OWASP NIST"
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
resp = await client.get(
|
||||
"https://api.duckduckgo.com/",
|
||||
params={
|
||||
"q": keywords,
|
||||
"format": "json",
|
||||
"no_html": "1",
|
||||
"skip_disambig": "1",
|
||||
},
|
||||
)
|
||||
if resp.status_code != 200:
|
||||
return []
|
||||
|
||||
data = resp.json()
|
||||
anchors: List[OpenAnchor] = []
|
||||
|
||||
# Parse RelatedTopics
|
||||
for topic in data.get("RelatedTopics", [])[:10]:
|
||||
url = topic.get("FirstURL", "")
|
||||
text = topic.get("Text", "")
|
||||
|
||||
if not url:
|
||||
continue
|
||||
|
||||
# Only accept known open-source domains
|
||||
framework = self._identify_framework_from_url(url)
|
||||
if framework:
|
||||
anchors.append(OpenAnchor(
|
||||
framework=framework,
|
||||
ref=text[:100] if text else url,
|
||||
url=url,
|
||||
))
|
||||
|
||||
if len(anchors) >= 3:
|
||||
break
|
||||
|
||||
return anchors
|
||||
|
||||
except Exception as e:
|
||||
logger.warning("Web anchor search failed: %s", e)
|
||||
return []
|
||||
|
||||
@staticmethod
|
||||
def _identify_framework_from_url(url: str) -> Optional[str]:
|
||||
"""Identify if a URL belongs to a known open-source framework."""
|
||||
url_lower = url.lower()
|
||||
if "owasp.org" in url_lower:
|
||||
return "OWASP"
|
||||
if "nist.gov" in url_lower or "csrc.nist.gov" in url_lower:
|
||||
return "NIST"
|
||||
if "enisa.europa.eu" in url_lower:
|
||||
return "ENISA"
|
||||
if "cisa.gov" in url_lower:
|
||||
return "CISA"
|
||||
if "eur-lex.europa.eu" in url_lower:
|
||||
return "EU Law"
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _build_reference_url(regulation_code: str, ref: str) -> str:
|
||||
"""Build a reference URL for known frameworks."""
|
||||
code = regulation_code.lower()
|
||||
if code.startswith("owasp"):
|
||||
return "https://owasp.org/www-project-application-security-verification-standard/"
|
||||
if code.startswith("nist"):
|
||||
return "https://csrc.nist.gov/publications"
|
||||
if code.startswith("enisa"):
|
||||
return "https://www.enisa.europa.eu/publications"
|
||||
if code.startswith("eu_"):
|
||||
return "https://eur-lex.europa.eu/"
|
||||
if code == "cisa_secure_by_design":
|
||||
return "https://www.cisa.gov/securebydesign"
|
||||
return ""
|
||||
618
control-pipeline/services/batch_dedup_runner.py
Normal file
618
control-pipeline/services/batch_dedup_runner.py
Normal file
@@ -0,0 +1,618 @@
|
||||
"""Batch Dedup Runner — Orchestrates deduplication of ~85k atomare Controls.
|
||||
|
||||
Reduces Pass 0b controls from ~85k to ~18-25k unique Master Controls via:
|
||||
Phase 1: Intra-Group Dedup — same merge_group_hint → pick best, link rest
|
||||
(85k → ~52k, mostly title-identical short-circuit, no embeddings)
|
||||
Phase 2: Cross-Group Dedup — embed masters, search Qdrant for similar
|
||||
masters with different hints (52k → ~18-25k)
|
||||
|
||||
All Pass 0b controls have pattern_id=NULL. The primary grouping key is
|
||||
merge_group_hint (format: "action_type:norm_obj:trigger_key"), which
|
||||
encodes the normalized action, object, and trigger.
|
||||
|
||||
Usage:
|
||||
runner = BatchDedupRunner(db)
|
||||
stats = await runner.run(dry_run=True) # preview
|
||||
stats = await runner.run(dry_run=False) # execute
|
||||
stats = await runner.run(hint_filter="implement:multi_factor_auth:none")
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from collections import defaultdict
|
||||
|
||||
from sqlalchemy import text
|
||||
|
||||
from services.control_dedup import (
|
||||
canonicalize_text,
|
||||
ensure_qdrant_collection,
|
||||
get_embedding,
|
||||
normalize_action,
|
||||
normalize_object,
|
||||
qdrant_search_cross_regulation,
|
||||
qdrant_upsert,
|
||||
LINK_THRESHOLD,
|
||||
REVIEW_THRESHOLD,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEDUP_COLLECTION = "atomic_controls_dedup"
|
||||
|
||||
|
||||
# ── Quality Score ────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def quality_score(control: dict) -> float:
|
||||
"""Score a control by richness of requirements, tests, evidence, and objective.
|
||||
|
||||
Higher score = better candidate for master control.
|
||||
"""
|
||||
score = 0.0
|
||||
|
||||
reqs = control.get("requirements") or "[]"
|
||||
if isinstance(reqs, str):
|
||||
try:
|
||||
reqs = json.loads(reqs)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
reqs = []
|
||||
score += len(reqs) * 2.0
|
||||
|
||||
tests = control.get("test_procedure") or "[]"
|
||||
if isinstance(tests, str):
|
||||
try:
|
||||
tests = json.loads(tests)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
tests = []
|
||||
score += len(tests) * 1.5
|
||||
|
||||
evidence = control.get("evidence") or "[]"
|
||||
if isinstance(evidence, str):
|
||||
try:
|
||||
evidence = json.loads(evidence)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
evidence = []
|
||||
score += len(evidence) * 1.0
|
||||
|
||||
objective = control.get("objective") or ""
|
||||
score += min(len(objective) / 200, 3.0)
|
||||
|
||||
return score
|
||||
|
||||
|
||||
# ── Batch Dedup Runner ───────────────────────────────────────────────────
|
||||
|
||||
|
||||
class BatchDedupRunner:
|
||||
"""Batch dedup orchestrator for existing Pass 0b atomic controls."""
|
||||
|
||||
def __init__(self, db, collection: str = DEDUP_COLLECTION):
|
||||
self.db = db
|
||||
self.collection = collection
|
||||
self.stats = {
|
||||
"total_controls": 0,
|
||||
"unique_hints": 0,
|
||||
"phase1_groups_processed": 0,
|
||||
"masters": 0,
|
||||
"linked": 0,
|
||||
"review": 0,
|
||||
"new_controls": 0,
|
||||
"parent_links_transferred": 0,
|
||||
"cross_group_linked": 0,
|
||||
"cross_group_review": 0,
|
||||
"errors": 0,
|
||||
"skipped_title_identical": 0,
|
||||
}
|
||||
self._progress_phase = ""
|
||||
self._progress_count = 0
|
||||
self._progress_total = 0
|
||||
|
||||
async def run(
|
||||
self,
|
||||
dry_run: bool = False,
|
||||
hint_filter: str = None,
|
||||
) -> dict:
|
||||
"""Run the full batch dedup pipeline.
|
||||
|
||||
Args:
|
||||
dry_run: If True, compute stats but don't modify DB/Qdrant.
|
||||
hint_filter: If set, only process groups matching this hint prefix.
|
||||
|
||||
Returns:
|
||||
Stats dict with counts.
|
||||
"""
|
||||
start = time.monotonic()
|
||||
logger.info("BatchDedup starting (dry_run=%s, hint_filter=%s)",
|
||||
dry_run, hint_filter)
|
||||
|
||||
if not dry_run:
|
||||
await ensure_qdrant_collection(collection=self.collection)
|
||||
|
||||
# Phase 1: Intra-group dedup (same merge_group_hint)
|
||||
self._progress_phase = "phase1"
|
||||
groups = self._load_merge_groups(hint_filter)
|
||||
self._progress_total = self.stats["total_controls"]
|
||||
|
||||
for hint, controls in groups:
|
||||
try:
|
||||
await self._process_hint_group(hint, controls, dry_run)
|
||||
self.stats["phase1_groups_processed"] += 1
|
||||
except Exception as e:
|
||||
logger.error("BatchDedup Phase 1 error on hint %s: %s", hint, e)
|
||||
self.stats["errors"] += 1
|
||||
try:
|
||||
self.db.rollback()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
logger.info(
|
||||
"BatchDedup Phase 1 done: %d masters, %d linked, %d review",
|
||||
self.stats["masters"], self.stats["linked"], self.stats["review"],
|
||||
)
|
||||
|
||||
# Phase 2: Cross-group dedup via embeddings
|
||||
if not dry_run:
|
||||
self._progress_phase = "phase2"
|
||||
await self._run_cross_group_pass()
|
||||
|
||||
elapsed = time.monotonic() - start
|
||||
self.stats["elapsed_seconds"] = round(elapsed, 1)
|
||||
logger.info("BatchDedup completed in %.1fs: %s", elapsed, self.stats)
|
||||
return self.stats
|
||||
|
||||
def _load_merge_groups(self, hint_filter: str = None) -> list:
|
||||
"""Load all Pass 0b controls grouped by merge_group_hint, largest first."""
|
||||
conditions = [
|
||||
"decomposition_method = 'pass0b'",
|
||||
"release_state != 'deprecated'",
|
||||
"release_state != 'duplicate'",
|
||||
]
|
||||
params = {}
|
||||
|
||||
if hint_filter:
|
||||
conditions.append("generation_metadata->>'merge_group_hint' LIKE :hf")
|
||||
params["hf"] = f"{hint_filter}%"
|
||||
|
||||
where = " AND ".join(conditions)
|
||||
rows = self.db.execute(text(f"""
|
||||
SELECT id::text, control_id, title, objective,
|
||||
pattern_id, requirements::text, test_procedure::text,
|
||||
evidence::text, release_state,
|
||||
generation_metadata->>'merge_group_hint' as merge_group_hint,
|
||||
generation_metadata->>'action_object_class' as action_object_class
|
||||
FROM canonical_controls
|
||||
WHERE {where}
|
||||
ORDER BY control_id
|
||||
"""), params).fetchall()
|
||||
|
||||
by_hint = defaultdict(list)
|
||||
for r in rows:
|
||||
by_hint[r[9] or ""].append({
|
||||
"uuid": r[0],
|
||||
"control_id": r[1],
|
||||
"title": r[2],
|
||||
"objective": r[3],
|
||||
"pattern_id": r[4],
|
||||
"requirements": r[5],
|
||||
"test_procedure": r[6],
|
||||
"evidence": r[7],
|
||||
"release_state": r[8],
|
||||
"merge_group_hint": r[9] or "",
|
||||
"action_object_class": r[10] or "",
|
||||
})
|
||||
|
||||
self.stats["total_controls"] = len(rows)
|
||||
self.stats["unique_hints"] = len(by_hint)
|
||||
|
||||
sorted_groups = sorted(by_hint.items(), key=lambda x: len(x[1]), reverse=True)
|
||||
logger.info("BatchDedup loaded %d controls in %d hint groups",
|
||||
len(rows), len(sorted_groups))
|
||||
return sorted_groups
|
||||
|
||||
def _sub_group_by_merge_hint(self, controls: list) -> dict:
|
||||
"""Group controls by merge_group_hint composite key."""
|
||||
groups = defaultdict(list)
|
||||
for c in controls:
|
||||
hint = c["merge_group_hint"]
|
||||
if hint:
|
||||
groups[hint].append(c)
|
||||
else:
|
||||
groups[f"__no_hint_{c['uuid']}"].append(c)
|
||||
return dict(groups)
|
||||
|
||||
async def _process_hint_group(
|
||||
self,
|
||||
hint: str,
|
||||
controls: list,
|
||||
dry_run: bool,
|
||||
):
|
||||
"""Process all controls sharing the same merge_group_hint.
|
||||
|
||||
Within a hint group, all controls share action+object+trigger.
|
||||
The best-quality control becomes master, rest are linked as duplicates.
|
||||
"""
|
||||
if len(controls) < 2:
|
||||
# Singleton → always master
|
||||
self.stats["masters"] += 1
|
||||
if not dry_run:
|
||||
await self._embed_and_index(controls[0])
|
||||
self._progress_count += 1
|
||||
self._log_progress(hint)
|
||||
return
|
||||
|
||||
# Sort by quality score (best first)
|
||||
sorted_group = sorted(controls, key=quality_score, reverse=True)
|
||||
master = sorted_group[0]
|
||||
self.stats["masters"] += 1
|
||||
|
||||
if not dry_run:
|
||||
await self._embed_and_index(master)
|
||||
|
||||
for candidate in sorted_group[1:]:
|
||||
# All share the same hint → check title similarity
|
||||
if candidate["title"].strip().lower() == master["title"].strip().lower():
|
||||
# Identical title → direct link (no embedding needed)
|
||||
self.stats["linked"] += 1
|
||||
self.stats["skipped_title_identical"] += 1
|
||||
if not dry_run:
|
||||
await self._mark_duplicate(master, candidate, confidence=1.0)
|
||||
else:
|
||||
# Different title within same hint → still likely duplicate
|
||||
# Use embedding to verify
|
||||
await self._check_and_link_within_group(master, candidate, dry_run)
|
||||
|
||||
self._progress_count += 1
|
||||
self._log_progress(hint)
|
||||
|
||||
async def _check_and_link_within_group(
|
||||
self,
|
||||
master: dict,
|
||||
candidate: dict,
|
||||
dry_run: bool,
|
||||
):
|
||||
"""Check if candidate (same hint group) is duplicate of master via embedding."""
|
||||
parts = candidate["merge_group_hint"].split(":", 2)
|
||||
action = parts[0] if len(parts) > 0 else ""
|
||||
obj = parts[1] if len(parts) > 1 else ""
|
||||
|
||||
canonical = canonicalize_text(action, obj, candidate["title"])
|
||||
embedding = await get_embedding(canonical)
|
||||
|
||||
if not embedding:
|
||||
# Can't embed → link anyway (same hint = same action+object)
|
||||
self.stats["linked"] += 1
|
||||
if not dry_run:
|
||||
await self._mark_duplicate(master, candidate, confidence=0.90)
|
||||
return
|
||||
|
||||
# Search the dedup collection (unfiltered — pattern_id is NULL)
|
||||
results = await qdrant_search_cross_regulation(
|
||||
embedding, top_k=3, collection=self.collection,
|
||||
)
|
||||
|
||||
if not results:
|
||||
# No Qdrant matches yet (master might not be indexed yet) → link to master
|
||||
self.stats["linked"] += 1
|
||||
if not dry_run:
|
||||
await self._mark_duplicate(master, candidate, confidence=0.90)
|
||||
return
|
||||
|
||||
best = results[0]
|
||||
best_score = best.get("score", 0.0)
|
||||
best_payload = best.get("payload", {})
|
||||
best_uuid = best_payload.get("control_uuid", "")
|
||||
|
||||
if best_score > LINK_THRESHOLD:
|
||||
self.stats["linked"] += 1
|
||||
if not dry_run:
|
||||
await self._mark_duplicate_to(best_uuid, candidate, confidence=best_score)
|
||||
elif best_score > REVIEW_THRESHOLD:
|
||||
self.stats["review"] += 1
|
||||
if not dry_run:
|
||||
self._write_review(candidate, best_payload, best_score)
|
||||
else:
|
||||
# Very different despite same hint → new master
|
||||
self.stats["new_controls"] += 1
|
||||
if not dry_run:
|
||||
await self._index_with_embedding(candidate, embedding)
|
||||
|
||||
async def _run_cross_group_pass(self):
|
||||
"""Phase 2: Find cross-group duplicates among surviving masters.
|
||||
|
||||
After Phase 1, ~52k masters remain. Many have similar semantics
|
||||
despite different merge_group_hints (e.g. different German spellings).
|
||||
This pass embeds all masters and finds near-duplicates via Qdrant.
|
||||
"""
|
||||
logger.info("BatchDedup Phase 2: Cross-group pass starting...")
|
||||
|
||||
rows = self.db.execute(text("""
|
||||
SELECT id::text, control_id, title,
|
||||
generation_metadata->>'merge_group_hint' as merge_group_hint
|
||||
FROM canonical_controls
|
||||
WHERE decomposition_method = 'pass0b'
|
||||
AND release_state != 'duplicate'
|
||||
AND release_state != 'deprecated'
|
||||
ORDER BY control_id
|
||||
""")).fetchall()
|
||||
|
||||
self._progress_total = len(rows)
|
||||
self._progress_count = 0
|
||||
logger.info("BatchDedup Cross-group: %d masters to check", len(rows))
|
||||
cross_linked = 0
|
||||
cross_review = 0
|
||||
|
||||
for i, r in enumerate(rows):
|
||||
uuid = r[0]
|
||||
hint = r[3] or ""
|
||||
parts = hint.split(":", 2)
|
||||
action = parts[0] if len(parts) > 0 else ""
|
||||
obj = parts[1] if len(parts) > 1 else ""
|
||||
|
||||
canonical = canonicalize_text(action, obj, r[2])
|
||||
embedding = await get_embedding(canonical)
|
||||
if not embedding:
|
||||
continue
|
||||
|
||||
results = await qdrant_search_cross_regulation(
|
||||
embedding, top_k=5, collection=self.collection,
|
||||
)
|
||||
if not results:
|
||||
continue
|
||||
|
||||
# Find best match from a DIFFERENT hint group
|
||||
for match in results:
|
||||
match_score = match.get("score", 0.0)
|
||||
match_payload = match.get("payload", {})
|
||||
match_uuid = match_payload.get("control_uuid", "")
|
||||
|
||||
# Skip self-match
|
||||
if match_uuid == uuid:
|
||||
continue
|
||||
|
||||
# Must be a different hint group (otherwise already handled in Phase 1)
|
||||
match_action = match_payload.get("action_normalized", "")
|
||||
match_object = match_payload.get("object_normalized", "")
|
||||
# Simple check: different control UUID is enough
|
||||
if match_score > LINK_THRESHOLD:
|
||||
# Mark the worse one as duplicate
|
||||
try:
|
||||
self.db.execute(text("""
|
||||
UPDATE canonical_controls
|
||||
SET release_state = 'duplicate', merged_into_uuid = CAST(:master AS uuid)
|
||||
WHERE id = CAST(:dup AS uuid)
|
||||
AND release_state != 'duplicate'
|
||||
"""), {"master": match_uuid, "dup": uuid})
|
||||
|
||||
self.db.execute(text("""
|
||||
INSERT INTO control_parent_links
|
||||
(control_uuid, parent_control_uuid, link_type, confidence)
|
||||
VALUES (CAST(:cu AS uuid), CAST(:pu AS uuid), 'cross_regulation', :conf)
|
||||
ON CONFLICT (control_uuid, parent_control_uuid) DO NOTHING
|
||||
"""), {"cu": match_uuid, "pu": uuid, "conf": match_score})
|
||||
|
||||
# Transfer parent links
|
||||
transferred = self._transfer_parent_links(match_uuid, uuid)
|
||||
self.stats["parent_links_transferred"] += transferred
|
||||
|
||||
self.db.commit()
|
||||
cross_linked += 1
|
||||
except Exception as e:
|
||||
logger.error("BatchDedup cross-group link error %s→%s: %s",
|
||||
uuid, match_uuid, e)
|
||||
self.db.rollback()
|
||||
self.stats["errors"] += 1
|
||||
break # Only one cross-link per control
|
||||
elif match_score > REVIEW_THRESHOLD:
|
||||
self._write_review(
|
||||
{"control_id": r[1], "title": r[2], "objective": "",
|
||||
"merge_group_hint": hint, "pattern_id": None},
|
||||
match_payload, match_score,
|
||||
)
|
||||
cross_review += 1
|
||||
break
|
||||
|
||||
self._progress_count = i + 1
|
||||
if (i + 1) % 500 == 0:
|
||||
logger.info("BatchDedup Cross-group: %d/%d checked, %d linked, %d review",
|
||||
i + 1, len(rows), cross_linked, cross_review)
|
||||
|
||||
self.stats["cross_group_linked"] = cross_linked
|
||||
self.stats["cross_group_review"] = cross_review
|
||||
logger.info("BatchDedup Cross-group complete: %d linked, %d review",
|
||||
cross_linked, cross_review)
|
||||
|
||||
# ── Qdrant Helpers ───────────────────────────────────────────────────
|
||||
|
||||
async def _embed_and_index(self, control: dict):
|
||||
"""Compute embedding and index a control in the dedup Qdrant collection."""
|
||||
parts = control["merge_group_hint"].split(":", 2)
|
||||
action = parts[0] if len(parts) > 0 else ""
|
||||
obj = parts[1] if len(parts) > 1 else ""
|
||||
|
||||
norm_action = normalize_action(action)
|
||||
norm_object = normalize_object(obj)
|
||||
canonical = canonicalize_text(action, obj, control["title"])
|
||||
embedding = await get_embedding(canonical)
|
||||
|
||||
if not embedding:
|
||||
return
|
||||
|
||||
await qdrant_upsert(
|
||||
point_id=control["uuid"],
|
||||
embedding=embedding,
|
||||
payload={
|
||||
"control_uuid": control["uuid"],
|
||||
"control_id": control["control_id"],
|
||||
"title": control["title"],
|
||||
"pattern_id": control.get("pattern_id"),
|
||||
"action_normalized": norm_action,
|
||||
"object_normalized": norm_object,
|
||||
"canonical_text": canonical,
|
||||
"merge_group_hint": control["merge_group_hint"],
|
||||
},
|
||||
collection=self.collection,
|
||||
)
|
||||
|
||||
async def _index_with_embedding(self, control: dict, embedding: list):
|
||||
"""Index a control with a pre-computed embedding."""
|
||||
parts = control["merge_group_hint"].split(":", 2)
|
||||
action = parts[0] if len(parts) > 0 else ""
|
||||
obj = parts[1] if len(parts) > 1 else ""
|
||||
|
||||
norm_action = normalize_action(action)
|
||||
norm_object = normalize_object(obj)
|
||||
canonical = canonicalize_text(action, obj, control["title"])
|
||||
|
||||
await qdrant_upsert(
|
||||
point_id=control["uuid"],
|
||||
embedding=embedding,
|
||||
payload={
|
||||
"control_uuid": control["uuid"],
|
||||
"control_id": control["control_id"],
|
||||
"title": control["title"],
|
||||
"pattern_id": control.get("pattern_id"),
|
||||
"action_normalized": norm_action,
|
||||
"object_normalized": norm_object,
|
||||
"canonical_text": canonical,
|
||||
"merge_group_hint": control["merge_group_hint"],
|
||||
},
|
||||
collection=self.collection,
|
||||
)
|
||||
|
||||
# ── DB Write Helpers ─────────────────────────────────────────────────
|
||||
|
||||
async def _mark_duplicate(self, master: dict, candidate: dict, confidence: float):
|
||||
"""Mark candidate as duplicate of master, transfer parent links."""
|
||||
try:
|
||||
self.db.execute(text("""
|
||||
UPDATE canonical_controls
|
||||
SET release_state = 'duplicate', merged_into_uuid = CAST(:master AS uuid)
|
||||
WHERE id = CAST(:cand AS uuid)
|
||||
"""), {"master": master["uuid"], "cand": candidate["uuid"]})
|
||||
|
||||
self.db.execute(text("""
|
||||
INSERT INTO control_parent_links
|
||||
(control_uuid, parent_control_uuid, link_type, confidence)
|
||||
VALUES (CAST(:master AS uuid), CAST(:cand_parent AS uuid), 'dedup_merge', :conf)
|
||||
ON CONFLICT (control_uuid, parent_control_uuid) DO NOTHING
|
||||
"""), {"master": master["uuid"], "cand_parent": candidate["uuid"], "conf": confidence})
|
||||
|
||||
transferred = self._transfer_parent_links(master["uuid"], candidate["uuid"])
|
||||
self.stats["parent_links_transferred"] += transferred
|
||||
|
||||
self.db.commit()
|
||||
except Exception as e:
|
||||
logger.error("BatchDedup _mark_duplicate error %s→%s: %s",
|
||||
candidate["uuid"], master["uuid"], e)
|
||||
self.db.rollback()
|
||||
raise
|
||||
|
||||
async def _mark_duplicate_to(self, master_uuid: str, candidate: dict, confidence: float):
|
||||
"""Mark candidate as duplicate of a Qdrant-matched master."""
|
||||
try:
|
||||
self.db.execute(text("""
|
||||
UPDATE canonical_controls
|
||||
SET release_state = 'duplicate', merged_into_uuid = CAST(:master AS uuid)
|
||||
WHERE id = CAST(:cand AS uuid)
|
||||
"""), {"master": master_uuid, "cand": candidate["uuid"]})
|
||||
|
||||
self.db.execute(text("""
|
||||
INSERT INTO control_parent_links
|
||||
(control_uuid, parent_control_uuid, link_type, confidence)
|
||||
VALUES (CAST(:master AS uuid), CAST(:cand_parent AS uuid), 'dedup_merge', :conf)
|
||||
ON CONFLICT (control_uuid, parent_control_uuid) DO NOTHING
|
||||
"""), {"master": master_uuid, "cand_parent": candidate["uuid"], "conf": confidence})
|
||||
|
||||
transferred = self._transfer_parent_links(master_uuid, candidate["uuid"])
|
||||
self.stats["parent_links_transferred"] += transferred
|
||||
|
||||
self.db.commit()
|
||||
except Exception as e:
|
||||
logger.error("BatchDedup _mark_duplicate_to error %s→%s: %s",
|
||||
candidate["uuid"], master_uuid, e)
|
||||
self.db.rollback()
|
||||
raise
|
||||
|
||||
def _transfer_parent_links(self, master_uuid: str, duplicate_uuid: str) -> int:
|
||||
"""Move existing parent links from duplicate to master."""
|
||||
rows = self.db.execute(text("""
|
||||
SELECT parent_control_uuid::text, link_type, confidence,
|
||||
source_regulation, source_article, obligation_candidate_id::text
|
||||
FROM control_parent_links
|
||||
WHERE control_uuid = CAST(:dup AS uuid)
|
||||
AND link_type = 'decomposition'
|
||||
"""), {"dup": duplicate_uuid}).fetchall()
|
||||
|
||||
transferred = 0
|
||||
for r in rows:
|
||||
parent_uuid = r[0]
|
||||
if parent_uuid == master_uuid:
|
||||
continue
|
||||
self.db.execute(text("""
|
||||
INSERT INTO control_parent_links
|
||||
(control_uuid, parent_control_uuid, link_type, confidence,
|
||||
source_regulation, source_article, obligation_candidate_id)
|
||||
VALUES (CAST(:cu AS uuid), CAST(:pu AS uuid), :lt, :conf,
|
||||
:sr, :sa, CAST(:oci AS uuid))
|
||||
ON CONFLICT (control_uuid, parent_control_uuid) DO NOTHING
|
||||
"""), {
|
||||
"cu": master_uuid,
|
||||
"pu": parent_uuid,
|
||||
"lt": r[1],
|
||||
"conf": float(r[2]) if r[2] else 1.0,
|
||||
"sr": r[3],
|
||||
"sa": r[4],
|
||||
"oci": r[5],
|
||||
})
|
||||
transferred += 1
|
||||
|
||||
return transferred
|
||||
|
||||
def _write_review(self, candidate: dict, matched_payload: dict, score: float):
|
||||
"""Write a dedup review entry for borderline matches."""
|
||||
try:
|
||||
self.db.execute(text("""
|
||||
INSERT INTO control_dedup_reviews
|
||||
(candidate_control_id, candidate_title, candidate_objective,
|
||||
matched_control_uuid, matched_control_id,
|
||||
similarity_score, dedup_stage, dedup_details)
|
||||
VALUES (:ccid, :ct, :co, CAST(:mcu AS uuid), :mci,
|
||||
:ss, 'batch_dedup', CAST(:dd AS jsonb))
|
||||
"""), {
|
||||
"ccid": candidate["control_id"],
|
||||
"ct": candidate["title"],
|
||||
"co": candidate.get("objective", ""),
|
||||
"mcu": matched_payload.get("control_uuid"),
|
||||
"mci": matched_payload.get("control_id"),
|
||||
"ss": score,
|
||||
"dd": json.dumps({
|
||||
"merge_group_hint": candidate.get("merge_group_hint", ""),
|
||||
"pattern_id": candidate.get("pattern_id"),
|
||||
}),
|
||||
})
|
||||
self.db.commit()
|
||||
except Exception as e:
|
||||
logger.error("BatchDedup _write_review error: %s", e)
|
||||
self.db.rollback()
|
||||
raise
|
||||
|
||||
# ── Progress ─────────────────────────────────────────────────────────
|
||||
|
||||
def _log_progress(self, hint: str):
|
||||
"""Log progress every 500 controls."""
|
||||
if self._progress_count > 0 and self._progress_count % 500 == 0:
|
||||
logger.info(
|
||||
"BatchDedup [%s] %d/%d — masters=%d, linked=%d, review=%d",
|
||||
self._progress_phase, self._progress_count, self._progress_total,
|
||||
self.stats["masters"], self.stats["linked"], self.stats["review"],
|
||||
)
|
||||
|
||||
def get_status(self) -> dict:
|
||||
"""Return current progress stats (for status endpoint)."""
|
||||
return {
|
||||
"phase": self._progress_phase,
|
||||
"progress": self._progress_count,
|
||||
"total": self._progress_total,
|
||||
**self.stats,
|
||||
}
|
||||
438
control-pipeline/services/citation_backfill.py
Normal file
438
control-pipeline/services/citation_backfill.py
Normal file
@@ -0,0 +1,438 @@
|
||||
"""
|
||||
Citation Backfill Service — enrich existing controls with article/paragraph provenance.
|
||||
|
||||
3-tier matching strategy:
|
||||
Tier 1 — Hash match: sha256(source_original_text) → RAG chunk lookup
|
||||
Tier 2 — Regex parse: split concatenated "DSGVO Art. 35" → regulation + article
|
||||
Tier 3 — Ollama LLM: ask local LLM to identify article/paragraph from text
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional
|
||||
|
||||
import httpx
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from .rag_client import ComplianceRAGClient, RAGSearchResult
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
OLLAMA_URL = os.getenv("OLLAMA_URL", "http://host.docker.internal:11434")
|
||||
OLLAMA_MODEL = os.getenv("CONTROL_GEN_OLLAMA_MODEL", "qwen3.5:35b-a3b")
|
||||
LLM_TIMEOUT = float(os.getenv("CONTROL_GEN_LLM_TIMEOUT", "180"))
|
||||
|
||||
ALL_COLLECTIONS = [
|
||||
"bp_compliance_ce",
|
||||
"bp_compliance_gesetze",
|
||||
"bp_compliance_datenschutz",
|
||||
"bp_dsfa_corpus",
|
||||
"bp_legal_templates",
|
||||
]
|
||||
|
||||
BACKFILL_SYSTEM_PROMPT = (
|
||||
"Du bist ein Rechtsexperte. Deine Aufgabe ist es, aus einem Gesetzestext "
|
||||
"den genauen Artikel und Absatz zu bestimmen. Antworte NUR mit validem JSON."
|
||||
)
|
||||
|
||||
# Regex to split concatenated source like "DSGVO Art. 35" or "NIS2 Artikel 21 Abs. 2"
|
||||
_SOURCE_ARTICLE_RE = re.compile(
|
||||
r"^(.+?)\s+(Art(?:ikel)?\.?\s*\d+.*)$", re.IGNORECASE
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MatchResult:
|
||||
article: str
|
||||
paragraph: str
|
||||
method: str # "hash", "regex", "llm"
|
||||
|
||||
|
||||
@dataclass
|
||||
class BackfillResult:
|
||||
total_controls: int = 0
|
||||
matched_hash: int = 0
|
||||
matched_regex: int = 0
|
||||
matched_llm: int = 0
|
||||
unmatched: int = 0
|
||||
updated: int = 0
|
||||
errors: list = field(default_factory=list)
|
||||
|
||||
|
||||
class CitationBackfill:
|
||||
"""Backfill article/paragraph into existing control source_citations."""
|
||||
|
||||
def __init__(self, db: Session, rag_client: ComplianceRAGClient):
|
||||
self.db = db
|
||||
self.rag = rag_client
|
||||
self._rag_index: dict[str, RAGSearchResult] = {}
|
||||
|
||||
async def run(self, dry_run: bool = True, limit: int = 0) -> BackfillResult:
|
||||
"""Main entry: iterate controls missing article/paragraph, match to RAG, update."""
|
||||
result = BackfillResult()
|
||||
|
||||
# Load controls needing backfill
|
||||
controls = self._load_controls_needing_backfill(limit)
|
||||
result.total_controls = len(controls)
|
||||
logger.info("Backfill: %d controls need article/paragraph enrichment", len(controls))
|
||||
|
||||
if not controls:
|
||||
return result
|
||||
|
||||
# Collect hashes we need to find — only build index for controls with source text
|
||||
needed_hashes: set[str] = set()
|
||||
for ctrl in controls:
|
||||
src = ctrl.get("source_original_text")
|
||||
if src:
|
||||
needed_hashes.add(hashlib.sha256(src.encode()).hexdigest())
|
||||
|
||||
if needed_hashes:
|
||||
# Build targeted RAG index — only scroll collections that our controls reference
|
||||
logger.info("Building targeted RAG hash index for %d source texts...", len(needed_hashes))
|
||||
await self._build_rag_index_targeted(controls)
|
||||
logger.info("RAG index built: %d chunks indexed, %d hashes needed", len(self._rag_index), len(needed_hashes))
|
||||
else:
|
||||
logger.info("No source_original_text found — skipping RAG index build")
|
||||
|
||||
# Process each control
|
||||
for i, ctrl in enumerate(controls):
|
||||
if i > 0 and i % 100 == 0:
|
||||
logger.info("Backfill progress: %d/%d processed", i, result.total_controls)
|
||||
|
||||
try:
|
||||
match = await self._match_control(ctrl)
|
||||
if match:
|
||||
if match.method == "hash":
|
||||
result.matched_hash += 1
|
||||
elif match.method == "regex":
|
||||
result.matched_regex += 1
|
||||
elif match.method == "llm":
|
||||
result.matched_llm += 1
|
||||
|
||||
if not dry_run:
|
||||
self._update_control(ctrl, match)
|
||||
result.updated += 1
|
||||
else:
|
||||
logger.debug(
|
||||
"DRY RUN: Would update %s with article=%s paragraph=%s (method=%s)",
|
||||
ctrl["control_id"], match.article, match.paragraph, match.method,
|
||||
)
|
||||
else:
|
||||
result.unmatched += 1
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Error backfilling {ctrl.get('control_id', '?')}: {e}"
|
||||
logger.error(error_msg)
|
||||
result.errors.append(error_msg)
|
||||
|
||||
if not dry_run:
|
||||
try:
|
||||
self.db.commit()
|
||||
except Exception as e:
|
||||
logger.error("Backfill commit failed: %s", e)
|
||||
result.errors.append(f"Commit failed: {e}")
|
||||
|
||||
logger.info(
|
||||
"Backfill complete: %d total, hash=%d regex=%d llm=%d unmatched=%d updated=%d",
|
||||
result.total_controls, result.matched_hash, result.matched_regex,
|
||||
result.matched_llm, result.unmatched, result.updated,
|
||||
)
|
||||
return result
|
||||
|
||||
def _load_controls_needing_backfill(self, limit: int = 0) -> list[dict]:
|
||||
"""Load controls where source_citation exists but lacks separate 'article' key."""
|
||||
query = """
|
||||
SELECT id, control_id, source_citation, source_original_text,
|
||||
generation_metadata, license_rule
|
||||
FROM canonical_controls
|
||||
WHERE license_rule IN (1, 2)
|
||||
AND source_citation IS NOT NULL
|
||||
AND (
|
||||
source_citation->>'article' IS NULL
|
||||
OR source_citation->>'article' = ''
|
||||
)
|
||||
ORDER BY control_id
|
||||
"""
|
||||
if limit > 0:
|
||||
query += f" LIMIT {limit}"
|
||||
|
||||
result = self.db.execute(text(query))
|
||||
cols = result.keys()
|
||||
controls = []
|
||||
for row in result:
|
||||
ctrl = dict(zip(cols, row))
|
||||
ctrl["id"] = str(ctrl["id"])
|
||||
# Parse JSON fields
|
||||
for jf in ("source_citation", "generation_metadata"):
|
||||
if isinstance(ctrl.get(jf), str):
|
||||
try:
|
||||
ctrl[jf] = json.loads(ctrl[jf])
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
ctrl[jf] = {}
|
||||
controls.append(ctrl)
|
||||
return controls
|
||||
|
||||
async def _build_rag_index_targeted(self, controls: list[dict]):
|
||||
"""Build RAG index by scrolling only collections relevant to our controls.
|
||||
|
||||
Uses regulation codes from generation_metadata to identify which collections
|
||||
to search, falling back to all collections only if needed.
|
||||
"""
|
||||
# Determine which collections are relevant based on regulation codes
|
||||
regulation_to_collection = self._map_regulations_to_collections(controls)
|
||||
collections_to_search = set(regulation_to_collection.values()) or set(ALL_COLLECTIONS)
|
||||
|
||||
logger.info("Targeted index: searching %d collections: %s",
|
||||
len(collections_to_search), ", ".join(collections_to_search))
|
||||
|
||||
for collection in collections_to_search:
|
||||
offset = None
|
||||
page = 0
|
||||
seen_offsets: set[str] = set()
|
||||
while True:
|
||||
chunks, next_offset = await self.rag.scroll(
|
||||
collection=collection, offset=offset, limit=200,
|
||||
)
|
||||
if not chunks:
|
||||
break
|
||||
for chunk in chunks:
|
||||
if chunk.text and len(chunk.text.strip()) >= 50:
|
||||
h = hashlib.sha256(chunk.text.encode()).hexdigest()
|
||||
self._rag_index[h] = chunk
|
||||
page += 1
|
||||
if page % 50 == 0:
|
||||
logger.info("Indexing %s: page %d (%d chunks so far)",
|
||||
collection, page, len(self._rag_index))
|
||||
if not next_offset:
|
||||
break
|
||||
if next_offset in seen_offsets:
|
||||
logger.warning("Scroll loop in %s at page %d — stopping", collection, page)
|
||||
break
|
||||
seen_offsets.add(next_offset)
|
||||
offset = next_offset
|
||||
|
||||
logger.info("Indexed collection %s: %d pages", collection, page)
|
||||
|
||||
def _map_regulations_to_collections(self, controls: list[dict]) -> dict[str, str]:
|
||||
"""Map regulation codes from controls to likely Qdrant collections."""
|
||||
# Heuristic: regulation code prefix → collection
|
||||
collection_map = {
|
||||
"eu_": "bp_compliance_gesetze",
|
||||
"dsgvo": "bp_compliance_datenschutz",
|
||||
"bdsg": "bp_compliance_gesetze",
|
||||
"ttdsg": "bp_compliance_gesetze",
|
||||
"nist_": "bp_compliance_ce",
|
||||
"owasp": "bp_compliance_ce",
|
||||
"bsi_": "bp_compliance_ce",
|
||||
"enisa": "bp_compliance_ce",
|
||||
"at_": "bp_compliance_recht",
|
||||
"fr_": "bp_compliance_recht",
|
||||
"es_": "bp_compliance_recht",
|
||||
}
|
||||
result: dict[str, str] = {}
|
||||
for ctrl in controls:
|
||||
meta = ctrl.get("generation_metadata") or {}
|
||||
reg = meta.get("source_regulation", "")
|
||||
if not reg:
|
||||
continue
|
||||
for prefix, coll in collection_map.items():
|
||||
if reg.startswith(prefix):
|
||||
result[reg] = coll
|
||||
break
|
||||
else:
|
||||
# Unknown regulation — search all
|
||||
for coll in ALL_COLLECTIONS:
|
||||
result[f"_all_{coll}"] = coll
|
||||
return result
|
||||
|
||||
async def _match_control(self, ctrl: dict) -> Optional[MatchResult]:
|
||||
"""3-tier matching: hash → regex → LLM."""
|
||||
|
||||
# Tier 1: Hash match against RAG index
|
||||
source_text = ctrl.get("source_original_text")
|
||||
if source_text:
|
||||
h = hashlib.sha256(source_text.encode()).hexdigest()
|
||||
chunk = self._rag_index.get(h)
|
||||
if chunk and (chunk.article or chunk.paragraph):
|
||||
return MatchResult(
|
||||
article=chunk.article or "",
|
||||
paragraph=chunk.paragraph or "",
|
||||
method="hash",
|
||||
)
|
||||
|
||||
# Tier 2: Regex parse concatenated source
|
||||
citation = ctrl.get("source_citation") or {}
|
||||
source_str = citation.get("source", "")
|
||||
parsed = _parse_concatenated_source(source_str)
|
||||
if parsed and parsed["article"]:
|
||||
return MatchResult(
|
||||
article=parsed["article"],
|
||||
paragraph="", # Regex can't extract paragraph from concatenated format
|
||||
method="regex",
|
||||
)
|
||||
|
||||
# Tier 3: Ollama LLM
|
||||
if source_text:
|
||||
return await self._llm_match(ctrl)
|
||||
|
||||
return None
|
||||
|
||||
async def _llm_match(self, ctrl: dict) -> Optional[MatchResult]:
|
||||
"""Use Ollama to identify article/paragraph from source text."""
|
||||
citation = ctrl.get("source_citation") or {}
|
||||
regulation_name = citation.get("source", "")
|
||||
metadata = ctrl.get("generation_metadata") or {}
|
||||
regulation_code = metadata.get("source_regulation", "")
|
||||
source_text = ctrl.get("source_original_text", "")
|
||||
|
||||
prompt = f"""Analysiere den folgenden Gesetzestext und bestimme den genauen Artikel und Absatz.
|
||||
|
||||
Gesetz: {regulation_name} (Code: {regulation_code})
|
||||
|
||||
Text:
|
||||
---
|
||||
{source_text[:2000]}
|
||||
---
|
||||
|
||||
Antworte NUR mit JSON:
|
||||
{{"article": "Art. XX", "paragraph": "Abs. Y"}}
|
||||
|
||||
Falls kein spezifischer Absatz erkennbar ist, setze paragraph auf "".
|
||||
Falls kein Artikel erkennbar ist, setze article auf "".
|
||||
Bei deutschen Gesetzen mit § verwende: "§ XX" statt "Art. XX"."""
|
||||
|
||||
try:
|
||||
raw = await _llm_ollama(prompt, BACKFILL_SYSTEM_PROMPT)
|
||||
data = _parse_json(raw)
|
||||
if data and (data.get("article") or data.get("paragraph")):
|
||||
return MatchResult(
|
||||
article=data.get("article", ""),
|
||||
paragraph=data.get("paragraph", ""),
|
||||
method="llm",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("LLM match failed for %s: %s", ctrl.get("control_id"), e)
|
||||
|
||||
return None
|
||||
|
||||
def _update_control(self, ctrl: dict, match: MatchResult):
|
||||
"""Update source_citation and generation_metadata in DB."""
|
||||
citation = ctrl.get("source_citation") or {}
|
||||
|
||||
# Clean the source name: remove concatenated article if present
|
||||
source_str = citation.get("source", "")
|
||||
parsed = _parse_concatenated_source(source_str)
|
||||
if parsed:
|
||||
citation["source"] = parsed["name"]
|
||||
|
||||
# Add separate article/paragraph fields
|
||||
citation["article"] = match.article
|
||||
citation["paragraph"] = match.paragraph
|
||||
|
||||
# Update generation_metadata
|
||||
metadata = ctrl.get("generation_metadata") or {}
|
||||
if match.article:
|
||||
metadata["source_article"] = match.article
|
||||
metadata["source_paragraph"] = match.paragraph
|
||||
metadata["backfill_method"] = match.method
|
||||
metadata["backfill_at"] = datetime.now(timezone.utc).isoformat()
|
||||
|
||||
self.db.execute(
|
||||
text("""
|
||||
UPDATE canonical_controls
|
||||
SET source_citation = :citation,
|
||||
generation_metadata = :metadata,
|
||||
updated_at = NOW()
|
||||
WHERE id = CAST(:id AS uuid)
|
||||
"""),
|
||||
{
|
||||
"id": ctrl["id"],
|
||||
"citation": json.dumps(citation),
|
||||
"metadata": json.dumps(metadata),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def _parse_concatenated_source(source: str) -> Optional[dict]:
|
||||
"""Parse 'DSGVO Art. 35' → {name: 'DSGVO', article: 'Art. 35'}.
|
||||
|
||||
Also handles '§' format: 'BDSG § 42' → {name: 'BDSG', article: '§ 42'}.
|
||||
"""
|
||||
if not source:
|
||||
return None
|
||||
|
||||
# Try Art./Artikel pattern
|
||||
m = _SOURCE_ARTICLE_RE.match(source)
|
||||
if m:
|
||||
return {"name": m.group(1).strip(), "article": m.group(2).strip()}
|
||||
|
||||
# Try § pattern
|
||||
m2 = re.match(r"^(.+?)\s+(§\s*\d+.*)$", source)
|
||||
if m2:
|
||||
return {"name": m2.group(1).strip(), "article": m2.group(2).strip()}
|
||||
|
||||
return None
|
||||
|
||||
|
||||
async def _llm_ollama(prompt: str, system_prompt: Optional[str] = None) -> str:
|
||||
"""Call Ollama chat API for backfill matching."""
|
||||
messages = []
|
||||
if system_prompt:
|
||||
messages.append({"role": "system", "content": system_prompt})
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
|
||||
payload = {
|
||||
"model": OLLAMA_MODEL,
|
||||
"messages": messages,
|
||||
"stream": False,
|
||||
"format": "json",
|
||||
"options": {"num_predict": 256},
|
||||
"think": False,
|
||||
}
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=LLM_TIMEOUT) as client:
|
||||
resp = await client.post(f"{OLLAMA_URL}/api/chat", json=payload)
|
||||
if resp.status_code != 200:
|
||||
logger.error("Ollama backfill failed %d: %s", resp.status_code, resp.text[:300])
|
||||
return ""
|
||||
data = resp.json()
|
||||
msg = data.get("message", {})
|
||||
if isinstance(msg, dict):
|
||||
return msg.get("content", "")
|
||||
return data.get("response", str(msg))
|
||||
except Exception as e:
|
||||
logger.error("Ollama backfill request failed: %s", e)
|
||||
return ""
|
||||
|
||||
|
||||
def _parse_json(raw: str) -> Optional[dict]:
|
||||
"""Extract JSON object from LLM output."""
|
||||
if not raw:
|
||||
return None
|
||||
# Try direct parse
|
||||
try:
|
||||
return json.loads(raw)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
# Try extracting from markdown code block
|
||||
m = re.search(r"```(?:json)?\s*(\{.*?\})\s*```", raw, re.DOTALL)
|
||||
if m:
|
||||
try:
|
||||
return json.loads(m.group(1))
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
# Try finding first { ... }
|
||||
m = re.search(r"\{[^{}]*\}", raw)
|
||||
if m:
|
||||
try:
|
||||
return json.loads(m.group(0))
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
return None
|
||||
546
control-pipeline/services/control_composer.py
Normal file
546
control-pipeline/services/control_composer.py
Normal file
@@ -0,0 +1,546 @@
|
||||
"""Control Composer — Pattern + Obligation → Master Control.
|
||||
|
||||
Takes an obligation (from ObligationExtractor) and a matched control pattern
|
||||
(from PatternMatcher), then uses LLM to compose a structured, actionable
|
||||
Master Control. Replaces the old Stage 3 (STRUCTURE/REFORM) with a
|
||||
pattern-guided approach.
|
||||
|
||||
Three composition modes based on license rules:
|
||||
Rule 1: Obligation + Pattern + original text → full control
|
||||
Rule 2: Obligation + Pattern + original text + citation → control
|
||||
Rule 3: Obligation + Pattern (NO original text) → reformulated control
|
||||
|
||||
Fallback: No pattern match → basic generation (tagged needs_pattern_assignment)
|
||||
|
||||
Part of the Multi-Layer Control Architecture (Phase 6 of 8).
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
from services.obligation_extractor import (
|
||||
ObligationMatch,
|
||||
_llm_ollama,
|
||||
_parse_json,
|
||||
)
|
||||
from services.pattern_matcher import (
|
||||
ControlPattern,
|
||||
PatternMatchResult,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
OLLAMA_MODEL = os.getenv("CONTROL_GEN_OLLAMA_MODEL", "qwen3.5:35b-a3b")
|
||||
|
||||
# Valid values for generated control fields
|
||||
VALID_SEVERITIES = {"low", "medium", "high", "critical"}
|
||||
VALID_EFFORTS = {"s", "m", "l", "xl"}
|
||||
VALID_VERIFICATION = {"code_review", "document", "tool", "hybrid"}
|
||||
|
||||
|
||||
@dataclass
|
||||
class ComposedControl:
|
||||
"""A Master Control composed from an obligation + pattern."""
|
||||
|
||||
# Core fields (match canonical_controls schema)
|
||||
control_id: str = ""
|
||||
title: str = ""
|
||||
objective: str = ""
|
||||
rationale: str = ""
|
||||
scope: dict = field(default_factory=dict)
|
||||
requirements: list = field(default_factory=list)
|
||||
test_procedure: list = field(default_factory=list)
|
||||
evidence: list = field(default_factory=list)
|
||||
severity: str = "medium"
|
||||
risk_score: float = 5.0
|
||||
implementation_effort: str = "m"
|
||||
open_anchors: list = field(default_factory=list)
|
||||
release_state: str = "draft"
|
||||
tags: list = field(default_factory=list)
|
||||
# 3-Rule License fields
|
||||
license_rule: Optional[int] = None
|
||||
source_original_text: Optional[str] = None
|
||||
source_citation: Optional[dict] = None
|
||||
customer_visible: bool = True
|
||||
# Classification
|
||||
verification_method: Optional[str] = None
|
||||
category: Optional[str] = None
|
||||
target_audience: Optional[list] = None
|
||||
# Pattern + Obligation linkage
|
||||
pattern_id: Optional[str] = None
|
||||
obligation_ids: list = field(default_factory=list)
|
||||
# Metadata
|
||||
generation_metadata: dict = field(default_factory=dict)
|
||||
composition_method: str = "pattern_guided" # pattern_guided | fallback
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""Serialize for DB storage or API response."""
|
||||
return {
|
||||
"control_id": self.control_id,
|
||||
"title": self.title,
|
||||
"objective": self.objective,
|
||||
"rationale": self.rationale,
|
||||
"scope": self.scope,
|
||||
"requirements": self.requirements,
|
||||
"test_procedure": self.test_procedure,
|
||||
"evidence": self.evidence,
|
||||
"severity": self.severity,
|
||||
"risk_score": self.risk_score,
|
||||
"implementation_effort": self.implementation_effort,
|
||||
"open_anchors": self.open_anchors,
|
||||
"release_state": self.release_state,
|
||||
"tags": self.tags,
|
||||
"license_rule": self.license_rule,
|
||||
"source_original_text": self.source_original_text,
|
||||
"source_citation": self.source_citation,
|
||||
"customer_visible": self.customer_visible,
|
||||
"verification_method": self.verification_method,
|
||||
"category": self.category,
|
||||
"target_audience": self.target_audience,
|
||||
"pattern_id": self.pattern_id,
|
||||
"obligation_ids": self.obligation_ids,
|
||||
"generation_metadata": self.generation_metadata,
|
||||
"composition_method": self.composition_method,
|
||||
}
|
||||
|
||||
|
||||
class ControlComposer:
|
||||
"""Composes Master Controls from obligations + patterns.
|
||||
|
||||
Usage::
|
||||
|
||||
composer = ControlComposer()
|
||||
|
||||
control = await composer.compose(
|
||||
obligation=obligation_match,
|
||||
pattern_result=pattern_match_result,
|
||||
chunk_text="...",
|
||||
license_rule=1,
|
||||
source_citation={...},
|
||||
)
|
||||
"""
|
||||
|
||||
async def compose(
|
||||
self,
|
||||
obligation: ObligationMatch,
|
||||
pattern_result: PatternMatchResult,
|
||||
chunk_text: Optional[str] = None,
|
||||
license_rule: int = 3,
|
||||
source_citation: Optional[dict] = None,
|
||||
regulation_code: Optional[str] = None,
|
||||
) -> ComposedControl:
|
||||
"""Compose a Master Control from obligation + pattern.
|
||||
|
||||
Args:
|
||||
obligation: The extracted obligation (from ObligationExtractor).
|
||||
pattern_result: The matched pattern (from PatternMatcher).
|
||||
chunk_text: Original RAG chunk text (only used for Rules 1-2).
|
||||
license_rule: 1=free, 2=citation, 3=restricted.
|
||||
source_citation: Citation metadata for Rule 2.
|
||||
regulation_code: Source regulation code.
|
||||
|
||||
Returns:
|
||||
ComposedControl ready for storage.
|
||||
"""
|
||||
pattern = pattern_result.pattern if pattern_result else None
|
||||
|
||||
if pattern:
|
||||
control = await self._compose_with_pattern(
|
||||
obligation, pattern, chunk_text, license_rule, source_citation,
|
||||
)
|
||||
else:
|
||||
control = await self._compose_fallback(
|
||||
obligation, chunk_text, license_rule, source_citation,
|
||||
)
|
||||
|
||||
# Set linkage fields
|
||||
control.pattern_id = pattern.id if pattern else None
|
||||
if obligation.obligation_id:
|
||||
control.obligation_ids = [obligation.obligation_id]
|
||||
|
||||
# Set license fields
|
||||
control.license_rule = license_rule
|
||||
if license_rule in (1, 2) and chunk_text:
|
||||
control.source_original_text = chunk_text
|
||||
if license_rule == 2 and source_citation:
|
||||
control.source_citation = source_citation
|
||||
if license_rule == 3:
|
||||
control.customer_visible = False
|
||||
control.source_original_text = None
|
||||
control.source_citation = None
|
||||
|
||||
# Build metadata
|
||||
control.generation_metadata = {
|
||||
"composition_method": control.composition_method,
|
||||
"pattern_id": control.pattern_id,
|
||||
"pattern_confidence": round(pattern_result.confidence, 3) if pattern_result else 0,
|
||||
"pattern_method": pattern_result.method if pattern_result else "none",
|
||||
"obligation_id": obligation.obligation_id,
|
||||
"obligation_method": obligation.method,
|
||||
"obligation_confidence": round(obligation.confidence, 3),
|
||||
"license_rule": license_rule,
|
||||
"regulation_code": regulation_code,
|
||||
}
|
||||
|
||||
# Validate and fix fields
|
||||
_validate_control(control)
|
||||
|
||||
return control
|
||||
|
||||
async def compose_batch(
|
||||
self,
|
||||
items: list[dict],
|
||||
) -> list[ComposedControl]:
|
||||
"""Compose multiple controls.
|
||||
|
||||
Args:
|
||||
items: List of dicts with keys: obligation, pattern_result,
|
||||
chunk_text, license_rule, source_citation, regulation_code.
|
||||
|
||||
Returns:
|
||||
List of ComposedControl instances.
|
||||
"""
|
||||
results = []
|
||||
for item in items:
|
||||
control = await self.compose(
|
||||
obligation=item["obligation"],
|
||||
pattern_result=item.get("pattern_result", PatternMatchResult()),
|
||||
chunk_text=item.get("chunk_text"),
|
||||
license_rule=item.get("license_rule", 3),
|
||||
source_citation=item.get("source_citation"),
|
||||
regulation_code=item.get("regulation_code"),
|
||||
)
|
||||
results.append(control)
|
||||
return results
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# Pattern-guided composition
|
||||
# -----------------------------------------------------------------------
|
||||
|
||||
async def _compose_with_pattern(
|
||||
self,
|
||||
obligation: ObligationMatch,
|
||||
pattern: ControlPattern,
|
||||
chunk_text: Optional[str],
|
||||
license_rule: int,
|
||||
source_citation: Optional[dict],
|
||||
) -> ComposedControl:
|
||||
"""Use LLM to fill the pattern template with obligation-specific details."""
|
||||
prompt = _build_compose_prompt(obligation, pattern, chunk_text, license_rule)
|
||||
system_prompt = _compose_system_prompt(license_rule)
|
||||
|
||||
llm_result = await _llm_ollama(prompt, system_prompt)
|
||||
if not llm_result:
|
||||
return self._compose_from_template(obligation, pattern)
|
||||
|
||||
parsed = _parse_json(llm_result)
|
||||
if not parsed:
|
||||
return self._compose_from_template(obligation, pattern)
|
||||
|
||||
control = ComposedControl(
|
||||
title=parsed.get("title", pattern.name_de)[:255],
|
||||
objective=parsed.get("objective", pattern.objective_template),
|
||||
rationale=parsed.get("rationale", pattern.rationale_template),
|
||||
requirements=_ensure_list(parsed.get("requirements", pattern.requirements_template)),
|
||||
test_procedure=_ensure_list(parsed.get("test_procedure", pattern.test_procedure_template)),
|
||||
evidence=_ensure_list(parsed.get("evidence", pattern.evidence_template)),
|
||||
severity=parsed.get("severity", pattern.severity_default),
|
||||
implementation_effort=parsed.get("implementation_effort", pattern.implementation_effort_default),
|
||||
category=parsed.get("category", pattern.category),
|
||||
tags=_ensure_list(parsed.get("tags", pattern.tags)),
|
||||
target_audience=_ensure_list(parsed.get("target_audience", [])),
|
||||
verification_method=parsed.get("verification_method"),
|
||||
open_anchors=_anchors_from_pattern(pattern),
|
||||
composition_method="pattern_guided",
|
||||
)
|
||||
|
||||
return control
|
||||
|
||||
def _compose_from_template(
|
||||
self,
|
||||
obligation: ObligationMatch,
|
||||
pattern: ControlPattern,
|
||||
) -> ComposedControl:
|
||||
"""Fallback: fill template directly without LLM (when LLM fails)."""
|
||||
obl_title = obligation.obligation_title or ""
|
||||
obl_text = obligation.obligation_text or ""
|
||||
|
||||
title = f"{pattern.name_de}"
|
||||
if obl_title:
|
||||
title = f"{pattern.name_de} — {obl_title}"
|
||||
|
||||
objective = pattern.objective_template
|
||||
if obl_text and len(obl_text) > 20:
|
||||
objective = f"{pattern.objective_template} Bezug: {obl_text[:200]}"
|
||||
|
||||
return ComposedControl(
|
||||
title=title[:255],
|
||||
objective=objective,
|
||||
rationale=pattern.rationale_template,
|
||||
requirements=list(pattern.requirements_template),
|
||||
test_procedure=list(pattern.test_procedure_template),
|
||||
evidence=list(pattern.evidence_template),
|
||||
severity=pattern.severity_default,
|
||||
implementation_effort=pattern.implementation_effort_default,
|
||||
category=pattern.category,
|
||||
tags=list(pattern.tags),
|
||||
open_anchors=_anchors_from_pattern(pattern),
|
||||
composition_method="template_only",
|
||||
)
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# Fallback (no pattern)
|
||||
# -----------------------------------------------------------------------
|
||||
|
||||
async def _compose_fallback(
|
||||
self,
|
||||
obligation: ObligationMatch,
|
||||
chunk_text: Optional[str],
|
||||
license_rule: int,
|
||||
source_citation: Optional[dict],
|
||||
) -> ComposedControl:
|
||||
"""Generate a control without a pattern template (old-style)."""
|
||||
prompt = _build_fallback_prompt(obligation, chunk_text, license_rule)
|
||||
system_prompt = _compose_system_prompt(license_rule)
|
||||
|
||||
llm_result = await _llm_ollama(prompt, system_prompt)
|
||||
parsed = _parse_json(llm_result) if llm_result else {}
|
||||
|
||||
obl_text = obligation.obligation_text or ""
|
||||
|
||||
control = ComposedControl(
|
||||
title=parsed.get("title", obl_text[:100] if obl_text else "Untitled Control")[:255],
|
||||
objective=parsed.get("objective", obl_text[:500]),
|
||||
rationale=parsed.get("rationale", "Aus gesetzlicher Pflicht abgeleitet."),
|
||||
requirements=_ensure_list(parsed.get("requirements", [])),
|
||||
test_procedure=_ensure_list(parsed.get("test_procedure", [])),
|
||||
evidence=_ensure_list(parsed.get("evidence", [])),
|
||||
severity=parsed.get("severity", "medium"),
|
||||
implementation_effort=parsed.get("implementation_effort", "m"),
|
||||
category=parsed.get("category"),
|
||||
tags=_ensure_list(parsed.get("tags", [])),
|
||||
target_audience=_ensure_list(parsed.get("target_audience", [])),
|
||||
verification_method=parsed.get("verification_method"),
|
||||
composition_method="fallback",
|
||||
release_state="needs_review",
|
||||
)
|
||||
|
||||
return control
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Prompt builders
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _compose_system_prompt(license_rule: int) -> str:
|
||||
"""Build the system prompt based on license rule."""
|
||||
if license_rule == 3:
|
||||
return (
|
||||
"Du bist ein Security-Compliance-Experte. Deine Aufgabe ist es, "
|
||||
"eigenstaendige Security Controls zu formulieren. "
|
||||
"Du formulierst IMMER in eigenen Worten. "
|
||||
"KOPIERE KEINE Saetze aus dem Quelltext. "
|
||||
"Verwende eigene Begriffe und Struktur. "
|
||||
"NENNE NICHT die Quelle. Keine proprietaeren Bezeichner. "
|
||||
"Antworte NUR mit validem JSON."
|
||||
)
|
||||
return (
|
||||
"Du bist ein Security-Compliance-Experte. "
|
||||
"Erstelle ein praxisorientiertes, umsetzbares Security Control. "
|
||||
"Antworte NUR mit validem JSON."
|
||||
)
|
||||
|
||||
|
||||
def _build_compose_prompt(
|
||||
obligation: ObligationMatch,
|
||||
pattern: ControlPattern,
|
||||
chunk_text: Optional[str],
|
||||
license_rule: int,
|
||||
) -> str:
|
||||
"""Build the LLM prompt for pattern-guided composition."""
|
||||
obl_section = _obligation_section(obligation)
|
||||
pattern_section = _pattern_section(pattern)
|
||||
|
||||
if license_rule == 3:
|
||||
context_section = "KONTEXT: Intern analysiert (keine Quellenangabe)."
|
||||
elif chunk_text:
|
||||
context_section = f"KONTEXT (Originaltext):\n{chunk_text[:2000]}"
|
||||
else:
|
||||
context_section = "KONTEXT: Kein Originaltext verfuegbar."
|
||||
|
||||
return f"""Erstelle ein PRAXISORIENTIERTES Security Control.
|
||||
|
||||
{obl_section}
|
||||
|
||||
{pattern_section}
|
||||
|
||||
{context_section}
|
||||
|
||||
AUFGABE:
|
||||
Fuelle das Muster mit pflicht-spezifischen Details.
|
||||
Das Ergebnis muss UMSETZBAR sein — keine Gesetzesparaphrase.
|
||||
Formuliere konkret und handlungsorientiert.
|
||||
|
||||
Antworte als JSON:
|
||||
{{
|
||||
"title": "Kurzer praegnanter Titel (max 100 Zeichen, deutsch)",
|
||||
"objective": "Was soll erreicht werden? (1-3 Saetze)",
|
||||
"rationale": "Warum ist das wichtig? (1-2 Saetze)",
|
||||
"requirements": ["Konkrete Anforderung 1", "Anforderung 2", ...],
|
||||
"test_procedure": ["Pruefschritt 1", "Pruefschritt 2", ...],
|
||||
"evidence": ["Nachweis 1", "Nachweis 2", ...],
|
||||
"severity": "low|medium|high|critical",
|
||||
"implementation_effort": "s|m|l|xl",
|
||||
"category": "{pattern.category}",
|
||||
"tags": ["tag1", "tag2"],
|
||||
"target_audience": ["unternehmen", "behoerden", "entwickler"],
|
||||
"verification_method": "code_review|document|tool|hybrid"
|
||||
}}"""
|
||||
|
||||
|
||||
def _build_fallback_prompt(
|
||||
obligation: ObligationMatch,
|
||||
chunk_text: Optional[str],
|
||||
license_rule: int,
|
||||
) -> str:
|
||||
"""Build the LLM prompt for fallback composition (no pattern)."""
|
||||
obl_section = _obligation_section(obligation)
|
||||
|
||||
if license_rule == 3:
|
||||
context_section = "KONTEXT: Intern analysiert (keine Quellenangabe)."
|
||||
elif chunk_text:
|
||||
context_section = f"KONTEXT (Originaltext):\n{chunk_text[:2000]}"
|
||||
else:
|
||||
context_section = "KONTEXT: Kein Originaltext verfuegbar."
|
||||
|
||||
return f"""Erstelle ein Security Control aus der folgenden Pflicht.
|
||||
|
||||
{obl_section}
|
||||
|
||||
{context_section}
|
||||
|
||||
AUFGABE:
|
||||
Formuliere ein umsetzbares Security Control.
|
||||
Keine Gesetzesparaphrase — konkrete Massnahmen beschreiben.
|
||||
|
||||
Antworte als JSON:
|
||||
{{
|
||||
"title": "Kurzer praegnanter Titel (max 100 Zeichen, deutsch)",
|
||||
"objective": "Was soll erreicht werden? (1-3 Saetze)",
|
||||
"rationale": "Warum ist das wichtig? (1-2 Saetze)",
|
||||
"requirements": ["Konkrete Anforderung 1", "Anforderung 2", ...],
|
||||
"test_procedure": ["Pruefschritt 1", "Pruefschritt 2", ...],
|
||||
"evidence": ["Nachweis 1", "Nachweis 2", ...],
|
||||
"severity": "low|medium|high|critical",
|
||||
"implementation_effort": "s|m|l|xl",
|
||||
"category": "one of: authentication, encryption, data_protection, etc.",
|
||||
"tags": ["tag1", "tag2"],
|
||||
"target_audience": ["unternehmen"],
|
||||
"verification_method": "code_review|document|tool|hybrid"
|
||||
}}"""
|
||||
|
||||
|
||||
def _obligation_section(obligation: ObligationMatch) -> str:
|
||||
"""Format the obligation for the prompt."""
|
||||
parts = ["PFLICHT (was das Gesetz verlangt):"]
|
||||
if obligation.obligation_title:
|
||||
parts.append(f" Titel: {obligation.obligation_title}")
|
||||
if obligation.obligation_text:
|
||||
parts.append(f" Beschreibung: {obligation.obligation_text[:500]}")
|
||||
if obligation.obligation_id:
|
||||
parts.append(f" ID: {obligation.obligation_id}")
|
||||
if obligation.regulation_id:
|
||||
parts.append(f" Rechtsgrundlage: {obligation.regulation_id}")
|
||||
if not obligation.obligation_text and not obligation.obligation_title:
|
||||
parts.append(" (Keine spezifische Pflicht extrahiert)")
|
||||
return "\n".join(parts)
|
||||
|
||||
|
||||
def _pattern_section(pattern: ControlPattern) -> str:
|
||||
"""Format the pattern for the prompt."""
|
||||
reqs = "\n ".join(f"- {r}" for r in pattern.requirements_template[:5])
|
||||
tests = "\n ".join(f"- {t}" for t in pattern.test_procedure_template[:3])
|
||||
return f"""MUSTER (wie man es typischerweise umsetzt):
|
||||
Pattern: {pattern.name_de} ({pattern.id})
|
||||
Domain: {pattern.domain}
|
||||
Ziel-Template: {pattern.objective_template}
|
||||
Anforderungs-Template:
|
||||
{reqs}
|
||||
Pruefverfahren-Template:
|
||||
{tests}"""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _ensure_list(value) -> list:
|
||||
"""Ensure a value is a list of strings."""
|
||||
if isinstance(value, list):
|
||||
return [str(v) for v in value if v]
|
||||
if isinstance(value, str):
|
||||
return [value]
|
||||
return []
|
||||
|
||||
|
||||
def _anchors_from_pattern(pattern: ControlPattern) -> list:
|
||||
"""Convert pattern's open_anchor_refs to control anchor format."""
|
||||
anchors = []
|
||||
for ref in pattern.open_anchor_refs:
|
||||
anchors.append({
|
||||
"framework": ref.get("framework", ""),
|
||||
"control_id": ref.get("ref", ""),
|
||||
"title": "",
|
||||
"alignment_score": 0.8,
|
||||
})
|
||||
return anchors
|
||||
|
||||
|
||||
def _validate_control(control: ComposedControl) -> None:
|
||||
"""Validate and fix control field values."""
|
||||
# Severity
|
||||
if control.severity not in VALID_SEVERITIES:
|
||||
control.severity = "medium"
|
||||
|
||||
# Implementation effort
|
||||
if control.implementation_effort not in VALID_EFFORTS:
|
||||
control.implementation_effort = "m"
|
||||
|
||||
# Verification method
|
||||
if control.verification_method and control.verification_method not in VALID_VERIFICATION:
|
||||
control.verification_method = None
|
||||
|
||||
# Risk score
|
||||
if not (0 <= control.risk_score <= 10):
|
||||
control.risk_score = _severity_to_risk(control.severity)
|
||||
|
||||
# Title length
|
||||
if len(control.title) > 255:
|
||||
control.title = control.title[:252] + "..."
|
||||
|
||||
# Ensure minimum content
|
||||
if not control.objective:
|
||||
control.objective = control.title
|
||||
if not control.rationale:
|
||||
control.rationale = "Aus regulatorischer Anforderung abgeleitet."
|
||||
if not control.requirements:
|
||||
control.requirements = ["Anforderung gemaess Pflichtbeschreibung umsetzen"]
|
||||
if not control.test_procedure:
|
||||
control.test_procedure = ["Umsetzung der Anforderungen pruefen"]
|
||||
if not control.evidence:
|
||||
control.evidence = ["Dokumentation der Umsetzung"]
|
||||
|
||||
|
||||
def _severity_to_risk(severity: str) -> float:
|
||||
"""Map severity to a default risk score."""
|
||||
return {
|
||||
"critical": 9.0,
|
||||
"high": 7.0,
|
||||
"medium": 5.0,
|
||||
"low": 3.0,
|
||||
}.get(severity, 5.0)
|
||||
745
control-pipeline/services/control_dedup.py
Normal file
745
control-pipeline/services/control_dedup.py
Normal file
@@ -0,0 +1,745 @@
|
||||
"""Control Deduplication Engine — 4-Stage Matching Pipeline.
|
||||
|
||||
Prevents duplicate atomic controls during Pass 0b by checking candidates
|
||||
against existing controls before insertion.
|
||||
|
||||
Stages:
|
||||
1. Pattern-Gate: pattern_id must match (hard gate)
|
||||
2. Action-Check: normalized action verb must match (hard gate)
|
||||
3. Object-Norm: normalized object must match (soft gate with high threshold)
|
||||
4. Embedding: cosine similarity with tiered thresholds (Qdrant)
|
||||
|
||||
Verdicts:
|
||||
- NEW: create a new atomic control
|
||||
- LINK: add parent link to existing control (similarity > LINK_THRESHOLD)
|
||||
- REVIEW: queue for human review (REVIEW_THRESHOLD < sim < LINK_THRESHOLD)
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional, Callable, Awaitable
|
||||
|
||||
import httpx
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ── Configuration ────────────────────────────────────────────────────
|
||||
|
||||
DEDUP_ENABLED = os.getenv("DEDUP_ENABLED", "true").lower() == "true"
|
||||
LINK_THRESHOLD = float(os.getenv("DEDUP_LINK_THRESHOLD", "0.92"))
|
||||
REVIEW_THRESHOLD = float(os.getenv("DEDUP_REVIEW_THRESHOLD", "0.85"))
|
||||
LINK_THRESHOLD_DIFF_OBJECT = float(os.getenv("DEDUP_LINK_THRESHOLD_DIFF_OBJ", "0.95"))
|
||||
CROSS_REG_LINK_THRESHOLD = float(os.getenv("DEDUP_CROSS_REG_THRESHOLD", "0.95"))
|
||||
QDRANT_COLLECTION = os.getenv("DEDUP_QDRANT_COLLECTION", "atomic_controls")
|
||||
QDRANT_URL = os.getenv("QDRANT_URL", "http://host.docker.internal:6333")
|
||||
EMBEDDING_URL = os.getenv("EMBEDDING_URL", "http://embedding-service:8087")
|
||||
|
||||
|
||||
# ── Result Dataclass ─────────────────────────────────────────────────
|
||||
|
||||
@dataclass
|
||||
class DedupResult:
|
||||
"""Outcome of the dedup check."""
|
||||
verdict: str # "new" | "link" | "review"
|
||||
matched_control_uuid: Optional[str] = None
|
||||
matched_control_id: Optional[str] = None
|
||||
matched_title: Optional[str] = None
|
||||
stage: str = "" # which stage decided
|
||||
similarity_score: float = 0.0
|
||||
link_type: str = "dedup_merge" # "dedup_merge" | "cross_regulation"
|
||||
details: dict = field(default_factory=dict)
|
||||
|
||||
|
||||
# ── Action Normalization ─────────────────────────────────────────────
|
||||
|
||||
_ACTION_SYNONYMS: dict[str, str] = {
|
||||
# German → canonical English
|
||||
"implementieren": "implement",
|
||||
"umsetzen": "implement",
|
||||
"einrichten": "implement",
|
||||
"einführen": "implement",
|
||||
"aufbauen": "implement",
|
||||
"bereitstellen": "implement",
|
||||
"aktivieren": "implement",
|
||||
"konfigurieren": "configure",
|
||||
"einstellen": "configure",
|
||||
"parametrieren": "configure",
|
||||
"testen": "test",
|
||||
"prüfen": "test",
|
||||
"überprüfen": "test",
|
||||
"verifizieren": "test",
|
||||
"validieren": "test",
|
||||
"kontrollieren": "test",
|
||||
"auditieren": "audit",
|
||||
"dokumentieren": "document",
|
||||
"protokollieren": "log",
|
||||
"aufzeichnen": "log",
|
||||
"loggen": "log",
|
||||
"überwachen": "monitor",
|
||||
"monitoring": "monitor",
|
||||
"beobachten": "monitor",
|
||||
"schulen": "train",
|
||||
"trainieren": "train",
|
||||
"sensibilisieren": "train",
|
||||
"löschen": "delete",
|
||||
"entfernen": "delete",
|
||||
"verschlüsseln": "encrypt",
|
||||
"sperren": "block",
|
||||
"beschränken": "restrict",
|
||||
"einschränken": "restrict",
|
||||
"begrenzen": "restrict",
|
||||
"autorisieren": "authorize",
|
||||
"genehmigen": "authorize",
|
||||
"freigeben": "authorize",
|
||||
"authentifizieren": "authenticate",
|
||||
"identifizieren": "identify",
|
||||
"melden": "report",
|
||||
"benachrichtigen": "notify",
|
||||
"informieren": "notify",
|
||||
"aktualisieren": "update",
|
||||
"erneuern": "update",
|
||||
"sichern": "backup",
|
||||
"wiederherstellen": "restore",
|
||||
# English passthrough
|
||||
"implement": "implement",
|
||||
"configure": "configure",
|
||||
"test": "test",
|
||||
"verify": "test",
|
||||
"validate": "test",
|
||||
"audit": "audit",
|
||||
"document": "document",
|
||||
"log": "log",
|
||||
"monitor": "monitor",
|
||||
"train": "train",
|
||||
"delete": "delete",
|
||||
"encrypt": "encrypt",
|
||||
"restrict": "restrict",
|
||||
"authorize": "authorize",
|
||||
"authenticate": "authenticate",
|
||||
"report": "report",
|
||||
"update": "update",
|
||||
"backup": "backup",
|
||||
"restore": "restore",
|
||||
}
|
||||
|
||||
|
||||
def normalize_action(action: str) -> str:
|
||||
"""Normalize an action verb to a canonical English form."""
|
||||
if not action:
|
||||
return ""
|
||||
action = action.strip().lower()
|
||||
# Strip German infinitive/conjugation suffixes for lookup
|
||||
action_base = re.sub(r"(en|t|st|e|te|tet|end)$", "", action)
|
||||
# Try exact match first, then base form
|
||||
if action in _ACTION_SYNONYMS:
|
||||
return _ACTION_SYNONYMS[action]
|
||||
if action_base in _ACTION_SYNONYMS:
|
||||
return _ACTION_SYNONYMS[action_base]
|
||||
# Fuzzy: check if action starts with any known verb
|
||||
for verb, canonical in _ACTION_SYNONYMS.items():
|
||||
if action.startswith(verb) or verb.startswith(action):
|
||||
return canonical
|
||||
return action # fallback: return as-is
|
||||
|
||||
|
||||
# ── Object Normalization ─────────────────────────────────────────────
|
||||
|
||||
_OBJECT_SYNONYMS: dict[str, str] = {
|
||||
# Authentication / Access
|
||||
"mfa": "multi_factor_auth",
|
||||
"multi-faktor-authentifizierung": "multi_factor_auth",
|
||||
"mehrfaktorauthentifizierung": "multi_factor_auth",
|
||||
"multi-factor authentication": "multi_factor_auth",
|
||||
"two-factor": "multi_factor_auth",
|
||||
"2fa": "multi_factor_auth",
|
||||
"passwort": "password_policy",
|
||||
"kennwort": "password_policy",
|
||||
"password": "password_policy",
|
||||
"zugangsdaten": "credentials",
|
||||
"credentials": "credentials",
|
||||
"admin-konten": "privileged_access",
|
||||
"admin accounts": "privileged_access",
|
||||
"administratorkonten": "privileged_access",
|
||||
"privilegierte zugriffe": "privileged_access",
|
||||
"privileged accounts": "privileged_access",
|
||||
"remote-zugriff": "remote_access",
|
||||
"fernzugriff": "remote_access",
|
||||
"remote access": "remote_access",
|
||||
"session": "session_management",
|
||||
"sitzung": "session_management",
|
||||
"sitzungsverwaltung": "session_management",
|
||||
# Encryption
|
||||
"verschlüsselung": "encryption",
|
||||
"encryption": "encryption",
|
||||
"kryptografie": "encryption",
|
||||
"kryptografische verfahren": "encryption",
|
||||
"schlüssel": "key_management",
|
||||
"key management": "key_management",
|
||||
"schlüsselverwaltung": "key_management",
|
||||
"zertifikat": "certificate_management",
|
||||
"certificate": "certificate_management",
|
||||
"tls": "transport_encryption",
|
||||
"ssl": "transport_encryption",
|
||||
"https": "transport_encryption",
|
||||
# Network
|
||||
"firewall": "firewall",
|
||||
"netzwerk": "network_security",
|
||||
"network": "network_security",
|
||||
"vpn": "vpn",
|
||||
"segmentierung": "network_segmentation",
|
||||
"segmentation": "network_segmentation",
|
||||
# Logging / Monitoring
|
||||
"audit-log": "audit_logging",
|
||||
"audit log": "audit_logging",
|
||||
"protokoll": "audit_logging",
|
||||
"logging": "audit_logging",
|
||||
"monitoring": "monitoring",
|
||||
"überwachung": "monitoring",
|
||||
"alerting": "alerting",
|
||||
"alarmierung": "alerting",
|
||||
"siem": "siem",
|
||||
# Data
|
||||
"personenbezogene daten": "personal_data",
|
||||
"personal data": "personal_data",
|
||||
"sensible daten": "sensitive_data",
|
||||
"sensitive data": "sensitive_data",
|
||||
"datensicherung": "backup",
|
||||
"backup": "backup",
|
||||
"wiederherstellung": "disaster_recovery",
|
||||
"disaster recovery": "disaster_recovery",
|
||||
# Policy / Process
|
||||
"richtlinie": "policy",
|
||||
"policy": "policy",
|
||||
"verfahrensanweisung": "procedure",
|
||||
"procedure": "procedure",
|
||||
"prozess": "process",
|
||||
"schulung": "training",
|
||||
"training": "training",
|
||||
"awareness": "awareness",
|
||||
"sensibilisierung": "awareness",
|
||||
# Incident
|
||||
"vorfall": "incident",
|
||||
"incident": "incident",
|
||||
"sicherheitsvorfall": "security_incident",
|
||||
"security incident": "security_incident",
|
||||
# Vulnerability
|
||||
"schwachstelle": "vulnerability",
|
||||
"vulnerability": "vulnerability",
|
||||
"patch": "patch_management",
|
||||
"update": "patch_management",
|
||||
"patching": "patch_management",
|
||||
}
|
||||
|
||||
# Precompile for substring matching (longest first)
|
||||
_OBJECT_KEYS_SORTED = sorted(_OBJECT_SYNONYMS.keys(), key=len, reverse=True)
|
||||
|
||||
|
||||
def normalize_object(obj: str) -> str:
|
||||
"""Normalize a compliance object to a canonical token."""
|
||||
if not obj:
|
||||
return ""
|
||||
obj_lower = obj.strip().lower()
|
||||
# Exact match
|
||||
if obj_lower in _OBJECT_SYNONYMS:
|
||||
return _OBJECT_SYNONYMS[obj_lower]
|
||||
# Substring match (longest first)
|
||||
for phrase in _OBJECT_KEYS_SORTED:
|
||||
if phrase in obj_lower:
|
||||
return _OBJECT_SYNONYMS[phrase]
|
||||
# Fallback: strip articles/prepositions, join with underscore
|
||||
cleaned = re.sub(r"\b(der|die|das|den|dem|des|ein|eine|eines|einem|einen"
|
||||
r"|für|von|zu|auf|in|an|bei|mit|nach|über|unter|the|a|an"
|
||||
r"|for|of|to|on|in|at|by|with)\b", "", obj_lower)
|
||||
tokens = [t for t in cleaned.split() if len(t) > 2]
|
||||
return "_".join(tokens[:4]) if tokens else obj_lower.replace(" ", "_")
|
||||
|
||||
|
||||
# ── Canonicalization ─────────────────────────────────────────────────
|
||||
|
||||
def canonicalize_text(action: str, obj: str, title: str = "") -> str:
|
||||
"""Build a canonical English text for embedding.
|
||||
|
||||
Transforms German compliance text into normalized English tokens
|
||||
for more stable embedding comparisons.
|
||||
"""
|
||||
norm_action = normalize_action(action)
|
||||
norm_object = normalize_object(obj)
|
||||
# Build canonical sentence
|
||||
parts = [norm_action, norm_object]
|
||||
if title:
|
||||
# Add title keywords (stripped of common filler)
|
||||
title_clean = re.sub(
|
||||
r"\b(und|oder|für|von|zu|der|die|das|den|dem|des|ein|eine"
|
||||
r"|bei|mit|nach|gemäß|gem\.|laut|entsprechend)\b",
|
||||
"", title.lower()
|
||||
)
|
||||
title_tokens = [t for t in title_clean.split() if len(t) > 3][:5]
|
||||
if title_tokens:
|
||||
parts.append("for")
|
||||
parts.extend(title_tokens)
|
||||
return " ".join(parts)
|
||||
|
||||
|
||||
# ── Embedding Helper ─────────────────────────────────────────────────
|
||||
|
||||
async def get_embedding(text: str) -> list[float]:
|
||||
"""Get embedding vector for a single text via embedding service."""
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
resp = await client.post(
|
||||
f"{EMBEDDING_URL}/embed",
|
||||
json={"texts": [text]},
|
||||
)
|
||||
embeddings = resp.json().get("embeddings", [])
|
||||
return embeddings[0] if embeddings else []
|
||||
except Exception as e:
|
||||
logger.warning("Embedding failed: %s", e)
|
||||
return []
|
||||
|
||||
|
||||
def cosine_similarity(a: list[float], b: list[float]) -> float:
|
||||
"""Compute cosine similarity between two vectors."""
|
||||
if not a or not b or len(a) != len(b):
|
||||
return 0.0
|
||||
dot = sum(x * y for x, y in zip(a, b))
|
||||
norm_a = sum(x * x for x in a) ** 0.5
|
||||
norm_b = sum(x * x for x in b) ** 0.5
|
||||
if norm_a == 0 or norm_b == 0:
|
||||
return 0.0
|
||||
return dot / (norm_a * norm_b)
|
||||
|
||||
|
||||
# ── Qdrant Helpers ───────────────────────────────────────────────────
|
||||
|
||||
async def qdrant_search(
|
||||
embedding: list[float],
|
||||
pattern_id: str,
|
||||
top_k: int = 10,
|
||||
collection: Optional[str] = None,
|
||||
) -> list[dict]:
|
||||
"""Search Qdrant for similar atomic controls, filtered by pattern_id."""
|
||||
if not embedding:
|
||||
return []
|
||||
coll = collection or QDRANT_COLLECTION
|
||||
body: dict = {
|
||||
"vector": embedding,
|
||||
"limit": top_k,
|
||||
"with_payload": True,
|
||||
"filter": {
|
||||
"must": [
|
||||
{"key": "pattern_id", "match": {"value": pattern_id}}
|
||||
]
|
||||
},
|
||||
}
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
resp = await client.post(
|
||||
f"{QDRANT_URL}/collections/{coll}/points/search",
|
||||
json=body,
|
||||
)
|
||||
if resp.status_code != 200:
|
||||
logger.warning("Qdrant search failed: %d", resp.status_code)
|
||||
return []
|
||||
return resp.json().get("result", [])
|
||||
except Exception as e:
|
||||
logger.warning("Qdrant search error: %s", e)
|
||||
return []
|
||||
|
||||
|
||||
async def qdrant_search_cross_regulation(
|
||||
embedding: list[float],
|
||||
top_k: int = 5,
|
||||
collection: Optional[str] = None,
|
||||
) -> list[dict]:
|
||||
"""Search Qdrant for similar controls across ALL regulations (no pattern_id filter).
|
||||
|
||||
Used for cross-regulation linking (e.g. DSGVO Art. 25 ↔ NIS2 Art. 21).
|
||||
"""
|
||||
if not embedding:
|
||||
return []
|
||||
coll = collection or QDRANT_COLLECTION
|
||||
body: dict = {
|
||||
"vector": embedding,
|
||||
"limit": top_k,
|
||||
"with_payload": True,
|
||||
}
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
resp = await client.post(
|
||||
f"{QDRANT_URL}/collections/{coll}/points/search",
|
||||
json=body,
|
||||
)
|
||||
if resp.status_code != 200:
|
||||
logger.warning("Qdrant cross-reg search failed: %d", resp.status_code)
|
||||
return []
|
||||
return resp.json().get("result", [])
|
||||
except Exception as e:
|
||||
logger.warning("Qdrant cross-reg search error: %s", e)
|
||||
return []
|
||||
|
||||
|
||||
async def qdrant_upsert(
|
||||
point_id: str,
|
||||
embedding: list[float],
|
||||
payload: dict,
|
||||
collection: Optional[str] = None,
|
||||
) -> bool:
|
||||
"""Upsert a single point into a Qdrant collection."""
|
||||
if not embedding:
|
||||
return False
|
||||
coll = collection or QDRANT_COLLECTION
|
||||
body = {
|
||||
"points": [{
|
||||
"id": point_id,
|
||||
"vector": embedding,
|
||||
"payload": payload,
|
||||
}]
|
||||
}
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
resp = await client.put(
|
||||
f"{QDRANT_URL}/collections/{coll}/points",
|
||||
json=body,
|
||||
)
|
||||
return resp.status_code == 200
|
||||
except Exception as e:
|
||||
logger.warning("Qdrant upsert error: %s", e)
|
||||
return False
|
||||
|
||||
|
||||
async def ensure_qdrant_collection(
|
||||
vector_size: int = 1024,
|
||||
collection: Optional[str] = None,
|
||||
) -> bool:
|
||||
"""Create a Qdrant collection if it doesn't exist (idempotent)."""
|
||||
coll = collection or QDRANT_COLLECTION
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
# Check if exists
|
||||
resp = await client.get(f"{QDRANT_URL}/collections/{coll}")
|
||||
if resp.status_code == 200:
|
||||
return True
|
||||
# Create
|
||||
resp = await client.put(
|
||||
f"{QDRANT_URL}/collections/{coll}",
|
||||
json={
|
||||
"vectors": {"size": vector_size, "distance": "Cosine"},
|
||||
},
|
||||
)
|
||||
if resp.status_code == 200:
|
||||
logger.info("Created Qdrant collection: %s", coll)
|
||||
# Create payload indexes
|
||||
for field_name in ["pattern_id", "action_normalized", "object_normalized", "control_id"]:
|
||||
await client.put(
|
||||
f"{QDRANT_URL}/collections/{coll}/index",
|
||||
json={"field_name": field_name, "field_schema": "keyword"},
|
||||
)
|
||||
return True
|
||||
logger.error("Failed to create Qdrant collection: %d", resp.status_code)
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.warning("Qdrant collection check error: %s", e)
|
||||
return False
|
||||
|
||||
|
||||
# ── Main Dedup Checker ───────────────────────────────────────────────
|
||||
|
||||
class ControlDedupChecker:
|
||||
"""4-stage dedup checker for atomic controls.
|
||||
|
||||
Usage:
|
||||
checker = ControlDedupChecker(db_session)
|
||||
result = await checker.check_duplicate(candidate_action, candidate_object, candidate_title, pattern_id)
|
||||
if result.verdict == "link":
|
||||
checker.add_parent_link(result.matched_control_uuid, parent_uuid)
|
||||
elif result.verdict == "review":
|
||||
checker.write_review(candidate, result)
|
||||
else:
|
||||
# Insert new control
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
db,
|
||||
embed_fn: Optional[Callable[[str], Awaitable[list[float]]]] = None,
|
||||
search_fn: Optional[Callable] = None,
|
||||
):
|
||||
self.db = db
|
||||
self._embed = embed_fn or get_embedding
|
||||
self._search = search_fn or qdrant_search
|
||||
self._cache: dict[str, list[dict]] = {} # pattern_id → existing controls
|
||||
|
||||
def _load_existing(self, pattern_id: str) -> list[dict]:
|
||||
"""Load existing atomic controls with same pattern_id from DB."""
|
||||
if pattern_id in self._cache:
|
||||
return self._cache[pattern_id]
|
||||
from sqlalchemy import text
|
||||
rows = self.db.execute(text("""
|
||||
SELECT id::text, control_id, title, objective,
|
||||
pattern_id,
|
||||
generation_metadata->>'obligation_type' as obligation_type
|
||||
FROM canonical_controls
|
||||
WHERE parent_control_uuid IS NOT NULL
|
||||
AND release_state != 'deprecated'
|
||||
AND pattern_id = :pid
|
||||
"""), {"pid": pattern_id}).fetchall()
|
||||
result = [
|
||||
{
|
||||
"uuid": r[0], "control_id": r[1], "title": r[2],
|
||||
"objective": r[3], "pattern_id": r[4],
|
||||
"obligation_type": r[5],
|
||||
}
|
||||
for r in rows
|
||||
]
|
||||
self._cache[pattern_id] = result
|
||||
return result
|
||||
|
||||
async def check_duplicate(
|
||||
self,
|
||||
action: str,
|
||||
obj: str,
|
||||
title: str,
|
||||
pattern_id: Optional[str],
|
||||
) -> DedupResult:
|
||||
"""Run the 4-stage dedup pipeline + cross-regulation linking.
|
||||
|
||||
Returns DedupResult with verdict: new/link/review.
|
||||
"""
|
||||
# No pattern_id → can't dedup meaningfully
|
||||
if not pattern_id:
|
||||
return DedupResult(verdict="new", stage="no_pattern")
|
||||
|
||||
# Stage 1: Pattern-Gate
|
||||
existing = self._load_existing(pattern_id)
|
||||
if not existing:
|
||||
return DedupResult(
|
||||
verdict="new", stage="pattern_gate",
|
||||
details={"reason": "no existing controls with this pattern_id"},
|
||||
)
|
||||
|
||||
# Stage 2: Action-Check
|
||||
norm_action = normalize_action(action)
|
||||
# We don't have action stored on existing controls from DB directly,
|
||||
# so we use embedding for controls that passed pattern gate.
|
||||
# But we CAN check via generation_metadata if available.
|
||||
|
||||
# Stage 3: Object-Normalization
|
||||
norm_object = normalize_object(obj)
|
||||
|
||||
# Stage 4: Embedding Similarity
|
||||
canonical = canonicalize_text(action, obj, title)
|
||||
embedding = await self._embed(canonical)
|
||||
if not embedding:
|
||||
# Can't compute embedding → default to new
|
||||
return DedupResult(
|
||||
verdict="new", stage="embedding_unavailable",
|
||||
details={"canonical_text": canonical},
|
||||
)
|
||||
|
||||
# Search Qdrant
|
||||
results = await self._search(embedding, pattern_id, top_k=5)
|
||||
|
||||
if not results:
|
||||
# No intra-pattern matches → try cross-regulation
|
||||
return await self._check_cross_regulation(embedding, DedupResult(
|
||||
verdict="new", stage="no_qdrant_matches",
|
||||
details={"canonical_text": canonical, "action": norm_action, "object": norm_object},
|
||||
))
|
||||
|
||||
# Evaluate best match
|
||||
best = results[0]
|
||||
best_score = best.get("score", 0.0)
|
||||
best_payload = best.get("payload", {})
|
||||
best_action = best_payload.get("action_normalized", "")
|
||||
best_object = best_payload.get("object_normalized", "")
|
||||
|
||||
# Action differs → NEW (even if embedding is high)
|
||||
if best_action and norm_action and best_action != norm_action:
|
||||
return await self._check_cross_regulation(embedding, DedupResult(
|
||||
verdict="new", stage="action_mismatch",
|
||||
similarity_score=best_score,
|
||||
matched_control_id=best_payload.get("control_id"),
|
||||
details={
|
||||
"candidate_action": norm_action,
|
||||
"existing_action": best_action,
|
||||
"similarity": best_score,
|
||||
},
|
||||
))
|
||||
|
||||
# Object differs → use higher threshold
|
||||
if best_object and norm_object and best_object != norm_object:
|
||||
if best_score > LINK_THRESHOLD_DIFF_OBJECT:
|
||||
return DedupResult(
|
||||
verdict="link", stage="embedding_diff_object",
|
||||
matched_control_uuid=best_payload.get("control_uuid"),
|
||||
matched_control_id=best_payload.get("control_id"),
|
||||
matched_title=best_payload.get("title"),
|
||||
similarity_score=best_score,
|
||||
details={"candidate_object": norm_object, "existing_object": best_object},
|
||||
)
|
||||
return await self._check_cross_regulation(embedding, DedupResult(
|
||||
verdict="new", stage="object_mismatch_below_threshold",
|
||||
similarity_score=best_score,
|
||||
matched_control_id=best_payload.get("control_id"),
|
||||
details={
|
||||
"candidate_object": norm_object,
|
||||
"existing_object": best_object,
|
||||
"threshold": LINK_THRESHOLD_DIFF_OBJECT,
|
||||
},
|
||||
))
|
||||
|
||||
# Same action + same object → tiered thresholds
|
||||
if best_score > LINK_THRESHOLD:
|
||||
return DedupResult(
|
||||
verdict="link", stage="embedding_match",
|
||||
matched_control_uuid=best_payload.get("control_uuid"),
|
||||
matched_control_id=best_payload.get("control_id"),
|
||||
matched_title=best_payload.get("title"),
|
||||
similarity_score=best_score,
|
||||
)
|
||||
if best_score > REVIEW_THRESHOLD:
|
||||
return DedupResult(
|
||||
verdict="review", stage="embedding_review",
|
||||
matched_control_uuid=best_payload.get("control_uuid"),
|
||||
matched_control_id=best_payload.get("control_id"),
|
||||
matched_title=best_payload.get("title"),
|
||||
similarity_score=best_score,
|
||||
)
|
||||
return await self._check_cross_regulation(embedding, DedupResult(
|
||||
verdict="new", stage="embedding_below_threshold",
|
||||
similarity_score=best_score,
|
||||
details={"threshold": REVIEW_THRESHOLD},
|
||||
))
|
||||
|
||||
async def _check_cross_regulation(
|
||||
self,
|
||||
embedding: list[float],
|
||||
intra_result: DedupResult,
|
||||
) -> DedupResult:
|
||||
"""Second pass: cross-regulation linking for controls deemed 'new'.
|
||||
|
||||
Searches Qdrant WITHOUT pattern_id filter. Uses a higher threshold
|
||||
(0.95) to avoid false positives across regulation boundaries.
|
||||
"""
|
||||
if intra_result.verdict != "new" or not embedding:
|
||||
return intra_result
|
||||
|
||||
cross_results = await qdrant_search_cross_regulation(embedding, top_k=5)
|
||||
if not cross_results:
|
||||
return intra_result
|
||||
|
||||
best = cross_results[0]
|
||||
best_score = best.get("score", 0.0)
|
||||
if best_score > CROSS_REG_LINK_THRESHOLD:
|
||||
best_payload = best.get("payload", {})
|
||||
return DedupResult(
|
||||
verdict="link",
|
||||
stage="cross_regulation",
|
||||
matched_control_uuid=best_payload.get("control_uuid"),
|
||||
matched_control_id=best_payload.get("control_id"),
|
||||
matched_title=best_payload.get("title"),
|
||||
similarity_score=best_score,
|
||||
link_type="cross_regulation",
|
||||
details={
|
||||
"cross_reg_score": best_score,
|
||||
"cross_reg_threshold": CROSS_REG_LINK_THRESHOLD,
|
||||
},
|
||||
)
|
||||
|
||||
return intra_result
|
||||
|
||||
def add_parent_link(
|
||||
self,
|
||||
control_uuid: str,
|
||||
parent_control_uuid: str,
|
||||
link_type: str = "dedup_merge",
|
||||
confidence: float = 0.0,
|
||||
source_regulation: Optional[str] = None,
|
||||
source_article: Optional[str] = None,
|
||||
obligation_candidate_id: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Add a parent link to an existing atomic control."""
|
||||
from sqlalchemy import text
|
||||
self.db.execute(text("""
|
||||
INSERT INTO control_parent_links
|
||||
(control_uuid, parent_control_uuid, link_type, confidence,
|
||||
source_regulation, source_article, obligation_candidate_id)
|
||||
VALUES (:cu, :pu, :lt, :conf, :sr, :sa, :oci::uuid)
|
||||
ON CONFLICT (control_uuid, parent_control_uuid) DO NOTHING
|
||||
"""), {
|
||||
"cu": control_uuid,
|
||||
"pu": parent_control_uuid,
|
||||
"lt": link_type,
|
||||
"conf": confidence,
|
||||
"sr": source_regulation,
|
||||
"sa": source_article,
|
||||
"oci": obligation_candidate_id,
|
||||
})
|
||||
self.db.commit()
|
||||
|
||||
def write_review(
|
||||
self,
|
||||
candidate_control_id: str,
|
||||
candidate_title: str,
|
||||
candidate_objective: str,
|
||||
result: DedupResult,
|
||||
parent_control_uuid: Optional[str] = None,
|
||||
obligation_candidate_id: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Write a dedup review queue entry."""
|
||||
from sqlalchemy import text
|
||||
self.db.execute(text("""
|
||||
INSERT INTO control_dedup_reviews
|
||||
(candidate_control_id, candidate_title, candidate_objective,
|
||||
matched_control_uuid, matched_control_id,
|
||||
similarity_score, dedup_stage, dedup_details,
|
||||
parent_control_uuid, obligation_candidate_id)
|
||||
VALUES (:ccid, :ct, :co, :mcu::uuid, :mci, :ss, :ds,
|
||||
:dd::jsonb, :pcu::uuid, :oci)
|
||||
"""), {
|
||||
"ccid": candidate_control_id,
|
||||
"ct": candidate_title,
|
||||
"co": candidate_objective,
|
||||
"mcu": result.matched_control_uuid,
|
||||
"mci": result.matched_control_id,
|
||||
"ss": result.similarity_score,
|
||||
"ds": result.stage,
|
||||
"dd": __import__("json").dumps(result.details),
|
||||
"pcu": parent_control_uuid,
|
||||
"oci": obligation_candidate_id,
|
||||
})
|
||||
self.db.commit()
|
||||
|
||||
async def index_control(
|
||||
self,
|
||||
control_uuid: str,
|
||||
control_id: str,
|
||||
title: str,
|
||||
action: str,
|
||||
obj: str,
|
||||
pattern_id: str,
|
||||
collection: Optional[str] = None,
|
||||
) -> bool:
|
||||
"""Index a new atomic control in Qdrant for future dedup checks."""
|
||||
norm_action = normalize_action(action)
|
||||
norm_object = normalize_object(obj)
|
||||
canonical = canonicalize_text(action, obj, title)
|
||||
embedding = await self._embed(canonical)
|
||||
if not embedding:
|
||||
return False
|
||||
return await qdrant_upsert(
|
||||
point_id=control_uuid,
|
||||
embedding=embedding,
|
||||
payload={
|
||||
"control_uuid": control_uuid,
|
||||
"control_id": control_id,
|
||||
"title": title,
|
||||
"pattern_id": pattern_id,
|
||||
"action_normalized": norm_action,
|
||||
"object_normalized": norm_object,
|
||||
"canonical_text": canonical,
|
||||
},
|
||||
collection=collection,
|
||||
)
|
||||
2249
control-pipeline/services/control_generator.py
Normal file
2249
control-pipeline/services/control_generator.py
Normal file
File diff suppressed because it is too large
Load Diff
154
control-pipeline/services/control_status_machine.py
Normal file
154
control-pipeline/services/control_status_machine.py
Normal file
@@ -0,0 +1,154 @@
|
||||
"""
|
||||
Control Status Transition State Machine.
|
||||
|
||||
Enforces that controls cannot be set to "pass" without sufficient evidence.
|
||||
Prevents Compliance-Theater where controls claim compliance without real proof.
|
||||
|
||||
Transition rules:
|
||||
planned → in_progress : always allowed
|
||||
in_progress → pass : requires ≥1 evidence with confidence ≥ E2 and
|
||||
truth_status in (uploaded, observed, validated_internal)
|
||||
in_progress → partial : requires ≥1 evidence (any level)
|
||||
pass → fail : always allowed (degradation)
|
||||
any → n/a : requires status_justification
|
||||
any → planned : always allowed (reset)
|
||||
"""
|
||||
|
||||
from typing import Any, List, Optional, Tuple
|
||||
|
||||
# EvidenceDB is an ORM model from compliance — we only need duck-typed objects
|
||||
# with .confidence_level and .truth_status attributes.
|
||||
EvidenceDB = Any
|
||||
|
||||
|
||||
# Confidence level ordering for comparisons
|
||||
CONFIDENCE_ORDER = {"E0": 0, "E1": 1, "E2": 2, "E3": 3, "E4": 4}
|
||||
|
||||
# Truth statuses that qualify as "real" evidence for pass transitions
|
||||
VALID_TRUTH_STATUSES = {"uploaded", "observed", "validated_internal", "accepted_by_auditor", "provided_to_auditor"}
|
||||
|
||||
|
||||
def validate_transition(
|
||||
current_status: str,
|
||||
new_status: str,
|
||||
evidence_list: Optional[List[EvidenceDB]] = None,
|
||||
status_justification: Optional[str] = None,
|
||||
bypass_for_auto_updater: bool = False,
|
||||
) -> Tuple[bool, List[str]]:
|
||||
"""
|
||||
Validate whether a control status transition is allowed.
|
||||
|
||||
Args:
|
||||
current_status: Current control status value (e.g. "planned", "pass")
|
||||
new_status: Requested new status
|
||||
evidence_list: List of EvidenceDB objects linked to this control
|
||||
status_justification: Text justification (required for n/a transitions)
|
||||
bypass_for_auto_updater: If True, skip evidence checks (used by CI/CD auto-updater
|
||||
which creates evidence atomically with status change)
|
||||
|
||||
Returns:
|
||||
Tuple of (allowed: bool, violations: list[str])
|
||||
"""
|
||||
violations: List[str] = []
|
||||
evidence_list = evidence_list or []
|
||||
|
||||
# Same status → no-op, always allowed
|
||||
if current_status == new_status:
|
||||
return True, []
|
||||
|
||||
# Reset to planned is always allowed
|
||||
if new_status == "planned":
|
||||
return True, []
|
||||
|
||||
# n/a requires justification
|
||||
if new_status == "n/a":
|
||||
if not status_justification or not status_justification.strip():
|
||||
violations.append("Transition to 'n/a' requires a status_justification explaining why this control is not applicable.")
|
||||
return len(violations) == 0, violations
|
||||
|
||||
# Degradation: pass → fail is always allowed
|
||||
if current_status == "pass" and new_status == "fail":
|
||||
return True, []
|
||||
|
||||
# planned → in_progress: always allowed
|
||||
if current_status == "planned" and new_status == "in_progress":
|
||||
return True, []
|
||||
|
||||
# in_progress → partial: needs at least 1 evidence
|
||||
if new_status == "partial":
|
||||
if not bypass_for_auto_updater and len(evidence_list) == 0:
|
||||
violations.append("Transition to 'partial' requires at least 1 evidence record.")
|
||||
return len(violations) == 0, violations
|
||||
|
||||
# in_progress → pass: strict requirements
|
||||
if new_status == "pass":
|
||||
if bypass_for_auto_updater:
|
||||
return True, []
|
||||
|
||||
if len(evidence_list) == 0:
|
||||
violations.append("Transition to 'pass' requires at least 1 evidence record.")
|
||||
return False, violations
|
||||
|
||||
# Check for at least one qualifying evidence
|
||||
has_qualifying = False
|
||||
for e in evidence_list:
|
||||
conf = getattr(e, "confidence_level", None)
|
||||
truth = getattr(e, "truth_status", None)
|
||||
|
||||
# Get string values from enum or string
|
||||
conf_val = conf.value if hasattr(conf, "value") else str(conf) if conf else "E1"
|
||||
truth_val = truth.value if hasattr(truth, "value") else str(truth) if truth else "uploaded"
|
||||
|
||||
if CONFIDENCE_ORDER.get(conf_val, 1) >= CONFIDENCE_ORDER["E2"] and truth_val in VALID_TRUTH_STATUSES:
|
||||
has_qualifying = True
|
||||
break
|
||||
|
||||
if not has_qualifying:
|
||||
violations.append(
|
||||
"Transition to 'pass' requires at least 1 evidence with confidence >= E2 "
|
||||
"and truth_status in (uploaded, observed, validated_internal, accepted_by_auditor). "
|
||||
"Current evidence does not meet this threshold."
|
||||
)
|
||||
|
||||
return len(violations) == 0, violations
|
||||
|
||||
# in_progress → fail: always allowed
|
||||
if new_status == "fail":
|
||||
return True, []
|
||||
|
||||
# Any other transition from planned/fail to pass requires going through in_progress
|
||||
if current_status in ("planned", "fail") and new_status == "pass":
|
||||
if bypass_for_auto_updater:
|
||||
return True, []
|
||||
violations.append(
|
||||
f"Direct transition from '{current_status}' to 'pass' is not allowed. "
|
||||
f"Move to 'in_progress' first, then to 'pass' with qualifying evidence."
|
||||
)
|
||||
return False, violations
|
||||
|
||||
# Default: allow other transitions (e.g. fail → partial, partial → pass)
|
||||
# For partial → pass, apply the same evidence checks
|
||||
if current_status == "partial" and new_status == "pass":
|
||||
if bypass_for_auto_updater:
|
||||
return True, []
|
||||
|
||||
has_qualifying = False
|
||||
for e in evidence_list:
|
||||
conf = getattr(e, "confidence_level", None)
|
||||
truth = getattr(e, "truth_status", None)
|
||||
conf_val = conf.value if hasattr(conf, "value") else str(conf) if conf else "E1"
|
||||
truth_val = truth.value if hasattr(truth, "value") else str(truth) if truth else "uploaded"
|
||||
|
||||
if CONFIDENCE_ORDER.get(conf_val, 1) >= CONFIDENCE_ORDER["E2"] and truth_val in VALID_TRUTH_STATUSES:
|
||||
has_qualifying = True
|
||||
break
|
||||
|
||||
if not has_qualifying:
|
||||
violations.append(
|
||||
"Transition from 'partial' to 'pass' requires at least 1 evidence with confidence >= E2 "
|
||||
"and truth_status in (uploaded, observed, validated_internal, accepted_by_auditor)."
|
||||
)
|
||||
return len(violations) == 0, violations
|
||||
|
||||
# All other transitions allowed
|
||||
return True, []
|
||||
3877
control-pipeline/services/decomposition_pass.py
Normal file
3877
control-pipeline/services/decomposition_pass.py
Normal file
File diff suppressed because it is too large
Load Diff
714
control-pipeline/services/framework_decomposition.py
Normal file
714
control-pipeline/services/framework_decomposition.py
Normal file
@@ -0,0 +1,714 @@
|
||||
"""Framework Decomposition Engine — decomposes framework-container obligations.
|
||||
|
||||
Sits between Pass 0a (obligation extraction) and Pass 0b (atomic control
|
||||
composition). Detects obligations that reference a framework domain (e.g.
|
||||
"CCM-Praktiken fuer AIS") and decomposes them into concrete sub-obligations
|
||||
using an internal framework registry.
|
||||
|
||||
Three routing types:
|
||||
atomic → pass through to Pass 0b unchanged
|
||||
compound → split compound verbs, then Pass 0b
|
||||
framework_container → decompose via registry, then Pass 0b
|
||||
|
||||
The registry is a set of JSON files under compliance/data/frameworks/.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Registry loading
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_REGISTRY_DIR = Path(__file__).resolve().parent.parent / "data" / "frameworks"
|
||||
_REGISTRY: dict[str, dict] = {} # framework_id → framework dict
|
||||
|
||||
|
||||
def _load_registry() -> dict[str, dict]:
|
||||
"""Load all framework JSON files from the registry directory."""
|
||||
registry: dict[str, dict] = {}
|
||||
if not _REGISTRY_DIR.is_dir():
|
||||
logger.warning("Framework registry dir not found: %s", _REGISTRY_DIR)
|
||||
return registry
|
||||
|
||||
for fpath in sorted(_REGISTRY_DIR.glob("*.json")):
|
||||
try:
|
||||
with open(fpath, encoding="utf-8") as f:
|
||||
fw = json.load(f)
|
||||
fw_id = fw.get("framework_id", fpath.stem)
|
||||
registry[fw_id] = fw
|
||||
logger.info(
|
||||
"Loaded framework: %s (%d domains)",
|
||||
fw_id,
|
||||
len(fw.get("domains", [])),
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Failed to load framework file: %s", fpath)
|
||||
return registry
|
||||
|
||||
|
||||
def get_registry() -> dict[str, dict]:
|
||||
"""Return the global framework registry (lazy-loaded)."""
|
||||
global _REGISTRY
|
||||
if not _REGISTRY:
|
||||
_REGISTRY = _load_registry()
|
||||
return _REGISTRY
|
||||
|
||||
|
||||
def reload_registry() -> dict[str, dict]:
|
||||
"""Force-reload the framework registry from disk."""
|
||||
global _REGISTRY
|
||||
_REGISTRY = _load_registry()
|
||||
return _REGISTRY
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Framework alias index (built from registry)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _build_alias_index(registry: dict[str, dict]) -> dict[str, str]:
|
||||
"""Build a lowercase alias → framework_id lookup."""
|
||||
idx: dict[str, str] = {}
|
||||
for fw_id, fw in registry.items():
|
||||
# Framework-level aliases
|
||||
idx[fw_id.lower()] = fw_id
|
||||
name = fw.get("display_name", "")
|
||||
if name:
|
||||
idx[name.lower()] = fw_id
|
||||
# Common short forms
|
||||
for part in fw_id.lower().replace("_", " ").split():
|
||||
if len(part) >= 3:
|
||||
idx[part] = fw_id
|
||||
return idx
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Routing — classify obligation type
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Extended patterns for framework detection (beyond the simple _COMPOSITE_RE
|
||||
# in decomposition_pass.py — here we also capture the framework name)
|
||||
_FRAMEWORK_PATTERN = re.compile(
|
||||
r"(?:praktiken|kontrollen|ma(?:ss|ß)nahmen|anforderungen|vorgaben|controls|practices|measures|requirements)"
|
||||
r"\s+(?:f(?:ue|ü)r|aus|gem(?:ae|ä)(?:ss|ß)|nach|from|of|for|per)\s+"
|
||||
r"(.+?)(?:\s+(?:m(?:ue|ü)ssen|sollen|sind|werden|implementieren|umsetzen|einf(?:ue|ü)hren)|\.|,|$)",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
# Direct framework name references
|
||||
_DIRECT_FRAMEWORK_RE = re.compile(
|
||||
r"\b(?:CSA\s*CCM|NIST\s*(?:SP\s*)?800-53|OWASP\s*(?:ASVS|SAMM|Top\s*10)"
|
||||
r"|CIS\s*Controls|BSI\s*(?:IT-)?Grundschutz|ENISA|ISO\s*2700[12]"
|
||||
r"|COBIT|SOX|PCI\s*DSS|HITRUST|SOC\s*2|KRITIS)\b",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
# Compound verb patterns (multiple main verbs)
|
||||
_COMPOUND_VERB_RE = re.compile(
|
||||
r"\b(?:und|sowie|als\s+auch|or|and)\b",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
# No-split phrases that look compound but aren't
|
||||
_NO_SPLIT_PHRASES = [
|
||||
"pflegen und aufrechterhalten",
|
||||
"dokumentieren und pflegen",
|
||||
"definieren und dokumentieren",
|
||||
"erstellen und freigeben",
|
||||
"pruefen und genehmigen",
|
||||
"identifizieren und bewerten",
|
||||
"erkennen und melden",
|
||||
"define and maintain",
|
||||
"create and maintain",
|
||||
"establish and maintain",
|
||||
"monitor and review",
|
||||
"detect and respond",
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
class RoutingResult:
|
||||
"""Result of obligation routing classification."""
|
||||
routing_type: str # atomic | compound | framework_container | unknown_review
|
||||
framework_ref: Optional[str] = None
|
||||
framework_domain: Optional[str] = None
|
||||
domain_title: Optional[str] = None
|
||||
confidence: float = 0.0
|
||||
reason: str = ""
|
||||
|
||||
|
||||
def classify_routing(
|
||||
obligation_text: str,
|
||||
action_raw: str,
|
||||
object_raw: str,
|
||||
condition_raw: Optional[str] = None,
|
||||
) -> RoutingResult:
|
||||
"""Classify an obligation into atomic / compound / framework_container."""
|
||||
combined = f"{obligation_text} {object_raw}".lower()
|
||||
|
||||
# --- Step 1: Framework container detection ---
|
||||
fw_result = _detect_framework(obligation_text, object_raw)
|
||||
if fw_result.routing_type == "framework_container":
|
||||
return fw_result
|
||||
|
||||
# --- Step 2: Compound verb detection ---
|
||||
if _is_compound_obligation(action_raw, obligation_text):
|
||||
return RoutingResult(
|
||||
routing_type="compound",
|
||||
confidence=0.7,
|
||||
reason="multiple_main_verbs",
|
||||
)
|
||||
|
||||
# --- Step 3: Default = atomic ---
|
||||
return RoutingResult(
|
||||
routing_type="atomic",
|
||||
confidence=0.9,
|
||||
reason="single_action_single_object",
|
||||
)
|
||||
|
||||
|
||||
def _detect_framework(
|
||||
obligation_text: str, object_raw: str,
|
||||
) -> RoutingResult:
|
||||
"""Detect if obligation references a framework domain."""
|
||||
combined = f"{obligation_text} {object_raw}"
|
||||
registry = get_registry()
|
||||
alias_idx = _build_alias_index(registry)
|
||||
|
||||
# Strategy 1: direct framework name match
|
||||
m = _DIRECT_FRAMEWORK_RE.search(combined)
|
||||
if m:
|
||||
fw_name = m.group(0).strip()
|
||||
fw_id = _resolve_framework_id(fw_name, alias_idx, registry)
|
||||
if fw_id:
|
||||
domain_id, domain_title = _match_domain(
|
||||
combined, registry[fw_id],
|
||||
)
|
||||
return RoutingResult(
|
||||
routing_type="framework_container",
|
||||
framework_ref=fw_id,
|
||||
framework_domain=domain_id,
|
||||
domain_title=domain_title,
|
||||
confidence=0.95 if domain_id else 0.75,
|
||||
reason=f"direct_framework_match:{fw_name}",
|
||||
)
|
||||
else:
|
||||
# Framework name recognized but not in registry
|
||||
return RoutingResult(
|
||||
routing_type="framework_container",
|
||||
framework_ref=None,
|
||||
framework_domain=None,
|
||||
confidence=0.6,
|
||||
reason=f"direct_framework_match_no_registry:{fw_name}",
|
||||
)
|
||||
|
||||
# Strategy 2: pattern match ("Praktiken fuer X")
|
||||
m2 = _FRAMEWORK_PATTERN.search(combined)
|
||||
if m2:
|
||||
ref_text = m2.group(1).strip()
|
||||
fw_id, domain_id, domain_title = _resolve_from_ref_text(
|
||||
ref_text, registry, alias_idx,
|
||||
)
|
||||
if fw_id:
|
||||
return RoutingResult(
|
||||
routing_type="framework_container",
|
||||
framework_ref=fw_id,
|
||||
framework_domain=domain_id,
|
||||
domain_title=domain_title,
|
||||
confidence=0.85 if domain_id else 0.65,
|
||||
reason=f"pattern_match:{ref_text}",
|
||||
)
|
||||
|
||||
# Strategy 3: keyword-heavy object
|
||||
if _has_framework_keywords(object_raw):
|
||||
return RoutingResult(
|
||||
routing_type="framework_container",
|
||||
framework_ref=None,
|
||||
framework_domain=None,
|
||||
confidence=0.5,
|
||||
reason="framework_keywords_in_object",
|
||||
)
|
||||
|
||||
return RoutingResult(routing_type="atomic", confidence=0.0)
|
||||
|
||||
|
||||
def _resolve_framework_id(
|
||||
name: str,
|
||||
alias_idx: dict[str, str],
|
||||
registry: dict[str, dict],
|
||||
) -> Optional[str]:
|
||||
"""Resolve a framework name to its registry ID."""
|
||||
normalized = re.sub(r"\s+", " ", name.strip().lower())
|
||||
# Direct alias match
|
||||
if normalized in alias_idx:
|
||||
return alias_idx[normalized]
|
||||
# Try compact form (strip spaces, hyphens, underscores)
|
||||
compact = re.sub(r"[\s_\-]+", "", normalized)
|
||||
for alias, fw_id in alias_idx.items():
|
||||
if re.sub(r"[\s_\-]+", "", alias) == compact:
|
||||
return fw_id
|
||||
# Substring match in display names
|
||||
for fw_id, fw in registry.items():
|
||||
display = fw.get("display_name", "").lower()
|
||||
if normalized in display or display in normalized:
|
||||
return fw_id
|
||||
# Partial match: check if normalized contains any alias (for multi-word refs)
|
||||
for alias, fw_id in alias_idx.items():
|
||||
if len(alias) >= 4 and alias in normalized:
|
||||
return fw_id
|
||||
return None
|
||||
|
||||
|
||||
def _match_domain(
|
||||
text: str, framework: dict,
|
||||
) -> tuple[Optional[str], Optional[str]]:
|
||||
"""Match a domain within a framework from text references."""
|
||||
text_lower = text.lower()
|
||||
best_id: Optional[str] = None
|
||||
best_title: Optional[str] = None
|
||||
best_score = 0
|
||||
|
||||
for domain in framework.get("domains", []):
|
||||
score = 0
|
||||
domain_id = domain["domain_id"]
|
||||
title = domain.get("title", "")
|
||||
|
||||
# Exact domain ID match (e.g. "AIS")
|
||||
if re.search(rf"\b{re.escape(domain_id)}\b", text, re.IGNORECASE):
|
||||
score += 10
|
||||
|
||||
# Full title match
|
||||
if title.lower() in text_lower:
|
||||
score += 8
|
||||
|
||||
# Alias match
|
||||
for alias in domain.get("aliases", []):
|
||||
if alias.lower() in text_lower:
|
||||
score += 6
|
||||
break
|
||||
|
||||
# Keyword overlap
|
||||
kw_hits = sum(
|
||||
1 for kw in domain.get("keywords", [])
|
||||
if kw.lower() in text_lower
|
||||
)
|
||||
score += kw_hits
|
||||
|
||||
if score > best_score:
|
||||
best_score = score
|
||||
best_id = domain_id
|
||||
best_title = title
|
||||
|
||||
if best_score >= 3:
|
||||
return best_id, best_title
|
||||
return None, None
|
||||
|
||||
|
||||
def _resolve_from_ref_text(
|
||||
ref_text: str,
|
||||
registry: dict[str, dict],
|
||||
alias_idx: dict[str, str],
|
||||
) -> tuple[Optional[str], Optional[str], Optional[str]]:
|
||||
"""Resolve framework + domain from a reference text like 'AIS' or 'Application Security'."""
|
||||
ref_lower = ref_text.lower()
|
||||
|
||||
for fw_id, fw in registry.items():
|
||||
for domain in fw.get("domains", []):
|
||||
# Check domain ID
|
||||
if domain["domain_id"].lower() in ref_lower:
|
||||
return fw_id, domain["domain_id"], domain.get("title")
|
||||
# Check title
|
||||
if domain.get("title", "").lower() in ref_lower:
|
||||
return fw_id, domain["domain_id"], domain.get("title")
|
||||
# Check aliases
|
||||
for alias in domain.get("aliases", []):
|
||||
if alias.lower() in ref_lower or ref_lower in alias.lower():
|
||||
return fw_id, domain["domain_id"], domain.get("title")
|
||||
|
||||
return None, None, None
|
||||
|
||||
|
||||
_FRAMEWORK_KW_SET = {
|
||||
"praktiken", "kontrollen", "massnahmen", "maßnahmen",
|
||||
"anforderungen", "vorgaben", "framework", "standard",
|
||||
"baseline", "katalog", "domain", "family", "category",
|
||||
"practices", "controls", "measures", "requirements",
|
||||
}
|
||||
|
||||
|
||||
def _has_framework_keywords(text: str) -> bool:
|
||||
"""Check if text contains framework-indicator keywords."""
|
||||
words = set(re.findall(r"[a-zäöüß]+", text.lower()))
|
||||
return len(words & _FRAMEWORK_KW_SET) >= 2
|
||||
|
||||
|
||||
def _is_compound_obligation(action_raw: str, obligation_text: str) -> bool:
|
||||
"""Detect if the obligation has multiple competing main verbs."""
|
||||
if not action_raw:
|
||||
return False
|
||||
|
||||
action_lower = action_raw.lower().strip()
|
||||
|
||||
# Check no-split phrases first
|
||||
for phrase in _NO_SPLIT_PHRASES:
|
||||
if phrase in action_lower:
|
||||
return False
|
||||
|
||||
# Must have a conjunction
|
||||
if not _COMPOUND_VERB_RE.search(action_lower):
|
||||
return False
|
||||
|
||||
# Split by conjunctions and check if we get 2+ meaningful verbs
|
||||
parts = re.split(r"\b(?:und|sowie|als\s+auch|or|and)\b", action_lower)
|
||||
meaningful = [p.strip() for p in parts if len(p.strip()) >= 3]
|
||||
return len(meaningful) >= 2
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Framework Decomposition
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@dataclass
|
||||
class DecomposedObligation:
|
||||
"""A concrete obligation derived from a framework container."""
|
||||
obligation_candidate_id: str
|
||||
parent_control_id: str
|
||||
parent_framework_container_id: str
|
||||
source_ref_law: str
|
||||
source_ref_article: str
|
||||
obligation_text: str
|
||||
actor: str
|
||||
action_raw: str
|
||||
object_raw: str
|
||||
condition_raw: Optional[str] = None
|
||||
trigger_raw: Optional[str] = None
|
||||
routing_type: str = "atomic"
|
||||
release_state: str = "decomposed"
|
||||
subcontrol_id: str = ""
|
||||
# Metadata
|
||||
action_hint: str = ""
|
||||
object_hint: str = ""
|
||||
object_class: str = ""
|
||||
keywords: list[str] = field(default_factory=list)
|
||||
|
||||
|
||||
@dataclass
|
||||
class FrameworkDecompositionResult:
|
||||
"""Result of framework decomposition."""
|
||||
framework_container_id: str
|
||||
source_obligation_candidate_id: str
|
||||
framework_ref: Optional[str]
|
||||
framework_domain: Optional[str]
|
||||
domain_title: Optional[str]
|
||||
matched_subcontrols: list[str]
|
||||
decomposition_confidence: float
|
||||
release_state: str # decomposed | unmatched | error
|
||||
decomposed_obligations: list[DecomposedObligation]
|
||||
issues: list[str]
|
||||
|
||||
|
||||
def decompose_framework_container(
|
||||
obligation_candidate_id: str,
|
||||
parent_control_id: str,
|
||||
obligation_text: str,
|
||||
framework_ref: Optional[str],
|
||||
framework_domain: Optional[str],
|
||||
actor: str = "organization",
|
||||
) -> FrameworkDecompositionResult:
|
||||
"""Decompose a framework-container obligation into concrete sub-obligations.
|
||||
|
||||
Steps:
|
||||
1. Resolve framework from registry
|
||||
2. Resolve domain within framework
|
||||
3. Select relevant subcontrols (keyword filter or full domain)
|
||||
4. Generate decomposed obligations
|
||||
"""
|
||||
container_id = f"FWC-{uuid.uuid4().hex[:8]}"
|
||||
registry = get_registry()
|
||||
issues: list[str] = []
|
||||
|
||||
# Step 1: Resolve framework
|
||||
fw = None
|
||||
if framework_ref and framework_ref in registry:
|
||||
fw = registry[framework_ref]
|
||||
else:
|
||||
# Try to find by name in text
|
||||
fw, framework_ref = _find_framework_in_text(obligation_text, registry)
|
||||
|
||||
if not fw:
|
||||
issues.append("ERROR: framework_not_matched")
|
||||
return FrameworkDecompositionResult(
|
||||
framework_container_id=container_id,
|
||||
source_obligation_candidate_id=obligation_candidate_id,
|
||||
framework_ref=framework_ref,
|
||||
framework_domain=framework_domain,
|
||||
domain_title=None,
|
||||
matched_subcontrols=[],
|
||||
decomposition_confidence=0.0,
|
||||
release_state="unmatched",
|
||||
decomposed_obligations=[],
|
||||
issues=issues,
|
||||
)
|
||||
|
||||
# Step 2: Resolve domain
|
||||
domain_data = None
|
||||
domain_title = None
|
||||
if framework_domain:
|
||||
for d in fw.get("domains", []):
|
||||
if d["domain_id"].lower() == framework_domain.lower():
|
||||
domain_data = d
|
||||
domain_title = d.get("title")
|
||||
break
|
||||
if not domain_data:
|
||||
# Try matching from text
|
||||
domain_id, domain_title = _match_domain(obligation_text, fw)
|
||||
if domain_id:
|
||||
for d in fw.get("domains", []):
|
||||
if d["domain_id"] == domain_id:
|
||||
domain_data = d
|
||||
framework_domain = domain_id
|
||||
break
|
||||
|
||||
if not domain_data:
|
||||
issues.append("WARN: domain_not_matched — using all domains")
|
||||
# Fall back to all subcontrols across all domains
|
||||
all_subcontrols = []
|
||||
for d in fw.get("domains", []):
|
||||
for sc in d.get("subcontrols", []):
|
||||
sc["_domain_id"] = d["domain_id"]
|
||||
all_subcontrols.append(sc)
|
||||
subcontrols = _select_subcontrols(obligation_text, all_subcontrols)
|
||||
if not subcontrols:
|
||||
issues.append("ERROR: no_subcontrols_matched")
|
||||
return FrameworkDecompositionResult(
|
||||
framework_container_id=container_id,
|
||||
source_obligation_candidate_id=obligation_candidate_id,
|
||||
framework_ref=framework_ref,
|
||||
framework_domain=framework_domain,
|
||||
domain_title=None,
|
||||
matched_subcontrols=[],
|
||||
decomposition_confidence=0.0,
|
||||
release_state="unmatched",
|
||||
decomposed_obligations=[],
|
||||
issues=issues,
|
||||
)
|
||||
else:
|
||||
# Step 3: Select subcontrols from domain
|
||||
raw_subcontrols = domain_data.get("subcontrols", [])
|
||||
subcontrols = _select_subcontrols(obligation_text, raw_subcontrols)
|
||||
if not subcontrols:
|
||||
# Full domain decomposition
|
||||
subcontrols = raw_subcontrols
|
||||
|
||||
# Quality check: too many subcontrols
|
||||
if len(subcontrols) > 25:
|
||||
issues.append(f"WARN: {len(subcontrols)} subcontrols — may be too broad")
|
||||
|
||||
# Step 4: Generate decomposed obligations
|
||||
display_name = fw.get("display_name", framework_ref or "Unknown")
|
||||
decomposed: list[DecomposedObligation] = []
|
||||
matched_ids: list[str] = []
|
||||
|
||||
for sc in subcontrols:
|
||||
sc_id = sc.get("subcontrol_id", "")
|
||||
matched_ids.append(sc_id)
|
||||
|
||||
action_hint = sc.get("action_hint", "")
|
||||
object_hint = sc.get("object_hint", "")
|
||||
|
||||
# Quality warnings
|
||||
if not action_hint:
|
||||
issues.append(f"WARN: {sc_id} missing action_hint")
|
||||
if not object_hint:
|
||||
issues.append(f"WARN: {sc_id} missing object_hint")
|
||||
|
||||
obl_id = f"{obligation_candidate_id}-{sc_id}"
|
||||
|
||||
decomposed.append(DecomposedObligation(
|
||||
obligation_candidate_id=obl_id,
|
||||
parent_control_id=parent_control_id,
|
||||
parent_framework_container_id=container_id,
|
||||
source_ref_law=display_name,
|
||||
source_ref_article=sc_id,
|
||||
obligation_text=sc.get("statement", ""),
|
||||
actor=actor,
|
||||
action_raw=action_hint or _infer_action(sc.get("statement", "")),
|
||||
object_raw=object_hint or _infer_object(sc.get("statement", "")),
|
||||
routing_type="atomic",
|
||||
release_state="decomposed",
|
||||
subcontrol_id=sc_id,
|
||||
action_hint=action_hint,
|
||||
object_hint=object_hint,
|
||||
object_class=sc.get("object_class", ""),
|
||||
keywords=sc.get("keywords", []),
|
||||
))
|
||||
|
||||
# Check if decomposed are identical to container
|
||||
for d in decomposed:
|
||||
if d.obligation_text.strip() == obligation_text.strip():
|
||||
issues.append(f"WARN: {d.subcontrol_id} identical to container text")
|
||||
|
||||
confidence = _compute_decomposition_confidence(
|
||||
framework_ref, framework_domain, domain_data, len(subcontrols), issues,
|
||||
)
|
||||
|
||||
return FrameworkDecompositionResult(
|
||||
framework_container_id=container_id,
|
||||
source_obligation_candidate_id=obligation_candidate_id,
|
||||
framework_ref=framework_ref,
|
||||
framework_domain=framework_domain,
|
||||
domain_title=domain_title,
|
||||
matched_subcontrols=matched_ids,
|
||||
decomposition_confidence=confidence,
|
||||
release_state="decomposed",
|
||||
decomposed_obligations=decomposed,
|
||||
issues=issues,
|
||||
)
|
||||
|
||||
|
||||
def _find_framework_in_text(
|
||||
text: str, registry: dict[str, dict],
|
||||
) -> tuple[Optional[dict], Optional[str]]:
|
||||
"""Try to find a framework by searching text for known names."""
|
||||
alias_idx = _build_alias_index(registry)
|
||||
m = _DIRECT_FRAMEWORK_RE.search(text)
|
||||
if m:
|
||||
fw_id = _resolve_framework_id(m.group(0), alias_idx, registry)
|
||||
if fw_id and fw_id in registry:
|
||||
return registry[fw_id], fw_id
|
||||
return None, None
|
||||
|
||||
|
||||
def _select_subcontrols(
|
||||
obligation_text: str, subcontrols: list[dict],
|
||||
) -> list[dict]:
|
||||
"""Select relevant subcontrols based on keyword matching.
|
||||
|
||||
Returns empty list if no targeted match found (caller falls back to
|
||||
full domain).
|
||||
"""
|
||||
text_lower = obligation_text.lower()
|
||||
scored: list[tuple[int, dict]] = []
|
||||
|
||||
for sc in subcontrols:
|
||||
score = 0
|
||||
for kw in sc.get("keywords", []):
|
||||
if kw.lower() in text_lower:
|
||||
score += 1
|
||||
# Title match
|
||||
title = sc.get("title", "").lower()
|
||||
if title and title in text_lower:
|
||||
score += 3
|
||||
# Object hint in text
|
||||
obj = sc.get("object_hint", "").lower()
|
||||
if obj and obj in text_lower:
|
||||
score += 2
|
||||
|
||||
if score > 0:
|
||||
scored.append((score, sc))
|
||||
|
||||
if not scored:
|
||||
return []
|
||||
|
||||
# Only return those with meaningful overlap (score >= 2)
|
||||
scored.sort(key=lambda x: x[0], reverse=True)
|
||||
return [sc for score, sc in scored if score >= 2]
|
||||
|
||||
|
||||
def _infer_action(statement: str) -> str:
|
||||
"""Infer a basic action verb from a statement."""
|
||||
s = statement.lower()
|
||||
if any(w in s for w in ["definiert", "definieren", "define"]):
|
||||
return "definieren"
|
||||
if any(w in s for w in ["implementiert", "implementieren", "implement"]):
|
||||
return "implementieren"
|
||||
if any(w in s for w in ["dokumentiert", "dokumentieren", "document"]):
|
||||
return "dokumentieren"
|
||||
if any(w in s for w in ["ueberwacht", "ueberwachen", "monitor"]):
|
||||
return "ueberwachen"
|
||||
if any(w in s for w in ["getestet", "testen", "test"]):
|
||||
return "testen"
|
||||
if any(w in s for w in ["geschuetzt", "schuetzen", "protect"]):
|
||||
return "implementieren"
|
||||
if any(w in s for w in ["verwaltet", "verwalten", "manage"]):
|
||||
return "pflegen"
|
||||
if any(w in s for w in ["gemeldet", "melden", "report"]):
|
||||
return "melden"
|
||||
return "implementieren"
|
||||
|
||||
|
||||
def _infer_object(statement: str) -> str:
|
||||
"""Infer the primary object from a statement (first noun phrase)."""
|
||||
# Simple heuristic: take the text after "muessen"/"muss" up to the verb
|
||||
m = re.search(
|
||||
r"(?:muessen|muss|m(?:ü|ue)ssen)\s+(.+?)(?:\s+werden|\s+sein|\.|,|$)",
|
||||
statement,
|
||||
re.IGNORECASE,
|
||||
)
|
||||
if m:
|
||||
return m.group(1).strip()[:80]
|
||||
# Fallback: first 80 chars
|
||||
return statement[:80] if statement else ""
|
||||
|
||||
|
||||
def _compute_decomposition_confidence(
|
||||
framework_ref: Optional[str],
|
||||
domain: Optional[str],
|
||||
domain_data: Optional[dict],
|
||||
num_subcontrols: int,
|
||||
issues: list[str],
|
||||
) -> float:
|
||||
"""Compute confidence score for the decomposition."""
|
||||
score = 0.3
|
||||
if framework_ref:
|
||||
score += 0.25
|
||||
if domain:
|
||||
score += 0.20
|
||||
if domain_data:
|
||||
score += 0.10
|
||||
if 1 <= num_subcontrols <= 15:
|
||||
score += 0.10
|
||||
elif num_subcontrols > 15:
|
||||
score += 0.05 # less confident with too many
|
||||
|
||||
# Penalize errors
|
||||
errors = sum(1 for i in issues if i.startswith("ERROR:"))
|
||||
score -= errors * 0.15
|
||||
return round(max(min(score, 1.0), 0.0), 2)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Registry statistics (for admin/debugging)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def registry_stats() -> dict:
|
||||
"""Return summary statistics about the loaded registry."""
|
||||
reg = get_registry()
|
||||
stats = {
|
||||
"frameworks": len(reg),
|
||||
"details": [],
|
||||
}
|
||||
total_domains = 0
|
||||
total_subcontrols = 0
|
||||
for fw_id, fw in reg.items():
|
||||
domains = fw.get("domains", [])
|
||||
n_sc = sum(len(d.get("subcontrols", [])) for d in domains)
|
||||
total_domains += len(domains)
|
||||
total_subcontrols += n_sc
|
||||
stats["details"].append({
|
||||
"framework_id": fw_id,
|
||||
"display_name": fw.get("display_name", ""),
|
||||
"domains": len(domains),
|
||||
"subcontrols": n_sc,
|
||||
})
|
||||
stats["total_domains"] = total_domains
|
||||
stats["total_subcontrols"] = total_subcontrols
|
||||
return stats
|
||||
116
control-pipeline/services/license_gate.py
Normal file
116
control-pipeline/services/license_gate.py
Normal file
@@ -0,0 +1,116 @@
|
||||
"""
|
||||
License Gate — checks whether a given source may be used for a specific purpose.
|
||||
|
||||
Usage types:
|
||||
- analysis: Read + analyse internally (TDM under UrhG 44b)
|
||||
- store_excerpt: Store verbatim excerpt in vault
|
||||
- ship_embeddings: Ship embeddings in product
|
||||
- ship_in_product: Ship text/content in product
|
||||
|
||||
Policy is driven by the canonical_control_sources table columns:
|
||||
allowed_analysis, allowed_store_excerpt, allowed_ship_embeddings, allowed_ship_in_product
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
USAGE_COLUMN_MAP = {
|
||||
"analysis": "allowed_analysis",
|
||||
"store_excerpt": "allowed_store_excerpt",
|
||||
"ship_embeddings": "allowed_ship_embeddings",
|
||||
"ship_in_product": "allowed_ship_in_product",
|
||||
}
|
||||
|
||||
|
||||
def check_source_allowed(db: Session, source_id: str, usage_type: str) -> bool:
|
||||
"""Check whether *source_id* may be used for *usage_type*.
|
||||
|
||||
Returns False if the source is unknown or the usage is not allowed.
|
||||
"""
|
||||
col = USAGE_COLUMN_MAP.get(usage_type)
|
||||
if col is None:
|
||||
logger.warning("Unknown usage_type=%s", usage_type)
|
||||
return False
|
||||
|
||||
row = db.execute(
|
||||
text(f"SELECT {col} FROM canonical_control_sources WHERE source_id = :sid"),
|
||||
{"sid": source_id},
|
||||
).fetchone()
|
||||
|
||||
if row is None:
|
||||
logger.warning("Source %s not found in registry", source_id)
|
||||
return False
|
||||
|
||||
return bool(row[0])
|
||||
|
||||
|
||||
def get_license_matrix(db: Session) -> list[dict[str, Any]]:
|
||||
"""Return the full license matrix with allowed usages per license."""
|
||||
rows = db.execute(
|
||||
text("""
|
||||
SELECT license_id, name, terms_url, commercial_use,
|
||||
ai_training_restriction, tdm_allowed_under_44b,
|
||||
deletion_required, notes
|
||||
FROM canonical_control_licenses
|
||||
ORDER BY license_id
|
||||
""")
|
||||
).fetchall()
|
||||
|
||||
return [
|
||||
{
|
||||
"license_id": r.license_id,
|
||||
"name": r.name,
|
||||
"terms_url": r.terms_url,
|
||||
"commercial_use": r.commercial_use,
|
||||
"ai_training_restriction": r.ai_training_restriction,
|
||||
"tdm_allowed_under_44b": r.tdm_allowed_under_44b,
|
||||
"deletion_required": r.deletion_required,
|
||||
"notes": r.notes,
|
||||
}
|
||||
for r in rows
|
||||
]
|
||||
|
||||
|
||||
def get_source_permissions(db: Session) -> list[dict[str, Any]]:
|
||||
"""Return all sources with their permission flags."""
|
||||
rows = db.execute(
|
||||
text("""
|
||||
SELECT s.source_id, s.title, s.publisher, s.url, s.version_label,
|
||||
s.language, s.license_id,
|
||||
s.allowed_analysis, s.allowed_store_excerpt,
|
||||
s.allowed_ship_embeddings, s.allowed_ship_in_product,
|
||||
s.vault_retention_days, s.vault_access_tier,
|
||||
l.name AS license_name, l.commercial_use
|
||||
FROM canonical_control_sources s
|
||||
JOIN canonical_control_licenses l ON l.license_id = s.license_id
|
||||
ORDER BY s.source_id
|
||||
""")
|
||||
).fetchall()
|
||||
|
||||
return [
|
||||
{
|
||||
"source_id": r.source_id,
|
||||
"title": r.title,
|
||||
"publisher": r.publisher,
|
||||
"url": r.url,
|
||||
"version_label": r.version_label,
|
||||
"language": r.language,
|
||||
"license_id": r.license_id,
|
||||
"license_name": r.license_name,
|
||||
"commercial_use": r.commercial_use,
|
||||
"allowed_analysis": r.allowed_analysis,
|
||||
"allowed_store_excerpt": r.allowed_store_excerpt,
|
||||
"allowed_ship_embeddings": r.allowed_ship_embeddings,
|
||||
"allowed_ship_in_product": r.allowed_ship_in_product,
|
||||
"vault_retention_days": r.vault_retention_days,
|
||||
"vault_access_tier": r.vault_access_tier,
|
||||
}
|
||||
for r in rows
|
||||
]
|
||||
624
control-pipeline/services/llm_provider.py
Normal file
624
control-pipeline/services/llm_provider.py
Normal file
@@ -0,0 +1,624 @@
|
||||
"""
|
||||
LLM Provider Abstraction for Compliance AI Features.
|
||||
|
||||
Supports:
|
||||
- Anthropic Claude API (default)
|
||||
- Self-Hosted LLMs (Ollama, vLLM, LocalAI, etc.)
|
||||
- HashiCorp Vault integration for secure API key storage
|
||||
|
||||
Configuration via environment variables:
|
||||
- COMPLIANCE_LLM_PROVIDER: "anthropic" or "self_hosted"
|
||||
- ANTHROPIC_API_KEY: API key for Claude (or loaded from Vault)
|
||||
- ANTHROPIC_MODEL: Model name (default: claude-sonnet-4-20250514)
|
||||
- SELF_HOSTED_LLM_URL: Base URL for self-hosted LLM
|
||||
- SELF_HOSTED_LLM_MODEL: Model name for self-hosted
|
||||
- SELF_HOSTED_LLM_KEY: Optional API key for self-hosted
|
||||
|
||||
Vault Configuration:
|
||||
- VAULT_ADDR: Vault server address (e.g., http://vault:8200)
|
||||
- VAULT_TOKEN: Vault authentication token
|
||||
- USE_VAULT_SECRETS: Set to "true" to enable Vault integration
|
||||
- VAULT_SECRET_PATH: Path to secrets (default: secret/breakpilot/api_keys)
|
||||
"""
|
||||
|
||||
import os
|
||||
import asyncio
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Optional, Dict, Any
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
|
||||
import httpx
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Vault Integration
|
||||
# =============================================================================
|
||||
|
||||
class VaultClient:
|
||||
"""
|
||||
HashiCorp Vault client for retrieving secrets.
|
||||
|
||||
Supports KV v2 secrets engine.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
addr: Optional[str] = None,
|
||||
token: Optional[str] = None
|
||||
):
|
||||
self.addr = addr or os.getenv("VAULT_ADDR", "http://localhost:8200")
|
||||
self.token = token or os.getenv("VAULT_TOKEN")
|
||||
self._cache: Dict[str, Any] = {}
|
||||
self._cache_ttl = 300 # 5 minutes cache
|
||||
|
||||
def _get_headers(self) -> Dict[str, str]:
|
||||
"""Get request headers with Vault token."""
|
||||
headers = {"Content-Type": "application/json"}
|
||||
if self.token:
|
||||
headers["X-Vault-Token"] = self.token
|
||||
return headers
|
||||
|
||||
def get_secret(self, path: str, key: str = "value") -> Optional[str]:
|
||||
"""
|
||||
Get a secret from Vault KV v2.
|
||||
|
||||
Args:
|
||||
path: Secret path (e.g., "breakpilot/api_keys/anthropic")
|
||||
key: Key within the secret data (default: "value")
|
||||
|
||||
Returns:
|
||||
Secret value or None if not found
|
||||
"""
|
||||
cache_key = f"{path}:{key}"
|
||||
|
||||
# Check cache first
|
||||
if cache_key in self._cache:
|
||||
return self._cache[cache_key]
|
||||
|
||||
try:
|
||||
# KV v2 uses /data/ in the path
|
||||
full_path = f"{self.addr}/v1/secret/data/{path}"
|
||||
|
||||
response = httpx.get(
|
||||
full_path,
|
||||
headers=self._get_headers(),
|
||||
timeout=10.0
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
secret_data = data.get("data", {}).get("data", {})
|
||||
secret_value = secret_data.get(key)
|
||||
|
||||
if secret_value:
|
||||
self._cache[cache_key] = secret_value
|
||||
logger.info(f"Successfully loaded secret from Vault: {path}")
|
||||
return secret_value
|
||||
|
||||
elif response.status_code == 404:
|
||||
logger.warning(f"Secret not found in Vault: {path}")
|
||||
else:
|
||||
logger.error(f"Vault error {response.status_code}: {response.text}")
|
||||
|
||||
except httpx.RequestError as e:
|
||||
logger.error(f"Failed to connect to Vault at {self.addr}: {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving secret from Vault: {e}")
|
||||
|
||||
return None
|
||||
|
||||
def get_anthropic_key(self) -> Optional[str]:
|
||||
"""Get Anthropic API key from Vault."""
|
||||
path = os.getenv("VAULT_ANTHROPIC_PATH", "breakpilot/api_keys/anthropic")
|
||||
return self.get_secret(path, "value")
|
||||
|
||||
def is_available(self) -> bool:
|
||||
"""Check if Vault is available and authenticated."""
|
||||
try:
|
||||
response = httpx.get(
|
||||
f"{self.addr}/v1/sys/health",
|
||||
headers=self._get_headers(),
|
||||
timeout=5.0
|
||||
)
|
||||
return response.status_code in (200, 429, 472, 473, 501, 503)
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
# Singleton Vault client
|
||||
_vault_client: Optional[VaultClient] = None
|
||||
|
||||
|
||||
def get_vault_client() -> VaultClient:
|
||||
"""Get shared Vault client instance."""
|
||||
global _vault_client
|
||||
if _vault_client is None:
|
||||
_vault_client = VaultClient()
|
||||
return _vault_client
|
||||
|
||||
|
||||
def get_secret_from_vault_or_env(
|
||||
vault_path: str,
|
||||
env_var: str,
|
||||
vault_key: str = "value"
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Get a secret, trying Vault first, then falling back to environment variable.
|
||||
|
||||
Args:
|
||||
vault_path: Path in Vault (e.g., "breakpilot/api_keys/anthropic")
|
||||
env_var: Environment variable name as fallback
|
||||
vault_key: Key within Vault secret data
|
||||
|
||||
Returns:
|
||||
Secret value or None
|
||||
"""
|
||||
use_vault = os.getenv("USE_VAULT_SECRETS", "").lower() in ("true", "1", "yes")
|
||||
|
||||
if use_vault:
|
||||
vault = get_vault_client()
|
||||
secret = vault.get_secret(vault_path, vault_key)
|
||||
if secret:
|
||||
return secret
|
||||
logger.info(f"Vault secret not found, falling back to env: {env_var}")
|
||||
|
||||
return os.getenv(env_var)
|
||||
|
||||
|
||||
class LLMProviderType(str, Enum):
|
||||
"""Supported LLM provider types."""
|
||||
ANTHROPIC = "anthropic"
|
||||
SELF_HOSTED = "self_hosted"
|
||||
OLLAMA = "ollama" # Alias for self_hosted (Ollama-specific)
|
||||
MOCK = "mock" # For testing
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMResponse:
|
||||
"""Standard response from LLM."""
|
||||
content: str
|
||||
model: str
|
||||
provider: str
|
||||
usage: Optional[Dict[str, int]] = None
|
||||
raw_response: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMConfig:
|
||||
"""Configuration for LLM provider."""
|
||||
provider_type: LLMProviderType
|
||||
api_key: Optional[str] = None
|
||||
model: str = "claude-sonnet-4-20250514"
|
||||
base_url: Optional[str] = None
|
||||
max_tokens: int = 4096
|
||||
temperature: float = 0.3
|
||||
timeout: float = 60.0
|
||||
|
||||
|
||||
class LLMProvider(ABC):
|
||||
"""Abstract base class for LLM providers."""
|
||||
|
||||
def __init__(self, config: LLMConfig):
|
||||
self.config = config
|
||||
|
||||
@abstractmethod
|
||||
async def complete(
|
||||
self,
|
||||
prompt: str,
|
||||
system_prompt: Optional[str] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
temperature: Optional[float] = None
|
||||
) -> LLMResponse:
|
||||
"""Generate a completion for the given prompt."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def batch_complete(
|
||||
self,
|
||||
prompts: List[str],
|
||||
system_prompt: Optional[str] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
rate_limit: float = 1.0
|
||||
) -> List[LLMResponse]:
|
||||
"""Generate completions for multiple prompts with rate limiting."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def provider_name(self) -> str:
|
||||
"""Return the provider name."""
|
||||
pass
|
||||
|
||||
|
||||
class AnthropicProvider(LLMProvider):
|
||||
"""Claude API Provider using Anthropic's official API."""
|
||||
|
||||
ANTHROPIC_API_URL = "https://api.anthropic.com/v1/messages"
|
||||
|
||||
def __init__(self, config: LLMConfig):
|
||||
super().__init__(config)
|
||||
if not config.api_key:
|
||||
raise ValueError("Anthropic API key is required")
|
||||
self.api_key = config.api_key
|
||||
self.model = config.model or "claude-sonnet-4-20250514"
|
||||
|
||||
@property
|
||||
def provider_name(self) -> str:
|
||||
return "anthropic"
|
||||
|
||||
async def complete(
|
||||
self,
|
||||
prompt: str,
|
||||
system_prompt: Optional[str] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
temperature: Optional[float] = None
|
||||
) -> LLMResponse:
|
||||
"""Generate completion using Claude API."""
|
||||
|
||||
headers = {
|
||||
"x-api-key": self.api_key,
|
||||
"anthropic-version": "2023-06-01",
|
||||
"content-type": "application/json"
|
||||
}
|
||||
|
||||
messages = [{"role": "user", "content": prompt}]
|
||||
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"max_tokens": max_tokens or self.config.max_tokens,
|
||||
"messages": messages
|
||||
}
|
||||
|
||||
if system_prompt:
|
||||
payload["system"] = system_prompt
|
||||
|
||||
if temperature is not None:
|
||||
payload["temperature"] = temperature
|
||||
elif self.config.temperature is not None:
|
||||
payload["temperature"] = self.config.temperature
|
||||
|
||||
async with httpx.AsyncClient(timeout=self.config.timeout) as client:
|
||||
try:
|
||||
response = await client.post(
|
||||
self.ANTHROPIC_API_URL,
|
||||
headers=headers,
|
||||
json=payload
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
content = ""
|
||||
if data.get("content"):
|
||||
content = data["content"][0].get("text", "")
|
||||
|
||||
return LLMResponse(
|
||||
content=content,
|
||||
model=self.model,
|
||||
provider=self.provider_name,
|
||||
usage=data.get("usage"),
|
||||
raw_response=data
|
||||
)
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
logger.error(f"Anthropic API error: {e.response.status_code} - {e.response.text}")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Anthropic API request failed: {e}")
|
||||
raise
|
||||
|
||||
async def batch_complete(
|
||||
self,
|
||||
prompts: List[str],
|
||||
system_prompt: Optional[str] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
rate_limit: float = 1.0
|
||||
) -> List[LLMResponse]:
|
||||
"""Process multiple prompts with rate limiting."""
|
||||
results = []
|
||||
|
||||
for i, prompt in enumerate(prompts):
|
||||
if i > 0:
|
||||
await asyncio.sleep(rate_limit)
|
||||
|
||||
try:
|
||||
result = await self.complete(
|
||||
prompt=prompt,
|
||||
system_prompt=system_prompt,
|
||||
max_tokens=max_tokens
|
||||
)
|
||||
results.append(result)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to process prompt {i}: {e}")
|
||||
# Append error response
|
||||
results.append(LLMResponse(
|
||||
content=f"Error: {str(e)}",
|
||||
model=self.model,
|
||||
provider=self.provider_name
|
||||
))
|
||||
|
||||
return results
|
||||
|
||||
|
||||
class SelfHostedProvider(LLMProvider):
|
||||
"""Self-Hosted LLM Provider supporting Ollama, vLLM, LocalAI, etc."""
|
||||
|
||||
def __init__(self, config: LLMConfig):
|
||||
super().__init__(config)
|
||||
if not config.base_url:
|
||||
raise ValueError("Base URL is required for self-hosted provider")
|
||||
self.base_url = config.base_url.rstrip("/")
|
||||
self.model = config.model
|
||||
self.api_key = config.api_key
|
||||
|
||||
@property
|
||||
def provider_name(self) -> str:
|
||||
return "self_hosted"
|
||||
|
||||
def _detect_api_format(self) -> str:
|
||||
"""Detect the API format based on URL patterns."""
|
||||
if "11434" in self.base_url or "ollama" in self.base_url.lower():
|
||||
return "ollama"
|
||||
elif "openai" in self.base_url.lower() or "v1" in self.base_url:
|
||||
return "openai"
|
||||
else:
|
||||
return "ollama" # Default to Ollama format
|
||||
|
||||
async def complete(
|
||||
self,
|
||||
prompt: str,
|
||||
system_prompt: Optional[str] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
temperature: Optional[float] = None
|
||||
) -> LLMResponse:
|
||||
"""Generate completion using self-hosted LLM."""
|
||||
|
||||
api_format = self._detect_api_format()
|
||||
|
||||
headers = {"content-type": "application/json"}
|
||||
if self.api_key:
|
||||
headers["Authorization"] = f"Bearer {self.api_key}"
|
||||
|
||||
if api_format == "ollama":
|
||||
# Ollama API format
|
||||
endpoint = f"{self.base_url}/api/generate"
|
||||
full_prompt = prompt
|
||||
if system_prompt:
|
||||
full_prompt = f"{system_prompt}\n\n{prompt}"
|
||||
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"prompt": full_prompt,
|
||||
"stream": False,
|
||||
"think": False, # Disable thinking mode (qwen3.5 etc.)
|
||||
"options": {}
|
||||
}
|
||||
|
||||
if max_tokens:
|
||||
payload["options"]["num_predict"] = max_tokens
|
||||
if temperature is not None:
|
||||
payload["options"]["temperature"] = temperature
|
||||
|
||||
else:
|
||||
# OpenAI-compatible format (vLLM, LocalAI, etc.)
|
||||
endpoint = f"{self.base_url}/v1/chat/completions"
|
||||
|
||||
messages = []
|
||||
if system_prompt:
|
||||
messages.append({"role": "system", "content": system_prompt})
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"messages": messages,
|
||||
"max_tokens": max_tokens or self.config.max_tokens,
|
||||
"temperature": temperature if temperature is not None else self.config.temperature
|
||||
}
|
||||
|
||||
async with httpx.AsyncClient(timeout=self.config.timeout) as client:
|
||||
try:
|
||||
response = await client.post(endpoint, headers=headers, json=payload)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
# Parse response based on format
|
||||
if api_format == "ollama":
|
||||
content = data.get("response", "")
|
||||
else:
|
||||
# OpenAI format
|
||||
content = data.get("choices", [{}])[0].get("message", {}).get("content", "")
|
||||
|
||||
return LLMResponse(
|
||||
content=content,
|
||||
model=self.model,
|
||||
provider=self.provider_name,
|
||||
usage=data.get("usage"),
|
||||
raw_response=data
|
||||
)
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
logger.error(f"Self-hosted LLM error: {e.response.status_code} - {e.response.text}")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Self-hosted LLM request failed: {e}")
|
||||
raise
|
||||
|
||||
async def batch_complete(
|
||||
self,
|
||||
prompts: List[str],
|
||||
system_prompt: Optional[str] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
rate_limit: float = 0.5 # Self-hosted can be faster
|
||||
) -> List[LLMResponse]:
|
||||
"""Process multiple prompts with rate limiting."""
|
||||
results = []
|
||||
|
||||
for i, prompt in enumerate(prompts):
|
||||
if i > 0:
|
||||
await asyncio.sleep(rate_limit)
|
||||
|
||||
try:
|
||||
result = await self.complete(
|
||||
prompt=prompt,
|
||||
system_prompt=system_prompt,
|
||||
max_tokens=max_tokens
|
||||
)
|
||||
results.append(result)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to process prompt {i}: {e}")
|
||||
results.append(LLMResponse(
|
||||
content=f"Error: {str(e)}",
|
||||
model=self.model,
|
||||
provider=self.provider_name
|
||||
))
|
||||
|
||||
return results
|
||||
|
||||
|
||||
class MockProvider(LLMProvider):
|
||||
"""Mock provider for testing without actual API calls."""
|
||||
|
||||
def __init__(self, config: LLMConfig):
|
||||
super().__init__(config)
|
||||
self.responses: List[str] = []
|
||||
self.call_count = 0
|
||||
|
||||
@property
|
||||
def provider_name(self) -> str:
|
||||
return "mock"
|
||||
|
||||
def set_responses(self, responses: List[str]):
|
||||
"""Set predetermined responses for testing."""
|
||||
self.responses = responses
|
||||
self.call_count = 0
|
||||
|
||||
async def complete(
|
||||
self,
|
||||
prompt: str,
|
||||
system_prompt: Optional[str] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
temperature: Optional[float] = None
|
||||
) -> LLMResponse:
|
||||
"""Return mock response."""
|
||||
if self.responses:
|
||||
content = self.responses[self.call_count % len(self.responses)]
|
||||
else:
|
||||
content = f"Mock response for: {prompt[:50]}..."
|
||||
|
||||
self.call_count += 1
|
||||
|
||||
return LLMResponse(
|
||||
content=content,
|
||||
model="mock-model",
|
||||
provider=self.provider_name,
|
||||
usage={"input_tokens": len(prompt), "output_tokens": len(content)}
|
||||
)
|
||||
|
||||
async def batch_complete(
|
||||
self,
|
||||
prompts: List[str],
|
||||
system_prompt: Optional[str] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
rate_limit: float = 0.0
|
||||
) -> List[LLMResponse]:
|
||||
"""Return mock responses for batch."""
|
||||
return [await self.complete(p, system_prompt, max_tokens) for p in prompts]
|
||||
|
||||
|
||||
def get_llm_config() -> LLMConfig:
|
||||
"""
|
||||
Create LLM config from environment variables or Vault.
|
||||
|
||||
Priority for API key:
|
||||
1. Vault (if USE_VAULT_SECRETS=true and Vault is available)
|
||||
2. Environment variable (ANTHROPIC_API_KEY)
|
||||
"""
|
||||
provider_type_str = os.getenv("COMPLIANCE_LLM_PROVIDER", "anthropic")
|
||||
|
||||
try:
|
||||
provider_type = LLMProviderType(provider_type_str)
|
||||
except ValueError:
|
||||
logger.warning(f"Unknown LLM provider: {provider_type_str}, falling back to mock")
|
||||
provider_type = LLMProviderType.MOCK
|
||||
|
||||
# Get API key from Vault or environment
|
||||
api_key = None
|
||||
if provider_type == LLMProviderType.ANTHROPIC:
|
||||
api_key = get_secret_from_vault_or_env(
|
||||
vault_path="breakpilot/api_keys/anthropic",
|
||||
env_var="ANTHROPIC_API_KEY"
|
||||
)
|
||||
elif provider_type in (LLMProviderType.SELF_HOSTED, LLMProviderType.OLLAMA):
|
||||
api_key = get_secret_from_vault_or_env(
|
||||
vault_path="breakpilot/api_keys/self_hosted_llm",
|
||||
env_var="SELF_HOSTED_LLM_KEY"
|
||||
)
|
||||
|
||||
# Select model based on provider type
|
||||
if provider_type == LLMProviderType.ANTHROPIC:
|
||||
model = os.getenv("ANTHROPIC_MODEL", "claude-sonnet-4-20250514")
|
||||
elif provider_type in (LLMProviderType.SELF_HOSTED, LLMProviderType.OLLAMA):
|
||||
model = os.getenv("SELF_HOSTED_LLM_MODEL", "qwen2.5:14b")
|
||||
else:
|
||||
model = "mock-model"
|
||||
|
||||
return LLMConfig(
|
||||
provider_type=provider_type,
|
||||
api_key=api_key,
|
||||
model=model,
|
||||
base_url=os.getenv("SELF_HOSTED_LLM_URL"),
|
||||
max_tokens=int(os.getenv("COMPLIANCE_LLM_MAX_TOKENS", "4096")),
|
||||
temperature=float(os.getenv("COMPLIANCE_LLM_TEMPERATURE", "0.3")),
|
||||
timeout=float(os.getenv("COMPLIANCE_LLM_TIMEOUT", "60.0"))
|
||||
)
|
||||
|
||||
|
||||
def get_llm_provider(config: Optional[LLMConfig] = None) -> LLMProvider:
|
||||
"""
|
||||
Factory function to get the appropriate LLM provider based on configuration.
|
||||
|
||||
Usage:
|
||||
provider = get_llm_provider()
|
||||
response = await provider.complete("Analyze this requirement...")
|
||||
"""
|
||||
if config is None:
|
||||
config = get_llm_config()
|
||||
|
||||
if config.provider_type == LLMProviderType.ANTHROPIC:
|
||||
if not config.api_key:
|
||||
logger.warning("No Anthropic API key found, using mock provider")
|
||||
return MockProvider(config)
|
||||
return AnthropicProvider(config)
|
||||
|
||||
elif config.provider_type in (LLMProviderType.SELF_HOSTED, LLMProviderType.OLLAMA):
|
||||
if not config.base_url:
|
||||
logger.warning("No self-hosted LLM URL found, using mock provider")
|
||||
return MockProvider(config)
|
||||
return SelfHostedProvider(config)
|
||||
|
||||
elif config.provider_type == LLMProviderType.MOCK:
|
||||
return MockProvider(config)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unsupported LLM provider type: {config.provider_type}")
|
||||
|
||||
|
||||
# Singleton instance for reuse
|
||||
_provider_instance: Optional[LLMProvider] = None
|
||||
|
||||
|
||||
def get_shared_provider() -> LLMProvider:
|
||||
"""Get a shared LLM provider instance."""
|
||||
global _provider_instance
|
||||
if _provider_instance is None:
|
||||
_provider_instance = get_llm_provider()
|
||||
return _provider_instance
|
||||
|
||||
|
||||
def reset_shared_provider():
|
||||
"""Reset the shared provider instance (useful for testing)."""
|
||||
global _provider_instance
|
||||
_provider_instance = None
|
||||
59
control-pipeline/services/normative_patterns.py
Normal file
59
control-pipeline/services/normative_patterns.py
Normal file
@@ -0,0 +1,59 @@
|
||||
"""Shared normative language patterns for assertion classification.
|
||||
|
||||
Extracted from decomposition_pass.py for reuse in the assertion engine.
|
||||
"""
|
||||
|
||||
import re
|
||||
|
||||
_PFLICHT_SIGNALS = [
|
||||
r"\bmüssen\b", r"\bmuss\b", r"\bhat\s+sicherzustellen\b",
|
||||
r"\bhaben\s+sicherzustellen\b", r"\bsind\s+verpflichtet\b",
|
||||
r"\bist\s+verpflichtet\b",
|
||||
r"\bist\s+zu\s+\w+en\b", r"\bsind\s+zu\s+\w+en\b",
|
||||
r"\bhat\s+zu\s+\w+en\b", r"\bhaben\s+zu\s+\w+en\b",
|
||||
r"\bist\s+\w+zu\w+en\b", r"\bsind\s+\w+zu\w+en\b",
|
||||
r"\bist\s+\w+\s+zu\s+\w+en\b", r"\bsind\s+\w+\s+zu\s+\w+en\b",
|
||||
r"\bhat\s+\w+\s+zu\s+\w+en\b", r"\bhaben\s+\w+\s+zu\s+\w+en\b",
|
||||
r"\bshall\b", r"\bmust\b", r"\brequired\b",
|
||||
r"\b\w+zuteilen\b", r"\b\w+zuwenden\b", r"\b\w+zustellen\b", r"\b\w+zulegen\b",
|
||||
r"\b\w+zunehmen\b", r"\b\w+zuführen\b", r"\b\w+zuhalten\b", r"\b\w+zusetzen\b",
|
||||
r"\b\w+zuweisen\b", r"\b\w+zuordnen\b", r"\b\w+zufügen\b", r"\b\w+zugeben\b",
|
||||
r"\bist\b.{1,80}\bzu\s+\w+en\b", r"\bsind\b.{1,80}\bzu\s+\w+en\b",
|
||||
]
|
||||
PFLICHT_RE = re.compile("|".join(_PFLICHT_SIGNALS), re.IGNORECASE)
|
||||
|
||||
_EMPFEHLUNG_SIGNALS = [
|
||||
r"\bsoll\b", r"\bsollen\b", r"\bsollte\b", r"\bsollten\b",
|
||||
r"\bgewährleisten\b", r"\bsicherstellen\b",
|
||||
r"\bshould\b", r"\bensure\b", r"\brecommend\w*\b",
|
||||
r"\bnachweisen\b", r"\beinhalten\b", r"\bunterlassen\b", r"\bwahren\b",
|
||||
r"\bdokumentieren\b", r"\bimplementieren\b", r"\büberprüfen\b", r"\büberwachen\b",
|
||||
r"\bprüfen,\s+ob\b", r"\bkontrollieren,\s+ob\b",
|
||||
]
|
||||
EMPFEHLUNG_RE = re.compile("|".join(_EMPFEHLUNG_SIGNALS), re.IGNORECASE)
|
||||
|
||||
_KANN_SIGNALS = [
|
||||
r"\bkann\b", r"\bkönnen\b", r"\bdarf\b", r"\bdürfen\b",
|
||||
r"\bmay\b", r"\boptional\b",
|
||||
]
|
||||
KANN_RE = re.compile("|".join(_KANN_SIGNALS), re.IGNORECASE)
|
||||
|
||||
NORMATIVE_RE = re.compile(
|
||||
"|".join(_PFLICHT_SIGNALS + _EMPFEHLUNG_SIGNALS + _KANN_SIGNALS),
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
_RATIONALE_SIGNALS = [
|
||||
r"\bda\s+", r"\bweil\b", r"\bgrund\b", r"\berwägung",
|
||||
r"\bbecause\b", r"\breason\b", r"\brationale\b",
|
||||
r"\bkönnen\s+.*\s+verursachen\b", r"\bführt\s+zu\b",
|
||||
]
|
||||
RATIONALE_RE = re.compile("|".join(_RATIONALE_SIGNALS), re.IGNORECASE)
|
||||
|
||||
# Evidence-related keywords (for fact detection)
|
||||
_EVIDENCE_KEYWORDS = [
|
||||
r"\bnachweis\b", r"\bzertifikat\b", r"\baudit.report\b",
|
||||
r"\bprotokoll\b", r"\bdokumentation\b", r"\bbericht\b",
|
||||
r"\bcertificate\b", r"\bevidence\b", r"\bproof\b",
|
||||
]
|
||||
EVIDENCE_RE = re.compile("|".join(_EVIDENCE_KEYWORDS), re.IGNORECASE)
|
||||
563
control-pipeline/services/obligation_extractor.py
Normal file
563
control-pipeline/services/obligation_extractor.py
Normal file
@@ -0,0 +1,563 @@
|
||||
"""Obligation Extractor — 3-Tier Chunk-to-Obligation Linking.
|
||||
|
||||
Maps RAG chunks to obligations from the v2 obligation framework using
|
||||
three tiers (fastest first):
|
||||
|
||||
Tier 1: EXACT MATCH — regulation_code + article → obligation_id (~40%)
|
||||
Tier 2: EMBEDDING — chunk text vs. obligation descriptions (~30%)
|
||||
Tier 3: LLM EXTRACT — local Ollama extracts obligation text (~25%)
|
||||
|
||||
Part of the Multi-Layer Control Architecture (Phase 4 of 8).
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import httpx
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
EMBEDDING_URL = os.getenv("EMBEDDING_URL", "http://embedding-service:8087")
|
||||
OLLAMA_URL = os.getenv("OLLAMA_URL", "http://host.docker.internal:11434")
|
||||
OLLAMA_MODEL = os.getenv("CONTROL_GEN_OLLAMA_MODEL", "qwen3.5:35b-a3b")
|
||||
LLM_TIMEOUT = float(os.getenv("CONTROL_GEN_LLM_TIMEOUT", "180"))
|
||||
|
||||
# Embedding similarity thresholds for Tier 2
|
||||
EMBEDDING_MATCH_THRESHOLD = 0.80
|
||||
EMBEDDING_CANDIDATE_THRESHOLD = 0.60
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Regulation code mapping: RAG chunk codes → obligation file regulation IDs
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_REGULATION_CODE_TO_ID = {
|
||||
# DSGVO
|
||||
"eu_2016_679": "dsgvo",
|
||||
"dsgvo": "dsgvo",
|
||||
"gdpr": "dsgvo",
|
||||
# AI Act
|
||||
"eu_2024_1689": "ai_act",
|
||||
"ai_act": "ai_act",
|
||||
"aiact": "ai_act",
|
||||
# NIS2
|
||||
"eu_2022_2555": "nis2",
|
||||
"nis2": "nis2",
|
||||
"bsig": "nis2",
|
||||
# BDSG
|
||||
"bdsg": "bdsg",
|
||||
# TTDSG
|
||||
"ttdsg": "ttdsg",
|
||||
# DSA
|
||||
"eu_2022_2065": "dsa",
|
||||
"dsa": "dsa",
|
||||
# Data Act
|
||||
"eu_2023_2854": "data_act",
|
||||
"data_act": "data_act",
|
||||
# EU Machinery
|
||||
"eu_2023_1230": "eu_machinery",
|
||||
"eu_machinery": "eu_machinery",
|
||||
# DORA
|
||||
"eu_2022_2554": "dora",
|
||||
"dora": "dora",
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class ObligationMatch:
|
||||
"""Result of obligation extraction."""
|
||||
|
||||
obligation_id: Optional[str] = None
|
||||
obligation_title: Optional[str] = None
|
||||
obligation_text: Optional[str] = None
|
||||
method: str = "none" # exact_match | embedding_match | llm_extracted | inferred
|
||||
confidence: float = 0.0
|
||||
regulation_id: Optional[str] = None # e.g. "dsgvo"
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
"obligation_id": self.obligation_id,
|
||||
"obligation_title": self.obligation_title,
|
||||
"obligation_text": self.obligation_text,
|
||||
"method": self.method,
|
||||
"confidence": self.confidence,
|
||||
"regulation_id": self.regulation_id,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class _ObligationEntry:
|
||||
"""Internal representation of a loaded obligation."""
|
||||
|
||||
id: str
|
||||
title: str
|
||||
description: str
|
||||
regulation_id: str
|
||||
articles: list[str] = field(default_factory=list) # normalized: ["art. 30", "§ 38"]
|
||||
embedding: list[float] = field(default_factory=list)
|
||||
|
||||
|
||||
class ObligationExtractor:
|
||||
"""3-Tier obligation extraction from RAG chunks.
|
||||
|
||||
Usage::
|
||||
|
||||
extractor = ObligationExtractor()
|
||||
await extractor.initialize() # loads obligations + embeddings
|
||||
|
||||
match = await extractor.extract(
|
||||
chunk_text="...",
|
||||
regulation_code="eu_2016_679",
|
||||
article="Art. 30",
|
||||
paragraph="Abs. 1",
|
||||
)
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._article_lookup: dict[str, list[str]] = {} # "dsgvo/art. 30" → ["DSGVO-OBL-001"]
|
||||
self._obligations: dict[str, _ObligationEntry] = {} # id → entry
|
||||
self._obligation_embeddings: list[list[float]] = []
|
||||
self._obligation_ids: list[str] = []
|
||||
self._initialized = False
|
||||
|
||||
async def initialize(self) -> None:
|
||||
"""Load all obligations from v2 JSON files and compute embeddings."""
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
self._load_obligations()
|
||||
await self._compute_embeddings()
|
||||
self._initialized = True
|
||||
logger.info(
|
||||
"ObligationExtractor initialized: %d obligations, %d article lookups, %d embeddings",
|
||||
len(self._obligations),
|
||||
len(self._article_lookup),
|
||||
sum(1 for e in self._obligation_embeddings if e),
|
||||
)
|
||||
|
||||
async def extract(
|
||||
self,
|
||||
chunk_text: str,
|
||||
regulation_code: str,
|
||||
article: Optional[str] = None,
|
||||
paragraph: Optional[str] = None,
|
||||
) -> ObligationMatch:
|
||||
"""Extract obligation from a chunk using 3-tier strategy."""
|
||||
if not self._initialized:
|
||||
await self.initialize()
|
||||
|
||||
reg_id = _normalize_regulation(regulation_code)
|
||||
|
||||
# Tier 1: Exact match via article lookup
|
||||
if article:
|
||||
match = self._tier1_exact(reg_id, article)
|
||||
if match:
|
||||
return match
|
||||
|
||||
# Tier 2: Embedding similarity
|
||||
match = await self._tier2_embedding(chunk_text, reg_id)
|
||||
if match:
|
||||
return match
|
||||
|
||||
# Tier 3: LLM extraction
|
||||
match = await self._tier3_llm(chunk_text, regulation_code, article)
|
||||
return match
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# Tier 1: Exact Match
|
||||
# -----------------------------------------------------------------------
|
||||
|
||||
def _tier1_exact(self, reg_id: Optional[str], article: str) -> Optional[ObligationMatch]:
|
||||
"""Look up obligation by regulation + article."""
|
||||
if not reg_id:
|
||||
return None
|
||||
|
||||
norm_article = _normalize_article(article)
|
||||
key = f"{reg_id}/{norm_article}"
|
||||
|
||||
obl_ids = self._article_lookup.get(key)
|
||||
if not obl_ids:
|
||||
return None
|
||||
|
||||
# Take the first match (highest priority)
|
||||
obl_id = obl_ids[0]
|
||||
entry = self._obligations.get(obl_id)
|
||||
if not entry:
|
||||
return None
|
||||
|
||||
return ObligationMatch(
|
||||
obligation_id=entry.id,
|
||||
obligation_title=entry.title,
|
||||
obligation_text=entry.description,
|
||||
method="exact_match",
|
||||
confidence=1.0,
|
||||
regulation_id=reg_id,
|
||||
)
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# Tier 2: Embedding Match
|
||||
# -----------------------------------------------------------------------
|
||||
|
||||
async def _tier2_embedding(
|
||||
self, chunk_text: str, reg_id: Optional[str]
|
||||
) -> Optional[ObligationMatch]:
|
||||
"""Find nearest obligation by embedding similarity."""
|
||||
if not self._obligation_embeddings:
|
||||
return None
|
||||
|
||||
chunk_embedding = await _get_embedding(chunk_text[:2000])
|
||||
if not chunk_embedding:
|
||||
return None
|
||||
|
||||
best_idx = -1
|
||||
best_score = 0.0
|
||||
|
||||
for i, obl_emb in enumerate(self._obligation_embeddings):
|
||||
if not obl_emb:
|
||||
continue
|
||||
# Prefer same-regulation matches
|
||||
obl_id = self._obligation_ids[i]
|
||||
entry = self._obligations.get(obl_id)
|
||||
score = _cosine_sim(chunk_embedding, obl_emb)
|
||||
|
||||
# Domain bonus: +0.05 if same regulation
|
||||
if entry and reg_id and entry.regulation_id == reg_id:
|
||||
score += 0.05
|
||||
|
||||
if score > best_score:
|
||||
best_score = score
|
||||
best_idx = i
|
||||
|
||||
if best_idx < 0:
|
||||
return None
|
||||
|
||||
# Remove domain bonus for threshold comparison
|
||||
raw_score = best_score
|
||||
obl_id = self._obligation_ids[best_idx]
|
||||
entry = self._obligations.get(obl_id)
|
||||
if entry and reg_id and entry.regulation_id == reg_id:
|
||||
raw_score -= 0.05
|
||||
|
||||
if raw_score >= EMBEDDING_MATCH_THRESHOLD:
|
||||
return ObligationMatch(
|
||||
obligation_id=entry.id if entry else obl_id,
|
||||
obligation_title=entry.title if entry else None,
|
||||
obligation_text=entry.description if entry else None,
|
||||
method="embedding_match",
|
||||
confidence=round(min(raw_score, 1.0), 3),
|
||||
regulation_id=entry.regulation_id if entry else reg_id,
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# Tier 3: LLM Extraction
|
||||
# -----------------------------------------------------------------------
|
||||
|
||||
async def _tier3_llm(
|
||||
self, chunk_text: str, regulation_code: str, article: Optional[str]
|
||||
) -> ObligationMatch:
|
||||
"""Use local LLM to extract the obligation from the chunk."""
|
||||
prompt = f"""Analysiere den folgenden Gesetzestext und extrahiere die zentrale rechtliche Pflicht.
|
||||
|
||||
Text:
|
||||
{chunk_text[:3000]}
|
||||
|
||||
Quelle: {regulation_code} {article or ''}
|
||||
|
||||
Antworte NUR als JSON:
|
||||
{{
|
||||
"obligation_text": "Die zentrale Pflicht in einem Satz",
|
||||
"actor": "Wer muss handeln (z.B. Verantwortlicher, Auftragsverarbeiter)",
|
||||
"action": "Was muss getan werden",
|
||||
"normative_strength": "muss|soll|kann"
|
||||
}}"""
|
||||
|
||||
system_prompt = (
|
||||
"Du bist ein Rechtsexperte fuer EU-Datenschutz- und Digitalrecht. "
|
||||
"Extrahiere die zentrale rechtliche Pflicht aus Gesetzestexten. "
|
||||
"Antworte ausschliesslich als JSON."
|
||||
)
|
||||
|
||||
result_text = await _llm_ollama(prompt, system_prompt)
|
||||
if not result_text:
|
||||
return ObligationMatch(
|
||||
method="llm_extracted",
|
||||
confidence=0.0,
|
||||
regulation_id=_normalize_regulation(regulation_code),
|
||||
)
|
||||
|
||||
parsed = _parse_json(result_text)
|
||||
obligation_text = parsed.get("obligation_text", result_text[:500])
|
||||
|
||||
return ObligationMatch(
|
||||
obligation_id=None,
|
||||
obligation_title=None,
|
||||
obligation_text=obligation_text,
|
||||
method="llm_extracted",
|
||||
confidence=0.60,
|
||||
regulation_id=_normalize_regulation(regulation_code),
|
||||
)
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# Initialization helpers
|
||||
# -----------------------------------------------------------------------
|
||||
|
||||
def _load_obligations(self) -> None:
|
||||
"""Load all obligation files from v2 framework."""
|
||||
v2_dir = _find_obligations_dir()
|
||||
if not v2_dir:
|
||||
logger.warning("Obligations v2 directory not found — Tier 1 disabled")
|
||||
return
|
||||
|
||||
manifest_path = v2_dir / "_manifest.json"
|
||||
if not manifest_path.exists():
|
||||
logger.warning("Manifest not found at %s", manifest_path)
|
||||
return
|
||||
|
||||
with open(manifest_path) as f:
|
||||
manifest = json.load(f)
|
||||
|
||||
for reg_info in manifest.get("regulations", []):
|
||||
reg_id = reg_info["id"]
|
||||
reg_file = v2_dir / reg_info["file"]
|
||||
if not reg_file.exists():
|
||||
logger.warning("Regulation file not found: %s", reg_file)
|
||||
continue
|
||||
|
||||
with open(reg_file) as f:
|
||||
data = json.load(f)
|
||||
|
||||
for obl in data.get("obligations", []):
|
||||
obl_id = obl["id"]
|
||||
entry = _ObligationEntry(
|
||||
id=obl_id,
|
||||
title=obl.get("title", ""),
|
||||
description=obl.get("description", ""),
|
||||
regulation_id=reg_id,
|
||||
)
|
||||
|
||||
# Build article lookup from legal_basis
|
||||
for basis in obl.get("legal_basis", []):
|
||||
article_raw = basis.get("article", "")
|
||||
if article_raw:
|
||||
norm_art = _normalize_article(article_raw)
|
||||
key = f"{reg_id}/{norm_art}"
|
||||
if key not in self._article_lookup:
|
||||
self._article_lookup[key] = []
|
||||
self._article_lookup[key].append(obl_id)
|
||||
entry.articles.append(norm_art)
|
||||
|
||||
self._obligations[obl_id] = entry
|
||||
|
||||
logger.info(
|
||||
"Loaded %d obligations from %d regulations",
|
||||
len(self._obligations),
|
||||
len(manifest.get("regulations", [])),
|
||||
)
|
||||
|
||||
async def _compute_embeddings(self) -> None:
|
||||
"""Compute embeddings for all obligation descriptions."""
|
||||
if not self._obligations:
|
||||
return
|
||||
|
||||
self._obligation_ids = list(self._obligations.keys())
|
||||
texts = [
|
||||
f"{self._obligations[oid].title}: {self._obligations[oid].description}"
|
||||
for oid in self._obligation_ids
|
||||
]
|
||||
|
||||
logger.info("Computing embeddings for %d obligations...", len(texts))
|
||||
self._obligation_embeddings = await _get_embeddings_batch(texts)
|
||||
valid = sum(1 for e in self._obligation_embeddings if e)
|
||||
logger.info("Got %d/%d valid embeddings", valid, len(texts))
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# Stats
|
||||
# -----------------------------------------------------------------------
|
||||
|
||||
def stats(self) -> dict:
|
||||
"""Return initialization statistics."""
|
||||
return {
|
||||
"total_obligations": len(self._obligations),
|
||||
"article_lookups": len(self._article_lookup),
|
||||
"embeddings_valid": sum(1 for e in self._obligation_embeddings if e),
|
||||
"regulations": list(
|
||||
{e.regulation_id for e in self._obligations.values()}
|
||||
),
|
||||
"initialized": self._initialized,
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Module-level helpers (reusable by other modules)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _normalize_regulation(regulation_code: str) -> Optional[str]:
|
||||
"""Map a RAG regulation_code to obligation framework regulation ID."""
|
||||
if not regulation_code:
|
||||
return None
|
||||
code = regulation_code.lower().strip()
|
||||
|
||||
# Direct lookup
|
||||
if code in _REGULATION_CODE_TO_ID:
|
||||
return _REGULATION_CODE_TO_ID[code]
|
||||
|
||||
# Prefix matching for families
|
||||
for prefix, reg_id in [
|
||||
("eu_2016_679", "dsgvo"),
|
||||
("eu_2024_1689", "ai_act"),
|
||||
("eu_2022_2555", "nis2"),
|
||||
("eu_2022_2065", "dsa"),
|
||||
("eu_2023_2854", "data_act"),
|
||||
("eu_2023_1230", "eu_machinery"),
|
||||
("eu_2022_2554", "dora"),
|
||||
]:
|
||||
if code.startswith(prefix):
|
||||
return reg_id
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _normalize_article(article: str) -> str:
|
||||
"""Normalize article references for consistent lookup.
|
||||
|
||||
Examples:
|
||||
"Art. 30" → "art. 30"
|
||||
"§ 38 BDSG" → "§ 38"
|
||||
"Article 10" → "art. 10"
|
||||
"Art. 30 Abs. 1" → "art. 30"
|
||||
"Artikel 35" → "art. 35"
|
||||
"""
|
||||
if not article:
|
||||
return ""
|
||||
s = article.strip()
|
||||
|
||||
# Remove trailing law name: "§ 38 BDSG" → "§ 38"
|
||||
s = re.sub(r"\s+(DSGVO|BDSG|TTDSG|DSA|NIS2|DORA|AI.?Act)\s*$", "", s, flags=re.IGNORECASE)
|
||||
|
||||
# Remove paragraph references: "Art. 30 Abs. 1" → "Art. 30"
|
||||
s = re.sub(r"\s+(Abs|Absatz|para|paragraph|lit|Satz)\.?\s+.*$", "", s, flags=re.IGNORECASE)
|
||||
|
||||
# Normalize "Article" / "Artikel" → "Art."
|
||||
s = re.sub(r"^(Article|Artikel)\s+", "Art. ", s, flags=re.IGNORECASE)
|
||||
|
||||
return s.lower().strip()
|
||||
|
||||
|
||||
def _cosine_sim(a: list[float], b: list[float]) -> float:
|
||||
"""Compute cosine similarity between two vectors."""
|
||||
if not a or not b or len(a) != len(b):
|
||||
return 0.0
|
||||
dot = sum(x * y for x, y in zip(a, b))
|
||||
norm_a = sum(x * x for x in a) ** 0.5
|
||||
norm_b = sum(x * x for x in b) ** 0.5
|
||||
if norm_a == 0 or norm_b == 0:
|
||||
return 0.0
|
||||
return dot / (norm_a * norm_b)
|
||||
|
||||
|
||||
def _find_obligations_dir() -> Optional[Path]:
|
||||
"""Locate the obligations v2 directory."""
|
||||
candidates = [
|
||||
Path(__file__).resolve().parent.parent.parent.parent
|
||||
/ "ai-compliance-sdk" / "policies" / "obligations" / "v2",
|
||||
Path("/app/ai-compliance-sdk/policies/obligations/v2"),
|
||||
Path("ai-compliance-sdk/policies/obligations/v2"),
|
||||
]
|
||||
for p in candidates:
|
||||
if p.is_dir() and (p / "_manifest.json").exists():
|
||||
return p
|
||||
return None
|
||||
|
||||
|
||||
async def _get_embedding(text: str) -> list[float]:
|
||||
"""Get embedding vector for a single text."""
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
resp = await client.post(
|
||||
f"{EMBEDDING_URL}/embed",
|
||||
json={"texts": [text]},
|
||||
)
|
||||
resp.raise_for_status()
|
||||
embeddings = resp.json().get("embeddings", [])
|
||||
return embeddings[0] if embeddings else []
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
|
||||
async def _get_embeddings_batch(
|
||||
texts: list[str], batch_size: int = 32
|
||||
) -> list[list[float]]:
|
||||
"""Get embeddings for multiple texts in batches."""
|
||||
all_embeddings: list[list[float]] = []
|
||||
for i in range(0, len(texts), batch_size):
|
||||
batch = texts[i : i + batch_size]
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
resp = await client.post(
|
||||
f"{EMBEDDING_URL}/embed",
|
||||
json={"texts": batch},
|
||||
)
|
||||
resp.raise_for_status()
|
||||
embeddings = resp.json().get("embeddings", [])
|
||||
all_embeddings.extend(embeddings)
|
||||
except Exception as e:
|
||||
logger.warning("Batch embedding failed for %d texts: %s", len(batch), e)
|
||||
all_embeddings.extend([[] for _ in batch])
|
||||
return all_embeddings
|
||||
|
||||
|
||||
async def _llm_ollama(prompt: str, system_prompt: Optional[str] = None) -> str:
|
||||
"""Call local Ollama for LLM extraction."""
|
||||
messages = []
|
||||
if system_prompt:
|
||||
messages.append({"role": "system", "content": system_prompt})
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
|
||||
payload = {
|
||||
"model": OLLAMA_MODEL,
|
||||
"messages": messages,
|
||||
"stream": False,
|
||||
"format": "json",
|
||||
"options": {"num_predict": 512},
|
||||
"think": False,
|
||||
}
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=LLM_TIMEOUT) as client:
|
||||
resp = await client.post(f"{OLLAMA_URL}/api/chat", json=payload)
|
||||
if resp.status_code != 200:
|
||||
logger.error(
|
||||
"Ollama chat failed %d: %s", resp.status_code, resp.text[:300]
|
||||
)
|
||||
return ""
|
||||
data = resp.json()
|
||||
return data.get("message", {}).get("content", "")
|
||||
except Exception as e:
|
||||
logger.warning("Ollama call failed: %s", e)
|
||||
return ""
|
||||
|
||||
|
||||
def _parse_json(text: str) -> dict:
|
||||
"""Extract JSON from LLM response text."""
|
||||
# Try direct parse
|
||||
try:
|
||||
return json.loads(text)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# Try extracting JSON block
|
||||
match = re.search(r"\{[^{}]*\}", text, re.DOTALL)
|
||||
if match:
|
||||
try:
|
||||
return json.loads(match.group())
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
return {}
|
||||
532
control-pipeline/services/pattern_matcher.py
Normal file
532
control-pipeline/services/pattern_matcher.py
Normal file
@@ -0,0 +1,532 @@
|
||||
"""Pattern Matcher — Obligation-to-Control-Pattern Linking.
|
||||
|
||||
Maps obligations (from the ObligationExtractor) to control patterns
|
||||
using two tiers:
|
||||
|
||||
Tier 1: KEYWORD MATCH — obligation_match_keywords from patterns (~70%)
|
||||
Tier 2: EMBEDDING — cosine similarity with domain bonus (~25%)
|
||||
|
||||
Part of the Multi-Layer Control Architecture (Phase 5 of 8).
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import yaml
|
||||
|
||||
from services.obligation_extractor import (
|
||||
_cosine_sim,
|
||||
_get_embedding,
|
||||
_get_embeddings_batch,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Minimum keyword score to accept a match (at least 2 keyword hits)
|
||||
KEYWORD_MATCH_MIN_HITS = 2
|
||||
# Embedding threshold for Tier 2
|
||||
EMBEDDING_PATTERN_THRESHOLD = 0.75
|
||||
# Domain bonus when regulation maps to the pattern's domain
|
||||
DOMAIN_BONUS = 0.10
|
||||
|
||||
# Map regulation IDs to pattern domains that are likely relevant
|
||||
_REGULATION_DOMAIN_AFFINITY = {
|
||||
"dsgvo": ["DATA", "COMP", "GOV"],
|
||||
"bdsg": ["DATA", "COMP"],
|
||||
"ttdsg": ["DATA"],
|
||||
"ai_act": ["AI", "COMP", "DATA"],
|
||||
"nis2": ["SEC", "INC", "NET", "LOG", "CRYP"],
|
||||
"dsa": ["DATA", "COMP"],
|
||||
"data_act": ["DATA", "COMP"],
|
||||
"eu_machinery": ["SEC", "COMP"],
|
||||
"dora": ["SEC", "INC", "FIN", "COMP"],
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class ControlPattern:
|
||||
"""Python representation of a control pattern from YAML."""
|
||||
|
||||
id: str
|
||||
name: str
|
||||
name_de: str
|
||||
domain: str
|
||||
category: str
|
||||
description: str
|
||||
objective_template: str
|
||||
rationale_template: str
|
||||
requirements_template: list[str] = field(default_factory=list)
|
||||
test_procedure_template: list[str] = field(default_factory=list)
|
||||
evidence_template: list[str] = field(default_factory=list)
|
||||
severity_default: str = "medium"
|
||||
implementation_effort_default: str = "m"
|
||||
obligation_match_keywords: list[str] = field(default_factory=list)
|
||||
tags: list[str] = field(default_factory=list)
|
||||
composable_with: list[str] = field(default_factory=list)
|
||||
open_anchor_refs: list[dict] = field(default_factory=list)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PatternMatchResult:
|
||||
"""Result of pattern matching."""
|
||||
|
||||
pattern: Optional[ControlPattern] = None
|
||||
pattern_id: Optional[str] = None
|
||||
method: str = "none" # keyword | embedding | combined | none
|
||||
confidence: float = 0.0
|
||||
keyword_hits: int = 0
|
||||
total_keywords: int = 0
|
||||
embedding_score: float = 0.0
|
||||
domain_bonus_applied: bool = False
|
||||
composable_patterns: list[str] = field(default_factory=list)
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
"pattern_id": self.pattern_id,
|
||||
"method": self.method,
|
||||
"confidence": round(self.confidence, 3),
|
||||
"keyword_hits": self.keyword_hits,
|
||||
"total_keywords": self.total_keywords,
|
||||
"embedding_score": round(self.embedding_score, 3),
|
||||
"domain_bonus_applied": self.domain_bonus_applied,
|
||||
"composable_patterns": self.composable_patterns,
|
||||
}
|
||||
|
||||
|
||||
class PatternMatcher:
|
||||
"""Links obligations to control patterns using keyword + embedding matching.
|
||||
|
||||
Usage::
|
||||
|
||||
matcher = PatternMatcher()
|
||||
await matcher.initialize()
|
||||
|
||||
result = await matcher.match(
|
||||
obligation_text="Fuehrung eines Verarbeitungsverzeichnisses...",
|
||||
regulation_id="dsgvo",
|
||||
)
|
||||
print(result.pattern_id) # e.g. "CP-COMP-001"
|
||||
print(result.confidence) # e.g. 0.85
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._patterns: list[ControlPattern] = []
|
||||
self._by_id: dict[str, ControlPattern] = {}
|
||||
self._by_domain: dict[str, list[ControlPattern]] = {}
|
||||
self._keyword_index: dict[str, list[str]] = {} # keyword → [pattern_ids]
|
||||
self._pattern_embeddings: list[list[float]] = []
|
||||
self._pattern_ids: list[str] = []
|
||||
self._initialized = False
|
||||
|
||||
async def initialize(self) -> None:
|
||||
"""Load patterns from YAML and compute embeddings."""
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
self._load_patterns()
|
||||
self._build_keyword_index()
|
||||
await self._compute_embeddings()
|
||||
self._initialized = True
|
||||
logger.info(
|
||||
"PatternMatcher initialized: %d patterns, %d keywords, %d embeddings",
|
||||
len(self._patterns),
|
||||
len(self._keyword_index),
|
||||
sum(1 for e in self._pattern_embeddings if e),
|
||||
)
|
||||
|
||||
async def match(
|
||||
self,
|
||||
obligation_text: str,
|
||||
regulation_id: Optional[str] = None,
|
||||
top_n: int = 1,
|
||||
) -> PatternMatchResult:
|
||||
"""Match obligation text to the best control pattern.
|
||||
|
||||
Args:
|
||||
obligation_text: The obligation description to match against.
|
||||
regulation_id: Source regulation (for domain bonus).
|
||||
top_n: Number of top results to consider for composability.
|
||||
|
||||
Returns:
|
||||
PatternMatchResult with the best match.
|
||||
"""
|
||||
if not self._initialized:
|
||||
await self.initialize()
|
||||
|
||||
if not obligation_text or not self._patterns:
|
||||
return PatternMatchResult()
|
||||
|
||||
# Tier 1: Keyword matching
|
||||
keyword_result = self._tier1_keyword(obligation_text, regulation_id)
|
||||
|
||||
# Tier 2: Embedding matching
|
||||
embedding_result = await self._tier2_embedding(obligation_text, regulation_id)
|
||||
|
||||
# Combine scores: prefer keyword match, boost with embedding if available
|
||||
best = self._combine_results(keyword_result, embedding_result)
|
||||
|
||||
# Attach composable patterns
|
||||
if best.pattern:
|
||||
best.composable_patterns = [
|
||||
pid for pid in best.pattern.composable_with
|
||||
if pid in self._by_id
|
||||
]
|
||||
|
||||
return best
|
||||
|
||||
async def match_top_n(
|
||||
self,
|
||||
obligation_text: str,
|
||||
regulation_id: Optional[str] = None,
|
||||
n: int = 3,
|
||||
) -> list[PatternMatchResult]:
|
||||
"""Return top-N pattern matches sorted by confidence descending."""
|
||||
if not self._initialized:
|
||||
await self.initialize()
|
||||
|
||||
if not obligation_text or not self._patterns:
|
||||
return []
|
||||
|
||||
keyword_scores = self._keyword_scores(obligation_text, regulation_id)
|
||||
embedding_scores = await self._embedding_scores(obligation_text, regulation_id)
|
||||
|
||||
# Merge scores
|
||||
all_pattern_ids = set(keyword_scores.keys()) | set(embedding_scores.keys())
|
||||
results: list[PatternMatchResult] = []
|
||||
|
||||
for pid in all_pattern_ids:
|
||||
pattern = self._by_id.get(pid)
|
||||
if not pattern:
|
||||
continue
|
||||
|
||||
kw_score = keyword_scores.get(pid, (0, 0, 0.0)) # (hits, total, score)
|
||||
emb_score = embedding_scores.get(pid, (0.0, False)) # (score, bonus_applied)
|
||||
|
||||
kw_hits, kw_total, kw_confidence = kw_score
|
||||
emb_confidence, bonus_applied = emb_score
|
||||
|
||||
# Combined confidence: max of keyword and embedding, with boost if both
|
||||
if kw_confidence > 0 and emb_confidence > 0:
|
||||
combined = max(kw_confidence, emb_confidence) + 0.05
|
||||
method = "combined"
|
||||
elif kw_confidence > 0:
|
||||
combined = kw_confidence
|
||||
method = "keyword"
|
||||
else:
|
||||
combined = emb_confidence
|
||||
method = "embedding"
|
||||
|
||||
results.append(PatternMatchResult(
|
||||
pattern=pattern,
|
||||
pattern_id=pid,
|
||||
method=method,
|
||||
confidence=min(combined, 1.0),
|
||||
keyword_hits=kw_hits,
|
||||
total_keywords=kw_total,
|
||||
embedding_score=emb_confidence,
|
||||
domain_bonus_applied=bonus_applied,
|
||||
composable_patterns=[
|
||||
p for p in pattern.composable_with if p in self._by_id
|
||||
],
|
||||
))
|
||||
|
||||
# Sort by confidence descending
|
||||
results.sort(key=lambda r: r.confidence, reverse=True)
|
||||
return results[:n]
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# Tier 1: Keyword Match
|
||||
# -----------------------------------------------------------------------
|
||||
|
||||
def _tier1_keyword(
|
||||
self, obligation_text: str, regulation_id: Optional[str]
|
||||
) -> Optional[PatternMatchResult]:
|
||||
"""Match by counting keyword hits in the obligation text."""
|
||||
scores = self._keyword_scores(obligation_text, regulation_id)
|
||||
if not scores:
|
||||
return None
|
||||
|
||||
# Find best match
|
||||
best_pid = max(scores, key=lambda pid: scores[pid][2])
|
||||
hits, total, confidence = scores[best_pid]
|
||||
|
||||
if hits < KEYWORD_MATCH_MIN_HITS:
|
||||
return None
|
||||
|
||||
pattern = self._by_id.get(best_pid)
|
||||
if not pattern:
|
||||
return None
|
||||
|
||||
# Check domain bonus
|
||||
bonus_applied = False
|
||||
if regulation_id and self._domain_matches(pattern.domain, regulation_id):
|
||||
confidence = min(confidence + DOMAIN_BONUS, 1.0)
|
||||
bonus_applied = True
|
||||
|
||||
return PatternMatchResult(
|
||||
pattern=pattern,
|
||||
pattern_id=best_pid,
|
||||
method="keyword",
|
||||
confidence=confidence,
|
||||
keyword_hits=hits,
|
||||
total_keywords=total,
|
||||
domain_bonus_applied=bonus_applied,
|
||||
)
|
||||
|
||||
def _keyword_scores(
|
||||
self, text: str, regulation_id: Optional[str]
|
||||
) -> dict[str, tuple[int, int, float]]:
|
||||
"""Compute keyword match scores for all patterns.
|
||||
|
||||
Returns dict: pattern_id → (hits, total_keywords, confidence).
|
||||
"""
|
||||
text_lower = text.lower()
|
||||
hits_by_pattern: dict[str, int] = {}
|
||||
|
||||
for keyword, pattern_ids in self._keyword_index.items():
|
||||
if keyword in text_lower:
|
||||
for pid in pattern_ids:
|
||||
hits_by_pattern[pid] = hits_by_pattern.get(pid, 0) + 1
|
||||
|
||||
result: dict[str, tuple[int, int, float]] = {}
|
||||
for pid, hits in hits_by_pattern.items():
|
||||
pattern = self._by_id.get(pid)
|
||||
if not pattern:
|
||||
continue
|
||||
total = len(pattern.obligation_match_keywords)
|
||||
confidence = hits / total if total > 0 else 0.0
|
||||
result[pid] = (hits, total, confidence)
|
||||
|
||||
return result
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# Tier 2: Embedding Match
|
||||
# -----------------------------------------------------------------------
|
||||
|
||||
async def _tier2_embedding(
|
||||
self, obligation_text: str, regulation_id: Optional[str]
|
||||
) -> Optional[PatternMatchResult]:
|
||||
"""Match by embedding similarity against pattern objective_templates."""
|
||||
scores = await self._embedding_scores(obligation_text, regulation_id)
|
||||
if not scores:
|
||||
return None
|
||||
|
||||
best_pid = max(scores, key=lambda pid: scores[pid][0])
|
||||
emb_score, bonus_applied = scores[best_pid]
|
||||
|
||||
if emb_score < EMBEDDING_PATTERN_THRESHOLD:
|
||||
return None
|
||||
|
||||
pattern = self._by_id.get(best_pid)
|
||||
if not pattern:
|
||||
return None
|
||||
|
||||
return PatternMatchResult(
|
||||
pattern=pattern,
|
||||
pattern_id=best_pid,
|
||||
method="embedding",
|
||||
confidence=min(emb_score, 1.0),
|
||||
embedding_score=emb_score,
|
||||
domain_bonus_applied=bonus_applied,
|
||||
)
|
||||
|
||||
async def _embedding_scores(
|
||||
self, obligation_text: str, regulation_id: Optional[str]
|
||||
) -> dict[str, tuple[float, bool]]:
|
||||
"""Compute embedding similarity scores for all patterns.
|
||||
|
||||
Returns dict: pattern_id → (score, domain_bonus_applied).
|
||||
"""
|
||||
if not self._pattern_embeddings:
|
||||
return {}
|
||||
|
||||
chunk_embedding = await _get_embedding(obligation_text[:2000])
|
||||
if not chunk_embedding:
|
||||
return {}
|
||||
|
||||
result: dict[str, tuple[float, bool]] = {}
|
||||
for i, pat_emb in enumerate(self._pattern_embeddings):
|
||||
if not pat_emb:
|
||||
continue
|
||||
pid = self._pattern_ids[i]
|
||||
pattern = self._by_id.get(pid)
|
||||
if not pattern:
|
||||
continue
|
||||
|
||||
score = _cosine_sim(chunk_embedding, pat_emb)
|
||||
|
||||
# Domain bonus
|
||||
bonus_applied = False
|
||||
if regulation_id and self._domain_matches(pattern.domain, regulation_id):
|
||||
score += DOMAIN_BONUS
|
||||
bonus_applied = True
|
||||
|
||||
result[pid] = (score, bonus_applied)
|
||||
|
||||
return result
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# Score combination
|
||||
# -----------------------------------------------------------------------
|
||||
|
||||
def _combine_results(
|
||||
self,
|
||||
keyword_result: Optional[PatternMatchResult],
|
||||
embedding_result: Optional[PatternMatchResult],
|
||||
) -> PatternMatchResult:
|
||||
"""Combine keyword and embedding results into the best match."""
|
||||
if not keyword_result and not embedding_result:
|
||||
return PatternMatchResult()
|
||||
|
||||
if not keyword_result:
|
||||
return embedding_result
|
||||
if not embedding_result:
|
||||
return keyword_result
|
||||
|
||||
# Both matched — check if they agree
|
||||
if keyword_result.pattern_id == embedding_result.pattern_id:
|
||||
# Same pattern: boost confidence
|
||||
combined_confidence = min(
|
||||
max(keyword_result.confidence, embedding_result.confidence) + 0.05,
|
||||
1.0,
|
||||
)
|
||||
return PatternMatchResult(
|
||||
pattern=keyword_result.pattern,
|
||||
pattern_id=keyword_result.pattern_id,
|
||||
method="combined",
|
||||
confidence=combined_confidence,
|
||||
keyword_hits=keyword_result.keyword_hits,
|
||||
total_keywords=keyword_result.total_keywords,
|
||||
embedding_score=embedding_result.embedding_score,
|
||||
domain_bonus_applied=(
|
||||
keyword_result.domain_bonus_applied
|
||||
or embedding_result.domain_bonus_applied
|
||||
),
|
||||
)
|
||||
|
||||
# Different patterns: pick the one with higher confidence
|
||||
if keyword_result.confidence >= embedding_result.confidence:
|
||||
return keyword_result
|
||||
return embedding_result
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# Domain affinity
|
||||
# -----------------------------------------------------------------------
|
||||
|
||||
@staticmethod
|
||||
def _domain_matches(pattern_domain: str, regulation_id: str) -> bool:
|
||||
"""Check if a pattern's domain has affinity with a regulation."""
|
||||
affine_domains = _REGULATION_DOMAIN_AFFINITY.get(regulation_id, [])
|
||||
return pattern_domain in affine_domains
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# Initialization helpers
|
||||
# -----------------------------------------------------------------------
|
||||
|
||||
def _load_patterns(self) -> None:
|
||||
"""Load control patterns from YAML files."""
|
||||
patterns_dir = _find_patterns_dir()
|
||||
if not patterns_dir:
|
||||
logger.warning("Control patterns directory not found")
|
||||
return
|
||||
|
||||
for yaml_file in sorted(patterns_dir.glob("*.yaml")):
|
||||
if yaml_file.name.startswith("_"):
|
||||
continue
|
||||
try:
|
||||
with open(yaml_file) as f:
|
||||
data = yaml.safe_load(f)
|
||||
if not data or "patterns" not in data:
|
||||
continue
|
||||
for p in data["patterns"]:
|
||||
pattern = ControlPattern(
|
||||
id=p["id"],
|
||||
name=p["name"],
|
||||
name_de=p["name_de"],
|
||||
domain=p["domain"],
|
||||
category=p["category"],
|
||||
description=p["description"],
|
||||
objective_template=p["objective_template"],
|
||||
rationale_template=p["rationale_template"],
|
||||
requirements_template=p.get("requirements_template", []),
|
||||
test_procedure_template=p.get("test_procedure_template", []),
|
||||
evidence_template=p.get("evidence_template", []),
|
||||
severity_default=p.get("severity_default", "medium"),
|
||||
implementation_effort_default=p.get("implementation_effort_default", "m"),
|
||||
obligation_match_keywords=p.get("obligation_match_keywords", []),
|
||||
tags=p.get("tags", []),
|
||||
composable_with=p.get("composable_with", []),
|
||||
open_anchor_refs=p.get("open_anchor_refs", []),
|
||||
)
|
||||
self._patterns.append(pattern)
|
||||
self._by_id[pattern.id] = pattern
|
||||
domain_list = self._by_domain.setdefault(pattern.domain, [])
|
||||
domain_list.append(pattern)
|
||||
except Exception as e:
|
||||
logger.error("Failed to load %s: %s", yaml_file.name, e)
|
||||
|
||||
logger.info("Loaded %d patterns from %s", len(self._patterns), patterns_dir)
|
||||
|
||||
def _build_keyword_index(self) -> None:
|
||||
"""Build reverse index: keyword → [pattern_ids]."""
|
||||
for pattern in self._patterns:
|
||||
for kw in pattern.obligation_match_keywords:
|
||||
lower_kw = kw.lower()
|
||||
if lower_kw not in self._keyword_index:
|
||||
self._keyword_index[lower_kw] = []
|
||||
self._keyword_index[lower_kw].append(pattern.id)
|
||||
|
||||
async def _compute_embeddings(self) -> None:
|
||||
"""Compute embeddings for all pattern objective templates."""
|
||||
if not self._patterns:
|
||||
return
|
||||
|
||||
self._pattern_ids = [p.id for p in self._patterns]
|
||||
texts = [
|
||||
f"{p.name_de}: {p.objective_template}"
|
||||
for p in self._patterns
|
||||
]
|
||||
|
||||
logger.info("Computing embeddings for %d patterns...", len(texts))
|
||||
self._pattern_embeddings = await _get_embeddings_batch(texts)
|
||||
valid = sum(1 for e in self._pattern_embeddings if e)
|
||||
logger.info("Got %d/%d valid pattern embeddings", valid, len(texts))
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# Public helpers
|
||||
# -----------------------------------------------------------------------
|
||||
|
||||
def get_pattern(self, pattern_id: str) -> Optional[ControlPattern]:
|
||||
"""Get a pattern by its ID."""
|
||||
return self._by_id.get(pattern_id.upper())
|
||||
|
||||
def get_patterns_by_domain(self, domain: str) -> list[ControlPattern]:
|
||||
"""Get all patterns for a domain."""
|
||||
return self._by_domain.get(domain.upper(), [])
|
||||
|
||||
def stats(self) -> dict:
|
||||
"""Return matcher statistics."""
|
||||
return {
|
||||
"total_patterns": len(self._patterns),
|
||||
"domains": list(self._by_domain.keys()),
|
||||
"keywords": len(self._keyword_index),
|
||||
"embeddings_valid": sum(1 for e in self._pattern_embeddings if e),
|
||||
"initialized": self._initialized,
|
||||
}
|
||||
|
||||
|
||||
def _find_patterns_dir() -> Optional[Path]:
|
||||
"""Locate the control_patterns directory."""
|
||||
candidates = [
|
||||
Path(__file__).resolve().parent.parent.parent.parent
|
||||
/ "ai-compliance-sdk" / "policies" / "control_patterns",
|
||||
Path("/app/ai-compliance-sdk/policies/control_patterns"),
|
||||
Path("ai-compliance-sdk/policies/control_patterns"),
|
||||
]
|
||||
for p in candidates:
|
||||
if p.is_dir():
|
||||
return p
|
||||
return None
|
||||
670
control-pipeline/services/pipeline_adapter.py
Normal file
670
control-pipeline/services/pipeline_adapter.py
Normal file
@@ -0,0 +1,670 @@
|
||||
"""Pipeline Adapter — New 10-Stage Pipeline Integration.
|
||||
|
||||
Bridges the existing 7-stage control_generator pipeline with the new
|
||||
multi-layer components (ObligationExtractor, PatternMatcher, ControlComposer).
|
||||
|
||||
New pipeline flow:
|
||||
chunk → license_classify
|
||||
→ obligation_extract (Stage 4 — NEW)
|
||||
→ pattern_match (Stage 5 — NEW)
|
||||
→ control_compose (Stage 6 — replaces old Stage 3)
|
||||
→ harmonize → anchor → store + crosswalk → mark processed
|
||||
|
||||
Can be used in two modes:
|
||||
1. INLINE: Called from _process_batch() to enrich the pipeline
|
||||
2. STANDALONE: Process chunks directly through new stages
|
||||
|
||||
Part of the Multi-Layer Control Architecture (Phase 7 of 8).
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from services.control_composer import ComposedControl, ControlComposer
|
||||
from services.obligation_extractor import ObligationExtractor, ObligationMatch
|
||||
from services.pattern_matcher import PatternMatcher, PatternMatchResult
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PipelineChunk:
|
||||
"""Input chunk for the new pipeline stages."""
|
||||
|
||||
text: str
|
||||
collection: str = ""
|
||||
regulation_code: str = ""
|
||||
article: Optional[str] = None
|
||||
paragraph: Optional[str] = None
|
||||
license_rule: int = 3
|
||||
license_info: dict = field(default_factory=dict)
|
||||
source_citation: Optional[dict] = None
|
||||
chunk_hash: str = ""
|
||||
|
||||
def compute_hash(self) -> str:
|
||||
if not self.chunk_hash:
|
||||
self.chunk_hash = hashlib.sha256(self.text.encode()).hexdigest()
|
||||
return self.chunk_hash
|
||||
|
||||
|
||||
@dataclass
|
||||
class PipelineResult:
|
||||
"""Result of processing a chunk through the new pipeline."""
|
||||
|
||||
chunk: PipelineChunk
|
||||
obligation: ObligationMatch = field(default_factory=ObligationMatch)
|
||||
pattern_result: PatternMatchResult = field(default_factory=PatternMatchResult)
|
||||
control: Optional[ComposedControl] = None
|
||||
crosswalk_written: bool = False
|
||||
error: Optional[str] = None
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
"chunk_hash": self.chunk.chunk_hash,
|
||||
"obligation": self.obligation.to_dict() if self.obligation else None,
|
||||
"pattern": self.pattern_result.to_dict() if self.pattern_result else None,
|
||||
"control": self.control.to_dict() if self.control else None,
|
||||
"crosswalk_written": self.crosswalk_written,
|
||||
"error": self.error,
|
||||
}
|
||||
|
||||
|
||||
class PipelineAdapter:
|
||||
"""Integrates ObligationExtractor + PatternMatcher + ControlComposer.
|
||||
|
||||
Usage::
|
||||
|
||||
adapter = PipelineAdapter(db)
|
||||
await adapter.initialize()
|
||||
|
||||
result = await adapter.process_chunk(PipelineChunk(
|
||||
text="...",
|
||||
regulation_code="eu_2016_679",
|
||||
article="Art. 30",
|
||||
license_rule=1,
|
||||
))
|
||||
"""
|
||||
|
||||
def __init__(self, db: Optional[Session] = None):
|
||||
self.db = db
|
||||
self._extractor = ObligationExtractor()
|
||||
self._matcher = PatternMatcher()
|
||||
self._composer = ControlComposer()
|
||||
self._initialized = False
|
||||
|
||||
async def initialize(self) -> None:
|
||||
"""Initialize all sub-components."""
|
||||
if self._initialized:
|
||||
return
|
||||
await self._extractor.initialize()
|
||||
await self._matcher.initialize()
|
||||
self._initialized = True
|
||||
logger.info("PipelineAdapter initialized")
|
||||
|
||||
async def process_chunk(self, chunk: PipelineChunk) -> PipelineResult:
|
||||
"""Process a single chunk through the new 3-stage pipeline.
|
||||
|
||||
Stage 4: Obligation Extract
|
||||
Stage 5: Pattern Match
|
||||
Stage 6: Control Compose
|
||||
"""
|
||||
if not self._initialized:
|
||||
await self.initialize()
|
||||
|
||||
chunk.compute_hash()
|
||||
result = PipelineResult(chunk=chunk)
|
||||
|
||||
try:
|
||||
# Stage 4: Obligation Extract
|
||||
result.obligation = await self._extractor.extract(
|
||||
chunk_text=chunk.text,
|
||||
regulation_code=chunk.regulation_code,
|
||||
article=chunk.article,
|
||||
paragraph=chunk.paragraph,
|
||||
)
|
||||
|
||||
# Stage 5: Pattern Match
|
||||
obligation_text = (
|
||||
result.obligation.obligation_text
|
||||
or result.obligation.obligation_title
|
||||
or chunk.text[:500]
|
||||
)
|
||||
result.pattern_result = await self._matcher.match(
|
||||
obligation_text=obligation_text,
|
||||
regulation_id=result.obligation.regulation_id,
|
||||
)
|
||||
|
||||
# Stage 6: Control Compose
|
||||
result.control = await self._composer.compose(
|
||||
obligation=result.obligation,
|
||||
pattern_result=result.pattern_result,
|
||||
chunk_text=chunk.text if chunk.license_rule in (1, 2) else None,
|
||||
license_rule=chunk.license_rule,
|
||||
source_citation=chunk.source_citation,
|
||||
regulation_code=chunk.regulation_code,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Pipeline processing failed: %s", e)
|
||||
result.error = str(e)
|
||||
|
||||
return result
|
||||
|
||||
async def process_batch(self, chunks: list[PipelineChunk]) -> list[PipelineResult]:
|
||||
"""Process multiple chunks through the pipeline."""
|
||||
results = []
|
||||
for chunk in chunks:
|
||||
result = await self.process_chunk(chunk)
|
||||
results.append(result)
|
||||
return results
|
||||
|
||||
def write_crosswalk(self, result: PipelineResult, control_uuid: str) -> bool:
|
||||
"""Write obligation_extraction + crosswalk_matrix rows for a processed chunk.
|
||||
|
||||
Called AFTER the control is stored in canonical_controls.
|
||||
"""
|
||||
if not self.db or not result.control:
|
||||
return False
|
||||
|
||||
chunk = result.chunk
|
||||
obligation = result.obligation
|
||||
pattern = result.pattern_result
|
||||
|
||||
try:
|
||||
# 1. Write obligation_extraction row
|
||||
self.db.execute(
|
||||
text("""
|
||||
INSERT INTO obligation_extractions (
|
||||
chunk_hash, collection, regulation_code,
|
||||
article, paragraph, obligation_id,
|
||||
obligation_text, confidence, extraction_method,
|
||||
pattern_id, pattern_match_score, control_uuid
|
||||
) VALUES (
|
||||
:chunk_hash, :collection, :regulation_code,
|
||||
:article, :paragraph, :obligation_id,
|
||||
:obligation_text, :confidence, :extraction_method,
|
||||
:pattern_id, :pattern_match_score,
|
||||
CAST(:control_uuid AS uuid)
|
||||
)
|
||||
"""),
|
||||
{
|
||||
"chunk_hash": chunk.chunk_hash,
|
||||
"collection": chunk.collection,
|
||||
"regulation_code": chunk.regulation_code,
|
||||
"article": chunk.article,
|
||||
"paragraph": chunk.paragraph,
|
||||
"obligation_id": obligation.obligation_id if obligation else None,
|
||||
"obligation_text": (
|
||||
obligation.obligation_text[:2000]
|
||||
if obligation and obligation.obligation_text
|
||||
else None
|
||||
),
|
||||
"confidence": obligation.confidence if obligation else 0,
|
||||
"extraction_method": obligation.method if obligation else "none",
|
||||
"pattern_id": pattern.pattern_id if pattern else None,
|
||||
"pattern_match_score": pattern.confidence if pattern else 0,
|
||||
"control_uuid": control_uuid,
|
||||
},
|
||||
)
|
||||
|
||||
# 2. Write crosswalk_matrix row
|
||||
self.db.execute(
|
||||
text("""
|
||||
INSERT INTO crosswalk_matrix (
|
||||
regulation_code, article, paragraph,
|
||||
obligation_id, pattern_id,
|
||||
master_control_id, master_control_uuid,
|
||||
confidence, source
|
||||
) VALUES (
|
||||
:regulation_code, :article, :paragraph,
|
||||
:obligation_id, :pattern_id,
|
||||
:master_control_id,
|
||||
CAST(:master_control_uuid AS uuid),
|
||||
:confidence, :source
|
||||
)
|
||||
"""),
|
||||
{
|
||||
"regulation_code": chunk.regulation_code,
|
||||
"article": chunk.article,
|
||||
"paragraph": chunk.paragraph,
|
||||
"obligation_id": obligation.obligation_id if obligation else None,
|
||||
"pattern_id": pattern.pattern_id if pattern else None,
|
||||
"master_control_id": result.control.control_id,
|
||||
"master_control_uuid": control_uuid,
|
||||
"confidence": min(
|
||||
obligation.confidence if obligation else 0,
|
||||
pattern.confidence if pattern else 0,
|
||||
),
|
||||
"source": "auto",
|
||||
},
|
||||
)
|
||||
|
||||
# 3. Update canonical_controls with pattern_id + obligation_ids
|
||||
if result.control.pattern_id or result.control.obligation_ids:
|
||||
self.db.execute(
|
||||
text("""
|
||||
UPDATE canonical_controls
|
||||
SET pattern_id = COALESCE(:pattern_id, pattern_id),
|
||||
obligation_ids = COALESCE(:obligation_ids, obligation_ids)
|
||||
WHERE id = CAST(:control_uuid AS uuid)
|
||||
"""),
|
||||
{
|
||||
"pattern_id": result.control.pattern_id,
|
||||
"obligation_ids": json.dumps(result.control.obligation_ids),
|
||||
"control_uuid": control_uuid,
|
||||
},
|
||||
)
|
||||
|
||||
self.db.commit()
|
||||
result.crosswalk_written = True
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to write crosswalk: %s", e)
|
||||
self.db.rollback()
|
||||
return False
|
||||
|
||||
def stats(self) -> dict:
|
||||
"""Return component statistics."""
|
||||
return {
|
||||
"extractor": self._extractor.stats(),
|
||||
"matcher": self._matcher.stats(),
|
||||
"initialized": self._initialized,
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Migration Passes — Backfill existing 4,800+ controls
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class MigrationPasses:
|
||||
"""Non-destructive migration passes for existing controls.
|
||||
|
||||
Pass 1: Obligation Linkage (deterministic, article→obligation lookup)
|
||||
Pass 2: Pattern Classification (keyword-based matching)
|
||||
Pass 3: Quality Triage (categorize by linkage completeness)
|
||||
Pass 4: Crosswalk Backfill (write crosswalk rows for linked controls)
|
||||
Pass 5: Deduplication (mark duplicate controls)
|
||||
|
||||
Usage::
|
||||
|
||||
migration = MigrationPasses(db)
|
||||
await migration.initialize()
|
||||
|
||||
result = await migration.run_pass1_obligation_linkage(limit=100)
|
||||
result = await migration.run_pass2_pattern_classification(limit=100)
|
||||
result = migration.run_pass3_quality_triage()
|
||||
result = migration.run_pass4_crosswalk_backfill()
|
||||
result = migration.run_pass5_deduplication()
|
||||
"""
|
||||
|
||||
def __init__(self, db: Session):
|
||||
self.db = db
|
||||
self._extractor = ObligationExtractor()
|
||||
self._matcher = PatternMatcher()
|
||||
self._initialized = False
|
||||
|
||||
async def initialize(self) -> None:
|
||||
"""Initialize extractors (loads obligations + patterns)."""
|
||||
if self._initialized:
|
||||
return
|
||||
self._extractor._load_obligations()
|
||||
self._matcher._load_patterns()
|
||||
self._matcher._build_keyword_index()
|
||||
self._initialized = True
|
||||
|
||||
# -------------------------------------------------------------------
|
||||
# Pass 1: Obligation Linkage (deterministic)
|
||||
# -------------------------------------------------------------------
|
||||
|
||||
async def run_pass1_obligation_linkage(self, limit: int = 0) -> dict:
|
||||
"""Link existing controls to obligations via source_citation article.
|
||||
|
||||
For each control with source_citation → extract regulation + article
|
||||
→ look up in obligation framework → set obligation_ids.
|
||||
"""
|
||||
if not self._initialized:
|
||||
await self.initialize()
|
||||
|
||||
query = """
|
||||
SELECT id, control_id, source_citation, generation_metadata
|
||||
FROM canonical_controls
|
||||
WHERE release_state NOT IN ('deprecated')
|
||||
AND (obligation_ids IS NULL OR obligation_ids = '[]')
|
||||
"""
|
||||
if limit > 0:
|
||||
query += f" LIMIT {limit}"
|
||||
|
||||
rows = self.db.execute(text(query)).fetchall()
|
||||
|
||||
stats = {"total": len(rows), "linked": 0, "no_match": 0, "no_citation": 0}
|
||||
|
||||
for row in rows:
|
||||
control_uuid = str(row[0])
|
||||
control_id = row[1]
|
||||
citation = row[2]
|
||||
metadata = row[3]
|
||||
|
||||
# Extract regulation + article from citation or metadata
|
||||
reg_code, article = _extract_regulation_article(citation, metadata)
|
||||
if not reg_code:
|
||||
stats["no_citation"] += 1
|
||||
continue
|
||||
|
||||
# Tier 1: Exact match
|
||||
match = self._extractor._tier1_exact(reg_code, article or "")
|
||||
if match and match.obligation_id:
|
||||
self.db.execute(
|
||||
text("""
|
||||
UPDATE canonical_controls
|
||||
SET obligation_ids = :obl_ids
|
||||
WHERE id = CAST(:uuid AS uuid)
|
||||
"""),
|
||||
{
|
||||
"obl_ids": json.dumps([match.obligation_id]),
|
||||
"uuid": control_uuid,
|
||||
},
|
||||
)
|
||||
stats["linked"] += 1
|
||||
else:
|
||||
stats["no_match"] += 1
|
||||
|
||||
self.db.commit()
|
||||
logger.info("Pass 1: %s", stats)
|
||||
return stats
|
||||
|
||||
# -------------------------------------------------------------------
|
||||
# Pass 2: Pattern Classification (keyword-based)
|
||||
# -------------------------------------------------------------------
|
||||
|
||||
async def run_pass2_pattern_classification(self, limit: int = 0) -> dict:
|
||||
"""Classify existing controls into patterns via keyword matching.
|
||||
|
||||
For each control without pattern_id → keyword-match title+objective
|
||||
against pattern library → assign best match.
|
||||
"""
|
||||
if not self._initialized:
|
||||
await self.initialize()
|
||||
|
||||
query = """
|
||||
SELECT id, control_id, title, objective
|
||||
FROM canonical_controls
|
||||
WHERE release_state NOT IN ('deprecated')
|
||||
AND (pattern_id IS NULL OR pattern_id = '')
|
||||
"""
|
||||
if limit > 0:
|
||||
query += f" LIMIT {limit}"
|
||||
|
||||
rows = self.db.execute(text(query)).fetchall()
|
||||
|
||||
stats = {"total": len(rows), "classified": 0, "no_match": 0}
|
||||
|
||||
for row in rows:
|
||||
control_uuid = str(row[0])
|
||||
title = row[2] or ""
|
||||
objective = row[3] or ""
|
||||
|
||||
# Keyword match
|
||||
match_text = f"{title} {objective}"
|
||||
result = self._matcher._tier1_keyword(match_text, None)
|
||||
|
||||
if result and result.pattern_id and result.keyword_hits >= 2:
|
||||
self.db.execute(
|
||||
text("""
|
||||
UPDATE canonical_controls
|
||||
SET pattern_id = :pattern_id
|
||||
WHERE id = CAST(:uuid AS uuid)
|
||||
"""),
|
||||
{
|
||||
"pattern_id": result.pattern_id,
|
||||
"uuid": control_uuid,
|
||||
},
|
||||
)
|
||||
stats["classified"] += 1
|
||||
else:
|
||||
stats["no_match"] += 1
|
||||
|
||||
self.db.commit()
|
||||
logger.info("Pass 2: %s", stats)
|
||||
return stats
|
||||
|
||||
# -------------------------------------------------------------------
|
||||
# Pass 3: Quality Triage
|
||||
# -------------------------------------------------------------------
|
||||
|
||||
def run_pass3_quality_triage(self) -> dict:
|
||||
"""Categorize controls by linkage completeness.
|
||||
|
||||
Sets generation_metadata.triage_status:
|
||||
- "review": has both obligation_id + pattern_id
|
||||
- "needs_obligation": has pattern_id but no obligation_id
|
||||
- "needs_pattern": has obligation_id but no pattern_id
|
||||
- "legacy_unlinked": has neither
|
||||
"""
|
||||
categories = {
|
||||
"review": """
|
||||
UPDATE canonical_controls
|
||||
SET generation_metadata = jsonb_set(
|
||||
COALESCE(generation_metadata::jsonb, '{}'::jsonb),
|
||||
'{triage_status}', '"review"'
|
||||
)
|
||||
WHERE release_state NOT IN ('deprecated')
|
||||
AND obligation_ids IS NOT NULL AND obligation_ids != '[]'
|
||||
AND pattern_id IS NOT NULL AND pattern_id != ''
|
||||
""",
|
||||
"needs_obligation": """
|
||||
UPDATE canonical_controls
|
||||
SET generation_metadata = jsonb_set(
|
||||
COALESCE(generation_metadata::jsonb, '{}'::jsonb),
|
||||
'{triage_status}', '"needs_obligation"'
|
||||
)
|
||||
WHERE release_state NOT IN ('deprecated')
|
||||
AND (obligation_ids IS NULL OR obligation_ids = '[]')
|
||||
AND pattern_id IS NOT NULL AND pattern_id != ''
|
||||
""",
|
||||
"needs_pattern": """
|
||||
UPDATE canonical_controls
|
||||
SET generation_metadata = jsonb_set(
|
||||
COALESCE(generation_metadata::jsonb, '{}'::jsonb),
|
||||
'{triage_status}', '"needs_pattern"'
|
||||
)
|
||||
WHERE release_state NOT IN ('deprecated')
|
||||
AND obligation_ids IS NOT NULL AND obligation_ids != '[]'
|
||||
AND (pattern_id IS NULL OR pattern_id = '')
|
||||
""",
|
||||
"legacy_unlinked": """
|
||||
UPDATE canonical_controls
|
||||
SET generation_metadata = jsonb_set(
|
||||
COALESCE(generation_metadata::jsonb, '{}'::jsonb),
|
||||
'{triage_status}', '"legacy_unlinked"'
|
||||
)
|
||||
WHERE release_state NOT IN ('deprecated')
|
||||
AND (obligation_ids IS NULL OR obligation_ids = '[]')
|
||||
AND (pattern_id IS NULL OR pattern_id = '')
|
||||
""",
|
||||
}
|
||||
|
||||
stats = {}
|
||||
for category, sql in categories.items():
|
||||
result = self.db.execute(text(sql))
|
||||
stats[category] = result.rowcount
|
||||
|
||||
self.db.commit()
|
||||
logger.info("Pass 3: %s", stats)
|
||||
return stats
|
||||
|
||||
# -------------------------------------------------------------------
|
||||
# Pass 4: Crosswalk Backfill
|
||||
# -------------------------------------------------------------------
|
||||
|
||||
def run_pass4_crosswalk_backfill(self) -> dict:
|
||||
"""Create crosswalk_matrix rows for controls with obligation + pattern.
|
||||
|
||||
Only creates rows that don't already exist.
|
||||
"""
|
||||
result = self.db.execute(text("""
|
||||
INSERT INTO crosswalk_matrix (
|
||||
regulation_code, obligation_id, pattern_id,
|
||||
master_control_id, master_control_uuid,
|
||||
confidence, source
|
||||
)
|
||||
SELECT
|
||||
COALESCE(
|
||||
(generation_metadata::jsonb->>'source_regulation'),
|
||||
''
|
||||
) AS regulation_code,
|
||||
obl.value::text AS obligation_id,
|
||||
cc.pattern_id,
|
||||
cc.control_id,
|
||||
cc.id,
|
||||
0.80,
|
||||
'migrated'
|
||||
FROM canonical_controls cc,
|
||||
jsonb_array_elements_text(
|
||||
COALESCE(cc.obligation_ids::jsonb, '[]'::jsonb)
|
||||
) AS obl(value)
|
||||
WHERE cc.release_state NOT IN ('deprecated')
|
||||
AND cc.pattern_id IS NOT NULL AND cc.pattern_id != ''
|
||||
AND cc.obligation_ids IS NOT NULL AND cc.obligation_ids != '[]'
|
||||
AND NOT EXISTS (
|
||||
SELECT 1 FROM crosswalk_matrix cw
|
||||
WHERE cw.master_control_uuid = cc.id
|
||||
AND cw.obligation_id = obl.value::text
|
||||
)
|
||||
"""))
|
||||
|
||||
rows_inserted = result.rowcount
|
||||
self.db.commit()
|
||||
logger.info("Pass 4: %d crosswalk rows inserted", rows_inserted)
|
||||
return {"rows_inserted": rows_inserted}
|
||||
|
||||
# -------------------------------------------------------------------
|
||||
# Pass 5: Deduplication
|
||||
# -------------------------------------------------------------------
|
||||
|
||||
def run_pass5_deduplication(self) -> dict:
|
||||
"""Mark duplicate controls (same obligation + same pattern).
|
||||
|
||||
Groups controls by (obligation_id, pattern_id), keeps the one with
|
||||
highest evidence_confidence (or newest), marks rest as deprecated.
|
||||
"""
|
||||
# Find groups with duplicates
|
||||
groups = self.db.execute(text("""
|
||||
SELECT cc.pattern_id,
|
||||
obl.value::text AS obligation_id,
|
||||
array_agg(cc.id ORDER BY cc.evidence_confidence DESC NULLS LAST, cc.created_at DESC) AS ids,
|
||||
count(*) AS cnt
|
||||
FROM canonical_controls cc,
|
||||
jsonb_array_elements_text(
|
||||
COALESCE(cc.obligation_ids::jsonb, '[]'::jsonb)
|
||||
) AS obl(value)
|
||||
WHERE cc.release_state NOT IN ('deprecated')
|
||||
AND cc.pattern_id IS NOT NULL AND cc.pattern_id != ''
|
||||
GROUP BY cc.pattern_id, obl.value::text
|
||||
HAVING count(*) > 1
|
||||
""")).fetchall()
|
||||
|
||||
stats = {"groups_found": len(groups), "controls_deprecated": 0}
|
||||
|
||||
for group in groups:
|
||||
ids = group[2] # Array of UUIDs, first is the keeper
|
||||
if len(ids) <= 1:
|
||||
continue
|
||||
|
||||
# Keep first (highest confidence), deprecate rest
|
||||
deprecate_ids = ids[1:]
|
||||
for dep_id in deprecate_ids:
|
||||
self.db.execute(
|
||||
text("""
|
||||
UPDATE canonical_controls
|
||||
SET release_state = 'deprecated',
|
||||
generation_metadata = jsonb_set(
|
||||
COALESCE(generation_metadata::jsonb, '{}'::jsonb),
|
||||
'{deprecated_reason}', '"duplicate_same_obligation_pattern"'
|
||||
)
|
||||
WHERE id = CAST(:uuid AS uuid)
|
||||
AND release_state != 'deprecated'
|
||||
"""),
|
||||
{"uuid": str(dep_id)},
|
||||
)
|
||||
stats["controls_deprecated"] += 1
|
||||
|
||||
self.db.commit()
|
||||
logger.info("Pass 5: %s", stats)
|
||||
return stats
|
||||
|
||||
def migration_status(self) -> dict:
|
||||
"""Return overall migration progress."""
|
||||
row = self.db.execute(text("""
|
||||
SELECT
|
||||
count(*) AS total,
|
||||
count(*) FILTER (WHERE obligation_ids IS NOT NULL AND obligation_ids != '[]') AS has_obligation,
|
||||
count(*) FILTER (WHERE pattern_id IS NOT NULL AND pattern_id != '') AS has_pattern,
|
||||
count(*) FILTER (
|
||||
WHERE obligation_ids IS NOT NULL AND obligation_ids != '[]'
|
||||
AND pattern_id IS NOT NULL AND pattern_id != ''
|
||||
) AS fully_linked,
|
||||
count(*) FILTER (WHERE release_state = 'deprecated') AS deprecated
|
||||
FROM canonical_controls
|
||||
""")).fetchone()
|
||||
|
||||
return {
|
||||
"total_controls": row[0],
|
||||
"has_obligation": row[1],
|
||||
"has_pattern": row[2],
|
||||
"fully_linked": row[3],
|
||||
"deprecated": row[4],
|
||||
"coverage_obligation_pct": round(row[1] / max(row[0], 1) * 100, 1),
|
||||
"coverage_pattern_pct": round(row[2] / max(row[0], 1) * 100, 1),
|
||||
"coverage_full_pct": round(row[3] / max(row[0], 1) * 100, 1),
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _extract_regulation_article(
|
||||
citation: Optional[str], metadata: Optional[str]
|
||||
) -> tuple[Optional[str], Optional[str]]:
|
||||
"""Extract regulation_code and article from control's citation/metadata."""
|
||||
from services.obligation_extractor import _normalize_regulation
|
||||
|
||||
reg_code = None
|
||||
article = None
|
||||
|
||||
# Try citation first (JSON string or dict)
|
||||
if citation:
|
||||
try:
|
||||
c = json.loads(citation) if isinstance(citation, str) else citation
|
||||
if isinstance(c, dict):
|
||||
article = c.get("article") or c.get("source_article")
|
||||
# Try to get regulation from source field
|
||||
source = c.get("source", "")
|
||||
if source:
|
||||
reg_code = _normalize_regulation(source)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
pass
|
||||
|
||||
# Try metadata
|
||||
if metadata and not reg_code:
|
||||
try:
|
||||
m = json.loads(metadata) if isinstance(metadata, str) else metadata
|
||||
if isinstance(m, dict):
|
||||
src_reg = m.get("source_regulation", "")
|
||||
if src_reg:
|
||||
reg_code = _normalize_regulation(src_reg)
|
||||
if not article:
|
||||
article = m.get("source_article")
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
pass
|
||||
|
||||
return reg_code, article
|
||||
213
control-pipeline/services/rag_client.py
Normal file
213
control-pipeline/services/rag_client.py
Normal file
@@ -0,0 +1,213 @@
|
||||
"""
|
||||
Compliance RAG Client — Proxy to Go SDK RAG Search.
|
||||
|
||||
Lightweight HTTP client that queries the Go AI Compliance SDK's
|
||||
POST /sdk/v1/rag/search endpoint. This avoids needing embedding
|
||||
models or direct Qdrant access in Python.
|
||||
|
||||
Error-tolerant: RAG failures never break the calling function.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional
|
||||
|
||||
import httpx
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
SDK_URL = os.getenv("SDK_URL", "http://ai-compliance-sdk:8090")
|
||||
RAG_SEARCH_TIMEOUT = 15.0 # seconds
|
||||
|
||||
|
||||
@dataclass
|
||||
class RAGSearchResult:
|
||||
"""A single search result from the compliance corpus."""
|
||||
text: str
|
||||
regulation_code: str
|
||||
regulation_name: str
|
||||
regulation_short: str
|
||||
category: str
|
||||
article: str
|
||||
paragraph: str
|
||||
source_url: str
|
||||
score: float
|
||||
collection: str = ""
|
||||
|
||||
|
||||
class ComplianceRAGClient:
|
||||
"""
|
||||
RAG client that proxies search requests to the Go SDK.
|
||||
|
||||
Usage:
|
||||
client = get_rag_client()
|
||||
results = await client.search("DSGVO Art. 35", collection="bp_compliance_recht")
|
||||
context_str = client.format_for_prompt(results)
|
||||
"""
|
||||
|
||||
def __init__(self, base_url: str = SDK_URL):
|
||||
self._search_url = f"{base_url}/sdk/v1/rag/search"
|
||||
|
||||
async def search(
|
||||
self,
|
||||
query: str,
|
||||
collection: str = "bp_compliance_ce",
|
||||
regulations: Optional[List[str]] = None,
|
||||
top_k: int = 5,
|
||||
) -> List[RAGSearchResult]:
|
||||
"""
|
||||
Search the RAG corpus via Go SDK.
|
||||
|
||||
Returns an empty list on any error (never raises).
|
||||
"""
|
||||
payload = {
|
||||
"query": query,
|
||||
"collection": collection,
|
||||
"top_k": top_k,
|
||||
}
|
||||
if regulations:
|
||||
payload["regulations"] = regulations
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=RAG_SEARCH_TIMEOUT) as client:
|
||||
resp = await client.post(self._search_url, json=payload)
|
||||
|
||||
if resp.status_code != 200:
|
||||
logger.warning(
|
||||
"RAG search returned %d: %s", resp.status_code, resp.text[:200]
|
||||
)
|
||||
return []
|
||||
|
||||
data = resp.json()
|
||||
results = []
|
||||
for r in data.get("results", []):
|
||||
results.append(RAGSearchResult(
|
||||
text=r.get("text", ""),
|
||||
regulation_code=r.get("regulation_code", ""),
|
||||
regulation_name=r.get("regulation_name", ""),
|
||||
regulation_short=r.get("regulation_short", ""),
|
||||
category=r.get("category", ""),
|
||||
article=r.get("article", ""),
|
||||
paragraph=r.get("paragraph", ""),
|
||||
source_url=r.get("source_url", ""),
|
||||
score=r.get("score", 0.0),
|
||||
collection=collection,
|
||||
))
|
||||
return results
|
||||
|
||||
except Exception as e:
|
||||
logger.warning("RAG search failed: %s", e)
|
||||
return []
|
||||
|
||||
async def search_with_rerank(
|
||||
self,
|
||||
query: str,
|
||||
collection: str = "bp_compliance_ce",
|
||||
regulations: Optional[List[str]] = None,
|
||||
top_k: int = 5,
|
||||
) -> List[RAGSearchResult]:
|
||||
"""
|
||||
Search with optional cross-encoder re-ranking.
|
||||
|
||||
Fetches top_k*4 results from RAG, then re-ranks with cross-encoder
|
||||
and returns top_k. Falls back to regular search if reranker is disabled.
|
||||
"""
|
||||
from .reranker import get_reranker
|
||||
|
||||
reranker = get_reranker()
|
||||
if reranker is None:
|
||||
return await self.search(query, collection, regulations, top_k)
|
||||
|
||||
# Fetch more candidates for re-ranking
|
||||
candidates = await self.search(
|
||||
query, collection, regulations, top_k=max(top_k * 4, 20)
|
||||
)
|
||||
if not candidates:
|
||||
return []
|
||||
|
||||
texts = [c.text for c in candidates]
|
||||
try:
|
||||
ranked_indices = reranker.rerank(query, texts, top_k=top_k)
|
||||
return [candidates[i] for i in ranked_indices]
|
||||
except Exception as e:
|
||||
logger.warning("Reranking failed, returning unranked: %s", e)
|
||||
return candidates[:top_k]
|
||||
|
||||
async def scroll(
|
||||
self,
|
||||
collection: str,
|
||||
offset: Optional[str] = None,
|
||||
limit: int = 100,
|
||||
) -> tuple[List[RAGSearchResult], Optional[str]]:
|
||||
"""
|
||||
Scroll through ALL chunks in a collection (paginated).
|
||||
|
||||
Returns (chunks, next_offset). next_offset is None when done.
|
||||
"""
|
||||
scroll_url = self._search_url.replace("/search", "/scroll")
|
||||
params = {"collection": collection, "limit": str(limit)}
|
||||
if offset:
|
||||
params["offset"] = offset
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
resp = await client.get(scroll_url, params=params)
|
||||
|
||||
if resp.status_code != 200:
|
||||
logger.warning(
|
||||
"RAG scroll returned %d: %s", resp.status_code, resp.text[:200]
|
||||
)
|
||||
return [], None
|
||||
|
||||
data = resp.json()
|
||||
results = []
|
||||
for r in data.get("chunks", []):
|
||||
results.append(RAGSearchResult(
|
||||
text=r.get("text", ""),
|
||||
regulation_code=r.get("regulation_code", ""),
|
||||
regulation_name=r.get("regulation_name", ""),
|
||||
regulation_short=r.get("regulation_short", ""),
|
||||
category=r.get("category", ""),
|
||||
article=r.get("article", ""),
|
||||
paragraph=r.get("paragraph", ""),
|
||||
source_url=r.get("source_url", ""),
|
||||
score=0.0,
|
||||
collection=collection,
|
||||
))
|
||||
next_offset = data.get("next_offset") or None
|
||||
return results, next_offset
|
||||
|
||||
except Exception as e:
|
||||
logger.warning("RAG scroll failed: %s", e)
|
||||
return [], None
|
||||
|
||||
def format_for_prompt(
|
||||
self, results: List[RAGSearchResult], max_results: int = 5
|
||||
) -> str:
|
||||
"""Format search results as Markdown for inclusion in an LLM prompt."""
|
||||
if not results:
|
||||
return ""
|
||||
|
||||
lines = ["## Relevanter Rechtskontext\n"]
|
||||
for i, r in enumerate(results[:max_results]):
|
||||
header = f"{i + 1}. **{r.regulation_short}** ({r.regulation_code})"
|
||||
if r.article:
|
||||
header += f" — {r.article}"
|
||||
lines.append(header)
|
||||
text = r.text[:400] + "..." if len(r.text) > 400 else r.text
|
||||
lines.append(f" > {text}\n")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
# Singleton
|
||||
_rag_client: Optional[ComplianceRAGClient] = None
|
||||
|
||||
|
||||
def get_rag_client() -> ComplianceRAGClient:
|
||||
"""Get the shared RAG client instance."""
|
||||
global _rag_client
|
||||
if _rag_client is None:
|
||||
_rag_client = ComplianceRAGClient()
|
||||
return _rag_client
|
||||
85
control-pipeline/services/reranker.py
Normal file
85
control-pipeline/services/reranker.py
Normal file
@@ -0,0 +1,85 @@
|
||||
"""
|
||||
Cross-Encoder Re-Ranking for RAG Search Results.
|
||||
|
||||
Uses BGE Reranker v2 (BAAI/bge-reranker-v2-m3, MIT license) to re-rank
|
||||
search results from Qdrant for improved retrieval quality.
|
||||
|
||||
Lazy-loads the model on first use. Disabled by default (RERANK_ENABLED=false).
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
RERANK_ENABLED = os.getenv("RERANK_ENABLED", "false").lower() == "true"
|
||||
RERANK_MODEL = os.getenv("RERANK_MODEL", "BAAI/bge-reranker-v2-m3")
|
||||
|
||||
|
||||
class Reranker:
|
||||
"""Cross-encoder reranker using sentence-transformers."""
|
||||
|
||||
def __init__(self, model_name: str = RERANK_MODEL):
|
||||
self._model = None # Lazy init
|
||||
self._model_name = model_name
|
||||
|
||||
def _ensure_model(self) -> None:
|
||||
"""Load model on first use."""
|
||||
if self._model is not None:
|
||||
return
|
||||
try:
|
||||
from sentence_transformers import CrossEncoder
|
||||
|
||||
logger.info("Loading reranker model: %s", self._model_name)
|
||||
self._model = CrossEncoder(self._model_name)
|
||||
logger.info("Reranker model loaded successfully")
|
||||
except ImportError:
|
||||
logger.error(
|
||||
"sentence-transformers not installed. "
|
||||
"Install with: pip install sentence-transformers"
|
||||
)
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Failed to load reranker model: %s", e)
|
||||
raise
|
||||
|
||||
def rerank(
|
||||
self, query: str, texts: list[str], top_k: int = 5
|
||||
) -> list[int]:
|
||||
"""
|
||||
Return indices of top_k texts sorted by relevance (highest first).
|
||||
|
||||
Args:
|
||||
query: The search query.
|
||||
texts: List of candidate texts to re-rank.
|
||||
top_k: Number of top results to return.
|
||||
|
||||
Returns:
|
||||
List of indices into the original texts list, sorted by relevance.
|
||||
"""
|
||||
if not texts:
|
||||
return []
|
||||
|
||||
self._ensure_model()
|
||||
|
||||
pairs = [[query, text] for text in texts]
|
||||
scores = self._model.predict(pairs)
|
||||
|
||||
# Sort by score descending, return indices
|
||||
ranked = sorted(range(len(scores)), key=lambda i: scores[i], reverse=True)
|
||||
return ranked[:top_k]
|
||||
|
||||
|
||||
# Module-level singleton
|
||||
_reranker: Optional[Reranker] = None
|
||||
|
||||
|
||||
def get_reranker() -> Optional[Reranker]:
|
||||
"""Get the shared reranker instance. Returns None if disabled."""
|
||||
global _reranker
|
||||
if not RERANK_ENABLED:
|
||||
return None
|
||||
if _reranker is None:
|
||||
_reranker = Reranker()
|
||||
return _reranker
|
||||
223
control-pipeline/services/similarity_detector.py
Normal file
223
control-pipeline/services/similarity_detector.py
Normal file
@@ -0,0 +1,223 @@
|
||||
"""
|
||||
Too-Close Similarity Detector — checks whether a candidate text is too similar
|
||||
to a protected source text (copyright / license compliance).
|
||||
|
||||
Five metrics:
|
||||
1. Exact-phrase — longest identical token sequence
|
||||
2. Token overlap — Jaccard similarity of token sets
|
||||
3. 3-gram Jaccard — Jaccard similarity of character 3-grams
|
||||
4. Embedding cosine — via bge-m3 (Ollama or embedding-service)
|
||||
5. LCS ratio — Longest Common Subsequence / max(len_a, len_b)
|
||||
|
||||
Decision:
|
||||
PASS — no fail + max 1 warn
|
||||
WARN — max 2 warn, no fail → human review
|
||||
FAIL — any fail threshold → block, rewrite required
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
import httpx
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Thresholds
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
THRESHOLDS = {
|
||||
"max_exact_run": {"warn": 8, "fail": 12},
|
||||
"token_overlap": {"warn": 0.20, "fail": 0.30},
|
||||
"ngram_jaccard": {"warn": 0.10, "fail": 0.18},
|
||||
"embedding_cosine": {"warn": 0.86, "fail": 0.92},
|
||||
"lcs_ratio": {"warn": 0.35, "fail": 0.50},
|
||||
}
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tokenisation helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_WORD_RE = re.compile(r"\w+", re.UNICODE)
|
||||
|
||||
|
||||
def _tokenize(text: str) -> list[str]:
|
||||
return [t.lower() for t in _WORD_RE.findall(text)]
|
||||
|
||||
|
||||
def _char_ngrams(text: str, n: int = 3) -> set[str]:
|
||||
text = text.lower()
|
||||
return {text[i : i + n] for i in range(len(text) - n + 1)} if len(text) >= n else set()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Metric implementations
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def max_exact_run(tokens_a: list[str], tokens_b: list[str]) -> int:
|
||||
"""Longest contiguous identical token sequence between a and b."""
|
||||
if not tokens_a or not tokens_b:
|
||||
return 0
|
||||
|
||||
best = 0
|
||||
set_b = set(tokens_b)
|
||||
|
||||
for i in range(len(tokens_a)):
|
||||
if tokens_a[i] not in set_b:
|
||||
continue
|
||||
for j in range(len(tokens_b)):
|
||||
if tokens_a[i] != tokens_b[j]:
|
||||
continue
|
||||
run = 0
|
||||
ii, jj = i, j
|
||||
while ii < len(tokens_a) and jj < len(tokens_b) and tokens_a[ii] == tokens_b[jj]:
|
||||
run += 1
|
||||
ii += 1
|
||||
jj += 1
|
||||
if run > best:
|
||||
best = run
|
||||
return best
|
||||
|
||||
|
||||
def token_overlap_jaccard(tokens_a: list[str], tokens_b: list[str]) -> float:
|
||||
"""Jaccard similarity of token sets."""
|
||||
set_a, set_b = set(tokens_a), set(tokens_b)
|
||||
if not set_a and not set_b:
|
||||
return 0.0
|
||||
return len(set_a & set_b) / len(set_a | set_b)
|
||||
|
||||
|
||||
def ngram_jaccard(text_a: str, text_b: str, n: int = 3) -> float:
|
||||
"""Jaccard similarity of character n-grams."""
|
||||
grams_a = _char_ngrams(text_a, n)
|
||||
grams_b = _char_ngrams(text_b, n)
|
||||
if not grams_a and not grams_b:
|
||||
return 0.0
|
||||
return len(grams_a & grams_b) / len(grams_a | grams_b)
|
||||
|
||||
|
||||
def lcs_ratio(tokens_a: list[str], tokens_b: list[str]) -> float:
|
||||
"""LCS length / max(len_a, len_b)."""
|
||||
m, n = len(tokens_a), len(tokens_b)
|
||||
if m == 0 or n == 0:
|
||||
return 0.0
|
||||
|
||||
# Space-optimised LCS (two rows)
|
||||
prev = [0] * (n + 1)
|
||||
curr = [0] * (n + 1)
|
||||
for i in range(1, m + 1):
|
||||
for j in range(1, n + 1):
|
||||
if tokens_a[i - 1] == tokens_b[j - 1]:
|
||||
curr[j] = prev[j - 1] + 1
|
||||
else:
|
||||
curr[j] = max(prev[j], curr[j - 1])
|
||||
prev, curr = curr, [0] * (n + 1)
|
||||
|
||||
return prev[n] / max(m, n)
|
||||
|
||||
|
||||
async def embedding_cosine(text_a: str, text_b: str, embedding_url: str | None = None) -> float:
|
||||
"""Cosine similarity via embedding service (bge-m3).
|
||||
|
||||
Falls back to 0.0 if the service is unreachable.
|
||||
"""
|
||||
url = embedding_url or "http://embedding-service:8087"
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
resp = await client.post(
|
||||
f"{url}/embed",
|
||||
json={"texts": [text_a, text_b]},
|
||||
)
|
||||
resp.raise_for_status()
|
||||
embeddings = resp.json().get("embeddings", [])
|
||||
if len(embeddings) < 2:
|
||||
return 0.0
|
||||
return _cosine(embeddings[0], embeddings[1])
|
||||
except Exception:
|
||||
logger.warning("Embedding service unreachable, skipping cosine check")
|
||||
return 0.0
|
||||
|
||||
|
||||
def _cosine(a: list[float], b: list[float]) -> float:
|
||||
dot = sum(x * y for x, y in zip(a, b))
|
||||
norm_a = sum(x * x for x in a) ** 0.5
|
||||
norm_b = sum(x * x for x in b) ** 0.5
|
||||
if norm_a == 0 or norm_b == 0:
|
||||
return 0.0
|
||||
return dot / (norm_a * norm_b)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Decision engine
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@dataclass
|
||||
class SimilarityReport:
|
||||
max_exact_run: int
|
||||
token_overlap: float
|
||||
ngram_jaccard: float
|
||||
embedding_cosine: float
|
||||
lcs_ratio: float
|
||||
status: str # PASS, WARN, FAIL
|
||||
details: dict # per-metric status
|
||||
|
||||
|
||||
def _classify(value: float | int, metric: str) -> str:
|
||||
t = THRESHOLDS[metric]
|
||||
if value >= t["fail"]:
|
||||
return "FAIL"
|
||||
if value >= t["warn"]:
|
||||
return "WARN"
|
||||
return "PASS"
|
||||
|
||||
|
||||
async def check_similarity(
|
||||
source_text: str,
|
||||
candidate_text: str,
|
||||
embedding_url: str | None = None,
|
||||
) -> SimilarityReport:
|
||||
"""Run all 5 metrics and return an aggregate report."""
|
||||
tok_src = _tokenize(source_text)
|
||||
tok_cand = _tokenize(candidate_text)
|
||||
|
||||
m_exact = max_exact_run(tok_src, tok_cand)
|
||||
m_token = token_overlap_jaccard(tok_src, tok_cand)
|
||||
m_ngram = ngram_jaccard(source_text, candidate_text)
|
||||
m_embed = await embedding_cosine(source_text, candidate_text, embedding_url)
|
||||
m_lcs = lcs_ratio(tok_src, tok_cand)
|
||||
|
||||
details = {
|
||||
"max_exact_run": _classify(m_exact, "max_exact_run"),
|
||||
"token_overlap": _classify(m_token, "token_overlap"),
|
||||
"ngram_jaccard": _classify(m_ngram, "ngram_jaccard"),
|
||||
"embedding_cosine": _classify(m_embed, "embedding_cosine"),
|
||||
"lcs_ratio": _classify(m_lcs, "lcs_ratio"),
|
||||
}
|
||||
|
||||
fail_count = sum(1 for v in details.values() if v == "FAIL")
|
||||
warn_count = sum(1 for v in details.values() if v == "WARN")
|
||||
|
||||
if fail_count > 0:
|
||||
status = "FAIL"
|
||||
elif warn_count > 2:
|
||||
status = "FAIL"
|
||||
elif warn_count > 1:
|
||||
status = "WARN"
|
||||
elif warn_count == 1:
|
||||
status = "PASS"
|
||||
else:
|
||||
status = "PASS"
|
||||
|
||||
return SimilarityReport(
|
||||
max_exact_run=m_exact,
|
||||
token_overlap=round(m_token, 4),
|
||||
ngram_jaccard=round(m_ngram, 4),
|
||||
embedding_cosine=round(m_embed, 4),
|
||||
lcs_ratio=round(m_lcs, 4),
|
||||
status=status,
|
||||
details=details,
|
||||
)
|
||||
331
control-pipeline/services/v1_enrichment.py
Normal file
331
control-pipeline/services/v1_enrichment.py
Normal file
@@ -0,0 +1,331 @@
|
||||
"""V1 Control Enrichment Service — Match Eigenentwicklung controls to regulations.
|
||||
|
||||
Finds regulatory coverage for v1 controls (generation_strategy='ungrouped',
|
||||
pipeline_version=1, no source_citation) by embedding similarity search.
|
||||
|
||||
Reuses embedding + Qdrant helpers from control_dedup.py.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import text
|
||||
|
||||
from db.session import SessionLocal
|
||||
from services.control_dedup import (
|
||||
get_embedding,
|
||||
qdrant_search_cross_regulation,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Similarity threshold — lower than dedup (0.85) since we want informational matches
|
||||
# Typical top scores for v1 controls are 0.70-0.77
|
||||
V1_MATCH_THRESHOLD = 0.70
|
||||
V1_MAX_MATCHES = 5
|
||||
|
||||
|
||||
def _is_eigenentwicklung_query() -> str:
|
||||
"""SQL WHERE clause identifying v1 Eigenentwicklung controls."""
|
||||
return """
|
||||
generation_strategy = 'ungrouped'
|
||||
AND (pipeline_version = '1' OR pipeline_version IS NULL)
|
||||
AND source_citation IS NULL
|
||||
AND parent_control_uuid IS NULL
|
||||
AND release_state NOT IN ('rejected', 'merged', 'deprecated')
|
||||
"""
|
||||
|
||||
|
||||
async def count_v1_controls() -> int:
|
||||
"""Count how many v1 Eigenentwicklung controls exist."""
|
||||
with SessionLocal() as db:
|
||||
row = db.execute(text(f"""
|
||||
SELECT COUNT(*) AS cnt
|
||||
FROM canonical_controls
|
||||
WHERE {_is_eigenentwicklung_query()}
|
||||
""")).fetchone()
|
||||
return row.cnt if row else 0
|
||||
|
||||
|
||||
async def enrich_v1_matches(
|
||||
dry_run: bool = True,
|
||||
batch_size: int = 100,
|
||||
offset: int = 0,
|
||||
) -> dict:
|
||||
"""Find regulatory matches for v1 Eigenentwicklung controls.
|
||||
|
||||
Args:
|
||||
dry_run: If True, only count — don't write matches.
|
||||
batch_size: Number of v1 controls to process per call.
|
||||
offset: Pagination offset (v1 control index).
|
||||
|
||||
Returns:
|
||||
Stats dict with counts, sample matches, and pagination info.
|
||||
"""
|
||||
with SessionLocal() as db:
|
||||
# 1. Load v1 controls (paginated)
|
||||
v1_controls = db.execute(text(f"""
|
||||
SELECT id, control_id, title, objective, category
|
||||
FROM canonical_controls
|
||||
WHERE {_is_eigenentwicklung_query()}
|
||||
ORDER BY control_id
|
||||
LIMIT :limit OFFSET :offset
|
||||
"""), {"limit": batch_size, "offset": offset}).fetchall()
|
||||
|
||||
# Count total for pagination
|
||||
total_row = db.execute(text(f"""
|
||||
SELECT COUNT(*) AS cnt
|
||||
FROM canonical_controls
|
||||
WHERE {_is_eigenentwicklung_query()}
|
||||
""")).fetchone()
|
||||
total_v1 = total_row.cnt if total_row else 0
|
||||
|
||||
if not v1_controls:
|
||||
return {
|
||||
"dry_run": dry_run,
|
||||
"processed": 0,
|
||||
"total_v1": total_v1,
|
||||
"message": "Kein weiterer Batch — alle v1 Controls verarbeitet.",
|
||||
}
|
||||
|
||||
if dry_run:
|
||||
return {
|
||||
"dry_run": True,
|
||||
"total_v1": total_v1,
|
||||
"offset": offset,
|
||||
"batch_size": batch_size,
|
||||
"sample_controls": [
|
||||
{
|
||||
"control_id": r.control_id,
|
||||
"title": r.title,
|
||||
"category": r.category,
|
||||
}
|
||||
for r in v1_controls[:20]
|
||||
],
|
||||
}
|
||||
|
||||
# 2. Process each v1 control
|
||||
processed = 0
|
||||
matches_inserted = 0
|
||||
errors = []
|
||||
sample_matches = []
|
||||
|
||||
for v1 in v1_controls:
|
||||
try:
|
||||
# Build search text
|
||||
search_text = f"{v1.title} — {v1.objective}"
|
||||
|
||||
# Get embedding
|
||||
embedding = await get_embedding(search_text)
|
||||
if not embedding:
|
||||
errors.append({
|
||||
"control_id": v1.control_id,
|
||||
"error": "Embedding fehlgeschlagen",
|
||||
})
|
||||
continue
|
||||
|
||||
# Search Qdrant (cross-regulation, no pattern filter)
|
||||
# Collection is atomic_controls_dedup (contains ~51k atomare Controls)
|
||||
results = await qdrant_search_cross_regulation(
|
||||
embedding, top_k=20,
|
||||
collection="atomic_controls_dedup",
|
||||
)
|
||||
|
||||
# For each hit: resolve to a regulatory parent with source_citation.
|
||||
# Atomic controls in Qdrant usually have parent_control_uuid → parent
|
||||
# has the source_citation. We deduplicate by parent to avoid
|
||||
# listing the same regulation multiple times.
|
||||
rank = 0
|
||||
seen_parents: set[str] = set()
|
||||
|
||||
for hit in results:
|
||||
score = hit.get("score", 0)
|
||||
if score < V1_MATCH_THRESHOLD:
|
||||
continue
|
||||
|
||||
payload = hit.get("payload", {})
|
||||
matched_uuid = payload.get("control_uuid")
|
||||
if not matched_uuid or matched_uuid == str(v1.id):
|
||||
continue
|
||||
|
||||
# Try the matched control itself first, then its parent
|
||||
matched_row = db.execute(text("""
|
||||
SELECT c.id, c.control_id, c.title, c.source_citation,
|
||||
c.severity, c.category, c.parent_control_uuid
|
||||
FROM canonical_controls c
|
||||
WHERE c.id = CAST(:uuid AS uuid)
|
||||
"""), {"uuid": matched_uuid}).fetchone()
|
||||
|
||||
if not matched_row:
|
||||
continue
|
||||
|
||||
# Resolve to regulatory control (one with source_citation)
|
||||
reg_row = matched_row
|
||||
if not reg_row.source_citation and reg_row.parent_control_uuid:
|
||||
# Look up parent — the parent has the source_citation
|
||||
parent_row = db.execute(text("""
|
||||
SELECT id, control_id, title, source_citation,
|
||||
severity, category, parent_control_uuid
|
||||
FROM canonical_controls
|
||||
WHERE id = CAST(:uuid AS uuid)
|
||||
AND source_citation IS NOT NULL
|
||||
"""), {"uuid": str(reg_row.parent_control_uuid)}).fetchone()
|
||||
if parent_row:
|
||||
reg_row = parent_row
|
||||
|
||||
if not reg_row.source_citation:
|
||||
continue
|
||||
|
||||
# Deduplicate by parent UUID
|
||||
parent_key = str(reg_row.id)
|
||||
if parent_key in seen_parents:
|
||||
continue
|
||||
seen_parents.add(parent_key)
|
||||
|
||||
rank += 1
|
||||
if rank > V1_MAX_MATCHES:
|
||||
break
|
||||
|
||||
# Extract source info
|
||||
source_citation = reg_row.source_citation or {}
|
||||
matched_source = source_citation.get("source") if isinstance(source_citation, dict) else None
|
||||
matched_article = source_citation.get("article") if isinstance(source_citation, dict) else None
|
||||
|
||||
# Insert match — link to the regulatory parent (not the atomic child)
|
||||
db.execute(text("""
|
||||
INSERT INTO v1_control_matches
|
||||
(v1_control_uuid, matched_control_uuid, similarity_score,
|
||||
match_rank, matched_source, matched_article, match_method)
|
||||
VALUES
|
||||
(CAST(:v1_uuid AS uuid), CAST(:matched_uuid AS uuid), :score,
|
||||
:rank, :source, :article, 'embedding')
|
||||
ON CONFLICT (v1_control_uuid, matched_control_uuid) DO UPDATE
|
||||
SET similarity_score = EXCLUDED.similarity_score,
|
||||
match_rank = EXCLUDED.match_rank
|
||||
"""), {
|
||||
"v1_uuid": str(v1.id),
|
||||
"matched_uuid": str(reg_row.id),
|
||||
"score": round(score, 3),
|
||||
"rank": rank,
|
||||
"source": matched_source,
|
||||
"article": matched_article,
|
||||
})
|
||||
matches_inserted += 1
|
||||
|
||||
# Collect sample
|
||||
if len(sample_matches) < 20:
|
||||
sample_matches.append({
|
||||
"v1_control_id": v1.control_id,
|
||||
"v1_title": v1.title,
|
||||
"matched_control_id": reg_row.control_id,
|
||||
"matched_title": reg_row.title,
|
||||
"matched_source": matched_source,
|
||||
"matched_article": matched_article,
|
||||
"similarity_score": round(score, 3),
|
||||
"match_rank": rank,
|
||||
})
|
||||
|
||||
processed += 1
|
||||
|
||||
except Exception as e:
|
||||
logger.warning("V1 enrichment error for %s: %s", v1.control_id, e)
|
||||
errors.append({
|
||||
"control_id": v1.control_id,
|
||||
"error": str(e),
|
||||
})
|
||||
|
||||
db.commit()
|
||||
|
||||
# Pagination
|
||||
next_offset = offset + batch_size if len(v1_controls) == batch_size else None
|
||||
|
||||
return {
|
||||
"dry_run": False,
|
||||
"offset": offset,
|
||||
"batch_size": batch_size,
|
||||
"next_offset": next_offset,
|
||||
"total_v1": total_v1,
|
||||
"processed": processed,
|
||||
"matches_inserted": matches_inserted,
|
||||
"errors": errors[:10],
|
||||
"sample_matches": sample_matches,
|
||||
}
|
||||
|
||||
|
||||
async def get_v1_matches(control_uuid: str) -> list[dict]:
|
||||
"""Get all regulatory matches for a specific v1 control.
|
||||
|
||||
Args:
|
||||
control_uuid: The UUID of the v1 control.
|
||||
|
||||
Returns:
|
||||
List of match dicts with control details.
|
||||
"""
|
||||
with SessionLocal() as db:
|
||||
rows = db.execute(text("""
|
||||
SELECT
|
||||
m.similarity_score,
|
||||
m.match_rank,
|
||||
m.matched_source,
|
||||
m.matched_article,
|
||||
m.match_method,
|
||||
c.control_id AS matched_control_id,
|
||||
c.title AS matched_title,
|
||||
c.objective AS matched_objective,
|
||||
c.severity AS matched_severity,
|
||||
c.category AS matched_category,
|
||||
c.source_citation AS matched_source_citation
|
||||
FROM v1_control_matches m
|
||||
JOIN canonical_controls c ON c.id = m.matched_control_uuid
|
||||
WHERE m.v1_control_uuid = CAST(:uuid AS uuid)
|
||||
ORDER BY m.match_rank
|
||||
"""), {"uuid": control_uuid}).fetchall()
|
||||
|
||||
return [
|
||||
{
|
||||
"matched_control_id": r.matched_control_id,
|
||||
"matched_title": r.matched_title,
|
||||
"matched_objective": r.matched_objective,
|
||||
"matched_severity": r.matched_severity,
|
||||
"matched_category": r.matched_category,
|
||||
"matched_source": r.matched_source,
|
||||
"matched_article": r.matched_article,
|
||||
"matched_source_citation": r.matched_source_citation,
|
||||
"similarity_score": float(r.similarity_score),
|
||||
"match_rank": r.match_rank,
|
||||
"match_method": r.match_method,
|
||||
}
|
||||
for r in rows
|
||||
]
|
||||
|
||||
|
||||
async def get_v1_enrichment_stats() -> dict:
|
||||
"""Get overview stats for v1 enrichment."""
|
||||
with SessionLocal() as db:
|
||||
total_v1 = db.execute(text(f"""
|
||||
SELECT COUNT(*) AS cnt FROM canonical_controls
|
||||
WHERE {_is_eigenentwicklung_query()}
|
||||
""")).fetchone()
|
||||
|
||||
matched_v1 = db.execute(text(f"""
|
||||
SELECT COUNT(DISTINCT m.v1_control_uuid) AS cnt
|
||||
FROM v1_control_matches m
|
||||
JOIN canonical_controls c ON c.id = m.v1_control_uuid
|
||||
WHERE {_is_eigenentwicklung_query().replace('release_state', 'c.release_state').replace('generation_strategy', 'c.generation_strategy').replace('pipeline_version', 'c.pipeline_version').replace('source_citation', 'c.source_citation').replace('parent_control_uuid', 'c.parent_control_uuid')}
|
||||
""")).fetchone()
|
||||
|
||||
total_matches = db.execute(text("""
|
||||
SELECT COUNT(*) AS cnt FROM v1_control_matches
|
||||
""")).fetchone()
|
||||
|
||||
avg_score = db.execute(text("""
|
||||
SELECT AVG(similarity_score) AS avg_score FROM v1_control_matches
|
||||
""")).fetchone()
|
||||
|
||||
return {
|
||||
"total_v1_controls": total_v1.cnt if total_v1 else 0,
|
||||
"v1_with_matches": matched_v1.cnt if matched_v1 else 0,
|
||||
"v1_without_matches": (total_v1.cnt if total_v1 else 0) - (matched_v1.cnt if matched_v1 else 0),
|
||||
"total_matches": total_matches.cnt if total_matches else 0,
|
||||
"avg_similarity_score": round(float(avg_score.avg_score), 3) if avg_score and avg_score.avg_score else None,
|
||||
}
|
||||
0
control-pipeline/tests/__init__.py
Normal file
0
control-pipeline/tests/__init__.py
Normal file
@@ -56,6 +56,7 @@ services:
|
||||
- "8091:8091" # Voice Service (WSS)
|
||||
- "8093:8093" # AI Compliance SDK
|
||||
- "8097:8097" # RAG Service (NEU)
|
||||
- "8098:8098" # Control Pipeline
|
||||
- "8443:8443" # Jitsi Meet
|
||||
- "3008:3008" # Admin Core
|
||||
- "3010:3010" # Portal Dashboard
|
||||
@@ -376,6 +377,50 @@ services:
|
||||
networks:
|
||||
- breakpilot-network
|
||||
|
||||
# =========================================================
|
||||
# CONTROL PIPELINE (Entwickler-only, nicht kundenrelevant)
|
||||
# =========================================================
|
||||
control-pipeline:
|
||||
build:
|
||||
context: ./control-pipeline
|
||||
dockerfile: Dockerfile
|
||||
container_name: bp-core-control-pipeline
|
||||
platform: linux/arm64
|
||||
expose:
|
||||
- "8098"
|
||||
environment:
|
||||
PORT: 8098
|
||||
DATABASE_URL: postgresql://${POSTGRES_USER:-breakpilot}:${POSTGRES_PASSWORD:-breakpilot123}@postgres:5432/${POSTGRES_DB:-breakpilot_db}
|
||||
SCHEMA_SEARCH_PATH: compliance,core,public
|
||||
QDRANT_URL: http://qdrant:6333
|
||||
EMBEDDING_SERVICE_URL: http://embedding-service:8087
|
||||
ANTHROPIC_API_KEY: ${ANTHROPIC_API_KEY:-}
|
||||
CONTROL_GEN_ANTHROPIC_MODEL: ${CONTROL_GEN_ANTHROPIC_MODEL:-claude-sonnet-4-6}
|
||||
DECOMPOSITION_LLM_MODEL: ${DECOMPOSITION_LLM_MODEL:-claude-haiku-4-5-20251001}
|
||||
OLLAMA_URL: ${OLLAMA_URL:-http://host.docker.internal:11434}
|
||||
CONTROL_GEN_OLLAMA_MODEL: ${CONTROL_GEN_OLLAMA_MODEL:-qwen3.5:35b-a3b}
|
||||
SDK_URL: http://ai-compliance-sdk:8090
|
||||
JWT_SECRET: ${JWT_SECRET:-your-super-secret-jwt-key-change-in-production}
|
||||
ENVIRONMENT: ${ENVIRONMENT:-development}
|
||||
extra_hosts:
|
||||
- "host.docker.internal:host-gateway"
|
||||
depends_on:
|
||||
postgres:
|
||||
condition: service_healthy
|
||||
qdrant:
|
||||
condition: service_healthy
|
||||
embedding-service:
|
||||
condition: service_healthy
|
||||
healthcheck:
|
||||
test: ["CMD", "curl", "-f", "http://127.0.0.1:8098/health"]
|
||||
interval: 30s
|
||||
timeout: 10s
|
||||
retries: 3
|
||||
start_period: 10s
|
||||
restart: unless-stopped
|
||||
networks:
|
||||
- breakpilot-network
|
||||
|
||||
embedding-service:
|
||||
build:
|
||||
context: ./embedding-service
|
||||
|
||||
@@ -578,6 +578,33 @@ server {
|
||||
}
|
||||
}
|
||||
|
||||
# =========================================================
|
||||
# CORE: Control Pipeline on port 8098 (Entwickler-only)
|
||||
# =========================================================
|
||||
server {
|
||||
listen 8098 ssl;
|
||||
http2 on;
|
||||
server_name macmini localhost;
|
||||
|
||||
ssl_certificate /etc/nginx/certs/macmini.crt;
|
||||
ssl_certificate_key /etc/nginx/certs/macmini.key;
|
||||
ssl_protocols TLSv1.2 TLSv1.3;
|
||||
ssl_ciphers ECDHE-ECDSA-AES128-GCM-SHA256:ECDHE-RSA-AES128-GCM-SHA256;
|
||||
ssl_prefer_server_ciphers off;
|
||||
|
||||
location / {
|
||||
set $upstream_pipeline bp-core-control-pipeline:8098;
|
||||
proxy_pass http://$upstream_pipeline;
|
||||
proxy_http_version 1.1;
|
||||
proxy_set_header Host $host;
|
||||
proxy_set_header X-Real-IP $remote_addr;
|
||||
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
||||
proxy_set_header X-Forwarded-Proto https;
|
||||
proxy_read_timeout 1800s;
|
||||
proxy_send_timeout 1800s;
|
||||
}
|
||||
}
|
||||
|
||||
# =========================================================
|
||||
# CORE: Edu-Search on port 8089
|
||||
# =========================================================
|
||||
|
||||
Reference in New Issue
Block a user