""" DSFA RAG API Route Handlers. Endpoint implementations for search, sources, ingestion, stats, and init. """ import logging from typing import List, Optional from fastapi import APIRouter, HTTPException, Query, Depends from dsfa_corpus_ingestion import ( DSFACorpusStore, DSFAQdrantService, DSFASearchResult, LICENSE_REGISTRY, DSFA_SOURCES, generate_attribution_notice, get_license_label, DSFA_COLLECTION, chunk_document, ) from dsfa_rag_models import ( DSFASourceResponse, DSFAChunkResponse, DSFASearchResultResponse, DSFASearchResponse, DSFASourceStatsResponse, DSFACorpusStatsResponse, IngestRequest, IngestResponse, LicenseInfo, ) from dsfa_rag_embedding import ( get_embedding, get_embeddings_batch, extract_text_from_url, ) logger = logging.getLogger(__name__) router = APIRouter(prefix="/api/v1/dsfa-rag", tags=["DSFA RAG"]) # ============================================================================= # Dependency Injection # ============================================================================= _db_pool = None def set_db_pool(pool): """Set the database pool for API endpoints.""" global _db_pool _db_pool = pool async def get_store() -> DSFACorpusStore: """Get DSFA corpus store.""" if _db_pool is None: raise HTTPException(status_code=503, detail="Database not initialized") return DSFACorpusStore(_db_pool) async def get_qdrant() -> DSFAQdrantService: """Get Qdrant service.""" return DSFAQdrantService() # ============================================================================= # API Endpoints # ============================================================================= @router.get("/search", response_model=DSFASearchResponse) async def search_dsfa_corpus( query: str = Query(..., min_length=3, description="Search query"), source_codes: Optional[List[str]] = Query(None, description="Filter by source codes"), document_types: Optional[List[str]] = Query(None, description="Filter by document types (guideline, checklist, regulation)"), categories: Optional[List[str]] = Query(None, description="Filter by categories (threshold_analysis, risk_assessment, mitigation)"), limit: int = Query(10, ge=1, le=50, description="Maximum results"), include_attribution: bool = Query(True, description="Include attribution in results"), store: DSFACorpusStore = Depends(get_store), qdrant: DSFAQdrantService = Depends(get_qdrant) ): """ Search DSFA corpus with full attribution. Returns matching chunks with source/license information for compliance. """ query_embedding = await get_embedding(query) raw_results = await qdrant.search( query_embedding=query_embedding, source_codes=source_codes, document_types=document_types, categories=categories, limit=limit ) results = [] licenses_used = set() for r in raw_results: license_code = r.get("license_code", "") license_info = LICENSE_REGISTRY.get(license_code, {}) result = DSFASearchResultResponse( chunk_id=r.get("chunk_id", ""), content=r.get("content", ""), score=r.get("score", 0.0), source_code=r.get("source_code", ""), source_name=r.get("source_name", ""), attribution_text=r.get("attribution_text", ""), license_code=license_code, license_name=license_info.get("name", license_code), license_url=license_info.get("url"), attribution_required=r.get("attribution_required", True), source_url=r.get("source_url"), document_type=r.get("document_type"), category=r.get("category"), section_title=r.get("section_title"), page_number=r.get("page_number") ) results.append(result) licenses_used.add(license_code) # Generate attribution notice search_results = [ DSFASearchResult( chunk_id=r.chunk_id, content=r.content, score=r.score, source_code=r.source_code, source_name=r.source_name, attribution_text=r.attribution_text, license_code=r.license_code, license_url=r.license_url, attribution_required=r.attribution_required, source_url=r.source_url, document_type=r.document_type or "", category=r.category or "", section_title=r.section_title, page_number=r.page_number ) for r in results ] attribution_notice = generate_attribution_notice(search_results) if include_attribution else "" return DSFASearchResponse( query=query, results=results, total_results=len(results), licenses_used=list(licenses_used), attribution_notice=attribution_notice ) @router.get("/sources", response_model=List[DSFASourceResponse]) async def list_dsfa_sources( document_type: Optional[str] = Query(None, description="Filter by document type"), license_code: Optional[str] = Query(None, description="Filter by license"), store: DSFACorpusStore = Depends(get_store) ): """List all registered DSFA sources with license info.""" sources = await store.list_sources() result = [] for s in sources: if document_type and s.get("document_type") != document_type: continue if license_code and s.get("license_code") != license_code: continue license_info = LICENSE_REGISTRY.get(s.get("license_code", ""), {}) result.append(DSFASourceResponse( id=str(s["id"]), source_code=s["source_code"], name=s["name"], full_name=s.get("full_name"), organization=s.get("organization"), source_url=s.get("source_url"), license_code=s.get("license_code", ""), license_name=license_info.get("name", s.get("license_code", "")), license_url=license_info.get("url"), attribution_required=s.get("attribution_required", True), attribution_text=s.get("attribution_text", ""), document_type=s.get("document_type"), language=s.get("language", "de") )) return result @router.get("/sources/available") async def list_available_sources(): """List all available source definitions (from DSFA_SOURCES constant).""" return [ { "source_code": s["source_code"], "name": s["name"], "organization": s.get("organization"), "license_code": s["license_code"], "document_type": s.get("document_type") } for s in DSFA_SOURCES ] @router.get("/sources/{source_code}", response_model=DSFASourceResponse) async def get_dsfa_source( source_code: str, store: DSFACorpusStore = Depends(get_store) ): """Get details for a specific source.""" source = await store.get_source_by_code(source_code) if not source: raise HTTPException(status_code=404, detail=f"Source not found: {source_code}") license_info = LICENSE_REGISTRY.get(source.get("license_code", ""), {}) return DSFASourceResponse( id=str(source["id"]), source_code=source["source_code"], name=source["name"], full_name=source.get("full_name"), organization=source.get("organization"), source_url=source.get("source_url"), license_code=source.get("license_code", ""), license_name=license_info.get("name", source.get("license_code", "")), license_url=license_info.get("url"), attribution_required=source.get("attribution_required", True), attribution_text=source.get("attribution_text", ""), document_type=source.get("document_type"), language=source.get("language", "de") ) @router.post("/sources/{source_code}/ingest", response_model=IngestResponse) async def ingest_dsfa_source( source_code: str, request: IngestRequest, store: DSFACorpusStore = Depends(get_store), qdrant: DSFAQdrantService = Depends(get_qdrant) ): """ Trigger ingestion for a specific source. Can provide document via URL or direct text. """ source = await store.get_source_by_code(source_code) if not source: raise HTTPException(status_code=404, detail=f"Source not found: {source_code}") if not request.document_text and not request.document_url: raise HTTPException( status_code=400, detail="Either document_text or document_url must be provided" ) await qdrant.ensure_collection() text_content = request.document_text if request.document_url and not text_content: logger.info(f"Extracting text from URL: {request.document_url}") text_content = await extract_text_from_url(request.document_url) if not text_content: raise HTTPException( status_code=400, detail=f"Could not extract text from URL: {request.document_url}" ) if not text_content or len(text_content.strip()) < 50: raise HTTPException(status_code=400, detail="Document text too short (min 50 chars)") doc_title = request.title or f"Document for {source_code}" document_id = await store.create_document( source_id=str(source["id"]), title=doc_title, file_type="text", metadata={"ingested_via": "api", "source_code": source_code} ) chunks = chunk_document(text_content, source_code) if not chunks: return IngestResponse( source_code=source_code, document_id=document_id, chunks_created=0, message="Document created but no chunks generated" ) chunk_texts = [chunk["content"] for chunk in chunks] logger.info(f"Generating embeddings for {len(chunk_texts)} chunks...") embeddings = await get_embeddings_batch(chunk_texts) chunk_records = [] for i, (chunk, embedding) in enumerate(zip(chunks, embeddings)): chunk_id = await store.create_chunk( document_id=document_id, source_id=str(source["id"]), content=chunk["content"], chunk_index=i, section_title=chunk.get("section_title"), page_number=chunk.get("page_number"), category=chunk.get("category") ) chunk_records.append({ "chunk_id": chunk_id, "document_id": document_id, "source_id": str(source["id"]), "content": chunk["content"], "section_title": chunk.get("section_title"), "source_code": source_code, "source_name": source["name"], "attribution_text": source["attribution_text"], "license_code": source["license_code"], "attribution_required": source.get("attribution_required", True), "document_type": source.get("document_type", ""), "category": chunk.get("category", ""), "language": source.get("language", "de"), "page_number": chunk.get("page_number") }) indexed_count = await qdrant.index_chunks(chunk_records, embeddings) await store.update_document_indexed(document_id, len(chunks)) return IngestResponse( source_code=source_code, document_id=document_id, chunks_created=indexed_count, message=f"Successfully ingested {indexed_count} chunks from document" ) @router.get("/chunks/{chunk_id}", response_model=DSFAChunkResponse) async def get_chunk_with_attribution( chunk_id: str, store: DSFACorpusStore = Depends(get_store) ): """Get single chunk with full source attribution.""" chunk = await store.get_chunk_with_attribution(chunk_id) if not chunk: raise HTTPException(status_code=404, detail=f"Chunk not found: {chunk_id}") license_info = LICENSE_REGISTRY.get(chunk.get("license_code", ""), {}) return DSFAChunkResponse( chunk_id=str(chunk["chunk_id"]), content=chunk.get("content", ""), section_title=chunk.get("section_title"), page_number=chunk.get("page_number"), category=chunk.get("category"), document_id=str(chunk.get("document_id", "")), document_title=chunk.get("document_title"), source_id=str(chunk.get("source_id", "")), source_code=chunk.get("source_code", ""), source_name=chunk.get("source_name", ""), attribution_text=chunk.get("attribution_text", ""), license_code=chunk.get("license_code", ""), license_name=license_info.get("name", chunk.get("license_code", "")), license_url=license_info.get("url"), attribution_required=chunk.get("attribution_required", True), source_url=chunk.get("source_url"), document_type=chunk.get("document_type") ) @router.get("/stats", response_model=DSFACorpusStatsResponse) async def get_corpus_stats( store: DSFACorpusStore = Depends(get_store), qdrant: DSFAQdrantService = Depends(get_qdrant) ): """Get corpus statistics for dashboard.""" source_stats = await store.get_source_stats() total_docs = 0 total_chunks = 0 stats_response = [] for s in source_stats: doc_count = s.get("document_count", 0) or 0 chunk_count = s.get("chunk_count", 0) or 0 total_docs += doc_count total_chunks += chunk_count last_indexed = s.get("last_indexed_at") stats_response.append(DSFASourceStatsResponse( source_id=str(s.get("source_id", "")), source_code=s.get("source_code", ""), name=s.get("name", ""), organization=s.get("organization"), license_code=s.get("license_code", ""), document_type=s.get("document_type"), document_count=doc_count, chunk_count=chunk_count, last_indexed_at=last_indexed.isoformat() if last_indexed else None )) qdrant_stats = await qdrant.get_stats() return DSFACorpusStatsResponse( sources=stats_response, total_sources=len(source_stats), total_documents=total_docs, total_chunks=total_chunks, qdrant_collection=DSFA_COLLECTION, qdrant_points_count=qdrant_stats.get("points_count", 0), qdrant_status=qdrant_stats.get("status", "unknown") ) @router.get("/licenses") async def list_licenses(): """List all supported licenses with their terms.""" return [ LicenseInfo( code=code, name=info.get("name", code), url=info.get("url"), attribution_required=info.get("attribution_required", True), modification_allowed=info.get("modification_allowed", True), commercial_use=info.get("commercial_use", True) ) for code, info in LICENSE_REGISTRY.items() ] @router.post("/init") async def initialize_dsfa_corpus( store: DSFACorpusStore = Depends(get_store), qdrant: DSFAQdrantService = Depends(get_qdrant) ): """ Initialize DSFA corpus. - Creates Qdrant collection - Registers all predefined sources """ qdrant_ok = await qdrant.ensure_collection() registered = 0 for source in DSFA_SOURCES: try: await store.register_source(source) registered += 1 except Exception as e: print(f"Error registering source {source['source_code']}: {e}") return { "qdrant_collection_created": qdrant_ok, "sources_registered": registered, "total_sources": len(DSFA_SOURCES) }