Some checks failed
CI/CD / go-lint (push) Has been skipped
CI/CD / python-lint (push) Has been skipped
CI/CD / nodejs-lint (push) Has been skipped
CI/CD / test-go-ai-compliance (push) Failing after 47s
CI/CD / test-python-backend-compliance (push) Successful in 33s
CI/CD / test-python-document-crawler (push) Successful in 24s
CI/CD / test-python-dsms-gateway (push) Successful in 18s
CI/CD / validate-canonical-controls (push) Successful in 11s
CI/CD / Deploy (push) Has been skipped
Implements the full Multi-Layer Control Architecture for migrating ~25,000 Rich Controls into atomic, deduplicated Master Controls with full traceability. Architecture: Legal Source → Obligation → Control Pattern → Master Control → Customer Instance New services: - ObligationExtractor: 3-tier extraction (exact → embedding → LLM) - PatternMatcher: 2-tier matching (keyword + embedding + domain-bonus) - ControlComposer: Pattern + Obligation → Master Control - PipelineAdapter: Pipeline integration + Migration Passes 1-5 - DecompositionPass: Pass 0a/0b — Rich Control → atomic Controls - CrosswalkRoutes: 15 API endpoints under /v1/canonical/ New DB schema: - Migration 060: obligation_extractions, control_patterns, crosswalk_matrix - Migration 061: obligation_candidates, parent_control_uuid tracking Pattern Library: 50 YAML patterns (30 core + 20 IT-security) Go SDK: Pattern loader with YAML validation and indexing Documentation: MkDocs updated with full architecture overview 500 Python tests passing across all components. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1132 lines
41 KiB
Python
1132 lines
41 KiB
Python
"""Tests for Multi-Layer Control Architecture routes (crosswalk_routes.py).
|
|
|
|
Covers:
|
|
- Pydantic model validation (Pattern, Obligation, Crosswalk, Migration)
|
|
- Pattern Library endpoints (list, get, get controls)
|
|
- Obligation extraction endpoint
|
|
- Crosswalk query + stats endpoints
|
|
- Migration pass endpoints (1-5) + status
|
|
- Helper: _get_pattern_control_counts
|
|
"""
|
|
|
|
import json
|
|
import pytest
|
|
from unittest.mock import MagicMock, patch, AsyncMock
|
|
from typing import Optional
|
|
|
|
from fastapi import FastAPI
|
|
from fastapi.testclient import TestClient
|
|
|
|
from compliance.api.crosswalk_routes import (
|
|
PatternResponse,
|
|
PatternListResponse,
|
|
PatternDetailResponse,
|
|
ObligationExtractRequest,
|
|
ObligationExtractResponse,
|
|
CrosswalkRow,
|
|
CrosswalkQueryResponse,
|
|
CrosswalkStatsResponse,
|
|
MigrationRequest,
|
|
MigrationResponse,
|
|
MigrationStatusResponse,
|
|
DecompositionStatusResponse,
|
|
router,
|
|
_get_pattern_control_counts,
|
|
)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# TestClient setup
|
|
# ---------------------------------------------------------------------------
|
|
|
|
_app = FastAPI()
|
|
_app.include_router(router, prefix="/api/compliance")
|
|
_client = TestClient(_app)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# MODEL TESTS
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestPatternResponse:
|
|
"""Tests for PatternResponse model."""
|
|
|
|
def test_basic_creation(self):
|
|
resp = PatternResponse(
|
|
id="CP-AUTH-001",
|
|
name="password_policy",
|
|
name_de="Passwortrichtlinie",
|
|
domain="AUTH",
|
|
category="authentication",
|
|
description="Password policy requirements",
|
|
objective_template="Ensure passwords meet complexity standards.",
|
|
severity_default="high",
|
|
)
|
|
assert resp.id == "CP-AUTH-001"
|
|
assert resp.domain == "AUTH"
|
|
assert resp.severity_default == "high"
|
|
|
|
def test_default_values(self):
|
|
resp = PatternResponse(
|
|
id="CP-AUTH-001",
|
|
name="test",
|
|
name_de="Test",
|
|
domain="AUTH",
|
|
category="auth",
|
|
description="desc",
|
|
objective_template="obj",
|
|
severity_default="medium",
|
|
)
|
|
assert resp.implementation_effort_default == "m"
|
|
assert resp.tags == []
|
|
assert resp.composable_with == []
|
|
assert resp.open_anchor_refs == []
|
|
assert resp.controls_count == 0
|
|
|
|
def test_full_model(self):
|
|
resp = PatternResponse(
|
|
id="CP-CRYP-001",
|
|
name="encryption_at_rest",
|
|
name_de="Verschluesselung ruhender Daten",
|
|
domain="CRYP",
|
|
category="encryption",
|
|
description="Encrypt data at rest",
|
|
objective_template="Ensure all stored data is encrypted.",
|
|
severity_default="critical",
|
|
implementation_effort_default="l",
|
|
tags=["encryption", "storage"],
|
|
composable_with=["CP-CRYP-002"],
|
|
open_anchor_refs=[{"framework": "NIST", "ref": "SC-28"}],
|
|
controls_count=42,
|
|
)
|
|
assert resp.controls_count == 42
|
|
assert len(resp.tags) == 2
|
|
assert resp.implementation_effort_default == "l"
|
|
|
|
|
|
class TestPatternDetailResponse:
|
|
"""Tests for PatternDetailResponse model (extends PatternResponse)."""
|
|
|
|
def test_has_extended_fields(self):
|
|
resp = PatternDetailResponse(
|
|
id="CP-AUTH-001",
|
|
name="mfa",
|
|
name_de="MFA",
|
|
domain="AUTH",
|
|
category="authentication",
|
|
description="Multi-factor authentication",
|
|
objective_template="Require MFA.",
|
|
severity_default="high",
|
|
rationale_template="Passwords alone are insufficient.",
|
|
requirements_template=["Require TOTP or hardware key"],
|
|
test_procedure_template=["Test login without MFA"],
|
|
evidence_template=["MFA config screenshot"],
|
|
obligation_match_keywords=["authentifizierung", "mfa"],
|
|
)
|
|
assert resp.rationale_template == "Passwords alone are insufficient."
|
|
assert len(resp.requirements_template) == 1
|
|
assert len(resp.obligation_match_keywords) == 2
|
|
|
|
def test_defaults(self):
|
|
resp = PatternDetailResponse(
|
|
id="CP-AUTH-001",
|
|
name="test",
|
|
name_de="Test",
|
|
domain="AUTH",
|
|
category="auth",
|
|
description="desc",
|
|
objective_template="obj",
|
|
severity_default="medium",
|
|
)
|
|
assert resp.rationale_template == ""
|
|
assert resp.requirements_template == []
|
|
assert resp.test_procedure_template == []
|
|
assert resp.evidence_template == []
|
|
assert resp.obligation_match_keywords == []
|
|
|
|
|
|
class TestObligationModels:
|
|
"""Tests for ObligationExtractRequest/Response models."""
|
|
|
|
def test_request_minimal(self):
|
|
req = ObligationExtractRequest(text="Ein Verantwortlicher muss...")
|
|
assert req.text == "Ein Verantwortlicher muss..."
|
|
assert req.regulation_code is None
|
|
assert req.article is None
|
|
assert req.paragraph is None
|
|
|
|
def test_request_full(self):
|
|
req = ObligationExtractRequest(
|
|
text="Art. 32 DSGVO",
|
|
regulation_code="eu_2016_679",
|
|
article="Art. 32",
|
|
paragraph="Abs. 1",
|
|
)
|
|
assert req.regulation_code == "eu_2016_679"
|
|
assert req.paragraph == "Abs. 1"
|
|
|
|
def test_response_defaults(self):
|
|
resp = ObligationExtractResponse()
|
|
assert resp.obligation_id is None
|
|
assert resp.method == "none"
|
|
assert resp.confidence == 0.0
|
|
assert resp.pattern_id is None
|
|
assert resp.pattern_confidence == 0.0
|
|
|
|
def test_response_full(self):
|
|
resp = ObligationExtractResponse(
|
|
obligation_id="DSGVO-OBL-001",
|
|
obligation_title="Verzeichnis der Verarbeitungstaetigkeiten",
|
|
obligation_text="Der Verantwortliche muss...",
|
|
method="exact_match",
|
|
confidence=1.0,
|
|
regulation_id="dsgvo",
|
|
pattern_id="CP-GOV-001",
|
|
pattern_confidence=0.85,
|
|
)
|
|
assert resp.obligation_id == "DSGVO-OBL-001"
|
|
assert resp.method == "exact_match"
|
|
assert resp.confidence == 1.0
|
|
assert resp.pattern_confidence == 0.85
|
|
|
|
|
|
class TestCrosswalkModels:
|
|
"""Tests for CrosswalkRow and CrosswalkQueryResponse models."""
|
|
|
|
def test_row_defaults(self):
|
|
row = CrosswalkRow()
|
|
assert row.regulation_code == ""
|
|
assert row.article is None
|
|
assert row.obligation_id is None
|
|
assert row.confidence == 0.0
|
|
assert row.source == "auto"
|
|
|
|
def test_row_full(self):
|
|
row = CrosswalkRow(
|
|
regulation_code="eu_2016_679",
|
|
article="Art. 32",
|
|
obligation_id="DSGVO-OBL-002",
|
|
pattern_id="CP-CRYP-001",
|
|
master_control_id="CRYP-001",
|
|
confidence=0.95,
|
|
source="manual",
|
|
)
|
|
assert row.regulation_code == "eu_2016_679"
|
|
assert row.confidence == 0.95
|
|
|
|
def test_query_response(self):
|
|
resp = CrosswalkQueryResponse(
|
|
rows=[
|
|
CrosswalkRow(regulation_code="eu_2016_679"),
|
|
CrosswalkRow(regulation_code="eu_2022_2554"),
|
|
],
|
|
total=100,
|
|
)
|
|
assert len(resp.rows) == 2
|
|
assert resp.total == 100
|
|
|
|
|
|
class TestCrosswalkStatsResponse:
|
|
"""Tests for CrosswalkStatsResponse model."""
|
|
|
|
def test_defaults(self):
|
|
resp = CrosswalkStatsResponse()
|
|
assert resp.total_rows == 0
|
|
assert resp.regulations_covered == 0
|
|
assert resp.obligations_linked == 0
|
|
assert resp.patterns_used == 0
|
|
assert resp.controls_linked == 0
|
|
assert resp.coverage_by_regulation == {}
|
|
|
|
def test_full(self):
|
|
resp = CrosswalkStatsResponse(
|
|
total_rows=500,
|
|
regulations_covered=9,
|
|
obligations_linked=200,
|
|
patterns_used=45,
|
|
controls_linked=350,
|
|
coverage_by_regulation={"eu_2016_679": 150, "eu_2022_2554": 80},
|
|
)
|
|
assert resp.total_rows == 500
|
|
assert resp.coverage_by_regulation["eu_2016_679"] == 150
|
|
|
|
|
|
class TestMigrationModels:
|
|
"""Tests for MigrationRequest/Response/Status models."""
|
|
|
|
def test_request_default(self):
|
|
req = MigrationRequest()
|
|
assert req.limit == 0
|
|
|
|
def test_request_with_limit(self):
|
|
req = MigrationRequest(limit=100)
|
|
assert req.limit == 100
|
|
|
|
def test_response_defaults(self):
|
|
resp = MigrationResponse()
|
|
assert resp.status == "completed"
|
|
assert resp.stats == {}
|
|
|
|
def test_status_defaults(self):
|
|
s = MigrationStatusResponse()
|
|
assert s.total_controls == 0
|
|
assert s.coverage_obligation_pct == 0.0
|
|
assert s.coverage_pattern_pct == 0.0
|
|
assert s.coverage_full_pct == 0.0
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# HELPER TESTS
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestGetPatternControlCounts:
|
|
"""Tests for _get_pattern_control_counts helper."""
|
|
|
|
@patch("compliance.api.crosswalk_routes.SessionLocal")
|
|
def test_returns_counts(self, mock_session_class):
|
|
mock_db = MagicMock()
|
|
mock_session_class.return_value = mock_db
|
|
mock_result = MagicMock()
|
|
mock_result.fetchall.return_value = [
|
|
("CP-AUTH-001", 15),
|
|
("CP-CRYP-001", 8),
|
|
]
|
|
mock_db.execute.return_value = mock_result
|
|
|
|
counts = _get_pattern_control_counts()
|
|
assert counts == {"CP-AUTH-001": 15, "CP-CRYP-001": 8}
|
|
mock_db.close.assert_called_once()
|
|
|
|
@patch("compliance.api.crosswalk_routes.SessionLocal")
|
|
def test_returns_empty_on_error(self, mock_session_class):
|
|
mock_db = MagicMock()
|
|
mock_session_class.return_value = mock_db
|
|
mock_db.execute.side_effect = Exception("DB down")
|
|
|
|
counts = _get_pattern_control_counts()
|
|
assert counts == {}
|
|
mock_db.close.assert_called_once()
|
|
|
|
@patch("compliance.api.crosswalk_routes.SessionLocal")
|
|
def test_returns_empty_when_no_patterns(self, mock_session_class):
|
|
mock_db = MagicMock()
|
|
mock_session_class.return_value = mock_db
|
|
mock_result = MagicMock()
|
|
mock_result.fetchall.return_value = []
|
|
mock_db.execute.return_value = mock_result
|
|
|
|
counts = _get_pattern_control_counts()
|
|
assert counts == {}
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# PATTERN LIBRARY ENDPOINT TESTS
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class _FakePattern:
|
|
"""Minimal pattern stub for list_patterns / get_pattern tests."""
|
|
|
|
def __init__(self, id, name="test", name_de="Test", domain="AUTH",
|
|
category="auth", description="desc",
|
|
objective_template="obj", severity_default="medium",
|
|
implementation_effort_default="m", tags=None,
|
|
composable_with=None, open_anchor_refs=None,
|
|
rationale_template="", requirements_template=None,
|
|
test_procedure_template=None, evidence_template=None,
|
|
obligation_match_keywords=None):
|
|
self.id = id
|
|
self.name = name
|
|
self.name_de = name_de
|
|
self.domain = domain
|
|
self.category = category
|
|
self.description = description
|
|
self.objective_template = objective_template
|
|
self.severity_default = severity_default
|
|
self.implementation_effort_default = implementation_effort_default
|
|
self.tags = tags or []
|
|
self.composable_with = composable_with or []
|
|
self.open_anchor_refs = open_anchor_refs or []
|
|
self.rationale_template = rationale_template
|
|
self.requirements_template = requirements_template or []
|
|
self.test_procedure_template = test_procedure_template or []
|
|
self.evidence_template = evidence_template or []
|
|
self.obligation_match_keywords = obligation_match_keywords or []
|
|
|
|
|
|
class TestListPatternsEndpoint:
|
|
"""Tests for GET /patterns."""
|
|
|
|
@patch("compliance.api.crosswalk_routes._get_pattern_control_counts")
|
|
@patch("compliance.api.crosswalk_routes.PatternMatcher", create=True)
|
|
def test_list_all_patterns(self, mock_matcher_import, mock_counts):
|
|
"""Patch the PatternMatcher class used inside the endpoint."""
|
|
fake_patterns = [
|
|
_FakePattern("CP-AUTH-001", domain="AUTH", tags=["auth"]),
|
|
_FakePattern("CP-CRYP-001", domain="CRYP", tags=["encryption"]),
|
|
]
|
|
mock_counts.return_value = {"CP-AUTH-001": 5}
|
|
|
|
with patch("compliance.services.pattern_matcher.PatternMatcher") as MockPM:
|
|
instance = MagicMock()
|
|
instance._patterns = fake_patterns
|
|
MockPM.return_value = instance
|
|
|
|
resp = _client.get("/api/compliance/v1/canonical/patterns")
|
|
assert resp.status_code == 200
|
|
data = resp.json()
|
|
assert data["total"] == 2
|
|
assert len(data["patterns"]) == 2
|
|
assert data["patterns"][0]["id"] == "CP-AUTH-001"
|
|
assert data["patterns"][0]["controls_count"] == 5
|
|
assert data["patterns"][1]["controls_count"] == 0
|
|
|
|
@patch("compliance.api.crosswalk_routes._get_pattern_control_counts")
|
|
def test_filter_by_domain(self, mock_counts):
|
|
fake_patterns = [
|
|
_FakePattern("CP-AUTH-001", domain="AUTH"),
|
|
_FakePattern("CP-CRYP-001", domain="CRYP"),
|
|
]
|
|
mock_counts.return_value = {}
|
|
|
|
with patch("compliance.services.pattern_matcher.PatternMatcher") as MockPM:
|
|
instance = MagicMock()
|
|
instance._patterns = fake_patterns
|
|
MockPM.return_value = instance
|
|
|
|
resp = _client.get("/api/compliance/v1/canonical/patterns?domain=auth")
|
|
assert resp.status_code == 200
|
|
data = resp.json()
|
|
assert data["total"] == 1
|
|
assert data["patterns"][0]["id"] == "CP-AUTH-001"
|
|
|
|
@patch("compliance.api.crosswalk_routes._get_pattern_control_counts")
|
|
def test_filter_by_category(self, mock_counts):
|
|
fake_patterns = [
|
|
_FakePattern("CP-AUTH-001", category="authentication"),
|
|
_FakePattern("CP-CRYP-001", category="encryption"),
|
|
]
|
|
mock_counts.return_value = {}
|
|
|
|
with patch("compliance.services.pattern_matcher.PatternMatcher") as MockPM:
|
|
instance = MagicMock()
|
|
instance._patterns = fake_patterns
|
|
MockPM.return_value = instance
|
|
|
|
resp = _client.get("/api/compliance/v1/canonical/patterns?category=encryption")
|
|
assert resp.status_code == 200
|
|
data = resp.json()
|
|
assert data["total"] == 1
|
|
assert data["patterns"][0]["id"] == "CP-CRYP-001"
|
|
|
|
@patch("compliance.api.crosswalk_routes._get_pattern_control_counts")
|
|
def test_filter_by_tag(self, mock_counts):
|
|
fake_patterns = [
|
|
_FakePattern("CP-AUTH-001", tags=["auth", "password"]),
|
|
_FakePattern("CP-CRYP-001", tags=["encryption"]),
|
|
]
|
|
mock_counts.return_value = {}
|
|
|
|
with patch("compliance.services.pattern_matcher.PatternMatcher") as MockPM:
|
|
instance = MagicMock()
|
|
instance._patterns = fake_patterns
|
|
MockPM.return_value = instance
|
|
|
|
resp = _client.get("/api/compliance/v1/canonical/patterns?tag=password")
|
|
assert resp.status_code == 200
|
|
data = resp.json()
|
|
assert data["total"] == 1
|
|
assert data["patterns"][0]["id"] == "CP-AUTH-001"
|
|
|
|
|
|
class TestGetPatternEndpoint:
|
|
"""Tests for GET /patterns/{pattern_id}."""
|
|
|
|
@patch("compliance.api.crosswalk_routes._get_pattern_control_counts")
|
|
def test_get_existing_pattern(self, mock_counts):
|
|
fake = _FakePattern(
|
|
"CP-AUTH-001",
|
|
name="password_policy",
|
|
rationale_template="Weak passwords are risky.",
|
|
requirements_template=["Min 12 chars"],
|
|
obligation_match_keywords=["passwort"],
|
|
)
|
|
mock_counts.return_value = {"CP-AUTH-001": 10}
|
|
|
|
with patch("compliance.services.pattern_matcher.PatternMatcher") as MockPM:
|
|
instance = MagicMock()
|
|
instance.get_pattern.return_value = fake
|
|
MockPM.return_value = instance
|
|
|
|
resp = _client.get("/api/compliance/v1/canonical/patterns/CP-AUTH-001")
|
|
assert resp.status_code == 200
|
|
data = resp.json()
|
|
assert data["id"] == "CP-AUTH-001"
|
|
assert data["rationale_template"] == "Weak passwords are risky."
|
|
assert data["obligation_match_keywords"] == ["passwort"]
|
|
assert data["controls_count"] == 10
|
|
|
|
@patch("compliance.api.crosswalk_routes._get_pattern_control_counts")
|
|
def test_get_nonexistent_pattern(self, mock_counts):
|
|
mock_counts.return_value = {}
|
|
|
|
with patch("compliance.services.pattern_matcher.PatternMatcher") as MockPM:
|
|
instance = MagicMock()
|
|
instance.get_pattern.return_value = None
|
|
MockPM.return_value = instance
|
|
|
|
resp = _client.get("/api/compliance/v1/canonical/patterns/CP-FAKE-999")
|
|
assert resp.status_code == 404
|
|
assert "not found" in resp.json()["detail"].lower()
|
|
|
|
|
|
class TestGetPatternControlsEndpoint:
|
|
"""Tests for GET /patterns/{pattern_id}/controls."""
|
|
|
|
@patch("compliance.api.crosswalk_routes.SessionLocal")
|
|
def test_returns_controls(self, mock_session_class):
|
|
mock_db = MagicMock()
|
|
mock_session_class.return_value = mock_db
|
|
|
|
# Main query result
|
|
mock_result = MagicMock()
|
|
mock_result.fetchall.return_value = [
|
|
("uuid-1", "AUTH-001", "MFA", "Require MFA", "high", "draft", "authentication", '["DSGVO-OBL-001"]'),
|
|
("uuid-2", "AUTH-002", "SSO", "Implement SSO", "medium", "draft", "authentication", None),
|
|
]
|
|
|
|
# Count query result
|
|
mock_count = MagicMock()
|
|
mock_count.fetchone.return_value = (2,)
|
|
|
|
mock_db.execute.side_effect = [mock_result, mock_count]
|
|
|
|
resp = _client.get("/api/compliance/v1/canonical/patterns/CP-AUTH-001/controls")
|
|
assert resp.status_code == 200
|
|
data = resp.json()
|
|
assert data["total"] == 2
|
|
assert len(data["controls"]) == 2
|
|
assert data["controls"][0]["control_id"] == "AUTH-001"
|
|
assert data["controls"][0]["obligation_ids"] == ["DSGVO-OBL-001"]
|
|
assert data["controls"][1]["obligation_ids"] == []
|
|
mock_db.close.assert_called_once()
|
|
|
|
@patch("compliance.api.crosswalk_routes.SessionLocal")
|
|
def test_pagination(self, mock_session_class):
|
|
mock_db = MagicMock()
|
|
mock_session_class.return_value = mock_db
|
|
|
|
mock_result = MagicMock()
|
|
mock_result.fetchall.return_value = []
|
|
mock_count = MagicMock()
|
|
mock_count.fetchone.return_value = (0,)
|
|
mock_db.execute.side_effect = [mock_result, mock_count]
|
|
|
|
resp = _client.get("/api/compliance/v1/canonical/patterns/CP-AUTH-001/controls?limit=10&offset=20")
|
|
assert resp.status_code == 200
|
|
data = resp.json()
|
|
assert data["total"] == 0
|
|
assert data["controls"] == []
|
|
|
|
@patch("compliance.api.crosswalk_routes.SessionLocal")
|
|
def test_json_parse_error_in_obligation_ids(self, mock_session_class):
|
|
mock_db = MagicMock()
|
|
mock_session_class.return_value = mock_db
|
|
|
|
mock_result = MagicMock()
|
|
mock_result.fetchall.return_value = [
|
|
("uuid-1", "AUTH-001", "MFA", "obj", "high", "draft", "auth", "not-valid-json"),
|
|
]
|
|
mock_count = MagicMock()
|
|
mock_count.fetchone.return_value = (1,)
|
|
mock_db.execute.side_effect = [mock_result, mock_count]
|
|
|
|
resp = _client.get("/api/compliance/v1/canonical/patterns/CP-AUTH-001/controls")
|
|
assert resp.status_code == 200
|
|
data = resp.json()
|
|
assert data["controls"][0]["obligation_ids"] == []
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# OBLIGATION EXTRACTION ENDPOINT TESTS
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestExtractObligationEndpoint:
|
|
"""Tests for POST /obligations/extract."""
|
|
|
|
@patch("compliance.services.pattern_matcher.PatternMatcher")
|
|
@patch("compliance.services.obligation_extractor.ObligationExtractor")
|
|
def test_extract_with_pattern_match(self, MockExtractor, MockMatcher):
|
|
# Mock extractor
|
|
mock_ext = AsyncMock()
|
|
MockExtractor.return_value = mock_ext
|
|
mock_obligation = MagicMock()
|
|
mock_obligation.obligation_id = "DSGVO-OBL-001"
|
|
mock_obligation.obligation_title = "VVT"
|
|
mock_obligation.obligation_text = "Der Verantwortliche muss..."
|
|
mock_obligation.method = "exact_match"
|
|
mock_obligation.confidence = 1.0
|
|
mock_obligation.regulation_id = "dsgvo"
|
|
mock_ext.extract.return_value = mock_obligation
|
|
|
|
# Mock matcher
|
|
mock_pm = MagicMock()
|
|
MockMatcher.return_value = mock_pm
|
|
mock_pattern_result = MagicMock()
|
|
mock_pattern_result.pattern_id = "CP-GOV-001"
|
|
mock_pattern_result.confidence = 0.85
|
|
mock_pm._tier1_keyword.return_value = mock_pattern_result
|
|
|
|
resp = _client.post(
|
|
"/api/compliance/v1/canonical/obligations/extract",
|
|
json={"text": "Art. 30 DSGVO Verzeichnis", "regulation_code": "eu_2016_679"},
|
|
)
|
|
assert resp.status_code == 200
|
|
data = resp.json()
|
|
assert data["obligation_id"] == "DSGVO-OBL-001"
|
|
assert data["method"] == "exact_match"
|
|
assert data["pattern_id"] == "CP-GOV-001"
|
|
assert data["pattern_confidence"] == 0.85
|
|
|
|
@patch("compliance.services.pattern_matcher.PatternMatcher")
|
|
@patch("compliance.services.obligation_extractor.ObligationExtractor")
|
|
def test_extract_no_pattern_match(self, MockExtractor, MockMatcher):
|
|
mock_ext = AsyncMock()
|
|
MockExtractor.return_value = mock_ext
|
|
mock_obligation = MagicMock()
|
|
mock_obligation.obligation_id = None
|
|
mock_obligation.obligation_title = None
|
|
mock_obligation.obligation_text = "Some text"
|
|
mock_obligation.method = "llm_extracted"
|
|
mock_obligation.confidence = 0.6
|
|
mock_obligation.regulation_id = None
|
|
mock_ext.extract.return_value = mock_obligation
|
|
|
|
mock_pm = MagicMock()
|
|
MockMatcher.return_value = mock_pm
|
|
mock_pm._tier1_keyword.return_value = None
|
|
|
|
resp = _client.post(
|
|
"/api/compliance/v1/canonical/obligations/extract",
|
|
json={"text": "Some random regulation text"},
|
|
)
|
|
assert resp.status_code == 200
|
|
data = resp.json()
|
|
assert data["obligation_id"] is None
|
|
assert data["pattern_id"] is None
|
|
assert data["pattern_confidence"] == 0.0
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# CROSSWALK ENDPOINT TESTS
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestQueryCrosswalkEndpoint:
|
|
"""Tests for GET /crosswalk."""
|
|
|
|
@patch("compliance.api.crosswalk_routes.SessionLocal")
|
|
def test_query_no_filters(self, mock_session_class):
|
|
mock_db = MagicMock()
|
|
mock_session_class.return_value = mock_db
|
|
|
|
mock_result = MagicMock()
|
|
mock_result.fetchall.return_value = [
|
|
("eu_2016_679", "Art. 32", "DSGVO-OBL-002", "CP-CRYP-001", "CRYP-001", 0.95, "auto"),
|
|
]
|
|
mock_count = MagicMock()
|
|
mock_count.fetchone.return_value = (1,)
|
|
mock_db.execute.side_effect = [mock_result, mock_count]
|
|
|
|
resp = _client.get("/api/compliance/v1/canonical/crosswalk")
|
|
assert resp.status_code == 200
|
|
data = resp.json()
|
|
assert data["total"] == 1
|
|
assert data["rows"][0]["regulation_code"] == "eu_2016_679"
|
|
assert data["rows"][0]["confidence"] == 0.95
|
|
mock_db.close.assert_called_once()
|
|
|
|
@patch("compliance.api.crosswalk_routes.SessionLocal")
|
|
def test_query_with_regulation_filter(self, mock_session_class):
|
|
mock_db = MagicMock()
|
|
mock_session_class.return_value = mock_db
|
|
|
|
mock_result = MagicMock()
|
|
mock_result.fetchall.return_value = []
|
|
mock_count = MagicMock()
|
|
mock_count.fetchone.return_value = (0,)
|
|
mock_db.execute.side_effect = [mock_result, mock_count]
|
|
|
|
resp = _client.get("/api/compliance/v1/canonical/crosswalk?regulation_code=eu_2016_679")
|
|
assert resp.status_code == 200
|
|
data = resp.json()
|
|
assert data["total"] == 0
|
|
|
|
# Verify SQL contained regulation filter
|
|
call_args = mock_db.execute.call_args_list[0]
|
|
sql_text = call_args[0][0].text
|
|
assert "regulation_code = :reg" in sql_text
|
|
|
|
@patch("compliance.api.crosswalk_routes.SessionLocal")
|
|
def test_query_with_all_filters(self, mock_session_class):
|
|
mock_db = MagicMock()
|
|
mock_session_class.return_value = mock_db
|
|
|
|
mock_result = MagicMock()
|
|
mock_result.fetchall.return_value = []
|
|
mock_count = MagicMock()
|
|
mock_count.fetchone.return_value = (0,)
|
|
mock_db.execute.side_effect = [mock_result, mock_count]
|
|
|
|
resp = _client.get(
|
|
"/api/compliance/v1/canonical/crosswalk"
|
|
"?regulation_code=eu_2016_679"
|
|
"&article=Art.%2032"
|
|
"&obligation_id=DSGVO-OBL-002"
|
|
"&pattern_id=CP-CRYP-001"
|
|
)
|
|
assert resp.status_code == 200
|
|
|
|
# Verify params were passed
|
|
call_args = mock_db.execute.call_args_list[0]
|
|
params = call_args[0][1]
|
|
assert params["reg"] == "eu_2016_679"
|
|
assert params["art"] == "Art. 32"
|
|
assert params["obl"] == "DSGVO-OBL-002"
|
|
assert params["pat"] == "CP-CRYP-001"
|
|
|
|
@patch("compliance.api.crosswalk_routes.SessionLocal")
|
|
def test_query_null_values_handled(self, mock_session_class):
|
|
mock_db = MagicMock()
|
|
mock_session_class.return_value = mock_db
|
|
|
|
mock_result = MagicMock()
|
|
mock_result.fetchall.return_value = [
|
|
("eu_2016_679", None, None, None, None, None, None),
|
|
]
|
|
mock_count = MagicMock()
|
|
mock_count.fetchone.return_value = (1,)
|
|
mock_db.execute.side_effect = [mock_result, mock_count]
|
|
|
|
resp = _client.get("/api/compliance/v1/canonical/crosswalk")
|
|
assert resp.status_code == 200
|
|
data = resp.json()
|
|
row = data["rows"][0]
|
|
assert row["confidence"] == 0.0
|
|
assert row["source"] == "auto"
|
|
|
|
@patch("compliance.api.crosswalk_routes.SessionLocal")
|
|
def test_query_pagination(self, mock_session_class):
|
|
mock_db = MagicMock()
|
|
mock_session_class.return_value = mock_db
|
|
|
|
mock_result = MagicMock()
|
|
mock_result.fetchall.return_value = []
|
|
mock_count = MagicMock()
|
|
mock_count.fetchone.return_value = (500,)
|
|
mock_db.execute.side_effect = [mock_result, mock_count]
|
|
|
|
resp = _client.get("/api/compliance/v1/canonical/crosswalk?limit=50&offset=100")
|
|
assert resp.status_code == 200
|
|
data = resp.json()
|
|
assert data["total"] == 500
|
|
|
|
call_args = mock_db.execute.call_args_list[0]
|
|
params = call_args[0][1]
|
|
assert params["limit"] == 50
|
|
assert params["offset"] == 100
|
|
|
|
|
|
class TestCrosswalkStatsEndpoint:
|
|
"""Tests for GET /crosswalk/stats."""
|
|
|
|
@patch("compliance.api.crosswalk_routes.SessionLocal")
|
|
def test_returns_stats(self, mock_session_class):
|
|
mock_db = MagicMock()
|
|
mock_session_class.return_value = mock_db
|
|
|
|
# Main stats query
|
|
mock_main = MagicMock()
|
|
mock_main.fetchone.return_value = (500, 9, 200, 45, 350)
|
|
|
|
# Coverage by regulation query
|
|
mock_reg = MagicMock()
|
|
mock_reg.fetchall.return_value = [
|
|
("eu_2016_679", 150),
|
|
("eu_2022_2554", 80),
|
|
]
|
|
|
|
mock_db.execute.side_effect = [mock_main, mock_reg]
|
|
|
|
resp = _client.get("/api/compliance/v1/canonical/crosswalk/stats")
|
|
assert resp.status_code == 200
|
|
data = resp.json()
|
|
assert data["total_rows"] == 500
|
|
assert data["regulations_covered"] == 9
|
|
assert data["obligations_linked"] == 200
|
|
assert data["patterns_used"] == 45
|
|
assert data["controls_linked"] == 350
|
|
assert data["coverage_by_regulation"]["eu_2016_679"] == 150
|
|
mock_db.close.assert_called_once()
|
|
|
|
@patch("compliance.api.crosswalk_routes.SessionLocal")
|
|
def test_empty_stats(self, mock_session_class):
|
|
mock_db = MagicMock()
|
|
mock_session_class.return_value = mock_db
|
|
|
|
mock_main = MagicMock()
|
|
mock_main.fetchone.return_value = (0, 0, 0, 0, 0)
|
|
mock_reg = MagicMock()
|
|
mock_reg.fetchall.return_value = []
|
|
|
|
mock_db.execute.side_effect = [mock_main, mock_reg]
|
|
|
|
resp = _client.get("/api/compliance/v1/canonical/crosswalk/stats")
|
|
assert resp.status_code == 200
|
|
data = resp.json()
|
|
assert data["total_rows"] == 0
|
|
assert data["coverage_by_regulation"] == {}
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# MIGRATION ENDPOINT TESTS
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestMigratePass1Endpoint:
|
|
"""Tests for POST /migrate/link-obligations."""
|
|
|
|
@patch("compliance.api.crosswalk_routes.SessionLocal")
|
|
@patch("compliance.services.pipeline_adapter.MigrationPasses")
|
|
def test_pass1_success(self, MockMigration, mock_session_class):
|
|
mock_db = MagicMock()
|
|
mock_session_class.return_value = mock_db
|
|
|
|
mock_mig = AsyncMock()
|
|
MockMigration.return_value = mock_mig
|
|
mock_mig.run_pass1_obligation_linkage.return_value = {
|
|
"processed": 100, "linked": 60, "skipped": 40,
|
|
}
|
|
|
|
resp = _client.post(
|
|
"/api/compliance/v1/canonical/migrate/link-obligations",
|
|
json={"limit": 100},
|
|
)
|
|
assert resp.status_code == 200
|
|
data = resp.json()
|
|
assert data["status"] == "completed"
|
|
assert data["stats"]["linked"] == 60
|
|
mock_db.close.assert_called_once()
|
|
|
|
@patch("compliance.api.crosswalk_routes.SessionLocal")
|
|
@patch("compliance.services.pipeline_adapter.MigrationPasses")
|
|
def test_pass1_failure(self, MockMigration, mock_session_class):
|
|
mock_db = MagicMock()
|
|
mock_session_class.return_value = mock_db
|
|
|
|
mock_mig = AsyncMock()
|
|
MockMigration.return_value = mock_mig
|
|
mock_mig.run_pass1_obligation_linkage.side_effect = RuntimeError("DB connection lost")
|
|
|
|
resp = _client.post(
|
|
"/api/compliance/v1/canonical/migrate/link-obligations",
|
|
json={},
|
|
)
|
|
assert resp.status_code == 500
|
|
assert "DB connection lost" in resp.json()["detail"]
|
|
mock_db.close.assert_called_once()
|
|
|
|
|
|
class TestMigratePass2Endpoint:
|
|
"""Tests for POST /migrate/classify-patterns."""
|
|
|
|
@patch("compliance.api.crosswalk_routes.SessionLocal")
|
|
@patch("compliance.services.pipeline_adapter.MigrationPasses")
|
|
def test_pass2_success(self, MockMigration, mock_session_class):
|
|
mock_db = MagicMock()
|
|
mock_session_class.return_value = mock_db
|
|
|
|
mock_mig = AsyncMock()
|
|
MockMigration.return_value = mock_mig
|
|
mock_mig.run_pass2_pattern_classification.return_value = {
|
|
"processed": 200, "classified": 140, "candidates": 30, "unmatched": 30,
|
|
}
|
|
|
|
resp = _client.post(
|
|
"/api/compliance/v1/canonical/migrate/classify-patterns",
|
|
json={"limit": 200},
|
|
)
|
|
assert resp.status_code == 200
|
|
data = resp.json()
|
|
assert data["stats"]["classified"] == 140
|
|
|
|
|
|
class TestMigratePass3Endpoint:
|
|
"""Tests for POST /migrate/triage."""
|
|
|
|
@patch("compliance.api.crosswalk_routes.SessionLocal")
|
|
@patch("compliance.services.pipeline_adapter.MigrationPasses")
|
|
def test_pass3_success(self, MockMigration, mock_session_class):
|
|
mock_db = MagicMock()
|
|
mock_session_class.return_value = mock_db
|
|
|
|
mock_mig = MagicMock()
|
|
MockMigration.return_value = mock_mig
|
|
mock_mig.run_pass3_quality_triage.return_value = {
|
|
"review": 60, "needs_obligation": 20, "needs_pattern": 15, "legacy_unlinked": 5,
|
|
}
|
|
|
|
resp = _client.post("/api/compliance/v1/canonical/migrate/triage")
|
|
assert resp.status_code == 200
|
|
data = resp.json()
|
|
assert data["stats"]["review"] == 60
|
|
|
|
@patch("compliance.api.crosswalk_routes.SessionLocal")
|
|
@patch("compliance.services.pipeline_adapter.MigrationPasses")
|
|
def test_pass3_failure(self, MockMigration, mock_session_class):
|
|
mock_db = MagicMock()
|
|
mock_session_class.return_value = mock_db
|
|
|
|
mock_mig = MagicMock()
|
|
MockMigration.return_value = mock_mig
|
|
mock_mig.run_pass3_quality_triage.side_effect = Exception("triage error")
|
|
|
|
resp = _client.post("/api/compliance/v1/canonical/migrate/triage")
|
|
assert resp.status_code == 500
|
|
|
|
|
|
class TestMigratePass4Endpoint:
|
|
"""Tests for POST /migrate/backfill-crosswalk."""
|
|
|
|
@patch("compliance.api.crosswalk_routes.SessionLocal")
|
|
@patch("compliance.services.pipeline_adapter.MigrationPasses")
|
|
def test_pass4_success(self, MockMigration, mock_session_class):
|
|
mock_db = MagicMock()
|
|
mock_session_class.return_value = mock_db
|
|
|
|
mock_mig = MagicMock()
|
|
MockMigration.return_value = mock_mig
|
|
mock_mig.run_pass4_crosswalk_backfill.return_value = {
|
|
"rows_created": 250,
|
|
}
|
|
|
|
resp = _client.post("/api/compliance/v1/canonical/migrate/backfill-crosswalk")
|
|
assert resp.status_code == 200
|
|
data = resp.json()
|
|
assert data["stats"]["rows_created"] == 250
|
|
|
|
|
|
class TestMigratePass5Endpoint:
|
|
"""Tests for POST /migrate/deduplicate."""
|
|
|
|
@patch("compliance.api.crosswalk_routes.SessionLocal")
|
|
@patch("compliance.services.pipeline_adapter.MigrationPasses")
|
|
def test_pass5_success(self, MockMigration, mock_session_class):
|
|
mock_db = MagicMock()
|
|
mock_session_class.return_value = mock_db
|
|
|
|
mock_mig = MagicMock()
|
|
MockMigration.return_value = mock_mig
|
|
mock_mig.run_pass5_deduplication.return_value = {
|
|
"groups_found": 80, "deprecated": 120, "kept": 80,
|
|
}
|
|
|
|
resp = _client.post("/api/compliance/v1/canonical/migrate/deduplicate")
|
|
assert resp.status_code == 200
|
|
data = resp.json()
|
|
assert data["stats"]["deprecated"] == 120
|
|
|
|
|
|
class TestMigrationStatusEndpoint:
|
|
"""Tests for GET /migrate/status."""
|
|
|
|
@patch("compliance.api.crosswalk_routes.SessionLocal")
|
|
@patch("compliance.services.pipeline_adapter.MigrationPasses")
|
|
def test_status_success(self, MockMigration, mock_session_class):
|
|
mock_db = MagicMock()
|
|
mock_session_class.return_value = mock_db
|
|
|
|
mock_mig = MagicMock()
|
|
MockMigration.return_value = mock_mig
|
|
mock_mig.migration_status.return_value = {
|
|
"total_controls": 4800,
|
|
"has_obligation": 2880,
|
|
"has_pattern": 3360,
|
|
"fully_linked": 2400,
|
|
"deprecated": 1200,
|
|
"coverage_obligation_pct": 60.0,
|
|
"coverage_pattern_pct": 70.0,
|
|
"coverage_full_pct": 50.0,
|
|
}
|
|
|
|
resp = _client.get("/api/compliance/v1/canonical/migrate/status")
|
|
assert resp.status_code == 200
|
|
data = resp.json()
|
|
assert data["total_controls"] == 4800
|
|
assert data["coverage_full_pct"] == 50.0
|
|
|
|
@patch("compliance.api.crosswalk_routes.SessionLocal")
|
|
@patch("compliance.services.pipeline_adapter.MigrationPasses")
|
|
def test_status_failure(self, MockMigration, mock_session_class):
|
|
mock_db = MagicMock()
|
|
mock_session_class.return_value = mock_db
|
|
|
|
mock_mig = MagicMock()
|
|
MockMigration.return_value = mock_mig
|
|
mock_mig.migration_status.side_effect = Exception("DB error")
|
|
|
|
resp = _client.get("/api/compliance/v1/canonical/migrate/status")
|
|
assert resp.status_code == 500
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# DECOMPOSITION ENDPOINT TESTS (Pass 0a / 0b)
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestDecompositionStatusModel:
|
|
"""Tests for DecompositionStatusResponse model."""
|
|
|
|
def test_defaults(self):
|
|
s = DecompositionStatusResponse()
|
|
assert s.rich_controls == 0
|
|
assert s.decomposition_pct == 0.0
|
|
assert s.composition_pct == 0.0
|
|
|
|
def test_full(self):
|
|
s = DecompositionStatusResponse(
|
|
rich_controls=5000,
|
|
decomposed_controls=1000,
|
|
total_candidates=3000,
|
|
validated=2500,
|
|
rejected=200,
|
|
composed=2000,
|
|
atomic_controls=1800,
|
|
decomposition_pct=20.0,
|
|
composition_pct=80.0,
|
|
)
|
|
assert s.rich_controls == 5000
|
|
assert s.atomic_controls == 1800
|
|
|
|
|
|
class TestMigrateDecomposeEndpoint:
|
|
"""Tests for POST /migrate/decompose (Pass 0a)."""
|
|
|
|
@patch("compliance.api.crosswalk_routes.SessionLocal")
|
|
@patch("compliance.services.decomposition_pass.DecompositionPass")
|
|
def test_pass0a_success(self, MockDecomp, mock_session_class):
|
|
mock_db = MagicMock()
|
|
mock_session_class.return_value = mock_db
|
|
|
|
mock_decomp = AsyncMock()
|
|
MockDecomp.return_value = mock_decomp
|
|
mock_decomp.run_pass0a.return_value = {
|
|
"controls_processed": 50,
|
|
"obligations_extracted": 180,
|
|
"obligations_validated": 160,
|
|
"obligations_rejected": 20,
|
|
"controls_skipped_empty": 5,
|
|
"errors": 0,
|
|
}
|
|
|
|
resp = _client.post(
|
|
"/api/compliance/v1/canonical/migrate/decompose",
|
|
json={"limit": 50},
|
|
)
|
|
assert resp.status_code == 200
|
|
data = resp.json()
|
|
assert data["status"] == "completed"
|
|
assert data["stats"]["obligations_extracted"] == 180
|
|
mock_db.close.assert_called_once()
|
|
|
|
@patch("compliance.api.crosswalk_routes.SessionLocal")
|
|
@patch("compliance.services.decomposition_pass.DecompositionPass")
|
|
def test_pass0a_failure(self, MockDecomp, mock_session_class):
|
|
mock_db = MagicMock()
|
|
mock_session_class.return_value = mock_db
|
|
|
|
mock_decomp = AsyncMock()
|
|
MockDecomp.return_value = mock_decomp
|
|
mock_decomp.run_pass0a.side_effect = RuntimeError("LLM timeout")
|
|
|
|
resp = _client.post(
|
|
"/api/compliance/v1/canonical/migrate/decompose",
|
|
json={},
|
|
)
|
|
assert resp.status_code == 500
|
|
assert "LLM timeout" in resp.json()["detail"]
|
|
|
|
|
|
class TestMigrateComposeAtomicEndpoint:
|
|
"""Tests for POST /migrate/compose-atomic (Pass 0b)."""
|
|
|
|
@patch("compliance.api.crosswalk_routes.SessionLocal")
|
|
@patch("compliance.services.decomposition_pass.DecompositionPass")
|
|
def test_pass0b_success(self, MockDecomp, mock_session_class):
|
|
mock_db = MagicMock()
|
|
mock_session_class.return_value = mock_db
|
|
|
|
mock_decomp = AsyncMock()
|
|
MockDecomp.return_value = mock_decomp
|
|
mock_decomp.run_pass0b.return_value = {
|
|
"candidates_processed": 160,
|
|
"controls_created": 155,
|
|
"llm_failures": 5,
|
|
"errors": 0,
|
|
}
|
|
|
|
resp = _client.post(
|
|
"/api/compliance/v1/canonical/migrate/compose-atomic",
|
|
json={"limit": 200},
|
|
)
|
|
assert resp.status_code == 200
|
|
data = resp.json()
|
|
assert data["stats"]["controls_created"] == 155
|
|
|
|
|
|
class TestDecompositionStatusEndpoint:
|
|
"""Tests for GET /migrate/decomposition-status."""
|
|
|
|
@patch("compliance.api.crosswalk_routes.SessionLocal")
|
|
@patch("compliance.services.decomposition_pass.DecompositionPass")
|
|
def test_status_success(self, MockDecomp, mock_session_class):
|
|
mock_db = MagicMock()
|
|
mock_session_class.return_value = mock_db
|
|
|
|
mock_decomp = MagicMock()
|
|
MockDecomp.return_value = mock_decomp
|
|
mock_decomp.decomposition_status.return_value = {
|
|
"rich_controls": 5000,
|
|
"decomposed_controls": 1000,
|
|
"total_candidates": 3000,
|
|
"validated": 2500,
|
|
"rejected": 200,
|
|
"composed": 2000,
|
|
"atomic_controls": 1800,
|
|
"decomposition_pct": 20.0,
|
|
"composition_pct": 80.0,
|
|
}
|
|
|
|
resp = _client.get("/api/compliance/v1/canonical/migrate/decomposition-status")
|
|
assert resp.status_code == 200
|
|
data = resp.json()
|
|
assert data["rich_controls"] == 5000
|
|
assert data["atomic_controls"] == 1800
|
|
assert data["decomposition_pct"] == 20.0
|
|
|
|
@patch("compliance.api.crosswalk_routes.SessionLocal")
|
|
@patch("compliance.services.decomposition_pass.DecompositionPass")
|
|
def test_status_failure(self, MockDecomp, mock_session_class):
|
|
mock_db = MagicMock()
|
|
mock_session_class.return_value = mock_db
|
|
|
|
mock_decomp = MagicMock()
|
|
MockDecomp.return_value = mock_decomp
|
|
mock_decomp.decomposition_status.side_effect = Exception("DB error")
|
|
|
|
resp = _client.get("/api/compliance/v1/canonical/migrate/decomposition-status")
|
|
assert resp.status_code == 500
|