"""Tests for Pipeline Adapter — Phase 7 of Multi-Layer Control Architecture. Validates: - PipelineChunk and PipelineResult dataclasses - PipelineAdapter.process_chunk() — full 3-stage flow - PipelineAdapter.process_batch() — batch processing - PipelineAdapter.write_crosswalk() — DB write logic (mocked) - MigrationPasses — all 5 passes (with mocked DB) - _extract_regulation_article helper - Edge cases: missing data, LLM failures, initialization """ import json from unittest.mock import AsyncMock, MagicMock, patch, call import pytest from compliance.services.pipeline_adapter import ( MigrationPasses, PipelineAdapter, PipelineChunk, PipelineResult, _extract_regulation_article, ) from compliance.services.obligation_extractor import ObligationMatch from compliance.services.pattern_matcher import ControlPattern, PatternMatchResult from compliance.services.control_composer import ComposedControl # ============================================================================= # Tests: PipelineChunk # ============================================================================= class TestPipelineChunk: def test_defaults(self): chunk = PipelineChunk(text="test") assert chunk.text == "test" assert chunk.collection == "" assert chunk.regulation_code == "" assert chunk.license_rule == 3 assert chunk.chunk_hash == "" def test_compute_hash(self): chunk = PipelineChunk(text="hello world") h = chunk.compute_hash() assert len(h) == 64 # SHA256 hex assert h == chunk.chunk_hash # cached def test_compute_hash_deterministic(self): chunk1 = PipelineChunk(text="same text") chunk2 = PipelineChunk(text="same text") assert chunk1.compute_hash() == chunk2.compute_hash() def test_compute_hash_idempotent(self): chunk = PipelineChunk(text="test") h1 = chunk.compute_hash() h2 = chunk.compute_hash() assert h1 == h2 # ============================================================================= # Tests: PipelineResult # ============================================================================= class TestPipelineResult: def test_defaults(self): chunk = PipelineChunk(text="test") result = PipelineResult(chunk=chunk) assert result.control is None assert result.crosswalk_written is False assert result.error is None def test_to_dict(self): chunk = PipelineChunk(text="test") chunk.compute_hash() result = PipelineResult( chunk=chunk, obligation=ObligationMatch( obligation_id="DSGVO-OBL-001", method="exact_match", confidence=1.0, ), pattern_result=PatternMatchResult( pattern_id="CP-AUTH-001", method="keyword", confidence=0.85, ), control=ComposedControl(title="Test Control"), ) d = result.to_dict() assert d["chunk_hash"] == chunk.chunk_hash assert d["obligation"]["obligation_id"] == "DSGVO-OBL-001" assert d["pattern"]["pattern_id"] == "CP-AUTH-001" assert d["control"]["title"] == "Test Control" assert d["error"] is None # ============================================================================= # Tests: _extract_regulation_article # ============================================================================= class TestExtractRegulationArticle: def test_from_citation_json(self): citation = json.dumps({ "source": "eu_2016_679", "article": "Art. 30", }) reg, art = _extract_regulation_article(citation, None) assert reg == "dsgvo" assert art == "Art. 30" def test_from_metadata(self): metadata = json.dumps({ "source_regulation": "eu_2024_1689", "source_article": "Art. 6", }) reg, art = _extract_regulation_article(None, metadata) assert reg == "ai_act" assert art == "Art. 6" def test_citation_takes_priority(self): citation = json.dumps({"source": "dsgvo", "article": "Art. 30"}) metadata = json.dumps({"source_regulation": "nis2", "source_article": "Art. 21"}) reg, art = _extract_regulation_article(citation, metadata) assert reg == "dsgvo" assert art == "Art. 30" def test_empty_inputs(self): reg, art = _extract_regulation_article(None, None) assert reg is None assert art is None def test_invalid_json(self): reg, art = _extract_regulation_article("not json", "also not json") assert reg is None assert art is None def test_citation_as_dict(self): citation = {"source": "bdsg", "article": "§ 38"} reg, art = _extract_regulation_article(citation, None) assert reg == "bdsg" assert art == "§ 38" def test_source_article_key(self): citation = json.dumps({"source": "dsgvo", "source_article": "Art. 32"}) reg, art = _extract_regulation_article(citation, None) assert reg == "dsgvo" assert art == "Art. 32" def test_unknown_source(self): citation = json.dumps({"source": "unknown_law", "article": "Art. 1"}) reg, art = _extract_regulation_article(citation, None) assert reg is None # _normalize_regulation returns None assert art == "Art. 1" # ============================================================================= # Tests: PipelineAdapter — process_chunk # ============================================================================= class TestPipelineAdapterProcessChunk: """Tests for the full 3-stage chunk processing.""" @pytest.mark.asyncio async def test_process_chunk_full_flow(self): """Process a chunk through all 3 stages.""" adapter = PipelineAdapter() obligation = ObligationMatch( obligation_id="DSGVO-OBL-001", obligation_title="Verarbeitungsverzeichnis", obligation_text="Fuehrung eines Verzeichnisses", method="exact_match", confidence=1.0, regulation_id="dsgvo", ) pattern_result = PatternMatchResult( pattern_id="CP-COMP-001", method="keyword", confidence=0.85, ) composed = ComposedControl( title="Test Control", objective="Test objective", pattern_id="CP-COMP-001", ) with patch.object( adapter._extractor, "initialize", new_callable=AsyncMock ), patch.object( adapter._matcher, "initialize", new_callable=AsyncMock ), patch.object( adapter._extractor, "extract", new_callable=AsyncMock, return_value=obligation, ), patch.object( adapter._matcher, "match", new_callable=AsyncMock, return_value=pattern_result, ), patch.object( adapter._composer, "compose", new_callable=AsyncMock, return_value=composed, ): adapter._initialized = True chunk = PipelineChunk( text="Art. 30 DSGVO Verarbeitungsverzeichnis", regulation_code="eu_2016_679", article="Art. 30", license_rule=1, ) result = await adapter.process_chunk(chunk) assert result.obligation.obligation_id == "DSGVO-OBL-001" assert result.pattern_result.pattern_id == "CP-COMP-001" assert result.control.title == "Test Control" assert result.error is None assert result.chunk.chunk_hash != "" @pytest.mark.asyncio async def test_process_chunk_error_handling(self): """Errors during processing should be captured, not raised.""" adapter = PipelineAdapter() adapter._initialized = True with patch.object( adapter._extractor, "extract", new_callable=AsyncMock, side_effect=Exception("LLM timeout"), ): chunk = PipelineChunk(text="test text") result = await adapter.process_chunk(chunk) assert result.error == "LLM timeout" assert result.control is None @pytest.mark.asyncio async def test_process_chunk_uses_obligation_text_for_pattern(self): """Pattern matcher should receive obligation text, not raw chunk.""" adapter = PipelineAdapter() adapter._initialized = True obligation = ObligationMatch( obligation_text="Specific obligation text", regulation_id="dsgvo", ) with patch.object( adapter._extractor, "extract", new_callable=AsyncMock, return_value=obligation, ), patch.object( adapter._matcher, "match", new_callable=AsyncMock, return_value=PatternMatchResult(), ) as mock_match, patch.object( adapter._composer, "compose", new_callable=AsyncMock, return_value=ComposedControl(), ): await adapter.process_chunk(PipelineChunk(text="raw chunk text")) # Pattern matcher should receive the obligation text mock_match.assert_called_once() call_args = mock_match.call_args assert call_args.kwargs["obligation_text"] == "Specific obligation text" @pytest.mark.asyncio async def test_process_chunk_fallback_to_chunk_text(self): """When obligation has no text, use chunk text for pattern matching.""" adapter = PipelineAdapter() adapter._initialized = True obligation = ObligationMatch() # No text with patch.object( adapter._extractor, "extract", new_callable=AsyncMock, return_value=obligation, ), patch.object( adapter._matcher, "match", new_callable=AsyncMock, return_value=PatternMatchResult(), ) as mock_match, patch.object( adapter._composer, "compose", new_callable=AsyncMock, return_value=ComposedControl(), ): await adapter.process_chunk(PipelineChunk(text="fallback chunk text")) call_args = mock_match.call_args assert "fallback chunk text" in call_args.kwargs["obligation_text"] # ============================================================================= # Tests: PipelineAdapter — process_batch # ============================================================================= class TestPipelineAdapterBatch: @pytest.mark.asyncio async def test_process_batch(self): adapter = PipelineAdapter() adapter._initialized = True with patch.object( adapter, "process_chunk", new_callable=AsyncMock, return_value=PipelineResult(chunk=PipelineChunk(text="x")), ): chunks = [PipelineChunk(text="a"), PipelineChunk(text="b")] results = await adapter.process_batch(chunks) assert len(results) == 2 @pytest.mark.asyncio async def test_process_batch_empty(self): adapter = PipelineAdapter() adapter._initialized = True results = await adapter.process_batch([]) assert results == [] # ============================================================================= # Tests: PipelineAdapter — write_crosswalk # ============================================================================= class TestWriteCrosswalk: def test_write_crosswalk_success(self): """write_crosswalk should execute 3 DB statements.""" mock_db = MagicMock() mock_db.execute = MagicMock() mock_db.commit = MagicMock() adapter = PipelineAdapter(db=mock_db) chunk = PipelineChunk( text="test", regulation_code="eu_2016_679", article="Art. 30", collection="bp_compliance_ce", ) chunk.compute_hash() result = PipelineResult( chunk=chunk, obligation=ObligationMatch( obligation_id="DSGVO-OBL-001", method="exact_match", confidence=1.0, ), pattern_result=PatternMatchResult( pattern_id="CP-COMP-001", confidence=0.85, ), control=ComposedControl( control_id="COMP-001", pattern_id="CP-COMP-001", obligation_ids=["DSGVO-OBL-001"], ), ) success = adapter.write_crosswalk(result, "uuid-123") assert success is True assert mock_db.execute.call_count == 3 # insert + insert + update mock_db.commit.assert_called_once() def test_write_crosswalk_no_db(self): adapter = PipelineAdapter(db=None) chunk = PipelineChunk(text="test") result = PipelineResult(chunk=chunk, control=ComposedControl()) assert adapter.write_crosswalk(result, "uuid") is False def test_write_crosswalk_no_control(self): mock_db = MagicMock() adapter = PipelineAdapter(db=mock_db) chunk = PipelineChunk(text="test") result = PipelineResult(chunk=chunk, control=None) assert adapter.write_crosswalk(result, "uuid") is False def test_write_crosswalk_db_error(self): mock_db = MagicMock() mock_db.execute = MagicMock(side_effect=Exception("DB error")) mock_db.rollback = MagicMock() adapter = PipelineAdapter(db=mock_db) chunk = PipelineChunk(text="test") chunk.compute_hash() result = PipelineResult( chunk=chunk, obligation=ObligationMatch(), pattern_result=PatternMatchResult(), control=ComposedControl(control_id="X-001"), ) assert adapter.write_crosswalk(result, "uuid") is False mock_db.rollback.assert_called_once() # ============================================================================= # Tests: PipelineAdapter — stats and initialization # ============================================================================= class TestPipelineAdapterInit: def test_stats_before_init(self): adapter = PipelineAdapter() stats = adapter.stats() assert stats["initialized"] is False @pytest.mark.asyncio async def test_auto_initialize(self): adapter = PipelineAdapter() with patch.object( adapter, "initialize", new_callable=AsyncMock, ) as mock_init: async def side_effect(): adapter._initialized = True mock_init.side_effect = side_effect with patch.object( adapter._extractor, "extract", new_callable=AsyncMock, return_value=ObligationMatch(), ), patch.object( adapter._matcher, "match", new_callable=AsyncMock, return_value=PatternMatchResult(), ), patch.object( adapter._composer, "compose", new_callable=AsyncMock, return_value=ComposedControl(), ): await adapter.process_chunk(PipelineChunk(text="test")) mock_init.assert_called_once() # ============================================================================= # Tests: MigrationPasses — Pass 1 (Obligation Linkage) # ============================================================================= class TestPass1ObligationLinkage: @pytest.mark.asyncio async def test_pass1_links_controls(self): """Pass 1 should link controls with matching articles to obligations.""" mock_db = MagicMock() # Simulate 2 controls: one with citation, one without mock_db.execute.return_value.fetchall.return_value = [ ( "uuid-1", "COMP-001", json.dumps({"source": "eu_2016_679", "article": "Art. 30"}), json.dumps({"source_regulation": "eu_2016_679"}), ), ( "uuid-2", "SEC-001", None, # No citation None, # No metadata ), ] migration = MigrationPasses(db=mock_db) await migration.initialize() # Reset mock after initialize queries mock_db.execute.reset_mock() mock_db.execute.return_value.fetchall.return_value = [ ( "uuid-1", "COMP-001", json.dumps({"source": "eu_2016_679", "article": "Art. 30"}), json.dumps({"source_regulation": "eu_2016_679"}), ), ( "uuid-2", "SEC-001", None, None, ), ] stats = await migration.run_pass1_obligation_linkage() assert stats["total"] == 2 assert stats["no_citation"] >= 1 @pytest.mark.asyncio async def test_pass1_with_limit(self): """Pass 1 should respect limit parameter.""" mock_db = MagicMock() mock_db.execute.return_value.fetchall.return_value = [] migration = MigrationPasses(db=mock_db) migration._initialized = True migration._extractor._load_obligations() stats = await migration.run_pass1_obligation_linkage(limit=10) assert stats["total"] == 0 # Check that LIMIT was in the SQL text clause query_call = mock_db.execute.call_args sql_text_obj = query_call[0][0] # first positional arg is the text() object assert "LIMIT" in sql_text_obj.text # ============================================================================= # Tests: MigrationPasses — Pass 2 (Pattern Classification) # ============================================================================= class TestPass2PatternClassification: @pytest.mark.asyncio async def test_pass2_classifies_controls(self): """Pass 2 should match controls to patterns via keywords.""" mock_db = MagicMock() mock_db.execute.return_value.fetchall.return_value = [ ( "uuid-1", "AUTH-001", "Passwortrichtlinie und Authentifizierung", "Sicherstellen dass Anmeldedaten credential geschuetzt sind", ), ] migration = MigrationPasses(db=mock_db) await migration.initialize() mock_db.execute.reset_mock() mock_db.execute.return_value.fetchall.return_value = [ ( "uuid-1", "AUTH-001", "Passwortrichtlinie und Authentifizierung", "Sicherstellen dass Anmeldedaten credential geschuetzt sind", ), ] stats = await migration.run_pass2_pattern_classification() assert stats["total"] == 1 # Should classify because "passwort", "authentifizierung", "anmeldedaten" are keywords assert stats["classified"] == 1 @pytest.mark.asyncio async def test_pass2_no_match(self): """Controls without keyword matches should be counted as no_match.""" mock_db = MagicMock() mock_db.execute.return_value.fetchall.return_value = [ ( "uuid-1", "MISC-001", "Completely unrelated title", "No keywords match here at all", ), ] migration = MigrationPasses(db=mock_db) await migration.initialize() mock_db.execute.reset_mock() mock_db.execute.return_value.fetchall.return_value = [ ( "uuid-1", "MISC-001", "Completely unrelated title", "No keywords match here at all", ), ] stats = await migration.run_pass2_pattern_classification() assert stats["no_match"] == 1 # ============================================================================= # Tests: MigrationPasses — Pass 3 (Quality Triage) # ============================================================================= class TestPass3QualityTriage: def test_pass3_executes_4_updates(self): """Pass 3 should execute exactly 4 UPDATE statements.""" mock_db = MagicMock() mock_result = MagicMock() mock_result.rowcount = 10 mock_db.execute.return_value = mock_result migration = MigrationPasses(db=mock_db) stats = migration.run_pass3_quality_triage() assert mock_db.execute.call_count == 4 mock_db.commit.assert_called_once() assert "review" in stats assert "needs_obligation" in stats assert "needs_pattern" in stats assert "legacy_unlinked" in stats # ============================================================================= # Tests: MigrationPasses — Pass 4 (Crosswalk Backfill) # ============================================================================= class TestPass4CrosswalkBackfill: def test_pass4_inserts_crosswalk_rows(self): mock_db = MagicMock() mock_result = MagicMock() mock_result.rowcount = 42 mock_db.execute.return_value = mock_result migration = MigrationPasses(db=mock_db) stats = migration.run_pass4_crosswalk_backfill() assert stats["rows_inserted"] == 42 mock_db.commit.assert_called_once() # ============================================================================= # Tests: MigrationPasses — Pass 5 (Deduplication) # ============================================================================= class TestPass5Deduplication: def test_pass5_no_duplicates(self): mock_db = MagicMock() mock_db.execute.return_value.fetchall.return_value = [] migration = MigrationPasses(db=mock_db) stats = migration.run_pass5_deduplication() assert stats["groups_found"] == 0 assert stats["controls_deprecated"] == 0 def test_pass5_deprecates_duplicates(self): """Pass 5 should keep first (highest confidence) and deprecate rest.""" mock_db = MagicMock() # First call: groups query returns one group with 3 controls groups_result = MagicMock() groups_result.fetchall.return_value = [ ( "CP-AUTH-001", # pattern_id "DSGVO-OBL-001", # obligation_id ["uuid-1", "uuid-2", "uuid-3"], # ids (ordered by confidence) 3, # count ), ] # Subsequent calls: UPDATE queries update_result = MagicMock() update_result.rowcount = 1 mock_db.execute.side_effect = [groups_result, update_result, update_result] migration = MigrationPasses(db=mock_db) stats = migration.run_pass5_deduplication() assert stats["groups_found"] == 1 assert stats["controls_deprecated"] == 2 # uuid-2, uuid-3 mock_db.commit.assert_called_once() # ============================================================================= # Tests: MigrationPasses — migration_status # ============================================================================= class TestMigrationStatus: def test_migration_status(self): mock_db = MagicMock() mock_db.execute.return_value.fetchone.return_value = ( 4800, # total 2880, # has_obligation (60%) 3360, # has_pattern (70%) 2400, # fully_linked (50%) 300, # deprecated ) migration = MigrationPasses(db=mock_db) status = migration.migration_status() assert status["total_controls"] == 4800 assert status["has_obligation"] == 2880 assert status["has_pattern"] == 3360 assert status["fully_linked"] == 2400 assert status["deprecated"] == 300 assert status["coverage_obligation_pct"] == 60.0 assert status["coverage_pattern_pct"] == 70.0 assert status["coverage_full_pct"] == 50.0 def test_migration_status_empty_db(self): mock_db = MagicMock() mock_db.execute.return_value.fetchone.return_value = (0, 0, 0, 0, 0) migration = MigrationPasses(db=mock_db) status = migration.migration_status() assert status["total_controls"] == 0 assert status["coverage_obligation_pct"] == 0.0