""" DSFA RAG API Endpoints. Provides REST API for searching DSFA corpus with full source attribution. Endpoints: - GET /api/v1/dsfa-rag/search - Semantic search with attribution - GET /api/v1/dsfa-rag/sources - List all registered sources - POST /api/v1/dsfa-rag/sources/{code}/ingest - Trigger source ingestion - GET /api/v1/dsfa-rag/chunks/{id} - Get single chunk with attribution - GET /api/v1/dsfa-rag/stats - Get corpus statistics """ import os import uuid import logging from typing import List, Optional from dataclasses import dataclass, asdict import httpx from fastapi import APIRouter, HTTPException, Query, Depends from pydantic import BaseModel, Field logger = logging.getLogger(__name__) # Embedding service configuration EMBEDDING_SERVICE_URL = os.getenv("EMBEDDING_SERVICE_URL", "http://172.18.0.13:8087") # Import from ingestion module from dsfa_corpus_ingestion import ( DSFACorpusStore, DSFAQdrantService, DSFASearchResult, LICENSE_REGISTRY, DSFA_SOURCES, generate_attribution_notice, get_license_label, DSFA_COLLECTION, chunk_document ) router = APIRouter(prefix="/api/v1/dsfa-rag", tags=["DSFA RAG"]) # ============================================================================= # Pydantic Models # ============================================================================= class DSFASourceResponse(BaseModel): """Response model for DSFA source.""" id: str source_code: str name: str full_name: Optional[str] = None organization: Optional[str] = None source_url: Optional[str] = None license_code: str license_name: str license_url: Optional[str] = None attribution_required: bool attribution_text: str document_type: Optional[str] = None language: str = "de" class DSFAChunkResponse(BaseModel): """Response model for a single chunk with attribution.""" chunk_id: str content: str section_title: Optional[str] = None page_number: Optional[int] = None category: Optional[str] = None # Document info document_id: str document_title: Optional[str] = None # Attribution (always included) source_id: str source_code: str source_name: str attribution_text: str license_code: str license_name: str license_url: Optional[str] = None attribution_required: bool source_url: Optional[str] = None document_type: Optional[str] = None class DSFASearchResultResponse(BaseModel): """Response model for search result.""" chunk_id: str content: str score: float # Attribution source_code: str source_name: str attribution_text: str license_code: str license_name: str license_url: Optional[str] = None attribution_required: bool source_url: Optional[str] = None # Metadata document_type: Optional[str] = None category: Optional[str] = None section_title: Optional[str] = None page_number: Optional[int] = None class DSFASearchResponse(BaseModel): """Response model for search endpoint.""" query: str results: List[DSFASearchResultResponse] total_results: int # Aggregated licenses for footer licenses_used: List[str] attribution_notice: str class DSFASourceStatsResponse(BaseModel): """Response model for source statistics.""" source_id: str source_code: str name: str organization: Optional[str] = None license_code: str document_type: Optional[str] = None document_count: int chunk_count: int last_indexed_at: Optional[str] = None class DSFACorpusStatsResponse(BaseModel): """Response model for corpus statistics.""" sources: List[DSFASourceStatsResponse] total_sources: int total_documents: int total_chunks: int qdrant_collection: str qdrant_points_count: int qdrant_status: str class IngestRequest(BaseModel): """Request model for ingestion.""" document_url: Optional[str] = None document_text: Optional[str] = None title: Optional[str] = None class IngestResponse(BaseModel): """Response model for ingestion.""" source_code: str document_id: Optional[str] = None chunks_created: int message: str class LicenseInfo(BaseModel): """License information.""" code: str name: str url: Optional[str] = None attribution_required: bool modification_allowed: bool commercial_use: bool # ============================================================================= # Dependency Injection # ============================================================================= # Database pool (will be set from main.py) _db_pool = None def set_db_pool(pool): """Set the database pool for API endpoints.""" global _db_pool _db_pool = pool async def get_store() -> DSFACorpusStore: """Get DSFA corpus store.""" if _db_pool is None: raise HTTPException(status_code=503, detail="Database not initialized") return DSFACorpusStore(_db_pool) async def get_qdrant() -> DSFAQdrantService: """Get Qdrant service.""" return DSFAQdrantService() # ============================================================================= # Embedding Service Integration # ============================================================================= async def get_embedding(text: str) -> List[float]: """ Get embedding for text using the embedding-service. Uses BGE-M3 model which produces 1024-dimensional vectors. """ async with httpx.AsyncClient(timeout=60.0) as client: try: response = await client.post( f"{EMBEDDING_SERVICE_URL}/embed-single", json={"text": text} ) response.raise_for_status() data = response.json() return data.get("embedding", []) except httpx.HTTPError as e: logger.error(f"Embedding service error: {e}") # Fallback to hash-based pseudo-embedding for development return _generate_fallback_embedding(text) async def get_embeddings_batch(texts: List[str]) -> List[List[float]]: """ Get embeddings for multiple texts in batch. """ async with httpx.AsyncClient(timeout=120.0) as client: try: response = await client.post( f"{EMBEDDING_SERVICE_URL}/embed", json={"texts": texts} ) response.raise_for_status() data = response.json() return data.get("embeddings", []) except httpx.HTTPError as e: logger.error(f"Embedding service batch error: {e}") # Fallback return [_generate_fallback_embedding(t) for t in texts] async def extract_text_from_url(url: str) -> str: """ Extract text from a document URL (PDF, HTML, etc.). """ async with httpx.AsyncClient(timeout=120.0) as client: try: # First try to use the embedding-service's extract-pdf endpoint response = await client.post( f"{EMBEDDING_SERVICE_URL}/extract-pdf", json={"url": url} ) response.raise_for_status() data = response.json() return data.get("text", "") except httpx.HTTPError as e: logger.error(f"PDF extraction error for {url}: {e}") # Fallback: try to fetch HTML content directly try: response = await client.get(url, follow_redirects=True) response.raise_for_status() content_type = response.headers.get("content-type", "") if "html" in content_type: # Simple HTML text extraction import re html = response.text # Remove scripts and styles html = re.sub(r']*>.*?', '', html, flags=re.DOTALL | re.IGNORECASE) html = re.sub(r']*>.*?', '', html, flags=re.DOTALL | re.IGNORECASE) # Remove tags text = re.sub(r'<[^>]+>', ' ', html) # Clean whitespace text = re.sub(r'\s+', ' ', text).strip() return text else: return "" except Exception as fetch_err: logger.error(f"Fallback fetch error for {url}: {fetch_err}") return "" def _generate_fallback_embedding(text: str) -> List[float]: """ Generate deterministic pseudo-embedding from text hash. Used as fallback when embedding service is unavailable. """ import hashlib import struct hash_bytes = hashlib.sha256(text.encode()).digest() embedding = [] for i in range(0, min(len(hash_bytes), 128), 4): val = struct.unpack('f', hash_bytes[i:i+4])[0] embedding.append(val % 1.0) # Pad to 1024 dimensions while len(embedding) < 1024: embedding.extend(embedding[:min(len(embedding), 1024 - len(embedding))]) return embedding[:1024] # ============================================================================= # API Endpoints # ============================================================================= @router.get("/search", response_model=DSFASearchResponse) async def search_dsfa_corpus( query: str = Query(..., min_length=3, description="Search query"), source_codes: Optional[List[str]] = Query(None, description="Filter by source codes"), document_types: Optional[List[str]] = Query(None, description="Filter by document types (guideline, checklist, regulation)"), categories: Optional[List[str]] = Query(None, description="Filter by categories (threshold_analysis, risk_assessment, mitigation)"), limit: int = Query(10, ge=1, le=50, description="Maximum results"), include_attribution: bool = Query(True, description="Include attribution in results"), store: DSFACorpusStore = Depends(get_store), qdrant: DSFAQdrantService = Depends(get_qdrant) ): """ Search DSFA corpus with full attribution. Returns matching chunks with source/license information for compliance. """ # Get query embedding query_embedding = await get_embedding(query) # Search Qdrant raw_results = await qdrant.search( query_embedding=query_embedding, source_codes=source_codes, document_types=document_types, categories=categories, limit=limit ) # Transform results results = [] licenses_used = set() for r in raw_results: license_code = r.get("license_code", "") license_info = LICENSE_REGISTRY.get(license_code, {}) result = DSFASearchResultResponse( chunk_id=r.get("chunk_id", ""), content=r.get("content", ""), score=r.get("score", 0.0), source_code=r.get("source_code", ""), source_name=r.get("source_name", ""), attribution_text=r.get("attribution_text", ""), license_code=license_code, license_name=license_info.get("name", license_code), license_url=license_info.get("url"), attribution_required=r.get("attribution_required", True), source_url=r.get("source_url"), document_type=r.get("document_type"), category=r.get("category"), section_title=r.get("section_title"), page_number=r.get("page_number") ) results.append(result) licenses_used.add(license_code) # Generate attribution notice search_results = [ DSFASearchResult( chunk_id=r.chunk_id, content=r.content, score=r.score, source_code=r.source_code, source_name=r.source_name, attribution_text=r.attribution_text, license_code=r.license_code, license_url=r.license_url, attribution_required=r.attribution_required, source_url=r.source_url, document_type=r.document_type or "", category=r.category or "", section_title=r.section_title, page_number=r.page_number ) for r in results ] attribution_notice = generate_attribution_notice(search_results) if include_attribution else "" return DSFASearchResponse( query=query, results=results, total_results=len(results), licenses_used=list(licenses_used), attribution_notice=attribution_notice ) @router.get("/sources", response_model=List[DSFASourceResponse]) async def list_dsfa_sources( document_type: Optional[str] = Query(None, description="Filter by document type"), license_code: Optional[str] = Query(None, description="Filter by license"), store: DSFACorpusStore = Depends(get_store) ): """List all registered DSFA sources with license info.""" sources = await store.list_sources() result = [] for s in sources: # Apply filters if document_type and s.get("document_type") != document_type: continue if license_code and s.get("license_code") != license_code: continue license_info = LICENSE_REGISTRY.get(s.get("license_code", ""), {}) result.append(DSFASourceResponse( id=str(s["id"]), source_code=s["source_code"], name=s["name"], full_name=s.get("full_name"), organization=s.get("organization"), source_url=s.get("source_url"), license_code=s.get("license_code", ""), license_name=license_info.get("name", s.get("license_code", "")), license_url=license_info.get("url"), attribution_required=s.get("attribution_required", True), attribution_text=s.get("attribution_text", ""), document_type=s.get("document_type"), language=s.get("language", "de") )) return result @router.get("/sources/available") async def list_available_sources(): """List all available source definitions (from DSFA_SOURCES constant).""" return [ { "source_code": s["source_code"], "name": s["name"], "organization": s.get("organization"), "license_code": s["license_code"], "document_type": s.get("document_type") } for s in DSFA_SOURCES ] @router.get("/sources/{source_code}", response_model=DSFASourceResponse) async def get_dsfa_source( source_code: str, store: DSFACorpusStore = Depends(get_store) ): """Get details for a specific source.""" source = await store.get_source_by_code(source_code) if not source: raise HTTPException(status_code=404, detail=f"Source not found: {source_code}") license_info = LICENSE_REGISTRY.get(source.get("license_code", ""), {}) return DSFASourceResponse( id=str(source["id"]), source_code=source["source_code"], name=source["name"], full_name=source.get("full_name"), organization=source.get("organization"), source_url=source.get("source_url"), license_code=source.get("license_code", ""), license_name=license_info.get("name", source.get("license_code", "")), license_url=license_info.get("url"), attribution_required=source.get("attribution_required", True), attribution_text=source.get("attribution_text", ""), document_type=source.get("document_type"), language=source.get("language", "de") ) @router.post("/sources/{source_code}/ingest", response_model=IngestResponse) async def ingest_dsfa_source( source_code: str, request: IngestRequest, store: DSFACorpusStore = Depends(get_store), qdrant: DSFAQdrantService = Depends(get_qdrant) ): """ Trigger ingestion for a specific source. Can provide document via URL or direct text. """ # Get source source = await store.get_source_by_code(source_code) if not source: raise HTTPException(status_code=404, detail=f"Source not found: {source_code}") # Need either URL or text if not request.document_text and not request.document_url: raise HTTPException( status_code=400, detail="Either document_text or document_url must be provided" ) # Ensure Qdrant collection exists await qdrant.ensure_collection() # Get text content text_content = request.document_text if request.document_url and not text_content: # Download and extract text from URL logger.info(f"Extracting text from URL: {request.document_url}") text_content = await extract_text_from_url(request.document_url) if not text_content: raise HTTPException( status_code=400, detail=f"Could not extract text from URL: {request.document_url}" ) if not text_content or len(text_content.strip()) < 50: raise HTTPException(status_code=400, detail="Document text too short (min 50 chars)") # Create document record doc_title = request.title or f"Document for {source_code}" document_id = await store.create_document( source_id=str(source["id"]), title=doc_title, file_type="text", metadata={"ingested_via": "api", "source_code": source_code} ) # Chunk the document chunks = chunk_document(text_content, source_code) if not chunks: return IngestResponse( source_code=source_code, document_id=document_id, chunks_created=0, message="Document created but no chunks generated" ) # Generate embeddings in batch for efficiency chunk_texts = [chunk["content"] for chunk in chunks] logger.info(f"Generating embeddings for {len(chunk_texts)} chunks...") embeddings = await get_embeddings_batch(chunk_texts) # Create chunk records in PostgreSQL and prepare for Qdrant chunk_records = [] for i, (chunk, embedding) in enumerate(zip(chunks, embeddings)): # Create chunk in PostgreSQL chunk_id = await store.create_chunk( document_id=document_id, source_id=str(source["id"]), content=chunk["content"], chunk_index=i, section_title=chunk.get("section_title"), page_number=chunk.get("page_number"), category=chunk.get("category") ) chunk_records.append({ "chunk_id": chunk_id, "document_id": document_id, "source_id": str(source["id"]), "content": chunk["content"], "section_title": chunk.get("section_title"), "source_code": source_code, "source_name": source["name"], "attribution_text": source["attribution_text"], "license_code": source["license_code"], "attribution_required": source.get("attribution_required", True), "document_type": source.get("document_type", ""), "category": chunk.get("category", ""), "language": source.get("language", "de"), "page_number": chunk.get("page_number") }) # Index in Qdrant indexed_count = await qdrant.index_chunks(chunk_records, embeddings) # Update document record await store.update_document_indexed(document_id, len(chunks)) return IngestResponse( source_code=source_code, document_id=document_id, chunks_created=indexed_count, message=f"Successfully ingested {indexed_count} chunks from document" ) @router.get("/chunks/{chunk_id}", response_model=DSFAChunkResponse) async def get_chunk_with_attribution( chunk_id: str, store: DSFACorpusStore = Depends(get_store) ): """Get single chunk with full source attribution.""" chunk = await store.get_chunk_with_attribution(chunk_id) if not chunk: raise HTTPException(status_code=404, detail=f"Chunk not found: {chunk_id}") license_info = LICENSE_REGISTRY.get(chunk.get("license_code", ""), {}) return DSFAChunkResponse( chunk_id=str(chunk["chunk_id"]), content=chunk.get("content", ""), section_title=chunk.get("section_title"), page_number=chunk.get("page_number"), category=chunk.get("category"), document_id=str(chunk.get("document_id", "")), document_title=chunk.get("document_title"), source_id=str(chunk.get("source_id", "")), source_code=chunk.get("source_code", ""), source_name=chunk.get("source_name", ""), attribution_text=chunk.get("attribution_text", ""), license_code=chunk.get("license_code", ""), license_name=license_info.get("name", chunk.get("license_code", "")), license_url=license_info.get("url"), attribution_required=chunk.get("attribution_required", True), source_url=chunk.get("source_url"), document_type=chunk.get("document_type") ) @router.get("/stats", response_model=DSFACorpusStatsResponse) async def get_corpus_stats( store: DSFACorpusStore = Depends(get_store), qdrant: DSFAQdrantService = Depends(get_qdrant) ): """Get corpus statistics for dashboard.""" # Get PostgreSQL stats source_stats = await store.get_source_stats() total_docs = 0 total_chunks = 0 stats_response = [] for s in source_stats: doc_count = s.get("document_count", 0) or 0 chunk_count = s.get("chunk_count", 0) or 0 total_docs += doc_count total_chunks += chunk_count last_indexed = s.get("last_indexed_at") stats_response.append(DSFASourceStatsResponse( source_id=str(s.get("source_id", "")), source_code=s.get("source_code", ""), name=s.get("name", ""), organization=s.get("organization"), license_code=s.get("license_code", ""), document_type=s.get("document_type"), document_count=doc_count, chunk_count=chunk_count, last_indexed_at=last_indexed.isoformat() if last_indexed else None )) # Get Qdrant stats qdrant_stats = await qdrant.get_stats() return DSFACorpusStatsResponse( sources=stats_response, total_sources=len(source_stats), total_documents=total_docs, total_chunks=total_chunks, qdrant_collection=DSFA_COLLECTION, qdrant_points_count=qdrant_stats.get("points_count", 0), qdrant_status=qdrant_stats.get("status", "unknown") ) @router.get("/licenses") async def list_licenses(): """List all supported licenses with their terms.""" return [ LicenseInfo( code=code, name=info.get("name", code), url=info.get("url"), attribution_required=info.get("attribution_required", True), modification_allowed=info.get("modification_allowed", True), commercial_use=info.get("commercial_use", True) ) for code, info in LICENSE_REGISTRY.items() ] @router.post("/init") async def initialize_dsfa_corpus( store: DSFACorpusStore = Depends(get_store), qdrant: DSFAQdrantService = Depends(get_qdrant) ): """ Initialize DSFA corpus. - Creates Qdrant collection - Registers all predefined sources """ # Ensure Qdrant collection exists qdrant_ok = await qdrant.ensure_collection() # Register all sources registered = 0 for source in DSFA_SOURCES: try: await store.register_source(source) registered += 1 except Exception as e: print(f"Error registering source {source['source_code']}: {e}") return { "qdrant_collection_created": qdrant_ok, "sources_registered": registered, "total_sources": len(DSFA_SOURCES) }