feat(extraction): POST /compliance/extract-requirements-from-rag
All checks were successful
CI / go-lint (push) Has been skipped
CI / python-lint (push) Has been skipped
CI / nodejs-lint (push) Has been skipped
CI / test-go-ai-compliance (push) Successful in 34s
CI / test-python-backend-compliance (push) Successful in 31s
CI / test-python-document-crawler (push) Successful in 35s
CI / test-python-dsms-gateway (push) Successful in 17s

Sucht alle RAG-Kollektionen nach Prüfaspekten und legt automatisch
Anforderungen in der DB an. Kernfeatures:

- Durchsucht alle 6 RAG-Kollektionen parallel (bp_compliance_ce,
  bp_compliance_recht, bp_compliance_gesetze, bp_compliance_datenschutz,
  bp_dsfa_corpus, bp_legal_templates)
- Erkennt BSI Prüfaspekte (O.Purp_6) im Artikel-Feld und per Regex
- Dedupliziert nach (regulation_code, article) — safe to call many times
- Auto-erstellt Regulations-Stubs für unbekannte regulation_codes
- dry_run=true zeigt was erstellt würde ohne DB-Schreibzugriff
- Optionale Filter: collections, regulation_codes, search_queries
- 18 Tests (alle bestanden)
- Frontend: "Aus RAG extrahieren" Button auf /sdk/requirements

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
Benjamin Admin
2026-03-05 15:11:10 +01:00
parent f3ccfe5dcd
commit 3ed8300daf
4 changed files with 866 additions and 9 deletions

View File

