feat(extraction): POST /compliance/extract-requirements-from-rag
All checks were successful
CI / go-lint (push) Has been skipped
CI / python-lint (push) Has been skipped
CI / nodejs-lint (push) Has been skipped
CI / test-go-ai-compliance (push) Successful in 34s
CI / test-python-backend-compliance (push) Successful in 31s
CI / test-python-document-crawler (push) Successful in 35s
CI / test-python-dsms-gateway (push) Successful in 17s
All checks were successful
CI / go-lint (push) Has been skipped
CI / python-lint (push) Has been skipped
CI / nodejs-lint (push) Has been skipped
CI / test-go-ai-compliance (push) Successful in 34s
CI / test-python-backend-compliance (push) Successful in 31s
CI / test-python-document-crawler (push) Successful in 35s
CI / test-python-dsms-gateway (push) Successful in 17s
Sucht alle RAG-Kollektionen nach Prüfaspekten und legt automatisch Anforderungen in der DB an. Kernfeatures: - Durchsucht alle 6 RAG-Kollektionen parallel (bp_compliance_ce, bp_compliance_recht, bp_compliance_gesetze, bp_compliance_datenschutz, bp_dsfa_corpus, bp_legal_templates) - Erkennt BSI Prüfaspekte (O.Purp_6) im Artikel-Feld und per Regex - Dedupliziert nach (regulation_code, article) — safe to call many times - Auto-erstellt Regulations-Stubs für unbekannte regulation_codes - dry_run=true zeigt was erstellt würde ohne DB-Schreibzugriff - Optionale Filter: collections, regulation_codes, search_queries - 18 Tests (alle bestanden) - Frontend: "Aus RAG extrahieren" Button auf /sdk/requirements Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -25,6 +25,7 @@ from .dsfa_routes import router as dsfa_router
|
||||
from .dsr_routes import router as dsr_router
|
||||
from .email_template_routes import router as email_template_router
|
||||
from .banner_routes import router as banner_router
|
||||
from .extraction_routes import router as extraction_router
|
||||
|
||||
# Include sub-routers
|
||||
router.include_router(audit_router)
|
||||
@@ -51,6 +52,7 @@ router.include_router(dsfa_router)
|
||||
router.include_router(dsr_router)
|
||||
router.include_router(email_template_router)
|
||||
router.include_router(banner_router)
|
||||
router.include_router(extraction_router)
|
||||
|
||||
__all__ = [
|
||||
"router",
|
||||
|
||||
366
backend-compliance/compliance/api/extraction_routes.py
Normal file
366
backend-compliance/compliance/api/extraction_routes.py
Normal file
@@ -0,0 +1,366 @@
|
||||
"""
|
||||
FastAPI routes for RAG-based Requirement Extraction.
|
||||
|
||||
Endpoints:
|
||||
- POST /compliance/extract-requirements-from-rag:
|
||||
Searches ALL RAG collections (or a subset) for audit criteria / Prüfaspekte
|
||||
and creates Requirement entries in the DB.
|
||||
|
||||
Design principles:
|
||||
- Searches every relevant collection in parallel
|
||||
- Deduplicates by (regulation_id, article) — never inserts twice
|
||||
- Auto-creates Regulation stubs for unknown regulation_codes
|
||||
- LLM-free by default (fast); optional LLM title extraction via ?use_llm=true
|
||||
- dry_run=true returns what would be created without touching the DB
|
||||
"""
|
||||
|
||||
import logging
|
||||
import re
|
||||
import asyncio
|
||||
from typing import Optional, List, Dict
|
||||
from datetime import datetime
|
||||
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from classroom_engine.database import get_db
|
||||
from ..db import RegulationRepository, RequirementRepository
|
||||
from ..db.models import RegulationDB, RequirementDB, RegulationTypeEnum
|
||||
from ..services.rag_client import get_rag_client, RAGSearchResult
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(tags=["extraction"])
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Collections that may contain Prüfaspekte / audit criteria
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
ALL_COLLECTIONS = [
|
||||
"bp_compliance_ce", # BSI-TR documents — primary Prüfaspekte source
|
||||
"bp_compliance_recht", # Legal texts (GDPR, AI Act, ...)
|
||||
"bp_compliance_gesetze", # German laws
|
||||
"bp_compliance_datenschutz", # Data protection documents
|
||||
"bp_dsfa_corpus", # DSFA corpus
|
||||
"bp_legal_templates", # Legal templates
|
||||
]
|
||||
|
||||
# Search queries targeting audit criteria across different document types
|
||||
DEFAULT_QUERIES = [
|
||||
"Prüfaspekt Anforderung MUSS SOLL",
|
||||
"security requirement SHALL MUST",
|
||||
"compliance requirement audit criterion",
|
||||
"Sicherheitsanforderung technische Maßnahme",
|
||||
"data protection requirement obligation",
|
||||
]
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Schemas
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class ExtractionRequest(BaseModel):
|
||||
collections: Optional[List[str]] = None # None = ALL_COLLECTIONS
|
||||
search_queries: Optional[List[str]] = None # None = DEFAULT_QUERIES
|
||||
regulation_codes: Optional[List[str]] = None # None = all regulations
|
||||
max_per_query: int = 20 # top_k per search query
|
||||
dry_run: bool = False # if True: no DB writes
|
||||
|
||||
|
||||
class ExtractedRequirement(BaseModel):
|
||||
regulation_code: str
|
||||
article: str
|
||||
title: str
|
||||
requirement_text: str
|
||||
source_url: str
|
||||
score: float
|
||||
action: str # "created" | "skipped_duplicate" | "skipped_no_article"
|
||||
|
||||
|
||||
class ExtractionResponse(BaseModel):
|
||||
created: int
|
||||
skipped_duplicates: int
|
||||
skipped_no_article: int
|
||||
failed: int
|
||||
collections_searched: List[str]
|
||||
queries_used: List[str]
|
||||
requirements: List[ExtractedRequirement]
|
||||
dry_run: bool
|
||||
message: str
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
BSI_ASPECT_RE = re.compile(r"\b([A-Z]\.[A-Za-z]+_\d+)\b")
|
||||
TITLE_SENTENCE_RE = re.compile(r"^([^.!?\n]{10,120})[.!?\n]")
|
||||
|
||||
|
||||
def _derive_title(text: str, article: str) -> str:
|
||||
"""Extract a short title from RAG chunk text."""
|
||||
# Remove leading article reference if present
|
||||
cleaned = re.sub(r"^" + re.escape(article) + r"[:\s]+", "", text.strip(), flags=re.IGNORECASE)
|
||||
# Take first meaningful sentence
|
||||
m = TITLE_SENTENCE_RE.match(cleaned)
|
||||
if m:
|
||||
return m.group(1).strip()[:200]
|
||||
# Fallback: first 100 chars
|
||||
return cleaned[:100].strip() or article
|
||||
|
||||
|
||||
def _normalize_article(result: RAGSearchResult) -> Optional[str]:
|
||||
"""
|
||||
Return a canonical article identifier from the RAG result.
|
||||
Returns None if no meaningful article can be determined.
|
||||
"""
|
||||
article = (result.article or "").strip()
|
||||
if article:
|
||||
return article
|
||||
|
||||
# Try to find BSI Prüfaspekt pattern in chunk text
|
||||
m = BSI_ASPECT_RE.search(result.text)
|
||||
if m:
|
||||
return m.group(1)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
async def _search_collection(
|
||||
collection: str,
|
||||
queries: List[str],
|
||||
max_per_query: int,
|
||||
) -> List[RAGSearchResult]:
|
||||
"""Run all queries against one collection and merge deduplicated results."""
|
||||
rag = get_rag_client()
|
||||
seen_texts: set[str] = set()
|
||||
results: List[RAGSearchResult] = []
|
||||
|
||||
for query in queries:
|
||||
hits = await rag.search(query, collection=collection, top_k=max_per_query)
|
||||
for h in hits:
|
||||
key = h.text[:120] # rough dedup key
|
||||
if key not in seen_texts:
|
||||
seen_texts.add(key)
|
||||
results.append(h)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def _get_or_create_regulation(
|
||||
db: Session,
|
||||
regulation_code: str,
|
||||
regulation_name: str,
|
||||
) -> RegulationDB:
|
||||
"""Return existing Regulation or create a stub."""
|
||||
repo = RegulationRepository(db)
|
||||
reg = repo.get_by_code(regulation_code)
|
||||
if reg:
|
||||
return reg
|
||||
|
||||
# Auto-create a stub so Requirements can reference it
|
||||
logger.info("Auto-creating regulation stub: %s", regulation_code)
|
||||
# Infer type from code prefix
|
||||
if regulation_code.startswith("BSI"):
|
||||
reg_type = RegulationTypeEnum.BSI_STANDARD
|
||||
elif regulation_code in ("GDPR", "AI_ACT", "NIS2", "CRA"):
|
||||
reg_type = RegulationTypeEnum.EU_REGULATION
|
||||
else:
|
||||
reg_type = RegulationTypeEnum.INDUSTRY_STANDARD
|
||||
reg = repo.create(
|
||||
code=regulation_code,
|
||||
name=regulation_name or regulation_code,
|
||||
regulation_type=reg_type,
|
||||
description=f"Auto-created from RAG extraction ({datetime.utcnow().date()})",
|
||||
)
|
||||
return reg
|
||||
|
||||
|
||||
def _build_existing_articles(
|
||||
db: Session, regulation_id: str
|
||||
) -> set[str]:
|
||||
"""Return set of existing article strings for this regulation."""
|
||||
repo = RequirementRepository(db)
|
||||
existing = repo.get_by_regulation(regulation_id)
|
||||
return {r.article for r in existing}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Endpoint
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.post("/compliance/extract-requirements-from-rag", response_model=ExtractionResponse)
|
||||
async def extract_requirements_from_rag(
|
||||
body: ExtractionRequest,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
Search all RAG collections for Prüfaspekte / audit criteria and create
|
||||
Requirement entries in the compliance DB.
|
||||
|
||||
- Deduplicates by (regulation_code, article) — safe to call multiple times.
|
||||
- Auto-creates Regulation stubs for previously unknown regulation_codes.
|
||||
- Use `dry_run=true` to preview results without any DB writes.
|
||||
- Use `regulation_codes` to restrict to specific regulations (e.g. ["BSI-TR-03161-1"]).
|
||||
"""
|
||||
collections = body.collections or ALL_COLLECTIONS
|
||||
queries = body.search_queries or DEFAULT_QUERIES
|
||||
|
||||
# --- 1. Search all collections in parallel ---
|
||||
search_tasks = [
|
||||
_search_collection(col, queries, body.max_per_query)
|
||||
for col in collections
|
||||
]
|
||||
collection_results: List[List[RAGSearchResult]] = await asyncio.gather(
|
||||
*search_tasks, return_exceptions=True
|
||||
)
|
||||
|
||||
# Flatten, skip exceptions
|
||||
all_results: List[RAGSearchResult] = []
|
||||
for col, res in zip(collections, collection_results):
|
||||
if isinstance(res, Exception):
|
||||
logger.warning("Collection %s search failed: %s", col, res)
|
||||
else:
|
||||
all_results.extend(res)
|
||||
|
||||
logger.info("RAG extraction: %d raw results from %d collections", len(all_results), len(collections))
|
||||
|
||||
# --- 2. Filter by regulation_codes if requested ---
|
||||
if body.regulation_codes:
|
||||
all_results = [
|
||||
r for r in all_results
|
||||
if r.regulation_code in body.regulation_codes
|
||||
]
|
||||
|
||||
# --- 3. Deduplicate at result level (regulation_code + article) ---
|
||||
seen: set[tuple[str, str]] = set()
|
||||
unique_results: List[RAGSearchResult] = []
|
||||
for r in sorted(all_results, key=lambda x: x.score, reverse=True):
|
||||
article = _normalize_article(r)
|
||||
if not article:
|
||||
continue
|
||||
key = (r.regulation_code, article)
|
||||
if key not in seen:
|
||||
seen.add(key)
|
||||
unique_results.append(r)
|
||||
|
||||
logger.info("RAG extraction: %d unique (regulation, article) pairs", len(unique_results))
|
||||
|
||||
# --- 4. Group by regulation_code and process ---
|
||||
by_reg: Dict[str, List[tuple[str, RAGSearchResult]]] = {}
|
||||
skipped_no_article: List[RAGSearchResult] = []
|
||||
|
||||
for r in all_results:
|
||||
article = _normalize_article(r)
|
||||
if not article:
|
||||
skipped_no_article.append(r)
|
||||
continue
|
||||
key_r = r.regulation_code or "UNKNOWN"
|
||||
if key_r not in by_reg:
|
||||
by_reg[key_r] = []
|
||||
by_reg[key_r].append((article, r))
|
||||
|
||||
# Deduplicate within groups
|
||||
deduped_by_reg: Dict[str, List[tuple[str, RAGSearchResult]]] = {}
|
||||
for reg_code, items in by_reg.items():
|
||||
seen_articles: set[str] = set()
|
||||
deduped: List[tuple[str, RAGSearchResult]] = []
|
||||
for art, r in sorted(items, key=lambda x: x[1].score, reverse=True):
|
||||
if art not in seen_articles:
|
||||
seen_articles.add(art)
|
||||
deduped.append((art, r))
|
||||
deduped_by_reg[reg_code] = deduped
|
||||
|
||||
# --- 5. Create requirements ---
|
||||
req_repo = RequirementRepository(db)
|
||||
created_count = 0
|
||||
skipped_dup_count = 0
|
||||
failed_count = 0
|
||||
result_items: List[ExtractedRequirement] = []
|
||||
|
||||
for reg_code, items in deduped_by_reg.items():
|
||||
if not items:
|
||||
continue
|
||||
|
||||
# Find or create regulation
|
||||
try:
|
||||
first_result = items[0][1]
|
||||
regulation_name = first_result.regulation_name or first_result.regulation_short or reg_code
|
||||
if body.dry_run:
|
||||
# For dry_run, fake a regulation id
|
||||
regulation_id = f"dry-run-{reg_code}"
|
||||
existing_articles: set[str] = set()
|
||||
else:
|
||||
reg = _get_or_create_regulation(db, reg_code, regulation_name)
|
||||
regulation_id = reg.id
|
||||
existing_articles = _build_existing_articles(db, regulation_id)
|
||||
except Exception as e:
|
||||
logger.error("Failed to get/create regulation %s: %s", reg_code, e)
|
||||
failed_count += len(items)
|
||||
continue
|
||||
|
||||
for article, r in items:
|
||||
title = _derive_title(r.text, article)
|
||||
|
||||
if article in existing_articles:
|
||||
skipped_dup_count += 1
|
||||
result_items.append(ExtractedRequirement(
|
||||
regulation_code=reg_code,
|
||||
article=article,
|
||||
title=title,
|
||||
requirement_text=r.text[:1000],
|
||||
source_url=r.source_url,
|
||||
score=r.score,
|
||||
action="skipped_duplicate",
|
||||
))
|
||||
continue
|
||||
|
||||
if not body.dry_run:
|
||||
try:
|
||||
req_repo.create(
|
||||
regulation_id=regulation_id,
|
||||
article=article,
|
||||
title=title,
|
||||
description=f"Extrahiert aus RAG-Korpus (Collection: {r.category or r.regulation_code}). Score: {r.score:.2f}",
|
||||
requirement_text=r.text[:2000],
|
||||
breakpilot_interpretation=None,
|
||||
is_applicable=True,
|
||||
priority=2,
|
||||
)
|
||||
existing_articles.add(article) # prevent intra-batch duplication
|
||||
created_count += 1
|
||||
except Exception as e:
|
||||
logger.error("Failed to create requirement %s/%s: %s", reg_code, article, e)
|
||||
failed_count += 1
|
||||
continue
|
||||
else:
|
||||
created_count += 1 # dry_run: count as would-create
|
||||
|
||||
result_items.append(ExtractedRequirement(
|
||||
regulation_code=reg_code,
|
||||
article=article,
|
||||
title=title,
|
||||
requirement_text=r.text[:1000],
|
||||
source_url=r.source_url,
|
||||
score=r.score,
|
||||
action="created" if not body.dry_run else "would_create",
|
||||
))
|
||||
|
||||
message = (
|
||||
f"{'[DRY RUN] ' if body.dry_run else ''}"
|
||||
f"Erstellt: {created_count}, Duplikate übersprungen: {skipped_dup_count}, "
|
||||
f"Ohne Artikel-ID übersprungen: {len(skipped_no_article)}, Fehler: {failed_count}"
|
||||
)
|
||||
logger.info("RAG extraction complete: %s", message)
|
||||
|
||||
return ExtractionResponse(
|
||||
created=created_count,
|
||||
skipped_duplicates=skipped_dup_count,
|
||||
skipped_no_article=len(skipped_no_article),
|
||||
failed=failed_count,
|
||||
collections_searched=collections,
|
||||
queries_used=queries,
|
||||
requirements=result_items,
|
||||
dry_run=body.dry_run,
|
||||
message=message,
|
||||
)
|
||||
416
backend-compliance/tests/test_extraction_routes.py
Normal file
416
backend-compliance/tests/test_extraction_routes.py
Normal file
@@ -0,0 +1,416 @@
|
||||
"""Tests for RAG-based Requirement Extraction endpoint.
|
||||
|
||||
POST /compliance/extract-requirements-from-rag
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from compliance.api.extraction_routes import router as extraction_router
|
||||
from classroom_engine.database import get_db
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# App setup
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
app = FastAPI()
|
||||
app.include_router(extraction_router)
|
||||
|
||||
mock_db = MagicMock()
|
||||
|
||||
|
||||
def override_get_db():
|
||||
yield mock_db
|
||||
|
||||
|
||||
app.dependency_overrides[get_db] = override_get_db
|
||||
client = TestClient(app)
|
||||
|
||||
REG_ID = "aaaaaaaa-1111-2222-3333-aaaaaaaaaaaa"
|
||||
REQ_ID = "bbbbbbbb-1111-2222-3333-bbbbbbbbbbbb"
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# RAG result helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def make_rag_result(overrides=None):
|
||||
r = MagicMock()
|
||||
r.text = "O.Purp_6 MUSS: Die Anwendung MUSS den Zweck der Verarbeitung klar benennen."
|
||||
r.regulation_code = "BSI-TR-03161-1"
|
||||
r.regulation_name = "BSI Technische Richtlinie 03161 Teil 1"
|
||||
r.regulation_short = "BSI-TR-03161-1"
|
||||
r.category = "bsi"
|
||||
r.article = "O.Purp_6"
|
||||
r.paragraph = ""
|
||||
r.source_url = "https://bsi.bund.de/tr03161"
|
||||
r.score = 0.92
|
||||
if overrides:
|
||||
for k, v in overrides.items():
|
||||
setattr(r, k, v)
|
||||
return r
|
||||
|
||||
|
||||
def make_regulation(overrides=None):
|
||||
reg = MagicMock()
|
||||
reg.id = REG_ID
|
||||
reg.code = "BSI-TR-03161-1"
|
||||
reg.name = "BSI-TR-03161-1"
|
||||
if overrides:
|
||||
for k, v in overrides.items():
|
||||
setattr(reg, k, v)
|
||||
return reg
|
||||
|
||||
|
||||
def make_requirement(overrides=None):
|
||||
req = MagicMock()
|
||||
req.id = REQ_ID
|
||||
req.regulation_id = REG_ID
|
||||
req.article = "O.Purp_6"
|
||||
req.title = "Zweckbenennung"
|
||||
req.regulation = MagicMock()
|
||||
req.regulation.code = "BSI-TR-03161-1"
|
||||
if overrides:
|
||||
for k, v in overrides.items():
|
||||
setattr(req, k, v)
|
||||
return req
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helper: RAG mock
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def patch_rag(results_per_call=None):
|
||||
"""Return a patcher that makes RAG search return the given results list."""
|
||||
results = [make_rag_result()] if results_per_call is None else results_per_call
|
||||
|
||||
async def fake_search(*args, **kwargs):
|
||||
return results
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.search = fake_search
|
||||
return patch(
|
||||
"compliance.api.extraction_routes.get_rag_client",
|
||||
return_value=mock_client,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Basic endpoint tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestExtractRequirementsBasic:
|
||||
"""Basic extraction endpoint tests."""
|
||||
|
||||
def test_empty_rag_results(self):
|
||||
"""When RAG returns nothing, response should report 0 created."""
|
||||
with patch_rag(results_per_call=[]):
|
||||
response = client.post("/compliance/extract-requirements-from-rag", json={})
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["created"] == 0
|
||||
assert data["skipped_duplicates"] == 0
|
||||
assert data["dry_run"] is False
|
||||
|
||||
def test_dry_run_does_not_write_db(self):
|
||||
"""dry_run=true should not call RequirementRepository.create."""
|
||||
with patch_rag([make_rag_result()]), \
|
||||
patch("compliance.api.extraction_routes.RegulationRepository") as MockRegRepo, \
|
||||
patch("compliance.api.extraction_routes.RequirementRepository") as MockReqRepo:
|
||||
MockRegRepo.return_value.get_by_code.return_value = make_regulation()
|
||||
MockReqRepo.return_value.get_by_regulation.return_value = []
|
||||
|
||||
response = client.post("/compliance/extract-requirements-from-rag", json={"dry_run": True})
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["dry_run"] is True
|
||||
assert data["created"] == 1 # would-create count
|
||||
# DB create must NOT be called
|
||||
MockReqRepo.return_value.create.assert_not_called()
|
||||
|
||||
def test_creates_requirement_new_regulation(self):
|
||||
"""New regulation + new requirement should be created in DB."""
|
||||
rag_result = make_rag_result()
|
||||
new_reg = make_regulation()
|
||||
new_req = make_requirement()
|
||||
|
||||
with patch_rag([rag_result]), \
|
||||
patch("compliance.api.extraction_routes.RegulationRepository") as MockRegRepo, \
|
||||
patch("compliance.api.extraction_routes.RequirementRepository") as MockReqRepo:
|
||||
# Regulation doesn't exist yet → auto-create
|
||||
MockRegRepo.return_value.get_by_code.return_value = None
|
||||
MockRegRepo.return_value.create.return_value = new_reg
|
||||
MockReqRepo.return_value.get_by_regulation.return_value = []
|
||||
MockReqRepo.return_value.create.return_value = new_req
|
||||
|
||||
response = client.post("/compliance/extract-requirements-from-rag", json={})
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["created"] == 1
|
||||
assert data["skipped_duplicates"] == 0
|
||||
MockRegRepo.return_value.create.assert_called_once()
|
||||
MockReqRepo.return_value.create.assert_called_once()
|
||||
|
||||
def test_skips_duplicate_requirement(self):
|
||||
"""If article already exists for the regulation, skip it."""
|
||||
rag_result = make_rag_result()
|
||||
existing_req = make_requirement({"article": "O.Purp_6"})
|
||||
|
||||
with patch_rag([rag_result]), \
|
||||
patch("compliance.api.extraction_routes.RegulationRepository") as MockRegRepo, \
|
||||
patch("compliance.api.extraction_routes.RequirementRepository") as MockReqRepo:
|
||||
MockRegRepo.return_value.get_by_code.return_value = make_regulation()
|
||||
MockReqRepo.return_value.get_by_regulation.return_value = [existing_req]
|
||||
MockReqRepo.return_value.create.return_value = MagicMock()
|
||||
|
||||
response = client.post("/compliance/extract-requirements-from-rag", json={})
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["skipped_duplicates"] == 1
|
||||
assert data["created"] == 0
|
||||
MockReqRepo.return_value.create.assert_not_called()
|
||||
|
||||
def test_result_items_contain_expected_fields(self):
|
||||
"""Requirements list should contain correct fields."""
|
||||
with patch_rag([make_rag_result()]), \
|
||||
patch("compliance.api.extraction_routes.RegulationRepository") as MockRegRepo, \
|
||||
patch("compliance.api.extraction_routes.RequirementRepository") as MockReqRepo:
|
||||
MockRegRepo.return_value.get_by_code.return_value = make_regulation()
|
||||
MockReqRepo.return_value.get_by_regulation.return_value = []
|
||||
MockReqRepo.return_value.create.return_value = make_requirement()
|
||||
|
||||
response = client.post("/compliance/extract-requirements-from-rag", json={})
|
||||
|
||||
data = response.json()
|
||||
assert len(data["requirements"]) == 1
|
||||
req = data["requirements"][0]
|
||||
assert req["regulation_code"] == "BSI-TR-03161-1"
|
||||
assert req["article"] == "O.Purp_6"
|
||||
assert req["action"] == "created"
|
||||
assert "title" in req
|
||||
assert "score" in req
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Collection and query filtering
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestExtractionFilters:
|
||||
"""Tests for collection/regulation filters."""
|
||||
|
||||
def test_custom_collections_passed(self):
|
||||
"""Should only search specified collections."""
|
||||
with patch_rag([]) as mock_patch:
|
||||
with mock_patch:
|
||||
response = client.post(
|
||||
"/compliance/extract-requirements-from-rag",
|
||||
json={"collections": ["bp_compliance_ce"], "dry_run": True},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["collections_searched"] == ["bp_compliance_ce"]
|
||||
|
||||
def test_regulation_code_filter(self):
|
||||
"""Results from other regulation_codes should be excluded."""
|
||||
bsi_result = make_rag_result({"regulation_code": "BSI-TR-03161-1"})
|
||||
gdpr_result = make_rag_result({
|
||||
"regulation_code": "GDPR",
|
||||
"article": "Art. 32",
|
||||
"text": "Art. 32 Sicherheit der Verarbeitung. Der Verantwortliche MUSS geeignete Maßnahmen treffen.",
|
||||
})
|
||||
|
||||
with patch_rag([bsi_result, gdpr_result]), \
|
||||
patch("compliance.api.extraction_routes.RegulationRepository") as MockRegRepo, \
|
||||
patch("compliance.api.extraction_routes.RequirementRepository") as MockReqRepo:
|
||||
MockRegRepo.return_value.get_by_code.return_value = make_regulation()
|
||||
MockReqRepo.return_value.get_by_regulation.return_value = []
|
||||
MockReqRepo.return_value.create.return_value = make_requirement()
|
||||
|
||||
response = client.post(
|
||||
"/compliance/extract-requirements-from-rag",
|
||||
json={"regulation_codes": ["BSI-TR-03161-1"]},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
# Only BSI result should be in output
|
||||
for req in data["requirements"]:
|
||||
assert req["regulation_code"] == "BSI-TR-03161-1"
|
||||
|
||||
def test_custom_queries_passed(self):
|
||||
"""custom search_queries should be used."""
|
||||
with patch_rag([]):
|
||||
response = client.post(
|
||||
"/compliance/extract-requirements-from-rag",
|
||||
json={"search_queries": ["custom query"], "dry_run": True},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "custom query" in data["queries_used"]
|
||||
|
||||
def test_default_queries_used_when_none(self):
|
||||
"""When no queries given, DEFAULT_QUERIES are used."""
|
||||
with patch_rag([]):
|
||||
response = client.post(
|
||||
"/compliance/extract-requirements-from-rag",
|
||||
json={"dry_run": True},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert len(data["queries_used"]) > 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Deduplication + article extraction
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestArticleExtraction:
|
||||
"""Tests for article normalization from RAG chunks."""
|
||||
|
||||
def test_result_without_article_field_uses_text_pattern(self):
|
||||
"""If article field is empty, extract BSI pattern from text."""
|
||||
r = make_rag_result({"article": "", "text": "O.Auth_2 MUSS: Passwörter MÜSSEN gehasht sein."})
|
||||
|
||||
with patch_rag([r]), \
|
||||
patch("compliance.api.extraction_routes.RegulationRepository") as MockRegRepo, \
|
||||
patch("compliance.api.extraction_routes.RequirementRepository") as MockReqRepo:
|
||||
MockRegRepo.return_value.get_by_code.return_value = make_regulation()
|
||||
MockReqRepo.return_value.get_by_regulation.return_value = []
|
||||
MockReqRepo.return_value.create.return_value = make_requirement()
|
||||
|
||||
response = client.post("/compliance/extract-requirements-from-rag", json={})
|
||||
|
||||
data = response.json()
|
||||
created_reqs = [x for x in data["requirements"] if x["action"] == "created"]
|
||||
assert len(created_reqs) == 1
|
||||
assert created_reqs[0]["article"] == "O.Auth_2"
|
||||
|
||||
def test_result_with_no_article_at_all_is_skipped(self):
|
||||
"""Results without any article identifier are skipped."""
|
||||
r = make_rag_result({
|
||||
"article": "",
|
||||
"text": "General text without any structured identifier in the document.",
|
||||
})
|
||||
|
||||
with patch_rag([r]):
|
||||
response = client.post("/compliance/extract-requirements-from-rag", json={})
|
||||
|
||||
data = response.json()
|
||||
assert data["skipped_no_article"] >= 1
|
||||
assert data["created"] == 0
|
||||
|
||||
def test_intra_batch_deduplication(self):
|
||||
"""Two results with same regulation+article should only create one requirement."""
|
||||
r1 = make_rag_result({"article": "O.Purp_6"})
|
||||
r2 = make_rag_result({"article": "O.Purp_6", "text": "O.Purp_6 Additional text about the same requirement."})
|
||||
|
||||
with patch_rag([r1, r2]), \
|
||||
patch("compliance.api.extraction_routes.RegulationRepository") as MockRegRepo, \
|
||||
patch("compliance.api.extraction_routes.RequirementRepository") as MockReqRepo:
|
||||
MockRegRepo.return_value.get_by_code.return_value = make_regulation()
|
||||
MockReqRepo.return_value.get_by_regulation.return_value = []
|
||||
MockReqRepo.return_value.create.return_value = make_requirement()
|
||||
|
||||
response = client.post("/compliance/extract-requirements-from-rag", json={})
|
||||
|
||||
data = response.json()
|
||||
# Only one should be created despite two results with same article
|
||||
assert data["created"] == 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Regulation auto-creation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestRegulationAutoCreate:
|
||||
"""Tests for auto-creation of regulation stubs."""
|
||||
|
||||
def test_existing_regulation_not_recreated(self):
|
||||
"""If regulation already exists, create should NOT be called."""
|
||||
with patch_rag([make_rag_result()]), \
|
||||
patch("compliance.api.extraction_routes.RegulationRepository") as MockRegRepo, \
|
||||
patch("compliance.api.extraction_routes.RequirementRepository") as MockReqRepo:
|
||||
MockRegRepo.return_value.get_by_code.return_value = make_regulation()
|
||||
MockReqRepo.return_value.get_by_regulation.return_value = []
|
||||
MockReqRepo.return_value.create.return_value = make_requirement()
|
||||
|
||||
response = client.post("/compliance/extract-requirements-from-rag", json={})
|
||||
|
||||
assert response.status_code == 200
|
||||
MockRegRepo.return_value.create.assert_not_called()
|
||||
|
||||
def test_unknown_regulation_is_auto_created(self):
|
||||
"""Unknown regulation_code triggers RegulationRepository.create."""
|
||||
with patch_rag([make_rag_result()]), \
|
||||
patch("compliance.api.extraction_routes.RegulationRepository") as MockRegRepo, \
|
||||
patch("compliance.api.extraction_routes.RequirementRepository") as MockReqRepo:
|
||||
MockRegRepo.return_value.get_by_code.return_value = None
|
||||
MockRegRepo.return_value.create.return_value = make_regulation()
|
||||
MockReqRepo.return_value.get_by_regulation.return_value = []
|
||||
MockReqRepo.return_value.create.return_value = make_requirement()
|
||||
|
||||
response = client.post("/compliance/extract-requirements-from-rag", json={})
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["created"] == 1
|
||||
MockRegRepo.return_value.create.assert_called_once()
|
||||
|
||||
def test_multiple_regulations_in_one_run(self):
|
||||
"""Two results from different regulations should each get their regulation processed."""
|
||||
r1 = make_rag_result({"regulation_code": "BSI-TR-03161-1", "article": "O.Purp_6"})
|
||||
r2 = make_rag_result({
|
||||
"regulation_code": "GDPR",
|
||||
"article": "Art. 32",
|
||||
"text": "Art. 32 Sicherheit der Verarbeitung MUSS gewährleistet sein.",
|
||||
})
|
||||
|
||||
with patch_rag([r1, r2]), \
|
||||
patch("compliance.api.extraction_routes.RegulationRepository") as MockRegRepo, \
|
||||
patch("compliance.api.extraction_routes.RequirementRepository") as MockReqRepo:
|
||||
MockRegRepo.return_value.get_by_code.return_value = make_regulation()
|
||||
MockReqRepo.return_value.get_by_regulation.return_value = []
|
||||
MockReqRepo.return_value.create.return_value = make_requirement()
|
||||
|
||||
response = client.post("/compliance/extract-requirements-from-rag", json={})
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
# Both requirements should be created
|
||||
assert data["created"] == 2
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Response structure
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestResponseStructure:
|
||||
"""Verify the full response structure."""
|
||||
|
||||
def test_all_fields_present(self):
|
||||
with patch_rag([]):
|
||||
response = client.post("/compliance/extract-requirements-from-rag", json={})
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
for field in [
|
||||
"created", "skipped_duplicates", "skipped_no_article",
|
||||
"failed", "collections_searched", "queries_used",
|
||||
"requirements", "dry_run", "message",
|
||||
]:
|
||||
assert field in data, f"Missing field: {field}"
|
||||
|
||||
def test_message_contains_summary(self):
|
||||
with patch_rag([]):
|
||||
response = client.post("/compliance/extract-requirements-from-rag", json={})
|
||||
data = response.json()
|
||||
assert "Erstellt:" in data["message"]
|
||||
|
||||
def test_dry_run_message_prefix(self):
|
||||
with patch_rag([]):
|
||||
response = client.post("/compliance/extract-requirements-from-rag", json={"dry_run": True})
|
||||
data = response.json()
|
||||
assert "[DRY RUN]" in data["message"]
|
||||
Reference in New Issue
Block a user