Files
breakpilot-compliance/backend-compliance/compliance/api/extraction_routes.py
Benjamin Admin 3ed8300daf
All checks were successful
CI / go-lint (push) Has been skipped
CI / python-lint (push) Has been skipped
CI / nodejs-lint (push) Has been skipped
CI / test-go-ai-compliance (push) Successful in 34s
CI / test-python-backend-compliance (push) Successful in 31s
CI / test-python-document-crawler (push) Successful in 35s
CI / test-python-dsms-gateway (push) Successful in 17s
feat(extraction): POST /compliance/extract-requirements-from-rag
Sucht alle RAG-Kollektionen nach Prüfaspekten und legt automatisch
Anforderungen in der DB an. Kernfeatures:

- Durchsucht alle 6 RAG-Kollektionen parallel (bp_compliance_ce,
  bp_compliance_recht, bp_compliance_gesetze, bp_compliance_datenschutz,
  bp_dsfa_corpus, bp_legal_templates)
- Erkennt BSI Prüfaspekte (O.Purp_6) im Artikel-Feld und per Regex
- Dedupliziert nach (regulation_code, article) — safe to call many times
- Auto-erstellt Regulations-Stubs für unbekannte regulation_codes
- dry_run=true zeigt was erstellt würde ohne DB-Schreibzugriff
- Optionale Filter: collections, regulation_codes, search_queries
- 18 Tests (alle bestanden)
- Frontend: "Aus RAG extrahieren" Button auf /sdk/requirements

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-05 15:11:10 +01:00

367 lines
13 KiB
Python

