#!/usr/bin/env python3 """Match Fachmann ground-truth measures against the IACE measure library. For each unique measure string in testdata/ground_truth_bremse.json this script computes the best match against the IACE library by combining four signals: 1. Token-Jaccard on a normalized token set (handles word-order + length mismatch) 2. Longest contiguous substring ratio (catches partial copies) 3. Norm-reference overlap (e.g. shared "EN 60204-1 Ziff. 6.2" between GT and library) 4. Length-adjusted SequenceMatcher ratio as a fallback for short fragments The combined score is the maximum of the four signals so that a strong hit on any single dimension lifts the entry out of the gap bucket. The previous version returned 0.40 for matches like "Potentialausgleich zwischen Robodrill / ..." vs M475 because the GT string was 5x longer than the library name; the new score catches these via token-Jaccard and substring. Outputs both a markdown report and a JSON file for programmatic consumption. """ from __future__ import annotations import argparse import json import pathlib import re import sys from difflib import SequenceMatcher ROOT = pathlib.Path(__file__).resolve().parents[1] GT = ROOT / "ai-compliance-sdk/internal/iace/testdata/ground_truth_bremse.json" MEASURE_DIR = ROOT / "ai-compliance-sdk/internal/iace" # Lightweight Go struct line parser. Each library entry is a single line. ENTRY_RE = re.compile(r'\{ID:\s*"(M\d+)"[^}]*\}', re.DOTALL) FIELD_RE = re.compile(r'(\w+):\s*"([^"]*)"') LIST_RE = re.compile(r'(\w+):\s*\[\]string\{([^}]*)\}') # Tokens that are too generic to count for similarity. STOPWORDS = { "der", "die", "das", "den", "dem", "des", "ein", "eine", "einer", "einem", "und", "oder", "an", "auf", "in", "im", "mit", "fuer", "fur", "zu", "zur", "zum", "bei", "von", "vom", "ist", "sind", "wird", "werden", "durch", "nicht", "kein", "keine", "alle", "alles", "auch", "nur", "ueber", "ueb", "the", "and", "or", "of", "to", "in", "on", "by", "for", "is", "are", "with", } # Section-header markers in GT that are not real measures. HEADER_PATTERNS = [ re.compile(r"^[A-ZÄÖÜ /\-_0-9]+$"), # ALLCAPS section title re.compile(r"^>\s"), # quoted header (e.g. "> Im Folgenden:") re.compile(r"^\d+\.\s+[A-ZÄÖÜ]"), # "1. Foo" enumerated header re.compile(r"^==>"), # "==> keine ..." comment re.compile(r"^Kein\s+st(ä|ae)ndiger\s+Arbeitsplatz"), ] NORM_RE = re.compile(r"(?:EN|IEC|ISO|DIN|TRBS|TRGS|ASR|DGUV|OSHA|VDE|VDI)[\s\-]?[A-Z]?\d[\w\-./]*", re.IGNORECASE) def norm(s: str) -> str: s = s.lower() s = (s.replace("ä", "ae").replace("ö", "oe").replace("ü", "ue") .replace("ß", "ss").replace("é", "e")) s = re.sub(r"[^a-z0-9 ]+", " ", s) s = re.sub(r"\s+", " ", s).strip() return s def tokens(s: str) -> set[str]: return {t for t in norm(s).split() if len(t) > 2 and t not in STOPWORDS} def is_header(s: str) -> bool: s = s.strip() if not s or len(s) < 8 or s.endswith(":"): return True return any(p.search(s) for p in HEADER_PATTERNS) def norm_refs(s: str) -> set[str]: return {re.sub(r"\s+", " ", m.group(0).lower().strip()) for m in NORM_RE.finditer(s)} def load_library() -> list[dict]: out: list[dict] = [] for f in sorted(MEASURE_DIR.glob("measures_library*.go")): text = f.read_text(encoding="utf-8") for m in ENTRY_RE.finditer(text): blob = m.group(0) fields = dict(FIELD_RE.findall(blob)) lists = {k: [s.strip().strip('"') for s in v.split('",') if s.strip()] for k, v in LIST_RE.findall(blob)} if "ID" not in fields: continue examples = lists.get("Examples", []) norm_list = lists.get("NormReferences", []) haystack = " ".join([fields.get("Name", ""), fields.get("Description", ""), *examples, *norm_list]) out.append({ "ID": fields["ID"], "Name": fields.get("Name", ""), "Description": fields.get("Description", ""), "HazardCategory": fields.get("HazardCategory", ""), "Examples": examples, "NormReferences": norm_list, "file": f.name, "_haystack_norm": norm(haystack), "_tokens": tokens(haystack), "_norm_refs": norm_refs(" ".join(norm_list)), }) return out def best_match(needle: str, lib: list[dict]) -> tuple[float, dict | None, dict]: n_str = norm(needle) n_tokens = tokens(needle) n_refs = norm_refs(needle) if not n_str: return 0.0, None, {} best: tuple[float, dict | None, dict] = (0.0, None, {}) for entry in lib: # Signal 1: Jaccard on normalized tokens. if n_tokens and entry["_tokens"]: inter = len(n_tokens & entry["_tokens"]) union = len(n_tokens | entry["_tokens"]) jaccard = inter / union if union else 0.0 # Token-recall bonus: if all GT tokens appear in library haystack # the GT string is "covered" even if library is much broader. recall = inter / len(n_tokens) if n_tokens else 0.0 else: jaccard = recall = 0.0 # Signal 2: substring containment ratio (catches verbatim fragments). contain_ratio = 0.0 if len(n_str) >= 12: sm = SequenceMatcher(None, n_str, entry["_haystack_norm"]) mb = sm.find_longest_match(0, len(n_str), 0, len(entry["_haystack_norm"])) contain_ratio = mb.size / len(n_str) if len(n_str) else 0.0 # Signal 3: norm-reference overlap. norm_overlap = 0.0 if n_refs and entry["_norm_refs"]: norm_overlap = len(n_refs & entry["_norm_refs"]) / len(n_refs) # Signal 4: classic SequenceMatcher ratio (length-tolerant via shorter side). seq_ratio = SequenceMatcher(None, n_str, entry["_haystack_norm"]).ratio() score = max(jaccard, recall * 0.9, contain_ratio, norm_overlap, seq_ratio) if score > best[0]: best = (score, entry, { "jaccard": round(jaccard, 3), "token_recall": round(recall, 3), "substring": round(contain_ratio, 3), "norm_overlap": round(norm_overlap, 3), "seq_ratio": round(seq_ratio, 3), }) return best def collect_gt_measures(gt_path: pathlib.Path) -> dict[str, list[str]]: """Return {measure_string -> [entry_nr,...]} (deduped per nr), filtered.""" data = json.loads(gt_path.read_text(encoding="utf-8")) bucket: dict[str, set[str]] = {} for e in data["entries"]: for m in e.get("measures", []): m = m.strip() if is_header(m): continue bucket.setdefault(m, set()).add(e["nr"]) return {k: sorted(v) for k, v in bucket.items()} def collect_gt_by_group(gt_path: pathlib.Path) -> dict[str, list[tuple[str, str]]]: """Return {hazard_group -> [(nr, measure), ...]}.""" data = json.loads(gt_path.read_text(encoding="utf-8")) out: dict[str, list[tuple[str, str]]] = {} for e in data["entries"]: group = e.get("hazard_group", "Unknown") for m in e.get("measures", []): m = m.strip() if is_header(m): continue out.setdefault(group, []).append((e["nr"], m)) return out def main() -> int: ap = argparse.ArgumentParser() ap.add_argument("--json", type=pathlib.Path, default=None, help="Write JSON output to this path (in addition to stdout markdown)") ap.add_argument("--gap-threshold", type=float, default=0.45) ap.add_argument("--weak-threshold", type=float, default=0.65) args = ap.parse_args() lib = load_library() print(f"Library entries parsed: {len(lib)}", file=sys.stderr) gt = collect_gt_measures(GT) print(f"GT measure strings (filtered): {len(gt)}", file=sys.stderr) rows: list[dict] = [] for measure, nrs in gt.items(): score, entry, signals = best_match(measure, lib) rows.append({ "score": round(score, 3), "gt_nrs": nrs, "gt_measure": measure, "match_id": entry["ID"] if entry else None, "match_name": entry["Name"] if entry else None, "match_category": entry["HazardCategory"] if entry else None, "signals": signals, }) rows.sort(key=lambda r: r["score"]) GAP, WEAK = args.gap_threshold, args.weak_threshold n_gap = sum(1 for r in rows if r["score"] < GAP) n_weak = sum(1 for r in rows if GAP <= r["score"] < WEAK) n_ok = sum(1 for r in rows if r["score"] >= WEAK) if args.json: args.json.write_text(json.dumps({ "total": len(rows), "gap_count": n_gap, "weak_count": n_weak, "ok_count": n_ok, "gap_threshold": GAP, "weak_threshold": WEAK, "rows": rows, }, ensure_ascii=False, indent=2), encoding="utf-8") print(f"JSON written: {args.json}", file=sys.stderr) print(f"# GT-Measure-Coverage Report\n") print(f"- Total filtered GT measures: **{len(rows)}**") print(f"- Gaps (score < {GAP}): **{n_gap}**") print(f"- Weak matches ({GAP} <= score < {WEAK}): **{n_weak}**") print(f"- Strong matches (score >= {WEAK}): **{n_ok}**\n") def section(title: str, lo: float, hi: float) -> None: print(f"## {title}\n") print("| Score | GT-Nr. | Best Match | Signals | GT Massnahme |") print("|-------|--------|-----------|---------|--------------|") for r in rows: if not (lo <= r["score"] < hi): continue mid = f"{r['match_id']} — {r['match_name']}" if r["match_id"] else "—" m_short = r["gt_measure"].replace("|", "\\|") if len(m_short) > 120: m_short = m_short[:117] + "..." sig = r["signals"] sigstr = f"j={sig.get('jaccard',0)} sub={sig.get('substring',0)} n={sig.get('norm_overlap',0)}" print(f"| {r['score']:.2f} | {','.join(r['gt_nrs'])} | {mid} | {sigstr} | {m_short} |") print() section(f"Gaps (score < {GAP})", 0.0, GAP) section(f"Weak Matches ({GAP} - {WEAK})", GAP, WEAK) section(f"Strong Matches (>= {WEAK})", WEAK, 1.01) return 0 if __name__ == "__main__": sys.exit(main())