All checks were successful
CI / go-lint (push) Has been skipped
CI / python-lint (push) Has been skipped
CI / nodejs-lint (push) Has been skipped
CI / test-go-ai-compliance (push) Successful in 34s
CI / test-python-backend-compliance (push) Successful in 31s
CI / test-python-document-crawler (push) Successful in 35s
CI / test-python-dsms-gateway (push) Successful in 17s
Sucht alle RAG-Kollektionen nach Prüfaspekten und legt automatisch Anforderungen in der DB an. Kernfeatures: - Durchsucht alle 6 RAG-Kollektionen parallel (bp_compliance_ce, bp_compliance_recht, bp_compliance_gesetze, bp_compliance_datenschutz, bp_dsfa_corpus, bp_legal_templates) - Erkennt BSI Prüfaspekte (O.Purp_6) im Artikel-Feld und per Regex - Dedupliziert nach (regulation_code, article) — safe to call many times - Auto-erstellt Regulations-Stubs für unbekannte regulation_codes - dry_run=true zeigt was erstellt würde ohne DB-Schreibzugriff - Optionale Filter: collections, regulation_codes, search_queries - 18 Tests (alle bestanden) - Frontend: "Aus RAG extrahieren" Button auf /sdk/requirements Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
417 lines
17 KiB
Python
417 lines
17 KiB
Python
"""Tests for RAG-based Requirement Extraction endpoint.
|
|
|
|
POST /compliance/extract-requirements-from-rag
|
|
"""
|
|
|
|
import pytest
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
from fastapi import FastAPI
|
|
from fastapi.testclient import TestClient
|
|
|
|
from compliance.api.extraction_routes import router as extraction_router
|
|
from classroom_engine.database import get_db
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# App setup
|
|
# ---------------------------------------------------------------------------
|
|
|
|
app = FastAPI()
|
|
app.include_router(extraction_router)
|
|
|
|
mock_db = MagicMock()
|
|
|
|
|
|
def override_get_db():
|
|
yield mock_db
|
|
|
|
|
|
app.dependency_overrides[get_db] = override_get_db
|
|
client = TestClient(app)
|
|
|
|
REG_ID = "aaaaaaaa-1111-2222-3333-aaaaaaaaaaaa"
|
|
REQ_ID = "bbbbbbbb-1111-2222-3333-bbbbbbbbbbbb"
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# RAG result helpers
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def make_rag_result(overrides=None):
|
|
r = MagicMock()
|
|
r.text = "O.Purp_6 MUSS: Die Anwendung MUSS den Zweck der Verarbeitung klar benennen."
|
|
r.regulation_code = "BSI-TR-03161-1"
|
|
r.regulation_name = "BSI Technische Richtlinie 03161 Teil 1"
|
|
r.regulation_short = "BSI-TR-03161-1"
|
|
r.category = "bsi"
|
|
r.article = "O.Purp_6"
|
|
r.paragraph = ""
|
|
r.source_url = "https://bsi.bund.de/tr03161"
|
|
r.score = 0.92
|
|
if overrides:
|
|
for k, v in overrides.items():
|
|
setattr(r, k, v)
|
|
return r
|
|
|
|
|
|
def make_regulation(overrides=None):
|
|
reg = MagicMock()
|
|
reg.id = REG_ID
|
|
reg.code = "BSI-TR-03161-1"
|
|
reg.name = "BSI-TR-03161-1"
|
|
if overrides:
|
|
for k, v in overrides.items():
|
|
setattr(reg, k, v)
|
|
return reg
|
|
|
|
|
|
def make_requirement(overrides=None):
|
|
req = MagicMock()
|
|
req.id = REQ_ID
|
|
req.regulation_id = REG_ID
|
|
req.article = "O.Purp_6"
|
|
req.title = "Zweckbenennung"
|
|
req.regulation = MagicMock()
|
|
req.regulation.code = "BSI-TR-03161-1"
|
|
if overrides:
|
|
for k, v in overrides.items():
|
|
setattr(req, k, v)
|
|
return req
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Helper: RAG mock
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def patch_rag(results_per_call=None):
|
|
"""Return a patcher that makes RAG search return the given results list."""
|
|
results = [make_rag_result()] if results_per_call is None else results_per_call
|
|
|
|
async def fake_search(*args, **kwargs):
|
|
return results
|
|
|
|
mock_client = MagicMock()
|
|
mock_client.search = fake_search
|
|
return patch(
|
|
"compliance.api.extraction_routes.get_rag_client",
|
|
return_value=mock_client,
|
|
)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Basic endpoint tests
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class TestExtractRequirementsBasic:
|
|
"""Basic extraction endpoint tests."""
|
|
|
|
def test_empty_rag_results(self):
|
|
"""When RAG returns nothing, response should report 0 created."""
|
|
with patch_rag(results_per_call=[]):
|
|
response = client.post("/compliance/extract-requirements-from-rag", json={})
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["created"] == 0
|
|
assert data["skipped_duplicates"] == 0
|
|
assert data["dry_run"] is False
|
|
|
|
def test_dry_run_does_not_write_db(self):
|
|
"""dry_run=true should not call RequirementRepository.create."""
|
|
with patch_rag([make_rag_result()]), \
|
|
patch("compliance.api.extraction_routes.RegulationRepository") as MockRegRepo, \
|
|
patch("compliance.api.extraction_routes.RequirementRepository") as MockReqRepo:
|
|
MockRegRepo.return_value.get_by_code.return_value = make_regulation()
|
|
MockReqRepo.return_value.get_by_regulation.return_value = []
|
|
|
|
response = client.post("/compliance/extract-requirements-from-rag", json={"dry_run": True})
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["dry_run"] is True
|
|
assert data["created"] == 1 # would-create count
|
|
# DB create must NOT be called
|
|
MockReqRepo.return_value.create.assert_not_called()
|
|
|
|
def test_creates_requirement_new_regulation(self):
|
|
"""New regulation + new requirement should be created in DB."""
|
|
rag_result = make_rag_result()
|
|
new_reg = make_regulation()
|
|
new_req = make_requirement()
|
|
|
|
with patch_rag([rag_result]), \
|
|
patch("compliance.api.extraction_routes.RegulationRepository") as MockRegRepo, \
|
|
patch("compliance.api.extraction_routes.RequirementRepository") as MockReqRepo:
|
|
# Regulation doesn't exist yet → auto-create
|
|
MockRegRepo.return_value.get_by_code.return_value = None
|
|
MockRegRepo.return_value.create.return_value = new_reg
|
|
MockReqRepo.return_value.get_by_regulation.return_value = []
|
|
MockReqRepo.return_value.create.return_value = new_req
|
|
|
|
response = client.post("/compliance/extract-requirements-from-rag", json={})
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["created"] == 1
|
|
assert data["skipped_duplicates"] == 0
|
|
MockRegRepo.return_value.create.assert_called_once()
|
|
MockReqRepo.return_value.create.assert_called_once()
|
|
|
|
def test_skips_duplicate_requirement(self):
|
|
"""If article already exists for the regulation, skip it."""
|
|
rag_result = make_rag_result()
|
|
existing_req = make_requirement({"article": "O.Purp_6"})
|
|
|
|
with patch_rag([rag_result]), \
|
|
patch("compliance.api.extraction_routes.RegulationRepository") as MockRegRepo, \
|
|
patch("compliance.api.extraction_routes.RequirementRepository") as MockReqRepo:
|
|
MockRegRepo.return_value.get_by_code.return_value = make_regulation()
|
|
MockReqRepo.return_value.get_by_regulation.return_value = [existing_req]
|
|
MockReqRepo.return_value.create.return_value = MagicMock()
|
|
|
|
response = client.post("/compliance/extract-requirements-from-rag", json={})
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["skipped_duplicates"] == 1
|
|
assert data["created"] == 0
|
|
MockReqRepo.return_value.create.assert_not_called()
|
|
|
|
def test_result_items_contain_expected_fields(self):
|
|
"""Requirements list should contain correct fields."""
|
|
with patch_rag([make_rag_result()]), \
|
|
patch("compliance.api.extraction_routes.RegulationRepository") as MockRegRepo, \
|
|
patch("compliance.api.extraction_routes.RequirementRepository") as MockReqRepo:
|
|
MockRegRepo.return_value.get_by_code.return_value = make_regulation()
|
|
MockReqRepo.return_value.get_by_regulation.return_value = []
|
|
MockReqRepo.return_value.create.return_value = make_requirement()
|
|
|
|
response = client.post("/compliance/extract-requirements-from-rag", json={})
|
|
|
|
data = response.json()
|
|
assert len(data["requirements"]) == 1
|
|
req = data["requirements"][0]
|
|
assert req["regulation_code"] == "BSI-TR-03161-1"
|
|
assert req["article"] == "O.Purp_6"
|
|
assert req["action"] == "created"
|
|
assert "title" in req
|
|
assert "score" in req
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Collection and query filtering
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class TestExtractionFilters:
|
|
"""Tests for collection/regulation filters."""
|
|
|
|
def test_custom_collections_passed(self):
|
|
"""Should only search specified collections."""
|
|
with patch_rag([]) as mock_patch:
|
|
with mock_patch:
|
|
response = client.post(
|
|
"/compliance/extract-requirements-from-rag",
|
|
json={"collections": ["bp_compliance_ce"], "dry_run": True},
|
|
)
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["collections_searched"] == ["bp_compliance_ce"]
|
|
|
|
def test_regulation_code_filter(self):
|
|
"""Results from other regulation_codes should be excluded."""
|
|
bsi_result = make_rag_result({"regulation_code": "BSI-TR-03161-1"})
|
|
gdpr_result = make_rag_result({
|
|
"regulation_code": "GDPR",
|
|
"article": "Art. 32",
|
|
"text": "Art. 32 Sicherheit der Verarbeitung. Der Verantwortliche MUSS geeignete Maßnahmen treffen.",
|
|
})
|
|
|
|
with patch_rag([bsi_result, gdpr_result]), \
|
|
patch("compliance.api.extraction_routes.RegulationRepository") as MockRegRepo, \
|
|
patch("compliance.api.extraction_routes.RequirementRepository") as MockReqRepo:
|
|
MockRegRepo.return_value.get_by_code.return_value = make_regulation()
|
|
MockReqRepo.return_value.get_by_regulation.return_value = []
|
|
MockReqRepo.return_value.create.return_value = make_requirement()
|
|
|
|
response = client.post(
|
|
"/compliance/extract-requirements-from-rag",
|
|
json={"regulation_codes": ["BSI-TR-03161-1"]},
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
# Only BSI result should be in output
|
|
for req in data["requirements"]:
|
|
assert req["regulation_code"] == "BSI-TR-03161-1"
|
|
|
|
def test_custom_queries_passed(self):
|
|
"""custom search_queries should be used."""
|
|
with patch_rag([]):
|
|
response = client.post(
|
|
"/compliance/extract-requirements-from-rag",
|
|
json={"search_queries": ["custom query"], "dry_run": True},
|
|
)
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert "custom query" in data["queries_used"]
|
|
|
|
def test_default_queries_used_when_none(self):
|
|
"""When no queries given, DEFAULT_QUERIES are used."""
|
|
with patch_rag([]):
|
|
response = client.post(
|
|
"/compliance/extract-requirements-from-rag",
|
|
json={"dry_run": True},
|
|
)
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert len(data["queries_used"]) > 0
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Deduplication + article extraction
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class TestArticleExtraction:
|
|
"""Tests for article normalization from RAG chunks."""
|
|
|
|
def test_result_without_article_field_uses_text_pattern(self):
|
|
"""If article field is empty, extract BSI pattern from text."""
|
|
r = make_rag_result({"article": "", "text": "O.Auth_2 MUSS: Passwörter MÜSSEN gehasht sein."})
|
|
|
|
with patch_rag([r]), \
|
|
patch("compliance.api.extraction_routes.RegulationRepository") as MockRegRepo, \
|
|
patch("compliance.api.extraction_routes.RequirementRepository") as MockReqRepo:
|
|
MockRegRepo.return_value.get_by_code.return_value = make_regulation()
|
|
MockReqRepo.return_value.get_by_regulation.return_value = []
|
|
MockReqRepo.return_value.create.return_value = make_requirement()
|
|
|
|
response = client.post("/compliance/extract-requirements-from-rag", json={})
|
|
|
|
data = response.json()
|
|
created_reqs = [x for x in data["requirements"] if x["action"] == "created"]
|
|
assert len(created_reqs) == 1
|
|
assert created_reqs[0]["article"] == "O.Auth_2"
|
|
|
|
def test_result_with_no_article_at_all_is_skipped(self):
|
|
"""Results without any article identifier are skipped."""
|
|
r = make_rag_result({
|
|
"article": "",
|
|
"text": "General text without any structured identifier in the document.",
|
|
})
|
|
|
|
with patch_rag([r]):
|
|
response = client.post("/compliance/extract-requirements-from-rag", json={})
|
|
|
|
data = response.json()
|
|
assert data["skipped_no_article"] >= 1
|
|
assert data["created"] == 0
|
|
|
|
def test_intra_batch_deduplication(self):
|
|
"""Two results with same regulation+article should only create one requirement."""
|
|
r1 = make_rag_result({"article": "O.Purp_6"})
|
|
r2 = make_rag_result({"article": "O.Purp_6", "text": "O.Purp_6 Additional text about the same requirement."})
|
|
|
|
with patch_rag([r1, r2]), \
|
|
patch("compliance.api.extraction_routes.RegulationRepository") as MockRegRepo, \
|
|
patch("compliance.api.extraction_routes.RequirementRepository") as MockReqRepo:
|
|
MockRegRepo.return_value.get_by_code.return_value = make_regulation()
|
|
MockReqRepo.return_value.get_by_regulation.return_value = []
|
|
MockReqRepo.return_value.create.return_value = make_requirement()
|
|
|
|
response = client.post("/compliance/extract-requirements-from-rag", json={})
|
|
|
|
data = response.json()
|
|
# Only one should be created despite two results with same article
|
|
assert data["created"] == 1
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Regulation auto-creation
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class TestRegulationAutoCreate:
|
|
"""Tests for auto-creation of regulation stubs."""
|
|
|
|
def test_existing_regulation_not_recreated(self):
|
|
"""If regulation already exists, create should NOT be called."""
|
|
with patch_rag([make_rag_result()]), \
|
|
patch("compliance.api.extraction_routes.RegulationRepository") as MockRegRepo, \
|
|
patch("compliance.api.extraction_routes.RequirementRepository") as MockReqRepo:
|
|
MockRegRepo.return_value.get_by_code.return_value = make_regulation()
|
|
MockReqRepo.return_value.get_by_regulation.return_value = []
|
|
MockReqRepo.return_value.create.return_value = make_requirement()
|
|
|
|
response = client.post("/compliance/extract-requirements-from-rag", json={})
|
|
|
|
assert response.status_code == 200
|
|
MockRegRepo.return_value.create.assert_not_called()
|
|
|
|
def test_unknown_regulation_is_auto_created(self):
|
|
"""Unknown regulation_code triggers RegulationRepository.create."""
|
|
with patch_rag([make_rag_result()]), \
|
|
patch("compliance.api.extraction_routes.RegulationRepository") as MockRegRepo, \
|
|
patch("compliance.api.extraction_routes.RequirementRepository") as MockReqRepo:
|
|
MockRegRepo.return_value.get_by_code.return_value = None
|
|
MockRegRepo.return_value.create.return_value = make_regulation()
|
|
MockReqRepo.return_value.get_by_regulation.return_value = []
|
|
MockReqRepo.return_value.create.return_value = make_requirement()
|
|
|
|
response = client.post("/compliance/extract-requirements-from-rag", json={})
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["created"] == 1
|
|
MockRegRepo.return_value.create.assert_called_once()
|
|
|
|
def test_multiple_regulations_in_one_run(self):
|
|
"""Two results from different regulations should each get their regulation processed."""
|
|
r1 = make_rag_result({"regulation_code": "BSI-TR-03161-1", "article": "O.Purp_6"})
|
|
r2 = make_rag_result({
|
|
"regulation_code": "GDPR",
|
|
"article": "Art. 32",
|
|
"text": "Art. 32 Sicherheit der Verarbeitung MUSS gewährleistet sein.",
|
|
})
|
|
|
|
with patch_rag([r1, r2]), \
|
|
patch("compliance.api.extraction_routes.RegulationRepository") as MockRegRepo, \
|
|
patch("compliance.api.extraction_routes.RequirementRepository") as MockReqRepo:
|
|
MockRegRepo.return_value.get_by_code.return_value = make_regulation()
|
|
MockReqRepo.return_value.get_by_regulation.return_value = []
|
|
MockReqRepo.return_value.create.return_value = make_requirement()
|
|
|
|
response = client.post("/compliance/extract-requirements-from-rag", json={})
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
# Both requirements should be created
|
|
assert data["created"] == 2
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Response structure
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class TestResponseStructure:
|
|
"""Verify the full response structure."""
|
|
|
|
def test_all_fields_present(self):
|
|
with patch_rag([]):
|
|
response = client.post("/compliance/extract-requirements-from-rag", json={})
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
for field in [
|
|
"created", "skipped_duplicates", "skipped_no_article",
|
|
"failed", "collections_searched", "queries_used",
|
|
"requirements", "dry_run", "message",
|
|
]:
|
|
assert field in data, f"Missing field: {field}"
|
|
|
|
def test_message_contains_summary(self):
|
|
with patch_rag([]):
|
|
response = client.post("/compliance/extract-requirements-from-rag", json={})
|
|
data = response.json()
|
|
assert "Erstellt:" in data["message"]
|
|
|
|
def test_dry_run_message_prefix(self):
|
|
with patch_rag([]):
|
|
response = client.post("/compliance/extract-requirements-from-rag", json={"dry_run": True})
|
|
data = response.json()
|
|
assert "[DRY RUN]" in data["message"]
|