Some checks failed
CI / go-lint (push) Has been skipped
CI / python-lint (push) Has been skipped
CI / nodejs-lint (push) Has been skipped
CI / test-go-school (push) Successful in 42s
CI / test-go-edu-search (push) Successful in 34s
CI / test-python-klausur (push) Failing after 2m51s
CI / test-python-agent-core (push) Successful in 21s
CI / test-nodejs-website (push) Successful in 29s
sed replacement left orphaned hostname references in story page and empty lines in getApiBase functions. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
158 lines
5.0 KiB
Python
158 lines
5.0 KiB
Python
"""
|
|
DSFA Qdrant Service — Vector store operations.
|
|
|
|
Contains:
|
|
- DSFAQdrantService: Qdrant client wrapper for DSFA corpus
|
|
"""
|
|
|
|
import os
|
|
import uuid
|
|
from typing import List, Dict, Optional
|
|
from dataclasses import asdict
|
|
|
|
from qdrant_client import QdrantClient
|
|
from qdrant_client.models import (
|
|
VectorParams, Distance, PointStruct, Filter, FieldCondition, MatchValue
|
|
)
|
|
|
|
from dsfa_corpus_store import DSFAChunkPayload
|
|
|
|
QDRANT_URL = os.getenv("QDRANT_URL", "http://qdrant:6333")
|
|
DSFA_COLLECTION = "bp_dsfa_corpus"
|
|
VECTOR_SIZE = 1024 # BGE-M3
|
|
|
|
|
|
class DSFAQdrantService:
|
|
"""Qdrant operations for DSFA corpus."""
|
|
|
|
def __init__(self, url: Optional[str] = None):
|
|
self.url = url or QDRANT_URL
|
|
self._client = None
|
|
|
|
@property
|
|
def client(self) -> QdrantClient:
|
|
if self._client is None:
|
|
self._client = QdrantClient(url=self.url, check_compatibility=False)
|
|
return self._client
|
|
|
|
async def ensure_collection(self) -> bool:
|
|
"""Ensure DSFA collection exists."""
|
|
try:
|
|
collections = self.client.get_collections().collections
|
|
collection_names = [c.name for c in collections]
|
|
|
|
if DSFA_COLLECTION not in collection_names:
|
|
self.client.create_collection(
|
|
collection_name=DSFA_COLLECTION,
|
|
vectors_config=VectorParams(
|
|
size=VECTOR_SIZE,
|
|
distance=Distance.COSINE
|
|
)
|
|
)
|
|
print(f"Created collection: {DSFA_COLLECTION}")
|
|
return True
|
|
except Exception as e:
|
|
print(f"Error ensuring collection: {e}")
|
|
return False
|
|
|
|
async def index_chunks(
|
|
self,
|
|
chunks: List[Dict],
|
|
embeddings: List[List[float]]
|
|
) -> int:
|
|
"""Index chunks into Qdrant."""
|
|
if not chunks or not embeddings:
|
|
return 0
|
|
|
|
points = []
|
|
for chunk, embedding in zip(chunks, embeddings):
|
|
point_id = str(uuid.uuid5(uuid.NAMESPACE_DNS, chunk["chunk_id"]))
|
|
|
|
payload = DSFAChunkPayload(
|
|
chunk_id=chunk["chunk_id"],
|
|
document_id=chunk["document_id"],
|
|
source_id=chunk["source_id"],
|
|
content=chunk["content"],
|
|
section_title=chunk.get("section_title"),
|
|
source_code=chunk["source_code"],
|
|
source_name=chunk["source_name"],
|
|
attribution_text=chunk["attribution_text"],
|
|
license_code=chunk["license_code"],
|
|
attribution_required=chunk.get("attribution_required", True),
|
|
document_type=chunk.get("document_type", ""),
|
|
category=chunk.get("category", ""),
|
|
language=chunk.get("language", "de"),
|
|
page_number=chunk.get("page_number")
|
|
)
|
|
|
|
points.append(
|
|
PointStruct(
|
|
id=point_id,
|
|
vector=embedding,
|
|
payload=asdict(payload)
|
|
)
|
|
)
|
|
|
|
self.client.upsert(collection_name=DSFA_COLLECTION, points=points)
|
|
return len(points)
|
|
|
|
async def search(
|
|
self,
|
|
query_embedding: List[float],
|
|
source_codes: Optional[List[str]] = None,
|
|
document_types: Optional[List[str]] = None,
|
|
categories: Optional[List[str]] = None,
|
|
limit: int = 10
|
|
) -> List[Dict]:
|
|
"""Search DSFA corpus with filters."""
|
|
must_conditions = []
|
|
|
|
if source_codes:
|
|
for code in source_codes:
|
|
must_conditions.append(
|
|
FieldCondition(key="source_code", match=MatchValue(value=code))
|
|
)
|
|
|
|
if document_types:
|
|
for dtype in document_types:
|
|
must_conditions.append(
|
|
FieldCondition(key="document_type", match=MatchValue(value=dtype))
|
|
)
|
|
|
|
if categories:
|
|
for cat in categories:
|
|
must_conditions.append(
|
|
FieldCondition(key="category", match=MatchValue(value=cat))
|
|
)
|
|
|
|
query_filter = Filter(must=must_conditions) if must_conditions else None
|
|
|
|
results = self.client.query_points(
|
|
collection_name=DSFA_COLLECTION,
|
|
query=query_embedding,
|
|
query_filter=query_filter,
|
|
limit=limit
|
|
)
|
|
|
|
return [
|
|
{
|
|
"id": str(r.id),
|
|
"score": r.score,
|
|
**r.payload
|
|
}
|
|
for r in results.points
|
|
]
|
|
|
|
async def get_stats(self) -> Dict:
|
|
"""Get collection statistics."""
|
|
try:
|
|
info = self.client.get_collection(DSFA_COLLECTION)
|
|
return {
|
|
"collection": DSFA_COLLECTION,
|
|
"vectors_count": info.vectors_count,
|
|
"points_count": info.points_count,
|
|
"status": info.status.value
|
|
}
|
|
except Exception as e:
|
|
return {"error": str(e), "collection": DSFA_COLLECTION}
|