Files
breakpilot-core/rag-service/qdrant_client_wrapper.py
Benjamin Admin 7c932c441f
All checks were successful
CI / go-lint (push) Has been skipped
CI / test-go-consent (push) Successful in 35s
CI / test-python-voice (push) Successful in 50s
CI / python-lint (push) Has been skipped
CI / nodejs-lint (push) Has been skipped
CI / test-bqas (push) Successful in 33s
CI / deploy-hetzner (push) Successful in 39s
feat(rag): Add bp_compliance_gesetze + bp_compliance_ce collections
Required for Verbraucherschutz + EU law ingestion.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-11 20:41:26 +01:00

249 lines
8.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import logging
import uuid
from typing import Any, Optional
from qdrant_client import QdrantClient
from qdrant_client.http import models as qmodels
from qdrant_client.http.exceptions import UnexpectedResponse
from config import settings
logger = logging.getLogger("rag-service.qdrant")
# ------------------------------------------------------------------
# Default collections with their vector dimensions
# ------------------------------------------------------------------
# Lehrer / EH collections (OpenAI-style 1536-dim embeddings)
_LEHRER_COLLECTIONS = {
"bp_eh": 1536,
"bp_nibis_eh": 1536,
"bp_nibis": 1536,
"bp_vocab": 1536,
}
# Compliance / Legal collections (1024-dim embeddings, e.g. from a smaller model)
_COMPLIANCE_COLLECTIONS = {
"bp_legal_templates": 1024,
"bp_compliance_gdpr": 1024,
"bp_compliance_schulrecht": 1024,
"bp_compliance_datenschutz": 1024,
"bp_compliance_gesetze": 1024,
"bp_compliance_ce": 1024,
"bp_dsfa_templates": 1024,
"bp_dsfa_risks": 1024,
}
ALL_DEFAULT_COLLECTIONS: dict[str, int] = {
**_LEHRER_COLLECTIONS,
**_COMPLIANCE_COLLECTIONS,
}
class QdrantClientWrapper:
"""Thin wrapper around QdrantClient with BreakPilot-specific helpers."""
def __init__(self) -> None:
self._client: Optional[QdrantClient] = None
@property
def client(self) -> QdrantClient:
if self._client is None:
self._client = QdrantClient(url=settings.QDRANT_URL, timeout=30)
logger.info("Connected to Qdrant at %s", settings.QDRANT_URL)
return self._client
# ------------------------------------------------------------------
# Initialisation
# ------------------------------------------------------------------
async def init_collections(self) -> None:
"""Create all default collections if they do not already exist."""
for name, dim in ALL_DEFAULT_COLLECTIONS.items():
await self.create_collection(name, dim)
logger.info(
"All default collections initialised (%d total)",
len(ALL_DEFAULT_COLLECTIONS),
)
async def create_collection(
self,
name: str,
vector_size: int,
distance: qmodels.Distance = qmodels.Distance.COSINE,
) -> bool:
"""Create a single collection. Returns True if newly created."""
try:
self.client.get_collection(name)
logger.debug("Collection '%s' already exists skipping", name)
return False
except (UnexpectedResponse, Exception):
pass
try:
self.client.create_collection(
collection_name=name,
vectors_config=qmodels.VectorParams(
size=vector_size,
distance=distance,
),
optimizers_config=qmodels.OptimizersConfigDiff(
indexing_threshold=20_000,
),
)
logger.info(
"Created collection '%s' (dims=%d, distance=%s)",
name,
vector_size,
distance.value,
)
return True
except Exception as exc:
logger.error("Failed to create collection '%s': %s", name, exc)
raise
# ------------------------------------------------------------------
# Indexing
# ------------------------------------------------------------------
async def index_documents(
self,
collection: str,
vectors: list[list[float]],
payloads: list[dict[str, Any]],
ids: Optional[list[str]] = None,
) -> int:
"""Batch-upsert vectors + payloads. Returns number of points upserted."""
if len(vectors) != len(payloads):
raise ValueError(
f"vectors ({len(vectors)}) and payloads ({len(payloads)}) must have equal length"
)
if ids is None:
ids = [str(uuid.uuid4()) for _ in vectors]
points = [
qmodels.PointStruct(id=pid, vector=vec, payload=pay)
for pid, vec, pay in zip(ids, vectors, payloads)
]
batch_size = 100
total_upserted = 0
for i in range(0, len(points), batch_size):
batch = points[i : i + batch_size]
self.client.upsert(collection_name=collection, points=batch, wait=True)
total_upserted += len(batch)
logger.info(
"Upserted %d points into '%s'", total_upserted, collection
)
return total_upserted
# ------------------------------------------------------------------
# Search
# ------------------------------------------------------------------
async def search(
self,
collection: str,
query_vector: list[float],
limit: int = 10,
filters: Optional[dict[str, Any]] = None,
score_threshold: Optional[float] = None,
) -> list[dict[str, Any]]:
"""Semantic search. Returns list of {id, score, payload}."""
qdrant_filter = None
if filters:
must_conditions = []
for key, value in filters.items():
if isinstance(value, list):
must_conditions.append(
qmodels.FieldCondition(
key=key, match=qmodels.MatchAny(any=value)
)
)
else:
must_conditions.append(
qmodels.FieldCondition(
key=key, match=qmodels.MatchValue(value=value)
)
)
qdrant_filter = qmodels.Filter(must=must_conditions)
results = self.client.query_points(
collection_name=collection,
query=query_vector,
limit=limit,
query_filter=qdrant_filter,
score_threshold=score_threshold,
with_payload=True,
)
return [
{
"id": str(hit.id),
"score": hit.score,
"payload": hit.payload or {},
}
for hit in results.points
]
# ------------------------------------------------------------------
# Delete
# ------------------------------------------------------------------
async def delete_by_filter(
self, collection: str, filter_conditions: dict[str, Any]
) -> bool:
"""Delete all points matching the given filter dict."""
must_conditions = []
for key, value in filter_conditions.items():
if isinstance(value, list):
must_conditions.append(
qmodels.FieldCondition(
key=key, match=qmodels.MatchAny(any=value)
)
)
else:
must_conditions.append(
qmodels.FieldCondition(
key=key, match=qmodels.MatchValue(value=value)
)
)
self.client.delete(
collection_name=collection,
points_selector=qmodels.FilterSelector(
filter=qmodels.Filter(must=must_conditions)
),
wait=True,
)
logger.info("Deleted points from '%s' with filter %s", collection, filter_conditions)
return True
# ------------------------------------------------------------------
# Info
# ------------------------------------------------------------------
async def get_collection_info(self, collection: str) -> dict[str, Any]:
"""Return basic stats for a collection."""
try:
info = self.client.get_collection(collection)
return {
"name": collection,
"vectors_count": info.vectors_count,
"points_count": info.points_count,
"status": info.status.value if info.status else "unknown",
"vector_size": (
info.config.params.vectors.size
if hasattr(info.config.params.vectors, "size")
else None
),
}
except Exception as exc:
logger.error("Failed to get info for '%s': %s", collection, exc)
raise
# Singleton
qdrant_wrapper = QdrantClientWrapper()