"""Pipeline Adapter — New 10-Stage Pipeline Integration. Bridges the existing 7-stage control_generator pipeline with the new multi-layer components (ObligationExtractor, PatternMatcher, ControlComposer). New pipeline flow: chunk → license_classify → obligation_extract (Stage 4 — NEW) → pattern_match (Stage 5 — NEW) → control_compose (Stage 6 — replaces old Stage 3) → harmonize → anchor → store + crosswalk → mark processed Can be used in two modes: 1. INLINE: Called from _process_batch() to enrich the pipeline 2. STANDALONE: Process chunks directly through new stages Part of the Multi-Layer Control Architecture (Phase 7 of 8). """ import hashlib import json import logging from dataclasses import dataclass, field from typing import Optional from sqlalchemy import text from sqlalchemy.orm import Session from services.control_composer import ComposedControl, ControlComposer from services.obligation_extractor import ObligationExtractor, ObligationMatch from services.pattern_matcher import PatternMatcher, PatternMatchResult logger = logging.getLogger(__name__) @dataclass class PipelineChunk: """Input chunk for the new pipeline stages.""" text: str collection: str = "" regulation_code: str = "" article: Optional[str] = None paragraph: Optional[str] = None license_rule: int = 3 license_info: dict = field(default_factory=dict) source_citation: Optional[dict] = None chunk_hash: str = "" def compute_hash(self) -> str: if not self.chunk_hash: self.chunk_hash = hashlib.sha256(self.text.encode()).hexdigest() return self.chunk_hash @dataclass class PipelineResult: """Result of processing a chunk through the new pipeline.""" chunk: PipelineChunk obligation: ObligationMatch = field(default_factory=ObligationMatch) pattern_result: PatternMatchResult = field(default_factory=PatternMatchResult) control: Optional[ComposedControl] = None crosswalk_written: bool = False error: Optional[str] = None def to_dict(self) -> dict: return { "chunk_hash": self.chunk.chunk_hash, "obligation": self.obligation.to_dict() if self.obligation else None, "pattern": self.pattern_result.to_dict() if self.pattern_result else None, "control": self.control.to_dict() if self.control else None, "crosswalk_written": self.crosswalk_written, "error": self.error, } class PipelineAdapter: """Integrates ObligationExtractor + PatternMatcher + ControlComposer. Usage:: adapter = PipelineAdapter(db) await adapter.initialize() result = await adapter.process_chunk(PipelineChunk( text="...", regulation_code="eu_2016_679", article="Art. 30", license_rule=1, )) """ def __init__(self, db: Optional[Session] = None): self.db = db self._extractor = ObligationExtractor() self._matcher = PatternMatcher() self._composer = ControlComposer() self._initialized = False async def initialize(self) -> None: """Initialize all sub-components.""" if self._initialized: return await self._extractor.initialize() await self._matcher.initialize() self._initialized = True logger.info("PipelineAdapter initialized") async def process_chunk(self, chunk: PipelineChunk) -> PipelineResult: """Process a single chunk through the new 3-stage pipeline. Stage 4: Obligation Extract Stage 5: Pattern Match Stage 6: Control Compose """ if not self._initialized: await self.initialize() chunk.compute_hash() result = PipelineResult(chunk=chunk) try: # Stage 4: Obligation Extract result.obligation = await self._extractor.extract( chunk_text=chunk.text, regulation_code=chunk.regulation_code, article=chunk.article, paragraph=chunk.paragraph, ) # Stage 5: Pattern Match obligation_text = ( result.obligation.obligation_text or result.obligation.obligation_title or chunk.text[:500] ) result.pattern_result = await self._matcher.match( obligation_text=obligation_text, regulation_id=result.obligation.regulation_id, ) # Stage 6: Control Compose result.control = await self._composer.compose( obligation=result.obligation, pattern_result=result.pattern_result, chunk_text=chunk.text if chunk.license_rule in (1, 2) else None, license_rule=chunk.license_rule, source_citation=chunk.source_citation, regulation_code=chunk.regulation_code, ) except Exception as e: logger.error("Pipeline processing failed: %s", e) result.error = str(e) return result async def process_batch(self, chunks: list[PipelineChunk]) -> list[PipelineResult]: """Process multiple chunks through the pipeline.""" results = [] for chunk in chunks: result = await self.process_chunk(chunk) results.append(result) return results def write_crosswalk(self, result: PipelineResult, control_uuid: str) -> bool: """Write obligation_extraction + crosswalk_matrix rows for a processed chunk. Called AFTER the control is stored in canonical_controls. """ if not self.db or not result.control: return False chunk = result.chunk obligation = result.obligation pattern = result.pattern_result try: # 1. Write obligation_extraction row self.db.execute( text(""" INSERT INTO obligation_extractions ( chunk_hash, collection, regulation_code, article, paragraph, obligation_id, obligation_text, confidence, extraction_method, pattern_id, pattern_match_score, control_uuid ) VALUES ( :chunk_hash, :collection, :regulation_code, :article, :paragraph, :obligation_id, :obligation_text, :confidence, :extraction_method, :pattern_id, :pattern_match_score, CAST(:control_uuid AS uuid) ) """), { "chunk_hash": chunk.chunk_hash, "collection": chunk.collection, "regulation_code": chunk.regulation_code, "article": chunk.article, "paragraph": chunk.paragraph, "obligation_id": obligation.obligation_id if obligation else None, "obligation_text": ( obligation.obligation_text[:2000] if obligation and obligation.obligation_text else None ), "confidence": obligation.confidence if obligation else 0, "extraction_method": obligation.method if obligation else "none", "pattern_id": pattern.pattern_id if pattern else None, "pattern_match_score": pattern.confidence if pattern else 0, "control_uuid": control_uuid, }, ) # 2. Write crosswalk_matrix row self.db.execute( text(""" INSERT INTO crosswalk_matrix ( regulation_code, article, paragraph, obligation_id, pattern_id, master_control_id, master_control_uuid, confidence, source ) VALUES ( :regulation_code, :article, :paragraph, :obligation_id, :pattern_id, :master_control_id, CAST(:master_control_uuid AS uuid), :confidence, :source ) """), { "regulation_code": chunk.regulation_code, "article": chunk.article, "paragraph": chunk.paragraph, "obligation_id": obligation.obligation_id if obligation else None, "pattern_id": pattern.pattern_id if pattern else None, "master_control_id": result.control.control_id, "master_control_uuid": control_uuid, "confidence": min( obligation.confidence if obligation else 0, pattern.confidence if pattern else 0, ), "source": "auto", }, ) # 3. Update canonical_controls with pattern_id + obligation_ids if result.control.pattern_id or result.control.obligation_ids: self.db.execute( text(""" UPDATE canonical_controls SET pattern_id = COALESCE(:pattern_id, pattern_id), obligation_ids = COALESCE(:obligation_ids, obligation_ids) WHERE id = CAST(:control_uuid AS uuid) """), { "pattern_id": result.control.pattern_id, "obligation_ids": json.dumps(result.control.obligation_ids), "control_uuid": control_uuid, }, ) self.db.commit() result.crosswalk_written = True return True except Exception as e: logger.error("Failed to write crosswalk: %s", e) self.db.rollback() return False def stats(self) -> dict: """Return component statistics.""" return { "extractor": self._extractor.stats(), "matcher": self._matcher.stats(), "initialized": self._initialized, } # --------------------------------------------------------------------------- # Migration Passes — Backfill existing 4,800+ controls # --------------------------------------------------------------------------- class MigrationPasses: """Non-destructive migration passes for existing controls. Pass 1: Obligation Linkage (deterministic, article→obligation lookup) Pass 2: Pattern Classification (keyword-based matching) Pass 3: Quality Triage (categorize by linkage completeness) Pass 4: Crosswalk Backfill (write crosswalk rows for linked controls) Pass 5: Deduplication (mark duplicate controls) Usage:: migration = MigrationPasses(db) await migration.initialize() result = await migration.run_pass1_obligation_linkage(limit=100) result = await migration.run_pass2_pattern_classification(limit=100) result = migration.run_pass3_quality_triage() result = migration.run_pass4_crosswalk_backfill() result = migration.run_pass5_deduplication() """ def __init__(self, db: Session): self.db = db self._extractor = ObligationExtractor() self._matcher = PatternMatcher() self._initialized = False async def initialize(self) -> None: """Initialize extractors (loads obligations + patterns).""" if self._initialized: return self._extractor._load_obligations() self._matcher._load_patterns() self._matcher._build_keyword_index() self._initialized = True # ------------------------------------------------------------------- # Pass 1: Obligation Linkage (deterministic) # ------------------------------------------------------------------- async def run_pass1_obligation_linkage(self, limit: int = 0) -> dict: """Link existing controls to obligations via source_citation article. For each control with source_citation → extract regulation + article → look up in obligation framework → set obligation_ids. """ if not self._initialized: await self.initialize() query = """ SELECT id, control_id, source_citation, generation_metadata FROM canonical_controls WHERE release_state NOT IN ('deprecated') AND (obligation_ids IS NULL OR obligation_ids = '[]') """ if limit > 0: query += f" LIMIT {limit}" rows = self.db.execute(text(query)).fetchall() stats = {"total": len(rows), "linked": 0, "no_match": 0, "no_citation": 0} for row in rows: control_uuid = str(row[0]) control_id = row[1] citation = row[2] metadata = row[3] # Extract regulation + article from citation or metadata reg_code, article = _extract_regulation_article(citation, metadata) if not reg_code: stats["no_citation"] += 1 continue # Tier 1: Exact match match = self._extractor._tier1_exact(reg_code, article or "") if match and match.obligation_id: self.db.execute( text(""" UPDATE canonical_controls SET obligation_ids = :obl_ids WHERE id = CAST(:uuid AS uuid) """), { "obl_ids": json.dumps([match.obligation_id]), "uuid": control_uuid, }, ) stats["linked"] += 1 else: stats["no_match"] += 1 self.db.commit() logger.info("Pass 1: %s", stats) return stats # ------------------------------------------------------------------- # Pass 2: Pattern Classification (keyword-based) # ------------------------------------------------------------------- async def run_pass2_pattern_classification(self, limit: int = 0) -> dict: """Classify existing controls into patterns via keyword matching. For each control without pattern_id → keyword-match title+objective against pattern library → assign best match. """ if not self._initialized: await self.initialize() query = """ SELECT id, control_id, title, objective FROM canonical_controls WHERE release_state NOT IN ('deprecated') AND (pattern_id IS NULL OR pattern_id = '') """ if limit > 0: query += f" LIMIT {limit}" rows = self.db.execute(text(query)).fetchall() stats = {"total": len(rows), "classified": 0, "no_match": 0} for row in rows: control_uuid = str(row[0]) title = row[2] or "" objective = row[3] or "" # Keyword match match_text = f"{title} {objective}" result = self._matcher._tier1_keyword(match_text, None) if result and result.pattern_id and result.keyword_hits >= 2: self.db.execute( text(""" UPDATE canonical_controls SET pattern_id = :pattern_id WHERE id = CAST(:uuid AS uuid) """), { "pattern_id": result.pattern_id, "uuid": control_uuid, }, ) stats["classified"] += 1 else: stats["no_match"] += 1 self.db.commit() logger.info("Pass 2: %s", stats) return stats # ------------------------------------------------------------------- # Pass 3: Quality Triage # ------------------------------------------------------------------- def run_pass3_quality_triage(self) -> dict: """Categorize controls by linkage completeness. Sets generation_metadata.triage_status: - "review": has both obligation_id + pattern_id - "needs_obligation": has pattern_id but no obligation_id - "needs_pattern": has obligation_id but no pattern_id - "legacy_unlinked": has neither """ categories = { "review": """ UPDATE canonical_controls SET generation_metadata = jsonb_set( COALESCE(generation_metadata::jsonb, '{}'::jsonb), '{triage_status}', '"review"' ) WHERE release_state NOT IN ('deprecated') AND obligation_ids IS NOT NULL AND obligation_ids != '[]' AND pattern_id IS NOT NULL AND pattern_id != '' """, "needs_obligation": """ UPDATE canonical_controls SET generation_metadata = jsonb_set( COALESCE(generation_metadata::jsonb, '{}'::jsonb), '{triage_status}', '"needs_obligation"' ) WHERE release_state NOT IN ('deprecated') AND (obligation_ids IS NULL OR obligation_ids = '[]') AND pattern_id IS NOT NULL AND pattern_id != '' """, "needs_pattern": """ UPDATE canonical_controls SET generation_metadata = jsonb_set( COALESCE(generation_metadata::jsonb, '{}'::jsonb), '{triage_status}', '"needs_pattern"' ) WHERE release_state NOT IN ('deprecated') AND obligation_ids IS NOT NULL AND obligation_ids != '[]' AND (pattern_id IS NULL OR pattern_id = '') """, "legacy_unlinked": """ UPDATE canonical_controls SET generation_metadata = jsonb_set( COALESCE(generation_metadata::jsonb, '{}'::jsonb), '{triage_status}', '"legacy_unlinked"' ) WHERE release_state NOT IN ('deprecated') AND (obligation_ids IS NULL OR obligation_ids = '[]') AND (pattern_id IS NULL OR pattern_id = '') """, } stats = {} for category, sql in categories.items(): result = self.db.execute(text(sql)) stats[category] = result.rowcount self.db.commit() logger.info("Pass 3: %s", stats) return stats # ------------------------------------------------------------------- # Pass 4: Crosswalk Backfill # ------------------------------------------------------------------- def run_pass4_crosswalk_backfill(self) -> dict: """Create crosswalk_matrix rows for controls with obligation + pattern. Only creates rows that don't already exist. """ result = self.db.execute(text(""" INSERT INTO crosswalk_matrix ( regulation_code, obligation_id, pattern_id, master_control_id, master_control_uuid, confidence, source ) SELECT COALESCE( (generation_metadata::jsonb->>'source_regulation'), '' ) AS regulation_code, obl.value::text AS obligation_id, cc.pattern_id, cc.control_id, cc.id, 0.80, 'migrated' FROM canonical_controls cc, jsonb_array_elements_text( COALESCE(cc.obligation_ids::jsonb, '[]'::jsonb) ) AS obl(value) WHERE cc.release_state NOT IN ('deprecated') AND cc.pattern_id IS NOT NULL AND cc.pattern_id != '' AND cc.obligation_ids IS NOT NULL AND cc.obligation_ids != '[]' AND NOT EXISTS ( SELECT 1 FROM crosswalk_matrix cw WHERE cw.master_control_uuid = cc.id AND cw.obligation_id = obl.value::text ) """)) rows_inserted = result.rowcount self.db.commit() logger.info("Pass 4: %d crosswalk rows inserted", rows_inserted) return {"rows_inserted": rows_inserted} # ------------------------------------------------------------------- # Pass 5: Deduplication # ------------------------------------------------------------------- def run_pass5_deduplication(self) -> dict: """Mark duplicate controls (same obligation + same pattern). Groups controls by (obligation_id, pattern_id), keeps the one with highest evidence_confidence (or newest), marks rest as deprecated. """ # Find groups with duplicates groups = self.db.execute(text(""" SELECT cc.pattern_id, obl.value::text AS obligation_id, array_agg(cc.id ORDER BY cc.evidence_confidence DESC NULLS LAST, cc.created_at DESC) AS ids, count(*) AS cnt FROM canonical_controls cc, jsonb_array_elements_text( COALESCE(cc.obligation_ids::jsonb, '[]'::jsonb) ) AS obl(value) WHERE cc.release_state NOT IN ('deprecated') AND cc.pattern_id IS NOT NULL AND cc.pattern_id != '' GROUP BY cc.pattern_id, obl.value::text HAVING count(*) > 1 """)).fetchall() stats = {"groups_found": len(groups), "controls_deprecated": 0} for group in groups: ids = group[2] # Array of UUIDs, first is the keeper if len(ids) <= 1: continue # Keep first (highest confidence), deprecate rest deprecate_ids = ids[1:] for dep_id in deprecate_ids: self.db.execute( text(""" UPDATE canonical_controls SET release_state = 'deprecated', generation_metadata = jsonb_set( COALESCE(generation_metadata::jsonb, '{}'::jsonb), '{deprecated_reason}', '"duplicate_same_obligation_pattern"' ) WHERE id = CAST(:uuid AS uuid) AND release_state != 'deprecated' """), {"uuid": str(dep_id)}, ) stats["controls_deprecated"] += 1 self.db.commit() logger.info("Pass 5: %s", stats) return stats def migration_status(self) -> dict: """Return overall migration progress.""" row = self.db.execute(text(""" SELECT count(*) AS total, count(*) FILTER (WHERE obligation_ids IS NOT NULL AND obligation_ids != '[]') AS has_obligation, count(*) FILTER (WHERE pattern_id IS NOT NULL AND pattern_id != '') AS has_pattern, count(*) FILTER ( WHERE obligation_ids IS NOT NULL AND obligation_ids != '[]' AND pattern_id IS NOT NULL AND pattern_id != '' ) AS fully_linked, count(*) FILTER (WHERE release_state = 'deprecated') AS deprecated FROM canonical_controls """)).fetchone() return { "total_controls": row[0], "has_obligation": row[1], "has_pattern": row[2], "fully_linked": row[3], "deprecated": row[4], "coverage_obligation_pct": round(row[1] / max(row[0], 1) * 100, 1), "coverage_pattern_pct": round(row[2] / max(row[0], 1) * 100, 1), "coverage_full_pct": round(row[3] / max(row[0], 1) * 100, 1), } # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- def _extract_regulation_article( citation: Optional[str], metadata: Optional[str] ) -> tuple[Optional[str], Optional[str]]: """Extract regulation_code and article from control's citation/metadata.""" from services.obligation_extractor import _normalize_regulation reg_code = None article = None # Try citation first (JSON string or dict) if citation: try: c = json.loads(citation) if isinstance(citation, str) else citation if isinstance(c, dict): article = c.get("article") or c.get("source_article") # Try to get regulation from source field source = c.get("source", "") if source: reg_code = _normalize_regulation(source) except (json.JSONDecodeError, TypeError): pass # Try metadata if metadata and not reg_code: try: m = json.loads(metadata) if isinstance(metadata, str) else metadata if isinstance(m, dict): src_reg = m.get("source_regulation", "") if src_reg: reg_code = _normalize_regulation(src_reg) if not article: article = m.get("source_article") except (json.JSONDecodeError, TypeError): pass return reg_code, article