"""Tests for Multi-Layer Control Architecture routes (crosswalk_routes.py). Covers: - Pydantic model validation (Pattern, Obligation, Crosswalk, Migration) - Pattern Library endpoints (list, get, get controls) - Obligation extraction endpoint - Crosswalk query + stats endpoints - Migration pass endpoints (1-5) + status - Helper: _get_pattern_control_counts """ import json import pytest from unittest.mock import MagicMock, patch, AsyncMock from typing import Optional from fastapi import FastAPI from fastapi.testclient import TestClient from compliance.api.crosswalk_routes import ( PatternResponse, PatternListResponse, PatternDetailResponse, ObligationExtractRequest, ObligationExtractResponse, CrosswalkRow, CrosswalkQueryResponse, CrosswalkStatsResponse, MigrationRequest, MigrationResponse, MigrationStatusResponse, DecompositionStatusResponse, router, _get_pattern_control_counts, ) # --------------------------------------------------------------------------- # TestClient setup # --------------------------------------------------------------------------- _app = FastAPI() _app.include_router(router, prefix="/api/compliance") _client = TestClient(_app) # --------------------------------------------------------------------------- # MODEL TESTS # --------------------------------------------------------------------------- class TestPatternResponse: """Tests for PatternResponse model.""" def test_basic_creation(self): resp = PatternResponse( id="CP-AUTH-001", name="password_policy", name_de="Passwortrichtlinie", domain="AUTH", category="authentication", description="Password policy requirements", objective_template="Ensure passwords meet complexity standards.", severity_default="high", ) assert resp.id == "CP-AUTH-001" assert resp.domain == "AUTH" assert resp.severity_default == "high" def test_default_values(self): resp = PatternResponse( id="CP-AUTH-001", name="test", name_de="Test", domain="AUTH", category="auth", description="desc", objective_template="obj", severity_default="medium", ) assert resp.implementation_effort_default == "m" assert resp.tags == [] assert resp.composable_with == [] assert resp.open_anchor_refs == [] assert resp.controls_count == 0 def test_full_model(self): resp = PatternResponse( id="CP-CRYP-001", name="encryption_at_rest", name_de="Verschluesselung ruhender Daten", domain="CRYP", category="encryption", description="Encrypt data at rest", objective_template="Ensure all stored data is encrypted.", severity_default="critical", implementation_effort_default="l", tags=["encryption", "storage"], composable_with=["CP-CRYP-002"], open_anchor_refs=[{"framework": "NIST", "ref": "SC-28"}], controls_count=42, ) assert resp.controls_count == 42 assert len(resp.tags) == 2 assert resp.implementation_effort_default == "l" class TestPatternDetailResponse: """Tests for PatternDetailResponse model (extends PatternResponse).""" def test_has_extended_fields(self): resp = PatternDetailResponse( id="CP-AUTH-001", name="mfa", name_de="MFA", domain="AUTH", category="authentication", description="Multi-factor authentication", objective_template="Require MFA.", severity_default="high", rationale_template="Passwords alone are insufficient.", requirements_template=["Require TOTP or hardware key"], test_procedure_template=["Test login without MFA"], evidence_template=["MFA config screenshot"], obligation_match_keywords=["authentifizierung", "mfa"], ) assert resp.rationale_template == "Passwords alone are insufficient." assert len(resp.requirements_template) == 1 assert len(resp.obligation_match_keywords) == 2 def test_defaults(self): resp = PatternDetailResponse( id="CP-AUTH-001", name="test", name_de="Test", domain="AUTH", category="auth", description="desc", objective_template="obj", severity_default="medium", ) assert resp.rationale_template == "" assert resp.requirements_template == [] assert resp.test_procedure_template == [] assert resp.evidence_template == [] assert resp.obligation_match_keywords == [] class TestObligationModels: """Tests for ObligationExtractRequest/Response models.""" def test_request_minimal(self): req = ObligationExtractRequest(text="Ein Verantwortlicher muss...") assert req.text == "Ein Verantwortlicher muss..." assert req.regulation_code is None assert req.article is None assert req.paragraph is None def test_request_full(self): req = ObligationExtractRequest( text="Art. 32 DSGVO", regulation_code="eu_2016_679", article="Art. 32", paragraph="Abs. 1", ) assert req.regulation_code == "eu_2016_679" assert req.paragraph == "Abs. 1" def test_response_defaults(self): resp = ObligationExtractResponse() assert resp.obligation_id is None assert resp.method == "none" assert resp.confidence == 0.0 assert resp.pattern_id is None assert resp.pattern_confidence == 0.0 def test_response_full(self): resp = ObligationExtractResponse( obligation_id="DSGVO-OBL-001", obligation_title="Verzeichnis der Verarbeitungstaetigkeiten", obligation_text="Der Verantwortliche muss...", method="exact_match", confidence=1.0, regulation_id="dsgvo", pattern_id="CP-GOV-001", pattern_confidence=0.85, ) assert resp.obligation_id == "DSGVO-OBL-001" assert resp.method == "exact_match" assert resp.confidence == 1.0 assert resp.pattern_confidence == 0.85 class TestCrosswalkModels: """Tests for CrosswalkRow and CrosswalkQueryResponse models.""" def test_row_defaults(self): row = CrosswalkRow() assert row.regulation_code == "" assert row.article is None assert row.obligation_id is None assert row.confidence == 0.0 assert row.source == "auto" def test_row_full(self): row = CrosswalkRow( regulation_code="eu_2016_679", article="Art. 32", obligation_id="DSGVO-OBL-002", pattern_id="CP-CRYP-001", master_control_id="CRYP-001", confidence=0.95, source="manual", ) assert row.regulation_code == "eu_2016_679" assert row.confidence == 0.95 def test_query_response(self): resp = CrosswalkQueryResponse( rows=[ CrosswalkRow(regulation_code="eu_2016_679"), CrosswalkRow(regulation_code="eu_2022_2554"), ], total=100, ) assert len(resp.rows) == 2 assert resp.total == 100 class TestCrosswalkStatsResponse: """Tests for CrosswalkStatsResponse model.""" def test_defaults(self): resp = CrosswalkStatsResponse() assert resp.total_rows == 0 assert resp.regulations_covered == 0 assert resp.obligations_linked == 0 assert resp.patterns_used == 0 assert resp.controls_linked == 0 assert resp.coverage_by_regulation == {} def test_full(self): resp = CrosswalkStatsResponse( total_rows=500, regulations_covered=9, obligations_linked=200, patterns_used=45, controls_linked=350, coverage_by_regulation={"eu_2016_679": 150, "eu_2022_2554": 80}, ) assert resp.total_rows == 500 assert resp.coverage_by_regulation["eu_2016_679"] == 150 class TestMigrationModels: """Tests for MigrationRequest/Response/Status models.""" def test_request_default(self): req = MigrationRequest() assert req.limit == 0 def test_request_with_limit(self): req = MigrationRequest(limit=100) assert req.limit == 100 def test_response_defaults(self): resp = MigrationResponse() assert resp.status == "completed" assert resp.stats == {} def test_status_defaults(self): s = MigrationStatusResponse() assert s.total_controls == 0 assert s.coverage_obligation_pct == 0.0 assert s.coverage_pattern_pct == 0.0 assert s.coverage_full_pct == 0.0 # --------------------------------------------------------------------------- # HELPER TESTS # --------------------------------------------------------------------------- class TestGetPatternControlCounts: """Tests for _get_pattern_control_counts helper.""" @patch("compliance.api.crosswalk_routes.SessionLocal") def test_returns_counts(self, mock_session_class): mock_db = MagicMock() mock_session_class.return_value = mock_db mock_result = MagicMock() mock_result.fetchall.return_value = [ ("CP-AUTH-001", 15), ("CP-CRYP-001", 8), ] mock_db.execute.return_value = mock_result counts = _get_pattern_control_counts() assert counts == {"CP-AUTH-001": 15, "CP-CRYP-001": 8} mock_db.close.assert_called_once() @patch("compliance.api.crosswalk_routes.SessionLocal") def test_returns_empty_on_error(self, mock_session_class): mock_db = MagicMock() mock_session_class.return_value = mock_db mock_db.execute.side_effect = Exception("DB down") counts = _get_pattern_control_counts() assert counts == {} mock_db.close.assert_called_once() @patch("compliance.api.crosswalk_routes.SessionLocal") def test_returns_empty_when_no_patterns(self, mock_session_class): mock_db = MagicMock() mock_session_class.return_value = mock_db mock_result = MagicMock() mock_result.fetchall.return_value = [] mock_db.execute.return_value = mock_result counts = _get_pattern_control_counts() assert counts == {} # --------------------------------------------------------------------------- # PATTERN LIBRARY ENDPOINT TESTS # --------------------------------------------------------------------------- class _FakePattern: """Minimal pattern stub for list_patterns / get_pattern tests.""" def __init__(self, id, name="test", name_de="Test", domain="AUTH", category="auth", description="desc", objective_template="obj", severity_default="medium", implementation_effort_default="m", tags=None, composable_with=None, open_anchor_refs=None, rationale_template="", requirements_template=None, test_procedure_template=None, evidence_template=None, obligation_match_keywords=None): self.id = id self.name = name self.name_de = name_de self.domain = domain self.category = category self.description = description self.objective_template = objective_template self.severity_default = severity_default self.implementation_effort_default = implementation_effort_default self.tags = tags or [] self.composable_with = composable_with or [] self.open_anchor_refs = open_anchor_refs or [] self.rationale_template = rationale_template self.requirements_template = requirements_template or [] self.test_procedure_template = test_procedure_template or [] self.evidence_template = evidence_template or [] self.obligation_match_keywords = obligation_match_keywords or [] class TestListPatternsEndpoint: """Tests for GET /patterns.""" @patch("compliance.api.crosswalk_routes._get_pattern_control_counts") @patch("compliance.api.crosswalk_routes.PatternMatcher", create=True) def test_list_all_patterns(self, mock_matcher_import, mock_counts): """Patch the PatternMatcher class used inside the endpoint.""" fake_patterns = [ _FakePattern("CP-AUTH-001", domain="AUTH", tags=["auth"]), _FakePattern("CP-CRYP-001", domain="CRYP", tags=["encryption"]), ] mock_counts.return_value = {"CP-AUTH-001": 5} with patch("compliance.services.pattern_matcher.PatternMatcher") as MockPM: instance = MagicMock() instance._patterns = fake_patterns MockPM.return_value = instance resp = _client.get("/api/compliance/v1/canonical/patterns") assert resp.status_code == 200 data = resp.json() assert data["total"] == 2 assert len(data["patterns"]) == 2 assert data["patterns"][0]["id"] == "CP-AUTH-001" assert data["patterns"][0]["controls_count"] == 5 assert data["patterns"][1]["controls_count"] == 0 @patch("compliance.api.crosswalk_routes._get_pattern_control_counts") def test_filter_by_domain(self, mock_counts): fake_patterns = [ _FakePattern("CP-AUTH-001", domain="AUTH"), _FakePattern("CP-CRYP-001", domain="CRYP"), ] mock_counts.return_value = {} with patch("compliance.services.pattern_matcher.PatternMatcher") as MockPM: instance = MagicMock() instance._patterns = fake_patterns MockPM.return_value = instance resp = _client.get("/api/compliance/v1/canonical/patterns?domain=auth") assert resp.status_code == 200 data = resp.json() assert data["total"] == 1 assert data["patterns"][0]["id"] == "CP-AUTH-001" @patch("compliance.api.crosswalk_routes._get_pattern_control_counts") def test_filter_by_category(self, mock_counts): fake_patterns = [ _FakePattern("CP-AUTH-001", category="authentication"), _FakePattern("CP-CRYP-001", category="encryption"), ] mock_counts.return_value = {} with patch("compliance.services.pattern_matcher.PatternMatcher") as MockPM: instance = MagicMock() instance._patterns = fake_patterns MockPM.return_value = instance resp = _client.get("/api/compliance/v1/canonical/patterns?category=encryption") assert resp.status_code == 200 data = resp.json() assert data["total"] == 1 assert data["patterns"][0]["id"] == "CP-CRYP-001" @patch("compliance.api.crosswalk_routes._get_pattern_control_counts") def test_filter_by_tag(self, mock_counts): fake_patterns = [ _FakePattern("CP-AUTH-001", tags=["auth", "password"]), _FakePattern("CP-CRYP-001", tags=["encryption"]), ] mock_counts.return_value = {} with patch("compliance.services.pattern_matcher.PatternMatcher") as MockPM: instance = MagicMock() instance._patterns = fake_patterns MockPM.return_value = instance resp = _client.get("/api/compliance/v1/canonical/patterns?tag=password") assert resp.status_code == 200 data = resp.json() assert data["total"] == 1 assert data["patterns"][0]["id"] == "CP-AUTH-001" class TestGetPatternEndpoint: """Tests for GET /patterns/{pattern_id}.""" @patch("compliance.api.crosswalk_routes._get_pattern_control_counts") def test_get_existing_pattern(self, mock_counts): fake = _FakePattern( "CP-AUTH-001", name="password_policy", rationale_template="Weak passwords are risky.", requirements_template=["Min 12 chars"], obligation_match_keywords=["passwort"], ) mock_counts.return_value = {"CP-AUTH-001": 10} with patch("compliance.services.pattern_matcher.PatternMatcher") as MockPM: instance = MagicMock() instance.get_pattern.return_value = fake MockPM.return_value = instance resp = _client.get("/api/compliance/v1/canonical/patterns/CP-AUTH-001") assert resp.status_code == 200 data = resp.json() assert data["id"] == "CP-AUTH-001" assert data["rationale_template"] == "Weak passwords are risky." assert data["obligation_match_keywords"] == ["passwort"] assert data["controls_count"] == 10 @patch("compliance.api.crosswalk_routes._get_pattern_control_counts") def test_get_nonexistent_pattern(self, mock_counts): mock_counts.return_value = {} with patch("compliance.services.pattern_matcher.PatternMatcher") as MockPM: instance = MagicMock() instance.get_pattern.return_value = None MockPM.return_value = instance resp = _client.get("/api/compliance/v1/canonical/patterns/CP-FAKE-999") assert resp.status_code == 404 assert "not found" in resp.json()["detail"].lower() class TestGetPatternControlsEndpoint: """Tests for GET /patterns/{pattern_id}/controls.""" @patch("compliance.api.crosswalk_routes.SessionLocal") def test_returns_controls(self, mock_session_class): mock_db = MagicMock() mock_session_class.return_value = mock_db # Main query result mock_result = MagicMock() mock_result.fetchall.return_value = [ ("uuid-1", "AUTH-001", "MFA", "Require MFA", "high", "draft", "authentication", '["DSGVO-OBL-001"]'), ("uuid-2", "AUTH-002", "SSO", "Implement SSO", "medium", "draft", "authentication", None), ] # Count query result mock_count = MagicMock() mock_count.fetchone.return_value = (2,) mock_db.execute.side_effect = [mock_result, mock_count] resp = _client.get("/api/compliance/v1/canonical/patterns/CP-AUTH-001/controls") assert resp.status_code == 200 data = resp.json() assert data["total"] == 2 assert len(data["controls"]) == 2 assert data["controls"][0]["control_id"] == "AUTH-001" assert data["controls"][0]["obligation_ids"] == ["DSGVO-OBL-001"] assert data["controls"][1]["obligation_ids"] == [] mock_db.close.assert_called_once() @patch("compliance.api.crosswalk_routes.SessionLocal") def test_pagination(self, mock_session_class): mock_db = MagicMock() mock_session_class.return_value = mock_db mock_result = MagicMock() mock_result.fetchall.return_value = [] mock_count = MagicMock() mock_count.fetchone.return_value = (0,) mock_db.execute.side_effect = [mock_result, mock_count] resp = _client.get("/api/compliance/v1/canonical/patterns/CP-AUTH-001/controls?limit=10&offset=20") assert resp.status_code == 200 data = resp.json() assert data["total"] == 0 assert data["controls"] == [] @patch("compliance.api.crosswalk_routes.SessionLocal") def test_json_parse_error_in_obligation_ids(self, mock_session_class): mock_db = MagicMock() mock_session_class.return_value = mock_db mock_result = MagicMock() mock_result.fetchall.return_value = [ ("uuid-1", "AUTH-001", "MFA", "obj", "high", "draft", "auth", "not-valid-json"), ] mock_count = MagicMock() mock_count.fetchone.return_value = (1,) mock_db.execute.side_effect = [mock_result, mock_count] resp = _client.get("/api/compliance/v1/canonical/patterns/CP-AUTH-001/controls") assert resp.status_code == 200 data = resp.json() assert data["controls"][0]["obligation_ids"] == [] # --------------------------------------------------------------------------- # OBLIGATION EXTRACTION ENDPOINT TESTS # --------------------------------------------------------------------------- class TestExtractObligationEndpoint: """Tests for POST /obligations/extract.""" @patch("compliance.services.pattern_matcher.PatternMatcher") @patch("compliance.services.obligation_extractor.ObligationExtractor") def test_extract_with_pattern_match(self, MockExtractor, MockMatcher): # Mock extractor mock_ext = AsyncMock() MockExtractor.return_value = mock_ext mock_obligation = MagicMock() mock_obligation.obligation_id = "DSGVO-OBL-001" mock_obligation.obligation_title = "VVT" mock_obligation.obligation_text = "Der Verantwortliche muss..." mock_obligation.method = "exact_match" mock_obligation.confidence = 1.0 mock_obligation.regulation_id = "dsgvo" mock_ext.extract.return_value = mock_obligation # Mock matcher mock_pm = MagicMock() MockMatcher.return_value = mock_pm mock_pattern_result = MagicMock() mock_pattern_result.pattern_id = "CP-GOV-001" mock_pattern_result.confidence = 0.85 mock_pm._tier1_keyword.return_value = mock_pattern_result resp = _client.post( "/api/compliance/v1/canonical/obligations/extract", json={"text": "Art. 30 DSGVO Verzeichnis", "regulation_code": "eu_2016_679"}, ) assert resp.status_code == 200 data = resp.json() assert data["obligation_id"] == "DSGVO-OBL-001" assert data["method"] == "exact_match" assert data["pattern_id"] == "CP-GOV-001" assert data["pattern_confidence"] == 0.85 @patch("compliance.services.pattern_matcher.PatternMatcher") @patch("compliance.services.obligation_extractor.ObligationExtractor") def test_extract_no_pattern_match(self, MockExtractor, MockMatcher): mock_ext = AsyncMock() MockExtractor.return_value = mock_ext mock_obligation = MagicMock() mock_obligation.obligation_id = None mock_obligation.obligation_title = None mock_obligation.obligation_text = "Some text" mock_obligation.method = "llm_extracted" mock_obligation.confidence = 0.6 mock_obligation.regulation_id = None mock_ext.extract.return_value = mock_obligation mock_pm = MagicMock() MockMatcher.return_value = mock_pm mock_pm._tier1_keyword.return_value = None resp = _client.post( "/api/compliance/v1/canonical/obligations/extract", json={"text": "Some random regulation text"}, ) assert resp.status_code == 200 data = resp.json() assert data["obligation_id"] is None assert data["pattern_id"] is None assert data["pattern_confidence"] == 0.0 # --------------------------------------------------------------------------- # CROSSWALK ENDPOINT TESTS # --------------------------------------------------------------------------- class TestQueryCrosswalkEndpoint: """Tests for GET /crosswalk.""" @patch("compliance.api.crosswalk_routes.SessionLocal") def test_query_no_filters(self, mock_session_class): mock_db = MagicMock() mock_session_class.return_value = mock_db mock_result = MagicMock() mock_result.fetchall.return_value = [ ("eu_2016_679", "Art. 32", "DSGVO-OBL-002", "CP-CRYP-001", "CRYP-001", 0.95, "auto"), ] mock_count = MagicMock() mock_count.fetchone.return_value = (1,) mock_db.execute.side_effect = [mock_result, mock_count] resp = _client.get("/api/compliance/v1/canonical/crosswalk") assert resp.status_code == 200 data = resp.json() assert data["total"] == 1 assert data["rows"][0]["regulation_code"] == "eu_2016_679" assert data["rows"][0]["confidence"] == 0.95 mock_db.close.assert_called_once() @patch("compliance.api.crosswalk_routes.SessionLocal") def test_query_with_regulation_filter(self, mock_session_class): mock_db = MagicMock() mock_session_class.return_value = mock_db mock_result = MagicMock() mock_result.fetchall.return_value = [] mock_count = MagicMock() mock_count.fetchone.return_value = (0,) mock_db.execute.side_effect = [mock_result, mock_count] resp = _client.get("/api/compliance/v1/canonical/crosswalk?regulation_code=eu_2016_679") assert resp.status_code == 200 data = resp.json() assert data["total"] == 0 # Verify SQL contained regulation filter call_args = mock_db.execute.call_args_list[0] sql_text = call_args[0][0].text assert "regulation_code = :reg" in sql_text @patch("compliance.api.crosswalk_routes.SessionLocal") def test_query_with_all_filters(self, mock_session_class): mock_db = MagicMock() mock_session_class.return_value = mock_db mock_result = MagicMock() mock_result.fetchall.return_value = [] mock_count = MagicMock() mock_count.fetchone.return_value = (0,) mock_db.execute.side_effect = [mock_result, mock_count] resp = _client.get( "/api/compliance/v1/canonical/crosswalk" "?regulation_code=eu_2016_679" "&article=Art.%2032" "&obligation_id=DSGVO-OBL-002" "&pattern_id=CP-CRYP-001" ) assert resp.status_code == 200 # Verify params were passed call_args = mock_db.execute.call_args_list[0] params = call_args[0][1] assert params["reg"] == "eu_2016_679" assert params["art"] == "Art. 32" assert params["obl"] == "DSGVO-OBL-002" assert params["pat"] == "CP-CRYP-001" @patch("compliance.api.crosswalk_routes.SessionLocal") def test_query_null_values_handled(self, mock_session_class): mock_db = MagicMock() mock_session_class.return_value = mock_db mock_result = MagicMock() mock_result.fetchall.return_value = [ ("eu_2016_679", None, None, None, None, None, None), ] mock_count = MagicMock() mock_count.fetchone.return_value = (1,) mock_db.execute.side_effect = [mock_result, mock_count] resp = _client.get("/api/compliance/v1/canonical/crosswalk") assert resp.status_code == 200 data = resp.json() row = data["rows"][0] assert row["confidence"] == 0.0 assert row["source"] == "auto" @patch("compliance.api.crosswalk_routes.SessionLocal") def test_query_pagination(self, mock_session_class): mock_db = MagicMock() mock_session_class.return_value = mock_db mock_result = MagicMock() mock_result.fetchall.return_value = [] mock_count = MagicMock() mock_count.fetchone.return_value = (500,) mock_db.execute.side_effect = [mock_result, mock_count] resp = _client.get("/api/compliance/v1/canonical/crosswalk?limit=50&offset=100") assert resp.status_code == 200 data = resp.json() assert data["total"] == 500 call_args = mock_db.execute.call_args_list[0] params = call_args[0][1] assert params["limit"] == 50 assert params["offset"] == 100 class TestCrosswalkStatsEndpoint: """Tests for GET /crosswalk/stats.""" @patch("compliance.api.crosswalk_routes.SessionLocal") def test_returns_stats(self, mock_session_class): mock_db = MagicMock() mock_session_class.return_value = mock_db # Main stats query mock_main = MagicMock() mock_main.fetchone.return_value = (500, 9, 200, 45, 350) # Coverage by regulation query mock_reg = MagicMock() mock_reg.fetchall.return_value = [ ("eu_2016_679", 150), ("eu_2022_2554", 80), ] mock_db.execute.side_effect = [mock_main, mock_reg] resp = _client.get("/api/compliance/v1/canonical/crosswalk/stats") assert resp.status_code == 200 data = resp.json() assert data["total_rows"] == 500 assert data["regulations_covered"] == 9 assert data["obligations_linked"] == 200 assert data["patterns_used"] == 45 assert data["controls_linked"] == 350 assert data["coverage_by_regulation"]["eu_2016_679"] == 150 mock_db.close.assert_called_once() @patch("compliance.api.crosswalk_routes.SessionLocal") def test_empty_stats(self, mock_session_class): mock_db = MagicMock() mock_session_class.return_value = mock_db mock_main = MagicMock() mock_main.fetchone.return_value = (0, 0, 0, 0, 0) mock_reg = MagicMock() mock_reg.fetchall.return_value = [] mock_db.execute.side_effect = [mock_main, mock_reg] resp = _client.get("/api/compliance/v1/canonical/crosswalk/stats") assert resp.status_code == 200 data = resp.json() assert data["total_rows"] == 0 assert data["coverage_by_regulation"] == {} # --------------------------------------------------------------------------- # MIGRATION ENDPOINT TESTS # --------------------------------------------------------------------------- class TestMigratePass1Endpoint: """Tests for POST /migrate/link-obligations.""" @patch("compliance.api.crosswalk_routes.SessionLocal") @patch("compliance.services.pipeline_adapter.MigrationPasses") def test_pass1_success(self, MockMigration, mock_session_class): mock_db = MagicMock() mock_session_class.return_value = mock_db mock_mig = AsyncMock() MockMigration.return_value = mock_mig mock_mig.run_pass1_obligation_linkage.return_value = { "processed": 100, "linked": 60, "skipped": 40, } resp = _client.post( "/api/compliance/v1/canonical/migrate/link-obligations", json={"limit": 100}, ) assert resp.status_code == 200 data = resp.json() assert data["status"] == "completed" assert data["stats"]["linked"] == 60 mock_db.close.assert_called_once() @patch("compliance.api.crosswalk_routes.SessionLocal") @patch("compliance.services.pipeline_adapter.MigrationPasses") def test_pass1_failure(self, MockMigration, mock_session_class): mock_db = MagicMock() mock_session_class.return_value = mock_db mock_mig = AsyncMock() MockMigration.return_value = mock_mig mock_mig.run_pass1_obligation_linkage.side_effect = RuntimeError("DB connection lost") resp = _client.post( "/api/compliance/v1/canonical/migrate/link-obligations", json={}, ) assert resp.status_code == 500 assert "DB connection lost" in resp.json()["detail"] mock_db.close.assert_called_once() class TestMigratePass2Endpoint: """Tests for POST /migrate/classify-patterns.""" @patch("compliance.api.crosswalk_routes.SessionLocal") @patch("compliance.services.pipeline_adapter.MigrationPasses") def test_pass2_success(self, MockMigration, mock_session_class): mock_db = MagicMock() mock_session_class.return_value = mock_db mock_mig = AsyncMock() MockMigration.return_value = mock_mig mock_mig.run_pass2_pattern_classification.return_value = { "processed": 200, "classified": 140, "candidates": 30, "unmatched": 30, } resp = _client.post( "/api/compliance/v1/canonical/migrate/classify-patterns", json={"limit": 200}, ) assert resp.status_code == 200 data = resp.json() assert data["stats"]["classified"] == 140 class TestMigratePass3Endpoint: """Tests for POST /migrate/triage.""" @patch("compliance.api.crosswalk_routes.SessionLocal") @patch("compliance.services.pipeline_adapter.MigrationPasses") def test_pass3_success(self, MockMigration, mock_session_class): mock_db = MagicMock() mock_session_class.return_value = mock_db mock_mig = MagicMock() MockMigration.return_value = mock_mig mock_mig.run_pass3_quality_triage.return_value = { "review": 60, "needs_obligation": 20, "needs_pattern": 15, "legacy_unlinked": 5, } resp = _client.post("/api/compliance/v1/canonical/migrate/triage") assert resp.status_code == 200 data = resp.json() assert data["stats"]["review"] == 60 @patch("compliance.api.crosswalk_routes.SessionLocal") @patch("compliance.services.pipeline_adapter.MigrationPasses") def test_pass3_failure(self, MockMigration, mock_session_class): mock_db = MagicMock() mock_session_class.return_value = mock_db mock_mig = MagicMock() MockMigration.return_value = mock_mig mock_mig.run_pass3_quality_triage.side_effect = Exception("triage error") resp = _client.post("/api/compliance/v1/canonical/migrate/triage") assert resp.status_code == 500 class TestMigratePass4Endpoint: """Tests for POST /migrate/backfill-crosswalk.""" @patch("compliance.api.crosswalk_routes.SessionLocal") @patch("compliance.services.pipeline_adapter.MigrationPasses") def test_pass4_success(self, MockMigration, mock_session_class): mock_db = MagicMock() mock_session_class.return_value = mock_db mock_mig = MagicMock() MockMigration.return_value = mock_mig mock_mig.run_pass4_crosswalk_backfill.return_value = { "rows_created": 250, } resp = _client.post("/api/compliance/v1/canonical/migrate/backfill-crosswalk") assert resp.status_code == 200 data = resp.json() assert data["stats"]["rows_created"] == 250 class TestMigratePass5Endpoint: """Tests for POST /migrate/deduplicate.""" @patch("compliance.api.crosswalk_routes.SessionLocal") @patch("compliance.services.pipeline_adapter.MigrationPasses") def test_pass5_success(self, MockMigration, mock_session_class): mock_db = MagicMock() mock_session_class.return_value = mock_db mock_mig = MagicMock() MockMigration.return_value = mock_mig mock_mig.run_pass5_deduplication.return_value = { "groups_found": 80, "deprecated": 120, "kept": 80, } resp = _client.post("/api/compliance/v1/canonical/migrate/deduplicate") assert resp.status_code == 200 data = resp.json() assert data["stats"]["deprecated"] == 120 class TestMigrationStatusEndpoint: """Tests for GET /migrate/status.""" @patch("compliance.api.crosswalk_routes.SessionLocal") @patch("compliance.services.pipeline_adapter.MigrationPasses") def test_status_success(self, MockMigration, mock_session_class): mock_db = MagicMock() mock_session_class.return_value = mock_db mock_mig = MagicMock() MockMigration.return_value = mock_mig mock_mig.migration_status.return_value = { "total_controls": 4800, "has_obligation": 2880, "has_pattern": 3360, "fully_linked": 2400, "deprecated": 1200, "coverage_obligation_pct": 60.0, "coverage_pattern_pct": 70.0, "coverage_full_pct": 50.0, } resp = _client.get("/api/compliance/v1/canonical/migrate/status") assert resp.status_code == 200 data = resp.json() assert data["total_controls"] == 4800 assert data["coverage_full_pct"] == 50.0 @patch("compliance.api.crosswalk_routes.SessionLocal") @patch("compliance.services.pipeline_adapter.MigrationPasses") def test_status_failure(self, MockMigration, mock_session_class): mock_db = MagicMock() mock_session_class.return_value = mock_db mock_mig = MagicMock() MockMigration.return_value = mock_mig mock_mig.migration_status.side_effect = Exception("DB error") resp = _client.get("/api/compliance/v1/canonical/migrate/status") assert resp.status_code == 500 # --------------------------------------------------------------------------- # DECOMPOSITION ENDPOINT TESTS (Pass 0a / 0b) # --------------------------------------------------------------------------- class TestDecompositionStatusModel: """Tests for DecompositionStatusResponse model.""" def test_defaults(self): s = DecompositionStatusResponse() assert s.rich_controls == 0 assert s.decomposition_pct == 0.0 assert s.composition_pct == 0.0 def test_full(self): s = DecompositionStatusResponse( rich_controls=5000, decomposed_controls=1000, total_candidates=3000, validated=2500, rejected=200, composed=2000, atomic_controls=1800, decomposition_pct=20.0, composition_pct=80.0, ) assert s.rich_controls == 5000 assert s.atomic_controls == 1800 class TestMigrateDecomposeEndpoint: """Tests for POST /migrate/decompose (Pass 0a).""" @patch("compliance.api.crosswalk_routes.SessionLocal") @patch("compliance.services.decomposition_pass.DecompositionPass") def test_pass0a_success(self, MockDecomp, mock_session_class): mock_db = MagicMock() mock_session_class.return_value = mock_db mock_decomp = AsyncMock() MockDecomp.return_value = mock_decomp mock_decomp.run_pass0a.return_value = { "controls_processed": 50, "obligations_extracted": 180, "obligations_validated": 160, "obligations_rejected": 20, "controls_skipped_empty": 5, "errors": 0, } resp = _client.post( "/api/compliance/v1/canonical/migrate/decompose", json={"limit": 50}, ) assert resp.status_code == 200 data = resp.json() assert data["status"] == "completed" assert data["stats"]["obligations_extracted"] == 180 mock_db.close.assert_called_once() @patch("compliance.api.crosswalk_routes.SessionLocal") @patch("compliance.services.decomposition_pass.DecompositionPass") def test_pass0a_failure(self, MockDecomp, mock_session_class): mock_db = MagicMock() mock_session_class.return_value = mock_db mock_decomp = AsyncMock() MockDecomp.return_value = mock_decomp mock_decomp.run_pass0a.side_effect = RuntimeError("LLM timeout") resp = _client.post( "/api/compliance/v1/canonical/migrate/decompose", json={}, ) assert resp.status_code == 500 assert "LLM timeout" in resp.json()["detail"] class TestMigrateComposeAtomicEndpoint: """Tests for POST /migrate/compose-atomic (Pass 0b).""" @patch("compliance.api.crosswalk_routes.SessionLocal") @patch("compliance.services.decomposition_pass.DecompositionPass") def test_pass0b_success(self, MockDecomp, mock_session_class): mock_db = MagicMock() mock_session_class.return_value = mock_db mock_decomp = AsyncMock() MockDecomp.return_value = mock_decomp mock_decomp.run_pass0b.return_value = { "candidates_processed": 160, "controls_created": 155, "llm_failures": 5, "errors": 0, } resp = _client.post( "/api/compliance/v1/canonical/migrate/compose-atomic", json={"limit": 200}, ) assert resp.status_code == 200 data = resp.json() assert data["stats"]["controls_created"] == 155 class TestDecompositionStatusEndpoint: """Tests for GET /migrate/decomposition-status.""" @patch("compliance.api.crosswalk_routes.SessionLocal") @patch("compliance.services.decomposition_pass.DecompositionPass") def test_status_success(self, MockDecomp, mock_session_class): mock_db = MagicMock() mock_session_class.return_value = mock_db mock_decomp = MagicMock() MockDecomp.return_value = mock_decomp mock_decomp.decomposition_status.return_value = { "rich_controls": 5000, "decomposed_controls": 1000, "total_candidates": 3000, "validated": 2500, "rejected": 200, "composed": 2000, "atomic_controls": 1800, "decomposition_pct": 20.0, "composition_pct": 80.0, } resp = _client.get("/api/compliance/v1/canonical/migrate/decomposition-status") assert resp.status_code == 200 data = resp.json() assert data["rich_controls"] == 5000 assert data["atomic_controls"] == 1800 assert data["decomposition_pct"] == 20.0 @patch("compliance.api.crosswalk_routes.SessionLocal") @patch("compliance.services.decomposition_pass.DecompositionPass") def test_status_failure(self, MockDecomp, mock_session_class): mock_db = MagicMock() mock_session_class.return_value = mock_db mock_decomp = MagicMock() MockDecomp.return_value = mock_decomp mock_decomp.decomposition_status.side_effect = Exception("DB error") resp = _client.get("/api/compliance/v1/canonical/migrate/decomposition-status") assert resp.status_code == 500