Files
breakpilot-core/rag-service/qdrant_client_wrapper.py
T
Benjamin Admin 8510af46eb feat(pipeline): MC Quality Overhaul — 74.5% → 92.8% accuracy, 5.3K → 13.6K MCs
Phase 0: Quality Audit script (Claude Sonnet, 1750 samples)
Phase 1: Object ontology expanded 31 → 74 tokens with descriptions + boundaries
Phase 2: 174K controls re-classified via Haiku (10 batches, $50)
  - Generic tokens removed (documentation, procedure, process)
  - L2 sub-topics added (108K + 64K controls)
  - Bad subtopics fixed (stakeholder_*, escalation fragments)
Phase 3: Re-clustering K=18704 (37K objects → 16.7K groups)
Phase 4: Direct MC generation from canonical tokens (gpre2_direct_mc.py)
Phase 5: Regulation-source split (gpre3, dry-run tested)

New features:
- Tenant-isolated document upload API (rag-service)
- BAuA crawler (Playwright, 131 PDFs downloaded)
- OSHA Technical Manual crawler (23 chapters)
- CE obligation extractor (6141 obligations from Qdrant)

RAG ingestion:
- 126 BAuA PDFs (TRBS/TRGS/ASR): 27,664 chunks
- OSHA Technical Manual: 7,241 chunks
- OSHA 1910 Subpart O (full): 745 chunks
- EuGH C-588/21 P: 216 chunks
- EU 2018/1725: 842 chunks

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-05-10 15:08:15 +02:00

