"""Obligation Extractor — 3-Tier Chunk-to-Obligation Linking. Maps RAG chunks to obligations from the v2 obligation framework using three tiers (fastest first): Tier 1: EXACT MATCH — regulation_code + article → obligation_id (~40%) Tier 2: EMBEDDING — chunk text vs. obligation descriptions (~30%) Tier 3: LLM EXTRACT — local Ollama extracts obligation text (~25%) Part of the Multi-Layer Control Architecture (Phase 4 of 8). """ import json import logging import os import re from dataclasses import dataclass, field from pathlib import Path from typing import Optional import httpx logger = logging.getLogger(__name__) EMBEDDING_URL = os.getenv("EMBEDDING_URL", "http://embedding-service:8087") OLLAMA_URL = os.getenv("OLLAMA_URL", "http://host.docker.internal:11434") OLLAMA_MODEL = os.getenv("CONTROL_GEN_OLLAMA_MODEL", "qwen3.5:35b-a3b") LLM_TIMEOUT = float(os.getenv("CONTROL_GEN_LLM_TIMEOUT", "180")) # Embedding similarity thresholds for Tier 2 EMBEDDING_MATCH_THRESHOLD = 0.80 EMBEDDING_CANDIDATE_THRESHOLD = 0.60 # --------------------------------------------------------------------------- # Regulation code mapping: RAG chunk codes → obligation file regulation IDs # --------------------------------------------------------------------------- _REGULATION_CODE_TO_ID = { # DSGVO "eu_2016_679": "dsgvo", "dsgvo": "dsgvo", "gdpr": "dsgvo", # AI Act "eu_2024_1689": "ai_act", "ai_act": "ai_act", "aiact": "ai_act", # NIS2 "eu_2022_2555": "nis2", "nis2": "nis2", "bsig": "nis2", # BDSG "bdsg": "bdsg", # TTDSG "ttdsg": "ttdsg", # DSA "eu_2022_2065": "dsa", "dsa": "dsa", # Data Act "eu_2023_2854": "data_act", "data_act": "data_act", # EU Machinery "eu_2023_1230": "eu_machinery", "eu_machinery": "eu_machinery", # DORA "eu_2022_2554": "dora", "dora": "dora", } @dataclass class ObligationMatch: """Result of obligation extraction.""" obligation_id: Optional[str] = None obligation_title: Optional[str] = None obligation_text: Optional[str] = None method: str = "none" # exact_match | embedding_match | llm_extracted | inferred confidence: float = 0.0 regulation_id: Optional[str] = None # e.g. "dsgvo" def to_dict(self) -> dict: return { "obligation_id": self.obligation_id, "obligation_title": self.obligation_title, "obligation_text": self.obligation_text, "method": self.method, "confidence": self.confidence, "regulation_id": self.regulation_id, } @dataclass class _ObligationEntry: """Internal representation of a loaded obligation.""" id: str title: str description: str regulation_id: str articles: list[str] = field(default_factory=list) # normalized: ["art. 30", "§ 38"] embedding: list[float] = field(default_factory=list) class ObligationExtractor: """3-Tier obligation extraction from RAG chunks. Usage:: extractor = ObligationExtractor() await extractor.initialize() # loads obligations + embeddings match = await extractor.extract( chunk_text="...", regulation_code="eu_2016_679", article="Art. 30", paragraph="Abs. 1", ) """ def __init__(self): self._article_lookup: dict[str, list[str]] = {} # "dsgvo/art. 30" → ["DSGVO-OBL-001"] self._obligations: dict[str, _ObligationEntry] = {} # id → entry self._obligation_embeddings: list[list[float]] = [] self._obligation_ids: list[str] = [] self._initialized = False async def initialize(self) -> None: """Load all obligations from v2 JSON files and compute embeddings.""" if self._initialized: return self._load_obligations() await self._compute_embeddings() self._initialized = True logger.info( "ObligationExtractor initialized: %d obligations, %d article lookups, %d embeddings", len(self._obligations), len(self._article_lookup), sum(1 for e in self._obligation_embeddings if e), ) async def extract( self, chunk_text: str, regulation_code: str, article: Optional[str] = None, paragraph: Optional[str] = None, ) -> ObligationMatch: """Extract obligation from a chunk using 3-tier strategy.""" if not self._initialized: await self.initialize() reg_id = _normalize_regulation(regulation_code) # Tier 1: Exact match via article lookup if article: match = self._tier1_exact(reg_id, article) if match: return match # Tier 2: Embedding similarity match = await self._tier2_embedding(chunk_text, reg_id) if match: return match # Tier 3: LLM extraction match = await self._tier3_llm(chunk_text, regulation_code, article) return match # ----------------------------------------------------------------------- # Tier 1: Exact Match # ----------------------------------------------------------------------- def _tier1_exact(self, reg_id: Optional[str], article: str) -> Optional[ObligationMatch]: """Look up obligation by regulation + article.""" if not reg_id: return None norm_article = _normalize_article(article) key = f"{reg_id}/{norm_article}" obl_ids = self._article_lookup.get(key) if not obl_ids: return None # Take the first match (highest priority) obl_id = obl_ids[0] entry = self._obligations.get(obl_id) if not entry: return None return ObligationMatch( obligation_id=entry.id, obligation_title=entry.title, obligation_text=entry.description, method="exact_match", confidence=1.0, regulation_id=reg_id, ) # ----------------------------------------------------------------------- # Tier 2: Embedding Match # ----------------------------------------------------------------------- async def _tier2_embedding( self, chunk_text: str, reg_id: Optional[str] ) -> Optional[ObligationMatch]: """Find nearest obligation by embedding similarity.""" if not self._obligation_embeddings: return None chunk_embedding = await _get_embedding(chunk_text[:2000]) if not chunk_embedding: return None best_idx = -1 best_score = 0.0 for i, obl_emb in enumerate(self._obligation_embeddings): if not obl_emb: continue # Prefer same-regulation matches obl_id = self._obligation_ids[i] entry = self._obligations.get(obl_id) score = _cosine_sim(chunk_embedding, obl_emb) # Domain bonus: +0.05 if same regulation if entry and reg_id and entry.regulation_id == reg_id: score += 0.05 if score > best_score: best_score = score best_idx = i if best_idx < 0: return None # Remove domain bonus for threshold comparison raw_score = best_score obl_id = self._obligation_ids[best_idx] entry = self._obligations.get(obl_id) if entry and reg_id and entry.regulation_id == reg_id: raw_score -= 0.05 if raw_score >= EMBEDDING_MATCH_THRESHOLD: return ObligationMatch( obligation_id=entry.id if entry else obl_id, obligation_title=entry.title if entry else None, obligation_text=entry.description if entry else None, method="embedding_match", confidence=round(min(raw_score, 1.0), 3), regulation_id=entry.regulation_id if entry else reg_id, ) return None # ----------------------------------------------------------------------- # Tier 3: LLM Extraction # ----------------------------------------------------------------------- async def _tier3_llm( self, chunk_text: str, regulation_code: str, article: Optional[str] ) -> ObligationMatch: """Use local LLM to extract the obligation from the chunk.""" prompt = f"""Analysiere den folgenden Gesetzestext und extrahiere die zentrale rechtliche Pflicht. Text: {chunk_text[:3000]} Quelle: {regulation_code} {article or ''} Antworte NUR als JSON: {{ "obligation_text": "Die zentrale Pflicht in einem Satz", "actor": "Wer muss handeln (z.B. Verantwortlicher, Auftragsverarbeiter)", "action": "Was muss getan werden", "normative_strength": "muss|soll|kann" }}""" system_prompt = ( "Du bist ein Rechtsexperte fuer EU-Datenschutz- und Digitalrecht. " "Extrahiere die zentrale rechtliche Pflicht aus Gesetzestexten. " "Antworte ausschliesslich als JSON." ) result_text = await _llm_ollama(prompt, system_prompt) if not result_text: return ObligationMatch( method="llm_extracted", confidence=0.0, regulation_id=_normalize_regulation(regulation_code), ) parsed = _parse_json(result_text) obligation_text = parsed.get("obligation_text", result_text[:500]) return ObligationMatch( obligation_id=None, obligation_title=None, obligation_text=obligation_text, method="llm_extracted", confidence=0.60, regulation_id=_normalize_regulation(regulation_code), ) # ----------------------------------------------------------------------- # Initialization helpers # ----------------------------------------------------------------------- def _load_obligations(self) -> None: """Load all obligation files from v2 framework.""" v2_dir = _find_obligations_dir() if not v2_dir: logger.warning("Obligations v2 directory not found — Tier 1 disabled") return manifest_path = v2_dir / "_manifest.json" if not manifest_path.exists(): logger.warning("Manifest not found at %s", manifest_path) return with open(manifest_path) as f: manifest = json.load(f) for reg_info in manifest.get("regulations", []): reg_id = reg_info["id"] reg_file = v2_dir / reg_info["file"] if not reg_file.exists(): logger.warning("Regulation file not found: %s", reg_file) continue with open(reg_file) as f: data = json.load(f) for obl in data.get("obligations", []): obl_id = obl["id"] entry = _ObligationEntry( id=obl_id, title=obl.get("title", ""), description=obl.get("description", ""), regulation_id=reg_id, ) # Build article lookup from legal_basis for basis in obl.get("legal_basis", []): article_raw = basis.get("article", "") if article_raw: norm_art = _normalize_article(article_raw) key = f"{reg_id}/{norm_art}" if key not in self._article_lookup: self._article_lookup[key] = [] self._article_lookup[key].append(obl_id) entry.articles.append(norm_art) self._obligations[obl_id] = entry logger.info( "Loaded %d obligations from %d regulations", len(self._obligations), len(manifest.get("regulations", [])), ) async def _compute_embeddings(self) -> None: """Compute embeddings for all obligation descriptions.""" if not self._obligations: return self._obligation_ids = list(self._obligations.keys()) texts = [ f"{self._obligations[oid].title}: {self._obligations[oid].description}" for oid in self._obligation_ids ] logger.info("Computing embeddings for %d obligations...", len(texts)) self._obligation_embeddings = await _get_embeddings_batch(texts) valid = sum(1 for e in self._obligation_embeddings if e) logger.info("Got %d/%d valid embeddings", valid, len(texts)) # ----------------------------------------------------------------------- # Stats # ----------------------------------------------------------------------- def stats(self) -> dict: """Return initialization statistics.""" return { "total_obligations": len(self._obligations), "article_lookups": len(self._article_lookup), "embeddings_valid": sum(1 for e in self._obligation_embeddings if e), "regulations": list( {e.regulation_id for e in self._obligations.values()} ), "initialized": self._initialized, } # --------------------------------------------------------------------------- # Module-level helpers (reusable by other modules) # --------------------------------------------------------------------------- def _normalize_regulation(regulation_code: str) -> Optional[str]: """Map a RAG regulation_code to obligation framework regulation ID.""" if not regulation_code: return None code = regulation_code.lower().strip() # Direct lookup if code in _REGULATION_CODE_TO_ID: return _REGULATION_CODE_TO_ID[code] # Prefix matching for families for prefix, reg_id in [ ("eu_2016_679", "dsgvo"), ("eu_2024_1689", "ai_act"), ("eu_2022_2555", "nis2"), ("eu_2022_2065", "dsa"), ("eu_2023_2854", "data_act"), ("eu_2023_1230", "eu_machinery"), ("eu_2022_2554", "dora"), ]: if code.startswith(prefix): return reg_id return None def _normalize_article(article: str) -> str: """Normalize article references for consistent lookup. Examples: "Art. 30" → "art. 30" "§ 38 BDSG" → "§ 38" "Article 10" → "art. 10" "Art. 30 Abs. 1" → "art. 30" "Artikel 35" → "art. 35" """ if not article: return "" s = article.strip() # Remove trailing law name: "§ 38 BDSG" → "§ 38" s = re.sub(r"\s+(DSGVO|BDSG|TTDSG|DSA|NIS2|DORA|AI.?Act)\s*$", "", s, flags=re.IGNORECASE) # Remove paragraph references: "Art. 30 Abs. 1" → "Art. 30" s = re.sub(r"\s+(Abs|Absatz|para|paragraph|lit|Satz)\.?\s+.*$", "", s, flags=re.IGNORECASE) # Normalize "Article" / "Artikel" → "Art." s = re.sub(r"^(Article|Artikel)\s+", "Art. ", s, flags=re.IGNORECASE) return s.lower().strip() def _cosine_sim(a: list[float], b: list[float]) -> float: """Compute cosine similarity between two vectors.""" if not a or not b or len(a) != len(b): return 0.0 dot = sum(x * y for x, y in zip(a, b)) norm_a = sum(x * x for x in a) ** 0.5 norm_b = sum(x * x for x in b) ** 0.5 if norm_a == 0 or norm_b == 0: return 0.0 return dot / (norm_a * norm_b) def _find_obligations_dir() -> Optional[Path]: """Locate the obligations v2 directory.""" candidates = [ Path(__file__).resolve().parent.parent.parent.parent / "ai-compliance-sdk" / "policies" / "obligations" / "v2", Path("/app/ai-compliance-sdk/policies/obligations/v2"), Path("ai-compliance-sdk/policies/obligations/v2"), ] for p in candidates: if p.is_dir() and (p / "_manifest.json").exists(): return p return None async def _get_embedding(text: str) -> list[float]: """Get embedding vector for a single text.""" try: async with httpx.AsyncClient(timeout=10.0) as client: resp = await client.post( f"{EMBEDDING_URL}/embed", json={"texts": [text]}, ) resp.raise_for_status() embeddings = resp.json().get("embeddings", []) return embeddings[0] if embeddings else [] except Exception: return [] async def _get_embeddings_batch( texts: list[str], batch_size: int = 32 ) -> list[list[float]]: """Get embeddings for multiple texts in batches.""" all_embeddings: list[list[float]] = [] for i in range(0, len(texts), batch_size): batch = texts[i : i + batch_size] try: async with httpx.AsyncClient(timeout=30.0) as client: resp = await client.post( f"{EMBEDDING_URL}/embed", json={"texts": batch}, ) resp.raise_for_status() embeddings = resp.json().get("embeddings", []) all_embeddings.extend(embeddings) except Exception as e: logger.warning("Batch embedding failed for %d texts: %s", len(batch), e) all_embeddings.extend([[] for _ in batch]) return all_embeddings async def _llm_ollama(prompt: str, system_prompt: Optional[str] = None) -> str: """Call local Ollama for LLM extraction.""" messages = [] if system_prompt: messages.append({"role": "system", "content": system_prompt}) messages.append({"role": "user", "content": prompt}) payload = { "model": OLLAMA_MODEL, "messages": messages, "stream": False, "format": "json", "options": {"num_predict": 512}, "think": False, } try: async with httpx.AsyncClient(timeout=LLM_TIMEOUT) as client: resp = await client.post(f"{OLLAMA_URL}/api/chat", json=payload) if resp.status_code != 200: logger.error( "Ollama chat failed %d: %s", resp.status_code, resp.text[:300] ) return "" data = resp.json() return data.get("message", {}).get("content", "") except Exception as e: logger.warning("Ollama call failed: %s", e) return "" def _parse_json(text: str) -> dict: """Extract JSON from LLM response text.""" # Try direct parse try: return json.loads(text) except json.JSONDecodeError: pass # Try extracting JSON block match = re.search(r"\{[^{}]*\}", text, re.DOTALL) if match: try: return json.loads(match.group()) except json.JSONDecodeError: pass return {}