""" Qdrant Vector Database Service for BYOEH Manages vector storage and semantic search for Erwartungshorizonte. """ import os from typing import List, Dict, Optional from qdrant_client import QdrantClient from qdrant_client.http import models from qdrant_client.models import VectorParams, Distance, PointStruct, Filter, FieldCondition, MatchValue QDRANT_URL = os.getenv("QDRANT_URL", "http://localhost:6333") COLLECTION_NAME = "bp_eh" VECTOR_SIZE = 1536 # OpenAI text-embedding-3-small _client: Optional[QdrantClient] = None def get_qdrant_client() -> QdrantClient: """Get or create Qdrant client singleton.""" global _client if _client is None: _client = QdrantClient(url=QDRANT_URL) return _client async def init_qdrant_collection() -> bool: """Initialize Qdrant collection for BYOEH if not exists.""" try: client = get_qdrant_client() # Check if collection exists collections = client.get_collections().collections collection_names = [c.name for c in collections] if COLLECTION_NAME not in collection_names: client.create_collection( collection_name=COLLECTION_NAME, vectors_config=VectorParams( size=VECTOR_SIZE, distance=Distance.COSINE ) ) print(f"Created Qdrant collection: {COLLECTION_NAME}") else: print(f"Qdrant collection {COLLECTION_NAME} already exists") return True except Exception as e: print(f"Failed to initialize Qdrant: {e}") return False async def index_eh_chunks( eh_id: str, tenant_id: str, subject: str, chunks: List[Dict] ) -> int: """ Index EH chunks in Qdrant. Args: eh_id: Erwartungshorizont ID tenant_id: Tenant/School ID for isolation subject: Subject (deutsch, englisch, etc.) chunks: List of {text, embedding, encrypted_content} Returns: Number of indexed chunks """ client = get_qdrant_client() points = [] for i, chunk in enumerate(chunks): point_id = f"{eh_id}_{i}" points.append( PointStruct( id=point_id, vector=chunk["embedding"], payload={ "tenant_id": tenant_id, "eh_id": eh_id, "chunk_index": i, "subject": subject, "encrypted_content": chunk.get("encrypted_content", ""), "training_allowed": False # ALWAYS FALSE - critical for compliance } ) ) if points: client.upsert(collection_name=COLLECTION_NAME, points=points) return len(points) async def search_eh( query_embedding: List[float], tenant_id: str, subject: Optional[str] = None, limit: int = 5 ) -> List[Dict]: """ Semantic search in tenant's Erwartungshorizonte. Args: query_embedding: Query vector (1536 dimensions) tenant_id: Tenant ID for isolation subject: Optional subject filter limit: Max results Returns: List of matching chunks with scores """ client = get_qdrant_client() # Build filter conditions must_conditions = [ FieldCondition(key="tenant_id", match=MatchValue(value=tenant_id)) ] if subject: must_conditions.append( FieldCondition(key="subject", match=MatchValue(value=subject)) ) query_filter = Filter(must=must_conditions) results = client.search( collection_name=COLLECTION_NAME, query_vector=query_embedding, query_filter=query_filter, limit=limit ) return [ { "id": str(r.id), "score": r.score, "eh_id": r.payload.get("eh_id"), "chunk_index": r.payload.get("chunk_index"), "encrypted_content": r.payload.get("encrypted_content"), "subject": r.payload.get("subject") } for r in results ] async def delete_eh_vectors(eh_id: str) -> int: """ Delete all vectors for a specific Erwartungshorizont. Args: eh_id: Erwartungshorizont ID Returns: Number of deleted points """ client = get_qdrant_client() # Get all points for this EH first scroll_result = client.scroll( collection_name=COLLECTION_NAME, scroll_filter=Filter( must=[FieldCondition(key="eh_id", match=MatchValue(value=eh_id))] ), limit=1000 ) point_ids = [str(p.id) for p in scroll_result[0]] if point_ids: client.delete( collection_name=COLLECTION_NAME, points_selector=models.PointIdsList(points=point_ids) ) return len(point_ids) async def get_collection_info() -> Dict: """Get collection statistics.""" try: client = get_qdrant_client() info = client.get_collection(COLLECTION_NAME) return { "name": COLLECTION_NAME, "vectors_count": info.vectors_count, "points_count": info.points_count, "status": info.status.value } except Exception as e: return {"error": str(e)} # ============================================================================= # QdrantService Class (for NiBiS Ingestion Pipeline) # ============================================================================= class QdrantService: """ Class-based Qdrant service for flexible collection management. Used by nibis_ingestion.py for bulk indexing. """ def __init__(self, url: str = None): self.url = url or QDRANT_URL self._client = None @property def client(self) -> QdrantClient: if self._client is None: self._client = QdrantClient(url=self.url) return self._client async def ensure_collection(self, collection_name: str, vector_size: int = VECTOR_SIZE) -> bool: """ Ensure collection exists, create if needed. Args: collection_name: Name of the collection vector_size: Dimension of vectors Returns: True if collection exists/created """ try: collections = self.client.get_collections().collections collection_names = [c.name for c in collections] if collection_name not in collection_names: self.client.create_collection( collection_name=collection_name, vectors_config=VectorParams( size=vector_size, distance=Distance.COSINE ) ) print(f"Created collection: {collection_name}") return True except Exception as e: print(f"Error ensuring collection: {e}") return False async def upsert_points(self, collection_name: str, points: List[Dict]) -> int: """ Upsert points into collection. Args: collection_name: Target collection points: List of {id, vector, payload} Returns: Number of upserted points """ import uuid if not points: return 0 qdrant_points = [] for p in points: # Convert string ID to UUID for Qdrant compatibility point_id = p["id"] if isinstance(point_id, str): # Use uuid5 with DNS namespace for deterministic UUID from string point_id = str(uuid.uuid5(uuid.NAMESPACE_DNS, point_id)) qdrant_points.append( PointStruct( id=point_id, vector=p["vector"], payload={**p.get("payload", {}), "original_id": p["id"]} # Keep original ID in payload ) ) self.client.upsert(collection_name=collection_name, points=qdrant_points) return len(qdrant_points) async def search( self, collection_name: str, query_vector: List[float], filter_conditions: Optional[Dict] = None, limit: int = 10 ) -> List[Dict]: """ Semantic search in collection. Args: collection_name: Collection to search query_vector: Query embedding filter_conditions: Optional filters (key: value pairs) limit: Max results Returns: List of matching points with scores """ query_filter = None if filter_conditions: must_conditions = [ FieldCondition(key=k, match=MatchValue(value=v)) for k, v in filter_conditions.items() ] query_filter = Filter(must=must_conditions) results = self.client.search( collection_name=collection_name, query_vector=query_vector, query_filter=query_filter, limit=limit ) return [ { "id": str(r.id), "score": r.score, "payload": r.payload } for r in results ] async def get_stats(self, collection_name: str) -> Dict: """Get collection statistics.""" try: info = self.client.get_collection(collection_name) return { "name": collection_name, "vectors_count": info.vectors_count, "points_count": info.points_count, "status": info.status.value } except Exception as e: return {"error": str(e), "name": collection_name} # ============================================================================= # NiBiS RAG Search (for Klausurkorrektur Module) # ============================================================================= async def search_nibis_eh( query_embedding: List[float], year: Optional[int] = None, subject: Optional[str] = None, niveau: Optional[str] = None, limit: int = 5 ) -> List[Dict]: """ Search in NiBiS Erwartungshorizonte (public, pre-indexed data). Unlike search_eh(), this searches in the public NiBiS collection and returns plaintext (not encrypted). Args: query_embedding: Query vector year: Optional year filter (2016, 2017, 2024, 2025) subject: Optional subject filter niveau: Optional niveau filter (eA, gA) limit: Max results Returns: List of matching chunks with metadata """ client = get_qdrant_client() collection = "bp_nibis_eh" # Build filter must_conditions = [] if year: must_conditions.append( FieldCondition(key="year", match=MatchValue(value=year)) ) if subject: must_conditions.append( FieldCondition(key="subject", match=MatchValue(value=subject)) ) if niveau: must_conditions.append( FieldCondition(key="niveau", match=MatchValue(value=niveau)) ) query_filter = Filter(must=must_conditions) if must_conditions else None try: results = client.search( collection_name=collection, query_vector=query_embedding, query_filter=query_filter, limit=limit ) return [ { "id": str(r.id), "score": r.score, "text": r.payload.get("text", ""), "year": r.payload.get("year"), "subject": r.payload.get("subject"), "niveau": r.payload.get("niveau"), "task_number": r.payload.get("task_number"), "doc_type": r.payload.get("doc_type"), "variant": r.payload.get("variant"), } for r in results ] except Exception as e: print(f"NiBiS search error: {e}") return [] # ============================================================================= # Legal Templates RAG Search (for Document Generator) # ============================================================================= LEGAL_TEMPLATES_COLLECTION = "bp_legal_templates" LEGAL_TEMPLATES_VECTOR_SIZE = 1024 # BGE-M3 async def init_legal_templates_collection() -> bool: """Initialize Qdrant collection for legal templates if not exists.""" try: client = get_qdrant_client() collections = client.get_collections().collections collection_names = [c.name for c in collections] if LEGAL_TEMPLATES_COLLECTION not in collection_names: client.create_collection( collection_name=LEGAL_TEMPLATES_COLLECTION, vectors_config=VectorParams( size=LEGAL_TEMPLATES_VECTOR_SIZE, distance=Distance.COSINE ) ) print(f"Created Qdrant collection: {LEGAL_TEMPLATES_COLLECTION}") else: print(f"Qdrant collection {LEGAL_TEMPLATES_COLLECTION} already exists") return True except Exception as e: print(f"Failed to initialize legal templates collection: {e}") return False async def search_legal_templates( query_embedding: List[float], template_type: Optional[str] = None, license_types: Optional[List[str]] = None, language: Optional[str] = None, jurisdiction: Optional[str] = None, attribution_required: Optional[bool] = None, limit: int = 10 ) -> List[Dict]: """ Search in legal templates collection for document generation. Args: query_embedding: Query vector (1024 dimensions, BGE-M3) template_type: Filter by template type (privacy_policy, terms_of_service, etc.) license_types: Filter by license types (cc0, mit, cc_by_4, etc.) language: Filter by language (de, en) jurisdiction: Filter by jurisdiction (DE, EU, US, etc.) attribution_required: Filter by attribution requirement limit: Max results Returns: List of matching template chunks with full metadata """ client = get_qdrant_client() # Build filter conditions must_conditions = [] if template_type: must_conditions.append( FieldCondition(key="template_type", match=MatchValue(value=template_type)) ) if language: must_conditions.append( FieldCondition(key="language", match=MatchValue(value=language)) ) if jurisdiction: must_conditions.append( FieldCondition(key="jurisdiction", match=MatchValue(value=jurisdiction)) ) if attribution_required is not None: must_conditions.append( FieldCondition(key="attribution_required", match=MatchValue(value=attribution_required)) ) # License type filter (OR condition) should_conditions = [] if license_types: for license_type in license_types: should_conditions.append( FieldCondition(key="license_id", match=MatchValue(value=license_type)) ) # Construct filter query_filter = None if must_conditions or should_conditions: filter_args = {} if must_conditions: filter_args["must"] = must_conditions if should_conditions: filter_args["should"] = should_conditions query_filter = Filter(**filter_args) try: results = client.search( collection_name=LEGAL_TEMPLATES_COLLECTION, query_vector=query_embedding, query_filter=query_filter, limit=limit ) return [ { "id": str(r.id), "score": r.score, "text": r.payload.get("text", ""), "document_title": r.payload.get("document_title"), "template_type": r.payload.get("template_type"), "clause_category": r.payload.get("clause_category"), "language": r.payload.get("language"), "jurisdiction": r.payload.get("jurisdiction"), "license_id": r.payload.get("license_id"), "license_name": r.payload.get("license_name"), "license_url": r.payload.get("license_url"), "attribution_required": r.payload.get("attribution_required"), "attribution_text": r.payload.get("attribution_text"), "source_name": r.payload.get("source_name"), "source_url": r.payload.get("source_url"), "source_repo": r.payload.get("source_repo"), "placeholders": r.payload.get("placeholders", []), "is_complete_document": r.payload.get("is_complete_document"), "is_modular": r.payload.get("is_modular"), "requires_customization": r.payload.get("requires_customization"), "output_allowed": r.payload.get("output_allowed"), "modification_allowed": r.payload.get("modification_allowed"), "distortion_prohibited": r.payload.get("distortion_prohibited"), } for r in results ] except Exception as e: print(f"Legal templates search error: {e}") return [] async def get_legal_templates_stats() -> Dict: """Get statistics for the legal templates collection.""" try: client = get_qdrant_client() info = client.get_collection(LEGAL_TEMPLATES_COLLECTION) # Count by template type template_types = ["privacy_policy", "terms_of_service", "cookie_banner", "impressum", "widerruf", "dpa", "sla", "agb"] type_counts = {} for ttype in template_types: result = client.count( collection_name=LEGAL_TEMPLATES_COLLECTION, count_filter=Filter( must=[FieldCondition(key="template_type", match=MatchValue(value=ttype))] ) ) if result.count > 0: type_counts[ttype] = result.count # Count by language lang_counts = {} for lang in ["de", "en"]: result = client.count( collection_name=LEGAL_TEMPLATES_COLLECTION, count_filter=Filter( must=[FieldCondition(key="language", match=MatchValue(value=lang))] ) ) lang_counts[lang] = result.count # Count by license license_counts = {} for license_id in ["cc0", "mit", "cc_by_4", "public_domain", "unlicense"]: result = client.count( collection_name=LEGAL_TEMPLATES_COLLECTION, count_filter=Filter( must=[FieldCondition(key="license_id", match=MatchValue(value=license_id))] ) ) if result.count > 0: license_counts[license_id] = result.count return { "collection": LEGAL_TEMPLATES_COLLECTION, "vectors_count": info.vectors_count, "points_count": info.points_count, "status": info.status.value, "template_types": type_counts, "languages": lang_counts, "licenses": license_counts, } except Exception as e: return {"error": str(e), "collection": LEGAL_TEMPLATES_COLLECTION} async def delete_legal_templates_by_source(source_name: str) -> int: """ Delete all legal template chunks from a specific source. Args: source_name: Name of the source to delete Returns: Number of deleted points """ client = get_qdrant_client() # Count first count_result = client.count( collection_name=LEGAL_TEMPLATES_COLLECTION, count_filter=Filter( must=[FieldCondition(key="source_name", match=MatchValue(value=source_name))] ) ) # Delete by filter client.delete( collection_name=LEGAL_TEMPLATES_COLLECTION, points_selector=Filter( must=[FieldCondition(key="source_name", match=MatchValue(value=source_name))] ) ) return count_result.count