"""Tests for RAG-based Requirement Extraction endpoint. POST /compliance/extract-requirements-from-rag """ import pytest from unittest.mock import AsyncMock, MagicMock, patch from fastapi import FastAPI from fastapi.testclient import TestClient from compliance.api.extraction_routes import router as extraction_router from classroom_engine.database import get_db # --------------------------------------------------------------------------- # App setup # --------------------------------------------------------------------------- app = FastAPI() app.include_router(extraction_router) mock_db = MagicMock() def override_get_db(): yield mock_db app.dependency_overrides[get_db] = override_get_db client = TestClient(app) REG_ID = "aaaaaaaa-1111-2222-3333-aaaaaaaaaaaa" REQ_ID = "bbbbbbbb-1111-2222-3333-bbbbbbbbbbbb" # --------------------------------------------------------------------------- # RAG result helpers # --------------------------------------------------------------------------- def make_rag_result(overrides=None): r = MagicMock() r.text = "O.Purp_6 MUSS: Die Anwendung MUSS den Zweck der Verarbeitung klar benennen." r.regulation_code = "BSI-TR-03161-1" r.regulation_name = "BSI Technische Richtlinie 03161 Teil 1" r.regulation_short = "BSI-TR-03161-1" r.category = "bsi" r.article = "O.Purp_6" r.paragraph = "" r.source_url = "https://bsi.bund.de/tr03161" r.score = 0.92 if overrides: for k, v in overrides.items(): setattr(r, k, v) return r def make_regulation(overrides=None): reg = MagicMock() reg.id = REG_ID reg.code = "BSI-TR-03161-1" reg.name = "BSI-TR-03161-1" if overrides: for k, v in overrides.items(): setattr(reg, k, v) return reg def make_requirement(overrides=None): req = MagicMock() req.id = REQ_ID req.regulation_id = REG_ID req.article = "O.Purp_6" req.title = "Zweckbenennung" req.regulation = MagicMock() req.regulation.code = "BSI-TR-03161-1" if overrides: for k, v in overrides.items(): setattr(req, k, v) return req # --------------------------------------------------------------------------- # Helper: RAG mock # --------------------------------------------------------------------------- def patch_rag(results_per_call=None): """Return a patcher that makes RAG search return the given results list.""" results = [make_rag_result()] if results_per_call is None else results_per_call async def fake_search(*args, **kwargs): return results mock_client = MagicMock() mock_client.search = fake_search return patch( "compliance.api.extraction_routes.get_rag_client", return_value=mock_client, ) # --------------------------------------------------------------------------- # Basic endpoint tests # --------------------------------------------------------------------------- class TestExtractRequirementsBasic: """Basic extraction endpoint tests.""" def test_empty_rag_results(self): """When RAG returns nothing, response should report 0 created.""" with patch_rag(results_per_call=[]): response = client.post("/compliance/extract-requirements-from-rag", json={}) assert response.status_code == 200 data = response.json() assert data["created"] == 0 assert data["skipped_duplicates"] == 0 assert data["dry_run"] is False def test_dry_run_does_not_write_db(self): """dry_run=true should not call RequirementRepository.create.""" with patch_rag([make_rag_result()]), \ patch("compliance.api.extraction_routes.RegulationRepository") as MockRegRepo, \ patch("compliance.api.extraction_routes.RequirementRepository") as MockReqRepo: MockRegRepo.return_value.get_by_code.return_value = make_regulation() MockReqRepo.return_value.get_by_regulation.return_value = [] response = client.post("/compliance/extract-requirements-from-rag", json={"dry_run": True}) assert response.status_code == 200 data = response.json() assert data["dry_run"] is True assert data["created"] == 1 # would-create count # DB create must NOT be called MockReqRepo.return_value.create.assert_not_called() def test_creates_requirement_new_regulation(self): """New regulation + new requirement should be created in DB.""" rag_result = make_rag_result() new_reg = make_regulation() new_req = make_requirement() with patch_rag([rag_result]), \ patch("compliance.api.extraction_routes.RegulationRepository") as MockRegRepo, \ patch("compliance.api.extraction_routes.RequirementRepository") as MockReqRepo: # Regulation doesn't exist yet → auto-create MockRegRepo.return_value.get_by_code.return_value = None MockRegRepo.return_value.create.return_value = new_reg MockReqRepo.return_value.get_by_regulation.return_value = [] MockReqRepo.return_value.create.return_value = new_req response = client.post("/compliance/extract-requirements-from-rag", json={}) assert response.status_code == 200 data = response.json() assert data["created"] == 1 assert data["skipped_duplicates"] == 0 MockRegRepo.return_value.create.assert_called_once() MockReqRepo.return_value.create.assert_called_once() def test_skips_duplicate_requirement(self): """If article already exists for the regulation, skip it.""" rag_result = make_rag_result() existing_req = make_requirement({"article": "O.Purp_6"}) with patch_rag([rag_result]), \ patch("compliance.api.extraction_routes.RegulationRepository") as MockRegRepo, \ patch("compliance.api.extraction_routes.RequirementRepository") as MockReqRepo: MockRegRepo.return_value.get_by_code.return_value = make_regulation() MockReqRepo.return_value.get_by_regulation.return_value = [existing_req] MockReqRepo.return_value.create.return_value = MagicMock() response = client.post("/compliance/extract-requirements-from-rag", json={}) assert response.status_code == 200 data = response.json() assert data["skipped_duplicates"] == 1 assert data["created"] == 0 MockReqRepo.return_value.create.assert_not_called() def test_result_items_contain_expected_fields(self): """Requirements list should contain correct fields.""" with patch_rag([make_rag_result()]), \ patch("compliance.api.extraction_routes.RegulationRepository") as MockRegRepo, \ patch("compliance.api.extraction_routes.RequirementRepository") as MockReqRepo: MockRegRepo.return_value.get_by_code.return_value = make_regulation() MockReqRepo.return_value.get_by_regulation.return_value = [] MockReqRepo.return_value.create.return_value = make_requirement() response = client.post("/compliance/extract-requirements-from-rag", json={}) data = response.json() assert len(data["requirements"]) == 1 req = data["requirements"][0] assert req["regulation_code"] == "BSI-TR-03161-1" assert req["article"] == "O.Purp_6" assert req["action"] == "created" assert "title" in req assert "score" in req # --------------------------------------------------------------------------- # Collection and query filtering # --------------------------------------------------------------------------- class TestExtractionFilters: """Tests for collection/regulation filters.""" def test_custom_collections_passed(self): """Should only search specified collections.""" with patch_rag([]) as mock_patch: with mock_patch: response = client.post( "/compliance/extract-requirements-from-rag", json={"collections": ["bp_compliance_ce"], "dry_run": True}, ) assert response.status_code == 200 data = response.json() assert data["collections_searched"] == ["bp_compliance_ce"] def test_regulation_code_filter(self): """Results from other regulation_codes should be excluded.""" bsi_result = make_rag_result({"regulation_code": "BSI-TR-03161-1"}) gdpr_result = make_rag_result({ "regulation_code": "GDPR", "article": "Art. 32", "text": "Art. 32 Sicherheit der Verarbeitung. Der Verantwortliche MUSS geeignete Maßnahmen treffen.", }) with patch_rag([bsi_result, gdpr_result]), \ patch("compliance.api.extraction_routes.RegulationRepository") as MockRegRepo, \ patch("compliance.api.extraction_routes.RequirementRepository") as MockReqRepo: MockRegRepo.return_value.get_by_code.return_value = make_regulation() MockReqRepo.return_value.get_by_regulation.return_value = [] MockReqRepo.return_value.create.return_value = make_requirement() response = client.post( "/compliance/extract-requirements-from-rag", json={"regulation_codes": ["BSI-TR-03161-1"]}, ) assert response.status_code == 200 data = response.json() # Only BSI result should be in output for req in data["requirements"]: assert req["regulation_code"] == "BSI-TR-03161-1" def test_custom_queries_passed(self): """custom search_queries should be used.""" with patch_rag([]): response = client.post( "/compliance/extract-requirements-from-rag", json={"search_queries": ["custom query"], "dry_run": True}, ) assert response.status_code == 200 data = response.json() assert "custom query" in data["queries_used"] def test_default_queries_used_when_none(self): """When no queries given, DEFAULT_QUERIES are used.""" with patch_rag([]): response = client.post( "/compliance/extract-requirements-from-rag", json={"dry_run": True}, ) assert response.status_code == 200 data = response.json() assert len(data["queries_used"]) > 0 # --------------------------------------------------------------------------- # Deduplication + article extraction # --------------------------------------------------------------------------- class TestArticleExtraction: """Tests for article normalization from RAG chunks.""" def test_result_without_article_field_uses_text_pattern(self): """If article field is empty, extract BSI pattern from text.""" r = make_rag_result({"article": "", "text": "O.Auth_2 MUSS: Passwörter MÜSSEN gehasht sein."}) with patch_rag([r]), \ patch("compliance.api.extraction_routes.RegulationRepository") as MockRegRepo, \ patch("compliance.api.extraction_routes.RequirementRepository") as MockReqRepo: MockRegRepo.return_value.get_by_code.return_value = make_regulation() MockReqRepo.return_value.get_by_regulation.return_value = [] MockReqRepo.return_value.create.return_value = make_requirement() response = client.post("/compliance/extract-requirements-from-rag", json={}) data = response.json() created_reqs = [x for x in data["requirements"] if x["action"] == "created"] assert len(created_reqs) == 1 assert created_reqs[0]["article"] == "O.Auth_2" def test_result_with_no_article_at_all_is_skipped(self): """Results without any article identifier are skipped.""" r = make_rag_result({ "article": "", "text": "General text without any structured identifier in the document.", }) with patch_rag([r]): response = client.post("/compliance/extract-requirements-from-rag", json={}) data = response.json() assert data["skipped_no_article"] >= 1 assert data["created"] == 0 def test_intra_batch_deduplication(self): """Two results with same regulation+article should only create one requirement.""" r1 = make_rag_result({"article": "O.Purp_6"}) r2 = make_rag_result({"article": "O.Purp_6", "text": "O.Purp_6 Additional text about the same requirement."}) with patch_rag([r1, r2]), \ patch("compliance.api.extraction_routes.RegulationRepository") as MockRegRepo, \ patch("compliance.api.extraction_routes.RequirementRepository") as MockReqRepo: MockRegRepo.return_value.get_by_code.return_value = make_regulation() MockReqRepo.return_value.get_by_regulation.return_value = [] MockReqRepo.return_value.create.return_value = make_requirement() response = client.post("/compliance/extract-requirements-from-rag", json={}) data = response.json() # Only one should be created despite two results with same article assert data["created"] == 1 # --------------------------------------------------------------------------- # Regulation auto-creation # --------------------------------------------------------------------------- class TestRegulationAutoCreate: """Tests for auto-creation of regulation stubs.""" def test_existing_regulation_not_recreated(self): """If regulation already exists, create should NOT be called.""" with patch_rag([make_rag_result()]), \ patch("compliance.api.extraction_routes.RegulationRepository") as MockRegRepo, \ patch("compliance.api.extraction_routes.RequirementRepository") as MockReqRepo: MockRegRepo.return_value.get_by_code.return_value = make_regulation() MockReqRepo.return_value.get_by_regulation.return_value = [] MockReqRepo.return_value.create.return_value = make_requirement() response = client.post("/compliance/extract-requirements-from-rag", json={}) assert response.status_code == 200 MockRegRepo.return_value.create.assert_not_called() def test_unknown_regulation_is_auto_created(self): """Unknown regulation_code triggers RegulationRepository.create.""" with patch_rag([make_rag_result()]), \ patch("compliance.api.extraction_routes.RegulationRepository") as MockRegRepo, \ patch("compliance.api.extraction_routes.RequirementRepository") as MockReqRepo: MockRegRepo.return_value.get_by_code.return_value = None MockRegRepo.return_value.create.return_value = make_regulation() MockReqRepo.return_value.get_by_regulation.return_value = [] MockReqRepo.return_value.create.return_value = make_requirement() response = client.post("/compliance/extract-requirements-from-rag", json={}) assert response.status_code == 200 data = response.json() assert data["created"] == 1 MockRegRepo.return_value.create.assert_called_once() def test_multiple_regulations_in_one_run(self): """Two results from different regulations should each get their regulation processed.""" r1 = make_rag_result({"regulation_code": "BSI-TR-03161-1", "article": "O.Purp_6"}) r2 = make_rag_result({ "regulation_code": "GDPR", "article": "Art. 32", "text": "Art. 32 Sicherheit der Verarbeitung MUSS gewährleistet sein.", }) with patch_rag([r1, r2]), \ patch("compliance.api.extraction_routes.RegulationRepository") as MockRegRepo, \ patch("compliance.api.extraction_routes.RequirementRepository") as MockReqRepo: MockRegRepo.return_value.get_by_code.return_value = make_regulation() MockReqRepo.return_value.get_by_regulation.return_value = [] MockReqRepo.return_value.create.return_value = make_requirement() response = client.post("/compliance/extract-requirements-from-rag", json={}) assert response.status_code == 200 data = response.json() # Both requirements should be created assert data["created"] == 2 # --------------------------------------------------------------------------- # Response structure # --------------------------------------------------------------------------- class TestResponseStructure: """Verify the full response structure.""" def test_all_fields_present(self): with patch_rag([]): response = client.post("/compliance/extract-requirements-from-rag", json={}) assert response.status_code == 200 data = response.json() for field in [ "created", "skipped_duplicates", "skipped_no_article", "failed", "collections_searched", "queries_used", "requirements", "dry_run", "message", ]: assert field in data, f"Missing field: {field}" def test_message_contains_summary(self): with patch_rag([]): response = client.post("/compliance/extract-requirements-from-rag", json={}) data = response.json() assert "Erstellt:" in data["message"] def test_dry_run_message_prefix(self): with patch_rag([]): response = client.post("/compliance/extract-requirements-from-rag", json={"dry_run": True}) data = response.json() assert "[DRY RUN]" in data["message"]