diff --git a/backend-compliance/compliance/services/control_generator.py b/backend-compliance/compliance/services/control_generator.py index ac58a62..2b43ae8 100644 --- a/backend-compliance/compliance/services/control_generator.py +++ b/backend-compliance/compliance/services/control_generator.py @@ -806,7 +806,9 @@ class ControlGeneratorPipeline: or payload.get("source_code", "")) # Filter by regulation_code if configured - if config.regulation_filter and reg_code: + if config.regulation_filter: + if not reg_code: + continue # Skip chunks without regulation code code_lower = reg_code.lower() if not any(code_lower.startswith(f.lower()) for f in config.regulation_filter): continue @@ -852,10 +854,16 @@ class ControlGeneratorPipeline: collection, collection_total, collection_new, ) - logger.info( - "RAG scroll complete: %d total unique seen, %d new unprocessed to process", - len(seen_hashes), len(all_results), - ) + if config.regulation_filter: + logger.info( + "RAG scroll complete: %d total unique seen, %d passed regulation_filter %s", + len(seen_hashes), len(all_results), config.regulation_filter, + ) + else: + logger.info( + "RAG scroll complete: %d total unique seen, %d new unprocessed to process", + len(seen_hashes), len(all_results), + ) return all_results def _get_processed_hashes(self, hashes: list[str]) -> set[str]: diff --git a/backend-compliance/tests/test_control_generator.py b/backend-compliance/tests/test_control_generator.py index c47a698..c8b2447 100644 --- a/backend-compliance/tests/test_control_generator.py +++ b/backend-compliance/tests/test_control_generator.py @@ -1022,6 +1022,54 @@ class TestRegulationFilter: assert "nist_sp_800_218" in codes assert "amlr" not in codes + @pytest.mark.asyncio + async def test_scan_rag_filters_out_empty_regulation_code(self): + """Chunks without regulation_code must be skipped when filter is active.""" + mock_db = MagicMock() + mock_db.execute.return_value = MagicMock() + mock_db.execute.return_value.__iter__ = MagicMock(return_value=iter([])) + + qdrant_points = { + "result": { + "points": [ + {"id": "1", "payload": { + "chunk_text": "OWASP ASVS requirement for input validation " * 5, + "regulation_code": "owasp_asvs", + }}, + {"id": "2", "payload": { + "chunk_text": "Some template without regulation code at all " * 5, + # No regulation_id, regulation_code, source_id, or source_code + }}, + {"id": "3", "payload": { + "chunk_text": "Another chunk with empty regulation code value " * 5, + "regulation_code": "", + }}, + ], + "next_page_offset": None, + } + } + + with patch("compliance.services.control_generator.httpx.AsyncClient") as mock_client_cls: + mock_client = AsyncMock() + mock_client_cls.return_value.__aenter__ = AsyncMock(return_value=mock_client) + mock_client_cls.return_value.__aexit__ = AsyncMock(return_value=False) + + mock_resp = MagicMock() + mock_resp.status_code = 200 + mock_resp.json.return_value = qdrant_points + mock_client.post.return_value = mock_resp + + pipeline = ControlGeneratorPipeline(db=mock_db, rag_client=MagicMock()) + config = GeneratorConfig( + collections=["bp_compliance_ce"], + regulation_filter=["owasp_"], + ) + results = await pipeline._scan_rag(config) + + # Only the owasp chunk should pass — empty reg_code chunks are filtered out + assert len(results) == 1 + assert results[0].regulation_code == "owasp_asvs" + @pytest.mark.asyncio async def test_scan_rag_no_filter_returns_all(self): """Verify _scan_rag returns all chunks when no regulation_filter."""