From 9437e029d0cc502b7af9c7be5ba76b6558c592f7 Mon Sep 17 00:00:00 2001 From: Benjamin Admin Date: Sun, 3 May 2026 23:14:06 +0200 Subject: [PATCH] =?UTF-8?q?feat(pipeline):=20F1=20regulation=20registry=20?= =?UTF-8?q?=E2=80=94=20DB-backed=20license/source-type=20lookup?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Migrates REGULATION_LICENSE_MAP (135 entries) and SOURCE_REGULATION_CLASSIFICATION (58 entries) from hardcoded Python dicts to compliance.regulation_registry table. - SQL migration: 002_regulation_registry.sql (table + indexes + trigger) - Migration script: f1_migrate_regulation_registry.py (162 rows, --dry-run) - RegulationRegistry cache: 5min TTL, prefix fallback, graceful degradation - control_generator._classify_regulation() delegates to DB with dict fallback - source_type_classification.classify_source_regulation() delegates to DB - 34 new tests (lookup, cache, degradation, migration data consistency) - 421 total tests pass, 0 regressions Co-Authored-By: Claude Opus 4.6 (1M context) --- .../data/source_type_classification.py | 21 +- .../migrations/002_regulation_registry.sql | 72 +++++ .../scripts/f1_migrate_regulation_registry.py | 247 +++++++++++++++ control-pipeline/services/anchor_finder.py | 3 - .../services/control_generator.py | 32 +- .../services/regulation_registry.py | 220 ++++++++++++++ .../tests/test_regulation_registry.py | 285 ++++++++++++++++++ 7 files changed, 850 insertions(+), 30 deletions(-) create mode 100644 control-pipeline/migrations/002_regulation_registry.sql create mode 100644 control-pipeline/scripts/f1_migrate_regulation_registry.py create mode 100644 control-pipeline/services/regulation_registry.py create mode 100644 control-pipeline/tests/test_regulation_registry.py diff --git a/control-pipeline/data/source_type_classification.py b/control-pipeline/data/source_type_classification.py index fbfbe25..8ce0808 100644 --- a/control-pipeline/data/source_type_classification.py +++ b/control-pipeline/data/source_type_classification.py @@ -165,21 +165,29 @@ def classify_source_regulation(source_regulation: str) -> str: """ Klassifiziert eine source_regulation als law, guideline oder framework. - Verwendet exaktes Matching gegen die Map. Bei unbekannten Quellen - wird anhand von Schluesselwoertern geraten, Fallback ist 'framework' - (konservativstes Ergebnis). + Delegates to DB-backed RegulationRegistry (with 5min cache). + Falls back to SOURCE_REGULATION_CLASSIFICATION dict + heuristic + if DB is unavailable. """ if not source_regulation: return SOURCE_TYPE_FRAMEWORK - # Exaktes Match + # Try DB-backed registry first + try: + from services.regulation_registry import classify_source_regulation as _db_classify + result = _db_classify(source_regulation) + if result: + return result + except Exception: + pass + + # Fallback: local dict if source_regulation in SOURCE_REGULATION_CLASSIFICATION: return SOURCE_REGULATION_CLASSIFICATION[source_regulation] # Heuristik fuer unbekannte Quellen lower = source_regulation.lower() - # Gesetze erkennen law_indicators = [ "verordnung", "richtlinie", "gesetz", "directive", "regulation", "(eu)", "(eg)", "act", "ley", "loi", "törvény", "código", @@ -187,19 +195,16 @@ def classify_source_regulation(source_regulation: str) -> str: if any(ind in lower for ind in law_indicators): return SOURCE_TYPE_LAW - # Leitlinien erkennen guideline_indicators = [ "edpb", "leitlinie", "guideline", "wp2", "bsi", "empfehlung", ] if any(ind in lower for ind in guideline_indicators): return SOURCE_TYPE_GUIDELINE - # Frameworks erkennen framework_indicators = [ "enisa", "nist", "owasp", "oecd", "cisa", "framework", "iso", ] if any(ind in lower for ind in framework_indicators): return SOURCE_TYPE_FRAMEWORK - # Konservativ: unbekannt = framework (geringste Verbindlichkeit) return SOURCE_TYPE_FRAMEWORK diff --git a/control-pipeline/migrations/002_regulation_registry.sql b/control-pipeline/migrations/002_regulation_registry.sql new file mode 100644 index 0000000..a5c1544 --- /dev/null +++ b/control-pipeline/migrations/002_regulation_registry.sql @@ -0,0 +1,72 @@ +-- Migration 002: Regulation Registry (Block F1) +-- Schema: compliance +-- Run: ssh macmini "docker exec -i bp-core-postgres psql -U breakpilot -d breakpilot_db" < control-pipeline/migrations/002_regulation_registry.sql + +SET search_path TO compliance, public; + +-- ======================================== +-- regulation_registry +-- ======================================== +-- Central registry for all regulations, laws, guidelines, and frameworks +-- referenced by the control pipeline. Replaces hardcoded Python dicts +-- (REGULATION_LICENSE_MAP, SOURCE_REGULATION_CLASSIFICATION). + +CREATE TABLE IF NOT EXISTS regulation_registry ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + + -- regulation_id: machine key (e.g. "eu_2016_679", "nist_sp_800_53") + regulation_id VARCHAR(100) UNIQUE NOT NULL, + + -- Display names + regulation_name_de TEXT, + regulation_name_en TEXT, + regulation_short VARCHAR(50), + + -- License classification (3-rule system) + license_rule INTEGER NOT NULL DEFAULT 1 + CHECK (license_rule IN (1, 2, 3)), + license_type VARCHAR(50), -- EU_LAW, DE_LAW, CC-BY-SA-4.0, etc. + attribution TEXT, -- Required for Rule 2 (CC-BY) + + -- Source classification + source_type VARCHAR(20) NOT NULL DEFAULT 'law' + CHECK (source_type IN ('law', 'guideline', 'standard', 'framework', 'restricted')), + + -- Metadata + jurisdiction VARCHAR(10), -- DE, EU, AT, CH, US, FR, ES, NL, IT, HU, INT + category VARCHAR(50), + celex VARCHAR(30), -- EU CELEX number if applicable + url TEXT, + + -- Lifecycle + status VARCHAR(20) NOT NULL DEFAULT 'active' + CHECK (status IN ('active', 'needs_review', 'deprecated')), + + created_at TIMESTAMPTZ DEFAULT NOW(), + updated_at TIMESTAMPTZ DEFAULT NOW() +); + +-- Indexes +CREATE INDEX IF NOT EXISTS idx_reg_registry_status + ON regulation_registry(status); +CREATE INDEX IF NOT EXISTS idx_reg_registry_jurisdiction + ON regulation_registry(jurisdiction); +CREATE INDEX IF NOT EXISTS idx_reg_registry_source_type + ON regulation_registry(source_type); +CREATE INDEX IF NOT EXISTS idx_reg_registry_license_rule + ON regulation_registry(license_rule); + +-- Updated-at trigger +CREATE OR REPLACE FUNCTION update_regulation_registry_updated_at() +RETURNS TRIGGER AS $$ +BEGIN + NEW.updated_at = NOW(); + RETURN NEW; +END; +$$ LANGUAGE plpgsql; + +DROP TRIGGER IF EXISTS trg_regulation_registry_updated_at ON regulation_registry; +CREATE TRIGGER trg_regulation_registry_updated_at + BEFORE UPDATE ON regulation_registry + FOR EACH ROW + EXECUTE FUNCTION update_regulation_registry_updated_at(); diff --git a/control-pipeline/scripts/f1_migrate_regulation_registry.py b/control-pipeline/scripts/f1_migrate_regulation_registry.py new file mode 100644 index 0000000..6667fd0 --- /dev/null +++ b/control-pipeline/scripts/f1_migrate_regulation_registry.py @@ -0,0 +1,247 @@ +#!/usr/bin/env python3 +""" +F1 Migration: Populate regulation_registry from hardcoded Python dicts. + +Sources: + - REGULATION_LICENSE_MAP (control_generator.py) — 135 entries keyed by regulation_id + - SOURCE_REGULATION_CLASSIFICATION (source_type_classification.py) — 58 entries keyed by name + +Usage: + # Dry run (prints SQL, no DB write): + python3 scripts/f1_migrate_regulation_registry.py --dry-run + + # Against Mac Mini: + python3 scripts/f1_migrate_regulation_registry.py --db-host macmini + + # Against local Docker: + python3 scripts/f1_migrate_regulation_registry.py --db-host localhost +""" + +import argparse +import sys +from pathlib import Path + +# Add parent so we can import from services/data +sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) + +from services.control_generator import REGULATION_LICENSE_MAP, _RULE2_PREFIXES, _RULE3_PREFIXES # noqa: E402 +from data.source_type_classification import SOURCE_REGULATION_CLASSIFICATION # noqa: E402 + +# Derive jurisdiction from license_type +_LICENSE_TO_JURISDICTION = { + "EU_LAW": "EU", + "EU_PUBLIC": "EU", + "DE_LAW": "DE", + "DE_PUBLIC": "DE", + "AT_LAW": "AT", + "CH_LAW": "CH", + "FR_LAW": "FR", + "ES_LAW": "ES", + "NL_LAW": "NL", + "IT_LAW": "IT", + "HU_LAW": "HU", + "NIST_PUBLIC_DOMAIN": "US", + "US_GOV_PUBLIC": "US", + "CC-BY-SA-4.0": "INT", + "CC-BY-4.0": "INT", + "OECD_PUBLIC": "INT", +} + + +def _derive_jurisdiction(license_type: str) -> str: + """Map license_type to jurisdiction code.""" + return _LICENSE_TO_JURISDICTION.get(license_type, "INT") + + +def build_rows() -> list[dict]: + """Merge REGULATION_LICENSE_MAP + SOURCE_REGULATION_CLASSIFICATION into rows.""" + rows = [] + # Track names we've seen (for dedup against SOURCE_REGULATION_CLASSIFICATION) + seen_names: set[str] = set() + + # 1) Primary source: REGULATION_LICENSE_MAP (has regulation_id as key) + for reg_id, info in REGULATION_LICENSE_MAP.items(): + name = info.get("name", reg_id) + seen_names.add(name) + + rows.append({ + "regulation_id": reg_id.lower().strip(), + "regulation_name_de": name, + "license_rule": info["rule"], + "license_type": info.get("license", ""), + "attribution": info.get("attribution"), + "source_type": info.get("source_type", "law"), + "jurisdiction": _derive_jurisdiction(info.get("license", "")), + "status": "active", + }) + + # 2) Secondary: SOURCE_REGULATION_CLASSIFICATION entries not already covered + # These are keyed by name, not by regulation_id. We create synthetic IDs. + for name, source_type in SOURCE_REGULATION_CLASSIFICATION.items(): + if name in seen_names: + continue + # Generate a regulation_id from the name + synthetic_id = ( + name.lower() + .replace(" ", "_") + .replace("(", "") + .replace(")", "") + .replace("/", "_") + .replace("-", "_") + .replace(".", "") + .replace(",", "") + .replace("ä", "ae") + .replace("ö", "oe") + .replace("ü", "ue") + .replace("á", "a") + .replace("é", "e") + .replace("ó", "o") + .strip("_") + )[:100] + + # Guess jurisdiction from name content + jurisdiction = "INT" + name_lower = name.lower() + if any(x in name_lower for x in ["edpb", "edps", "(eu)", "eu ", "wp2"]): + jurisdiction = "EU" + elif any(x in name_lower for x in ["bsi", "bdsg", "bundes", "gwg"]): + jurisdiction = "DE" + elif "nist" in name_lower or "cisa" in name_lower: + jurisdiction = "US" + elif "österreich" in name_lower: + jurisdiction = "AT" + elif "schweiz" in name_lower: + jurisdiction = "CH" + elif "spanien" in name_lower: + jurisdiction = "ES" + elif "frankreich" in name_lower: + jurisdiction = "FR" + elif "ungarn" in name_lower: + jurisdiction = "HU" + + # Map source_type_classification's "framework" to our "standard" + # (source_type_classification uses law/guideline/framework) + mapped_source_type = source_type + if source_type == "framework": + mapped_source_type = "standard" + + rows.append({ + "regulation_id": synthetic_id, + "regulation_name_de": name, + "license_rule": 1, # default: conservative + "license_type": "", + "attribution": None, + "source_type": mapped_source_type, + "jurisdiction": jurisdiction, + "status": "needs_review", # needs manual review since we guessed + }) + + return rows + + +def generate_sql(rows: list[dict]) -> str: + """Generate INSERT SQL for all rows.""" + lines = [ + "SET search_path TO compliance, public;", + "", + "-- Auto-generated by f1_migrate_regulation_registry.py", + f"-- {len(rows)} rows total", + "", + ] + + for row in rows: + attr = f"'{row['attribution']}'" if row["attribution"] else "NULL" + lines.append( + f"INSERT INTO regulation_registry " + f"(regulation_id, regulation_name_de, license_rule, license_type, " + f"attribution, source_type, jurisdiction, status) " + f"VALUES (" + f"'{row['regulation_id']}', " + f"'{_escape_sql(row['regulation_name_de'])}', " + f"{row['license_rule']}, " + f"'{row['license_type']}', " + f"{attr}, " + f"'{row['source_type']}', " + f"'{row['jurisdiction']}', " + f"'{row['status']}'" + f") ON CONFLICT (regulation_id) DO UPDATE SET " + f"regulation_name_de = EXCLUDED.regulation_name_de, " + f"license_rule = EXCLUDED.license_rule, " + f"license_type = EXCLUDED.license_type, " + f"attribution = EXCLUDED.attribution, " + f"source_type = EXCLUDED.source_type, " + f"jurisdiction = EXCLUDED.jurisdiction;" + ) + + return "\n".join(lines) + + +def _escape_sql(val: str) -> str: + """Escape single quotes for SQL.""" + return val.replace("'", "''") + + +def insert_via_sqlalchemy(rows: list[dict], db_host: str) -> int: + """Insert rows using SQLAlchemy (same pattern as control-pipeline).""" + from sqlalchemy import create_engine, text + + url = f"postgresql://breakpilot:breakpilot123@{db_host}:5432/breakpilot_db" + engine = create_engine(url) + + inserted = 0 + with engine.connect() as conn: + conn.execute(text("SET search_path TO compliance, public")) + for row in rows: + conn.execute( + text(""" + INSERT INTO regulation_registry + (regulation_id, regulation_name_de, license_rule, license_type, + attribution, source_type, jurisdiction, status) + VALUES + (:regulation_id, :regulation_name_de, :license_rule, :license_type, + :attribution, :source_type, :jurisdiction, :status) + ON CONFLICT (regulation_id) DO UPDATE SET + regulation_name_de = EXCLUDED.regulation_name_de, + license_rule = EXCLUDED.license_rule, + license_type = EXCLUDED.license_type, + attribution = EXCLUDED.attribution, + source_type = EXCLUDED.source_type, + jurisdiction = EXCLUDED.jurisdiction + """), + row, + ) + inserted += 1 + conn.commit() + + return inserted + + +def main(): + parser = argparse.ArgumentParser(description="Migrate regulation registry data") + parser.add_argument("--dry-run", action="store_true", help="Print SQL only") + parser.add_argument("--db-host", default="localhost", help="PostgreSQL host") + args = parser.parse_args() + + rows = build_rows() + print(f"Built {len(rows)} rows from hardcoded dicts") + + # Stats + by_rule = {} + by_status = {} + for r in rows: + by_rule[r["license_rule"]] = by_rule.get(r["license_rule"], 0) + 1 + by_status[r["status"]] = by_status.get(r["status"], 0) + 1 + print(f" By license_rule: {by_rule}") + print(f" By status: {by_status}") + + if args.dry_run: + print("\n--- DRY RUN (SQL output) ---\n") + print(generate_sql(rows)) + return + + inserted = insert_via_sqlalchemy(rows, args.db_host) + print(f"Inserted/updated {inserted} rows into regulation_registry") + + +if __name__ == "__main__": + main() diff --git a/control-pipeline/services/anchor_finder.py b/control-pipeline/services/anchor_finder.py index 8807567..4d2e5c5 100644 --- a/control-pipeline/services/anchor_finder.py +++ b/control-pipeline/services/anchor_finder.py @@ -17,9 +17,6 @@ import httpx from .control_generator import ( GeneratedControl, - REGULATION_LICENSE_MAP, - _RULE2_PREFIXES, - _RULE3_PREFIXES, _classify_regulation, ) diff --git a/control-pipeline/services/control_generator.py b/control-pipeline/services/control_generator.py index c4a5509..9583755 100644 --- a/control-pipeline/services/control_generator.py +++ b/control-pipeline/services/control_generator.py @@ -33,6 +33,7 @@ from sqlalchemy import text from sqlalchemy.orm import Session from .rag_client import ComplianceRAGClient, RAGSearchResult, get_rag_client +from .regulation_registry import get_registry as _get_regulation_registry from .similarity_detector import check_similarity logger = logging.getLogger(__name__) @@ -245,28 +246,21 @@ def _classify_regulation(regulation_code: str) -> dict: Returns dict with keys: license, rule, name, source_type. source_type is one of: law, guideline, standard, restricted. + + Delegates to DB-backed RegulationRegistry (with 5min cache). + Falls back to REGULATION_LICENSE_MAP if DB is unavailable. """ - code = regulation_code.lower().strip() + registry = _get_regulation_registry() + result = registry.classify_regulation(regulation_code) - # Exact match first - if code in REGULATION_LICENSE_MAP: - return REGULATION_LICENSE_MAP[code] + # If registry returned the unknown fallback AND we have a local match, + # prefer the local dict (graceful degradation during migration) + if result.get("license") == "UNKNOWN": + code = regulation_code.lower().strip() + if code in REGULATION_LICENSE_MAP: + return REGULATION_LICENSE_MAP[code] - # Prefix match for Rule 2 (ENISA = standard) - for prefix in _RULE2_PREFIXES: - if code.startswith(prefix): - return {"license": "CC-BY-4.0", "rule": 2, "source_type": "standard", - "name": "ENISA", "attribution": "ENISA, CC BY 4.0"} - - # Prefix match for Rule 3 (BSI/ISO/ETSI = restricted) - for prefix in _RULE3_PREFIXES: - if code.startswith(prefix): - return {"license": f"{prefix.rstrip('_').upper()}_RESTRICTED", "rule": 3, - "source_type": "restricted", "name": "INTERNAL_ONLY"} - - # Unknown → treat as restricted (safe default) - logger.warning("Unknown regulation_code %r — defaulting to Rule 3 (restricted)", code) - return {"license": "UNKNOWN", "rule": 3, "source_type": "restricted", "name": "INTERNAL_ONLY"} + return result # --------------------------------------------------------------------------- diff --git a/control-pipeline/services/regulation_registry.py b/control-pipeline/services/regulation_registry.py new file mode 100644 index 0000000..4057a60 --- /dev/null +++ b/control-pipeline/services/regulation_registry.py @@ -0,0 +1,220 @@ +""" +DB-backed Regulation Registry with in-memory cache. + +Replaces hardcoded REGULATION_LICENSE_MAP and SOURCE_REGULATION_CLASSIFICATION +with a single PostgreSQL table (compliance.regulation_registry). + +Cache TTL: 5 minutes. Thread-safe via simple timestamp check. +Falls back to hardcoded dicts if DB is unavailable (graceful degradation). +""" + +import logging +import time +from typing import Optional + +from sqlalchemy import text +from sqlalchemy.exc import SQLAlchemyError + +from db.session import SessionLocal + +logger = logging.getLogger(__name__) + +_CACHE_TTL_SECONDS = 300 # 5 minutes + +# Prefix-based fallback rules (unchanged from original logic) +_RULE2_PREFIXES = ("enisa_",) +_RULE3_PREFIXES = ("bsi_", "iso_", "etsi_") + +# Fallback for unknown regulations +_UNKNOWN_REGULATION = { + "license": "UNKNOWN", + "rule": 3, + "source_type": "restricted", + "name": "INTERNAL_ONLY", + "attribution": None, +} + + +class RegulationRegistry: + """In-memory cache of the regulation_registry table. + + Provides two lookup modes: + 1. by_code(regulation_id) — replaces REGULATION_LICENSE_MAP[code] + 2. source_type_by_name(name) — replaces SOURCE_REGULATION_CLASSIFICATION[name] + """ + + def __init__(self): + self._by_code: dict[str, dict] = {} + self._by_name: dict[str, str] = {} + self._loaded_at: float = 0.0 + + def _is_stale(self) -> bool: + return (time.monotonic() - self._loaded_at) > _CACHE_TTL_SECONDS + + def _load(self) -> bool: + """Load all rows from regulation_registry into memory.""" + try: + db = SessionLocal() + try: + rows = db.execute( + text(""" + SELECT regulation_id, regulation_name_de, license_rule, + license_type, attribution, source_type, jurisdiction, + status + FROM regulation_registry + WHERE status != 'deprecated' + """) + ).fetchall() + finally: + db.close() + + by_code: dict[str, dict] = {} + by_name: dict[str, str] = {} + + for row in rows: + entry = { + "license": row[3] or "", # license_type + "rule": row[2], # license_rule + "source_type": row[5] or "law", # source_type + "name": row[1] or row[0], # regulation_name_de or regulation_id + "attribution": row[4], # attribution + "jurisdiction": row[6], # jurisdiction + } + by_code[row[0].lower()] = entry + + # Also index by name for source_type lookups + if row[1]: + by_name[row[1]] = row[5] or "law" + + self._by_code = by_code + self._by_name = by_name + self._loaded_at = time.monotonic() + logger.info( + "Regulation registry loaded: %d entries by code, %d by name", + len(by_code), len(by_name), + ) + return True + + except SQLAlchemyError: + logger.warning( + "Failed to load regulation_registry from DB — using stale cache", + exc_info=True, + ) + return False + + def _ensure_loaded(self) -> None: + """Reload cache if stale.""" + if self._is_stale(): + self._load() + + def classify_regulation(self, regulation_code: str) -> dict: + """Look up license info for a regulation_code. + + Returns dict with keys: license, rule, name, source_type, attribution. + Equivalent to the old _classify_regulation() function. + """ + self._ensure_loaded() + code = regulation_code.lower().strip() + + # Exact match from DB + if code in self._by_code: + return self._by_code[code] + + # Prefix match for Rule 2 (ENISA = standard) + for prefix in _RULE2_PREFIXES: + if code.startswith(prefix): + return { + "license": "CC-BY-4.0", + "rule": 2, + "source_type": "standard", + "name": "ENISA", + "attribution": "ENISA, CC BY 4.0", + } + + # Prefix match for Rule 3 (BSI/ISO/ETSI = restricted) + for prefix in _RULE3_PREFIXES: + if code.startswith(prefix): + return { + "license": f"{prefix.rstrip('_').upper()}_RESTRICTED", + "rule": 3, + "source_type": "restricted", + "name": "INTERNAL_ONLY", + "attribution": None, + } + + # Unknown → restricted (safe default) + logger.warning( + "Unknown regulation_code %r — defaulting to Rule 3 (restricted)", code + ) + return dict(_UNKNOWN_REGULATION) + + def source_type_by_name(self, source_regulation: str) -> str: + """Look up source_type by regulation display name. + + Equivalent to old classify_source_regulation(). + Falls back to heuristic for unknown names. + """ + self._ensure_loaded() + + if not source_regulation: + return "framework" + + # Exact match from DB + if source_regulation in self._by_name: + return self._by_name[source_regulation] + + # Heuristic fallback for unknown sources + lower = source_regulation.lower() + + law_indicators = [ + "verordnung", "richtlinie", "gesetz", "directive", "regulation", + "(eu)", "(eg)", "act", "ley", "loi", "törvény", "código", + ] + if any(ind in lower for ind in law_indicators): + return "law" + + guideline_indicators = [ + "edpb", "leitlinie", "guideline", "wp2", "bsi", "empfehlung", + ] + if any(ind in lower for ind in guideline_indicators): + return "guideline" + + framework_indicators = [ + "enisa", "nist", "owasp", "oecd", "cisa", "framework", "iso", + ] + if any(ind in lower for ind in framework_indicators): + return "framework" + + return "framework" + + def get_all(self) -> dict[str, dict]: + """Return all cached entries (by regulation_code).""" + self._ensure_loaded() + return dict(self._by_code) + + def is_open_source(self, regulation_code: str) -> bool: + """Check if regulation is Rule 1 or 2 (safe to reference).""" + info = self.classify_regulation(regulation_code) + return info["rule"] in (1, 2) + + +# Module-level singleton +_registry: Optional[RegulationRegistry] = None + + +def get_registry() -> RegulationRegistry: + """Get or create the singleton RegulationRegistry instance.""" + global _registry + if _registry is None: + _registry = RegulationRegistry() + return _registry + + +def classify_regulation(regulation_code: str) -> dict: + """Convenience: look up license info for a regulation_code.""" + return get_registry().classify_regulation(regulation_code) + + +def classify_source_regulation(source_regulation: str) -> str: + """Convenience: look up source_type by regulation display name.""" + return get_registry().source_type_by_name(source_regulation) diff --git a/control-pipeline/tests/test_regulation_registry.py b/control-pipeline/tests/test_regulation_registry.py new file mode 100644 index 0000000..8c03aba --- /dev/null +++ b/control-pipeline/tests/test_regulation_registry.py @@ -0,0 +1,285 @@ +"""Tests for RegulationRegistry — DB-backed lookup with cache and fallback.""" + +import time +from unittest.mock import patch, MagicMock + +import pytest + +from services.regulation_registry import ( + RegulationRegistry, + _CACHE_TTL_SECONDS, +) + + +# ── Test data: simulates DB rows ────────────────────────────────────────── + +_MOCK_DB_ROWS = [ + # (regulation_id, regulation_name_de, license_rule, license_type, + # attribution, source_type, jurisdiction, status) + ("eu_2016_679", "DSGVO (EU) 2016/679", 1, "EU_LAW", + None, "law", "EU", "active"), + ("nist_sp_800_53", "NIST SP 800-53 Rev. 5", 1, "NIST_PUBLIC_DOMAIN", + None, "standard", "US", "active"), + ("owasp_asvs", "OWASP ASVS 4.0", 2, "CC-BY-SA-4.0", + "OWASP Foundation, CC BY-SA 4.0", "standard", "INT", "active"), + ("bdsg", "Bundesdatenschutzgesetz (BDSG)", 1, "DE_LAW", + None, "law", "DE", "active"), + ("at_dsg", "Österreichisches Datenschutzgesetz (DSG)", 1, "AT_LAW", + None, "law", "AT", "active"), +] + + +def _mock_db_execute(query): + """Mock that returns our test rows.""" + mock_result = MagicMock() + mock_result.fetchall.return_value = _MOCK_DB_ROWS + return mock_result + + +@pytest.fixture +def registry(): + """Create a registry with mocked DB.""" + reg = RegulationRegistry() + with patch("services.regulation_registry.SessionLocal") as mock_session_cls: + mock_session = MagicMock() + mock_session.execute = _mock_db_execute + mock_session_cls.return_value = mock_session + reg._load() + return reg + + +# ── classify_regulation tests ───────────────────────────────────────────── + + +class TestClassifyRegulation: + def test_exact_match_eu_law(self, registry): + result = registry.classify_regulation("eu_2016_679") + assert result["rule"] == 1 + assert result["license"] == "EU_LAW" + assert result["source_type"] == "law" + assert result["name"] == "DSGVO (EU) 2016/679" + + def test_exact_match_case_insensitive(self, registry): + result = registry.classify_regulation("EU_2016_679") + assert result["rule"] == 1 + assert result["name"] == "DSGVO (EU) 2016/679" + + def test_exact_match_with_whitespace(self, registry): + result = registry.classify_regulation(" eu_2016_679 ") + assert result["rule"] == 1 + + def test_nist_standard(self, registry): + result = registry.classify_regulation("nist_sp_800_53") + assert result["rule"] == 1 + assert result["source_type"] == "standard" + + def test_owasp_rule2(self, registry): + result = registry.classify_regulation("owasp_asvs") + assert result["rule"] == 2 + assert result["attribution"] == "OWASP Foundation, CC BY-SA 4.0" + + def test_german_law(self, registry): + result = registry.classify_regulation("bdsg") + assert result["rule"] == 1 + assert result["source_type"] == "law" + assert result["jurisdiction"] == "DE" + + def test_austrian_law(self, registry): + result = registry.classify_regulation("at_dsg") + assert result["rule"] == 1 + assert result["jurisdiction"] == "AT" + + def test_prefix_enisa_rule2(self, registry): + result = registry.classify_regulation("enisa_supply_chain_2024") + assert result["rule"] == 2 + assert result["source_type"] == "standard" + assert "ENISA" in result["attribution"] + + def test_prefix_bsi_rule3(self, registry): + result = registry.classify_regulation("bsi_tr_03161") + assert result["rule"] == 3 + assert result["source_type"] == "restricted" + assert result["name"] == "INTERNAL_ONLY" + + def test_prefix_iso_rule3(self, registry): + result = registry.classify_regulation("iso_27001") + assert result["rule"] == 3 + assert result["source_type"] == "restricted" + + def test_prefix_etsi_rule3(self, registry): + result = registry.classify_regulation("etsi_en_303_645") + assert result["rule"] == 3 + + def test_unknown_defaults_to_restricted(self, registry): + result = registry.classify_regulation("some_unknown_regulation") + assert result["rule"] == 3 + assert result["source_type"] == "restricted" + assert result["license"] == "UNKNOWN" + + +# ── source_type_by_name tests ──────────────────────────────────────────── + + +class TestSourceTypeByName: + def test_exact_match_law(self, registry): + result = registry.source_type_by_name("DSGVO (EU) 2016/679") + assert result == "law" + + def test_exact_match_standard(self, registry): + result = registry.source_type_by_name("NIST SP 800-53 Rev. 5") + assert result == "standard" + + def test_empty_returns_framework(self, registry): + assert registry.source_type_by_name("") == "framework" + assert registry.source_type_by_name(None) == "framework" + + def test_heuristic_law(self, registry): + assert registry.source_type_by_name("Verordnung XYZ") == "law" + assert registry.source_type_by_name("Some EU Directive") == "law" + + def test_heuristic_guideline(self, registry): + assert registry.source_type_by_name("EDPB Leitlinie 99/2025") == "guideline" + assert registry.source_type_by_name("BSI Standard 200-1") == "guideline" + + def test_heuristic_framework(self, registry): + # "ENISA Cloud Guidelines" matches "guideline" before "enisa" in heuristic order + assert registry.source_type_by_name("ENISA Cloud Report") == "framework" + assert registry.source_type_by_name("OWASP Testing Guide") == "framework" + + def test_unknown_returns_framework(self, registry): + assert registry.source_type_by_name("Completely Unknown Document") == "framework" + + +# ── is_open_source tests ─────────────��─────────────────────────────────── + + +class TestIsOpenSource: + def test_rule1_is_open(self, registry): + assert registry.is_open_source("eu_2016_679") is True + + def test_rule2_is_open(self, registry): + assert registry.is_open_source("owasp_asvs") is True + + def test_rule3_is_not_open(self, registry): + assert registry.is_open_source("bsi_tr_03161") is False + + def test_unknown_is_not_open(self, registry): + assert registry.is_open_source("unknown_thing") is False + + +# ── Cache behavior tests ──────��────────────────────────────────────────── + + +class TestCacheBehavior: + def test_fresh_cache_not_stale(self, registry): + assert registry._is_stale() is False + + def test_old_cache_is_stale(self, registry): + registry._loaded_at = time.monotonic() - _CACHE_TTL_SECONDS - 1 + assert registry._is_stale() is True + + def test_ensure_loaded_reloads_when_stale(self): + reg = RegulationRegistry() + reg._loaded_at = time.monotonic() - _CACHE_TTL_SECONDS - 100 # force stale + + load_called = False + original_load = reg._load + + def tracking_load(): + nonlocal load_called + load_called = True + + reg._load = tracking_load + reg._ensure_loaded() + assert load_called, "_load should have been called when cache is stale" + + def test_ensure_loaded_skips_when_fresh(self, registry): + with patch.object(registry, "_load") as mock_load: + registry._ensure_loaded() + mock_load.assert_not_called() + + +# ── Graceful degradation tests ──────��──────────────────────────────────── + + +class TestGracefulDegradation: + def test_db_failure_uses_stale_cache(self): + """If DB fails, stale cache entries are still usable.""" + reg = RegulationRegistry() + + # First load succeeds + with patch("services.regulation_registry.SessionLocal") as mock_cls: + mock_session = MagicMock() + mock_session.execute = _mock_db_execute + mock_cls.return_value = mock_session + reg._load() + + # Force stale + reg._loaded_at = time.monotonic() - _CACHE_TTL_SECONDS - 1 + + # Second load fails — DB error + from sqlalchemy.exc import OperationalError + with patch("services.regulation_registry.SessionLocal") as mock_cls: + mock_cls.side_effect = OperationalError("connection refused", None, None) + reg._ensure_loaded() + + # Should still have cached data + result = reg.classify_regulation("eu_2016_679") + assert result["rule"] == 1 + + def test_empty_registry_returns_unknown(self): + """Unloaded registry returns safe defaults.""" + reg = RegulationRegistry() + reg._loaded_at = time.monotonic() # pretend fresh but empty + + result = reg.classify_regulation("eu_2016_679") + assert result["rule"] == 3 # safe default + assert result["license"] == "UNKNOWN" + + +# ── Migration data consistency tests ───────��───────────────────────────── + + +class TestMigrationDataConsistency: + """Verify that the migration script produces valid data.""" + + def test_build_rows_produces_data(self): + from scripts.f1_migrate_regulation_registry import build_rows + rows = build_rows() + assert len(rows) > 100 # at least 100 entries + + def test_all_rows_have_required_fields(self): + from scripts.f1_migrate_regulation_registry import build_rows + rows = build_rows() + for row in rows: + assert row["regulation_id"], f"Missing regulation_id: {row}" + assert row["regulation_name_de"], f"Missing name: {row}" + assert row["license_rule"] in (1, 2, 3), f"Bad rule: {row}" + assert row["source_type"] in ( + "law", "guideline", "standard", "framework", "restricted" + ), f"Bad source_type: {row}" + assert row["jurisdiction"], f"Missing jurisdiction: {row}" + assert row["status"] in ("active", "needs_review", "deprecated") + + def test_no_duplicate_regulation_ids(self): + from scripts.f1_migrate_regulation_registry import build_rows + rows = build_rows() + ids = [r["regulation_id"] for r in rows] + assert len(ids) == len(set(ids)), f"Duplicates: {[x for x in ids if ids.count(x) > 1]}" + + def test_known_regulations_present(self): + from scripts.f1_migrate_regulation_registry import build_rows + rows = build_rows() + ids = {r["regulation_id"] for r in rows} + assert "eu_2016_679" in ids # DSGVO + assert "bdsg" in ids # BDSG + assert "nist_sp_800_53" in ids # NIST + assert "owasp_asvs" in ids # OWASP + + def test_owasp_has_attribution(self): + from scripts.f1_migrate_regulation_registry import build_rows + rows = build_rows() + owasp = [r for r in rows if r["regulation_id"] == "owasp_asvs"][0] + assert owasp["attribution"] is not None + assert "OWASP" in owasp["attribution"] + assert owasp["license_rule"] == 2