""" FastAPI routes for RAG-based Requirement Extraction. Endpoints: - POST /compliance/extract-requirements-from-rag: Searches ALL RAG collections (or a subset) for audit criteria / Prüfaspekte and creates Requirement entries in the DB. Design principles: - Searches every relevant collection in parallel - Deduplicates by (regulation_id, article) — never inserts twice - Auto-creates Regulation stubs for unknown regulation_codes - LLM-free by default (fast); optional LLM title extraction via ?use_llm=true - dry_run=true returns what would be created without touching the DB """ import logging import re import asyncio from typing import Optional, List, Dict from datetime import datetime from fastapi import APIRouter, Depends from pydantic import BaseModel from sqlalchemy.orm import Session from classroom_engine.database import get_db from ..db import RegulationRepository, RequirementRepository from ..db.models import RegulationDB, RegulationTypeEnum from ..services.rag_client import get_rag_client, RAGSearchResult logger = logging.getLogger(__name__) router = APIRouter(tags=["extraction"]) # --------------------------------------------------------------------------- # Collections that may contain Prüfaspekte / audit criteria # --------------------------------------------------------------------------- ALL_COLLECTIONS = [ "bp_compliance_ce", # BSI-TR documents — primary Prüfaspekte source "bp_compliance_gesetze", # German laws "bp_compliance_datenschutz", # Data protection documents "bp_dsfa_corpus", # DSFA corpus "bp_legal_templates", # Legal templates ] # Search queries targeting audit criteria across different document types DEFAULT_QUERIES = [ "Prüfaspekt Anforderung MUSS SOLL", "security requirement SHALL MUST", "compliance requirement audit criterion", "Sicherheitsanforderung technische Maßnahme", "data protection requirement obligation", ] # --------------------------------------------------------------------------- # Schemas # --------------------------------------------------------------------------- class ExtractionRequest(BaseModel): collections: Optional[List[str]] = None # None = ALL_COLLECTIONS search_queries: Optional[List[str]] = None # None = DEFAULT_QUERIES regulation_codes: Optional[List[str]] = None # None = all regulations max_per_query: int = 20 # top_k per search query dry_run: bool = False # if True: no DB writes class ExtractedRequirement(BaseModel): regulation_code: str article: str title: str requirement_text: str source_url: str score: float action: str # "created" | "skipped_duplicate" | "skipped_no_article" class ExtractionResponse(BaseModel): created: int skipped_duplicates: int skipped_no_article: int failed: int collections_searched: List[str] queries_used: List[str] requirements: List[ExtractedRequirement] dry_run: bool message: str # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- BSI_ASPECT_RE = re.compile(r"\b([A-Z]\.[A-Za-z]+_\d+)\b") TITLE_SENTENCE_RE = re.compile(r"^([^.!?\n]{10,120})[.!?\n]") def _derive_title(text: str, article: str) -> str: """Extract a short title from RAG chunk text.""" # Remove leading article reference if present cleaned = re.sub(r"^" + re.escape(article) + r"[:\s]+", "", text.strip(), flags=re.IGNORECASE) # Take first meaningful sentence m = TITLE_SENTENCE_RE.match(cleaned) if m: return m.group(1).strip()[:200] # Fallback: first 100 chars return cleaned[:100].strip() or article def _normalize_article(result: RAGSearchResult) -> Optional[str]: """ Return a canonical article identifier from the RAG result. Returns None if no meaningful article can be determined. """ article = (result.article or "").strip() if article: return article # Try to find BSI Prüfaspekt pattern in chunk text m = BSI_ASPECT_RE.search(result.text) if m: return m.group(1) return None async def _search_collection( collection: str, queries: List[str], max_per_query: int, ) -> List[RAGSearchResult]: """Run all queries against one collection and merge deduplicated results.""" rag = get_rag_client() seen_texts: set[str] = set() results: List[RAGSearchResult] = [] for query in queries: hits = await rag.search(query, collection=collection, top_k=max_per_query) for h in hits: key = h.text[:120] # rough dedup key if key not in seen_texts: seen_texts.add(key) results.append(h) return results def _get_or_create_regulation( db: Session, regulation_code: str, regulation_name: str, ) -> RegulationDB: """Return existing Regulation or create a stub.""" repo = RegulationRepository(db) reg = repo.get_by_code(regulation_code) if reg: return reg # Auto-create a stub so Requirements can reference it logger.info("Auto-creating regulation stub: %s", regulation_code) # Infer type from code prefix if regulation_code.startswith("BSI"): reg_type = RegulationTypeEnum.BSI_STANDARD elif regulation_code in ("GDPR", "AI_ACT", "NIS2", "CRA"): reg_type = RegulationTypeEnum.EU_REGULATION else: reg_type = RegulationTypeEnum.INDUSTRY_STANDARD reg = repo.create( code=regulation_code, name=regulation_name or regulation_code, regulation_type=reg_type, description=f"Auto-created from RAG extraction ({datetime.utcnow().date()})", ) return reg def _build_existing_articles( db: Session, regulation_id: str ) -> set[str]: """Return set of existing article strings for this regulation.""" repo = RequirementRepository(db) existing = repo.get_by_regulation(regulation_id) return {r.article for r in existing} # --------------------------------------------------------------------------- # Extraction helpers — independently testable # --------------------------------------------------------------------------- def _parse_rag_results( all_results: List[RAGSearchResult], regulation_codes: Optional[List[str]] = None, ) -> dict: """ Filter, deduplicate, and group RAG search results by regulation code. Returns a dict with: - deduped_by_reg: Dict[str, List[tuple[str, RAGSearchResult]]] - skipped_no_article: List[RAGSearchResult] - unique_count: int """ # Filter by regulation_codes if requested if regulation_codes: all_results = [ r for r in all_results if r.regulation_code in regulation_codes ] # Deduplicate at result level (regulation_code + article) seen: set[tuple[str, str]] = set() unique_count = 0 for r in sorted(all_results, key=lambda x: x.score, reverse=True): article = _normalize_article(r) if not article: continue key = (r.regulation_code, article) if key not in seen: seen.add(key) unique_count += 1 # Group by regulation_code by_reg: Dict[str, List[tuple[str, RAGSearchResult]]] = {} skipped_no_article: List[RAGSearchResult] = [] for r in all_results: article = _normalize_article(r) if not article: skipped_no_article.append(r) continue key_r = r.regulation_code or "UNKNOWN" if key_r not in by_reg: by_reg[key_r] = [] by_reg[key_r].append((article, r)) # Deduplicate within groups deduped_by_reg: Dict[str, List[tuple[str, RAGSearchResult]]] = {} for reg_code, items in by_reg.items(): seen_articles: set[str] = set() deduped: List[tuple[str, RAGSearchResult]] = [] for art, r in sorted(items, key=lambda x: x[1].score, reverse=True): if art not in seen_articles: seen_articles.add(art) deduped.append((art, r)) deduped_by_reg[reg_code] = deduped return { "deduped_by_reg": deduped_by_reg, "skipped_no_article": skipped_no_article, "unique_count": unique_count, } def _store_requirements( db: Session, deduped_by_reg: Dict[str, List[tuple[str, "RAGSearchResult"]]], dry_run: bool, ) -> dict: """ Persist extracted requirements to the database (or simulate in dry_run mode). Returns a dict with: - created_count: int - skipped_dup_count: int - failed_count: int - result_items: List[ExtractedRequirement] """ req_repo = RequirementRepository(db) created_count = 0 skipped_dup_count = 0 failed_count = 0 result_items: List[ExtractedRequirement] = [] for reg_code, items in deduped_by_reg.items(): if not items: continue # Find or create regulation try: first_result = items[0][1] regulation_name = first_result.regulation_name or first_result.regulation_short or reg_code if dry_run: # For dry_run, fake a regulation id regulation_id = f"dry-run-{reg_code}" existing_articles: set[str] = set() else: reg = _get_or_create_regulation(db, reg_code, regulation_name) regulation_id = reg.id existing_articles = _build_existing_articles(db, regulation_id) except Exception as e: logger.error("Failed to get/create regulation %s: %s", reg_code, e) failed_count += len(items) continue for article, r in items: title = _derive_title(r.text, article) if article in existing_articles: skipped_dup_count += 1 result_items.append(ExtractedRequirement( regulation_code=reg_code, article=article, title=title, requirement_text=r.text[:1000], source_url=r.source_url, score=r.score, action="skipped_duplicate", )) continue if not dry_run: try: req_repo.create( regulation_id=regulation_id, article=article, title=title, description=f"Extrahiert aus RAG-Korpus (Collection: {r.category or r.regulation_code}). Score: {r.score:.2f}", requirement_text=r.text[:2000], breakpilot_interpretation=None, is_applicable=True, priority=2, ) existing_articles.add(article) # prevent intra-batch duplication created_count += 1 except Exception as e: logger.error("Failed to create requirement %s/%s: %s", reg_code, article, e) failed_count += 1 continue else: created_count += 1 # dry_run: count as would-create result_items.append(ExtractedRequirement( regulation_code=reg_code, article=article, title=title, requirement_text=r.text[:1000], source_url=r.source_url, score=r.score, action="created" if not dry_run else "would_create", )) return { "created_count": created_count, "skipped_dup_count": skipped_dup_count, "failed_count": failed_count, "result_items": result_items, } # --------------------------------------------------------------------------- # Endpoint # --------------------------------------------------------------------------- @router.post("/compliance/extract-requirements-from-rag", response_model=ExtractionResponse) async def extract_requirements_from_rag( body: ExtractionRequest, db: Session = Depends(get_db), ): """ Search all RAG collections for Prüfaspekte / audit criteria and create Requirement entries in the compliance DB. - Deduplicates by (regulation_code, article) — safe to call multiple times. - Auto-creates Regulation stubs for previously unknown regulation_codes. - Use `dry_run=true` to preview results without any DB writes. - Use `regulation_codes` to restrict to specific regulations (e.g. ["BSI-TR-03161-1"]). """ collections = body.collections or ALL_COLLECTIONS queries = body.search_queries or DEFAULT_QUERIES # --- 1. Search all collections in parallel --- search_tasks = [ _search_collection(col, queries, body.max_per_query) for col in collections ] collection_results: List[List[RAGSearchResult]] = await asyncio.gather( *search_tasks, return_exceptions=True ) # Flatten, skip exceptions all_results: List[RAGSearchResult] = [] for col, res in zip(collections, collection_results): if isinstance(res, Exception): logger.warning("Collection %s search failed: %s", col, res) else: all_results.extend(res) logger.info("RAG extraction: %d raw results from %d collections", len(all_results), len(collections)) # --- 2. Parse, filter, deduplicate, and group --- parsed = _parse_rag_results(all_results, body.regulation_codes) deduped_by_reg = parsed["deduped_by_reg"] skipped_no_article = parsed["skipped_no_article"] logger.info("RAG extraction: %d unique (regulation, article) pairs", parsed["unique_count"]) # --- 3. Create requirements --- store_result = _store_requirements(db, deduped_by_reg, body.dry_run) created_count = store_result["created_count"] skipped_dup_count = store_result["skipped_dup_count"] failed_count = store_result["failed_count"] result_items = store_result["result_items"] message = ( f"{'[DRY RUN] ' if body.dry_run else ''}" f"Erstellt: {created_count}, Duplikate übersprungen: {skipped_dup_count}, " f"Ohne Artikel-ID übersprungen: {len(skipped_no_article)}, Fehler: {failed_count}" ) logger.info("RAG extraction complete: %s", message) return ExtractionResponse( created=created_count, skipped_duplicates=skipped_dup_count, skipped_no_article=len(skipped_no_article), failed=failed_count, collections_searched=collections, queries_used=queries, requirements=result_items, dry_run=body.dry_run, message=message, )