- Add QDRANT_API_KEY to config.py (empty string = no auth) - Pass api_key to QdrantClient constructor (None when empty) - Add QDRANT_API_KEY to coolify compose and env example Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
264 lines
8.8 KiB
Python
264 lines
8.8 KiB
Python
import logging
|
||
import uuid
|
||
from typing import Any, Optional
|
||
|
||
from qdrant_client import QdrantClient
|
||
from qdrant_client.http import models as qmodels
|
||
from qdrant_client.http.exceptions import UnexpectedResponse
|
||
|
||
from config import settings
|
||
|
||
logger = logging.getLogger("rag-service.qdrant")
|
||
|
||
# ------------------------------------------------------------------
|
||
# Default collections with their vector dimensions
|
||
# ------------------------------------------------------------------
|
||
# Lehrer / EH collections (OpenAI-style 1536-dim embeddings)
|
||
_LEHRER_COLLECTIONS = {
|
||
"bp_eh": 1536,
|
||
"bp_nibis_eh": 1536,
|
||
"bp_nibis": 1536,
|
||
"bp_vocab": 1536,
|
||
}
|
||
|
||
# Compliance / Legal collections (1024-dim embeddings, e.g. from a smaller model)
|
||
_COMPLIANCE_COLLECTIONS = {
|
||
"bp_legal_templates": 1024,
|
||
"bp_compliance_gdpr": 1024,
|
||
"bp_compliance_schulrecht": 1024,
|
||
"bp_compliance_datenschutz": 1024,
|
||
"bp_compliance_gesetze": 1024,
|
||
"bp_compliance_ce": 1024,
|
||
"bp_dsfa_templates": 1024,
|
||
"bp_dsfa_risks": 1024,
|
||
}
|
||
|
||
ALL_DEFAULT_COLLECTIONS: dict[str, int] = {
|
||
**_LEHRER_COLLECTIONS,
|
||
**_COMPLIANCE_COLLECTIONS,
|
||
}
|
||
|
||
|
||
class QdrantClientWrapper:
|
||
"""Thin wrapper around QdrantClient with BreakPilot-specific helpers."""
|
||
|
||
def __init__(self) -> None:
|
||
self._client: Optional[QdrantClient] = None
|
||
|
||
@property
|
||
def client(self) -> QdrantClient:
|
||
if self._client is None:
|
||
self._client = QdrantClient(
|
||
url=settings.QDRANT_URL,
|
||
api_key=settings.QDRANT_API_KEY or None,
|
||
timeout=30,
|
||
)
|
||
logger.info("Connected to Qdrant at %s", settings.QDRANT_URL)
|
||
return self._client
|
||
|
||
# ------------------------------------------------------------------
|
||
# Initialisation
|
||
# ------------------------------------------------------------------
|
||
|
||
async def init_collections(self) -> None:
|
||
"""Create all default collections if they do not already exist."""
|
||
for name, dim in ALL_DEFAULT_COLLECTIONS.items():
|
||
await self.create_collection(name, dim)
|
||
logger.info(
|
||
"All default collections initialised (%d total)",
|
||
len(ALL_DEFAULT_COLLECTIONS),
|
||
)
|
||
|
||
async def create_collection(
|
||
self,
|
||
name: str,
|
||
vector_size: int,
|
||
distance: qmodels.Distance = qmodels.Distance.COSINE,
|
||
) -> bool:
|
||
"""Create a single collection. Returns True if newly created."""
|
||
try:
|
||
self.client.get_collection(name)
|
||
logger.debug("Collection '%s' already exists – skipping", name)
|
||
return False
|
||
except (UnexpectedResponse, Exception):
|
||
pass
|
||
|
||
try:
|
||
self.client.create_collection(
|
||
collection_name=name,
|
||
vectors_config=qmodels.VectorParams(
|
||
size=vector_size,
|
||
distance=distance,
|
||
),
|
||
optimizers_config=qmodels.OptimizersConfigDiff(
|
||
indexing_threshold=20_000,
|
||
),
|
||
)
|
||
logger.info(
|
||
"Created collection '%s' (dims=%d, distance=%s)",
|
||
name,
|
||
vector_size,
|
||
distance.value,
|
||
)
|
||
return True
|
||
except Exception as exc:
|
||
logger.error("Failed to create collection '%s': %s", name, exc)
|
||
raise
|
||
|
||
# ------------------------------------------------------------------
|
||
# Indexing
|
||
# ------------------------------------------------------------------
|
||
|
||
async def ensure_collection(self, name: str, vector_size: int = 1024) -> None:
|
||
"""Create collection if it doesn't exist."""
|
||
try:
|
||
self.client.get_collection(name)
|
||
except Exception:
|
||
await self.create_collection(name, vector_size)
|
||
|
||
async def index_documents(
|
||
self,
|
||
collection: str,
|
||
vectors: list[list[float]],
|
||
payloads: list[dict[str, Any]],
|
||
ids: Optional[list[str]] = None,
|
||
) -> int:
|
||
"""Batch-upsert vectors + payloads. Returns number of points upserted."""
|
||
if len(vectors) != len(payloads):
|
||
raise ValueError(
|
||
f"vectors ({len(vectors)}) and payloads ({len(payloads)}) must have equal length"
|
||
)
|
||
|
||
# Auto-create collection if missing
|
||
vector_size = len(vectors[0]) if vectors else 1024
|
||
await self.ensure_collection(collection, vector_size)
|
||
|
||
if ids is None:
|
||
ids = [str(uuid.uuid4()) for _ in vectors]
|
||
|
||
points = [
|
||
qmodels.PointStruct(id=pid, vector=vec, payload=pay)
|
||
for pid, vec, pay in zip(ids, vectors, payloads)
|
||
]
|
||
|
||
batch_size = 100
|
||
total_upserted = 0
|
||
for i in range(0, len(points), batch_size):
|
||
batch = points[i : i + batch_size]
|
||
self.client.upsert(collection_name=collection, points=batch, wait=True)
|
||
total_upserted += len(batch)
|
||
|
||
logger.info(
|
||
"Upserted %d points into '%s'", total_upserted, collection
|
||
)
|
||
return total_upserted
|
||
|
||
# ------------------------------------------------------------------
|
||
# Search
|
||
# ------------------------------------------------------------------
|
||
|
||
async def search(
|
||
self,
|
||
collection: str,
|
||
query_vector: list[float],
|
||
limit: int = 10,
|
||
filters: Optional[dict[str, Any]] = None,
|
||
score_threshold: Optional[float] = None,
|
||
) -> list[dict[str, Any]]:
|
||
"""Semantic search. Returns list of {id, score, payload}."""
|
||
qdrant_filter = None
|
||
if filters:
|
||
must_conditions = []
|
||
for key, value in filters.items():
|
||
if isinstance(value, list):
|
||
must_conditions.append(
|
||
qmodels.FieldCondition(
|
||
key=key, match=qmodels.MatchAny(any=value)
|
||
)
|
||
)
|
||
else:
|
||
must_conditions.append(
|
||
qmodels.FieldCondition(
|
||
key=key, match=qmodels.MatchValue(value=value)
|
||
)
|
||
)
|
||
qdrant_filter = qmodels.Filter(must=must_conditions)
|
||
|
||
results = self.client.query_points(
|
||
collection_name=collection,
|
||
query=query_vector,
|
||
limit=limit,
|
||
query_filter=qdrant_filter,
|
||
score_threshold=score_threshold,
|
||
with_payload=True,
|
||
)
|
||
|
||
return [
|
||
{
|
||
"id": str(hit.id),
|
||
"score": hit.score,
|
||
"payload": hit.payload or {},
|
||
}
|
||
for hit in results.points
|
||
]
|
||
|
||
# ------------------------------------------------------------------
|
||
# Delete
|
||
# ------------------------------------------------------------------
|
||
|
||
async def delete_by_filter(
|
||
self, collection: str, filter_conditions: dict[str, Any]
|
||
) -> bool:
|
||
"""Delete all points matching the given filter dict."""
|
||
must_conditions = []
|
||
for key, value in filter_conditions.items():
|
||
if isinstance(value, list):
|
||
must_conditions.append(
|
||
qmodels.FieldCondition(
|
||
key=key, match=qmodels.MatchAny(any=value)
|
||
)
|
||
)
|
||
else:
|
||
must_conditions.append(
|
||
qmodels.FieldCondition(
|
||
key=key, match=qmodels.MatchValue(value=value)
|
||
)
|
||
)
|
||
|
||
self.client.delete(
|
||
collection_name=collection,
|
||
points_selector=qmodels.FilterSelector(
|
||
filter=qmodels.Filter(must=must_conditions)
|
||
),
|
||
wait=True,
|
||
)
|
||
logger.info("Deleted points from '%s' with filter %s", collection, filter_conditions)
|
||
return True
|
||
|
||
# ------------------------------------------------------------------
|
||
# Info
|
||
# ------------------------------------------------------------------
|
||
|
||
async def get_collection_info(self, collection: str) -> dict[str, Any]:
|
||
"""Return basic stats for a collection."""
|
||
try:
|
||
info = self.client.get_collection(collection)
|
||
return {
|
||
"name": collection,
|
||
"vectors_count": info.vectors_count,
|
||
"points_count": info.points_count,
|
||
"status": info.status.value if info.status else "unknown",
|
||
"vector_size": (
|
||
info.config.params.vectors.size
|
||
if hasattr(info.config.params.vectors, "size")
|
||
else None
|
||
),
|
||
}
|
||
except Exception as exc:
|
||
logger.error("Failed to get info for '%s': %s", collection, exc)
|
||
raise
|
||
|
||
|
||
# Singleton
|
||
qdrant_wrapper = QdrantClientWrapper()
|