import logging from typing import Any, Optional from fastapi import APIRouter, HTTPException, Request from pydantic import BaseModel, Field from api.auth import optional_jwt_auth from embedding_client import embedding_client from qdrant_client_wrapper import qdrant_wrapper logger = logging.getLogger("rag-service.api.search") router = APIRouter(prefix="/api/v1") # ---- Request / Response models -------------------------------------------- class SemanticSearchRequest(BaseModel): query: str collection: str = "bp_eh" limit: int = Field(default=10, ge=1, le=100) filters: Optional[dict[str, Any]] = None score_threshold: Optional[float] = None class HybridSearchRequest(BaseModel): query: str collection: str = "bp_eh" limit: int = Field(default=10, ge=1, le=100) filters: Optional[dict[str, Any]] = None score_threshold: Optional[float] = None keyword_boost: float = Field(default=0.3, ge=0.0, le=1.0) rerank: bool = True rerank_top_k: int = Field(default=10, ge=1, le=50) class RerankRequest(BaseModel): query: str documents: list[str] top_k: int = Field(default=10, ge=1, le=100) class SearchResult(BaseModel): id: str score: float payload: dict[str, Any] = {} class SearchResponse(BaseModel): results: list[SearchResult] count: int query: str collection: str # ---- Endpoints ------------------------------------------------------------ @router.post("/search", response_model=SearchResponse) async def semantic_search(body: SemanticSearchRequest, request: Request): """ Pure semantic (vector) search. Embeds the query, then searches Qdrant for nearest neighbours. """ optional_jwt_auth(request) # Generate query embedding try: query_vector = await embedding_client.generate_single_embedding(body.query) except Exception as exc: logger.error("Failed to embed query: %s", exc) raise HTTPException(status_code=502, detail=f"Embedding service error: {exc}") # Search Qdrant try: results = await qdrant_wrapper.search( collection=body.collection, query_vector=query_vector, limit=body.limit, filters=body.filters, score_threshold=body.score_threshold, ) except Exception as exc: logger.error("Qdrant search failed: %s", exc) raise HTTPException(status_code=500, detail=f"Vector search failed: {exc}") return SearchResponse( results=[SearchResult(**r) for r in results], count=len(results), query=body.query, collection=body.collection, ) @router.post("/search/hybrid", response_model=SearchResponse) async def hybrid_search(body: HybridSearchRequest, request: Request): """ Hybrid search: vector search + keyword filtering + optional re-ranking. 1. Embed query and do vector search with a higher initial limit 2. Apply keyword matching on chunk_text to boost relevant results 3. Optionally re-rank the top results via the embedding service """ optional_jwt_auth(request) # --- Step 1: Vector search (fetch more than needed for re-ranking) --- fetch_limit = max(body.limit * 3, 30) try: query_vector = await embedding_client.generate_single_embedding(body.query) except Exception as exc: logger.error("Failed to embed query: %s", exc) raise HTTPException(status_code=502, detail=f"Embedding service error: {exc}") try: vector_results = await qdrant_wrapper.search( collection=body.collection, query_vector=query_vector, limit=fetch_limit, filters=body.filters, score_threshold=body.score_threshold, ) except Exception as exc: logger.error("Qdrant search failed: %s", exc) raise HTTPException(status_code=500, detail=f"Vector search failed: {exc}") if not vector_results: return SearchResponse( results=[], count=0, query=body.query, collection=body.collection, ) # --- Step 2: Keyword boost --- query_terms = body.query.lower().split() for result in vector_results: chunk_text = result.get("payload", {}).get("chunk_text", "").lower() keyword_hits = sum(1 for term in query_terms if term in chunk_text) keyword_score = (keyword_hits / max(len(query_terms), 1)) * body.keyword_boost result["score"] = result["score"] + keyword_score # Sort by boosted score vector_results.sort(key=lambda x: x["score"], reverse=True) # --- Step 3: Optional re-ranking --- if body.rerank and len(vector_results) > 1: try: documents = [ r.get("payload", {}).get("chunk_text", "") for r in vector_results[: body.rerank_top_k] ] reranked = await embedding_client.rerank_documents( query=body.query, documents=documents, top_k=body.limit, ) # Rebuild results in re-ranked order reranked_results = [] for item in reranked: idx = item.get("index", 0) if idx < len(vector_results): entry = vector_results[idx].copy() entry["score"] = item.get("score", entry["score"]) reranked_results.append(entry) vector_results = reranked_results except Exception as exc: logger.warning("Re-ranking failed, falling back to vector+keyword scores: %s", exc) # Trim to requested limit final_results = vector_results[: body.limit] return SearchResponse( results=[SearchResult(**r) for r in final_results], count=len(final_results), query=body.query, collection=body.collection, ) @router.post("/rerank") async def rerank(body: RerankRequest, request: Request): """ Standalone re-ranking endpoint. Sends query + documents to the embedding service for re-ranking. """ optional_jwt_auth(request) if not body.documents: return {"results": [], "count": 0} try: results = await embedding_client.rerank_documents( query=body.query, documents=body.documents, top_k=body.top_k, ) return {"results": results, "count": len(results), "query": body.query} except Exception as exc: logger.error("Re-ranking failed: %s", exc) raise HTTPException(status_code=502, detail=f"Re-ranking failed: {exc}")