"""Pattern Matcher — Obligation-to-Control-Pattern Linking. Maps obligations (from the ObligationExtractor) to control patterns using two tiers: Tier 1: KEYWORD MATCH — obligation_match_keywords from patterns (~70%) Tier 2: EMBEDDING — cosine similarity with domain bonus (~25%) Part of the Multi-Layer Control Architecture (Phase 5 of 8). """ import logging import os from dataclasses import dataclass, field from pathlib import Path from typing import Optional import yaml from compliance.services.obligation_extractor import ( _cosine_sim, _get_embedding, _get_embeddings_batch, ) logger = logging.getLogger(__name__) # Minimum keyword score to accept a match (at least 2 keyword hits) KEYWORD_MATCH_MIN_HITS = 2 # Embedding threshold for Tier 2 EMBEDDING_PATTERN_THRESHOLD = 0.75 # Domain bonus when regulation maps to the pattern's domain DOMAIN_BONUS = 0.10 # Map regulation IDs to pattern domains that are likely relevant _REGULATION_DOMAIN_AFFINITY = { "dsgvo": ["DATA", "COMP", "GOV"], "bdsg": ["DATA", "COMP"], "ttdsg": ["DATA"], "ai_act": ["AI", "COMP", "DATA"], "nis2": ["SEC", "INC", "NET", "LOG", "CRYP"], "dsa": ["DATA", "COMP"], "data_act": ["DATA", "COMP"], "eu_machinery": ["SEC", "COMP"], "dora": ["SEC", "INC", "FIN", "COMP"], } @dataclass class ControlPattern: """Python representation of a control pattern from YAML.""" id: str name: str name_de: str domain: str category: str description: str objective_template: str rationale_template: str requirements_template: list[str] = field(default_factory=list) test_procedure_template: list[str] = field(default_factory=list) evidence_template: list[str] = field(default_factory=list) severity_default: str = "medium" implementation_effort_default: str = "m" obligation_match_keywords: list[str] = field(default_factory=list) tags: list[str] = field(default_factory=list) composable_with: list[str] = field(default_factory=list) open_anchor_refs: list[dict] = field(default_factory=list) @dataclass class PatternMatchResult: """Result of pattern matching.""" pattern: Optional[ControlPattern] = None pattern_id: Optional[str] = None method: str = "none" # keyword | embedding | combined | none confidence: float = 0.0 keyword_hits: int = 0 total_keywords: int = 0 embedding_score: float = 0.0 domain_bonus_applied: bool = False composable_patterns: list[str] = field(default_factory=list) def to_dict(self) -> dict: return { "pattern_id": self.pattern_id, "method": self.method, "confidence": round(self.confidence, 3), "keyword_hits": self.keyword_hits, "total_keywords": self.total_keywords, "embedding_score": round(self.embedding_score, 3), "domain_bonus_applied": self.domain_bonus_applied, "composable_patterns": self.composable_patterns, } class PatternMatcher: """Links obligations to control patterns using keyword + embedding matching. Usage:: matcher = PatternMatcher() await matcher.initialize() result = await matcher.match( obligation_text="Fuehrung eines Verarbeitungsverzeichnisses...", regulation_id="dsgvo", ) print(result.pattern_id) # e.g. "CP-COMP-001" print(result.confidence) # e.g. 0.85 """ def __init__(self): self._patterns: list[ControlPattern] = [] self._by_id: dict[str, ControlPattern] = {} self._by_domain: dict[str, list[ControlPattern]] = {} self._keyword_index: dict[str, list[str]] = {} # keyword → [pattern_ids] self._pattern_embeddings: list[list[float]] = [] self._pattern_ids: list[str] = [] self._initialized = False async def initialize(self) -> None: """Load patterns from YAML and compute embeddings.""" if self._initialized: return self._load_patterns() self._build_keyword_index() await self._compute_embeddings() self._initialized = True logger.info( "PatternMatcher initialized: %d patterns, %d keywords, %d embeddings", len(self._patterns), len(self._keyword_index), sum(1 for e in self._pattern_embeddings if e), ) async def match( self, obligation_text: str, regulation_id: Optional[str] = None, top_n: int = 1, ) -> PatternMatchResult: """Match obligation text to the best control pattern. Args: obligation_text: The obligation description to match against. regulation_id: Source regulation (for domain bonus). top_n: Number of top results to consider for composability. Returns: PatternMatchResult with the best match. """ if not self._initialized: await self.initialize() if not obligation_text or not self._patterns: return PatternMatchResult() # Tier 1: Keyword matching keyword_result = self._tier1_keyword(obligation_text, regulation_id) # Tier 2: Embedding matching embedding_result = await self._tier2_embedding(obligation_text, regulation_id) # Combine scores: prefer keyword match, boost with embedding if available best = self._combine_results(keyword_result, embedding_result) # Attach composable patterns if best.pattern: best.composable_patterns = [ pid for pid in best.pattern.composable_with if pid in self._by_id ] return best async def match_top_n( self, obligation_text: str, regulation_id: Optional[str] = None, n: int = 3, ) -> list[PatternMatchResult]: """Return top-N pattern matches sorted by confidence descending.""" if not self._initialized: await self.initialize() if not obligation_text or not self._patterns: return [] keyword_scores = self._keyword_scores(obligation_text, regulation_id) embedding_scores = await self._embedding_scores(obligation_text, regulation_id) # Merge scores all_pattern_ids = set(keyword_scores.keys()) | set(embedding_scores.keys()) results: list[PatternMatchResult] = [] for pid in all_pattern_ids: pattern = self._by_id.get(pid) if not pattern: continue kw_score = keyword_scores.get(pid, (0, 0, 0.0)) # (hits, total, score) emb_score = embedding_scores.get(pid, (0.0, False)) # (score, bonus_applied) kw_hits, kw_total, kw_confidence = kw_score emb_confidence, bonus_applied = emb_score # Combined confidence: max of keyword and embedding, with boost if both if kw_confidence > 0 and emb_confidence > 0: combined = max(kw_confidence, emb_confidence) + 0.05 method = "combined" elif kw_confidence > 0: combined = kw_confidence method = "keyword" else: combined = emb_confidence method = "embedding" results.append(PatternMatchResult( pattern=pattern, pattern_id=pid, method=method, confidence=min(combined, 1.0), keyword_hits=kw_hits, total_keywords=kw_total, embedding_score=emb_confidence, domain_bonus_applied=bonus_applied, composable_patterns=[ p for p in pattern.composable_with if p in self._by_id ], )) # Sort by confidence descending results.sort(key=lambda r: r.confidence, reverse=True) return results[:n] # ----------------------------------------------------------------------- # Tier 1: Keyword Match # ----------------------------------------------------------------------- def _tier1_keyword( self, obligation_text: str, regulation_id: Optional[str] ) -> Optional[PatternMatchResult]: """Match by counting keyword hits in the obligation text.""" scores = self._keyword_scores(obligation_text, regulation_id) if not scores: return None # Find best match best_pid = max(scores, key=lambda pid: scores[pid][2]) hits, total, confidence = scores[best_pid] if hits < KEYWORD_MATCH_MIN_HITS: return None pattern = self._by_id.get(best_pid) if not pattern: return None # Check domain bonus bonus_applied = False if regulation_id and self._domain_matches(pattern.domain, regulation_id): confidence = min(confidence + DOMAIN_BONUS, 1.0) bonus_applied = True return PatternMatchResult( pattern=pattern, pattern_id=best_pid, method="keyword", confidence=confidence, keyword_hits=hits, total_keywords=total, domain_bonus_applied=bonus_applied, ) def _keyword_scores( self, text: str, regulation_id: Optional[str] ) -> dict[str, tuple[int, int, float]]: """Compute keyword match scores for all patterns. Returns dict: pattern_id → (hits, total_keywords, confidence). """ text_lower = text.lower() hits_by_pattern: dict[str, int] = {} for keyword, pattern_ids in self._keyword_index.items(): if keyword in text_lower: for pid in pattern_ids: hits_by_pattern[pid] = hits_by_pattern.get(pid, 0) + 1 result: dict[str, tuple[int, int, float]] = {} for pid, hits in hits_by_pattern.items(): pattern = self._by_id.get(pid) if not pattern: continue total = len(pattern.obligation_match_keywords) confidence = hits / total if total > 0 else 0.0 result[pid] = (hits, total, confidence) return result # ----------------------------------------------------------------------- # Tier 2: Embedding Match # ----------------------------------------------------------------------- async def _tier2_embedding( self, obligation_text: str, regulation_id: Optional[str] ) -> Optional[PatternMatchResult]: """Match by embedding similarity against pattern objective_templates.""" scores = await self._embedding_scores(obligation_text, regulation_id) if not scores: return None best_pid = max(scores, key=lambda pid: scores[pid][0]) emb_score, bonus_applied = scores[best_pid] if emb_score < EMBEDDING_PATTERN_THRESHOLD: return None pattern = self._by_id.get(best_pid) if not pattern: return None return PatternMatchResult( pattern=pattern, pattern_id=best_pid, method="embedding", confidence=min(emb_score, 1.0), embedding_score=emb_score, domain_bonus_applied=bonus_applied, ) async def _embedding_scores( self, obligation_text: str, regulation_id: Optional[str] ) -> dict[str, tuple[float, bool]]: """Compute embedding similarity scores for all patterns. Returns dict: pattern_id → (score, domain_bonus_applied). """ if not self._pattern_embeddings: return {} chunk_embedding = await _get_embedding(obligation_text[:2000]) if not chunk_embedding: return {} result: dict[str, tuple[float, bool]] = {} for i, pat_emb in enumerate(self._pattern_embeddings): if not pat_emb: continue pid = self._pattern_ids[i] pattern = self._by_id.get(pid) if not pattern: continue score = _cosine_sim(chunk_embedding, pat_emb) # Domain bonus bonus_applied = False if regulation_id and self._domain_matches(pattern.domain, regulation_id): score += DOMAIN_BONUS bonus_applied = True result[pid] = (score, bonus_applied) return result # ----------------------------------------------------------------------- # Score combination # ----------------------------------------------------------------------- def _combine_results( self, keyword_result: Optional[PatternMatchResult], embedding_result: Optional[PatternMatchResult], ) -> PatternMatchResult: """Combine keyword and embedding results into the best match.""" if not keyword_result and not embedding_result: return PatternMatchResult() if not keyword_result: return embedding_result if not embedding_result: return keyword_result # Both matched — check if they agree if keyword_result.pattern_id == embedding_result.pattern_id: # Same pattern: boost confidence combined_confidence = min( max(keyword_result.confidence, embedding_result.confidence) + 0.05, 1.0, ) return PatternMatchResult( pattern=keyword_result.pattern, pattern_id=keyword_result.pattern_id, method="combined", confidence=combined_confidence, keyword_hits=keyword_result.keyword_hits, total_keywords=keyword_result.total_keywords, embedding_score=embedding_result.embedding_score, domain_bonus_applied=( keyword_result.domain_bonus_applied or embedding_result.domain_bonus_applied ), ) # Different patterns: pick the one with higher confidence if keyword_result.confidence >= embedding_result.confidence: return keyword_result return embedding_result # ----------------------------------------------------------------------- # Domain affinity # ----------------------------------------------------------------------- @staticmethod def _domain_matches(pattern_domain: str, regulation_id: str) -> bool: """Check if a pattern's domain has affinity with a regulation.""" affine_domains = _REGULATION_DOMAIN_AFFINITY.get(regulation_id, []) return pattern_domain in affine_domains # ----------------------------------------------------------------------- # Initialization helpers # ----------------------------------------------------------------------- def _load_patterns(self) -> None: """Load control patterns from YAML files.""" patterns_dir = _find_patterns_dir() if not patterns_dir: logger.warning("Control patterns directory not found") return for yaml_file in sorted(patterns_dir.glob("*.yaml")): if yaml_file.name.startswith("_"): continue try: with open(yaml_file) as f: data = yaml.safe_load(f) if not data or "patterns" not in data: continue for p in data["patterns"]: pattern = ControlPattern( id=p["id"], name=p["name"], name_de=p["name_de"], domain=p["domain"], category=p["category"], description=p["description"], objective_template=p["objective_template"], rationale_template=p["rationale_template"], requirements_template=p.get("requirements_template", []), test_procedure_template=p.get("test_procedure_template", []), evidence_template=p.get("evidence_template", []), severity_default=p.get("severity_default", "medium"), implementation_effort_default=p.get("implementation_effort_default", "m"), obligation_match_keywords=p.get("obligation_match_keywords", []), tags=p.get("tags", []), composable_with=p.get("composable_with", []), open_anchor_refs=p.get("open_anchor_refs", []), ) self._patterns.append(pattern) self._by_id[pattern.id] = pattern domain_list = self._by_domain.setdefault(pattern.domain, []) domain_list.append(pattern) except Exception as e: logger.error("Failed to load %s: %s", yaml_file.name, e) logger.info("Loaded %d patterns from %s", len(self._patterns), patterns_dir) def _build_keyword_index(self) -> None: """Build reverse index: keyword → [pattern_ids].""" for pattern in self._patterns: for kw in pattern.obligation_match_keywords: lower_kw = kw.lower() if lower_kw not in self._keyword_index: self._keyword_index[lower_kw] = [] self._keyword_index[lower_kw].append(pattern.id) async def _compute_embeddings(self) -> None: """Compute embeddings for all pattern objective templates.""" if not self._patterns: return self._pattern_ids = [p.id for p in self._patterns] texts = [ f"{p.name_de}: {p.objective_template}" for p in self._patterns ] logger.info("Computing embeddings for %d patterns...", len(texts)) self._pattern_embeddings = await _get_embeddings_batch(texts) valid = sum(1 for e in self._pattern_embeddings if e) logger.info("Got %d/%d valid pattern embeddings", valid, len(texts)) # ----------------------------------------------------------------------- # Public helpers # ----------------------------------------------------------------------- def get_pattern(self, pattern_id: str) -> Optional[ControlPattern]: """Get a pattern by its ID.""" return self._by_id.get(pattern_id.upper()) def get_patterns_by_domain(self, domain: str) -> list[ControlPattern]: """Get all patterns for a domain.""" return self._by_domain.get(domain.upper(), []) def stats(self) -> dict: """Return matcher statistics.""" return { "total_patterns": len(self._patterns), "domains": list(self._by_domain.keys()), "keywords": len(self._keyword_index), "embeddings_valid": sum(1 for e in self._pattern_embeddings if e), "initialized": self._initialized, } def _find_patterns_dir() -> Optional[Path]: """Locate the control_patterns directory.""" candidates = [ Path(__file__).resolve().parent.parent.parent.parent / "ai-compliance-sdk" / "policies" / "control_patterns", Path("/app/ai-compliance-sdk/policies/control_patterns"), Path("ai-compliance-sdk/policies/control_patterns"), ] for p in candidates: if p.is_dir(): return p return None