""" Qdrant Vector Database Service — Legal Templates RAG Search. """ from typing import List, Dict, Optional from qdrant_client.models import VectorParams, Distance, Filter, FieldCondition, MatchValue from qdrant_core import get_qdrant_client LEGAL_TEMPLATES_COLLECTION = "bp_legal_templates" LEGAL_TEMPLATES_VECTOR_SIZE = 1024 # BGE-M3 async def init_legal_templates_collection() -> bool: """Initialize Qdrant collection for legal templates if not exists.""" try: client = get_qdrant_client() collections = client.get_collections().collections collection_names = [c.name for c in collections] if LEGAL_TEMPLATES_COLLECTION not in collection_names: client.create_collection( collection_name=LEGAL_TEMPLATES_COLLECTION, vectors_config=VectorParams( size=LEGAL_TEMPLATES_VECTOR_SIZE, distance=Distance.COSINE ) ) print(f"Created Qdrant collection: {LEGAL_TEMPLATES_COLLECTION}") else: print(f"Qdrant collection {LEGAL_TEMPLATES_COLLECTION} already exists") return True except Exception as e: print(f"Failed to initialize legal templates collection: {e}") return False async def search_legal_templates( query_embedding: List[float], template_type: Optional[str] = None, license_types: Optional[List[str]] = None, language: Optional[str] = None, jurisdiction: Optional[str] = None, attribution_required: Optional[bool] = None, limit: int = 10 ) -> List[Dict]: """ Search in legal templates collection for document generation. Args: query_embedding: Query vector (1024 dimensions, BGE-M3) template_type: Filter by template type (privacy_policy, terms_of_service, etc.) license_types: Filter by license types (cc0, mit, cc_by_4, etc.) language: Filter by language (de, en) jurisdiction: Filter by jurisdiction (DE, EU, US, etc.) attribution_required: Filter by attribution requirement limit: Max results Returns: List of matching template chunks with full metadata """ client = get_qdrant_client() # Build filter conditions must_conditions = [] if template_type: must_conditions.append( FieldCondition(key="template_type", match=MatchValue(value=template_type)) ) if language: must_conditions.append( FieldCondition(key="language", match=MatchValue(value=language)) ) if jurisdiction: must_conditions.append( FieldCondition(key="jurisdiction", match=MatchValue(value=jurisdiction)) ) if attribution_required is not None: must_conditions.append( FieldCondition(key="attribution_required", match=MatchValue(value=attribution_required)) ) # License type filter (OR condition) should_conditions = [] if license_types: for license_type in license_types: should_conditions.append( FieldCondition(key="license_id", match=MatchValue(value=license_type)) ) # Construct filter query_filter = None if must_conditions or should_conditions: filter_args = {} if must_conditions: filter_args["must"] = must_conditions if should_conditions: filter_args["should"] = should_conditions query_filter = Filter(**filter_args) try: results = client.search( collection_name=LEGAL_TEMPLATES_COLLECTION, query_vector=query_embedding, query_filter=query_filter, limit=limit ) return [ { "id": str(r.id), "score": r.score, "text": r.payload.get("text", ""), "document_title": r.payload.get("document_title"), "template_type": r.payload.get("template_type"), "clause_category": r.payload.get("clause_category"), "language": r.payload.get("language"), "jurisdiction": r.payload.get("jurisdiction"), "license_id": r.payload.get("license_id"), "license_name": r.payload.get("license_name"), "license_url": r.payload.get("license_url"), "attribution_required": r.payload.get("attribution_required"), "attribution_text": r.payload.get("attribution_text"), "source_name": r.payload.get("source_name"), "source_url": r.payload.get("source_url"), "source_repo": r.payload.get("source_repo"), "placeholders": r.payload.get("placeholders", []), "is_complete_document": r.payload.get("is_complete_document"), "is_modular": r.payload.get("is_modular"), "requires_customization": r.payload.get("requires_customization"), "output_allowed": r.payload.get("output_allowed"), "modification_allowed": r.payload.get("modification_allowed"), "distortion_prohibited": r.payload.get("distortion_prohibited"), } for r in results ] except Exception as e: print(f"Legal templates search error: {e}") return [] async def get_legal_templates_stats() -> Dict: """Get statistics for the legal templates collection.""" try: client = get_qdrant_client() info = client.get_collection(LEGAL_TEMPLATES_COLLECTION) # Count by template type template_types = ["privacy_policy", "terms_of_service", "cookie_banner", "impressum", "widerruf", "dpa", "sla", "agb"] type_counts = {} for ttype in template_types: result = client.count( collection_name=LEGAL_TEMPLATES_COLLECTION, count_filter=Filter( must=[FieldCondition(key="template_type", match=MatchValue(value=ttype))] ) ) if result.count > 0: type_counts[ttype] = result.count # Count by language lang_counts = {} for lang in ["de", "en"]: result = client.count( collection_name=LEGAL_TEMPLATES_COLLECTION, count_filter=Filter( must=[FieldCondition(key="language", match=MatchValue(value=lang))] ) ) lang_counts[lang] = result.count # Count by license license_counts = {} for license_id in ["cc0", "mit", "cc_by_4", "public_domain", "unlicense"]: result = client.count( collection_name=LEGAL_TEMPLATES_COLLECTION, count_filter=Filter( must=[FieldCondition(key="license_id", match=MatchValue(value=license_id))] ) ) if result.count > 0: license_counts[license_id] = result.count return { "collection": LEGAL_TEMPLATES_COLLECTION, "vectors_count": info.vectors_count, "points_count": info.points_count, "status": info.status.value, "template_types": type_counts, "languages": lang_counts, "licenses": license_counts, } except Exception as e: return {"error": str(e), "collection": LEGAL_TEMPLATES_COLLECTION} async def delete_legal_templates_by_source(source_name: str) -> int: """ Delete all legal template chunks from a specific source. Args: source_name: Name of the source to delete Returns: Number of deleted points """ client = get_qdrant_client() # Count first count_result = client.count( collection_name=LEGAL_TEMPLATES_COLLECTION, count_filter=Filter( must=[FieldCondition(key="source_name", match=MatchValue(value=source_name))] ) ) # Delete by filter client.delete( collection_name=LEGAL_TEMPLATES_COLLECTION, points_selector=Filter( must=[FieldCondition(key="source_name", match=MatchValue(value=source_name))] ) ) return count_result.count