Files
breakpilot-core/rag-service/qdrant_client_wrapper.py
Benjamin Admin be45adb975
All checks were successful
CI / go-lint (push) Has been skipped
CI / nodejs-lint (push) Has been skipped
CI / test-go-consent (push) Successful in 33s
CI / test-python-voice (push) Successful in 36s
CI / deploy-hetzner (push) Successful in 38s
CI / python-lint (push) Has been skipped
CI / test-bqas (push) Successful in 31s
fix(rag): Auto-create Qdrant collection on first index
Collections may not exist if init_collections() failed at startup
(e.g. Qdrant not ready). Now index_documents() ensures the
collection exists before upserting.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-11 21:02:05 +01:00

260 lines
8.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import logging
import uuid
from typing import Any, Optional
from qdrant_client import QdrantClient
from qdrant_client.http import models as qmodels
from qdrant_client.http.exceptions import UnexpectedResponse
from config import settings
logger = logging.getLogger("rag-service.qdrant")
# ------------------------------------------------------------------
# Default collections with their vector dimensions
# ------------------------------------------------------------------
# Lehrer / EH collections (OpenAI-style 1536-dim embeddings)
_LEHRER_COLLECTIONS = {
"bp_eh": 1536,
"bp_nibis_eh": 1536,
"bp_nibis": 1536,
"bp_vocab": 1536,
}
# Compliance / Legal collections (1024-dim embeddings, e.g. from a smaller model)
_COMPLIANCE_COLLECTIONS = {
"bp_legal_templates": 1024,
"bp_compliance_gdpr": 1024,
"bp_compliance_schulrecht": 1024,
"bp_compliance_datenschutz": 1024,
"bp_compliance_gesetze": 1024,
"bp_compliance_ce": 1024,
"bp_dsfa_templates": 1024,
"bp_dsfa_risks": 1024,
}
ALL_DEFAULT_COLLECTIONS: dict[str, int] = {
**_LEHRER_COLLECTIONS,
**_COMPLIANCE_COLLECTIONS,
}
class QdrantClientWrapper:
"""Thin wrapper around QdrantClient with BreakPilot-specific helpers."""
def __init__(self) -> None:
self._client: Optional[QdrantClient] = None
@property
def client(self) -> QdrantClient:
if self._client is None:
self._client = QdrantClient(url=settings.QDRANT_URL, timeout=30)
logger.info("Connected to Qdrant at %s", settings.QDRANT_URL)
return self._client
# ------------------------------------------------------------------
# Initialisation
# ------------------------------------------------------------------
async def init_collections(self) -> None:
"""Create all default collections if they do not already exist."""
for name, dim in ALL_DEFAULT_COLLECTIONS.items():
await self.create_collection(name, dim)
logger.info(
"All default collections initialised (%d total)",
len(ALL_DEFAULT_COLLECTIONS),
)
async def create_collection(
self,
name: str,
vector_size: int,
distance: qmodels.Distance = qmodels.Distance.COSINE,
) -> bool:
"""Create a single collection. Returns True if newly created."""
try:
self.client.get_collection(name)
logger.debug("Collection '%s' already exists skipping", name)
return False
except (UnexpectedResponse, Exception):
pass
try:
self.client.create_collection(
collection_name=name,
vectors_config=qmodels.VectorParams(
size=vector_size,
distance=distance,
),
optimizers_config=qmodels.OptimizersConfigDiff(
indexing_threshold=20_000,
),
)
logger.info(
"Created collection '%s' (dims=%d, distance=%s)",
name,
vector_size,
distance.value,
)
return True
except Exception as exc:
logger.error("Failed to create collection '%s': %s", name, exc)
raise
# ------------------------------------------------------------------
# Indexing
# ------------------------------------------------------------------
async def ensure_collection(self, name: str, vector_size: int = 1024) -> None:
"""Create collection if it doesn't exist."""
try:
self.client.get_collection(name)
except Exception:
await self.create_collection(name, vector_size)
async def index_documents(
self,
collection: str,
vectors: list[list[float]],
payloads: list[dict[str, Any]],
ids: Optional[list[str]] = None,
) -> int:
"""Batch-upsert vectors + payloads. Returns number of points upserted."""
if len(vectors) != len(payloads):
raise ValueError(
f"vectors ({len(vectors)}) and payloads ({len(payloads)}) must have equal length"
)
# Auto-create collection if missing
vector_size = len(vectors[0]) if vectors else 1024
await self.ensure_collection(collection, vector_size)
if ids is None:
ids = [str(uuid.uuid4()) for _ in vectors]
points = [
qmodels.PointStruct(id=pid, vector=vec, payload=pay)
for pid, vec, pay in zip(ids, vectors, payloads)
]
batch_size = 100
total_upserted = 0
for i in range(0, len(points), batch_size):
batch = points[i : i + batch_size]
self.client.upsert(collection_name=collection, points=batch, wait=True)
total_upserted += len(batch)
logger.info(
"Upserted %d points into '%s'", total_upserted, collection
)
return total_upserted
# ------------------------------------------------------------------
# Search
# ------------------------------------------------------------------
async def search(
self,
collection: str,
query_vector: list[float],
limit: int = 10,
filters: Optional[dict[str, Any]] = None,
score_threshold: Optional[float] = None,
) -> list[dict[str, Any]]:
"""Semantic search. Returns list of {id, score, payload}."""
qdrant_filter = None
if filters:
must_conditions = []
for key, value in filters.items():
if isinstance(value, list):
must_conditions.append(
qmodels.FieldCondition(
key=key, match=qmodels.MatchAny(any=value)
)
)
else:
must_conditions.append(
qmodels.FieldCondition(
key=key, match=qmodels.MatchValue(value=value)
)
)
qdrant_filter = qmodels.Filter(must=must_conditions)
results = self.client.query_points(
collection_name=collection,
query=query_vector,
limit=limit,
query_filter=qdrant_filter,
score_threshold=score_threshold,
with_payload=True,
)
return [
{
"id": str(hit.id),
"score": hit.score,
"payload": hit.payload or {},
}
for hit in results.points
]
# ------------------------------------------------------------------
# Delete
# ------------------------------------------------------------------
async def delete_by_filter(
self, collection: str, filter_conditions: dict[str, Any]
) -> bool:
"""Delete all points matching the given filter dict."""
must_conditions = []
for key, value in filter_conditions.items():
if isinstance(value, list):
must_conditions.append(
qmodels.FieldCondition(
key=key, match=qmodels.MatchAny(any=value)
)
)
else:
must_conditions.append(
qmodels.FieldCondition(
key=key, match=qmodels.MatchValue(value=value)
)
)
self.client.delete(
collection_name=collection,
points_selector=qmodels.FilterSelector(
filter=qmodels.Filter(must=must_conditions)
),
wait=True,
)
logger.info("Deleted points from '%s' with filter %s", collection, filter_conditions)
return True
# ------------------------------------------------------------------
# Info
# ------------------------------------------------------------------
async def get_collection_info(self, collection: str) -> dict[str, Any]:
"""Return basic stats for a collection."""
try:
info = self.client.get_collection(collection)
return {
"name": collection,
"vectors_count": info.vectors_count,
"points_count": info.points_count,
"status": info.status.value if info.status else "unknown",
"vector_size": (
info.config.params.vectors.size
if hasattr(info.config.params.vectors, "size")
else None
),
}
except Exception as exc:
logger.error("Failed to get info for '%s': %s", collection, exc)
raise
# Singleton
qdrant_wrapper = QdrantClientWrapper()