Files
breakpilot-compliance/backend-compliance/tests/test_crosswalk_routes.py
Benjamin Admin 825e070ed9
Some checks failed
CI/CD / go-lint (push) Has been skipped
CI/CD / python-lint (push) Has been skipped
CI/CD / nodejs-lint (push) Has been skipped
CI/CD / test-go-ai-compliance (push) Failing after 47s
CI/CD / test-python-backend-compliance (push) Successful in 33s
CI/CD / test-python-document-crawler (push) Successful in 24s
CI/CD / test-python-dsms-gateway (push) Successful in 18s
CI/CD / validate-canonical-controls (push) Successful in 11s
CI/CD / Deploy (push) Has been skipped
feat(multi-layer): complete Multi-Layer Control Architecture (Phases 1-8 + Pass 0)
Implements the full Multi-Layer Control Architecture for migrating ~25,000
Rich Controls into atomic, deduplicated Master Controls with full traceability.

Architecture: Legal Source → Obligation → Control Pattern → Master Control → Customer Instance

New services:
- ObligationExtractor: 3-tier extraction (exact → embedding → LLM)
- PatternMatcher: 2-tier matching (keyword + embedding + domain-bonus)
- ControlComposer: Pattern + Obligation → Master Control
- PipelineAdapter: Pipeline integration + Migration Passes 1-5
- DecompositionPass: Pass 0a/0b — Rich Control → atomic Controls
- CrosswalkRoutes: 15 API endpoints under /v1/canonical/

New DB schema:
- Migration 060: obligation_extractions, control_patterns, crosswalk_matrix
- Migration 061: obligation_candidates, parent_control_uuid tracking

Pattern Library: 50 YAML patterns (30 core + 20 IT-security)
Go SDK: Pattern loader with YAML validation and indexing
Documentation: MkDocs updated with full architecture overview

500 Python tests passing across all components.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-17 09:00:37 +01:00

1132 lines
41 KiB
Python

