""" Legal Templates Ingestion Pipeline for RAG. Indexes legal template documents from various open-source repositories into Qdrant for semantic search. Supports multiple license types with proper attribution tracking. Collection: bp_legal_templates Usage: python legal_templates_ingestion.py --ingest-all python legal_templates_ingestion.py --ingest-source github-site-policy python legal_templates_ingestion.py --status python legal_templates_ingestion.py --search "Datenschutzerklaerung" """ import asyncio import hashlib import json import logging import os from dataclasses import dataclass, field from datetime import datetime from typing import Any, Dict, List, Optional from urllib.parse import urlparse import httpx from qdrant_client import QdrantClient from qdrant_client.models import ( Distance, FieldCondition, Filter, MatchValue, PointStruct, VectorParams, ) from template_sources import ( LICENSES, TEMPLATE_SOURCES, TEMPLATE_TYPES, LicenseType, SourceConfig, get_enabled_sources, get_sources_by_priority, ) from github_crawler import ( ExtractedDocument, GitHubCrawler, RepositoryDownloader, ) # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Configuration - Support both QDRANT_URL and QDRANT_HOST/PORT _qdrant_url = os.getenv("QDRANT_URL", "") if _qdrant_url: _parsed = urlparse(_qdrant_url) QDRANT_HOST = _parsed.hostname or "localhost" QDRANT_PORT = _parsed.port or 6333 else: QDRANT_HOST = os.getenv("QDRANT_HOST", "localhost") QDRANT_PORT = int(os.getenv("QDRANT_PORT", "6333")) EMBEDDING_SERVICE_URL = os.getenv("EMBEDDING_SERVICE_URL", "http://localhost:8087") LEGAL_TEMPLATES_COLLECTION = "bp_legal_templates" VECTOR_SIZE = 1024 # BGE-M3 dimension # Chunking configuration CHUNK_SIZE = int(os.getenv("TEMPLATE_CHUNK_SIZE", "1000")) CHUNK_OVERLAP = int(os.getenv("TEMPLATE_CHUNK_OVERLAP", "200")) # Batch processing EMBEDDING_BATCH_SIZE = 4 MAX_RETRIES = 3 RETRY_DELAY = 3.0 @dataclass class IngestionStatus: """Status of a source ingestion.""" source_name: str status: str # "pending", "running", "completed", "failed" documents_found: int = 0 chunks_created: int = 0 chunks_indexed: int = 0 errors: List[str] = field(default_factory=list) started_at: Optional[datetime] = None completed_at: Optional[datetime] = None @dataclass class TemplateChunk: """A chunk of template text ready for indexing.""" text: str chunk_index: int document_title: str template_type: str clause_category: Optional[str] language: str jurisdiction: str license_id: str license_name: str license_url: str attribution_required: bool share_alike: bool no_derivatives: bool commercial_use: bool source_name: str source_url: str source_repo: Optional[str] source_commit: Optional[str] source_file: str source_hash: str attribution_text: Optional[str] copyright_notice: Optional[str] is_complete_document: bool is_modular: bool requires_customization: bool placeholders: List[str] training_allowed: bool output_allowed: bool modification_allowed: bool distortion_prohibited: bool class LegalTemplatesIngestion: """Handles ingestion of legal templates into Qdrant.""" def __init__(self): self.qdrant = QdrantClient(host=QDRANT_HOST, port=QDRANT_PORT) self.http_client = httpx.AsyncClient(timeout=120.0) self._ensure_collection() self._ingestion_status: Dict[str, IngestionStatus] = {} def _ensure_collection(self): """Create the legal templates collection if it doesn't exist.""" collections = self.qdrant.get_collections().collections collection_names = [c.name for c in collections] if LEGAL_TEMPLATES_COLLECTION not in collection_names: logger.info(f"Creating collection: {LEGAL_TEMPLATES_COLLECTION}") self.qdrant.create_collection( collection_name=LEGAL_TEMPLATES_COLLECTION, vectors_config=VectorParams( size=VECTOR_SIZE, distance=Distance.COSINE, ), ) logger.info(f"Collection {LEGAL_TEMPLATES_COLLECTION} created") else: logger.info(f"Collection {LEGAL_TEMPLATES_COLLECTION} already exists") async def _generate_embeddings(self, texts: List[str]) -> List[List[float]]: """Generate embeddings via the embedding service.""" try: response = await self.http_client.post( f"{EMBEDDING_SERVICE_URL}/embed", json={"texts": texts}, timeout=120.0, ) response.raise_for_status() data = response.json() return data["embeddings"] except Exception as e: logger.error(f"Embedding generation failed: {e}") raise def _chunk_text(self, text: str, chunk_size: int = CHUNK_SIZE, overlap: int = CHUNK_OVERLAP) -> List[str]: """ Split text into overlapping chunks. Respects paragraph and sentence boundaries where possible. """ if not text: return [] if len(text) <= chunk_size: return [text.strip()] # Split into paragraphs first paragraphs = text.split('\n\n') chunks = [] current_chunk = [] current_length = 0 for para in paragraphs: para = para.strip() if not para: continue para_length = len(para) if para_length > chunk_size: # Large paragraph: split by sentences if current_chunk: chunks.append('\n\n'.join(current_chunk)) current_chunk = [] current_length = 0 # Split long paragraph by sentences sentences = self._split_sentences(para) for sentence in sentences: if current_length + len(sentence) + 1 > chunk_size: if current_chunk: chunks.append(' '.join(current_chunk)) # Keep overlap overlap_count = max(1, len(current_chunk) // 3) current_chunk = current_chunk[-overlap_count:] current_length = sum(len(s) + 1 for s in current_chunk) current_chunk.append(sentence) current_length += len(sentence) + 1 elif current_length + para_length + 2 > chunk_size: # Paragraph would exceed chunk size if current_chunk: chunks.append('\n\n'.join(current_chunk)) current_chunk = [] current_length = 0 current_chunk.append(para) current_length = para_length else: current_chunk.append(para) current_length += para_length + 2 # Add final chunk if current_chunk: chunks.append('\n\n'.join(current_chunk)) return [c.strip() for c in chunks if c.strip()] def _split_sentences(self, text: str) -> List[str]: """Split text into sentences with basic abbreviation handling.""" import re # Protect common abbreviations abbreviations = ['bzw', 'ca', 'd.h', 'etc', 'ggf', 'inkl', 'u.a', 'usw', 'z.B', 'z.b', 'e.g', 'i.e', 'vs', 'no'] protected = text for abbr in abbreviations: pattern = re.compile(r'\b' + re.escape(abbr) + r'\.', re.IGNORECASE) protected = pattern.sub(abbr.replace('.', '') + '', protected) # Protect decimal numbers protected = re.sub(r'(\d)\.(\d)', r'\1\2', protected) # Split on sentence endings sentences = re.split(r'(?<=[.!?])\s+', protected) # Restore protected characters result = [] for s in sentences: s = s.replace('', '.').replace('', '.').replace('', '.') s = s.strip() if s: result.append(s) return result def _infer_template_type(self, doc: ExtractedDocument, source: SourceConfig) -> str: """Infer the template type from document content and metadata.""" text_lower = doc.text.lower() title_lower = doc.title.lower() # Check known indicators type_indicators = { "privacy_policy": ["datenschutz", "privacy", "personal data", "personenbezogen"], "terms_of_service": ["nutzungsbedingungen", "terms of service", "terms of use", "agb"], "cookie_banner": ["cookie", "cookies", "tracking"], "impressum": ["impressum", "legal notice", "imprint"], "widerruf": ["widerruf", "cancellation", "withdrawal", "right to cancel"], "dpa": ["auftragsverarbeitung", "data processing agreement", "dpa"], "sla": ["service level", "availability", "uptime"], "nda": ["confidential", "non-disclosure", "geheimhaltung", "vertraulich"], "community_guidelines": ["community", "guidelines", "conduct", "verhaltens"], "acceptable_use": ["acceptable use", "acceptable usage", "nutzungsrichtlinien"], } for template_type, indicators in type_indicators.items(): for indicator in indicators: if indicator in text_lower or indicator in title_lower: return template_type # Fall back to source's first template type if source.template_types: return source.template_types[0] return "clause" # Generic fallback def _infer_clause_category(self, text: str) -> Optional[str]: """Infer the clause category from text content.""" text_lower = text.lower() categories = { "haftung": ["haftung", "liability", "haftungsausschluss", "limitation"], "datenschutz": ["datenschutz", "privacy", "personal data", "personenbezogen"], "widerruf": ["widerruf", "cancellation", "withdrawal"], "gewaehrleistung": ["gewaehrleistung", "warranty", "garantie"], "kuendigung": ["kuendigung", "termination", "beendigung"], "zahlung": ["zahlung", "payment", "preis", "price"], "gerichtsstand": ["gerichtsstand", "jurisdiction", "governing law"], "aenderungen": ["aenderung", "modification", "amendment"], "schlussbestimmungen": ["schlussbestimmung", "miscellaneous", "final provisions"], } for category, indicators in categories.items(): for indicator in indicators: if indicator in text_lower: return category return None def _create_chunks( self, doc: ExtractedDocument, source: SourceConfig, ) -> List[TemplateChunk]: """Create template chunks from an extracted document.""" license_info = source.license_info template_type = self._infer_template_type(doc, source) # Chunk the text text_chunks = self._chunk_text(doc.text) chunks = [] for i, chunk_text in enumerate(text_chunks): # Determine if this is a complete document or a clause is_complete = len(text_chunks) == 1 and len(chunk_text) > 500 is_modular = len(doc.sections) > 0 or '##' in doc.text requires_customization = len(doc.placeholders) > 0 # Generate attribution text attribution_text = None if license_info.attribution_required: attribution_text = license_info.get_attribution_text( source.name, doc.source_url or source.get_source_url() ) chunk = TemplateChunk( text=chunk_text, chunk_index=i, document_title=doc.title, template_type=template_type, clause_category=self._infer_clause_category(chunk_text), language=doc.language, jurisdiction=source.jurisdiction, license_id=license_info.id.value, license_name=license_info.name, license_url=license_info.url, attribution_required=license_info.attribution_required, share_alike=license_info.share_alike, no_derivatives=license_info.no_derivatives, commercial_use=license_info.commercial_use, source_name=source.name, source_url=doc.source_url or source.get_source_url(), source_repo=source.repo_url, source_commit=doc.source_commit, source_file=doc.file_path, source_hash=doc.source_hash, attribution_text=attribution_text, copyright_notice=None, # Could be extracted from doc if present is_complete_document=is_complete, is_modular=is_modular, requires_customization=requires_customization, placeholders=doc.placeholders, training_allowed=license_info.training_allowed, output_allowed=license_info.output_allowed, modification_allowed=license_info.modification_allowed, distortion_prohibited=license_info.distortion_prohibited, ) chunks.append(chunk) return chunks async def ingest_source(self, source: SourceConfig) -> IngestionStatus: """Ingest a single source into Qdrant.""" status = IngestionStatus( source_name=source.name, status="running", started_at=datetime.utcnow(), ) self._ingestion_status[source.name] = status logger.info(f"Ingesting source: {source.name}") try: # Crawl the source documents: List[ExtractedDocument] = [] if source.repo_url: async with GitHubCrawler() as crawler: async for doc in crawler.crawl_repository(source): documents.append(doc) status.documents_found += 1 logger.info(f"Found {len(documents)} documents in {source.name}") if not documents: status.status = "completed" status.completed_at = datetime.utcnow() return status # Create chunks from all documents all_chunks: List[TemplateChunk] = [] for doc in documents: chunks = self._create_chunks(doc, source) all_chunks.extend(chunks) status.chunks_created += len(chunks) logger.info(f"Created {len(all_chunks)} chunks from {source.name}") # Generate embeddings and index in batches for i in range(0, len(all_chunks), EMBEDDING_BATCH_SIZE): batch_chunks = all_chunks[i:i + EMBEDDING_BATCH_SIZE] chunk_texts = [c.text for c in batch_chunks] # Retry logic for embeddings embeddings = None for retry in range(MAX_RETRIES): try: embeddings = await self._generate_embeddings(chunk_texts) break except Exception as e: logger.warning( f"Embedding attempt {retry+1}/{MAX_RETRIES} failed: {e}" ) if retry < MAX_RETRIES - 1: await asyncio.sleep(RETRY_DELAY * (retry + 1)) else: status.errors.append(f"Embedding failed for batch {i}: {e}") if embeddings is None: continue # Create points for Qdrant points = [] for j, (chunk, embedding) in enumerate(zip(batch_chunks, embeddings)): point_id = hashlib.md5( f"{source.name}-{chunk.source_file}-{chunk.chunk_index}".encode() ).hexdigest() payload = { "text": chunk.text, "chunk_index": chunk.chunk_index, "document_title": chunk.document_title, "template_type": chunk.template_type, "clause_category": chunk.clause_category, "language": chunk.language, "jurisdiction": chunk.jurisdiction, "license_id": chunk.license_id, "license_name": chunk.license_name, "license_url": chunk.license_url, "attribution_required": chunk.attribution_required, "share_alike": chunk.share_alike, "no_derivatives": chunk.no_derivatives, "commercial_use": chunk.commercial_use, "source_name": chunk.source_name, "source_url": chunk.source_url, "source_repo": chunk.source_repo, "source_commit": chunk.source_commit, "source_file": chunk.source_file, "source_hash": chunk.source_hash, "attribution_text": chunk.attribution_text, "copyright_notice": chunk.copyright_notice, "is_complete_document": chunk.is_complete_document, "is_modular": chunk.is_modular, "requires_customization": chunk.requires_customization, "placeholders": chunk.placeholders, "training_allowed": chunk.training_allowed, "output_allowed": chunk.output_allowed, "modification_allowed": chunk.modification_allowed, "distortion_prohibited": chunk.distortion_prohibited, "indexed_at": datetime.utcnow().isoformat(), } points.append(PointStruct( id=point_id, vector=embedding, payload=payload, )) # Upsert to Qdrant if points: self.qdrant.upsert( collection_name=LEGAL_TEMPLATES_COLLECTION, points=points, ) status.chunks_indexed += len(points) # Rate limiting await asyncio.sleep(1.0) status.status = "completed" status.completed_at = datetime.utcnow() logger.info( f"Completed {source.name}: {status.chunks_indexed} chunks indexed" ) except Exception as e: status.status = "failed" status.errors.append(str(e)) status.completed_at = datetime.utcnow() logger.error(f"Failed to ingest {source.name}: {e}") return status async def ingest_all(self, max_priority: int = 5) -> Dict[str, IngestionStatus]: """Ingest all enabled sources up to a priority level.""" sources = get_sources_by_priority(max_priority) results = {} for source in sources: result = await self.ingest_source(source) results[source.name] = result return results async def ingest_by_license(self, license_type: LicenseType) -> Dict[str, IngestionStatus]: """Ingest all sources of a specific license type.""" from template_sources import get_sources_by_license sources = get_sources_by_license(license_type) results = {} for source in sources: result = await self.ingest_source(source) results[source.name] = result return results def get_status(self) -> Dict[str, Any]: """Get collection status and ingestion results.""" try: collection_info = self.qdrant.get_collection(LEGAL_TEMPLATES_COLLECTION) # Count points per source source_counts = {} for source in TEMPLATE_SOURCES: result = self.qdrant.count( collection_name=LEGAL_TEMPLATES_COLLECTION, count_filter=Filter( must=[ FieldCondition( key="source_name", match=MatchValue(value=source.name), ) ] ), ) source_counts[source.name] = result.count # Count by license type license_counts = {} for license_type in LicenseType: result = self.qdrant.count( collection_name=LEGAL_TEMPLATES_COLLECTION, count_filter=Filter( must=[ FieldCondition( key="license_id", match=MatchValue(value=license_type.value), ) ] ), ) license_counts[license_type.value] = result.count # Count by template type template_type_counts = {} for template_type in TEMPLATE_TYPES.keys(): result = self.qdrant.count( collection_name=LEGAL_TEMPLATES_COLLECTION, count_filter=Filter( must=[ FieldCondition( key="template_type", match=MatchValue(value=template_type), ) ] ), ) if result.count > 0: template_type_counts[template_type] = result.count # Count by language language_counts = {} for lang in ["de", "en"]: result = self.qdrant.count( collection_name=LEGAL_TEMPLATES_COLLECTION, count_filter=Filter( must=[ FieldCondition( key="language", match=MatchValue(value=lang), ) ] ), ) language_counts[lang] = result.count return { "collection": LEGAL_TEMPLATES_COLLECTION, "total_points": collection_info.points_count, "vector_size": VECTOR_SIZE, "sources": source_counts, "licenses": license_counts, "template_types": template_type_counts, "languages": language_counts, "status": "ready" if collection_info.points_count > 0 else "empty", "ingestion_status": { name: { "status": s.status, "documents_found": s.documents_found, "chunks_indexed": s.chunks_indexed, "errors": s.errors, } for name, s in self._ingestion_status.items() }, } except Exception as e: return { "collection": LEGAL_TEMPLATES_COLLECTION, "error": str(e), "status": "error", } async def search( self, query: str, template_type: Optional[str] = None, license_types: Optional[List[str]] = None, language: Optional[str] = None, jurisdiction: Optional[str] = None, attribution_required: Optional[bool] = None, top_k: int = 10, ) -> List[Dict[str, Any]]: """ Search the legal templates collection. Args: query: Search query text template_type: Filter by template type (e.g., "privacy_policy") license_types: Filter by license types (e.g., ["cc0", "mit"]) language: Filter by language (e.g., "de") jurisdiction: Filter by jurisdiction (e.g., "DE") attribution_required: Filter by attribution requirement top_k: Number of results to return Returns: List of search results with full metadata """ # Generate query embedding embeddings = await self._generate_embeddings([query]) query_vector = embeddings[0] # Build filter conditions must_conditions = [] if template_type: must_conditions.append( FieldCondition( key="template_type", match=MatchValue(value=template_type), ) ) if language: must_conditions.append( FieldCondition( key="language", match=MatchValue(value=language), ) ) if jurisdiction: must_conditions.append( FieldCondition( key="jurisdiction", match=MatchValue(value=jurisdiction), ) ) if attribution_required is not None: must_conditions.append( FieldCondition( key="attribution_required", match=MatchValue(value=attribution_required), ) ) # License type filter (OR condition) should_conditions = [] if license_types: for license_type in license_types: should_conditions.append( FieldCondition( key="license_id", match=MatchValue(value=license_type), ) ) # Construct filter search_filter = None if must_conditions or should_conditions: filter_dict = {} if must_conditions: filter_dict["must"] = must_conditions if should_conditions: filter_dict["should"] = should_conditions search_filter = Filter(**filter_dict) # Execute search results = self.qdrant.search( collection_name=LEGAL_TEMPLATES_COLLECTION, query_vector=query_vector, query_filter=search_filter, limit=top_k, ) return [ { "id": hit.id, "score": hit.score, "text": hit.payload.get("text"), "document_title": hit.payload.get("document_title"), "template_type": hit.payload.get("template_type"), "clause_category": hit.payload.get("clause_category"), "language": hit.payload.get("language"), "jurisdiction": hit.payload.get("jurisdiction"), "license_id": hit.payload.get("license_id"), "license_name": hit.payload.get("license_name"), "attribution_required": hit.payload.get("attribution_required"), "attribution_text": hit.payload.get("attribution_text"), "source_name": hit.payload.get("source_name"), "source_url": hit.payload.get("source_url"), "placeholders": hit.payload.get("placeholders"), "is_complete_document": hit.payload.get("is_complete_document"), "requires_customization": hit.payload.get("requires_customization"), "output_allowed": hit.payload.get("output_allowed"), "modification_allowed": hit.payload.get("modification_allowed"), } for hit in results ] def delete_source(self, source_name: str) -> int: """Delete all chunks from a specific source.""" # First count how many we're deleting count_result = self.qdrant.count( collection_name=LEGAL_TEMPLATES_COLLECTION, count_filter=Filter( must=[ FieldCondition( key="source_name", match=MatchValue(value=source_name), ) ] ), ) # Delete by filter self.qdrant.delete( collection_name=LEGAL_TEMPLATES_COLLECTION, points_selector=Filter( must=[ FieldCondition( key="source_name", match=MatchValue(value=source_name), ) ] ), ) return count_result.count def reset_collection(self): """Delete and recreate the collection.""" logger.warning(f"Resetting collection: {LEGAL_TEMPLATES_COLLECTION}") # Delete collection try: self.qdrant.delete_collection(LEGAL_TEMPLATES_COLLECTION) except Exception: pass # Collection might not exist # Recreate self._ensure_collection() self._ingestion_status.clear() logger.info(f"Collection {LEGAL_TEMPLATES_COLLECTION} reset") async def close(self): """Close HTTP client.""" await self.http_client.aclose() async def main(): """CLI entry point.""" import argparse parser = argparse.ArgumentParser(description="Legal Templates Ingestion") parser.add_argument( "--ingest-all", action="store_true", help="Ingest all enabled sources" ) parser.add_argument( "--ingest-source", type=str, metavar="NAME", help="Ingest a specific source by name" ) parser.add_argument( "--ingest-license", type=str, choices=["cc0", "mit", "cc_by_4", "public_domain"], help="Ingest all sources of a specific license type" ) parser.add_argument( "--max-priority", type=int, default=3, help="Maximum priority level to ingest (1=highest, 5=lowest)" ) parser.add_argument( "--status", action="store_true", help="Show collection status" ) parser.add_argument( "--search", type=str, metavar="QUERY", help="Test search query" ) parser.add_argument( "--template-type", type=str, help="Filter search by template type" ) parser.add_argument( "--language", type=str, help="Filter search by language" ) parser.add_argument( "--reset", action="store_true", help="Reset (delete and recreate) the collection" ) parser.add_argument( "--delete-source", type=str, metavar="NAME", help="Delete all chunks from a source" ) args = parser.parse_args() ingestion = LegalTemplatesIngestion() try: if args.reset: ingestion.reset_collection() print("Collection reset successfully") elif args.delete_source: count = ingestion.delete_source(args.delete_source) print(f"Deleted {count} chunks from {args.delete_source}") elif args.status: status = ingestion.get_status() print(json.dumps(status, indent=2, default=str)) elif args.ingest_all: print(f"Ingesting all sources (max priority: {args.max_priority})...") results = await ingestion.ingest_all(max_priority=args.max_priority) print("\nResults:") for name, status in results.items(): print(f" {name}: {status.chunks_indexed} chunks ({status.status})") if status.errors: for error in status.errors: print(f" ERROR: {error}") total = sum(s.chunks_indexed for s in results.values()) print(f"\nTotal: {total} chunks indexed") elif args.ingest_source: source = next( (s for s in TEMPLATE_SOURCES if s.name == args.ingest_source), None ) if not source: print(f"Unknown source: {args.ingest_source}") print("Available sources:") for s in TEMPLATE_SOURCES: print(f" - {s.name}") return print(f"Ingesting: {source.name}") status = await ingestion.ingest_source(source) print(f"\nResult: {status.chunks_indexed} chunks ({status.status})") if status.errors: for error in status.errors: print(f" ERROR: {error}") elif args.ingest_license: license_type = LicenseType(args.ingest_license) print(f"Ingesting all {license_type.value} sources...") results = await ingestion.ingest_by_license(license_type) print("\nResults:") for name, status in results.items(): print(f" {name}: {status.chunks_indexed} chunks ({status.status})") elif args.search: print(f"Searching: {args.search}") results = await ingestion.search( args.search, template_type=args.template_type, language=args.language, ) print(f"\nFound {len(results)} results:") for i, result in enumerate(results, 1): print(f"\n{i}. [{result['template_type']}] {result['document_title']}") print(f" Score: {result['score']:.3f}") print(f" License: {result['license_name']}") print(f" Source: {result['source_name']}") print(f" Language: {result['language']}") if result['attribution_required']: print(f" Attribution: {result['attribution_text']}") print(f" Text: {result['text'][:200]}...") else: parser.print_help() finally: await ingestion.close() if __name__ == "__main__": asyncio.run(main())