feat(canonical-controls): Canonical Control Library — rechtssichere Security Controls
All checks were successful
CI/CD / go-lint (push) Has been skipped
CI/CD / python-lint (push) Has been skipped
CI/CD / nodejs-lint (push) Has been skipped
CI/CD / test-go-ai-compliance (push) Successful in 40s
CI/CD / test-python-backend-compliance (push) Successful in 41s
CI/CD / test-python-document-crawler (push) Successful in 26s
CI/CD / test-python-dsms-gateway (push) Successful in 23s
CI/CD / validate-canonical-controls (push) Successful in 18s
CI/CD / deploy-hetzner (push) Successful in 2m26s
All checks were successful
CI/CD / go-lint (push) Has been skipped
CI/CD / python-lint (push) Has been skipped
CI/CD / nodejs-lint (push) Has been skipped
CI/CD / test-go-ai-compliance (push) Successful in 40s
CI/CD / test-python-backend-compliance (push) Successful in 41s
CI/CD / test-python-document-crawler (push) Successful in 26s
CI/CD / test-python-dsms-gateway (push) Successful in 23s
CI/CD / validate-canonical-controls (push) Successful in 18s
CI/CD / deploy-hetzner (push) Successful in 2m26s
Eigenstaendig formulierte Security Controls mit unabhaengiger Taxonomie und Open-Source-Verankerung (OWASP, NIST, ENISA). Keine BSI-Nomenklatur. - Migration 044: 5 DB-Tabellen (frameworks, controls, sources, licenses, mappings) - 10 Seed Controls mit 39 Open-Source-Referenzen - License Gate: Quellen-Berechtigungspruefung (analysis/excerpt/embeddings/product) - Too-Close-Detektor: 5 Metriken (exact-phrase, token-overlap, ngram, embedding, LCS) - REST API: 8 Endpoints unter /v1/canonical/ - Go Loader mit Multi-Index (ID, domain, severity, framework) - Frontend: Control Library Browser + Provenance Wiki - CI/CD: validate-controls.py Job (schema, no-leak, open-anchors) - 67 Tests (8 Go + 59 Python), alle PASS - MkDocs Dokumentation Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -33,6 +33,7 @@ from .change_request_routes import router as change_request_router
|
||||
from .generation_routes import router as generation_router
|
||||
from .project_routes import router as project_router
|
||||
from .wiki_routes import router as wiki_router
|
||||
from .canonical_control_routes import router as canonical_control_router
|
||||
|
||||
# Include sub-routers
|
||||
router.include_router(audit_router)
|
||||
@@ -67,6 +68,7 @@ router.include_router(change_request_router)
|
||||
router.include_router(generation_router)
|
||||
router.include_router(project_router)
|
||||
router.include_router(wiki_router)
|
||||
router.include_router(canonical_control_router)
|
||||
|
||||
__all__ = [
|
||||
"router",
|
||||
@@ -101,4 +103,5 @@ __all__ = [
|
||||
"generation_router",
|
||||
"project_router",
|
||||
"wiki_router",
|
||||
"canonical_control_router",
|
||||
]
|
||||
|
||||
332
backend-compliance/compliance/api/canonical_control_routes.py
Normal file
332
backend-compliance/compliance/api/canonical_control_routes.py
Normal file
@@ -0,0 +1,332 @@
|
||||
"""
|
||||
FastAPI routes for the Canonical Control Library.
|
||||
|
||||
Provides read-only access to independently authored security controls.
|
||||
All controls are formulated without proprietary nomenclature and anchored
|
||||
in open-source frameworks (OWASP, NIST, ENISA).
|
||||
|
||||
Endpoints:
|
||||
GET /v1/canonical/frameworks — All frameworks
|
||||
GET /v1/canonical/frameworks/{framework_id} — Framework details
|
||||
GET /v1/canonical/frameworks/{framework_id}/controls — Controls of a framework
|
||||
GET /v1/canonical/controls — All controls (filterable)
|
||||
GET /v1/canonical/controls/{control_id} — Single control by control_id
|
||||
GET /v1/canonical/sources — Source registry
|
||||
GET /v1/canonical/licenses — License matrix
|
||||
POST /v1/canonical/controls/{control_id}/similarity-check — Too-close check
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any, Optional
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Query
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import text
|
||||
|
||||
from database import SessionLocal
|
||||
from compliance.services.license_gate import get_license_matrix, get_source_permissions
|
||||
from compliance.services.similarity_detector import check_similarity
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter(prefix="/v1/canonical", tags=["canonical-controls"])
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# RESPONSE MODELS
|
||||
# =============================================================================
|
||||
|
||||
class FrameworkResponse(BaseModel):
|
||||
id: str
|
||||
framework_id: str
|
||||
name: str
|
||||
version: str
|
||||
description: Optional[str] = None
|
||||
owner: Optional[str] = None
|
||||
policy_version: Optional[str] = None
|
||||
release_state: str
|
||||
created_at: str
|
||||
updated_at: str
|
||||
|
||||
|
||||
class ControlResponse(BaseModel):
|
||||
id: str
|
||||
framework_id: str
|
||||
control_id: str
|
||||
title: str
|
||||
objective: str
|
||||
rationale: str
|
||||
scope: dict
|
||||
requirements: list
|
||||
test_procedure: list
|
||||
evidence: list
|
||||
severity: str
|
||||
risk_score: Optional[float] = None
|
||||
implementation_effort: Optional[str] = None
|
||||
evidence_confidence: Optional[float] = None
|
||||
open_anchors: list
|
||||
release_state: str
|
||||
tags: list
|
||||
created_at: str
|
||||
updated_at: str
|
||||
|
||||
|
||||
class SimilarityCheckRequest(BaseModel):
|
||||
source_text: str
|
||||
candidate_text: str
|
||||
|
||||
|
||||
class SimilarityCheckResponse(BaseModel):
|
||||
max_exact_run: int
|
||||
token_overlap: float
|
||||
ngram_jaccard: float
|
||||
embedding_cosine: float
|
||||
lcs_ratio: float
|
||||
status: str
|
||||
details: dict
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# HELPERS
|
||||
# =============================================================================
|
||||
|
||||
def _row_to_dict(row, columns: list[str]) -> dict[str, Any]:
|
||||
"""Generic row → dict converter."""
|
||||
return {col: (getattr(row, col).isoformat() if hasattr(getattr(row, col, None), 'isoformat') else getattr(row, col)) for col in columns}
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# FRAMEWORKS
|
||||
# =============================================================================
|
||||
|
||||
@router.get("/frameworks")
|
||||
async def list_frameworks():
|
||||
"""List all registered control frameworks."""
|
||||
with SessionLocal() as db:
|
||||
rows = db.execute(
|
||||
text("""
|
||||
SELECT id, framework_id, name, version, description,
|
||||
owner, policy_version, release_state,
|
||||
created_at, updated_at
|
||||
FROM canonical_control_frameworks
|
||||
ORDER BY name
|
||||
""")
|
||||
).fetchall()
|
||||
|
||||
return [
|
||||
{
|
||||
"id": str(r.id),
|
||||
"framework_id": r.framework_id,
|
||||
"name": r.name,
|
||||
"version": r.version,
|
||||
"description": r.description,
|
||||
"owner": r.owner,
|
||||
"policy_version": r.policy_version,
|
||||
"release_state": r.release_state,
|
||||
"created_at": r.created_at.isoformat() if r.created_at else None,
|
||||
"updated_at": r.updated_at.isoformat() if r.updated_at else None,
|
||||
}
|
||||
for r in rows
|
||||
]
|
||||
|
||||
|
||||
@router.get("/frameworks/{framework_id}")
|
||||
async def get_framework(framework_id: str):
|
||||
"""Get a single framework by its framework_id."""
|
||||
with SessionLocal() as db:
|
||||
row = db.execute(
|
||||
text("""
|
||||
SELECT id, framework_id, name, version, description,
|
||||
owner, policy_version, release_state,
|
||||
created_at, updated_at
|
||||
FROM canonical_control_frameworks
|
||||
WHERE framework_id = :fid
|
||||
"""),
|
||||
{"fid": framework_id},
|
||||
).fetchone()
|
||||
|
||||
if not row:
|
||||
raise HTTPException(status_code=404, detail="Framework not found")
|
||||
|
||||
return {
|
||||
"id": str(row.id),
|
||||
"framework_id": row.framework_id,
|
||||
"name": row.name,
|
||||
"version": row.version,
|
||||
"description": row.description,
|
||||
"owner": row.owner,
|
||||
"policy_version": row.policy_version,
|
||||
"release_state": row.release_state,
|
||||
"created_at": row.created_at.isoformat() if row.created_at else None,
|
||||
"updated_at": row.updated_at.isoformat() if row.updated_at else None,
|
||||
}
|
||||
|
||||
|
||||
@router.get("/frameworks/{framework_id}/controls")
|
||||
async def list_framework_controls(
|
||||
framework_id: str,
|
||||
severity: Optional[str] = Query(None),
|
||||
release_state: Optional[str] = Query(None),
|
||||
):
|
||||
"""List controls belonging to a framework."""
|
||||
with SessionLocal() as db:
|
||||
# Resolve framework UUID
|
||||
fw = db.execute(
|
||||
text("SELECT id FROM canonical_control_frameworks WHERE framework_id = :fid"),
|
||||
{"fid": framework_id},
|
||||
).fetchone()
|
||||
if not fw:
|
||||
raise HTTPException(status_code=404, detail="Framework not found")
|
||||
|
||||
query = """
|
||||
SELECT id, framework_id, control_id, title, objective, rationale,
|
||||
scope, requirements, test_procedure, evidence,
|
||||
severity, risk_score, implementation_effort,
|
||||
evidence_confidence, open_anchors, release_state, tags,
|
||||
created_at, updated_at
|
||||
FROM canonical_controls
|
||||
WHERE framework_id = :fw_id
|
||||
"""
|
||||
params: dict[str, Any] = {"fw_id": str(fw.id)}
|
||||
|
||||
if severity:
|
||||
query += " AND severity = :sev"
|
||||
params["sev"] = severity
|
||||
if release_state:
|
||||
query += " AND release_state = :rs"
|
||||
params["rs"] = release_state
|
||||
|
||||
query += " ORDER BY control_id"
|
||||
rows = db.execute(text(query), params).fetchall()
|
||||
|
||||
return [_control_row(r) for r in rows]
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# CONTROLS
|
||||
# =============================================================================
|
||||
|
||||
@router.get("/controls")
|
||||
async def list_controls(
|
||||
severity: Optional[str] = Query(None),
|
||||
domain: Optional[str] = Query(None),
|
||||
release_state: Optional[str] = Query(None),
|
||||
):
|
||||
"""List all canonical controls, with optional filters."""
|
||||
query = """
|
||||
SELECT id, framework_id, control_id, title, objective, rationale,
|
||||
scope, requirements, test_procedure, evidence,
|
||||
severity, risk_score, implementation_effort,
|
||||
evidence_confidence, open_anchors, release_state, tags,
|
||||
created_at, updated_at
|
||||
FROM canonical_controls
|
||||
WHERE 1=1
|
||||
"""
|
||||
params: dict[str, Any] = {}
|
||||
|
||||
if severity:
|
||||
query += " AND severity = :sev"
|
||||
params["sev"] = severity
|
||||
if domain:
|
||||
query += " AND LEFT(control_id, LENGTH(:dom)) = :dom"
|
||||
params["dom"] = domain.upper()
|
||||
if release_state:
|
||||
query += " AND release_state = :rs"
|
||||
params["rs"] = release_state
|
||||
|
||||
query += " ORDER BY control_id"
|
||||
|
||||
with SessionLocal() as db:
|
||||
rows = db.execute(text(query), params).fetchall()
|
||||
|
||||
return [_control_row(r) for r in rows]
|
||||
|
||||
|
||||
@router.get("/controls/{control_id}")
|
||||
async def get_control(control_id: str):
|
||||
"""Get a single canonical control by its control_id (e.g. AUTH-001)."""
|
||||
with SessionLocal() as db:
|
||||
row = db.execute(
|
||||
text("""
|
||||
SELECT id, framework_id, control_id, title, objective, rationale,
|
||||
scope, requirements, test_procedure, evidence,
|
||||
severity, risk_score, implementation_effort,
|
||||
evidence_confidence, open_anchors, release_state, tags,
|
||||
created_at, updated_at
|
||||
FROM canonical_controls
|
||||
WHERE control_id = :cid
|
||||
"""),
|
||||
{"cid": control_id.upper()},
|
||||
).fetchone()
|
||||
|
||||
if not row:
|
||||
raise HTTPException(status_code=404, detail="Control not found")
|
||||
|
||||
return _control_row(row)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# SIMILARITY CHECK
|
||||
# =============================================================================
|
||||
|
||||
@router.post("/controls/{control_id}/similarity-check")
|
||||
async def similarity_check(control_id: str, body: SimilarityCheckRequest):
|
||||
"""Run the too-close detector against a source/candidate text pair."""
|
||||
report = await check_similarity(body.source_text, body.candidate_text)
|
||||
return {
|
||||
"control_id": control_id.upper(),
|
||||
"max_exact_run": report.max_exact_run,
|
||||
"token_overlap": report.token_overlap,
|
||||
"ngram_jaccard": report.ngram_jaccard,
|
||||
"embedding_cosine": report.embedding_cosine,
|
||||
"lcs_ratio": report.lcs_ratio,
|
||||
"status": report.status,
|
||||
"details": report.details,
|
||||
}
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# SOURCES & LICENSES
|
||||
# =============================================================================
|
||||
|
||||
@router.get("/sources")
|
||||
async def list_sources():
|
||||
"""List all registered sources with permission flags."""
|
||||
with SessionLocal() as db:
|
||||
return get_source_permissions(db)
|
||||
|
||||
|
||||
@router.get("/licenses")
|
||||
async def list_licenses():
|
||||
"""Return the license matrix."""
|
||||
with SessionLocal() as db:
|
||||
return get_license_matrix(db)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# INTERNAL HELPERS
|
||||
# =============================================================================
|
||||
|
||||
def _control_row(r) -> dict:
|
||||
return {
|
||||
"id": str(r.id),
|
||||
"framework_id": str(r.framework_id),
|
||||
"control_id": r.control_id,
|
||||
"title": r.title,
|
||||
"objective": r.objective,
|
||||
"rationale": r.rationale,
|
||||
"scope": r.scope,
|
||||
"requirements": r.requirements,
|
||||
"test_procedure": r.test_procedure,
|
||||
"evidence": r.evidence,
|
||||
"severity": r.severity,
|
||||
"risk_score": float(r.risk_score) if r.risk_score is not None else None,
|
||||
"implementation_effort": r.implementation_effort,
|
||||
"evidence_confidence": float(r.evidence_confidence) if r.evidence_confidence is not None else None,
|
||||
"open_anchors": r.open_anchors,
|
||||
"release_state": r.release_state,
|
||||
"tags": r.tags or [],
|
||||
"created_at": r.created_at.isoformat() if r.created_at else None,
|
||||
"updated_at": r.updated_at.isoformat() if r.updated_at else None,
|
||||
}
|
||||
116
backend-compliance/compliance/services/license_gate.py
Normal file
116
backend-compliance/compliance/services/license_gate.py
Normal file
@@ -0,0 +1,116 @@
|
||||
"""
|
||||
License Gate — checks whether a given source may be used for a specific purpose.
|
||||
|
||||
Usage types:
|
||||
- analysis: Read + analyse internally (TDM under UrhG 44b)
|
||||
- store_excerpt: Store verbatim excerpt in vault
|
||||
- ship_embeddings: Ship embeddings in product
|
||||
- ship_in_product: Ship text/content in product
|
||||
|
||||
Policy is driven by the canonical_control_sources table columns:
|
||||
allowed_analysis, allowed_store_excerpt, allowed_ship_embeddings, allowed_ship_in_product
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
USAGE_COLUMN_MAP = {
|
||||
"analysis": "allowed_analysis",
|
||||
"store_excerpt": "allowed_store_excerpt",
|
||||
"ship_embeddings": "allowed_ship_embeddings",
|
||||
"ship_in_product": "allowed_ship_in_product",
|
||||
}
|
||||
|
||||
|
||||
def check_source_allowed(db: Session, source_id: str, usage_type: str) -> bool:
|
||||
"""Check whether *source_id* may be used for *usage_type*.
|
||||
|
||||
Returns False if the source is unknown or the usage is not allowed.
|
||||
"""
|
||||
col = USAGE_COLUMN_MAP.get(usage_type)
|
||||
if col is None:
|
||||
logger.warning("Unknown usage_type=%s", usage_type)
|
||||
return False
|
||||
|
||||
row = db.execute(
|
||||
text(f"SELECT {col} FROM canonical_control_sources WHERE source_id = :sid"),
|
||||
{"sid": source_id},
|
||||
).fetchone()
|
||||
|
||||
if row is None:
|
||||
logger.warning("Source %s not found in registry", source_id)
|
||||
return False
|
||||
|
||||
return bool(row[0])
|
||||
|
||||
|
||||
def get_license_matrix(db: Session) -> list[dict[str, Any]]:
|
||||
"""Return the full license matrix with allowed usages per license."""
|
||||
rows = db.execute(
|
||||
text("""
|
||||
SELECT license_id, name, terms_url, commercial_use,
|
||||
ai_training_restriction, tdm_allowed_under_44b,
|
||||
deletion_required, notes
|
||||
FROM canonical_control_licenses
|
||||
ORDER BY license_id
|
||||
""")
|
||||
).fetchall()
|
||||
|
||||
return [
|
||||
{
|
||||
"license_id": r.license_id,
|
||||
"name": r.name,
|
||||
"terms_url": r.terms_url,
|
||||
"commercial_use": r.commercial_use,
|
||||
"ai_training_restriction": r.ai_training_restriction,
|
||||
"tdm_allowed_under_44b": r.tdm_allowed_under_44b,
|
||||
"deletion_required": r.deletion_required,
|
||||
"notes": r.notes,
|
||||
}
|
||||
for r in rows
|
||||
]
|
||||
|
||||
|
||||
def get_source_permissions(db: Session) -> list[dict[str, Any]]:
|
||||
"""Return all sources with their permission flags."""
|
||||
rows = db.execute(
|
||||
text("""
|
||||
SELECT s.source_id, s.title, s.publisher, s.url, s.version_label,
|
||||
s.language, s.license_id,
|
||||
s.allowed_analysis, s.allowed_store_excerpt,
|
||||
s.allowed_ship_embeddings, s.allowed_ship_in_product,
|
||||
s.vault_retention_days, s.vault_access_tier,
|
||||
l.name AS license_name, l.commercial_use
|
||||
FROM canonical_control_sources s
|
||||
JOIN canonical_control_licenses l ON l.license_id = s.license_id
|
||||
ORDER BY s.source_id
|
||||
""")
|
||||
).fetchall()
|
||||
|
||||
return [
|
||||
{
|
||||
"source_id": r.source_id,
|
||||
"title": r.title,
|
||||
"publisher": r.publisher,
|
||||
"url": r.url,
|
||||
"version_label": r.version_label,
|
||||
"language": r.language,
|
||||
"license_id": r.license_id,
|
||||
"license_name": r.license_name,
|
||||
"commercial_use": r.commercial_use,
|
||||
"allowed_analysis": r.allowed_analysis,
|
||||
"allowed_store_excerpt": r.allowed_store_excerpt,
|
||||
"allowed_ship_embeddings": r.allowed_ship_embeddings,
|
||||
"allowed_ship_in_product": r.allowed_ship_in_product,
|
||||
"vault_retention_days": r.vault_retention_days,
|
||||
"vault_access_tier": r.vault_access_tier,
|
||||
}
|
||||
for r in rows
|
||||
]
|
||||
223
backend-compliance/compliance/services/similarity_detector.py
Normal file
223
backend-compliance/compliance/services/similarity_detector.py
Normal file
@@ -0,0 +1,223 @@
|
||||
"""
|
||||
Too-Close Similarity Detector — checks whether a candidate text is too similar
|
||||
to a protected source text (copyright / license compliance).
|
||||
|
||||
Five metrics:
|
||||
1. Exact-phrase — longest identical token sequence
|
||||
2. Token overlap — Jaccard similarity of token sets
|
||||
3. 3-gram Jaccard — Jaccard similarity of character 3-grams
|
||||
4. Embedding cosine — via bge-m3 (Ollama or embedding-service)
|
||||
5. LCS ratio — Longest Common Subsequence / max(len_a, len_b)
|
||||
|
||||
Decision:
|
||||
PASS — no fail + max 1 warn
|
||||
WARN — max 2 warn, no fail → human review
|
||||
FAIL — any fail threshold → block, rewrite required
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
import httpx
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Thresholds
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
THRESHOLDS = {
|
||||
"max_exact_run": {"warn": 8, "fail": 12},
|
||||
"token_overlap": {"warn": 0.20, "fail": 0.30},
|
||||
"ngram_jaccard": {"warn": 0.10, "fail": 0.18},
|
||||
"embedding_cosine": {"warn": 0.86, "fail": 0.92},
|
||||
"lcs_ratio": {"warn": 0.35, "fail": 0.50},
|
||||
}
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tokenisation helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_WORD_RE = re.compile(r"\w+", re.UNICODE)
|
||||
|
||||
|
||||
def _tokenize(text: str) -> list[str]:
|
||||
return [t.lower() for t in _WORD_RE.findall(text)]
|
||||
|
||||
|
||||
def _char_ngrams(text: str, n: int = 3) -> set[str]:
|
||||
text = text.lower()
|
||||
return {text[i : i + n] for i in range(len(text) - n + 1)} if len(text) >= n else set()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Metric implementations
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def max_exact_run(tokens_a: list[str], tokens_b: list[str]) -> int:
|
||||
"""Longest contiguous identical token sequence between a and b."""
|
||||
if not tokens_a or not tokens_b:
|
||||
return 0
|
||||
|
||||
best = 0
|
||||
set_b = set(tokens_b)
|
||||
|
||||
for i in range(len(tokens_a)):
|
||||
if tokens_a[i] not in set_b:
|
||||
continue
|
||||
for j in range(len(tokens_b)):
|
||||
if tokens_a[i] != tokens_b[j]:
|
||||
continue
|
||||
run = 0
|
||||
ii, jj = i, j
|
||||
while ii < len(tokens_a) and jj < len(tokens_b) and tokens_a[ii] == tokens_b[jj]:
|
||||
run += 1
|
||||
ii += 1
|
||||
jj += 1
|
||||
if run > best:
|
||||
best = run
|
||||
return best
|
||||
|
||||
|
||||
def token_overlap_jaccard(tokens_a: list[str], tokens_b: list[str]) -> float:
|
||||
"""Jaccard similarity of token sets."""
|
||||
set_a, set_b = set(tokens_a), set(tokens_b)
|
||||
if not set_a and not set_b:
|
||||
return 0.0
|
||||
return len(set_a & set_b) / len(set_a | set_b)
|
||||
|
||||
|
||||
def ngram_jaccard(text_a: str, text_b: str, n: int = 3) -> float:
|
||||
"""Jaccard similarity of character n-grams."""
|
||||
grams_a = _char_ngrams(text_a, n)
|
||||
grams_b = _char_ngrams(text_b, n)
|
||||
if not grams_a and not grams_b:
|
||||
return 0.0
|
||||
return len(grams_a & grams_b) / len(grams_a | grams_b)
|
||||
|
||||
|
||||
def lcs_ratio(tokens_a: list[str], tokens_b: list[str]) -> float:
|
||||
"""LCS length / max(len_a, len_b)."""
|
||||
m, n = len(tokens_a), len(tokens_b)
|
||||
if m == 0 or n == 0:
|
||||
return 0.0
|
||||
|
||||
# Space-optimised LCS (two rows)
|
||||
prev = [0] * (n + 1)
|
||||
curr = [0] * (n + 1)
|
||||
for i in range(1, m + 1):
|
||||
for j in range(1, n + 1):
|
||||
if tokens_a[i - 1] == tokens_b[j - 1]:
|
||||
curr[j] = prev[j - 1] + 1
|
||||
else:
|
||||
curr[j] = max(prev[j], curr[j - 1])
|
||||
prev, curr = curr, [0] * (n + 1)
|
||||
|
||||
return prev[n] / max(m, n)
|
||||
|
||||
|
||||
async def embedding_cosine(text_a: str, text_b: str, embedding_url: str | None = None) -> float:
|
||||
"""Cosine similarity via embedding service (bge-m3).
|
||||
|
||||
Falls back to 0.0 if the service is unreachable.
|
||||
"""
|
||||
url = embedding_url or "http://embedding-service:8087"
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
resp = await client.post(
|
||||
f"{url}/embed",
|
||||
json={"texts": [text_a, text_b]},
|
||||
)
|
||||
resp.raise_for_status()
|
||||
embeddings = resp.json().get("embeddings", [])
|
||||
if len(embeddings) < 2:
|
||||
return 0.0
|
||||
return _cosine(embeddings[0], embeddings[1])
|
||||
except Exception:
|
||||
logger.warning("Embedding service unreachable, skipping cosine check")
|
||||
return 0.0
|
||||
|
||||
|
||||
def _cosine(a: list[float], b: list[float]) -> float:
|
||||
dot = sum(x * y for x, y in zip(a, b))
|
||||
norm_a = sum(x * x for x in a) ** 0.5
|
||||
norm_b = sum(x * x for x in b) ** 0.5
|
||||
if norm_a == 0 or norm_b == 0:
|
||||
return 0.0
|
||||
return dot / (norm_a * norm_b)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Decision engine
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@dataclass
|
||||
class SimilarityReport:
|
||||
max_exact_run: int
|
||||
token_overlap: float
|
||||
ngram_jaccard: float
|
||||
embedding_cosine: float
|
||||
lcs_ratio: float
|
||||
status: str # PASS, WARN, FAIL
|
||||
details: dict # per-metric status
|
||||
|
||||
|
||||
def _classify(value: float | int, metric: str) -> str:
|
||||
t = THRESHOLDS[metric]
|
||||
if value >= t["fail"]:
|
||||
return "FAIL"
|
||||
if value >= t["warn"]:
|
||||
return "WARN"
|
||||
return "PASS"
|
||||
|
||||
|
||||
async def check_similarity(
|
||||
source_text: str,
|
||||
candidate_text: str,
|
||||
embedding_url: str | None = None,
|
||||
) -> SimilarityReport:
|
||||
"""Run all 5 metrics and return an aggregate report."""
|
||||
tok_src = _tokenize(source_text)
|
||||
tok_cand = _tokenize(candidate_text)
|
||||
|
||||
m_exact = max_exact_run(tok_src, tok_cand)
|
||||
m_token = token_overlap_jaccard(tok_src, tok_cand)
|
||||
m_ngram = ngram_jaccard(source_text, candidate_text)
|
||||
m_embed = await embedding_cosine(source_text, candidate_text, embedding_url)
|
||||
m_lcs = lcs_ratio(tok_src, tok_cand)
|
||||
|
||||
details = {
|
||||
"max_exact_run": _classify(m_exact, "max_exact_run"),
|
||||
"token_overlap": _classify(m_token, "token_overlap"),
|
||||
"ngram_jaccard": _classify(m_ngram, "ngram_jaccard"),
|
||||
"embedding_cosine": _classify(m_embed, "embedding_cosine"),
|
||||
"lcs_ratio": _classify(m_lcs, "lcs_ratio"),
|
||||
}
|
||||
|
||||
fail_count = sum(1 for v in details.values() if v == "FAIL")
|
||||
warn_count = sum(1 for v in details.values() if v == "WARN")
|
||||
|
||||
if fail_count > 0:
|
||||
status = "FAIL"
|
||||
elif warn_count > 2:
|
||||
status = "FAIL"
|
||||
elif warn_count > 1:
|
||||
status = "WARN"
|
||||
elif warn_count == 1:
|
||||
status = "PASS"
|
||||
else:
|
||||
status = "PASS"
|
||||
|
||||
return SimilarityReport(
|
||||
max_exact_run=m_exact,
|
||||
token_overlap=round(m_token, 4),
|
||||
ngram_jaccard=round(m_ngram, 4),
|
||||
embedding_cosine=round(m_embed, 4),
|
||||
lcs_ratio=round(m_lcs, 4),
|
||||
status=status,
|
||||
details=details,
|
||||
)
|
||||
118
backend-compliance/compliance/tests/test_similarity_detector.py
Normal file
118
backend-compliance/compliance/tests/test_similarity_detector.py
Normal file
@@ -0,0 +1,118 @@
|
||||
"""Tests for the Too-Close Similarity Detector."""
|
||||
|
||||
import pytest
|
||||
from compliance.services.similarity_detector import (
|
||||
max_exact_run,
|
||||
token_overlap_jaccard,
|
||||
ngram_jaccard,
|
||||
lcs_ratio,
|
||||
check_similarity,
|
||||
_tokenize,
|
||||
)
|
||||
|
||||
|
||||
class TestTokenize:
|
||||
def test_basic(self):
|
||||
tokens = _tokenize("Hello World 123")
|
||||
assert tokens == ["hello", "world", "123"]
|
||||
|
||||
def test_german_umlauts(self):
|
||||
tokens = _tokenize("Schutzmaßnahmen für Daten")
|
||||
assert len(tokens) == 3
|
||||
|
||||
def test_empty(self):
|
||||
assert _tokenize("") == []
|
||||
|
||||
|
||||
class TestMaxExactRun:
|
||||
def test_identical(self):
|
||||
tokens = _tokenize("the quick brown fox jumps over the lazy dog")
|
||||
assert max_exact_run(tokens, tokens) == len(tokens)
|
||||
|
||||
def test_partial_match(self):
|
||||
a = _tokenize("the quick brown fox")
|
||||
b = _tokenize("a quick brown cat")
|
||||
assert max_exact_run(a, b) == 2 # "quick brown"
|
||||
|
||||
def test_no_match(self):
|
||||
a = _tokenize("hello world")
|
||||
b = _tokenize("foo bar")
|
||||
assert max_exact_run(a, b) == 0
|
||||
|
||||
def test_empty(self):
|
||||
assert max_exact_run([], []) == 0
|
||||
assert max_exact_run(["a"], []) == 0
|
||||
|
||||
|
||||
class TestTokenOverlapJaccard:
|
||||
def test_identical(self):
|
||||
tokens = _tokenize("hello world")
|
||||
assert token_overlap_jaccard(tokens, tokens) == 1.0
|
||||
|
||||
def test_no_overlap(self):
|
||||
a = _tokenize("hello world")
|
||||
b = _tokenize("foo bar")
|
||||
assert token_overlap_jaccard(a, b) == 0.0
|
||||
|
||||
def test_partial(self):
|
||||
a = _tokenize("hello world foo")
|
||||
b = _tokenize("hello bar baz")
|
||||
# intersection: {hello}, union: {hello, world, foo, bar, baz}
|
||||
assert abs(token_overlap_jaccard(a, b) - 0.2) < 0.01
|
||||
|
||||
|
||||
class TestNgramJaccard:
|
||||
def test_identical(self):
|
||||
assert ngram_jaccard("hello", "hello") == 1.0
|
||||
|
||||
def test_different(self):
|
||||
assert ngram_jaccard("abc", "xyz") == 0.0
|
||||
|
||||
def test_short(self):
|
||||
assert ngram_jaccard("ab", "cd") == 0.0 # too short for 3-grams
|
||||
|
||||
|
||||
class TestLcsRatio:
|
||||
def test_identical(self):
|
||||
tokens = _tokenize("multi factor authentication required")
|
||||
assert lcs_ratio(tokens, tokens) == 1.0
|
||||
|
||||
def test_partial(self):
|
||||
a = _tokenize("multi factor authentication")
|
||||
b = _tokenize("single factor verification")
|
||||
# LCS: "factor" (length 1), max(3,3) = 3, ratio = 1/3
|
||||
result = lcs_ratio(a, b)
|
||||
assert 0.3 < result < 0.4
|
||||
|
||||
def test_empty(self):
|
||||
assert lcs_ratio([], []) == 0.0
|
||||
|
||||
|
||||
class TestCheckSimilarity:
|
||||
@pytest.mark.asyncio
|
||||
async def test_identical_texts_fail(self):
|
||||
text = "Multi-factor authentication must be enforced for all administrative accounts."
|
||||
report = await check_similarity(text, text, embedding_url="http://localhost:99999")
|
||||
# Identical texts should have max overlap
|
||||
assert report.token_overlap == 1.0
|
||||
assert report.status == "FAIL"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_different_texts_pass(self):
|
||||
source = "Die Anwendung muss eine Zwei-Faktor-Authentisierung implementieren."
|
||||
candidate = "Network traffic should be encrypted using TLS 1.3 at minimum."
|
||||
report = await check_similarity(source, candidate, embedding_url="http://localhost:99999")
|
||||
assert report.token_overlap < 0.1
|
||||
assert report.status == "PASS"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_report_fields(self):
|
||||
report = await check_similarity("hello world", "foo bar", embedding_url="http://localhost:99999")
|
||||
assert hasattr(report, "max_exact_run")
|
||||
assert hasattr(report, "token_overlap")
|
||||
assert hasattr(report, "ngram_jaccard")
|
||||
assert hasattr(report, "embedding_cosine")
|
||||
assert hasattr(report, "lcs_ratio")
|
||||
assert hasattr(report, "status")
|
||||
assert hasattr(report, "details")
|
||||
assert report.status in ("PASS", "WARN", "FAIL")
|
||||
Reference in New Issue
Block a user