332 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import logging
import uuid
from typing import Any, Optional
from qdrant_client import QdrantClient
from qdrant_client.http import models as qmodels
from qdrant_client.http.exceptions import UnexpectedResponse
from config import settings
logger = logging.getLogger("rag-service.qdrant")
# ------------------------------------------------------------------
# Default collections with their vector dimensions
# ------------------------------------------------------------------
# Lehrer / EH collections (OpenAI-style 1536-dim embeddings)
_LEHRER_COLLECTIONS = {
"bp_eh": 1536,
"bp_nibis_eh": 1536,
"bp_nibis": 1536,
"bp_vocab": 1536,
}
# Compliance / Legal collections (1024-dim embeddings, e.g. from a smaller model)
_COMPLIANCE_COLLECTIONS = {
"bp_legal_templates": 1024,
"bp_compliance_gdpr": 1024,
"bp_compliance_schulrecht": 1024,
"bp_compliance_datenschutz": 1024,
"bp_compliance_gesetze": 1024,
"bp_compliance_ce": 1024,
"bp_dsfa_templates": 1024,
"bp_dsfa_risks": 1024,
}
ALL_DEFAULT_COLLECTIONS: dict[str, int] = {
**_LEHRER_COLLECTIONS,
**_COMPLIANCE_COLLECTIONS,
}
class QdrantClientWrapper:
"""Thin wrapper around QdrantClient with BreakPilot-specific helpers."""
def __init__(self) -> None:
self._client: Optional[QdrantClient] = None
@property
def client(self) -> QdrantClient:
if self._client is None:
self._client = QdrantClient(
url=settings.QDRANT_URL,
api_key=settings.QDRANT_API_KEY or None,
timeout=30,
)
logger.info("Connected to Qdrant at %s", settings.QDRANT_URL)
return self._client
# ------------------------------------------------------------------
# Initialisation
# ------------------------------------------------------------------
async def init_collections(self) -> None:
"""Create all default collections if they do not already exist."""
for name, dim in ALL_DEFAULT_COLLECTIONS.items():
await self.create_collection(name, dim)
logger.info(
"All default collections initialised (%d total)",
len(ALL_DEFAULT_COLLECTIONS),
)
async def create_collection(
self,
name: str,
vector_size: int,
distance: qmodels.Distance = qmodels.Distance.COSINE,
) -> bool:
"""Create a single collection. Returns True if newly created."""
try:
self.client.get_collection(name)
logger.debug("Collection '%s' already exists skipping", name)
return False
except (UnexpectedResponse, Exception):
pass
try:
self.client.create_collection(
collection_name=name,
vectors_config=qmodels.VectorParams(
size=vector_size,
distance=distance,
),
optimizers_config=qmodels.OptimizersConfigDiff(
indexing_threshold=20_000,
),
)
logger.info(
"Created collection '%s' (dims=%d, distance=%s)",
name,
vector_size,
distance.value,
)
return True
except Exception as exc:
logger.error("Failed to create collection '%s': %s", name, exc)
raise
# ------------------------------------------------------------------
# Indexing
# ------------------------------------------------------------------
async def ensure_collection(self, name: str, vector_size: int = 1024) -> None:
"""Create collection if it doesn't exist."""
try:
self.client.get_collection(name)
except Exception:
await self.create_collection(name, vector_size)
async def index_documents(
self,
collection: str,
vectors: list[list[float]],
payloads: list[dict[str, Any]],
ids: Optional[list[str]] = None,
) -> int:
"""Batch-upsert vectors + payloads. Returns number of points upserted."""
if len(vectors) != len(payloads):
raise ValueError(
f"vectors ({len(vectors)}) and payloads ({len(payloads)}) must have equal length"
)
# Auto-create collection if missing
vector_size = len(vectors[0]) if vectors else 1024
await self.ensure_collection(collection, vector_size)
if ids is None:
ids = [str(uuid.uuid4()) for _ in vectors]
points = [
qmodels.PointStruct(id=pid, vector=vec, payload=pay)
for pid, vec, pay in zip(ids, vectors, payloads)
]
batch_size = 100
total_upserted = 0
for i in range(0, len(points), batch_size):
batch = points[i : i + batch_size]
self.client.upsert(collection_name=collection, points=batch, wait=True)
total_upserted += len(batch)
logger.info(
"Upserted %d points into '%s'", total_upserted, collection
)
return total_upserted
# ------------------------------------------------------------------
# Search
# ------------------------------------------------------------------
async def search(
self,
collection: str,
query_vector: list[float],
limit: int = 10,
filters: Optional[dict[str, Any]] = None,
score_threshold: Optional[float] = None,
) -> list[dict[str, Any]]:
"""Semantic search. Returns list of {id, score, payload}."""
qdrant_filter = None
if filters:
must_conditions = []
for key, value in filters.items():
if isinstance(value, list):
must_conditions.append(
qmodels.FieldCondition(
key=key, match=qmodels.MatchAny(any=value)
)
)
else:
must_conditions.append(
qmodels.FieldCondition(
key=key, match=qmodels.MatchValue(value=value)
)
)
qdrant_filter = qmodels.Filter(must=must_conditions)
results = self.client.query_points(
collection_name=collection,
query=query_vector,
limit=limit,
query_filter=qdrant_filter,
score_threshold=score_threshold,
with_payload=True,
)
return [
{
"id": str(hit.id),
"score": hit.score,
"payload": hit.payload or {},
}
for hit in results.points
]
# ------------------------------------------------------------------
# Delete
# ------------------------------------------------------------------
async def delete_by_filter(
self, collection: str, filter_conditions: dict[str, Any]
) -> bool:
"""Delete all points matching the given filter dict."""
must_conditions = []
for key, value in filter_conditions.items():
if isinstance(value, list):
must_conditions.append(
qmodels.FieldCondition(
key=key, match=qmodels.MatchAny(any=value)
)
)
else:
must_conditions.append(
qmodels.FieldCondition(
key=key, match=qmodels.MatchValue(value=value)
)
)
self.client.delete(
collection_name=collection,
points_selector=qmodels.FilterSelector(
filter=qmodels.Filter(must=must_conditions)
),
wait=True,
)
logger.info("Deleted points from '%s' with filter %s", collection, filter_conditions)
return True
# ------------------------------------------------------------------
# Tenant document helpers
# ------------------------------------------------------------------
async def get_unique_documents(self, collection: str) -> list[dict]:
"""Get unique documents from a collection by scrolling and grouping."""
try:
self.client.get_collection(collection)
except Exception:
return []
docs: dict[str, dict] = {}
offset = None
while True:
result = self.client.scroll(
collection_name=collection,
scroll_filter=None,
limit=100,
offset=offset,
with_payload=True,
with_vectors=False,
)
points, next_offset = result
for pt in points:
payload = pt.payload or {}
doc_id = payload.get("document_id", "")
if doc_id and doc_id not in docs:
docs[doc_id] = {
"id": doc_id,
"filename": payload.get("filename", ""),
"file_size": payload.get("file_size", 0),
"status": "indexed",
"chunk_count": 0,
"collection": collection,
}
if doc_id:
docs[doc_id]["chunk_count"] += 1
if next_offset is None:
break
offset = next_offset
return list(docs.values())
async def count_by_filter(
self, collection: str, filter_conditions: dict[str, Any]
) -> int:
"""Count points matching filter."""
try:
self.client.get_collection(collection)
except Exception:
return 0
must_conditions = []
for key, value in filter_conditions.items():
must_conditions.append(
qmodels.FieldCondition(
key=key, match=qmodels.MatchValue(value=value)
)
)
result = self.client.count(
collection_name=collection,
count_filter=qmodels.Filter(must=must_conditions),
exact=True,
)
return result.count
# ------------------------------------------------------------------
# Info
# ------------------------------------------------------------------
async def get_collection_info(self, collection: str) -> dict[str, Any]:
"""Return basic stats for a collection."""
try:
info = self.client.get_collection(collection)
return {
"name": collection,
"vectors_count": info.vectors_count,
"points_count": info.points_count,
"status": info.status.value if info.status else "unknown",
"vector_size": (
info.config.params.vectors.size
if hasattr(info.config.params.vectors, "size")
else None
),
}
except Exception as exc:
logger.error("Failed to get info for '%s': %s", collection, exc)
raise
# Singleton
qdrant_wrapper = QdrantClientWrapper()