"""
FastAPI routes for RAG-based Requirement Extraction.
Endpoints:
- POST /compliance/extract-requirements-from-rag:
Searches ALL RAG collections (or a subset) for audit criteria / Prüfaspekte
and creates Requirement entries in the DB.
Design principles:
- Searches every relevant collection in parallel
- Deduplicates by (regulation_id, article) — never inserts twice
- Auto-creates Regulation stubs for unknown regulation_codes
- LLM-free by default (fast); optional LLM title extraction via ?use_llm=true
- dry_run=true returns what would be created without touching the DB
"""
import logging
import re
import asyncio
from typing import Optional, List, Dict
from datetime import datetime
from fastapi import APIRouter, Depends, Query
from pydantic import BaseModel
from sqlalchemy.orm import Session
from classroom_engine.database import get_db
from ..db import RegulationRepository, RequirementRepository
from ..db.models import RegulationDB, RequirementDB, RegulationTypeEnum
from ..services.rag_client import get_rag_client, RAGSearchResult
logger = logging.getLogger(__name__)
router = APIRouter(tags=["extraction"])
# ---------------------------------------------------------------------------
# Collections that may contain Prüfaspekte / audit criteria
# ---------------------------------------------------------------------------
ALL_COLLECTIONS = [
"bp_compliance_ce", # BSI-TR documents — primary Prüfaspekte source
"bp_compliance_recht", # Legal texts (GDPR, AI Act, ...)
"bp_compliance_gesetze", # German laws
"bp_compliance_datenschutz", # Data protection documents
"bp_dsfa_corpus", # DSFA corpus
"bp_legal_templates", # Legal templates
]
# Search queries targeting audit criteria across different document types
DEFAULT_QUERIES = [
"Prüfaspekt Anforderung MUSS SOLL",
"security requirement SHALL MUST",
"compliance requirement audit criterion",
"Sicherheitsanforderung technische Maßnahme",
"data protection requirement obligation",
]
# ---------------------------------------------------------------------------
# Schemas
# ---------------------------------------------------------------------------
class ExtractionRequest(BaseModel):
collections: Optional[List[str]] = None # None = ALL_COLLECTIONS
search_queries: Optional[List[str]] = None # None = DEFAULT_QUERIES
regulation_codes: Optional[List[str]] = None # None = all regulations
max_per_query: int = 20 # top_k per search query
dry_run: bool = False # if True: no DB writes
class ExtractedRequirement(BaseModel):
regulation_code: str
article: str
title: str
requirement_text: str
source_url: str
score: float
action: str # "created" | "skipped_duplicate" | "skipped_no_article"
class ExtractionResponse(BaseModel):
created: int
skipped_duplicates: int
skipped_no_article: int
failed: int
collections_searched: List[str]
queries_used: List[str]
requirements: List[ExtractedRequirement]
dry_run: bool
message: str
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
BSI_ASPECT_RE = re.compile(r"\b([A-Z]\.[A-Za-z]+_\d+)\b")
TITLE_SENTENCE_RE = re.compile(r"^([^.!?\n]{10,120})[.!?\n]")
def _derive_title(text: str, article: str) -> str:
"""Extract a short title from RAG chunk text."""
# Remove leading article reference if present
cleaned = re.sub(r"^" + re.escape(article) + r"[:\s]+", "", text.strip(), flags=re.IGNORECASE)
# Take first meaningful sentence
m = TITLE_SENTENCE_RE.match(cleaned)
if m:
return m.group(1).strip()[:200]
# Fallback: first 100 chars
return cleaned[:100].strip() or article
def _normalize_article(result: RAGSearchResult) -> Optional[str]:
"""
Return a canonical article identifier from the RAG result.
Returns None if no meaningful article can be determined.
"""
article = (result.article or "").strip()
if article:
return article
# Try to find BSI Prüfaspekt pattern in chunk text
m = BSI_ASPECT_RE.search(result.text)
if m:
return m.group(1)
return None
async def _search_collection(
collection: str,
queries: List[str],
max_per_query: int,
) -> List[RAGSearchResult]:
"""Run all queries against one collection and merge deduplicated results."""
rag = get_rag_client()
seen_texts: set[str] = set()
results: List[RAGSearchResult] = []
for query in queries:
hits = await rag.search(query, collection=collection, top_k=max_per_query)
for h in hits:
key = h.text[:120] # rough dedup key
if key not in seen_texts:
seen_texts.add(key)
results.append(h)
return results
def _get_or_create_regulation(
db: Session,
regulation_code: str,
regulation_name: str,
) -> RegulationDB:
"""Return existing Regulation or create a stub."""
repo = RegulationRepository(db)
reg = repo.get_by_code(regulation_code)
if reg:
return reg
# Auto-create a stub so Requirements can reference it
logger.info("Auto-creating regulation stub: %s", regulation_code)
# Infer type from code prefix
if regulation_code.startswith("BSI"):
reg_type = RegulationTypeEnum.BSI_STANDARD
elif regulation_code in ("GDPR", "AI_ACT", "NIS2", "CRA"):
reg_type = RegulationTypeEnum.EU_REGULATION
else:
reg_type = RegulationTypeEnum.INDUSTRY_STANDARD
reg = repo.create(
code=regulation_code,
name=regulation_name or regulation_code,
regulation_type=reg_type,
description=f"Auto-created from RAG extraction ({datetime.utcnow().date()})",
)
return reg
def _build_existing_articles(
db: Session, regulation_id: str
) -> set[str]:
"""Return set of existing article strings for this regulation."""
repo = RequirementRepository(db)
existing = repo.get_by_regulation(regulation_id)
return {r.article for r in existing}
# ---------------------------------------------------------------------------
# Endpoint
# ---------------------------------------------------------------------------
@router.post("/compliance/extract-requirements-from-rag", response_model=ExtractionResponse)
async def extract_requirements_from_rag(
body: ExtractionRequest,
db: Session = Depends(get_db),
):
"""
Search all RAG collections for Prüfaspekte / audit criteria and create
Requirement entries in the compliance DB.
- Deduplicates by (regulation_code, article) — safe to call multiple times.
- Auto-creates Regulation stubs for previously unknown regulation_codes.
- Use `dry_run=true` to preview results without any DB writes.
- Use `regulation_codes` to restrict to specific regulations (e.g. ["BSI-TR-03161-1"]).
"""
collections = body.collections or ALL_COLLECTIONS
queries = body.search_queries or DEFAULT_QUERIES
# --- 1. Search all collections in parallel ---
search_tasks = [
_search_collection(col, queries, body.max_per_query)
for col in collections
]
collection_results: List[List[RAGSearchResult]] = await asyncio.gather(
*search_tasks, return_exceptions=True
)
# Flatten, skip exceptions
all_results: List[RAGSearchResult] = []
for col, res in zip(collections, collection_results):
if isinstance(res, Exception):
logger.warning("Collection %s search failed: %s", col, res)
else:
all_results.extend(res)
logger.info("RAG extraction: %d raw results from %d collections", len(all_results), len(collections))
# --- 2. Filter by regulation_codes if requested ---
if body.regulation_codes:
all_results = [
r for r in all_results
if r.regulation_code in body.regulation_codes
]
# --- 3. Deduplicate at result level (regulation_code + article) ---
seen: set[tuple[str, str]] = set()
unique_results: List[RAGSearchResult] = []
for r in sorted(all_results, key=lambda x: x.score, reverse=True):
article = _normalize_article(r)
if not article:
continue
key = (r.regulation_code, article)
if key not in seen:
seen.add(key)
unique_results.append(r)
logger.info("RAG extraction: %d unique (regulation, article) pairs", len(unique_results))
# --- 4. Group by regulation_code and process ---
by_reg: Dict[str, List[tuple[str, RAGSearchResult]]] = {}
skipped_no_article: List[RAGSearchResult] = []
for r in all_results:
article = _normalize_article(r)
if not article:
skipped_no_article.append(r)
continue
key_r = r.regulation_code or "UNKNOWN"
if key_r not in by_reg:
by_reg[key_r] = []
by_reg[key_r].append((article, r))
# Deduplicate within groups
deduped_by_reg: Dict[str, List[tuple[str, RAGSearchResult]]] = {}
for reg_code, items in by_reg.items():
seen_articles: set[str] = set()
deduped: List[tuple[str, RAGSearchResult]] = []
for art, r in sorted(items, key=lambda x: x[1].score, reverse=True):
if art not in seen_articles:
seen_articles.add(art)
deduped.append((art, r))
deduped_by_reg[reg_code] = deduped
# --- 5. Create requirements ---
req_repo = RequirementRepository(db)
created_count = 0
skipped_dup_count = 0
failed_count = 0
result_items: List[ExtractedRequirement] = []
for reg_code, items in deduped_by_reg.items():
if not items:
continue
# Find or create regulation
try:
first_result = items[0][1]
regulation_name = first_result.regulation_name or first_result.regulation_short or reg_code
if body.dry_run:
# For dry_run, fake a regulation id
regulation_id = f"dry-run-{reg_code}"
existing_articles: set[str] = set()
else:
reg = _get_or_create_regulation(db, reg_code, regulation_name)
regulation_id = reg.id
existing_articles = _build_existing_articles(db, regulation_id)
except Exception as e:
logger.error("Failed to get/create regulation %s: %s", reg_code, e)
failed_count += len(items)
continue
for article, r in items:
title = _derive_title(r.text, article)
if article in existing_articles:
skipped_dup_count += 1
result_items.append(ExtractedRequirement(
regulation_code=reg_code,
article=article,
title=title,
requirement_text=r.text[:1000],
source_url=r.source_url,
score=r.score,
action="skipped_duplicate",
))
continue
if not body.dry_run:
try:
req_repo.create(
regulation_id=regulation_id,
article=article,
title=title,
description=f"Extrahiert aus RAG-Korpus (Collection: {r.category or r.regulation_code}). Score: {r.score:.2f}",
requirement_text=r.text[:2000],
breakpilot_interpretation=None,
is_applicable=True,
priority=2,
)
existing_articles.add(article) # prevent intra-batch duplication
created_count += 1
except Exception as e:
logger.error("Failed to create requirement %s/%s: %s", reg_code, article, e)
failed_count += 1
continue
else:
created_count += 1 # dry_run: count as would-create
result_items.append(ExtractedRequirement(
regulation_code=reg_code,
article=article,
title=title,
requirement_text=r.text[:1000],
source_url=r.source_url,
score=r.score,
action="created" if not body.dry_run else "would_create",
))
message = (
f"{'[DRY RUN] ' if body.dry_run else ''}"
f"Erstellt: {created_count}, Duplikate übersprungen: {skipped_dup_count}, "
f"Ohne Artikel-ID übersprungen: {len(skipped_no_article)}, Fehler: {failed_count}"
)
logger.info("RAG extraction complete: %s", message)
return ExtractionResponse(
created=created_count,
skipped_duplicates=skipped_dup_count,
skipped_no_article=len(skipped_no_article),
failed=failed_count,
collections_searched=collections,
queries_used=queries,
requirements=result_items,
dry_run=body.dry_run,
message=message,
)