""" Legal Templates Ingestion Pipeline for RAG. Indexes legal template documents from various open-source repositories into Qdrant for semantic search. Supports multiple license types with proper attribution tracking. Collection: bp_legal_templates Usage: python legal_templates_cli.py --ingest-all python legal_templates_cli.py --ingest-source github-site-policy python legal_templates_cli.py --status python legal_templates_cli.py --search "Datenschutzerklaerung" """ import asyncio import hashlib import logging import os from datetime import datetime from typing import Any, Dict, List, Optional from urllib.parse import urlparse import httpx from qdrant_client import QdrantClient from qdrant_client.models import ( Distance, FieldCondition, Filter, MatchValue, PointStruct, VectorParams, ) from template_sources import ( LICENSES, TEMPLATE_SOURCES, TEMPLATE_TYPES, LicenseType, SourceConfig, get_enabled_sources, get_sources_by_priority, ) from github_crawler import ( ExtractedDocument, GitHubCrawler, RepositoryDownloader, ) # Re-export from chunking module for backward compatibility from legal_templates_chunking import ( # noqa: F401 IngestionStatus, TemplateChunk, chunk_text, create_chunks, infer_clause_category, infer_template_type, split_sentences, ) # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Configuration - Support both QDRANT_URL and QDRANT_HOST/PORT _qdrant_url = os.getenv("QDRANT_URL", "") if _qdrant_url: _parsed = urlparse(_qdrant_url) QDRANT_HOST = _parsed.hostname or "localhost" QDRANT_PORT = _parsed.port or 6333 else: QDRANT_HOST = os.getenv("QDRANT_HOST", "localhost") QDRANT_PORT = int(os.getenv("QDRANT_PORT", "6333")) EMBEDDING_SERVICE_URL = os.getenv("EMBEDDING_SERVICE_URL", "http://localhost:8087") LEGAL_TEMPLATES_COLLECTION = "bp_legal_templates" VECTOR_SIZE = 1024 # BGE-M3 dimension # Chunking configuration CHUNK_SIZE = int(os.getenv("TEMPLATE_CHUNK_SIZE", "1000")) CHUNK_OVERLAP = int(os.getenv("TEMPLATE_CHUNK_OVERLAP", "200")) # Batch processing EMBEDDING_BATCH_SIZE = 4 MAX_RETRIES = 3 RETRY_DELAY = 3.0 class LegalTemplatesIngestion: """Handles ingestion of legal templates into Qdrant.""" def __init__(self): self.qdrant = QdrantClient(host=QDRANT_HOST, port=QDRANT_PORT) self.http_client = httpx.AsyncClient(timeout=120.0) self._ensure_collection() self._ingestion_status: Dict[str, IngestionStatus] = {} def _ensure_collection(self): """Create the legal templates collection if it doesn't exist.""" collections = self.qdrant.get_collections().collections collection_names = [c.name for c in collections] if LEGAL_TEMPLATES_COLLECTION not in collection_names: logger.info(f"Creating collection: {LEGAL_TEMPLATES_COLLECTION}") self.qdrant.create_collection( collection_name=LEGAL_TEMPLATES_COLLECTION, vectors_config=VectorParams( size=VECTOR_SIZE, distance=Distance.COSINE, ), ) logger.info(f"Collection {LEGAL_TEMPLATES_COLLECTION} created") else: logger.info(f"Collection {LEGAL_TEMPLATES_COLLECTION} already exists") async def _generate_embeddings(self, texts: List[str]) -> List[List[float]]: """Generate embeddings via the embedding service.""" try: response = await self.http_client.post( f"{EMBEDDING_SERVICE_URL}/embed", json={"texts": texts}, timeout=120.0, ) response.raise_for_status() data = response.json() return data["embeddings"] except Exception as e: logger.error(f"Embedding generation failed: {e}") raise async def ingest_source(self, source: SourceConfig) -> IngestionStatus: """Ingest a single source into Qdrant.""" status = IngestionStatus( source_name=source.name, status="running", started_at=datetime.utcnow(), ) self._ingestion_status[source.name] = status logger.info(f"Ingesting source: {source.name}") try: # Crawl the source documents: List[ExtractedDocument] = [] if source.repo_url: async with GitHubCrawler() as crawler: async for doc in crawler.crawl_repository(source): documents.append(doc) status.documents_found += 1 logger.info(f"Found {len(documents)} documents in {source.name}") if not documents: status.status = "completed" status.completed_at = datetime.utcnow() return status # Create chunks from all documents all_chunks: List[TemplateChunk] = [] for doc in documents: chunks = create_chunks(doc, source, CHUNK_SIZE, CHUNK_OVERLAP) all_chunks.extend(chunks) status.chunks_created += len(chunks) logger.info(f"Created {len(all_chunks)} chunks from {source.name}") # Generate embeddings and index in batches for i in range(0, len(all_chunks), EMBEDDING_BATCH_SIZE): batch_chunks = all_chunks[i:i + EMBEDDING_BATCH_SIZE] chunk_texts = [c.text for c in batch_chunks] # Retry logic for embeddings embeddings = None for retry in range(MAX_RETRIES): try: embeddings = await self._generate_embeddings(chunk_texts) break except Exception as e: logger.warning( f"Embedding attempt {retry+1}/{MAX_RETRIES} failed: {e}" ) if retry < MAX_RETRIES - 1: await asyncio.sleep(RETRY_DELAY * (retry + 1)) else: status.errors.append(f"Embedding failed for batch {i}: {e}") if embeddings is None: continue # Create points for Qdrant points = [] for j, (chunk, embedding) in enumerate(zip(batch_chunks, embeddings)): point_id = hashlib.md5( f"{source.name}-{chunk.source_file}-{chunk.chunk_index}".encode() ).hexdigest() payload = { "text": chunk.text, "chunk_index": chunk.chunk_index, "document_title": chunk.document_title, "template_type": chunk.template_type, "clause_category": chunk.clause_category, "language": chunk.language, "jurisdiction": chunk.jurisdiction, "license_id": chunk.license_id, "license_name": chunk.license_name, "license_url": chunk.license_url, "attribution_required": chunk.attribution_required, "share_alike": chunk.share_alike, "no_derivatives": chunk.no_derivatives, "commercial_use": chunk.commercial_use, "source_name": chunk.source_name, "source_url": chunk.source_url, "source_repo": chunk.source_repo, "source_commit": chunk.source_commit, "source_file": chunk.source_file, "source_hash": chunk.source_hash, "attribution_text": chunk.attribution_text, "copyright_notice": chunk.copyright_notice, "is_complete_document": chunk.is_complete_document, "is_modular": chunk.is_modular, "requires_customization": chunk.requires_customization, "placeholders": chunk.placeholders, "training_allowed": chunk.training_allowed, "output_allowed": chunk.output_allowed, "modification_allowed": chunk.modification_allowed, "distortion_prohibited": chunk.distortion_prohibited, "indexed_at": datetime.utcnow().isoformat(), } points.append(PointStruct( id=point_id, vector=embedding, payload=payload, )) # Upsert to Qdrant if points: self.qdrant.upsert( collection_name=LEGAL_TEMPLATES_COLLECTION, points=points, ) status.chunks_indexed += len(points) # Rate limiting await asyncio.sleep(1.0) status.status = "completed" status.completed_at = datetime.utcnow() logger.info( f"Completed {source.name}: {status.chunks_indexed} chunks indexed" ) except Exception as e: status.status = "failed" status.errors.append(str(e)) status.completed_at = datetime.utcnow() logger.error(f"Failed to ingest {source.name}: {e}") return status async def ingest_all(self, max_priority: int = 5) -> Dict[str, IngestionStatus]: """Ingest all enabled sources up to a priority level.""" sources = get_sources_by_priority(max_priority) results = {} for source in sources: result = await self.ingest_source(source) results[source.name] = result return results async def ingest_by_license(self, license_type: LicenseType) -> Dict[str, IngestionStatus]: """Ingest all sources of a specific license type.""" from template_sources import get_sources_by_license sources = get_sources_by_license(license_type) results = {} for source in sources: result = await self.ingest_source(source) results[source.name] = result return results def get_status(self) -> Dict[str, Any]: """Get collection status and ingestion results.""" try: collection_info = self.qdrant.get_collection(LEGAL_TEMPLATES_COLLECTION) # Count points per source source_counts = {} for source in TEMPLATE_SOURCES: result = self.qdrant.count( collection_name=LEGAL_TEMPLATES_COLLECTION, count_filter=Filter( must=[ FieldCondition( key="source_name", match=MatchValue(value=source.name), ) ] ), ) source_counts[source.name] = result.count # Count by license type license_counts = {} for license_type in LicenseType: result = self.qdrant.count( collection_name=LEGAL_TEMPLATES_COLLECTION, count_filter=Filter( must=[ FieldCondition( key="license_id", match=MatchValue(value=license_type.value), ) ] ), ) license_counts[license_type.value] = result.count # Count by template type template_type_counts = {} for template_type in TEMPLATE_TYPES.keys(): result = self.qdrant.count( collection_name=LEGAL_TEMPLATES_COLLECTION, count_filter=Filter( must=[ FieldCondition( key="template_type", match=MatchValue(value=template_type), ) ] ), ) if result.count > 0: template_type_counts[template_type] = result.count # Count by language language_counts = {} for lang in ["de", "en"]: result = self.qdrant.count( collection_name=LEGAL_TEMPLATES_COLLECTION, count_filter=Filter( must=[ FieldCondition( key="language", match=MatchValue(value=lang), ) ] ), ) language_counts[lang] = result.count return { "collection": LEGAL_TEMPLATES_COLLECTION, "total_points": collection_info.points_count, "vector_size": VECTOR_SIZE, "sources": source_counts, "licenses": license_counts, "template_types": template_type_counts, "languages": language_counts, "status": "ready" if collection_info.points_count > 0 else "empty", "ingestion_status": { name: { "status": s.status, "documents_found": s.documents_found, "chunks_indexed": s.chunks_indexed, "errors": s.errors, } for name, s in self._ingestion_status.items() }, } except Exception as e: return { "collection": LEGAL_TEMPLATES_COLLECTION, "error": str(e), "status": "error", } async def search( self, query: str, template_type: Optional[str] = None, license_types: Optional[List[str]] = None, language: Optional[str] = None, jurisdiction: Optional[str] = None, attribution_required: Optional[bool] = None, top_k: int = 10, ) -> List[Dict[str, Any]]: """Search the legal templates collection.""" # Generate query embedding embeddings = await self._generate_embeddings([query]) query_vector = embeddings[0] # Build filter conditions must_conditions = [] if template_type: must_conditions.append( FieldCondition(key="template_type", match=MatchValue(value=template_type)) ) if language: must_conditions.append( FieldCondition(key="language", match=MatchValue(value=language)) ) if jurisdiction: must_conditions.append( FieldCondition(key="jurisdiction", match=MatchValue(value=jurisdiction)) ) if attribution_required is not None: must_conditions.append( FieldCondition(key="attribution_required", match=MatchValue(value=attribution_required)) ) # License type filter (OR condition) should_conditions = [] if license_types: for lt in license_types: should_conditions.append( FieldCondition(key="license_id", match=MatchValue(value=lt)) ) # Construct filter search_filter = None if must_conditions or should_conditions: filter_dict = {} if must_conditions: filter_dict["must"] = must_conditions if should_conditions: filter_dict["should"] = should_conditions search_filter = Filter(**filter_dict) # Execute search results = self.qdrant.search( collection_name=LEGAL_TEMPLATES_COLLECTION, query_vector=query_vector, query_filter=search_filter, limit=top_k, ) return [ { "id": hit.id, "score": hit.score, "text": hit.payload.get("text"), "document_title": hit.payload.get("document_title"), "template_type": hit.payload.get("template_type"), "clause_category": hit.payload.get("clause_category"), "language": hit.payload.get("language"), "jurisdiction": hit.payload.get("jurisdiction"), "license_id": hit.payload.get("license_id"), "license_name": hit.payload.get("license_name"), "attribution_required": hit.payload.get("attribution_required"), "attribution_text": hit.payload.get("attribution_text"), "source_name": hit.payload.get("source_name"), "source_url": hit.payload.get("source_url"), "placeholders": hit.payload.get("placeholders"), "is_complete_document": hit.payload.get("is_complete_document"), "requires_customization": hit.payload.get("requires_customization"), "output_allowed": hit.payload.get("output_allowed"), "modification_allowed": hit.payload.get("modification_allowed"), } for hit in results ] def delete_source(self, source_name: str) -> int: """Delete all chunks from a specific source.""" count_result = self.qdrant.count( collection_name=LEGAL_TEMPLATES_COLLECTION, count_filter=Filter( must=[FieldCondition(key="source_name", match=MatchValue(value=source_name))] ), ) self.qdrant.delete( collection_name=LEGAL_TEMPLATES_COLLECTION, points_selector=Filter( must=[FieldCondition(key="source_name", match=MatchValue(value=source_name))] ), ) return count_result.count def reset_collection(self): """Delete and recreate the collection.""" logger.warning(f"Resetting collection: {LEGAL_TEMPLATES_COLLECTION}") try: self.qdrant.delete_collection(LEGAL_TEMPLATES_COLLECTION) except Exception: pass self._ensure_collection() self._ingestion_status.clear() logger.info(f"Collection {LEGAL_TEMPLATES_COLLECTION} reset") async def close(self): """Close HTTP client.""" await self.http_client.aclose()