diff --git a/admin-compliance/app/sdk/requirements/page.tsx b/admin-compliance/app/sdk/requirements/page.tsx index ad172b6..91f6c8a 100644 --- a/admin-compliance/app/sdk/requirements/page.tsx +++ b/admin-compliance/app/sdk/requirements/page.tsx @@ -404,6 +404,50 @@ export default function RequirementsPage() { const [error, setError] = useState(null) const [showAddForm, setShowAddForm] = useState(false) const [expandedId, setExpandedId] = useState(null) + const [ragExtracting, setRagExtracting] = useState(false) + const [ragResult, setRagResult] = useState<{ created: number; skipped_duplicates: number; message: string } | null>(null) + + const extractFromRAG = async () => { + setRagExtracting(true) + setRagResult(null) + try { + const res = await fetch('/api/sdk/v1/compliance/extract-requirements-from-rag', { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify({ max_per_query: 20 }), + }) + if (res.ok) { + const data = await res.json() + setRagResult({ created: data.created, skipped_duplicates: data.skipped_duplicates, message: data.message }) + // Reload requirements list + const listRes = await fetch('/api/sdk/v1/compliance/requirements') + if (listRes.ok) { + const listData = await listRes.json() + const reqs = listData.requirements || listData + if (Array.isArray(reqs) && reqs.length > 0) { + const mapped = reqs.map((r: Record) => ({ + id: (r.requirement_id || r.id) as string, + regulation: (r.regulation_code || r.regulation || '') as string, + article: (r.article || '') as string, + title: (r.title || '') as string, + description: (r.description || '') as string, + criticality: ((r.criticality || r.priority || 'MEDIUM') as string).toUpperCase() as import('@/lib/sdk').RiskSeverity, + applicableModules: [] as string[], + status: 'NOT_STARTED' as import('@/lib/sdk').RequirementStatus, + controls: [] as string[], + })) + dispatch({ type: 'SET_STATE', payload: { requirements: mapped } }) + } + } + } else { + setRagResult({ created: 0, skipped_duplicates: 0, message: 'RAG-Extraktion fehlgeschlagen' }) + } + } catch { + setRagResult({ created: 0, skipped_duplicates: 0, message: 'RAG-Extraktion nicht erreichbar' }) + } finally { + setRagExtracting(false) + } + } // Fetch requirements from backend on mount useEffect(() => { @@ -626,17 +670,46 @@ export default function RequirementsPage() { explanation={stepInfo.explanation} tips={stepInfo.tips} > - +
+ + +
+ {/* RAG Extraction Result Banner */} + {ragResult && ( +
0 ? 'bg-green-50 border-green-200' : 'bg-blue-50 border-blue-200'}`}> + + {ragResult.created > 0 ? '✅' : 'ℹ️'} {ragResult.message} + + +
+ )} + {/* Add Form */} {showAddForm && ( str: + """Extract a short title from RAG chunk text.""" + # Remove leading article reference if present + cleaned = re.sub(r"^" + re.escape(article) + r"[:\s]+", "", text.strip(), flags=re.IGNORECASE) + # Take first meaningful sentence + m = TITLE_SENTENCE_RE.match(cleaned) + if m: + return m.group(1).strip()[:200] + # Fallback: first 100 chars + return cleaned[:100].strip() or article + + +def _normalize_article(result: RAGSearchResult) -> Optional[str]: + """ + Return a canonical article identifier from the RAG result. + Returns None if no meaningful article can be determined. + """ + article = (result.article or "").strip() + if article: + return article + + # Try to find BSI Prüfaspekt pattern in chunk text + m = BSI_ASPECT_RE.search(result.text) + if m: + return m.group(1) + + return None + + +async def _search_collection( + collection: str, + queries: List[str], + max_per_query: int, +) -> List[RAGSearchResult]: + """Run all queries against one collection and merge deduplicated results.""" + rag = get_rag_client() + seen_texts: set[str] = set() + results: List[RAGSearchResult] = [] + + for query in queries: + hits = await rag.search(query, collection=collection, top_k=max_per_query) + for h in hits: + key = h.text[:120] # rough dedup key + if key not in seen_texts: + seen_texts.add(key) + results.append(h) + + return results + + +def _get_or_create_regulation( + db: Session, + regulation_code: str, + regulation_name: str, +) -> RegulationDB: + """Return existing Regulation or create a stub.""" + repo = RegulationRepository(db) + reg = repo.get_by_code(regulation_code) + if reg: + return reg + + # Auto-create a stub so Requirements can reference it + logger.info("Auto-creating regulation stub: %s", regulation_code) + # Infer type from code prefix + if regulation_code.startswith("BSI"): + reg_type = RegulationTypeEnum.BSI_STANDARD + elif regulation_code in ("GDPR", "AI_ACT", "NIS2", "CRA"): + reg_type = RegulationTypeEnum.EU_REGULATION + else: + reg_type = RegulationTypeEnum.INDUSTRY_STANDARD + reg = repo.create( + code=regulation_code, + name=regulation_name or regulation_code, + regulation_type=reg_type, + description=f"Auto-created from RAG extraction ({datetime.utcnow().date()})", + ) + return reg + + +def _build_existing_articles( + db: Session, regulation_id: str +) -> set[str]: + """Return set of existing article strings for this regulation.""" + repo = RequirementRepository(db) + existing = repo.get_by_regulation(regulation_id) + return {r.article for r in existing} + + +# --------------------------------------------------------------------------- +# Endpoint +# --------------------------------------------------------------------------- + +@router.post("/compliance/extract-requirements-from-rag", response_model=ExtractionResponse) +async def extract_requirements_from_rag( + body: ExtractionRequest, + db: Session = Depends(get_db), +): + """ + Search all RAG collections for Prüfaspekte / audit criteria and create + Requirement entries in the compliance DB. + + - Deduplicates by (regulation_code, article) — safe to call multiple times. + - Auto-creates Regulation stubs for previously unknown regulation_codes. + - Use `dry_run=true` to preview results without any DB writes. + - Use `regulation_codes` to restrict to specific regulations (e.g. ["BSI-TR-03161-1"]). + """ + collections = body.collections or ALL_COLLECTIONS + queries = body.search_queries or DEFAULT_QUERIES + + # --- 1. Search all collections in parallel --- + search_tasks = [ + _search_collection(col, queries, body.max_per_query) + for col in collections + ] + collection_results: List[List[RAGSearchResult]] = await asyncio.gather( + *search_tasks, return_exceptions=True + ) + + # Flatten, skip exceptions + all_results: List[RAGSearchResult] = [] + for col, res in zip(collections, collection_results): + if isinstance(res, Exception): + logger.warning("Collection %s search failed: %s", col, res) + else: + all_results.extend(res) + + logger.info("RAG extraction: %d raw results from %d collections", len(all_results), len(collections)) + + # --- 2. Filter by regulation_codes if requested --- + if body.regulation_codes: + all_results = [ + r for r in all_results + if r.regulation_code in body.regulation_codes + ] + + # --- 3. Deduplicate at result level (regulation_code + article) --- + seen: set[tuple[str, str]] = set() + unique_results: List[RAGSearchResult] = [] + for r in sorted(all_results, key=lambda x: x.score, reverse=True): + article = _normalize_article(r) + if not article: + continue + key = (r.regulation_code, article) + if key not in seen: + seen.add(key) + unique_results.append(r) + + logger.info("RAG extraction: %d unique (regulation, article) pairs", len(unique_results)) + + # --- 4. Group by regulation_code and process --- + by_reg: Dict[str, List[tuple[str, RAGSearchResult]]] = {} + skipped_no_article: List[RAGSearchResult] = [] + + for r in all_results: + article = _normalize_article(r) + if not article: + skipped_no_article.append(r) + continue + key_r = r.regulation_code or "UNKNOWN" + if key_r not in by_reg: + by_reg[key_r] = [] + by_reg[key_r].append((article, r)) + + # Deduplicate within groups + deduped_by_reg: Dict[str, List[tuple[str, RAGSearchResult]]] = {} + for reg_code, items in by_reg.items(): + seen_articles: set[str] = set() + deduped: List[tuple[str, RAGSearchResult]] = [] + for art, r in sorted(items, key=lambda x: x[1].score, reverse=True): + if art not in seen_articles: + seen_articles.add(art) + deduped.append((art, r)) + deduped_by_reg[reg_code] = deduped + + # --- 5. Create requirements --- + req_repo = RequirementRepository(db) + created_count = 0 + skipped_dup_count = 0 + failed_count = 0 + result_items: List[ExtractedRequirement] = [] + + for reg_code, items in deduped_by_reg.items(): + if not items: + continue + + # Find or create regulation + try: + first_result = items[0][1] + regulation_name = first_result.regulation_name or first_result.regulation_short or reg_code + if body.dry_run: + # For dry_run, fake a regulation id + regulation_id = f"dry-run-{reg_code}" + existing_articles: set[str] = set() + else: + reg = _get_or_create_regulation(db, reg_code, regulation_name) + regulation_id = reg.id + existing_articles = _build_existing_articles(db, regulation_id) + except Exception as e: + logger.error("Failed to get/create regulation %s: %s", reg_code, e) + failed_count += len(items) + continue + + for article, r in items: + title = _derive_title(r.text, article) + + if article in existing_articles: + skipped_dup_count += 1 + result_items.append(ExtractedRequirement( + regulation_code=reg_code, + article=article, + title=title, + requirement_text=r.text[:1000], + source_url=r.source_url, + score=r.score, + action="skipped_duplicate", + )) + continue + + if not body.dry_run: + try: + req_repo.create( + regulation_id=regulation_id, + article=article, + title=title, + description=f"Extrahiert aus RAG-Korpus (Collection: {r.category or r.regulation_code}). Score: {r.score:.2f}", + requirement_text=r.text[:2000], + breakpilot_interpretation=None, + is_applicable=True, + priority=2, + ) + existing_articles.add(article) # prevent intra-batch duplication + created_count += 1 + except Exception as e: + logger.error("Failed to create requirement %s/%s: %s", reg_code, article, e) + failed_count += 1 + continue + else: + created_count += 1 # dry_run: count as would-create + + result_items.append(ExtractedRequirement( + regulation_code=reg_code, + article=article, + title=title, + requirement_text=r.text[:1000], + source_url=r.source_url, + score=r.score, + action="created" if not body.dry_run else "would_create", + )) + + message = ( + f"{'[DRY RUN] ' if body.dry_run else ''}" + f"Erstellt: {created_count}, Duplikate übersprungen: {skipped_dup_count}, " + f"Ohne Artikel-ID übersprungen: {len(skipped_no_article)}, Fehler: {failed_count}" + ) + logger.info("RAG extraction complete: %s", message) + + return ExtractionResponse( + created=created_count, + skipped_duplicates=skipped_dup_count, + skipped_no_article=len(skipped_no_article), + failed=failed_count, + collections_searched=collections, + queries_used=queries, + requirements=result_items, + dry_run=body.dry_run, + message=message, + ) diff --git a/backend-compliance/tests/test_extraction_routes.py b/backend-compliance/tests/test_extraction_routes.py new file mode 100644 index 0000000..04ab812 --- /dev/null +++ b/backend-compliance/tests/test_extraction_routes.py @@ -0,0 +1,416 @@ +"""Tests for RAG-based Requirement Extraction endpoint. + +POST /compliance/extract-requirements-from-rag +""" + +import pytest +from unittest.mock import AsyncMock, MagicMock, patch +from fastapi import FastAPI +from fastapi.testclient import TestClient + +from compliance.api.extraction_routes import router as extraction_router +from classroom_engine.database import get_db + +# --------------------------------------------------------------------------- +# App setup +# --------------------------------------------------------------------------- + +app = FastAPI() +app.include_router(extraction_router) + +mock_db = MagicMock() + + +def override_get_db(): + yield mock_db + + +app.dependency_overrides[get_db] = override_get_db +client = TestClient(app) + +REG_ID = "aaaaaaaa-1111-2222-3333-aaaaaaaaaaaa" +REQ_ID = "bbbbbbbb-1111-2222-3333-bbbbbbbbbbbb" + +# --------------------------------------------------------------------------- +# RAG result helpers +# --------------------------------------------------------------------------- + +def make_rag_result(overrides=None): + r = MagicMock() + r.text = "O.Purp_6 MUSS: Die Anwendung MUSS den Zweck der Verarbeitung klar benennen." + r.regulation_code = "BSI-TR-03161-1" + r.regulation_name = "BSI Technische Richtlinie 03161 Teil 1" + r.regulation_short = "BSI-TR-03161-1" + r.category = "bsi" + r.article = "O.Purp_6" + r.paragraph = "" + r.source_url = "https://bsi.bund.de/tr03161" + r.score = 0.92 + if overrides: + for k, v in overrides.items(): + setattr(r, k, v) + return r + + +def make_regulation(overrides=None): + reg = MagicMock() + reg.id = REG_ID + reg.code = "BSI-TR-03161-1" + reg.name = "BSI-TR-03161-1" + if overrides: + for k, v in overrides.items(): + setattr(reg, k, v) + return reg + + +def make_requirement(overrides=None): + req = MagicMock() + req.id = REQ_ID + req.regulation_id = REG_ID + req.article = "O.Purp_6" + req.title = "Zweckbenennung" + req.regulation = MagicMock() + req.regulation.code = "BSI-TR-03161-1" + if overrides: + for k, v in overrides.items(): + setattr(req, k, v) + return req + + +# --------------------------------------------------------------------------- +# Helper: RAG mock +# --------------------------------------------------------------------------- + +def patch_rag(results_per_call=None): + """Return a patcher that makes RAG search return the given results list.""" + results = [make_rag_result()] if results_per_call is None else results_per_call + + async def fake_search(*args, **kwargs): + return results + + mock_client = MagicMock() + mock_client.search = fake_search + return patch( + "compliance.api.extraction_routes.get_rag_client", + return_value=mock_client, + ) + + +# --------------------------------------------------------------------------- +# Basic endpoint tests +# --------------------------------------------------------------------------- + +class TestExtractRequirementsBasic: + """Basic extraction endpoint tests.""" + + def test_empty_rag_results(self): + """When RAG returns nothing, response should report 0 created.""" + with patch_rag(results_per_call=[]): + response = client.post("/compliance/extract-requirements-from-rag", json={}) + assert response.status_code == 200 + data = response.json() + assert data["created"] == 0 + assert data["skipped_duplicates"] == 0 + assert data["dry_run"] is False + + def test_dry_run_does_not_write_db(self): + """dry_run=true should not call RequirementRepository.create.""" + with patch_rag([make_rag_result()]), \ + patch("compliance.api.extraction_routes.RegulationRepository") as MockRegRepo, \ + patch("compliance.api.extraction_routes.RequirementRepository") as MockReqRepo: + MockRegRepo.return_value.get_by_code.return_value = make_regulation() + MockReqRepo.return_value.get_by_regulation.return_value = [] + + response = client.post("/compliance/extract-requirements-from-rag", json={"dry_run": True}) + + assert response.status_code == 200 + data = response.json() + assert data["dry_run"] is True + assert data["created"] == 1 # would-create count + # DB create must NOT be called + MockReqRepo.return_value.create.assert_not_called() + + def test_creates_requirement_new_regulation(self): + """New regulation + new requirement should be created in DB.""" + rag_result = make_rag_result() + new_reg = make_regulation() + new_req = make_requirement() + + with patch_rag([rag_result]), \ + patch("compliance.api.extraction_routes.RegulationRepository") as MockRegRepo, \ + patch("compliance.api.extraction_routes.RequirementRepository") as MockReqRepo: + # Regulation doesn't exist yet → auto-create + MockRegRepo.return_value.get_by_code.return_value = None + MockRegRepo.return_value.create.return_value = new_reg + MockReqRepo.return_value.get_by_regulation.return_value = [] + MockReqRepo.return_value.create.return_value = new_req + + response = client.post("/compliance/extract-requirements-from-rag", json={}) + + assert response.status_code == 200 + data = response.json() + assert data["created"] == 1 + assert data["skipped_duplicates"] == 0 + MockRegRepo.return_value.create.assert_called_once() + MockReqRepo.return_value.create.assert_called_once() + + def test_skips_duplicate_requirement(self): + """If article already exists for the regulation, skip it.""" + rag_result = make_rag_result() + existing_req = make_requirement({"article": "O.Purp_6"}) + + with patch_rag([rag_result]), \ + patch("compliance.api.extraction_routes.RegulationRepository") as MockRegRepo, \ + patch("compliance.api.extraction_routes.RequirementRepository") as MockReqRepo: + MockRegRepo.return_value.get_by_code.return_value = make_regulation() + MockReqRepo.return_value.get_by_regulation.return_value = [existing_req] + MockReqRepo.return_value.create.return_value = MagicMock() + + response = client.post("/compliance/extract-requirements-from-rag", json={}) + + assert response.status_code == 200 + data = response.json() + assert data["skipped_duplicates"] == 1 + assert data["created"] == 0 + MockReqRepo.return_value.create.assert_not_called() + + def test_result_items_contain_expected_fields(self): + """Requirements list should contain correct fields.""" + with patch_rag([make_rag_result()]), \ + patch("compliance.api.extraction_routes.RegulationRepository") as MockRegRepo, \ + patch("compliance.api.extraction_routes.RequirementRepository") as MockReqRepo: + MockRegRepo.return_value.get_by_code.return_value = make_regulation() + MockReqRepo.return_value.get_by_regulation.return_value = [] + MockReqRepo.return_value.create.return_value = make_requirement() + + response = client.post("/compliance/extract-requirements-from-rag", json={}) + + data = response.json() + assert len(data["requirements"]) == 1 + req = data["requirements"][0] + assert req["regulation_code"] == "BSI-TR-03161-1" + assert req["article"] == "O.Purp_6" + assert req["action"] == "created" + assert "title" in req + assert "score" in req + + +# --------------------------------------------------------------------------- +# Collection and query filtering +# --------------------------------------------------------------------------- + +class TestExtractionFilters: + """Tests for collection/regulation filters.""" + + def test_custom_collections_passed(self): + """Should only search specified collections.""" + with patch_rag([]) as mock_patch: + with mock_patch: + response = client.post( + "/compliance/extract-requirements-from-rag", + json={"collections": ["bp_compliance_ce"], "dry_run": True}, + ) + assert response.status_code == 200 + data = response.json() + assert data["collections_searched"] == ["bp_compliance_ce"] + + def test_regulation_code_filter(self): + """Results from other regulation_codes should be excluded.""" + bsi_result = make_rag_result({"regulation_code": "BSI-TR-03161-1"}) + gdpr_result = make_rag_result({ + "regulation_code": "GDPR", + "article": "Art. 32", + "text": "Art. 32 Sicherheit der Verarbeitung. Der Verantwortliche MUSS geeignete Maßnahmen treffen.", + }) + + with patch_rag([bsi_result, gdpr_result]), \ + patch("compliance.api.extraction_routes.RegulationRepository") as MockRegRepo, \ + patch("compliance.api.extraction_routes.RequirementRepository") as MockReqRepo: + MockRegRepo.return_value.get_by_code.return_value = make_regulation() + MockReqRepo.return_value.get_by_regulation.return_value = [] + MockReqRepo.return_value.create.return_value = make_requirement() + + response = client.post( + "/compliance/extract-requirements-from-rag", + json={"regulation_codes": ["BSI-TR-03161-1"]}, + ) + + assert response.status_code == 200 + data = response.json() + # Only BSI result should be in output + for req in data["requirements"]: + assert req["regulation_code"] == "BSI-TR-03161-1" + + def test_custom_queries_passed(self): + """custom search_queries should be used.""" + with patch_rag([]): + response = client.post( + "/compliance/extract-requirements-from-rag", + json={"search_queries": ["custom query"], "dry_run": True}, + ) + assert response.status_code == 200 + data = response.json() + assert "custom query" in data["queries_used"] + + def test_default_queries_used_when_none(self): + """When no queries given, DEFAULT_QUERIES are used.""" + with patch_rag([]): + response = client.post( + "/compliance/extract-requirements-from-rag", + json={"dry_run": True}, + ) + assert response.status_code == 200 + data = response.json() + assert len(data["queries_used"]) > 0 + + +# --------------------------------------------------------------------------- +# Deduplication + article extraction +# --------------------------------------------------------------------------- + +class TestArticleExtraction: + """Tests for article normalization from RAG chunks.""" + + def test_result_without_article_field_uses_text_pattern(self): + """If article field is empty, extract BSI pattern from text.""" + r = make_rag_result({"article": "", "text": "O.Auth_2 MUSS: Passwörter MÜSSEN gehasht sein."}) + + with patch_rag([r]), \ + patch("compliance.api.extraction_routes.RegulationRepository") as MockRegRepo, \ + patch("compliance.api.extraction_routes.RequirementRepository") as MockReqRepo: + MockRegRepo.return_value.get_by_code.return_value = make_regulation() + MockReqRepo.return_value.get_by_regulation.return_value = [] + MockReqRepo.return_value.create.return_value = make_requirement() + + response = client.post("/compliance/extract-requirements-from-rag", json={}) + + data = response.json() + created_reqs = [x for x in data["requirements"] if x["action"] == "created"] + assert len(created_reqs) == 1 + assert created_reqs[0]["article"] == "O.Auth_2" + + def test_result_with_no_article_at_all_is_skipped(self): + """Results without any article identifier are skipped.""" + r = make_rag_result({ + "article": "", + "text": "General text without any structured identifier in the document.", + }) + + with patch_rag([r]): + response = client.post("/compliance/extract-requirements-from-rag", json={}) + + data = response.json() + assert data["skipped_no_article"] >= 1 + assert data["created"] == 0 + + def test_intra_batch_deduplication(self): + """Two results with same regulation+article should only create one requirement.""" + r1 = make_rag_result({"article": "O.Purp_6"}) + r2 = make_rag_result({"article": "O.Purp_6", "text": "O.Purp_6 Additional text about the same requirement."}) + + with patch_rag([r1, r2]), \ + patch("compliance.api.extraction_routes.RegulationRepository") as MockRegRepo, \ + patch("compliance.api.extraction_routes.RequirementRepository") as MockReqRepo: + MockRegRepo.return_value.get_by_code.return_value = make_regulation() + MockReqRepo.return_value.get_by_regulation.return_value = [] + MockReqRepo.return_value.create.return_value = make_requirement() + + response = client.post("/compliance/extract-requirements-from-rag", json={}) + + data = response.json() + # Only one should be created despite two results with same article + assert data["created"] == 1 + + +# --------------------------------------------------------------------------- +# Regulation auto-creation +# --------------------------------------------------------------------------- + +class TestRegulationAutoCreate: + """Tests for auto-creation of regulation stubs.""" + + def test_existing_regulation_not_recreated(self): + """If regulation already exists, create should NOT be called.""" + with patch_rag([make_rag_result()]), \ + patch("compliance.api.extraction_routes.RegulationRepository") as MockRegRepo, \ + patch("compliance.api.extraction_routes.RequirementRepository") as MockReqRepo: + MockRegRepo.return_value.get_by_code.return_value = make_regulation() + MockReqRepo.return_value.get_by_regulation.return_value = [] + MockReqRepo.return_value.create.return_value = make_requirement() + + response = client.post("/compliance/extract-requirements-from-rag", json={}) + + assert response.status_code == 200 + MockRegRepo.return_value.create.assert_not_called() + + def test_unknown_regulation_is_auto_created(self): + """Unknown regulation_code triggers RegulationRepository.create.""" + with patch_rag([make_rag_result()]), \ + patch("compliance.api.extraction_routes.RegulationRepository") as MockRegRepo, \ + patch("compliance.api.extraction_routes.RequirementRepository") as MockReqRepo: + MockRegRepo.return_value.get_by_code.return_value = None + MockRegRepo.return_value.create.return_value = make_regulation() + MockReqRepo.return_value.get_by_regulation.return_value = [] + MockReqRepo.return_value.create.return_value = make_requirement() + + response = client.post("/compliance/extract-requirements-from-rag", json={}) + + assert response.status_code == 200 + data = response.json() + assert data["created"] == 1 + MockRegRepo.return_value.create.assert_called_once() + + def test_multiple_regulations_in_one_run(self): + """Two results from different regulations should each get their regulation processed.""" + r1 = make_rag_result({"regulation_code": "BSI-TR-03161-1", "article": "O.Purp_6"}) + r2 = make_rag_result({ + "regulation_code": "GDPR", + "article": "Art. 32", + "text": "Art. 32 Sicherheit der Verarbeitung MUSS gewährleistet sein.", + }) + + with patch_rag([r1, r2]), \ + patch("compliance.api.extraction_routes.RegulationRepository") as MockRegRepo, \ + patch("compliance.api.extraction_routes.RequirementRepository") as MockReqRepo: + MockRegRepo.return_value.get_by_code.return_value = make_regulation() + MockReqRepo.return_value.get_by_regulation.return_value = [] + MockReqRepo.return_value.create.return_value = make_requirement() + + response = client.post("/compliance/extract-requirements-from-rag", json={}) + + assert response.status_code == 200 + data = response.json() + # Both requirements should be created + assert data["created"] == 2 + + +# --------------------------------------------------------------------------- +# Response structure +# --------------------------------------------------------------------------- + +class TestResponseStructure: + """Verify the full response structure.""" + + def test_all_fields_present(self): + with patch_rag([]): + response = client.post("/compliance/extract-requirements-from-rag", json={}) + assert response.status_code == 200 + data = response.json() + for field in [ + "created", "skipped_duplicates", "skipped_no_article", + "failed", "collections_searched", "queries_used", + "requirements", "dry_run", "message", + ]: + assert field in data, f"Missing field: {field}" + + def test_message_contains_summary(self): + with patch_rag([]): + response = client.post("/compliance/extract-requirements-from-rag", json={}) + data = response.json() + assert "Erstellt:" in data["message"] + + def test_dry_run_message_prefix(self): + with patch_rag([]): + response = client.post("/compliance/extract-requirements-from-rag", json={"dry_run": True}) + data = response.json() + assert "[DRY RUN]" in data["message"]