Files
breakpilot-core/control-pipeline/services/pipeline_adapter.py
Benjamin Admin e3ab428b91 feat: control-pipeline Service aus Compliance-Repo migriert
Control-Pipeline (Pass 0a/0b, BatchDedup, Generator) als eigenstaendiger
Service in Core, damit Compliance-Repo unabhaengig refakturiert werden kann.
Schreibt weiterhin ins compliance-Schema der shared PostgreSQL.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-04-09 14:40:47 +02:00

671 lines
25 KiB
Python

"""Pipeline Adapter — New 10-Stage Pipeline Integration.
Bridges the existing 7-stage control_generator pipeline with the new
multi-layer components (ObligationExtractor, PatternMatcher, ControlComposer).
New pipeline flow:
chunk → license_classify
→ obligation_extract (Stage 4 — NEW)
→ pattern_match (Stage 5 — NEW)
→ control_compose (Stage 6 — replaces old Stage 3)
→ harmonize → anchor → store + crosswalk → mark processed
Can be used in two modes:
1. INLINE: Called from _process_batch() to enrich the pipeline
2. STANDALONE: Process chunks directly through new stages
Part of the Multi-Layer Control Architecture (Phase 7 of 8).
"""
import hashlib
import json
import logging
from dataclasses import dataclass, field
from typing import Optional
from sqlalchemy import text
from sqlalchemy.orm import Session
from services.control_composer import ComposedControl, ControlComposer
from services.obligation_extractor import ObligationExtractor, ObligationMatch
from services.pattern_matcher import PatternMatcher, PatternMatchResult
logger = logging.getLogger(__name__)
@dataclass
class PipelineChunk:
"""Input chunk for the new pipeline stages."""
text: str
collection: str = ""
regulation_code: str = ""
article: Optional[str] = None
paragraph: Optional[str] = None
license_rule: int = 3
license_info: dict = field(default_factory=dict)
source_citation: Optional[dict] = None
chunk_hash: str = ""
def compute_hash(self) -> str:
if not self.chunk_hash:
self.chunk_hash = hashlib.sha256(self.text.encode()).hexdigest()
return self.chunk_hash
@dataclass
class PipelineResult:
"""Result of processing a chunk through the new pipeline."""
chunk: PipelineChunk
obligation: ObligationMatch = field(default_factory=ObligationMatch)
pattern_result: PatternMatchResult = field(default_factory=PatternMatchResult)
control: Optional[ComposedControl] = None
crosswalk_written: bool = False
error: Optional[str] = None
def to_dict(self) -> dict:
return {
"chunk_hash": self.chunk.chunk_hash,
"obligation": self.obligation.to_dict() if self.obligation else None,
"pattern": self.pattern_result.to_dict() if self.pattern_result else None,
"control": self.control.to_dict() if self.control else None,
"crosswalk_written": self.crosswalk_written,
"error": self.error,
}
class PipelineAdapter:
"""Integrates ObligationExtractor + PatternMatcher + ControlComposer.
Usage::
adapter = PipelineAdapter(db)
await adapter.initialize()
result = await adapter.process_chunk(PipelineChunk(
text="...",
regulation_code="eu_2016_679",
article="Art. 30",
license_rule=1,
))
"""
def __init__(self, db: Optional[Session] = None):
self.db = db
self._extractor = ObligationExtractor()
self._matcher = PatternMatcher()
self._composer = ControlComposer()
self._initialized = False
async def initialize(self) -> None:
"""Initialize all sub-components."""
if self._initialized:
return
await self._extractor.initialize()
await self._matcher.initialize()
self._initialized = True
logger.info("PipelineAdapter initialized")
async def process_chunk(self, chunk: PipelineChunk) -> PipelineResult:
"""Process a single chunk through the new 3-stage pipeline.
Stage 4: Obligation Extract
Stage 5: Pattern Match
Stage 6: Control Compose
"""
if not self._initialized:
await self.initialize()
chunk.compute_hash()
result = PipelineResult(chunk=chunk)
try:
# Stage 4: Obligation Extract
result.obligation = await self._extractor.extract(
chunk_text=chunk.text,
regulation_code=chunk.regulation_code,
article=chunk.article,
paragraph=chunk.paragraph,
)
# Stage 5: Pattern Match
obligation_text = (
result.obligation.obligation_text
or result.obligation.obligation_title
or chunk.text[:500]
)
result.pattern_result = await self._matcher.match(
obligation_text=obligation_text,
regulation_id=result.obligation.regulation_id,
)
# Stage 6: Control Compose
result.control = await self._composer.compose(
obligation=result.obligation,
pattern_result=result.pattern_result,
chunk_text=chunk.text if chunk.license_rule in (1, 2) else None,
license_rule=chunk.license_rule,
source_citation=chunk.source_citation,
regulation_code=chunk.regulation_code,
)
except Exception as e:
logger.error("Pipeline processing failed: %s", e)
result.error = str(e)
return result
async def process_batch(self, chunks: list[PipelineChunk]) -> list[PipelineResult]:
"""Process multiple chunks through the pipeline."""
results = []
for chunk in chunks:
result = await self.process_chunk(chunk)
results.append(result)
return results
def write_crosswalk(self, result: PipelineResult, control_uuid: str) -> bool:
"""Write obligation_extraction + crosswalk_matrix rows for a processed chunk.
Called AFTER the control is stored in canonical_controls.
"""
if not self.db or not result.control:
return False
chunk = result.chunk
obligation = result.obligation
pattern = result.pattern_result
try:
# 1. Write obligation_extraction row
self.db.execute(
text("""
INSERT INTO obligation_extractions (
chunk_hash, collection, regulation_code,
article, paragraph, obligation_id,
obligation_text, confidence, extraction_method,
pattern_id, pattern_match_score, control_uuid
) VALUES (
:chunk_hash, :collection, :regulation_code,
:article, :paragraph, :obligation_id,
:obligation_text, :confidence, :extraction_method,
:pattern_id, :pattern_match_score,
CAST(:control_uuid AS uuid)
)
"""),
{
"chunk_hash": chunk.chunk_hash,
"collection": chunk.collection,
"regulation_code": chunk.regulation_code,
"article": chunk.article,
"paragraph": chunk.paragraph,
"obligation_id": obligation.obligation_id if obligation else None,
"obligation_text": (
obligation.obligation_text[:2000]
if obligation and obligation.obligation_text
else None
),
"confidence": obligation.confidence if obligation else 0,
"extraction_method": obligation.method if obligation else "none",
"pattern_id": pattern.pattern_id if pattern else None,
"pattern_match_score": pattern.confidence if pattern else 0,
"control_uuid": control_uuid,
},
)
# 2. Write crosswalk_matrix row
self.db.execute(
text("""
INSERT INTO crosswalk_matrix (
regulation_code, article, paragraph,
obligation_id, pattern_id,
master_control_id, master_control_uuid,
confidence, source
) VALUES (
:regulation_code, :article, :paragraph,
:obligation_id, :pattern_id,
:master_control_id,
CAST(:master_control_uuid AS uuid),
:confidence, :source
)
"""),
{
"regulation_code": chunk.regulation_code,
"article": chunk.article,
"paragraph": chunk.paragraph,
"obligation_id": obligation.obligation_id if obligation else None,
"pattern_id": pattern.pattern_id if pattern else None,
"master_control_id": result.control.control_id,
"master_control_uuid": control_uuid,
"confidence": min(
obligation.confidence if obligation else 0,
pattern.confidence if pattern else 0,
),
"source": "auto",
},
)
# 3. Update canonical_controls with pattern_id + obligation_ids
if result.control.pattern_id or result.control.obligation_ids:
self.db.execute(
text("""
UPDATE canonical_controls
SET pattern_id = COALESCE(:pattern_id, pattern_id),
obligation_ids = COALESCE(:obligation_ids, obligation_ids)
WHERE id = CAST(:control_uuid AS uuid)
"""),
{
"pattern_id": result.control.pattern_id,
"obligation_ids": json.dumps(result.control.obligation_ids),
"control_uuid": control_uuid,
},
)
self.db.commit()
result.crosswalk_written = True
return True
except Exception as e:
logger.error("Failed to write crosswalk: %s", e)
self.db.rollback()
return False
def stats(self) -> dict:
"""Return component statistics."""
return {
"extractor": self._extractor.stats(),
"matcher": self._matcher.stats(),
"initialized": self._initialized,
}
# ---------------------------------------------------------------------------
# Migration Passes — Backfill existing 4,800+ controls
# ---------------------------------------------------------------------------
class MigrationPasses:
"""Non-destructive migration passes for existing controls.
Pass 1: Obligation Linkage (deterministic, article→obligation lookup)
Pass 2: Pattern Classification (keyword-based matching)
Pass 3: Quality Triage (categorize by linkage completeness)
Pass 4: Crosswalk Backfill (write crosswalk rows for linked controls)
Pass 5: Deduplication (mark duplicate controls)
Usage::
migration = MigrationPasses(db)
await migration.initialize()
result = await migration.run_pass1_obligation_linkage(limit=100)
result = await migration.run_pass2_pattern_classification(limit=100)
result = migration.run_pass3_quality_triage()
result = migration.run_pass4_crosswalk_backfill()
result = migration.run_pass5_deduplication()
"""
def __init__(self, db: Session):
self.db = db
self._extractor = ObligationExtractor()
self._matcher = PatternMatcher()
self._initialized = False
async def initialize(self) -> None:
"""Initialize extractors (loads obligations + patterns)."""
if self._initialized:
return
self._extractor._load_obligations()
self._matcher._load_patterns()
self._matcher._build_keyword_index()
self._initialized = True
# -------------------------------------------------------------------
# Pass 1: Obligation Linkage (deterministic)
# -------------------------------------------------------------------
async def run_pass1_obligation_linkage(self, limit: int = 0) -> dict:
"""Link existing controls to obligations via source_citation article.
For each control with source_citation → extract regulation + article
→ look up in obligation framework → set obligation_ids.
"""
if not self._initialized:
await self.initialize()
query = """
SELECT id, control_id, source_citation, generation_metadata
FROM canonical_controls
WHERE release_state NOT IN ('deprecated')
AND (obligation_ids IS NULL OR obligation_ids = '[]')
"""
if limit > 0:
query += f" LIMIT {limit}"
rows = self.db.execute(text(query)).fetchall()
stats = {"total": len(rows), "linked": 0, "no_match": 0, "no_citation": 0}
for row in rows:
control_uuid = str(row[0])
control_id = row[1]
citation = row[2]
metadata = row[3]
# Extract regulation + article from citation or metadata
reg_code, article = _extract_regulation_article(citation, metadata)
if not reg_code:
stats["no_citation"] += 1
continue
# Tier 1: Exact match
match = self._extractor._tier1_exact(reg_code, article or "")
if match and match.obligation_id:
self.db.execute(
text("""
UPDATE canonical_controls
SET obligation_ids = :obl_ids
WHERE id = CAST(:uuid AS uuid)
"""),
{
"obl_ids": json.dumps([match.obligation_id]),
"uuid": control_uuid,
},
)
stats["linked"] += 1
else:
stats["no_match"] += 1
self.db.commit()
logger.info("Pass 1: %s", stats)
return stats
# -------------------------------------------------------------------
# Pass 2: Pattern Classification (keyword-based)
# -------------------------------------------------------------------
async def run_pass2_pattern_classification(self, limit: int = 0) -> dict:
"""Classify existing controls into patterns via keyword matching.
For each control without pattern_id → keyword-match title+objective
against pattern library → assign best match.
"""
if not self._initialized:
await self.initialize()
query = """
SELECT id, control_id, title, objective
FROM canonical_controls
WHERE release_state NOT IN ('deprecated')
AND (pattern_id IS NULL OR pattern_id = '')
"""
if limit > 0:
query += f" LIMIT {limit}"
rows = self.db.execute(text(query)).fetchall()
stats = {"total": len(rows), "classified": 0, "no_match": 0}
for row in rows:
control_uuid = str(row[0])
title = row[2] or ""
objective = row[3] or ""
# Keyword match
match_text = f"{title} {objective}"
result = self._matcher._tier1_keyword(match_text, None)
if result and result.pattern_id and result.keyword_hits >= 2:
self.db.execute(
text("""
UPDATE canonical_controls
SET pattern_id = :pattern_id
WHERE id = CAST(:uuid AS uuid)
"""),
{
"pattern_id": result.pattern_id,
"uuid": control_uuid,
},
)
stats["classified"] += 1
else:
stats["no_match"] += 1
self.db.commit()
logger.info("Pass 2: %s", stats)
return stats
# -------------------------------------------------------------------
# Pass 3: Quality Triage
# -------------------------------------------------------------------
def run_pass3_quality_triage(self) -> dict:
"""Categorize controls by linkage completeness.
Sets generation_metadata.triage_status:
- "review": has both obligation_id + pattern_id
- "needs_obligation": has pattern_id but no obligation_id
- "needs_pattern": has obligation_id but no pattern_id
- "legacy_unlinked": has neither
"""
categories = {
"review": """
UPDATE canonical_controls
SET generation_metadata = jsonb_set(
COALESCE(generation_metadata::jsonb, '{}'::jsonb),
'{triage_status}', '"review"'
)
WHERE release_state NOT IN ('deprecated')
AND obligation_ids IS NOT NULL AND obligation_ids != '[]'
AND pattern_id IS NOT NULL AND pattern_id != ''
""",
"needs_obligation": """
UPDATE canonical_controls
SET generation_metadata = jsonb_set(
COALESCE(generation_metadata::jsonb, '{}'::jsonb),
'{triage_status}', '"needs_obligation"'
)
WHERE release_state NOT IN ('deprecated')
AND (obligation_ids IS NULL OR obligation_ids = '[]')
AND pattern_id IS NOT NULL AND pattern_id != ''
""",
"needs_pattern": """
UPDATE canonical_controls
SET generation_metadata = jsonb_set(
COALESCE(generation_metadata::jsonb, '{}'::jsonb),
'{triage_status}', '"needs_pattern"'
)
WHERE release_state NOT IN ('deprecated')
AND obligation_ids IS NOT NULL AND obligation_ids != '[]'
AND (pattern_id IS NULL OR pattern_id = '')
""",
"legacy_unlinked": """
UPDATE canonical_controls
SET generation_metadata = jsonb_set(
COALESCE(generation_metadata::jsonb, '{}'::jsonb),
'{triage_status}', '"legacy_unlinked"'
)
WHERE release_state NOT IN ('deprecated')
AND (obligation_ids IS NULL OR obligation_ids = '[]')
AND (pattern_id IS NULL OR pattern_id = '')
""",
}
stats = {}
for category, sql in categories.items():
result = self.db.execute(text(sql))
stats[category] = result.rowcount
self.db.commit()
logger.info("Pass 3: %s", stats)
return stats
# -------------------------------------------------------------------
# Pass 4: Crosswalk Backfill
# -------------------------------------------------------------------
def run_pass4_crosswalk_backfill(self) -> dict:
"""Create crosswalk_matrix rows for controls with obligation + pattern.
Only creates rows that don't already exist.
"""
result = self.db.execute(text("""
INSERT INTO crosswalk_matrix (
regulation_code, obligation_id, pattern_id,
master_control_id, master_control_uuid,
confidence, source
)
SELECT
COALESCE(
(generation_metadata::jsonb->>'source_regulation'),
''
) AS regulation_code,
obl.value::text AS obligation_id,
cc.pattern_id,
cc.control_id,
cc.id,
0.80,
'migrated'
FROM canonical_controls cc,
jsonb_array_elements_text(
COALESCE(cc.obligation_ids::jsonb, '[]'::jsonb)
) AS obl(value)
WHERE cc.release_state NOT IN ('deprecated')
AND cc.pattern_id IS NOT NULL AND cc.pattern_id != ''
AND cc.obligation_ids IS NOT NULL AND cc.obligation_ids != '[]'
AND NOT EXISTS (
SELECT 1 FROM crosswalk_matrix cw
WHERE cw.master_control_uuid = cc.id
AND cw.obligation_id = obl.value::text
)
"""))
rows_inserted = result.rowcount
self.db.commit()
logger.info("Pass 4: %d crosswalk rows inserted", rows_inserted)
return {"rows_inserted": rows_inserted}
# -------------------------------------------------------------------
# Pass 5: Deduplication
# -------------------------------------------------------------------
def run_pass5_deduplication(self) -> dict:
"""Mark duplicate controls (same obligation + same pattern).
Groups controls by (obligation_id, pattern_id), keeps the one with
highest evidence_confidence (or newest), marks rest as deprecated.
"""
# Find groups with duplicates
groups = self.db.execute(text("""
SELECT cc.pattern_id,
obl.value::text AS obligation_id,
array_agg(cc.id ORDER BY cc.evidence_confidence DESC NULLS LAST, cc.created_at DESC) AS ids,
count(*) AS cnt
FROM canonical_controls cc,
jsonb_array_elements_text(
COALESCE(cc.obligation_ids::jsonb, '[]'::jsonb)
) AS obl(value)
WHERE cc.release_state NOT IN ('deprecated')
AND cc.pattern_id IS NOT NULL AND cc.pattern_id != ''
GROUP BY cc.pattern_id, obl.value::text
HAVING count(*) > 1
""")).fetchall()
stats = {"groups_found": len(groups), "controls_deprecated": 0}
for group in groups:
ids = group[2] # Array of UUIDs, first is the keeper
if len(ids) <= 1:
continue
# Keep first (highest confidence), deprecate rest
deprecate_ids = ids[1:]
for dep_id in deprecate_ids:
self.db.execute(
text("""
UPDATE canonical_controls
SET release_state = 'deprecated',
generation_metadata = jsonb_set(
COALESCE(generation_metadata::jsonb, '{}'::jsonb),
'{deprecated_reason}', '"duplicate_same_obligation_pattern"'
)
WHERE id = CAST(:uuid AS uuid)
AND release_state != 'deprecated'
"""),
{"uuid": str(dep_id)},
)
stats["controls_deprecated"] += 1
self.db.commit()
logger.info("Pass 5: %s", stats)
return stats
def migration_status(self) -> dict:
"""Return overall migration progress."""
row = self.db.execute(text("""
SELECT
count(*) AS total,
count(*) FILTER (WHERE obligation_ids IS NOT NULL AND obligation_ids != '[]') AS has_obligation,
count(*) FILTER (WHERE pattern_id IS NOT NULL AND pattern_id != '') AS has_pattern,
count(*) FILTER (
WHERE obligation_ids IS NOT NULL AND obligation_ids != '[]'
AND pattern_id IS NOT NULL AND pattern_id != ''
) AS fully_linked,
count(*) FILTER (WHERE release_state = 'deprecated') AS deprecated
FROM canonical_controls
""")).fetchone()
return {
"total_controls": row[0],
"has_obligation": row[1],
"has_pattern": row[2],
"fully_linked": row[3],
"deprecated": row[4],
"coverage_obligation_pct": round(row[1] / max(row[0], 1) * 100, 1),
"coverage_pattern_pct": round(row[2] / max(row[0], 1) * 100, 1),
"coverage_full_pct": round(row[3] / max(row[0], 1) * 100, 1),
}
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _extract_regulation_article(
citation: Optional[str], metadata: Optional[str]
) -> tuple[Optional[str], Optional[str]]:
"""Extract regulation_code and article from control's citation/metadata."""
from services.obligation_extractor import _normalize_regulation
reg_code = None
article = None
# Try citation first (JSON string or dict)
if citation:
try:
c = json.loads(citation) if isinstance(citation, str) else citation
if isinstance(c, dict):
article = c.get("article") or c.get("source_article")
# Try to get regulation from source field
source = c.get("source", "")
if source:
reg_code = _normalize_regulation(source)
except (json.JSONDecodeError, TypeError):
pass
# Try metadata
if metadata and not reg_code:
try:
m = json.loads(metadata) if isinstance(metadata, str) else metadata
if isinstance(m, dict):
src_reg = m.get("source_regulation", "")
if src_reg:
reg_code = _normalize_regulation(src_reg)
if not article:
article = m.get("source_article")
except (json.JSONDecodeError, TypeError):
pass
return reg_code, article