""" DB-backed Regulation Registry with in-memory cache. Replaces hardcoded REGULATION_LICENSE_MAP and SOURCE_REGULATION_CLASSIFICATION with a single PostgreSQL table (compliance.regulation_registry). Cache TTL: 5 minutes. Thread-safe via simple timestamp check. Falls back to hardcoded dicts if DB is unavailable (graceful degradation). """ import logging import time from typing import Optional from sqlalchemy import text from sqlalchemy.exc import SQLAlchemyError from db.session import SessionLocal logger = logging.getLogger(__name__) _CACHE_TTL_SECONDS = 300 # 5 minutes # Prefix-based fallback rules (unchanged from original logic) _RULE2_PREFIXES = ("enisa_",) _RULE3_PREFIXES = ("bsi_", "iso_", "etsi_") # Fallback for unknown regulations _UNKNOWN_REGULATION = { "license": "UNKNOWN", "rule": 3, "source_type": "restricted", "name": "INTERNAL_ONLY", "attribution": None, } class RegulationRegistry: """In-memory cache of the regulation_registry table. Provides two lookup modes: 1. by_code(regulation_id) — replaces REGULATION_LICENSE_MAP[code] 2. source_type_by_name(name) — replaces SOURCE_REGULATION_CLASSIFICATION[name] """ def __init__(self): self._by_code: dict[str, dict] = {} self._by_name: dict[str, str] = {} self._loaded_at: float = 0.0 def _is_stale(self) -> bool: return (time.monotonic() - self._loaded_at) > _CACHE_TTL_SECONDS def _load(self) -> bool: """Load all rows from regulation_registry into memory.""" try: db = SessionLocal() try: rows = db.execute( text(""" SELECT regulation_id, regulation_name_de, license_rule, license_type, attribution, source_type, jurisdiction, status FROM regulation_registry WHERE status != 'deprecated' """) ).fetchall() finally: db.close() by_code: dict[str, dict] = {} by_name: dict[str, str] = {} for row in rows: entry = { "license": row[3] or "", # license_type "rule": row[2], # license_rule "source_type": row[5] or "law", # source_type "name": row[1] or row[0], # regulation_name_de or regulation_id "attribution": row[4], # attribution "jurisdiction": row[6], # jurisdiction } by_code[row[0].lower()] = entry # Also index by name for source_type lookups if row[1]: by_name[row[1]] = row[5] or "law" self._by_code = by_code self._by_name = by_name self._loaded_at = time.monotonic() logger.info( "Regulation registry loaded: %d entries by code, %d by name", len(by_code), len(by_name), ) return True except SQLAlchemyError: logger.warning( "Failed to load regulation_registry from DB — using stale cache", exc_info=True, ) return False def _ensure_loaded(self) -> None: """Reload cache if stale.""" if self._is_stale(): self._load() def classify_regulation(self, regulation_code: str) -> dict: """Look up license info for a regulation_code. Returns dict with keys: license, rule, name, source_type, attribution. Equivalent to the old _classify_regulation() function. """ self._ensure_loaded() code = regulation_code.lower().strip() # Exact match from DB if code in self._by_code: return self._by_code[code] # Prefix match for Rule 2 (ENISA = standard) for prefix in _RULE2_PREFIXES: if code.startswith(prefix): return { "license": "CC-BY-4.0", "rule": 2, "source_type": "standard", "name": "ENISA", "attribution": "ENISA, CC BY 4.0", } # Prefix match for Rule 3 (BSI/ISO/ETSI = restricted) for prefix in _RULE3_PREFIXES: if code.startswith(prefix): return { "license": f"{prefix.rstrip('_').upper()}_RESTRICTED", "rule": 3, "source_type": "restricted", "name": "INTERNAL_ONLY", "attribution": None, } # Unknown → restricted (safe default) logger.warning( "Unknown regulation_code %r — defaulting to Rule 3 (restricted)", code ) return dict(_UNKNOWN_REGULATION) def source_type_by_name(self, source_regulation: str) -> str: """Look up source_type by regulation display name. Equivalent to old classify_source_regulation(). Falls back to heuristic for unknown names. """ self._ensure_loaded() if not source_regulation: return "framework" # Exact match from DB if source_regulation in self._by_name: return self._by_name[source_regulation] # Heuristic fallback for unknown sources lower = source_regulation.lower() law_indicators = [ "verordnung", "richtlinie", "gesetz", "directive", "regulation", "(eu)", "(eg)", "act", "ley", "loi", "törvény", "código", ] if any(ind in lower for ind in law_indicators): return "law" guideline_indicators = [ "edpb", "leitlinie", "guideline", "wp2", "bsi", "empfehlung", ] if any(ind in lower for ind in guideline_indicators): return "guideline" framework_indicators = [ "enisa", "nist", "owasp", "oecd", "cisa", "framework", "iso", ] if any(ind in lower for ind in framework_indicators): return "framework" return "framework" def get_all(self) -> dict[str, dict]: """Return all cached entries (by regulation_code).""" self._ensure_loaded() return dict(self._by_code) def is_open_source(self, regulation_code: str) -> bool: """Check if regulation is Rule 1 or 2 (safe to reference).""" info = self.classify_regulation(regulation_code) return info["rule"] in (1, 2) # Module-level singleton _registry: Optional[RegulationRegistry] = None def get_registry() -> RegulationRegistry: """Get or create the singleton RegulationRegistry instance.""" global _registry if _registry is None: _registry = RegulationRegistry() return _registry def classify_regulation(regulation_code: str) -> dict: """Convenience: look up license info for a regulation_code.""" return get_registry().classify_regulation(regulation_code) def classify_source_regulation(source_regulation: str) -> str: """Convenience: look up source_type by regulation display name.""" return get_registry().source_type_by_name(source_regulation)