"""Tests for Multi-Layer Control Architecture routes (crosswalk_routes.py).
Covers:
- Pydantic model validation (Pattern, Obligation, Crosswalk, Migration)
- Pattern Library endpoints (list, get, get controls)
- Obligation extraction endpoint
- Crosswalk query + stats endpoints
- Migration pass endpoints (1-5) + status
- Helper: _get_pattern_control_counts
"""
import json
import pytest
from unittest.mock import MagicMock, patch, AsyncMock
from typing import Optional
from fastapi import FastAPI
from fastapi.testclient import TestClient
from compliance.api.crosswalk_routes import (
PatternResponse,
PatternListResponse,
PatternDetailResponse,
ObligationExtractRequest,
ObligationExtractResponse,
CrosswalkRow,
CrosswalkQueryResponse,
CrosswalkStatsResponse,
MigrationRequest,
MigrationResponse,
MigrationStatusResponse,
DecompositionStatusResponse,
router,
_get_pattern_control_counts,
)
# ---------------------------------------------------------------------------
# TestClient setup
# ---------------------------------------------------------------------------
_app = FastAPI()
_app.include_router(router, prefix="/api/compliance")
_client = TestClient(_app)
# ---------------------------------------------------------------------------
# MODEL TESTS
# ---------------------------------------------------------------------------
class TestPatternResponse:
"""Tests for PatternResponse model."""
def test_basic_creation(self):
resp = PatternResponse(
id="CP-AUTH-001",
name="password_policy",
name_de="Passwortrichtlinie",
domain="AUTH",
category="authentication",
description="Password policy requirements",
objective_template="Ensure passwords meet complexity standards.",
severity_default="high",
)
assert resp.id == "CP-AUTH-001"
assert resp.domain == "AUTH"
assert resp.severity_default == "high"
def test_default_values(self):
resp = PatternResponse(
id="CP-AUTH-001",
name="test",
name_de="Test",
domain="AUTH",
category="auth",
description="desc",
objective_template="obj",
severity_default="medium",
)
assert resp.implementation_effort_default == "m"
assert resp.tags == []
assert resp.composable_with == []
assert resp.open_anchor_refs == []
assert resp.controls_count == 0
def test_full_model(self):
resp = PatternResponse(
id="CP-CRYP-001",
name="encryption_at_rest",
name_de="Verschluesselung ruhender Daten",
domain="CRYP",
category="encryption",
description="Encrypt data at rest",
objective_template="Ensure all stored data is encrypted.",
severity_default="critical",
implementation_effort_default="l",
tags=["encryption", "storage"],
composable_with=["CP-CRYP-002"],
open_anchor_refs=[{"framework": "NIST", "ref": "SC-28"}],
controls_count=42,
)
assert resp.controls_count == 42
assert len(resp.tags) == 2
assert resp.implementation_effort_default == "l"
class TestPatternDetailResponse:
"""Tests for PatternDetailResponse model (extends PatternResponse)."""
def test_has_extended_fields(self):
resp = PatternDetailResponse(
id="CP-AUTH-001",
name="mfa",
name_de="MFA",
domain="AUTH",
category="authentication",
description="Multi-factor authentication",
objective_template="Require MFA.",
severity_default="high",
rationale_template="Passwords alone are insufficient.",
requirements_template=["Require TOTP or hardware key"],
test_procedure_template=["Test login without MFA"],
evidence_template=["MFA config screenshot"],
obligation_match_keywords=["authentifizierung", "mfa"],
)
assert resp.rationale_template == "Passwords alone are insufficient."
assert len(resp.requirements_template) == 1
assert len(resp.obligation_match_keywords) == 2
def test_defaults(self):
resp = PatternDetailResponse(
id="CP-AUTH-001",
name="test",
name_de="Test",
domain="AUTH",
category="auth",
description="desc",
objective_template="obj",
severity_default="medium",
)
assert resp.rationale_template == ""
assert resp.requirements_template == []
assert resp.test_procedure_template == []
assert resp.evidence_template == []
assert resp.obligation_match_keywords == []
class TestObligationModels:
"""Tests for ObligationExtractRequest/Response models."""
def test_request_minimal(self):
req = ObligationExtractRequest(text="Ein Verantwortlicher muss...")
assert req.text == "Ein Verantwortlicher muss..."
assert req.regulation_code is None
assert req.article is None
assert req.paragraph is None
def test_request_full(self):
req = ObligationExtractRequest(
text="Art. 32 DSGVO",
regulation_code="eu_2016_679",
article="Art. 32",
paragraph="Abs. 1",
)
assert req.regulation_code == "eu_2016_679"
assert req.paragraph == "Abs. 1"
def test_response_defaults(self):
resp = ObligationExtractResponse()
assert resp.obligation_id is None
assert resp.method == "none"
assert resp.confidence == 0.0
assert resp.pattern_id is None
assert resp.pattern_confidence == 0.0
def test_response_full(self):
resp = ObligationExtractResponse(
obligation_id="DSGVO-OBL-001",
obligation_title="Verzeichnis der Verarbeitungstaetigkeiten",
obligation_text="Der Verantwortliche muss...",
method="exact_match",
confidence=1.0,
regulation_id="dsgvo",
pattern_id="CP-GOV-001",
pattern_confidence=0.85,
)
assert resp.obligation_id == "DSGVO-OBL-001"
assert resp.method == "exact_match"
assert resp.confidence == 1.0
assert resp.pattern_confidence == 0.85
class TestCrosswalkModels:
"""Tests for CrosswalkRow and CrosswalkQueryResponse models."""
def test_row_defaults(self):
row = CrosswalkRow()
assert row.regulation_code == ""
assert row.article is None
assert row.obligation_id is None
assert row.confidence == 0.0
assert row.source == "auto"
def test_row_full(self):
row = CrosswalkRow(
regulation_code="eu_2016_679",
article="Art. 32",
obligation_id="DSGVO-OBL-002",
pattern_id="CP-CRYP-001",
master_control_id="CRYP-001",
confidence=0.95,
source="manual",
)
assert row.regulation_code == "eu_2016_679"
assert row.confidence == 0.95
def test_query_response(self):
resp = CrosswalkQueryResponse(
rows=[
CrosswalkRow(regulation_code="eu_2016_679"),
CrosswalkRow(regulation_code="eu_2022_2554"),
],
total=100,
)
assert len(resp.rows) == 2
assert resp.total == 100
class TestCrosswalkStatsResponse:
"""Tests for CrosswalkStatsResponse model."""
def test_defaults(self):
resp = CrosswalkStatsResponse()
assert resp.total_rows == 0
assert resp.regulations_covered == 0
assert resp.obligations_linked == 0
assert resp.patterns_used == 0
assert resp.controls_linked == 0
assert resp.coverage_by_regulation == {}
def test_full(self):
resp = CrosswalkStatsResponse(
total_rows=500,
regulations_covered=9,
obligations_linked=200,
patterns_used=45,
controls_linked=350,
coverage_by_regulation={"eu_2016_679": 150, "eu_2022_2554": 80},
)
assert resp.total_rows == 500
assert resp.coverage_by_regulation["eu_2016_679"] == 150
class TestMigrationModels:
"""Tests for MigrationRequest/Response/Status models."""
def test_request_default(self):
req = MigrationRequest()
assert req.limit == 0
def test_request_with_limit(self):
req = MigrationRequest(limit=100)
assert req.limit == 100
def test_response_defaults(self):
resp = MigrationResponse()
assert resp.status == "completed"
assert resp.stats == {}
def test_status_defaults(self):
s = MigrationStatusResponse()
assert s.total_controls == 0
assert s.coverage_obligation_pct == 0.0
assert s.coverage_pattern_pct == 0.0
assert s.coverage_full_pct == 0.0
# ---------------------------------------------------------------------------
# HELPER TESTS
# ---------------------------------------------------------------------------
class TestGetPatternControlCounts:
"""Tests for _get_pattern_control_counts helper."""
@patch("compliance.api.crosswalk_routes.SessionLocal")
def test_returns_counts(self, mock_session_class):
mock_db = MagicMock()
mock_session_class.return_value = mock_db
mock_result = MagicMock()
mock_result.fetchall.return_value = [
("CP-AUTH-001", 15),
("CP-CRYP-001", 8),
]
mock_db.execute.return_value = mock_result
counts = _get_pattern_control_counts()
assert counts == {"CP-AUTH-001": 15, "CP-CRYP-001": 8}
mock_db.close.assert_called_once()
@patch("compliance.api.crosswalk_routes.SessionLocal")
def test_returns_empty_on_error(self, mock_session_class):
mock_db = MagicMock()
mock_session_class.return_value = mock_db
mock_db.execute.side_effect = Exception("DB down")
counts = _get_pattern_control_counts()
assert counts == {}
mock_db.close.assert_called_once()
@patch("compliance.api.crosswalk_routes.SessionLocal")
def test_returns_empty_when_no_patterns(self, mock_session_class):
mock_db = MagicMock()
mock_session_class.return_value = mock_db
mock_result = MagicMock()
mock_result.fetchall.return_value = []
mock_db.execute.return_value = mock_result
counts = _get_pattern_control_counts()
assert counts == {}
# ---------------------------------------------------------------------------
# PATTERN LIBRARY ENDPOINT TESTS
# ---------------------------------------------------------------------------
class _FakePattern:
"""Minimal pattern stub for list_patterns / get_pattern tests."""
def __init__(self, id, name="test", name_de="Test", domain="AUTH",
category="auth", description="desc",
objective_template="obj", severity_default="medium",
implementation_effort_default="m", tags=None,
composable_with=None, open_anchor_refs=None,
rationale_template="", requirements_template=None,
test_procedure_template=None, evidence_template=None,
obligation_match_keywords=None):
self.id = id
self.name = name
self.name_de = name_de
self.domain = domain
self.category = category
self.description = description
self.objective_template = objective_template
self.severity_default = severity_default
self.implementation_effort_default = implementation_effort_default
self.tags = tags or []
self.composable_with = composable_with or []
self.open_anchor_refs = open_anchor_refs or []
self.rationale_template = rationale_template
self.requirements_template = requirements_template or []
self.test_procedure_template = test_procedure_template or []
self.evidence_template = evidence_template or []
self.obligation_match_keywords = obligation_match_keywords or []
class TestListPatternsEndpoint:
"""Tests for GET /patterns."""
@patch("compliance.api.crosswalk_routes._get_pattern_control_counts")
@patch("compliance.api.crosswalk_routes.PatternMatcher", create=True)
def test_list_all_patterns(self, mock_matcher_import, mock_counts):
"""Patch the PatternMatcher class used inside the endpoint."""
fake_patterns = [
_FakePattern("CP-AUTH-001", domain="AUTH", tags=["auth"]),
_FakePattern("CP-CRYP-001", domain="CRYP", tags=["encryption"]),
]
mock_counts.return_value = {"CP-AUTH-001": 5}
with patch("compliance.services.pattern_matcher.PatternMatcher") as MockPM:
instance = MagicMock()
instance._patterns = fake_patterns
MockPM.return_value = instance
resp = _client.get("/api/compliance/v1/canonical/patterns")
assert resp.status_code == 200
data = resp.json()
assert data["total"] == 2
assert len(data["patterns"]) == 2
assert data["patterns"][0]["id"] == "CP-AUTH-001"
assert data["patterns"][0]["controls_count"] == 5
assert data["patterns"][1]["controls_count"] == 0
@patch("compliance.api.crosswalk_routes._get_pattern_control_counts")
def test_filter_by_domain(self, mock_counts):
fake_patterns = [
_FakePattern("CP-AUTH-001", domain="AUTH"),
_FakePattern("CP-CRYP-001", domain="CRYP"),
]
mock_counts.return_value = {}
with patch("compliance.services.pattern_matcher.PatternMatcher") as MockPM:
instance = MagicMock()
instance._patterns = fake_patterns
MockPM.return_value = instance
resp = _client.get("/api/compliance/v1/canonical/patterns?domain=auth")
assert resp.status_code == 200
data = resp.json()
assert data["total"] == 1
assert data["patterns"][0]["id"] == "CP-AUTH-001"
@patch("compliance.api.crosswalk_routes._get_pattern_control_counts")
def test_filter_by_category(self, mock_counts):
fake_patterns = [
_FakePattern("CP-AUTH-001", category="authentication"),
_FakePattern("CP-CRYP-001", category="encryption"),
]
mock_counts.return_value = {}
with patch("compliance.services.pattern_matcher.PatternMatcher") as MockPM:
instance = MagicMock()
instance._patterns = fake_patterns
MockPM.return_value = instance
resp = _client.get("/api/compliance/v1/canonical/patterns?category=encryption")
assert resp.status_code == 200
data = resp.json()
assert data["total"] == 1
assert data["patterns"][0]["id"] == "CP-CRYP-001"
@patch("compliance.api.crosswalk_routes._get_pattern_control_counts")
def test_filter_by_tag(self, mock_counts):
fake_patterns = [
_FakePattern("CP-AUTH-001", tags=["auth", "password"]),
_FakePattern("CP-CRYP-001", tags=["encryption"]),
]
mock_counts.return_value = {}
with patch("compliance.services.pattern_matcher.PatternMatcher") as MockPM:
instance = MagicMock()
instance._patterns = fake_patterns
MockPM.return_value = instance
resp = _client.get("/api/compliance/v1/canonical/patterns?tag=password")
assert resp.status_code == 200
data = resp.json()
assert data["total"] == 1
assert data["patterns"][0]["id"] == "CP-AUTH-001"
class TestGetPatternEndpoint:
"""Tests for GET /patterns/{pattern_id}."""
@patch("compliance.api.crosswalk_routes._get_pattern_control_counts")
def test_get_existing_pattern(self, mock_counts):
fake = _FakePattern(
"CP-AUTH-001",
name="password_policy",
rationale_template="Weak passwords are risky.",
requirements_template=["Min 12 chars"],
obligation_match_keywords=["passwort"],
)
mock_counts.return_value = {"CP-AUTH-001": 10}
with patch("compliance.services.pattern_matcher.PatternMatcher") as MockPM:
instance = MagicMock()
instance.get_pattern.return_value = fake
MockPM.return_value = instance
resp = _client.get("/api/compliance/v1/canonical/patterns/CP-AUTH-001")
assert resp.status_code == 200
data = resp.json()
assert data["id"] == "CP-AUTH-001"
assert data["rationale_template"] == "Weak passwords are risky."
assert data["obligation_match_keywords"] == ["passwort"]
assert data["controls_count"] == 10
@patch("compliance.api.crosswalk_routes._get_pattern_control_counts")
def test_get_nonexistent_pattern(self, mock_counts):
mock_counts.return_value = {}
with patch("compliance.services.pattern_matcher.PatternMatcher") as MockPM:
instance = MagicMock()
instance.get_pattern.return_value = None
MockPM.return_value = instance
resp = _client.get("/api/compliance/v1/canonical/patterns/CP-FAKE-999")
assert resp.status_code == 404
assert "not found" in resp.json()["detail"].lower()
class TestGetPatternControlsEndpoint:
"""Tests for GET /patterns/{pattern_id}/controls."""
@patch("compliance.api.crosswalk_routes.SessionLocal")
def test_returns_controls(self, mock_session_class):
mock_db = MagicMock()
mock_session_class.return_value = mock_db
# Main query result
mock_result = MagicMock()
mock_result.fetchall.return_value = [
("uuid-1", "AUTH-001", "MFA", "Require MFA", "high", "draft", "authentication", '["DSGVO-OBL-001"]'),
("uuid-2", "AUTH-002", "SSO", "Implement SSO", "medium", "draft", "authentication", None),
]
# Count query result
mock_count = MagicMock()
mock_count.fetchone.return_value = (2,)
mock_db.execute.side_effect = [mock_result, mock_count]
resp = _client.get("/api/compliance/v1/canonical/patterns/CP-AUTH-001/controls")
assert resp.status_code == 200
data = resp.json()
assert data["total"] == 2
assert len(data["controls"]) == 2
assert data["controls"][0]["control_id"] == "AUTH-001"
assert data["controls"][0]["obligation_ids"] == ["DSGVO-OBL-001"]
assert data["controls"][1]["obligation_ids"] == []
mock_db.close.assert_called_once()
@patch("compliance.api.crosswalk_routes.SessionLocal")
def test_pagination(self, mock_session_class):
mock_db = MagicMock()
mock_session_class.return_value = mock_db
mock_result = MagicMock()
mock_result.fetchall.return_value = []
mock_count = MagicMock()
mock_count.fetchone.return_value = (0,)
mock_db.execute.side_effect = [mock_result, mock_count]
resp = _client.get("/api/compliance/v1/canonical/patterns/CP-AUTH-001/controls?limit=10&offset=20")
assert resp.status_code == 200
data = resp.json()
assert data["total"] == 0
assert data["controls"] == []
@patch("compliance.api.crosswalk_routes.SessionLocal")
def test_json_parse_error_in_obligation_ids(self, mock_session_class):
mock_db = MagicMock()
mock_session_class.return_value = mock_db
mock_result = MagicMock()
mock_result.fetchall.return_value = [
("uuid-1", "AUTH-001", "MFA", "obj", "high", "draft", "auth", "not-valid-json"),
]
mock_count = MagicMock()
mock_count.fetchone.return_value = (1,)
mock_db.execute.side_effect = [mock_result, mock_count]
resp = _client.get("/api/compliance/v1/canonical/patterns/CP-AUTH-001/controls")
assert resp.status_code == 200
data = resp.json()
assert data["controls"][0]["obligation_ids"] == []
# ---------------------------------------------------------------------------
# OBLIGATION EXTRACTION ENDPOINT TESTS
# ---------------------------------------------------------------------------
class TestExtractObligationEndpoint:
"""Tests for POST /obligations/extract."""
@patch("compliance.services.pattern_matcher.PatternMatcher")
@patch("compliance.services.obligation_extractor.ObligationExtractor")
def test_extract_with_pattern_match(self, MockExtractor, MockMatcher):
# Mock extractor
mock_ext = AsyncMock()
MockExtractor.return_value = mock_ext
mock_obligation = MagicMock()
mock_obligation.obligation_id = "DSGVO-OBL-001"
mock_obligation.obligation_title = "VVT"
mock_obligation.obligation_text = "Der Verantwortliche muss..."
mock_obligation.method = "exact_match"
mock_obligation.confidence = 1.0
mock_obligation.regulation_id = "dsgvo"
mock_ext.extract.return_value = mock_obligation
# Mock matcher
mock_pm = MagicMock()
MockMatcher.return_value = mock_pm
mock_pattern_result = MagicMock()
mock_pattern_result.pattern_id = "CP-GOV-001"
mock_pattern_result.confidence = 0.85
mock_pm._tier1_keyword.return_value = mock_pattern_result
resp = _client.post(
"/api/compliance/v1/canonical/obligations/extract",
json={"text": "Art. 30 DSGVO Verzeichnis", "regulation_code": "eu_2016_679"},
)
assert resp.status_code == 200
data = resp.json()
assert data["obligation_id"] == "DSGVO-OBL-001"
assert data["method"] == "exact_match"
assert data["pattern_id"] == "CP-GOV-001"
assert data["pattern_confidence"] == 0.85
@patch("compliance.services.pattern_matcher.PatternMatcher")
@patch("compliance.services.obligation_extractor.ObligationExtractor")
def test_extract_no_pattern_match(self, MockExtractor, MockMatcher):
mock_ext = AsyncMock()
MockExtractor.return_value = mock_ext
mock_obligation = MagicMock()
mock_obligation.obligation_id = None
mock_obligation.obligation_title = None
mock_obligation.obligation_text = "Some text"
mock_obligation.method = "llm_extracted"
mock_obligation.confidence = 0.6
mock_obligation.regulation_id = None
mock_ext.extract.return_value = mock_obligation
mock_pm = MagicMock()
MockMatcher.return_value = mock_pm
mock_pm._tier1_keyword.return_value = None
resp = _client.post(
"/api/compliance/v1/canonical/obligations/extract",
json={"text": "Some random regulation text"},
)
assert resp.status_code == 200
data = resp.json()
assert data["obligation_id"] is None
assert data["pattern_id"] is None
assert data["pattern_confidence"] == 0.0
# ---------------------------------------------------------------------------
# CROSSWALK ENDPOINT TESTS
# ---------------------------------------------------------------------------
class TestQueryCrosswalkEndpoint:
"""Tests for GET /crosswalk."""
@patch("compliance.api.crosswalk_routes.SessionLocal")
def test_query_no_filters(self, mock_session_class):
mock_db = MagicMock()
mock_session_class.return_value = mock_db
mock_result = MagicMock()
mock_result.fetchall.return_value = [
("eu_2016_679", "Art. 32", "DSGVO-OBL-002", "CP-CRYP-001", "CRYP-001", 0.95, "auto"),
]
mock_count = MagicMock()
mock_count.fetchone.return_value = (1,)
mock_db.execute.side_effect = [mock_result, mock_count]
resp = _client.get("/api/compliance/v1/canonical/crosswalk")
assert resp.status_code == 200
data = resp.json()
assert data["total"] == 1
assert data["rows"][0]["regulation_code"] == "eu_2016_679"
assert data["rows"][0]["confidence"] == 0.95
mock_db.close.assert_called_once()
@patch("compliance.api.crosswalk_routes.SessionLocal")
def test_query_with_regulation_filter(self, mock_session_class):
mock_db = MagicMock()
mock_session_class.return_value = mock_db
mock_result = MagicMock()
mock_result.fetchall.return_value = []
mock_count = MagicMock()
mock_count.fetchone.return_value = (0,)
mock_db.execute.side_effect = [mock_result, mock_count]
resp = _client.get("/api/compliance/v1/canonical/crosswalk?regulation_code=eu_2016_679")
assert resp.status_code == 200
data = resp.json()
assert data["total"] == 0
# Verify SQL contained regulation filter
call_args = mock_db.execute.call_args_list[0]
sql_text = call_args[0][0].text
assert "regulation_code = :reg" in sql_text
@patch("compliance.api.crosswalk_routes.SessionLocal")
def test_query_with_all_filters(self, mock_session_class):
mock_db = MagicMock()
mock_session_class.return_value = mock_db
mock_result = MagicMock()
mock_result.fetchall.return_value = []
mock_count = MagicMock()
mock_count.fetchone.return_value = (0,)
mock_db.execute.side_effect = [mock_result, mock_count]
resp = _client.get(
"/api/compliance/v1/canonical/crosswalk"
"?regulation_code=eu_2016_679"
"&article=Art.%2032"
"&obligation_id=DSGVO-OBL-002"
"&pattern_id=CP-CRYP-001"
)
assert resp.status_code == 200
# Verify params were passed
call_args = mock_db.execute.call_args_list[0]
params = call_args[0][1]
assert params["reg"] == "eu_2016_679"
assert params["art"] == "Art. 32"
assert params["obl"] == "DSGVO-OBL-002"
assert params["pat"] == "CP-CRYP-001"
@patch("compliance.api.crosswalk_routes.SessionLocal")
def test_query_null_values_handled(self, mock_session_class):
mock_db = MagicMock()
mock_session_class.return_value = mock_db
mock_result = MagicMock()
mock_result.fetchall.return_value = [
("eu_2016_679", None, None, None, None, None, None),
]
mock_count = MagicMock()
mock_count.fetchone.return_value = (1,)
mock_db.execute.side_effect = [mock_result, mock_count]
resp = _client.get("/api/compliance/v1/canonical/crosswalk")
assert resp.status_code == 200
data = resp.json()
row = data["rows"][0]
assert row["confidence"] == 0.0
assert row["source"] == "auto"
@patch("compliance.api.crosswalk_routes.SessionLocal")
def test_query_pagination(self, mock_session_class):
mock_db = MagicMock()
mock_session_class.return_value = mock_db
mock_result = MagicMock()
mock_result.fetchall.return_value = []
mock_count = MagicMock()
mock_count.fetchone.return_value = (500,)
mock_db.execute.side_effect = [mock_result, mock_count]
resp = _client.get("/api/compliance/v1/canonical/crosswalk?limit=50&offset=100")
assert resp.status_code == 200
data = resp.json()
assert data["total"] == 500
call_args = mock_db.execute.call_args_list[0]
params = call_args[0][1]
assert params["limit"] == 50
assert params["offset"] == 100
class TestCrosswalkStatsEndpoint:
"""Tests for GET /crosswalk/stats."""
@patch("compliance.api.crosswalk_routes.SessionLocal")
def test_returns_stats(self, mock_session_class):
mock_db = MagicMock()
mock_session_class.return_value = mock_db
# Main stats query
mock_main = MagicMock()
mock_main.fetchone.return_value = (500, 9, 200, 45, 350)
# Coverage by regulation query
mock_reg = MagicMock()
mock_reg.fetchall.return_value = [
("eu_2016_679", 150),
("eu_2022_2554", 80),
]
mock_db.execute.side_effect = [mock_main, mock_reg]
resp = _client.get("/api/compliance/v1/canonical/crosswalk/stats")
assert resp.status_code == 200
data = resp.json()
assert data["total_rows"] == 500
assert data["regulations_covered"] == 9
assert data["obligations_linked"] == 200
assert data["patterns_used"] == 45
assert data["controls_linked"] == 350
assert data["coverage_by_regulation"]["eu_2016_679"] == 150
mock_db.close.assert_called_once()
@patch("compliance.api.crosswalk_routes.SessionLocal")
def test_empty_stats(self, mock_session_class):
mock_db = MagicMock()
mock_session_class.return_value = mock_db
mock_main = MagicMock()
mock_main.fetchone.return_value = (0, 0, 0, 0, 0)
mock_reg = MagicMock()
mock_reg.fetchall.return_value = []
mock_db.execute.side_effect = [mock_main, mock_reg]
resp = _client.get("/api/compliance/v1/canonical/crosswalk/stats")
assert resp.status_code == 200
data = resp.json()
assert data["total_rows"] == 0
assert data["coverage_by_regulation"] == {}
# ---------------------------------------------------------------------------
# MIGRATION ENDPOINT TESTS
# ---------------------------------------------------------------------------
class TestMigratePass1Endpoint:
"""Tests for POST /migrate/link-obligations."""
@patch("compliance.api.crosswalk_routes.SessionLocal")
@patch("compliance.services.pipeline_adapter.MigrationPasses")
def test_pass1_success(self, MockMigration, mock_session_class):
mock_db = MagicMock()
mock_session_class.return_value = mock_db
mock_mig = AsyncMock()
MockMigration.return_value = mock_mig
mock_mig.run_pass1_obligation_linkage.return_value = {
"processed": 100, "linked": 60, "skipped": 40,
}
resp = _client.post(
"/api/compliance/v1/canonical/migrate/link-obligations",
json={"limit": 100},
)
assert resp.status_code == 200
data = resp.json()
assert data["status"] == "completed"
assert data["stats"]["linked"] == 60
mock_db.close.assert_called_once()
@patch("compliance.api.crosswalk_routes.SessionLocal")
@patch("compliance.services.pipeline_adapter.MigrationPasses")
def test_pass1_failure(self, MockMigration, mock_session_class):
mock_db = MagicMock()
mock_session_class.return_value = mock_db
mock_mig = AsyncMock()
MockMigration.return_value = mock_mig
mock_mig.run_pass1_obligation_linkage.side_effect = RuntimeError("DB connection lost")
resp = _client.post(
"/api/compliance/v1/canonical/migrate/link-obligations",
json={},
)
assert resp.status_code == 500
assert "DB connection lost" in resp.json()["detail"]
mock_db.close.assert_called_once()
class TestMigratePass2Endpoint:
"""Tests for POST /migrate/classify-patterns."""
@patch("compliance.api.crosswalk_routes.SessionLocal")
@patch("compliance.services.pipeline_adapter.MigrationPasses")
def test_pass2_success(self, MockMigration, mock_session_class):
mock_db = MagicMock()
mock_session_class.return_value = mock_db
mock_mig = AsyncMock()
MockMigration.return_value = mock_mig
mock_mig.run_pass2_pattern_classification.return_value = {
"processed": 200, "classified": 140, "candidates": 30, "unmatched": 30,
}
resp = _client.post(
"/api/compliance/v1/canonical/migrate/classify-patterns",
json={"limit": 200},
)
assert resp.status_code == 200
data = resp.json()
assert data["stats"]["classified"] == 140
class TestMigratePass3Endpoint:
"""Tests for POST /migrate/triage."""
@patch("compliance.api.crosswalk_routes.SessionLocal")
@patch("compliance.services.pipeline_adapter.MigrationPasses")
def test_pass3_success(self, MockMigration, mock_session_class):
mock_db = MagicMock()
mock_session_class.return_value = mock_db
mock_mig = MagicMock()
MockMigration.return_value = mock_mig
mock_mig.run_pass3_quality_triage.return_value = {
"review": 60, "needs_obligation": 20, "needs_pattern": 15, "legacy_unlinked": 5,
}
resp = _client.post("/api/compliance/v1/canonical/migrate/triage")
assert resp.status_code == 200
data = resp.json()
assert data["stats"]["review"] == 60
@patch("compliance.api.crosswalk_routes.SessionLocal")
@patch("compliance.services.pipeline_adapter.MigrationPasses")
def test_pass3_failure(self, MockMigration, mock_session_class):
mock_db = MagicMock()
mock_session_class.return_value = mock_db
mock_mig = MagicMock()
MockMigration.return_value = mock_mig
mock_mig.run_pass3_quality_triage.side_effect = Exception("triage error")
resp = _client.post("/api/compliance/v1/canonical/migrate/triage")
assert resp.status_code == 500
class TestMigratePass4Endpoint:
"""Tests for POST /migrate/backfill-crosswalk."""
@patch("compliance.api.crosswalk_routes.SessionLocal")
@patch("compliance.services.pipeline_adapter.MigrationPasses")
def test_pass4_success(self, MockMigration, mock_session_class):
mock_db = MagicMock()
mock_session_class.return_value = mock_db
mock_mig = MagicMock()
MockMigration.return_value = mock_mig
mock_mig.run_pass4_crosswalk_backfill.return_value = {
"rows_created": 250,
}
resp = _client.post("/api/compliance/v1/canonical/migrate/backfill-crosswalk")
assert resp.status_code == 200
data = resp.json()
assert data["stats"]["rows_created"] == 250
class TestMigratePass5Endpoint:
"""Tests for POST /migrate/deduplicate."""
@patch("compliance.api.crosswalk_routes.SessionLocal")
@patch("compliance.services.pipeline_adapter.MigrationPasses")
def test_pass5_success(self, MockMigration, mock_session_class):
mock_db = MagicMock()
mock_session_class.return_value = mock_db
mock_mig = MagicMock()
MockMigration.return_value = mock_mig
mock_mig.run_pass5_deduplication.return_value = {
"groups_found": 80, "deprecated": 120, "kept": 80,
}
resp = _client.post("/api/compliance/v1/canonical/migrate/deduplicate")
assert resp.status_code == 200
data = resp.json()
assert data["stats"]["deprecated"] == 120
class TestMigrationStatusEndpoint:
"""Tests for GET /migrate/status."""
@patch("compliance.api.crosswalk_routes.SessionLocal")
@patch("compliance.services.pipeline_adapter.MigrationPasses")
def test_status_success(self, MockMigration, mock_session_class):
mock_db = MagicMock()
mock_session_class.return_value = mock_db
mock_mig = MagicMock()
MockMigration.return_value = mock_mig
mock_mig.migration_status.return_value = {
"total_controls": 4800,
"has_obligation": 2880,
"has_pattern": 3360,
"fully_linked": 2400,
"deprecated": 1200,
"coverage_obligation_pct": 60.0,
"coverage_pattern_pct": 70.0,
"coverage_full_pct": 50.0,
}
resp = _client.get("/api/compliance/v1/canonical/migrate/status")
assert resp.status_code == 200
data = resp.json()
assert data["total_controls"] == 4800
assert data["coverage_full_pct"] == 50.0
@patch("compliance.api.crosswalk_routes.SessionLocal")
@patch("compliance.services.pipeline_adapter.MigrationPasses")
def test_status_failure(self, MockMigration, mock_session_class):
mock_db = MagicMock()
mock_session_class.return_value = mock_db
mock_mig = MagicMock()
MockMigration.return_value = mock_mig
mock_mig.migration_status.side_effect = Exception("DB error")
resp = _client.get("/api/compliance/v1/canonical/migrate/status")
assert resp.status_code == 500
# ---------------------------------------------------------------------------
# DECOMPOSITION ENDPOINT TESTS (Pass 0a / 0b)
# ---------------------------------------------------------------------------
class TestDecompositionStatusModel:
"""Tests for DecompositionStatusResponse model."""
def test_defaults(self):
s = DecompositionStatusResponse()
assert s.rich_controls == 0
assert s.decomposition_pct == 0.0
assert s.composition_pct == 0.0
def test_full(self):
s = DecompositionStatusResponse(
rich_controls=5000,
decomposed_controls=1000,
total_candidates=3000,
validated=2500,
rejected=200,
composed=2000,
atomic_controls=1800,
decomposition_pct=20.0,
composition_pct=80.0,
)
assert s.rich_controls == 5000
assert s.atomic_controls == 1800
class TestMigrateDecomposeEndpoint:
"""Tests for POST /migrate/decompose (Pass 0a)."""
@patch("compliance.api.crosswalk_routes.SessionLocal")
@patch("compliance.services.decomposition_pass.DecompositionPass")
def test_pass0a_success(self, MockDecomp, mock_session_class):
mock_db = MagicMock()
mock_session_class.return_value = mock_db
mock_decomp = AsyncMock()
MockDecomp.return_value = mock_decomp
mock_decomp.run_pass0a.return_value = {
"controls_processed": 50,
"obligations_extracted": 180,
"obligations_validated": 160,
"obligations_rejected": 20,
"controls_skipped_empty": 5,
"errors": 0,
}
resp = _client.post(
"/api/compliance/v1/canonical/migrate/decompose",
json={"limit": 50},
)
assert resp.status_code == 200
data = resp.json()
assert data["status"] == "completed"
assert data["stats"]["obligations_extracted"] == 180
mock_db.close.assert_called_once()
@patch("compliance.api.crosswalk_routes.SessionLocal")
@patch("compliance.services.decomposition_pass.DecompositionPass")
def test_pass0a_failure(self, MockDecomp, mock_session_class):
mock_db = MagicMock()
mock_session_class.return_value = mock_db
mock_decomp = AsyncMock()
MockDecomp.return_value = mock_decomp
mock_decomp.run_pass0a.side_effect = RuntimeError("LLM timeout")
resp = _client.post(
"/api/compliance/v1/canonical/migrate/decompose",
json={},
)
assert resp.status_code == 500
assert "LLM timeout" in resp.json()["detail"]
class TestMigrateComposeAtomicEndpoint:
"""Tests for POST /migrate/compose-atomic (Pass 0b)."""
@patch("compliance.api.crosswalk_routes.SessionLocal")
@patch("compliance.services.decomposition_pass.DecompositionPass")
def test_pass0b_success(self, MockDecomp, mock_session_class):
mock_db = MagicMock()
mock_session_class.return_value = mock_db
mock_decomp = AsyncMock()
MockDecomp.return_value = mock_decomp
mock_decomp.run_pass0b.return_value = {
"candidates_processed": 160,
"controls_created": 155,
"llm_failures": 5,
"errors": 0,
}
resp = _client.post(
"/api/compliance/v1/canonical/migrate/compose-atomic",
json={"limit": 200},
)
assert resp.status_code == 200
data = resp.json()
assert data["stats"]["controls_created"] == 155
class TestDecompositionStatusEndpoint:
"""Tests for GET /migrate/decomposition-status."""
@patch("compliance.api.crosswalk_routes.SessionLocal")
@patch("compliance.services.decomposition_pass.DecompositionPass")
def test_status_success(self, MockDecomp, mock_session_class):
mock_db = MagicMock()
mock_session_class.return_value = mock_db
mock_decomp = MagicMock()
MockDecomp.return_value = mock_decomp
mock_decomp.decomposition_status.return_value = {
"rich_controls": 5000,
"decomposed_controls": 1000,
"total_candidates": 3000,
"validated": 2500,
"rejected": 200,
"composed": 2000,
"atomic_controls": 1800,
"decomposition_pct": 20.0,
"composition_pct": 80.0,
}
resp = _client.get("/api/compliance/v1/canonical/migrate/decomposition-status")
assert resp.status_code == 200
data = resp.json()
assert data["rich_controls"] == 5000
assert data["atomic_controls"] == 1800
assert data["decomposition_pct"] == 20.0
@patch("compliance.api.crosswalk_routes.SessionLocal")
@patch("compliance.services.decomposition_pass.DecompositionPass")
def test_status_failure(self, MockDecomp, mock_session_class):
mock_db = MagicMock()
mock_session_class.return_value = mock_db
mock_decomp = MagicMock()
MockDecomp.return_value = mock_decomp
mock_decomp.decomposition_status.side_effect = Exception("DB error")
resp = _client.get("/api/compliance/v1/canonical/migrate/decomposition-status")
assert resp.status_code == 500