import logging import uuid from typing import Any, Optional from qdrant_client import QdrantClient from qdrant_client.http import models as qmodels from qdrant_client.http.exceptions import UnexpectedResponse from config import settings logger = logging.getLogger("rag-service.qdrant") # ------------------------------------------------------------------ # Default collections with their vector dimensions # ------------------------------------------------------------------ # Lehrer / EH collections (OpenAI-style 1536-dim embeddings) _LEHRER_COLLECTIONS = { "bp_eh": 1536, "bp_nibis_eh": 1536, "bp_nibis": 1536, "bp_vocab": 1536, } # Compliance / Legal collections (1024-dim embeddings, e.g. from a smaller model) _COMPLIANCE_COLLECTIONS = { "bp_legal_templates": 1024, "bp_compliance_gdpr": 1024, "bp_compliance_schulrecht": 1024, "bp_compliance_datenschutz": 1024, "bp_dsfa_templates": 1024, "bp_dsfa_risks": 1024, } ALL_DEFAULT_COLLECTIONS: dict[str, int] = { **_LEHRER_COLLECTIONS, **_COMPLIANCE_COLLECTIONS, } class QdrantClientWrapper: """Thin wrapper around QdrantClient with BreakPilot-specific helpers.""" def __init__(self) -> None: self._client: Optional[QdrantClient] = None @property def client(self) -> QdrantClient: if self._client is None: self._client = QdrantClient(url=settings.QDRANT_URL, timeout=30) logger.info("Connected to Qdrant at %s", settings.QDRANT_URL) return self._client # ------------------------------------------------------------------ # Initialisation # ------------------------------------------------------------------ async def init_collections(self) -> None: """Create all default collections if they do not already exist.""" for name, dim in ALL_DEFAULT_COLLECTIONS.items(): await self.create_collection(name, dim) logger.info( "All default collections initialised (%d total)", len(ALL_DEFAULT_COLLECTIONS), ) async def create_collection( self, name: str, vector_size: int, distance: qmodels.Distance = qmodels.Distance.COSINE, ) -> bool: """Create a single collection. Returns True if newly created.""" try: self.client.get_collection(name) logger.debug("Collection '%s' already exists – skipping", name) return False except (UnexpectedResponse, Exception): pass try: self.client.create_collection( collection_name=name, vectors_config=qmodels.VectorParams( size=vector_size, distance=distance, ), optimizers_config=qmodels.OptimizersConfigDiff( indexing_threshold=20_000, ), ) logger.info( "Created collection '%s' (dims=%d, distance=%s)", name, vector_size, distance.value, ) return True except Exception as exc: logger.error("Failed to create collection '%s': %s", name, exc) raise # ------------------------------------------------------------------ # Indexing # ------------------------------------------------------------------ async def index_documents( self, collection: str, vectors: list[list[float]], payloads: list[dict[str, Any]], ids: Optional[list[str]] = None, ) -> int: """Batch-upsert vectors + payloads. Returns number of points upserted.""" if len(vectors) != len(payloads): raise ValueError( f"vectors ({len(vectors)}) and payloads ({len(payloads)}) must have equal length" ) if ids is None: ids = [str(uuid.uuid4()) for _ in vectors] points = [ qmodels.PointStruct(id=pid, vector=vec, payload=pay) for pid, vec, pay in zip(ids, vectors, payloads) ] batch_size = 100 total_upserted = 0 for i in range(0, len(points), batch_size): batch = points[i : i + batch_size] self.client.upsert(collection_name=collection, points=batch, wait=True) total_upserted += len(batch) logger.info( "Upserted %d points into '%s'", total_upserted, collection ) return total_upserted # ------------------------------------------------------------------ # Search # ------------------------------------------------------------------ async def search( self, collection: str, query_vector: list[float], limit: int = 10, filters: Optional[dict[str, Any]] = None, score_threshold: Optional[float] = None, ) -> list[dict[str, Any]]: """Semantic search. Returns list of {id, score, payload}.""" qdrant_filter = None if filters: must_conditions = [] for key, value in filters.items(): if isinstance(value, list): must_conditions.append( qmodels.FieldCondition( key=key, match=qmodels.MatchAny(any=value) ) ) else: must_conditions.append( qmodels.FieldCondition( key=key, match=qmodels.MatchValue(value=value) ) ) qdrant_filter = qmodels.Filter(must=must_conditions) results = self.client.search( collection_name=collection, query_vector=query_vector, limit=limit, query_filter=qdrant_filter, score_threshold=score_threshold, ) return [ { "id": str(hit.id), "score": hit.score, "payload": hit.payload or {}, } for hit in results ] # ------------------------------------------------------------------ # Delete # ------------------------------------------------------------------ async def delete_by_filter( self, collection: str, filter_conditions: dict[str, Any] ) -> bool: """Delete all points matching the given filter dict.""" must_conditions = [] for key, value in filter_conditions.items(): if isinstance(value, list): must_conditions.append( qmodels.FieldCondition( key=key, match=qmodels.MatchAny(any=value) ) ) else: must_conditions.append( qmodels.FieldCondition( key=key, match=qmodels.MatchValue(value=value) ) ) self.client.delete( collection_name=collection, points_selector=qmodels.FilterSelector( filter=qmodels.Filter(must=must_conditions) ), wait=True, ) logger.info("Deleted points from '%s' with filter %s", collection, filter_conditions) return True # ------------------------------------------------------------------ # Info # ------------------------------------------------------------------ async def get_collection_info(self, collection: str) -> dict[str, Any]: """Return basic stats for a collection.""" try: info = self.client.get_collection(collection) return { "name": collection, "vectors_count": info.vectors_count, "points_count": info.points_count, "status": info.status.value if info.status else "unknown", "vector_size": ( info.config.params.vectors.size if hasattr(info.config.params.vectors, "size") else None ), } except Exception as exc: logger.error("Failed to get info for '%s': %s", collection, exc) raise # Singleton qdrant_wrapper = QdrantClientWrapper()