Files
breakpilot-compliance/backend-compliance/tests/test_extraction_routes.py
Benjamin Admin 3ed8300daf
All checks were successful
CI / go-lint (push) Has been skipped
CI / python-lint (push) Has been skipped
CI / nodejs-lint (push) Has been skipped
CI / test-go-ai-compliance (push) Successful in 34s
CI / test-python-backend-compliance (push) Successful in 31s
CI / test-python-document-crawler (push) Successful in 35s
CI / test-python-dsms-gateway (push) Successful in 17s
feat(extraction): POST /compliance/extract-requirements-from-rag
Sucht alle RAG-Kollektionen nach Prüfaspekten und legt automatisch
Anforderungen in der DB an. Kernfeatures:

- Durchsucht alle 6 RAG-Kollektionen parallel (bp_compliance_ce,
  bp_compliance_recht, bp_compliance_gesetze, bp_compliance_datenschutz,
  bp_dsfa_corpus, bp_legal_templates)
- Erkennt BSI Prüfaspekte (O.Purp_6) im Artikel-Feld und per Regex
- Dedupliziert nach (regulation_code, article) — safe to call many times
- Auto-erstellt Regulations-Stubs für unbekannte regulation_codes
- dry_run=true zeigt was erstellt würde ohne DB-Schreibzugriff
- Optionale Filter: collections, regulation_codes, search_queries
- 18 Tests (alle bestanden)
- Frontend: "Aus RAG extrahieren" Button auf /sdk/requirements

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-05 15:11:10 +01:00

417 lines
17 KiB
Python