@@ -404,6 +404,50 @@ export default function RequirementsPage() {
const [error, setError] = useState<string | null>(null)
const [showAddForm, setShowAddForm] = useState(false)
const [expandedId, setExpandedId] = useState<string | null>(null)
const [ragExtracting, setRagExtracting] = useState(false)
const [ragResult, setRagResult] = useState<{ created: number; skipped_duplicates: number; message: string } | null>(null)
const extractFromRAG = async () => {
setRagExtracting(true)
setRagResult(null)
try {
const res = await fetch('/api/sdk/v1/compliance/extract-requirements-from-rag', {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ max_per_query: 20 }),
})
if (res.ok) {
const data = await res.json()
setRagResult({ created: data.created, skipped_duplicates: data.skipped_duplicates, message: data.message })
// Reload requirements list
const listRes = await fetch('/api/sdk/v1/compliance/requirements')
if (listRes.ok) {
const listData = await listRes.json()
const reqs = listData.requirements || listData
if (Array.isArray(reqs) && reqs.length > 0) {
const mapped = reqs.map((r: Record<string, unknown>) => ({
id: (r.requirement_id || r.id) as string,
regulation: (r.regulation_code || r.regulation || '') as string,
article: (r.article || '') as string,
title: (r.title || '') as string,
description: (r.description || '') as string,
criticality: ((r.criticality || r.priority || 'MEDIUM') as string).toUpperCase() as import('@/lib/sdk').RiskSeverity,
applicableModules: [] as string[],
status: 'NOT_STARTED' as import('@/lib/sdk').RequirementStatus,
controls: [] as string[],
}))
dispatch({ type: 'SET_STATE', payload: { requirements: mapped } })
}
}
} else {
setRagResult({ created: 0, skipped_duplicates: 0, message: 'RAG-Extraktion fehlgeschlagen' })
}
} catch {
setRagResult({ created: 0, skipped_duplicates: 0, message: 'RAG-Extraktion nicht erreichbar' })
} finally {
setRagExtracting(false)
}
}
// Fetch requirements from backend on mount
useEffect(() => {
@@ -626,17 +670,46 @@ export default function RequirementsPage() {
explanation={stepInfo.explanation}
tips={stepInfo.tips}
>
<button
onClick={() => setShowAddForm(true)}
className="flex items-center gap-2 px-4 py-2 bg-purple-600 text-white rounded-lg hover:bg-purple-700 transition-colors"
>
<svg className="w-5 h-5" fill="none" stroke="currentColor" viewBox="0 0 24 24">
<path strokeLinecap="round" strokeLinejoin="round" strokeWidth={2} d="M12 6v6m0 0v6m0-6h6m-6 0H6" />
</svg>
Anforderung hinzufuegen
</button>
<div className="flex items-center gap-2">
<button
onClick={extractFromRAG}
disabled={ragExtracting}
className="flex items-center gap-2 px-4 py-2 bg-indigo-600 text-white rounded-lg hover:bg-indigo-700 disabled:opacity-60 transition-colors"
>
{ragExtracting ? (
<svg className="w-4 h-4 animate-spin" fill="none" viewBox="0 0 24 24">
<circle className="opacity-25" cx="12" cy="12" r="10" stroke="currentColor" strokeWidth="4" />
<path className="opacity-75" fill="currentColor" d="M4 12a8 8 0 018-8v8H4z" />
</svg>
) : (
<svg className="w-4 h-4" fill="none" stroke="currentColor" viewBox="0 0 24 24">
<path strokeLinecap="round" strokeLinejoin="round" strokeWidth={2} d="M9.663 17h4.673M12 3v1m6.364 1.636l-.707.707M21 12h-1M4 12H3m3.343-5.657l-.707-.707m2.828 9.9a5 5 0 117.072 0l-.347.347a3.5 3.5 0 01-4.95 0l-.347-.347z" />
</svg>
)}
Aus RAG extrahieren
</button>
<button
onClick={() => setShowAddForm(true)}
className="flex items-center gap-2 px-4 py-2 bg-purple-600 text-white rounded-lg hover:bg-purple-700 transition-colors"
>
<svg className="w-5 h-5" fill="none" stroke="currentColor" viewBox="0 0 24 24">
<path strokeLinecap="round" strokeLinejoin="round" strokeWidth={2} d="M12 6v6m0 0v6m0-6h6m-6 0H6" />
</svg>
Anforderung hinzufuegen
</button>
</div>
</StepHeader>
{/* RAG Extraction Result Banner */}
{ragResult && (
<div className={`flex items-center justify-between p-3 rounded-lg border ${ragResult.created > 0 ? 'bg-green-50 border-green-200' : 'bg-blue-50 border-blue-200'}`}>
<span className="text-sm">
{ragResult.created > 0 ? '✅' : ''} {ragResult.message}
</span>
<button onClick={() => setRagResult(null)} className="text-gray-400 hover:text-gray-600 ml-4">&times;</button>
</div>
)}
{/* Add Form */}
{showAddForm && (
<AddRequirementForm

View File

@@ -25,6 +25,7 @@ from .dsfa_routes import router as dsfa_router
from .dsr_routes import router as dsr_router
from .email_template_routes import router as email_template_router
from .banner_routes import router as banner_router
from .extraction_routes import router as extraction_router
# Include sub-routers
router.include_router(audit_router)
@@ -51,6 +52,7 @@ router.include_router(dsfa_router)
router.include_router(dsr_router)
router.include_router(email_template_router)
router.include_router(banner_router)
router.include_router(extraction_router)
__all__ = [
"router",

View File

@@ -0,0 +1,366 @@
"""
FastAPI routes for RAG-based Requirement Extraction.
Endpoints:
- POST /compliance/extract-requirements-from-rag:
Searches ALL RAG collections (or a subset) for audit criteria / Prüfaspekte
and creates Requirement entries in the DB.
Design principles:
- Searches every relevant collection in parallel
- Deduplicates by (regulation_id, article) — never inserts twice
- Auto-creates Regulation stubs for unknown regulation_codes
- LLM-free by default (fast); optional LLM title extraction via ?use_llm=true
- dry_run=true returns what would be created without touching the DB
"""
import logging
import re
import asyncio
from typing import Optional, List, Dict
from datetime import datetime
from fastapi import APIRouter, Depends, Query
from pydantic import BaseModel
from sqlalchemy.orm import Session
from classroom_engine.database import get_db
from ..db import RegulationRepository, RequirementRepository
from ..db.models import RegulationDB, RequirementDB, RegulationTypeEnum
from ..services.rag_client import get_rag_client, RAGSearchResult
logger = logging.getLogger(__name__)
router = APIRouter(tags=["extraction"])
# ---------------------------------------------------------------------------
# Collections that may contain Prüfaspekte / audit criteria
# ---------------------------------------------------------------------------
ALL_COLLECTIONS = [
"bp_compliance_ce", # BSI-TR documents — primary Prüfaspekte source
"bp_compliance_recht", # Legal texts (GDPR, AI Act, ...)
"bp_compliance_gesetze", # German laws
"bp_compliance_datenschutz", # Data protection documents
"bp_dsfa_corpus", # DSFA corpus
"bp_legal_templates", # Legal templates
]
# Search queries targeting audit criteria across different document types
DEFAULT_QUERIES = [
"Prüfaspekt Anforderung MUSS SOLL",
"security requirement SHALL MUST",
"compliance requirement audit criterion",
"Sicherheitsanforderung technische Maßnahme",
"data protection requirement obligation",
]
# ---------------------------------------------------------------------------
# Schemas
# ---------------------------------------------------------------------------
class ExtractionRequest(BaseModel):
collections: Optional[List[str]] = None # None = ALL_COLLECTIONS
search_queries: Optional[List[str]] = None # None = DEFAULT_QUERIES
regulation_codes: Optional[List[str]] = None # None = all regulations
max_per_query: int = 20 # top_k per search query
dry_run: bool = False # if True: no DB writes
class ExtractedRequirement(BaseModel):
regulation_code: str
article: str
title: str
requirement_text: str
source_url: str
score: float
action: str # "created" | "skipped_duplicate" | "skipped_no_article"
class ExtractionResponse(BaseModel):
created: int
skipped_duplicates: int
skipped_no_article: int
failed: int
collections_searched: List[str]
queries_used: List[str]
requirements: List[ExtractedRequirement]
dry_run: bool
message: str
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
BSI_ASPECT_RE = re.compile(r"\b([A-Z]\.[A-Za-z]+_\d+)\b")
TITLE_SENTENCE_RE = re.compile(r"^([^.!?\n]{10,120})[.!?\n]")
def _derive_title(text: str, article: str) -> str:
"""Extract a short title from RAG chunk text."""
# Remove leading article reference if present
cleaned = re.sub(r"^" + re.escape(article) + r"[:\s]+", "", text.strip(), flags=re.IGNORECASE)
# Take first meaningful sentence
m = TITLE_SENTENCE_RE.match(cleaned)
if m:
return m.group(1).strip()[:200]
# Fallback: first 100 chars
return cleaned[:100].strip() or article
def _normalize_article(result: RAGSearchResult) -> Optional[str]:
"""
Return a canonical article identifier from the RAG result.
Returns None if no meaningful article can be determined.
"""
article = (result.article or "").strip()
if article:
return article
# Try to find BSI Prüfaspekt pattern in chunk text
m = BSI_ASPECT_RE.search(result.text)
if m:
return m.group(1)
return None
async def _search_collection(
collection: str,
queries: List[str],
max_per_query: int,
) -> List[RAGSearchResult]:
"""Run all queries against one collection and merge deduplicated results."""
rag = get_rag_client()
seen_texts: set[str] = set()
results: List[RAGSearchResult] = []
for query in queries:
hits = await rag.search(query, collection=collection, top_k=max_per_query)
for h in hits:
key = h.text[:120] # rough dedup key
if key not in seen_texts:
seen_texts.add(key)
results.append(h)
return results
def _get_or_create_regulation(
db: Session,
regulation_code: str,
regulation_name: str,
) -> RegulationDB:
"""Return existing Regulation or create a stub."""
repo = RegulationRepository(db)
reg = repo.get_by_code(regulation_code)
if reg:
return reg
# Auto-create a stub so Requirements can reference it
logger.info("Auto-creating regulation stub: %s", regulation_code)
# Infer type from code prefix
if regulation_code.startswith("BSI"):
reg_type = RegulationTypeEnum.BSI_STANDARD
elif regulation_code in ("GDPR", "AI_ACT", "NIS2", "CRA"):
reg_type = RegulationTypeEnum.EU_REGULATION
else:
reg_type = RegulationTypeEnum.INDUSTRY_STANDARD
reg = repo.create(
code=regulation_code,
name=regulation_name or regulation_code,
regulation_type=reg_type,
description=f"Auto-created from RAG extraction ({datetime.utcnow().date()})",
)
return reg
def _build_existing_articles(
db: Session, regulation_id: str
) -> set[str]:
"""Return set of existing article strings for this regulation."""
repo = RequirementRepository(db)
existing = repo.get_by_regulation(regulation_id)
return {r.article for r in existing}
# ---------------------------------------------------------------------------
# Endpoint
# ---------------------------------------------------------------------------
@router.post("/compliance/extract-requirements-from-rag", response_model=ExtractionResponse)
async def extract_requirements_from_rag(
body: ExtractionRequest,
db: Session = Depends(get_db),
):
"""
Search all RAG collections for Prüfaspekte / audit criteria and create
Requirement entries in the compliance DB.
- Deduplicates by (regulation_code, article) — safe to call multiple times.
- Auto-creates Regulation stubs for previously unknown regulation_codes.
- Use `dry_run=true` to preview results without any DB writes.
- Use `regulation_codes` to restrict to specific regulations (e.g. ["BSI-TR-03161-1"]).
"""
collections = body.collections or ALL_COLLECTIONS
queries = body.search_queries or DEFAULT_QUERIES
# --- 1. Search all collections in parallel ---
search_tasks = [
_search_collection(col, queries, body.max_per_query)
for col in collections
]
collection_results: List[List[RAGSearchResult]] = await asyncio.gather(
*search_tasks, return_exceptions=True
)
# Flatten, skip exceptions
all_results: List[RAGSearchResult] = []
for col, res in zip(collections, collection_results):
if isinstance(res, Exception):
logger.warning("Collection %s search failed: %s", col, res)
else:
all_results.extend(res)
logger.info("RAG extraction: %d raw results from %d collections", len(all_results), len(collections))
# --- 2. Filter by regulation_codes if requested ---
if body.regulation_codes:
all_results = [
r for r in all_results
if r.regulation_code in body.regulation_codes
]
# --- 3. Deduplicate at result level (regulation_code + article) ---
seen: set[tuple[str, str]] = set()
unique_results: List[RAGSearchResult] = []
for r in sorted(all_results, key=lambda x: x.score, reverse=True):
article = _normalize_article(r)
if not article:
continue
key = (r.regulation_code, article)
if key not in seen:
seen.add(key)
unique_results.append(r)
logger.info("RAG extraction: %d unique (regulation, article) pairs", len(unique_results))
# --- 4. Group by regulation_code and process ---
by_reg: Dict[str, List[tuple[str, RAGSearchResult]]] = {}
skipped_no_article: List[RAGSearchResult] = []
for r in all_results:
article = _normalize_article(r)
if not article:
skipped_no_article.append(r)
continue
key_r = r.regulation_code or "UNKNOWN"
if key_r not in by_reg:
by_reg[key_r] = []
by_reg[key_r].append((article, r))
# Deduplicate within groups
deduped_by_reg: Dict[str, List[tuple[str, RAGSearchResult]]] = {}
for reg_code, items in by_reg.items():
seen_articles: set[str] = set()
deduped: List[tuple[str, RAGSearchResult]] = []
for art, r in sorted(items, key=lambda x: x[1].score, reverse=True):
if art not in seen_articles:
seen_articles.add(art)
deduped.append((art, r))
deduped_by_reg[reg_code] = deduped
# --- 5. Create requirements ---
req_repo = RequirementRepository(db)
created_count = 0
skipped_dup_count = 0
failed_count = 0
result_items: List[ExtractedRequirement] = []
for reg_code, items in deduped_by_reg.items():
if not items:
continue
# Find or create regulation
try:
first_result = items[0][1]
regulation_name = first_result.regulation_name or first_result.regulation_short or reg_code
if body.dry_run:
# For dry_run, fake a regulation id
regulation_id = f"dry-run-{reg_code}"
existing_articles: set[str] = set()
else:
reg = _get_or_create_regulation(db, reg_code, regulation_name)
regulation_id = reg.id
existing_articles = _build_existing_articles(db, regulation_id)
except Exception as e:
logger.error("Failed to get/create regulation %s: %s", reg_code, e)
failed_count += len(items)
continue
for article, r in items:
title = _derive_title(r.text, article)
if article in existing_articles:
skipped_dup_count += 1
result_items.append(ExtractedRequirement(
regulation_code=reg_code,
article=article,
title=title,
requirement_text=r.text[:1000],
source_url=r.source_url,
score=r.score,
action="skipped_duplicate",
))
continue
if not body.dry_run:
try:
req_repo.create(
regulation_id=regulation_id,
article=article,
title=title,
description=f"Extrahiert aus RAG-Korpus (Collection: {r.category or r.regulation_code}). Score: {r.score:.2f}",
requirement_text=r.text[:2000],
breakpilot_interpretation=None,
is_applicable=True,
priority=2,
)
existing_articles.add(article) # prevent intra-batch duplication
created_count += 1
except Exception as e:
logger.error("Failed to create requirement %s/%s: %s", reg_code, article, e)
failed_count += 1
continue
else:
created_count += 1 # dry_run: count as would-create
result_items.append(ExtractedRequirement(
regulation_code=reg_code,
article=article,
title=title,
requirement_text=r.text[:1000],
source_url=r.source_url,
score=r.score,
action="created" if not body.dry_run else "would_create",
))
message = (
f"{'[DRY RUN] ' if body.dry_run else ''}"
f"Erstellt: {created_count}, Duplikate übersprungen: {skipped_dup_count}, "
f"Ohne Artikel-ID übersprungen: {len(skipped_no_article)}, Fehler: {failed_count}"
)
logger.info("RAG extraction complete: %s", message)
return ExtractionResponse(
created=created_count,
skipped_duplicates=skipped_dup_count,
skipped_no_article=len(skipped_no_article),
failed=failed_count,
collections_searched=collections,
queries_used=queries,
requirements=result_items,
dry_run=body.dry_run,
message=message,
)

View File

@@ -0,0 +1,416 @@
"""Tests for RAG-based Requirement Extraction endpoint.
POST /compliance/extract-requirements-from-rag
"""
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from fastapi import FastAPI
from fastapi.testclient import TestClient
from compliance.api.extraction_routes import router as extraction_router
from classroom_engine.database import get_db
# ---------------------------------------------------------------------------
# App setup
# ---------------------------------------------------------------------------
app = FastAPI()
app.include_router(extraction_router)
mock_db = MagicMock()
def override_get_db():
yield mock_db
app.dependency_overrides[get_db] = override_get_db
client = TestClient(app)
REG_ID = "aaaaaaaa-1111-2222-3333-aaaaaaaaaaaa"
REQ_ID = "bbbbbbbb-1111-2222-3333-bbbbbbbbbbbb"
# ---------------------------------------------------------------------------
# RAG result helpers
# ---------------------------------------------------------------------------
def make_rag_result(overrides=None):
r = MagicMock()
r.text = "O.Purp_6 MUSS: Die Anwendung MUSS den Zweck der Verarbeitung klar benennen."
r.regulation_code = "BSI-TR-03161-1"
r.regulation_name = "BSI Technische Richtlinie 03161 Teil 1"
r.regulation_short = "BSI-TR-03161-1"
r.category = "bsi"
r.article = "O.Purp_6"
r.paragraph = ""
r.source_url = "https://bsi.bund.de/tr03161"
r.score = 0.92
if overrides:
for k, v in overrides.items():
setattr(r, k, v)
return r
def make_regulation(overrides=None):
reg = MagicMock()
reg.id = REG_ID
reg.code = "BSI-TR-03161-1"
reg.name = "BSI-TR-03161-1"
if overrides:
for k, v in overrides.items():
setattr(reg, k, v)
return reg
def make_requirement(overrides=None):
req = MagicMock()
req.id = REQ_ID
req.regulation_id = REG_ID
req.article = "O.Purp_6"
req.title = "Zweckbenennung"
req.regulation = MagicMock()
req.regulation.code = "BSI-TR-03161-1"
if overrides:
for k, v in overrides.items():
setattr(req, k, v)
return req
# ---------------------------------------------------------------------------
# Helper: RAG mock
# ---------------------------------------------------------------------------
def patch_rag(results_per_call=None):
"""Return a patcher that makes RAG search return the given results list."""
results = [make_rag_result()] if results_per_call is None else results_per_call
async def fake_search(*args, **kwargs):
return results
mock_client = MagicMock()
mock_client.search = fake_search
return patch(
"compliance.api.extraction_routes.get_rag_client",
return_value=mock_client,
)
# ---------------------------------------------------------------------------
# Basic endpoint tests
# ---------------------------------------------------------------------------
class TestExtractRequirementsBasic:
"""Basic extraction endpoint tests."""
def test_empty_rag_results(self):
"""When RAG returns nothing, response should report 0 created."""
with patch_rag(results_per_call=[]):
response = client.post("/compliance/extract-requirements-from-rag", json={})
assert response.status_code == 200
data = response.json()
assert data["created"] == 0
assert data["skipped_duplicates"] == 0
assert data["dry_run"] is False
def test_dry_run_does_not_write_db(self):
"""dry_run=true should not call RequirementRepository.create."""
with patch_rag([make_rag_result()]), \
patch("compliance.api.extraction_routes.RegulationRepository") as MockRegRepo, \
patch("compliance.api.extraction_routes.RequirementRepository") as MockReqRepo:
MockRegRepo.return_value.get_by_code.return_value = make_regulation()
MockReqRepo.return_value.get_by_regulation.return_value = []
response = client.post("/compliance/extract-requirements-from-rag", json={"dry_run": True})
assert response.status_code == 200
data = response.json()
assert data["dry_run"] is True
assert data["created"] == 1 # would-create count
# DB create must NOT be called
MockReqRepo.return_value.create.assert_not_called()
def test_creates_requirement_new_regulation(self):
"""New regulation + new requirement should be created in DB."""
rag_result = make_rag_result()
new_reg = make_regulation()
new_req = make_requirement()
with patch_rag([rag_result]), \
patch("compliance.api.extraction_routes.RegulationRepository") as MockRegRepo, \
patch("compliance.api.extraction_routes.RequirementRepository") as MockReqRepo:
# Regulation doesn't exist yet → auto-create
MockRegRepo.return_value.get_by_code.return_value = None
MockRegRepo.return_value.create.return_value = new_reg
MockReqRepo.return_value.get_by_regulation.return_value = []
MockReqRepo.return_value.create.return_value = new_req
response = client.post("/compliance/extract-requirements-from-rag", json={})
assert response.status_code == 200
data = response.json()
assert data["created"] == 1
assert data["skipped_duplicates"] == 0
MockRegRepo.return_value.create.assert_called_once()
MockReqRepo.return_value.create.assert_called_once()
def test_skips_duplicate_requirement(self):
"""If article already exists for the regulation, skip it."""
rag_result = make_rag_result()
existing_req = make_requirement({"article": "O.Purp_6"})
with patch_rag([rag_result]), \
patch("compliance.api.extraction_routes.RegulationRepository") as MockRegRepo, \
patch("compliance.api.extraction_routes.RequirementRepository") as MockReqRepo:
MockRegRepo.return_value.get_by_code.return_value = make_regulation()
MockReqRepo.return_value.get_by_regulation.return_value = [existing_req]
MockReqRepo.return_value.create.return_value = MagicMock()
response = client.post("/compliance/extract-requirements-from-rag", json={})
assert response.status_code == 200
data = response.json()
assert data["skipped_duplicates"] == 1
assert data["created"] == 0
MockReqRepo.return_value.create.assert_not_called()
def test_result_items_contain_expected_fields(self):
"""Requirements list should contain correct fields."""
with patch_rag([make_rag_result()]), \
patch("compliance.api.extraction_routes.RegulationRepository") as MockRegRepo, \
patch("compliance.api.extraction_routes.RequirementRepository") as MockReqRepo:
MockRegRepo.return_value.get_by_code.return_value = make_regulation()
MockReqRepo.return_value.get_by_regulation.return_value = []
MockReqRepo.return_value.create.return_value = make_requirement()
response = client.post("/compliance/extract-requirements-from-rag", json={})
data = response.json()
assert len(data["requirements"]) == 1
req = data["requirements"][0]
assert req["regulation_code"] == "BSI-TR-03161-1"
assert req["article"] == "O.Purp_6"
assert req["action"] == "created"
assert "title" in req
assert "score" in req
# ---------------------------------------------------------------------------
# Collection and query filtering
# ---------------------------------------------------------------------------
class TestExtractionFilters:
"""Tests for collection/regulation filters."""
def test_custom_collections_passed(self):
"""Should only search specified collections."""
with patch_rag([]) as mock_patch:
with mock_patch:
response = client.post(
"/compliance/extract-requirements-from-rag",
json={"collections": ["bp_compliance_ce"], "dry_run": True},
)
assert response.status_code == 200
data = response.json()
assert data["collections_searched"] == ["bp_compliance_ce"]
def test_regulation_code_filter(self):
"""Results from other regulation_codes should be excluded."""
bsi_result = make_rag_result({"regulation_code": "BSI-TR-03161-1"})
gdpr_result = make_rag_result({
"regulation_code": "GDPR",
"article": "Art. 32",
"text": "Art. 32 Sicherheit der Verarbeitung. Der Verantwortliche MUSS geeignete Maßnahmen treffen.",
})
with patch_rag([bsi_result, gdpr_result]), \
patch("compliance.api.extraction_routes.RegulationRepository") as MockRegRepo, \
patch("compliance.api.extraction_routes.RequirementRepository") as MockReqRepo:
MockRegRepo.return_value.get_by_code.return_value = make_regulation()
MockReqRepo.return_value.get_by_regulation.return_value = []
MockReqRepo.return_value.create.return_value = make_requirement()
response = client.post(
"/compliance/extract-requirements-from-rag",
json={"regulation_codes": ["BSI-TR-03161-1"]},
)
assert response.status_code == 200
data = response.json()
# Only BSI result should be in output
for req in data["requirements"]:
assert req["regulation_code"] == "BSI-TR-03161-1"
def test_custom_queries_passed(self):
"""custom search_queries should be used."""
with patch_rag([]):
response = client.post(
"/compliance/extract-requirements-from-rag",
json={"search_queries": ["custom query"], "dry_run": True},
)
assert response.status_code == 200
data = response.json()
assert "custom query" in data["queries_used"]
def test_default_queries_used_when_none(self):
"""When no queries given, DEFAULT_QUERIES are used."""
with patch_rag([]):
response = client.post(
"/compliance/extract-requirements-from-rag",
json={"dry_run": True},
)
assert response.status_code == 200
data = response.json()
assert len(data["queries_used"]) > 0
# ---------------------------------------------------------------------------
# Deduplication + article extraction
# ---------------------------------------------------------------------------
class TestArticleExtraction:
"""Tests for article normalization from RAG chunks."""
def test_result_without_article_field_uses_text_pattern(self):
"""If article field is empty, extract BSI pattern from text."""
r = make_rag_result({"article": "", "text": "O.Auth_2 MUSS: Passwörter MÜSSEN gehasht sein."})
with patch_rag([r]), \
patch("compliance.api.extraction_routes.RegulationRepository") as MockRegRepo, \
patch("compliance.api.extraction_routes.RequirementRepository") as MockReqRepo:
MockRegRepo.return_value.get_by_code.return_value = make_regulation()
MockReqRepo.return_value.get_by_regulation.return_value = []
MockReqRepo.return_value.create.return_value = make_requirement()
response = client.post("/compliance/extract-requirements-from-rag", json={})
data = response.json()
created_reqs = [x for x in data["requirements"] if x["action"] == "created"]
assert len(created_reqs) == 1
assert created_reqs[0]["article"] == "O.Auth_2"
def test_result_with_no_article_at_all_is_skipped(self):
"""Results without any article identifier are skipped."""
r = make_rag_result({
"article": "",
"text": "General text without any structured identifier in the document.",
})
with patch_rag([r]):
response = client.post("/compliance/extract-requirements-from-rag", json={})
data = response.json()
assert data["skipped_no_article"] >= 1
assert data["created"] == 0
def test_intra_batch_deduplication(self):
"""Two results with same regulation+article should only create one requirement."""
r1 = make_rag_result({"article": "O.Purp_6"})
r2 = make_rag_result({"article": "O.Purp_6", "text": "O.Purp_6 Additional text about the same requirement."})
with patch_rag([r1, r2]), \
patch("compliance.api.extraction_routes.RegulationRepository") as MockRegRepo, \
patch("compliance.api.extraction_routes.RequirementRepository") as MockReqRepo:
MockRegRepo.return_value.get_by_code.return_value = make_regulation()
MockReqRepo.return_value.get_by_regulation.return_value = []
MockReqRepo.return_value.create.return_value = make_requirement()
response = client.post("/compliance/extract-requirements-from-rag", json={})
data = response.json()
# Only one should be created despite two results with same article
assert data["created"] == 1
# ---------------------------------------------------------------------------
# Regulation auto-creation
# ---------------------------------------------------------------------------
class TestRegulationAutoCreate:
"""Tests for auto-creation of regulation stubs."""
def test_existing_regulation_not_recreated(self):
"""If regulation already exists, create should NOT be called."""
with patch_rag([make_rag_result()]), \
patch("compliance.api.extraction_routes.RegulationRepository") as MockRegRepo, \
patch("compliance.api.extraction_routes.RequirementRepository") as MockReqRepo:
MockRegRepo.return_value.get_by_code.return_value = make_regulation()
MockReqRepo.return_value.get_by_regulation.return_value = []
MockReqRepo.return_value.create.return_value = make_requirement()
response = client.post("/compliance/extract-requirements-from-rag", json={})
assert response.status_code == 200
MockRegRepo.return_value.create.assert_not_called()
def test_unknown_regulation_is_auto_created(self):
"""Unknown regulation_code triggers RegulationRepository.create."""
with patch_rag([make_rag_result()]), \
patch("compliance.api.extraction_routes.RegulationRepository") as MockRegRepo, \
patch("compliance.api.extraction_routes.RequirementRepository") as MockReqRepo:
MockRegRepo.return_value.get_by_code.return_value = None
MockRegRepo.return_value.create.return_value = make_regulation()
MockReqRepo.return_value.get_by_regulation.return_value = []
MockReqRepo.return_value.create.return_value = make_requirement()
response = client.post("/compliance/extract-requirements-from-rag", json={})
assert response.status_code == 200
data = response.json()
assert data["created"] == 1
MockRegRepo.return_value.create.assert_called_once()
def test_multiple_regulations_in_one_run(self):
"""Two results from different regulations should each get their regulation processed."""
r1 = make_rag_result({"regulation_code": "BSI-TR-03161-1", "article": "O.Purp_6"})
r2 = make_rag_result({
"regulation_code": "GDPR",
"article": "Art. 32",
"text": "Art. 32 Sicherheit der Verarbeitung MUSS gewährleistet sein.",
})
with patch_rag([r1, r2]), \
patch("compliance.api.extraction_routes.RegulationRepository") as MockRegRepo, \
patch("compliance.api.extraction_routes.RequirementRepository") as MockReqRepo:
MockRegRepo.return_value.get_by_code.return_value = make_regulation()
MockReqRepo.return_value.get_by_regulation.return_value = []
MockReqRepo.return_value.create.return_value = make_requirement()
response = client.post("/compliance/extract-requirements-from-rag", json={})
assert response.status_code == 200
data = response.json()
# Both requirements should be created
assert data["created"] == 2
# ---------------------------------------------------------------------------
# Response structure
# ---------------------------------------------------------------------------
class TestResponseStructure:
"""Verify the full response structure."""
def test_all_fields_present(self):
with patch_rag([]):
response = client.post("/compliance/extract-requirements-from-rag", json={})
assert response.status_code == 200
data = response.json()
for field in [
"created", "skipped_duplicates", "skipped_no_article",
"failed", "collections_searched", "queries_used",
"requirements", "dry_run", "message",
]:
assert field in data, f"Missing field: {field}"
def test_message_contains_summary(self):
with patch_rag([]):
response = client.post("/compliance/extract-requirements-from-rag", json={})
data = response.json()
assert "Erstellt:" in data["message"]
def test_dry_run_message_prefix(self):
with patch_rag([]):
response = client.post("/compliance/extract-requirements-from-rag", json={"dry_run": True})
data = response.json()
assert "[DRY RUN]" in data["message"]