""" MC Embedding Match — semantic fallback for the regex-based doc_check. The Sonnet classifier filtered MCs to `check_type='text'` (matchable against doc text). But the regex matcher is still too strict — BMW writes "Speicherdauer 2 Jahre", the MC pattern expects "\\d+\\s*(Tag|Jahr)". We catch these via BGE-M3 embeddings + cosine similarity: 1. Embed the MC's check_question (once, cached in sidecar) 2. Embed the doc text in 50-word chunks 3. cosine(MC, max(chunks)) ≥ threshold → MC passes via "semantic" This recovers ~50% of failed MCs at BMW-scale (estimated). Embeddings come from bp-core-embedding-service (BGE-M3, 1024-dim, multilingual). Sidecar SQLite stores 1024 × 4 bytes = 4KB per MC. """ from __future__ import annotations import logging import math import os import re import sqlite3 import struct from typing import Iterable import httpx logger = logging.getLogger(__name__) EMBEDDING_URL = os.getenv("EMBEDDING_URL", "http://embedding-service:8087") SIDECAR_DB = os.getenv("MC_CLASS_DB", "/data/mc_classification.db") DIM = 1024 # BGE-M3 SIMILARITY_THRESHOLD = float(os.getenv("MC_EMBEDDING_THRESHOLD", "0.55")) CHUNK_SIZE_WORDS = 50 CHUNK_STRIDE = 30 # overlap so multi-sentence MCs aren't cut # Short Pflichtfelder (Impressum: HRB-Nr, USt-IdNr, Anschrift) gehen in # 50-Wort-Chunks unter. Wir scannen den Doc ZUSAETZLICH mit feineren # 15-Wort-Fenstern und lockerem Threshold fuer Impressum/AVV-Typen. SHORT_FIELD_CHUNK_WORDS = 15 SHORT_FIELD_STRIDE = 8 SHORT_FIELD_THRESHOLD = float(os.getenv("MC_EMBEDDING_THRESHOLD_SHORT", "0.50")) SHORT_FIELD_DOC_TYPES = {"impressum", "avv"} # Doc-Type-spezifische Threshold-Overrides — kalibriert anhand BMW v7 # Run: bei 0.55 lagen DSE+Cookie systemisch bei 93% (Over-Firing weil # 8000-Wort-Texte alles vage matchen). 0.60 zieht die echten ~80% ein. # Impressum hat nur 6 echte MCs + Short-Field-Rescue → 0.50 ok. THRESHOLD_OVERRIDE = { "impressum": 0.50, "avv": 0.55, "dse": 0.60, "cookie": 0.60, "widerruf": 0.58, "loeschkonzept": 0.55, "dsfa": 0.55, } def _ensure_schema() -> None: """Add embedding column to mc_classification if not present.""" try: with sqlite3.connect(SIDECAR_DB) as c: cols = [r[1] for r in c.execute("PRAGMA table_info(mc_classification)")] if "embedding" not in cols: c.execute("ALTER TABLE mc_classification ADD COLUMN embedding BLOB") logger.info("Added embedding column to mc_classification") except Exception as e: logger.warning("Embedding schema migration skipped: %s", e) def _vec_to_blob(v: list[float]) -> bytes: return struct.pack(f"{len(v)}f", *v) def _blob_to_vec(b: bytes) -> list[float]: return list(struct.unpack(f"{len(b)//4}f", b)) EMBED_BATCH_SIZE = 32 async def _embed_texts(texts: list[str], timeout: float = 120.0) -> list[list[float]]: """Call the central embedding-service in batches; returns one vector per input. BGE-M3 hangs / times out on >100 inputs at once on a CPU-only host. We chunk into 32er batches and collect. """ if not texts: return [] out: list[list[float]] = [] async with httpx.AsyncClient(timeout=timeout) as client: for i in range(0, len(texts), EMBED_BATCH_SIZE): batch = texts[i:i + EMBED_BATCH_SIZE] try: r = await client.post( f"{EMBEDDING_URL}/embed", json={"texts": batch}, ) r.raise_for_status() vecs = r.json().get("embeddings") or [] out.extend(vecs) except httpx.HTTPError as e: logger.warning("Embed sub-batch [%d-%d] failed: %s %s", i, i + len(batch), type(e).__name__, e) # Pad with empty vectors so caller can still align by index out.extend([[] for _ in batch]) return out async def ensure_mc_embeddings(batch_size: int = 64, force: bool = False) -> int: """One-shot: embed every text-MC missing an embedding. Returns count. Embeds the title + (rough) check_question for each MC to give the BGE-M3 enough context. Title alone is too terse for the model to discriminate against full-paragraph doc text. Idempotent — only fills NULL rows unless force=True. Safe to call on every run. """ _ensure_schema() # Pull check_question from the PG source table once per call (needs # context that's not in the sidecar) try: import psycopg2 pg = psycopg2.connect(os.environ["DATABASE_URL"]) with pg.cursor() as c: c.execute("SELECT control_id, doc_type, title, check_question " "FROM compliance.doc_check_controls") pg_rows = c.fetchall() pg.close() pg_lookup = {(r[0], r[1] or ""): (r[2] or "", r[3] or "") for r in pg_rows} except Exception as e: logger.warning("ensure_mc_embeddings PG load failed: %s", e) pg_lookup = {} try: with sqlite3.connect(SIDECAR_DB) as c: where = ("WHERE check_type = 'text'" + ("" if force else " AND embedding IS NULL")) rows = c.execute( f"SELECT control_id, doc_type, title FROM mc_classification {where}" ).fetchall() except Exception as e: logger.warning("ensure_mc_embeddings query failed: %s", e) return 0 if not rows: return 0 logger.info("Embedding %d text-MCs (force=%s) via %s ...", len(rows), force, EMBEDDING_URL) done = 0 for i in range(0, len(rows), batch_size): batch = rows[i:i + batch_size] # Compose "title — check_question" so the embedding captures both # the topic (title) and the concrete check phrasing (question). # That helps BMW's actual policy language land in the same vector # neighbourhood as our control wording. texts: list[str] = [] for cid, dt, t in batch: title_text, question = pg_lookup.get((cid, dt or ""), (t or "", "")) combined = f"{title_text}. {question}".strip() texts.append(combined[:600]) try: embs = await _embed_texts(texts) except Exception as e: logger.warning("Embed batch failed (i=%d): %s", i, e) continue with sqlite3.connect(SIDECAR_DB) as c: for (cid, dt, _t), vec in zip(batch, embs): if not vec or len(vec) != DIM: continue c.execute( "UPDATE mc_classification SET embedding = ? " "WHERE control_id = ? AND doc_type = ?", (_vec_to_blob(vec), cid, dt), ) c.commit() done += len(batch) logger.info("ensure_mc_embeddings: filled %d/%d", done, len(rows)) return done def _chunk_text(text: str, size: int = CHUNK_SIZE_WORDS, stride: int = CHUNK_STRIDE) -> list[str]: """Sliding-window chunking — overlap helps catch MCs that span 2 sentences.""" words = re.findall(r"\S+", text or "") if len(words) <= size: return [" ".join(words)] if words else [] out: list[str] = [] i = 0 while i < len(words): out.append(" ".join(words[i:i + size])) i += stride return out def _cosine(a: list[float], b: list[float]) -> float: """Plain Python cosine — fast enough for our scale, no numpy import.""" if not a or not b or len(a) != len(b): return 0.0 dot = sum(x * y for x, y in zip(a, b)) na = math.sqrt(sum(x * x for x in a)) nb = math.sqrt(sum(y * y for y in b)) if na == 0 or nb == 0: return 0.0 return dot / (na * nb) async def embedding_match( doc_text: str, mc_records: Iterable[dict], doc_type: str | None = None, threshold: float | None = None, ) -> set[str]: """Return the subset of MC control_ids that semantically match doc_text. For Impressum/AVV-types we ADDITIONALLY scan the doc with smaller 15-word windows and a looser threshold so that short Pflichtfelder (HRB, USt-IdNr, postal address) land in their own chunk and aren't diluted by 50-word neighbourhoods of unrelated text. """ if not doc_text or not mc_records: return set() candidates = list(mc_records) if not candidates: return set() cid_set = {c.get("control_id") for c in candidates if c.get("control_id")} if not cid_set: return set() try: with sqlite3.connect(SIDECAR_DB) as c: placeholders = ",".join("?" * len(cid_set)) q = ("SELECT control_id, embedding FROM mc_classification " f"WHERE control_id IN ({placeholders}) " "AND check_type='text' AND embedding IS NOT NULL") params = list(cid_set) if doc_type: q += " AND doc_type = ?" params.append(doc_type) rows = c.execute(q, params).fetchall() except Exception as e: logger.warning("embedding lookup failed: %s", e) return set() if not rows: return set() mc_embeddings = {cid: _blob_to_vec(blob) for cid, blob in rows} effective_threshold = threshold or THRESHOLD_OVERRIDE.get( (doc_type or "").lower(), SIMILARITY_THRESHOLD) chunks = _chunk_text(doc_text) if not chunks: return set() try: chunk_vecs = await _embed_texts(chunks) except Exception as e: logger.warning("doc chunk embedding failed: %s %s", type(e).__name__, e or "(empty msg)", exc_info=True) return set() # Filter empty vectors (failed sub-batches return [] placeholders) chunk_vecs = [v for v in chunk_vecs if v and len(v) == DIM] if not chunk_vecs: logger.warning("doc chunk embedding: no usable vectors (all batches failed)") return set() matched: set[str] = set() for cid, mc_vec in mc_embeddings.items(): best = max((_cosine(mc_vec, cv) for cv in chunk_vecs), default=0.0) if best >= effective_threshold: matched.add(cid) # Short-field rescue pass for Impressum-type docs: small windows + # looser threshold catch one-line Pflichtfelder that 50-word chunks # dilute (HRB-Nr, USt-IdNr, postal address). Only runs for MCs not # yet matched in the main pass. if (doc_type or "").lower() in SHORT_FIELD_DOC_TYPES: unmatched = {cid: vec for cid, vec in mc_embeddings.items() if cid not in matched} if unmatched: short_chunks = _chunk_text(doc_text, size=SHORT_FIELD_CHUNK_WORDS, stride=SHORT_FIELD_STRIDE) try: short_vecs = await _embed_texts(short_chunks) except Exception as e: logger.warning("short-chunk embedding failed: %s", e) short_vecs = [] if short_vecs: short_passes = 0 for cid, mc_vec in unmatched.items(): best = max((_cosine(mc_vec, cv) for cv in short_vecs), default=0.0) if best >= SHORT_FIELD_THRESHOLD: matched.add(cid) short_passes += 1 if short_passes: logger.info( "embedding short-field rescue for %s: +%d MCs (threshold %.2f, %d chunks)", doc_type, short_passes, SHORT_FIELD_THRESHOLD, len(short_chunks), ) logger.info( "embedding match for %s: %d/%d MCs passed semantic threshold (main=%.2f)", doc_type or "?", len(matched), len(mc_embeddings), effective_threshold, ) return matched