"""Tests for RAG-based Requirement Extraction endpoint.
POST /compliance/extract-requirements-from-rag
"""
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from fastapi import FastAPI
from fastapi.testclient import TestClient
from compliance.api.extraction_routes import router as extraction_router
from classroom_engine.database import get_db
# ---------------------------------------------------------------------------
# App setup
# ---------------------------------------------------------------------------
app = FastAPI()
app.include_router(extraction_router)
mock_db = MagicMock()
def override_get_db():
yield mock_db
app.dependency_overrides[get_db] = override_get_db
client = TestClient(app)
REG_ID = "aaaaaaaa-1111-2222-3333-aaaaaaaaaaaa"
REQ_ID = "bbbbbbbb-1111-2222-3333-bbbbbbbbbbbb"
# ---------------------------------------------------------------------------
# RAG result helpers
# ---------------------------------------------------------------------------
def make_rag_result(overrides=None):
r = MagicMock()
r.text = "O.Purp_6 MUSS: Die Anwendung MUSS den Zweck der Verarbeitung klar benennen."
r.regulation_code = "BSI-TR-03161-1"
r.regulation_name = "BSI Technische Richtlinie 03161 Teil 1"
r.regulation_short = "BSI-TR-03161-1"
r.category = "bsi"
r.article = "O.Purp_6"
r.paragraph = ""
r.source_url = "https://bsi.bund.de/tr03161"
r.score = 0.92
if overrides:
for k, v in overrides.items():
setattr(r, k, v)
return r
def make_regulation(overrides=None):
reg = MagicMock()
reg.id = REG_ID
reg.code = "BSI-TR-03161-1"
reg.name = "BSI-TR-03161-1"
if overrides:
for k, v in overrides.items():
setattr(reg, k, v)
return reg
def make_requirement(overrides=None):
req = MagicMock()
req.id = REQ_ID
req.regulation_id = REG_ID
req.article = "O.Purp_6"
req.title = "Zweckbenennung"
req.regulation = MagicMock()
req.regulation.code = "BSI-TR-03161-1"
if overrides:
for k, v in overrides.items():
setattr(req, k, v)
return req
# ---------------------------------------------------------------------------
# Helper: RAG mock
# ---------------------------------------------------------------------------
def patch_rag(results_per_call=None):
"""Return a patcher that makes RAG search return the given results list."""
results = [make_rag_result()] if results_per_call is None else results_per_call
async def fake_search(*args, **kwargs):
return results
mock_client = MagicMock()
mock_client.search = fake_search
return patch(
"compliance.api.extraction_routes.get_rag_client",
return_value=mock_client,
)
# ---------------------------------------------------------------------------
# Basic endpoint tests
# ---------------------------------------------------------------------------
class TestExtractRequirementsBasic:
"""Basic extraction endpoint tests."""
def test_empty_rag_results(self):
"""When RAG returns nothing, response should report 0 created."""
with patch_rag(results_per_call=[]):
response = client.post("/compliance/extract-requirements-from-rag", json={})
assert response.status_code == 200
data = response.json()
assert data["created"] == 0
assert data["skipped_duplicates"] == 0
assert data["dry_run"] is False
def test_dry_run_does_not_write_db(self):
"""dry_run=true should not call RequirementRepository.create."""
with patch_rag([make_rag_result()]), \
patch("compliance.api.extraction_routes.RegulationRepository") as MockRegRepo, \
patch("compliance.api.extraction_routes.RequirementRepository") as MockReqRepo:
MockRegRepo.return_value.get_by_code.return_value = make_regulation()
MockReqRepo.return_value.get_by_regulation.return_value = []
response = client.post("/compliance/extract-requirements-from-rag", json={"dry_run": True})
assert response.status_code == 200
data = response.json()
assert data["dry_run"] is True
assert data["created"] == 1 # would-create count
# DB create must NOT be called
MockReqRepo.return_value.create.assert_not_called()
def test_creates_requirement_new_regulation(self):
"""New regulation + new requirement should be created in DB."""
rag_result = make_rag_result()
new_reg = make_regulation()
new_req = make_requirement()
with patch_rag([rag_result]), \
patch("compliance.api.extraction_routes.RegulationRepository") as MockRegRepo, \
patch("compliance.api.extraction_routes.RequirementRepository") as MockReqRepo:
# Regulation doesn't exist yet → auto-create
MockRegRepo.return_value.get_by_code.return_value = None
MockRegRepo.return_value.create.return_value = new_reg
MockReqRepo.return_value.get_by_regulation.return_value = []
MockReqRepo.return_value.create.return_value = new_req
response = client.post("/compliance/extract-requirements-from-rag", json={})
assert response.status_code == 200
data = response.json()
assert data["created"] == 1
assert data["skipped_duplicates"] == 0
MockRegRepo.return_value.create.assert_called_once()
MockReqRepo.return_value.create.assert_called_once()
def test_skips_duplicate_requirement(self):
"""If article already exists for the regulation, skip it."""
rag_result = make_rag_result()
existing_req = make_requirement({"article": "O.Purp_6"})
with patch_rag([rag_result]), \
patch("compliance.api.extraction_routes.RegulationRepository") as MockRegRepo, \
patch("compliance.api.extraction_routes.RequirementRepository") as MockReqRepo:
MockRegRepo.return_value.get_by_code.return_value = make_regulation()
MockReqRepo.return_value.get_by_regulation.return_value = [existing_req]
MockReqRepo.return_value.create.return_value = MagicMock()
response = client.post("/compliance/extract-requirements-from-rag", json={})
assert response.status_code == 200
data = response.json()
assert data["skipped_duplicates"] == 1
assert data["created"] == 0
MockReqRepo.return_value.create.assert_not_called()
def test_result_items_contain_expected_fields(self):
"""Requirements list should contain correct fields."""
with patch_rag([make_rag_result()]), \
patch("compliance.api.extraction_routes.RegulationRepository") as MockRegRepo, \
patch("compliance.api.extraction_routes.RequirementRepository") as MockReqRepo:
MockRegRepo.return_value.get_by_code.return_value = make_regulation()
MockReqRepo.return_value.get_by_regulation.return_value = []
MockReqRepo.return_value.create.return_value = make_requirement()
response = client.post("/compliance/extract-requirements-from-rag", json={})
data = response.json()
assert len(data["requirements"]) == 1
req = data["requirements"][0]
assert req["regulation_code"] == "BSI-TR-03161-1"
assert req["article"] == "O.Purp_6"
assert req["action"] == "created"
assert "title" in req
assert "score" in req
# ---------------------------------------------------------------------------
# Collection and query filtering
# ---------------------------------------------------------------------------
class TestExtractionFilters:
"""Tests for collection/regulation filters."""
def test_custom_collections_passed(self):
"""Should only search specified collections."""
with patch_rag([]) as mock_patch:
with mock_patch:
response = client.post(
"/compliance/extract-requirements-from-rag",
json={"collections": ["bp_compliance_ce"], "dry_run": True},
)
assert response.status_code == 200
data = response.json()
assert data["collections_searched"] == ["bp_compliance_ce"]
def test_regulation_code_filter(self):
"""Results from other regulation_codes should be excluded."""
bsi_result = make_rag_result({"regulation_code": "BSI-TR-03161-1"})
gdpr_result = make_rag_result({
"regulation_code": "GDPR",
"article": "Art. 32",
"text": "Art. 32 Sicherheit der Verarbeitung. Der Verantwortliche MUSS geeignete Maßnahmen treffen.",
})
with patch_rag([bsi_result, gdpr_result]), \
patch("compliance.api.extraction_routes.RegulationRepository") as MockRegRepo, \
patch("compliance.api.extraction_routes.RequirementRepository") as MockReqRepo:
MockRegRepo.return_value.get_by_code.return_value = make_regulation()
MockReqRepo.return_value.get_by_regulation.return_value = []
MockReqRepo.return_value.create.return_value = make_requirement()
response = client.post(
"/compliance/extract-requirements-from-rag",
json={"regulation_codes": ["BSI-TR-03161-1"]},
)
assert response.status_code == 200
data = response.json()
# Only BSI result should be in output
for req in data["requirements"]:
assert req["regulation_code"] == "BSI-TR-03161-1"
def test_custom_queries_passed(self):
"""custom search_queries should be used."""
with patch_rag([]):
response = client.post(
"/compliance/extract-requirements-from-rag",
json={"search_queries": ["custom query"], "dry_run": True},
)
assert response.status_code == 200
data = response.json()
assert "custom query" in data["queries_used"]
def test_default_queries_used_when_none(self):
"""When no queries given, DEFAULT_QUERIES are used."""
with patch_rag([]):
response = client.post(
"/compliance/extract-requirements-from-rag",
json={"dry_run": True},
)
assert response.status_code == 200
data = response.json()
assert len(data["queries_used"]) > 0
# ---------------------------------------------------------------------------
# Deduplication + article extraction
# ---------------------------------------------------------------------------
class TestArticleExtraction:
"""Tests for article normalization from RAG chunks."""
def test_result_without_article_field_uses_text_pattern(self):
"""If article field is empty, extract BSI pattern from text."""
r = make_rag_result({"article": "", "text": "O.Auth_2 MUSS: Passwörter MÜSSEN gehasht sein."})
with patch_rag([r]), \
patch("compliance.api.extraction_routes.RegulationRepository") as MockRegRepo, \
patch("compliance.api.extraction_routes.RequirementRepository") as MockReqRepo:
MockRegRepo.return_value.get_by_code.return_value = make_regulation()
MockReqRepo.return_value.get_by_regulation.return_value = []
MockReqRepo.return_value.create.return_value = make_requirement()
response = client.post("/compliance/extract-requirements-from-rag", json={})
data = response.json()
created_reqs = [x for x in data["requirements"] if x["action"] == "created"]
assert len(created_reqs) == 1
assert created_reqs[0]["article"] == "O.Auth_2"
def test_result_with_no_article_at_all_is_skipped(self):
"""Results without any article identifier are skipped."""
r = make_rag_result({
"article": "",
"text": "General text without any structured identifier in the document.",
})
with patch_rag([r]):
response = client.post("/compliance/extract-requirements-from-rag", json={})
data = response.json()
assert data["skipped_no_article"] >= 1
assert data["created"] == 0
def test_intra_batch_deduplication(self):
"""Two results with same regulation+article should only create one requirement."""
r1 = make_rag_result({"article": "O.Purp_6"})
r2 = make_rag_result({"article": "O.Purp_6", "text": "O.Purp_6 Additional text about the same requirement."})
with patch_rag([r1, r2]), \
patch("compliance.api.extraction_routes.RegulationRepository") as MockRegRepo, \
patch("compliance.api.extraction_routes.RequirementRepository") as MockReqRepo:
MockRegRepo.return_value.get_by_code.return_value = make_regulation()
MockReqRepo.return_value.get_by_regulation.return_value = []
MockReqRepo.return_value.create.return_value = make_requirement()
response = client.post("/compliance/extract-requirements-from-rag", json={})
data = response.json()
# Only one should be created despite two results with same article
assert data["created"] == 1
# ---------------------------------------------------------------------------
# Regulation auto-creation
# ---------------------------------------------------------------------------
class TestRegulationAutoCreate:
"""Tests for auto-creation of regulation stubs."""
def test_existing_regulation_not_recreated(self):
"""If regulation already exists, create should NOT be called."""
with patch_rag([make_rag_result()]), \
patch("compliance.api.extraction_routes.RegulationRepository") as MockRegRepo, \
patch("compliance.api.extraction_routes.RequirementRepository") as MockReqRepo:
MockRegRepo.return_value.get_by_code.return_value = make_regulation()
MockReqRepo.return_value.get_by_regulation.return_value = []
MockReqRepo.return_value.create.return_value = make_requirement()
response = client.post("/compliance/extract-requirements-from-rag", json={})
assert response.status_code == 200
MockRegRepo.return_value.create.assert_not_called()
def test_unknown_regulation_is_auto_created(self):
"""Unknown regulation_code triggers RegulationRepository.create."""
with patch_rag([make_rag_result()]), \
patch("compliance.api.extraction_routes.RegulationRepository") as MockRegRepo, \
patch("compliance.api.extraction_routes.RequirementRepository") as MockReqRepo:
MockRegRepo.return_value.get_by_code.return_value = None
MockRegRepo.return_value.create.return_value = make_regulation()
MockReqRepo.return_value.get_by_regulation.return_value = []
MockReqRepo.return_value.create.return_value = make_requirement()
response = client.post("/compliance/extract-requirements-from-rag", json={})
assert response.status_code == 200
data = response.json()
assert data["created"] == 1
MockRegRepo.return_value.create.assert_called_once()
def test_multiple_regulations_in_one_run(self):
"""Two results from different regulations should each get their regulation processed."""
r1 = make_rag_result({"regulation_code": "BSI-TR-03161-1", "article": "O.Purp_6"})
r2 = make_rag_result({
"regulation_code": "GDPR",
"article": "Art. 32",
"text": "Art. 32 Sicherheit der Verarbeitung MUSS gewährleistet sein.",
})
with patch_rag([r1, r2]), \
patch("compliance.api.extraction_routes.RegulationRepository") as MockRegRepo, \
patch("compliance.api.extraction_routes.RequirementRepository") as MockReqRepo:
MockRegRepo.return_value.get_by_code.return_value = make_regulation()
MockReqRepo.return_value.get_by_regulation.return_value = []
MockReqRepo.return_value.create.return_value = make_requirement()
response = client.post("/compliance/extract-requirements-from-rag", json={})
assert response.status_code == 200
data = response.json()
# Both requirements should be created
assert data["created"] == 2
# ---------------------------------------------------------------------------
# Response structure
# ---------------------------------------------------------------------------
class TestResponseStructure:
"""Verify the full response structure."""
def test_all_fields_present(self):
with patch_rag([]):
response = client.post("/compliance/extract-requirements-from-rag", json={})
assert response.status_code == 200
data = response.json()
for field in [
"created", "skipped_duplicates", "skipped_no_article",
"failed", "collections_searched", "queries_used",
"requirements", "dry_run", "message",
]:
assert field in data, f"Missing field: {field}"
def test_message_contains_summary(self):
with patch_rag([]):
response = client.post("/compliance/extract-requirements-from-rag", json={})
data = response.json()
assert "Erstellt:" in data["message"]
def test_dry_run_message_prefix(self):
with patch_rag([]):
response = client.post("/compliance/extract-requirements-from-rag", json={"dry_run": True})
data = response.json()
assert "[DRY RUN]" in data["message"]