Services: Admin-Lehrer, Backend-Lehrer, Studio v2, Website, Klausur-Service, School-Service, Voice-Service, Geo-Service, BreakPilot Drive, Agent-Core Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
943 lines
34 KiB
Python
943 lines
34 KiB
Python
"""
|
|
Legal Templates Ingestion Pipeline for RAG.
|
|
|
|
Indexes legal template documents from various open-source repositories
|
|
into Qdrant for semantic search. Supports multiple license types with
|
|
proper attribution tracking.
|
|
|
|
Collection: bp_legal_templates
|
|
|
|
Usage:
|
|
python legal_templates_ingestion.py --ingest-all
|
|
python legal_templates_ingestion.py --ingest-source github-site-policy
|
|
python legal_templates_ingestion.py --status
|
|
python legal_templates_ingestion.py --search "Datenschutzerklaerung"
|
|
"""
|
|
|
|
import asyncio
|
|
import hashlib
|
|
import json
|
|
import logging
|
|
import os
|
|
from dataclasses import dataclass, field
|
|
from datetime import datetime
|
|
from typing import Any, Dict, List, Optional
|
|
from urllib.parse import urlparse
|
|
|
|
import httpx
|
|
from qdrant_client import QdrantClient
|
|
from qdrant_client.models import (
|
|
Distance,
|
|
FieldCondition,
|
|
Filter,
|
|
MatchValue,
|
|
PointStruct,
|
|
VectorParams,
|
|
)
|
|
|
|
from template_sources import (
|
|
LICENSES,
|
|
TEMPLATE_SOURCES,
|
|
TEMPLATE_TYPES,
|
|
LicenseType,
|
|
SourceConfig,
|
|
get_enabled_sources,
|
|
get_sources_by_priority,
|
|
)
|
|
from github_crawler import (
|
|
ExtractedDocument,
|
|
GitHubCrawler,
|
|
RepositoryDownloader,
|
|
)
|
|
|
|
# Configure logging
|
|
logging.basicConfig(level=logging.INFO)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Configuration - Support both QDRANT_URL and QDRANT_HOST/PORT
|
|
_qdrant_url = os.getenv("QDRANT_URL", "")
|
|
if _qdrant_url:
|
|
_parsed = urlparse(_qdrant_url)
|
|
QDRANT_HOST = _parsed.hostname or "localhost"
|
|
QDRANT_PORT = _parsed.port or 6333
|
|
else:
|
|
QDRANT_HOST = os.getenv("QDRANT_HOST", "localhost")
|
|
QDRANT_PORT = int(os.getenv("QDRANT_PORT", "6333"))
|
|
|
|
EMBEDDING_SERVICE_URL = os.getenv("EMBEDDING_SERVICE_URL", "http://localhost:8087")
|
|
LEGAL_TEMPLATES_COLLECTION = "bp_legal_templates"
|
|
VECTOR_SIZE = 1024 # BGE-M3 dimension
|
|
|
|
# Chunking configuration
|
|
CHUNK_SIZE = int(os.getenv("TEMPLATE_CHUNK_SIZE", "1000"))
|
|
CHUNK_OVERLAP = int(os.getenv("TEMPLATE_CHUNK_OVERLAP", "200"))
|
|
|
|
# Batch processing
|
|
EMBEDDING_BATCH_SIZE = 4
|
|
MAX_RETRIES = 3
|
|
RETRY_DELAY = 3.0
|
|
|
|
|
|
@dataclass
|
|
class IngestionStatus:
|
|
"""Status of a source ingestion."""
|
|
source_name: str
|
|
status: str # "pending", "running", "completed", "failed"
|
|
documents_found: int = 0
|
|
chunks_created: int = 0
|
|
chunks_indexed: int = 0
|
|
errors: List[str] = field(default_factory=list)
|
|
started_at: Optional[datetime] = None
|
|
completed_at: Optional[datetime] = None
|
|
|
|
|
|
@dataclass
|
|
class TemplateChunk:
|
|
"""A chunk of template text ready for indexing."""
|
|
text: str
|
|
chunk_index: int
|
|
document_title: str
|
|
template_type: str
|
|
clause_category: Optional[str]
|
|
language: str
|
|
jurisdiction: str
|
|
license_id: str
|
|
license_name: str
|
|
license_url: str
|
|
attribution_required: bool
|
|
share_alike: bool
|
|
no_derivatives: bool
|
|
commercial_use: bool
|
|
source_name: str
|
|
source_url: str
|
|
source_repo: Optional[str]
|
|
source_commit: Optional[str]
|
|
source_file: str
|
|
source_hash: str
|
|
attribution_text: Optional[str]
|
|
copyright_notice: Optional[str]
|
|
is_complete_document: bool
|
|
is_modular: bool
|
|
requires_customization: bool
|
|
placeholders: List[str]
|
|
training_allowed: bool
|
|
output_allowed: bool
|
|
modification_allowed: bool
|
|
distortion_prohibited: bool
|
|
|
|
|
|
class LegalTemplatesIngestion:
|
|
"""Handles ingestion of legal templates into Qdrant."""
|
|
|
|
def __init__(self):
|
|
self.qdrant = QdrantClient(host=QDRANT_HOST, port=QDRANT_PORT)
|
|
self.http_client = httpx.AsyncClient(timeout=120.0)
|
|
self._ensure_collection()
|
|
self._ingestion_status: Dict[str, IngestionStatus] = {}
|
|
|
|
def _ensure_collection(self):
|
|
"""Create the legal templates collection if it doesn't exist."""
|
|
collections = self.qdrant.get_collections().collections
|
|
collection_names = [c.name for c in collections]
|
|
|
|
if LEGAL_TEMPLATES_COLLECTION not in collection_names:
|
|
logger.info(f"Creating collection: {LEGAL_TEMPLATES_COLLECTION}")
|
|
self.qdrant.create_collection(
|
|
collection_name=LEGAL_TEMPLATES_COLLECTION,
|
|
vectors_config=VectorParams(
|
|
size=VECTOR_SIZE,
|
|
distance=Distance.COSINE,
|
|
),
|
|
)
|
|
logger.info(f"Collection {LEGAL_TEMPLATES_COLLECTION} created")
|
|
else:
|
|
logger.info(f"Collection {LEGAL_TEMPLATES_COLLECTION} already exists")
|
|
|
|
async def _generate_embeddings(self, texts: List[str]) -> List[List[float]]:
|
|
"""Generate embeddings via the embedding service."""
|
|
try:
|
|
response = await self.http_client.post(
|
|
f"{EMBEDDING_SERVICE_URL}/embed",
|
|
json={"texts": texts},
|
|
timeout=120.0,
|
|
)
|
|
response.raise_for_status()
|
|
data = response.json()
|
|
return data["embeddings"]
|
|
except Exception as e:
|
|
logger.error(f"Embedding generation failed: {e}")
|
|
raise
|
|
|
|
def _chunk_text(self, text: str, chunk_size: int = CHUNK_SIZE, overlap: int = CHUNK_OVERLAP) -> List[str]:
|
|
"""
|
|
Split text into overlapping chunks.
|
|
Respects paragraph and sentence boundaries where possible.
|
|
"""
|
|
if not text:
|
|
return []
|
|
|
|
if len(text) <= chunk_size:
|
|
return [text.strip()]
|
|
|
|
# Split into paragraphs first
|
|
paragraphs = text.split('\n\n')
|
|
chunks = []
|
|
current_chunk = []
|
|
current_length = 0
|
|
|
|
for para in paragraphs:
|
|
para = para.strip()
|
|
if not para:
|
|
continue
|
|
|
|
para_length = len(para)
|
|
|
|
if para_length > chunk_size:
|
|
# Large paragraph: split by sentences
|
|
if current_chunk:
|
|
chunks.append('\n\n'.join(current_chunk))
|
|
current_chunk = []
|
|
current_length = 0
|
|
|
|
# Split long paragraph by sentences
|
|
sentences = self._split_sentences(para)
|
|
for sentence in sentences:
|
|
if current_length + len(sentence) + 1 > chunk_size:
|
|
if current_chunk:
|
|
chunks.append(' '.join(current_chunk))
|
|
# Keep overlap
|
|
overlap_count = max(1, len(current_chunk) // 3)
|
|
current_chunk = current_chunk[-overlap_count:]
|
|
current_length = sum(len(s) + 1 for s in current_chunk)
|
|
current_chunk.append(sentence)
|
|
current_length += len(sentence) + 1
|
|
|
|
elif current_length + para_length + 2 > chunk_size:
|
|
# Paragraph would exceed chunk size
|
|
if current_chunk:
|
|
chunks.append('\n\n'.join(current_chunk))
|
|
current_chunk = []
|
|
current_length = 0
|
|
current_chunk.append(para)
|
|
current_length = para_length
|
|
|
|
else:
|
|
current_chunk.append(para)
|
|
current_length += para_length + 2
|
|
|
|
# Add final chunk
|
|
if current_chunk:
|
|
chunks.append('\n\n'.join(current_chunk))
|
|
|
|
return [c.strip() for c in chunks if c.strip()]
|
|
|
|
def _split_sentences(self, text: str) -> List[str]:
|
|
"""Split text into sentences with basic abbreviation handling."""
|
|
import re
|
|
|
|
# Protect common abbreviations
|
|
abbreviations = ['bzw', 'ca', 'd.h', 'etc', 'ggf', 'inkl', 'u.a', 'usw', 'z.B', 'z.b', 'e.g', 'i.e', 'vs', 'no']
|
|
protected = text
|
|
for abbr in abbreviations:
|
|
pattern = re.compile(r'\b' + re.escape(abbr) + r'\.', re.IGNORECASE)
|
|
protected = pattern.sub(abbr.replace('.', '<DOT>') + '<ABBR>', protected)
|
|
|
|
# Protect decimal numbers
|
|
protected = re.sub(r'(\d)\.(\d)', r'\1<DECIMAL>\2', protected)
|
|
|
|
# Split on sentence endings
|
|
sentences = re.split(r'(?<=[.!?])\s+', protected)
|
|
|
|
# Restore protected characters
|
|
result = []
|
|
for s in sentences:
|
|
s = s.replace('<DOT>', '.').replace('<ABBR>', '.').replace('<DECIMAL>', '.')
|
|
s = s.strip()
|
|
if s:
|
|
result.append(s)
|
|
|
|
return result
|
|
|
|
def _infer_template_type(self, doc: ExtractedDocument, source: SourceConfig) -> str:
|
|
"""Infer the template type from document content and metadata."""
|
|
text_lower = doc.text.lower()
|
|
title_lower = doc.title.lower()
|
|
|
|
# Check known indicators
|
|
type_indicators = {
|
|
"privacy_policy": ["datenschutz", "privacy", "personal data", "personenbezogen"],
|
|
"terms_of_service": ["nutzungsbedingungen", "terms of service", "terms of use", "agb"],
|
|
"cookie_banner": ["cookie", "cookies", "tracking"],
|
|
"impressum": ["impressum", "legal notice", "imprint"],
|
|
"widerruf": ["widerruf", "cancellation", "withdrawal", "right to cancel"],
|
|
"dpa": ["auftragsverarbeitung", "data processing agreement", "dpa"],
|
|
"sla": ["service level", "availability", "uptime"],
|
|
"nda": ["confidential", "non-disclosure", "geheimhaltung", "vertraulich"],
|
|
"community_guidelines": ["community", "guidelines", "conduct", "verhaltens"],
|
|
"acceptable_use": ["acceptable use", "acceptable usage", "nutzungsrichtlinien"],
|
|
}
|
|
|
|
for template_type, indicators in type_indicators.items():
|
|
for indicator in indicators:
|
|
if indicator in text_lower or indicator in title_lower:
|
|
return template_type
|
|
|
|
# Fall back to source's first template type
|
|
if source.template_types:
|
|
return source.template_types[0]
|
|
|
|
return "clause" # Generic fallback
|
|
|
|
def _infer_clause_category(self, text: str) -> Optional[str]:
|
|
"""Infer the clause category from text content."""
|
|
text_lower = text.lower()
|
|
|
|
categories = {
|
|
"haftung": ["haftung", "liability", "haftungsausschluss", "limitation"],
|
|
"datenschutz": ["datenschutz", "privacy", "personal data", "personenbezogen"],
|
|
"widerruf": ["widerruf", "cancellation", "withdrawal"],
|
|
"gewaehrleistung": ["gewaehrleistung", "warranty", "garantie"],
|
|
"kuendigung": ["kuendigung", "termination", "beendigung"],
|
|
"zahlung": ["zahlung", "payment", "preis", "price"],
|
|
"gerichtsstand": ["gerichtsstand", "jurisdiction", "governing law"],
|
|
"aenderungen": ["aenderung", "modification", "amendment"],
|
|
"schlussbestimmungen": ["schlussbestimmung", "miscellaneous", "final provisions"],
|
|
}
|
|
|
|
for category, indicators in categories.items():
|
|
for indicator in indicators:
|
|
if indicator in text_lower:
|
|
return category
|
|
|
|
return None
|
|
|
|
def _create_chunks(
|
|
self,
|
|
doc: ExtractedDocument,
|
|
source: SourceConfig,
|
|
) -> List[TemplateChunk]:
|
|
"""Create template chunks from an extracted document."""
|
|
license_info = source.license_info
|
|
template_type = self._infer_template_type(doc, source)
|
|
|
|
# Chunk the text
|
|
text_chunks = self._chunk_text(doc.text)
|
|
|
|
chunks = []
|
|
for i, chunk_text in enumerate(text_chunks):
|
|
# Determine if this is a complete document or a clause
|
|
is_complete = len(text_chunks) == 1 and len(chunk_text) > 500
|
|
is_modular = len(doc.sections) > 0 or '##' in doc.text
|
|
requires_customization = len(doc.placeholders) > 0
|
|
|
|
# Generate attribution text
|
|
attribution_text = None
|
|
if license_info.attribution_required:
|
|
attribution_text = license_info.get_attribution_text(
|
|
source.name,
|
|
doc.source_url or source.get_source_url()
|
|
)
|
|
|
|
chunk = TemplateChunk(
|
|
text=chunk_text,
|
|
chunk_index=i,
|
|
document_title=doc.title,
|
|
template_type=template_type,
|
|
clause_category=self._infer_clause_category(chunk_text),
|
|
language=doc.language,
|
|
jurisdiction=source.jurisdiction,
|
|
license_id=license_info.id.value,
|
|
license_name=license_info.name,
|
|
license_url=license_info.url,
|
|
attribution_required=license_info.attribution_required,
|
|
share_alike=license_info.share_alike,
|
|
no_derivatives=license_info.no_derivatives,
|
|
commercial_use=license_info.commercial_use,
|
|
source_name=source.name,
|
|
source_url=doc.source_url or source.get_source_url(),
|
|
source_repo=source.repo_url,
|
|
source_commit=doc.source_commit,
|
|
source_file=doc.file_path,
|
|
source_hash=doc.source_hash,
|
|
attribution_text=attribution_text,
|
|
copyright_notice=None, # Could be extracted from doc if present
|
|
is_complete_document=is_complete,
|
|
is_modular=is_modular,
|
|
requires_customization=requires_customization,
|
|
placeholders=doc.placeholders,
|
|
training_allowed=license_info.training_allowed,
|
|
output_allowed=license_info.output_allowed,
|
|
modification_allowed=license_info.modification_allowed,
|
|
distortion_prohibited=license_info.distortion_prohibited,
|
|
)
|
|
chunks.append(chunk)
|
|
|
|
return chunks
|
|
|
|
async def ingest_source(self, source: SourceConfig) -> IngestionStatus:
|
|
"""Ingest a single source into Qdrant."""
|
|
status = IngestionStatus(
|
|
source_name=source.name,
|
|
status="running",
|
|
started_at=datetime.utcnow(),
|
|
)
|
|
self._ingestion_status[source.name] = status
|
|
|
|
logger.info(f"Ingesting source: {source.name}")
|
|
|
|
try:
|
|
# Crawl the source
|
|
documents: List[ExtractedDocument] = []
|
|
|
|
if source.repo_url:
|
|
async with GitHubCrawler() as crawler:
|
|
async for doc in crawler.crawl_repository(source):
|
|
documents.append(doc)
|
|
status.documents_found += 1
|
|
|
|
logger.info(f"Found {len(documents)} documents in {source.name}")
|
|
|
|
if not documents:
|
|
status.status = "completed"
|
|
status.completed_at = datetime.utcnow()
|
|
return status
|
|
|
|
# Create chunks from all documents
|
|
all_chunks: List[TemplateChunk] = []
|
|
for doc in documents:
|
|
chunks = self._create_chunks(doc, source)
|
|
all_chunks.extend(chunks)
|
|
status.chunks_created += len(chunks)
|
|
|
|
logger.info(f"Created {len(all_chunks)} chunks from {source.name}")
|
|
|
|
# Generate embeddings and index in batches
|
|
for i in range(0, len(all_chunks), EMBEDDING_BATCH_SIZE):
|
|
batch_chunks = all_chunks[i:i + EMBEDDING_BATCH_SIZE]
|
|
chunk_texts = [c.text for c in batch_chunks]
|
|
|
|
# Retry logic for embeddings
|
|
embeddings = None
|
|
for retry in range(MAX_RETRIES):
|
|
try:
|
|
embeddings = await self._generate_embeddings(chunk_texts)
|
|
break
|
|
except Exception as e:
|
|
logger.warning(
|
|
f"Embedding attempt {retry+1}/{MAX_RETRIES} failed: {e}"
|
|
)
|
|
if retry < MAX_RETRIES - 1:
|
|
await asyncio.sleep(RETRY_DELAY * (retry + 1))
|
|
else:
|
|
status.errors.append(f"Embedding failed for batch {i}: {e}")
|
|
|
|
if embeddings is None:
|
|
continue
|
|
|
|
# Create points for Qdrant
|
|
points = []
|
|
for j, (chunk, embedding) in enumerate(zip(batch_chunks, embeddings)):
|
|
point_id = hashlib.md5(
|
|
f"{source.name}-{chunk.source_file}-{chunk.chunk_index}".encode()
|
|
).hexdigest()
|
|
|
|
payload = {
|
|
"text": chunk.text,
|
|
"chunk_index": chunk.chunk_index,
|
|
"document_title": chunk.document_title,
|
|
"template_type": chunk.template_type,
|
|
"clause_category": chunk.clause_category,
|
|
"language": chunk.language,
|
|
"jurisdiction": chunk.jurisdiction,
|
|
"license_id": chunk.license_id,
|
|
"license_name": chunk.license_name,
|
|
"license_url": chunk.license_url,
|
|
"attribution_required": chunk.attribution_required,
|
|
"share_alike": chunk.share_alike,
|
|
"no_derivatives": chunk.no_derivatives,
|
|
"commercial_use": chunk.commercial_use,
|
|
"source_name": chunk.source_name,
|
|
"source_url": chunk.source_url,
|
|
"source_repo": chunk.source_repo,
|
|
"source_commit": chunk.source_commit,
|
|
"source_file": chunk.source_file,
|
|
"source_hash": chunk.source_hash,
|
|
"attribution_text": chunk.attribution_text,
|
|
"copyright_notice": chunk.copyright_notice,
|
|
"is_complete_document": chunk.is_complete_document,
|
|
"is_modular": chunk.is_modular,
|
|
"requires_customization": chunk.requires_customization,
|
|
"placeholders": chunk.placeholders,
|
|
"training_allowed": chunk.training_allowed,
|
|
"output_allowed": chunk.output_allowed,
|
|
"modification_allowed": chunk.modification_allowed,
|
|
"distortion_prohibited": chunk.distortion_prohibited,
|
|
"indexed_at": datetime.utcnow().isoformat(),
|
|
}
|
|
|
|
points.append(PointStruct(
|
|
id=point_id,
|
|
vector=embedding,
|
|
payload=payload,
|
|
))
|
|
|
|
# Upsert to Qdrant
|
|
if points:
|
|
self.qdrant.upsert(
|
|
collection_name=LEGAL_TEMPLATES_COLLECTION,
|
|
points=points,
|
|
)
|
|
status.chunks_indexed += len(points)
|
|
|
|
# Rate limiting
|
|
await asyncio.sleep(1.0)
|
|
|
|
status.status = "completed"
|
|
status.completed_at = datetime.utcnow()
|
|
logger.info(
|
|
f"Completed {source.name}: {status.chunks_indexed} chunks indexed"
|
|
)
|
|
|
|
except Exception as e:
|
|
status.status = "failed"
|
|
status.errors.append(str(e))
|
|
status.completed_at = datetime.utcnow()
|
|
logger.error(f"Failed to ingest {source.name}: {e}")
|
|
|
|
return status
|
|
|
|
async def ingest_all(self, max_priority: int = 5) -> Dict[str, IngestionStatus]:
|
|
"""Ingest all enabled sources up to a priority level."""
|
|
sources = get_sources_by_priority(max_priority)
|
|
results = {}
|
|
|
|
for source in sources:
|
|
result = await self.ingest_source(source)
|
|
results[source.name] = result
|
|
|
|
return results
|
|
|
|
async def ingest_by_license(self, license_type: LicenseType) -> Dict[str, IngestionStatus]:
|
|
"""Ingest all sources of a specific license type."""
|
|
from template_sources import get_sources_by_license
|
|
|
|
sources = get_sources_by_license(license_type)
|
|
results = {}
|
|
|
|
for source in sources:
|
|
result = await self.ingest_source(source)
|
|
results[source.name] = result
|
|
|
|
return results
|
|
|
|
def get_status(self) -> Dict[str, Any]:
|
|
"""Get collection status and ingestion results."""
|
|
try:
|
|
collection_info = self.qdrant.get_collection(LEGAL_TEMPLATES_COLLECTION)
|
|
|
|
# Count points per source
|
|
source_counts = {}
|
|
for source in TEMPLATE_SOURCES:
|
|
result = self.qdrant.count(
|
|
collection_name=LEGAL_TEMPLATES_COLLECTION,
|
|
count_filter=Filter(
|
|
must=[
|
|
FieldCondition(
|
|
key="source_name",
|
|
match=MatchValue(value=source.name),
|
|
)
|
|
]
|
|
),
|
|
)
|
|
source_counts[source.name] = result.count
|
|
|
|
# Count by license type
|
|
license_counts = {}
|
|
for license_type in LicenseType:
|
|
result = self.qdrant.count(
|
|
collection_name=LEGAL_TEMPLATES_COLLECTION,
|
|
count_filter=Filter(
|
|
must=[
|
|
FieldCondition(
|
|
key="license_id",
|
|
match=MatchValue(value=license_type.value),
|
|
)
|
|
]
|
|
),
|
|
)
|
|
license_counts[license_type.value] = result.count
|
|
|
|
# Count by template type
|
|
template_type_counts = {}
|
|
for template_type in TEMPLATE_TYPES.keys():
|
|
result = self.qdrant.count(
|
|
collection_name=LEGAL_TEMPLATES_COLLECTION,
|
|
count_filter=Filter(
|
|
must=[
|
|
FieldCondition(
|
|
key="template_type",
|
|
match=MatchValue(value=template_type),
|
|
)
|
|
]
|
|
),
|
|
)
|
|
if result.count > 0:
|
|
template_type_counts[template_type] = result.count
|
|
|
|
# Count by language
|
|
language_counts = {}
|
|
for lang in ["de", "en"]:
|
|
result = self.qdrant.count(
|
|
collection_name=LEGAL_TEMPLATES_COLLECTION,
|
|
count_filter=Filter(
|
|
must=[
|
|
FieldCondition(
|
|
key="language",
|
|
match=MatchValue(value=lang),
|
|
)
|
|
]
|
|
),
|
|
)
|
|
language_counts[lang] = result.count
|
|
|
|
return {
|
|
"collection": LEGAL_TEMPLATES_COLLECTION,
|
|
"total_points": collection_info.points_count,
|
|
"vector_size": VECTOR_SIZE,
|
|
"sources": source_counts,
|
|
"licenses": license_counts,
|
|
"template_types": template_type_counts,
|
|
"languages": language_counts,
|
|
"status": "ready" if collection_info.points_count > 0 else "empty",
|
|
"ingestion_status": {
|
|
name: {
|
|
"status": s.status,
|
|
"documents_found": s.documents_found,
|
|
"chunks_indexed": s.chunks_indexed,
|
|
"errors": s.errors,
|
|
}
|
|
for name, s in self._ingestion_status.items()
|
|
},
|
|
}
|
|
|
|
except Exception as e:
|
|
return {
|
|
"collection": LEGAL_TEMPLATES_COLLECTION,
|
|
"error": str(e),
|
|
"status": "error",
|
|
}
|
|
|
|
async def search(
|
|
self,
|
|
query: str,
|
|
template_type: Optional[str] = None,
|
|
license_types: Optional[List[str]] = None,
|
|
language: Optional[str] = None,
|
|
jurisdiction: Optional[str] = None,
|
|
attribution_required: Optional[bool] = None,
|
|
top_k: int = 10,
|
|
) -> List[Dict[str, Any]]:
|
|
"""
|
|
Search the legal templates collection.
|
|
|
|
Args:
|
|
query: Search query text
|
|
template_type: Filter by template type (e.g., "privacy_policy")
|
|
license_types: Filter by license types (e.g., ["cc0", "mit"])
|
|
language: Filter by language (e.g., "de")
|
|
jurisdiction: Filter by jurisdiction (e.g., "DE")
|
|
attribution_required: Filter by attribution requirement
|
|
top_k: Number of results to return
|
|
|
|
Returns:
|
|
List of search results with full metadata
|
|
"""
|
|
# Generate query embedding
|
|
embeddings = await self._generate_embeddings([query])
|
|
query_vector = embeddings[0]
|
|
|
|
# Build filter conditions
|
|
must_conditions = []
|
|
|
|
if template_type:
|
|
must_conditions.append(
|
|
FieldCondition(
|
|
key="template_type",
|
|
match=MatchValue(value=template_type),
|
|
)
|
|
)
|
|
|
|
if language:
|
|
must_conditions.append(
|
|
FieldCondition(
|
|
key="language",
|
|
match=MatchValue(value=language),
|
|
)
|
|
)
|
|
|
|
if jurisdiction:
|
|
must_conditions.append(
|
|
FieldCondition(
|
|
key="jurisdiction",
|
|
match=MatchValue(value=jurisdiction),
|
|
)
|
|
)
|
|
|
|
if attribution_required is not None:
|
|
must_conditions.append(
|
|
FieldCondition(
|
|
key="attribution_required",
|
|
match=MatchValue(value=attribution_required),
|
|
)
|
|
)
|
|
|
|
# License type filter (OR condition)
|
|
should_conditions = []
|
|
if license_types:
|
|
for license_type in license_types:
|
|
should_conditions.append(
|
|
FieldCondition(
|
|
key="license_id",
|
|
match=MatchValue(value=license_type),
|
|
)
|
|
)
|
|
|
|
# Construct filter
|
|
search_filter = None
|
|
if must_conditions or should_conditions:
|
|
filter_dict = {}
|
|
if must_conditions:
|
|
filter_dict["must"] = must_conditions
|
|
if should_conditions:
|
|
filter_dict["should"] = should_conditions
|
|
search_filter = Filter(**filter_dict)
|
|
|
|
# Execute search
|
|
results = self.qdrant.search(
|
|
collection_name=LEGAL_TEMPLATES_COLLECTION,
|
|
query_vector=query_vector,
|
|
query_filter=search_filter,
|
|
limit=top_k,
|
|
)
|
|
|
|
return [
|
|
{
|
|
"id": hit.id,
|
|
"score": hit.score,
|
|
"text": hit.payload.get("text"),
|
|
"document_title": hit.payload.get("document_title"),
|
|
"template_type": hit.payload.get("template_type"),
|
|
"clause_category": hit.payload.get("clause_category"),
|
|
"language": hit.payload.get("language"),
|
|
"jurisdiction": hit.payload.get("jurisdiction"),
|
|
"license_id": hit.payload.get("license_id"),
|
|
"license_name": hit.payload.get("license_name"),
|
|
"attribution_required": hit.payload.get("attribution_required"),
|
|
"attribution_text": hit.payload.get("attribution_text"),
|
|
"source_name": hit.payload.get("source_name"),
|
|
"source_url": hit.payload.get("source_url"),
|
|
"placeholders": hit.payload.get("placeholders"),
|
|
"is_complete_document": hit.payload.get("is_complete_document"),
|
|
"requires_customization": hit.payload.get("requires_customization"),
|
|
"output_allowed": hit.payload.get("output_allowed"),
|
|
"modification_allowed": hit.payload.get("modification_allowed"),
|
|
}
|
|
for hit in results
|
|
]
|
|
|
|
def delete_source(self, source_name: str) -> int:
|
|
"""Delete all chunks from a specific source."""
|
|
# First count how many we're deleting
|
|
count_result = self.qdrant.count(
|
|
collection_name=LEGAL_TEMPLATES_COLLECTION,
|
|
count_filter=Filter(
|
|
must=[
|
|
FieldCondition(
|
|
key="source_name",
|
|
match=MatchValue(value=source_name),
|
|
)
|
|
]
|
|
),
|
|
)
|
|
|
|
# Delete by filter
|
|
self.qdrant.delete(
|
|
collection_name=LEGAL_TEMPLATES_COLLECTION,
|
|
points_selector=Filter(
|
|
must=[
|
|
FieldCondition(
|
|
key="source_name",
|
|
match=MatchValue(value=source_name),
|
|
)
|
|
]
|
|
),
|
|
)
|
|
|
|
return count_result.count
|
|
|
|
def reset_collection(self):
|
|
"""Delete and recreate the collection."""
|
|
logger.warning(f"Resetting collection: {LEGAL_TEMPLATES_COLLECTION}")
|
|
|
|
# Delete collection
|
|
try:
|
|
self.qdrant.delete_collection(LEGAL_TEMPLATES_COLLECTION)
|
|
except Exception:
|
|
pass # Collection might not exist
|
|
|
|
# Recreate
|
|
self._ensure_collection()
|
|
self._ingestion_status.clear()
|
|
|
|
logger.info(f"Collection {LEGAL_TEMPLATES_COLLECTION} reset")
|
|
|
|
async def close(self):
|
|
"""Close HTTP client."""
|
|
await self.http_client.aclose()
|
|
|
|
|
|
async def main():
|
|
"""CLI entry point."""
|
|
import argparse
|
|
|
|
parser = argparse.ArgumentParser(description="Legal Templates Ingestion")
|
|
parser.add_argument(
|
|
"--ingest-all",
|
|
action="store_true",
|
|
help="Ingest all enabled sources"
|
|
)
|
|
parser.add_argument(
|
|
"--ingest-source",
|
|
type=str,
|
|
metavar="NAME",
|
|
help="Ingest a specific source by name"
|
|
)
|
|
parser.add_argument(
|
|
"--ingest-license",
|
|
type=str,
|
|
choices=["cc0", "mit", "cc_by_4", "public_domain"],
|
|
help="Ingest all sources of a specific license type"
|
|
)
|
|
parser.add_argument(
|
|
"--max-priority",
|
|
type=int,
|
|
default=3,
|
|
help="Maximum priority level to ingest (1=highest, 5=lowest)"
|
|
)
|
|
parser.add_argument(
|
|
"--status",
|
|
action="store_true",
|
|
help="Show collection status"
|
|
)
|
|
parser.add_argument(
|
|
"--search",
|
|
type=str,
|
|
metavar="QUERY",
|
|
help="Test search query"
|
|
)
|
|
parser.add_argument(
|
|
"--template-type",
|
|
type=str,
|
|
help="Filter search by template type"
|
|
)
|
|
parser.add_argument(
|
|
"--language",
|
|
type=str,
|
|
help="Filter search by language"
|
|
)
|
|
parser.add_argument(
|
|
"--reset",
|
|
action="store_true",
|
|
help="Reset (delete and recreate) the collection"
|
|
)
|
|
parser.add_argument(
|
|
"--delete-source",
|
|
type=str,
|
|
metavar="NAME",
|
|
help="Delete all chunks from a source"
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
|
|
ingestion = LegalTemplatesIngestion()
|
|
|
|
try:
|
|
if args.reset:
|
|
ingestion.reset_collection()
|
|
print("Collection reset successfully")
|
|
|
|
elif args.delete_source:
|
|
count = ingestion.delete_source(args.delete_source)
|
|
print(f"Deleted {count} chunks from {args.delete_source}")
|
|
|
|
elif args.status:
|
|
status = ingestion.get_status()
|
|
print(json.dumps(status, indent=2, default=str))
|
|
|
|
elif args.ingest_all:
|
|
print(f"Ingesting all sources (max priority: {args.max_priority})...")
|
|
results = await ingestion.ingest_all(max_priority=args.max_priority)
|
|
print("\nResults:")
|
|
for name, status in results.items():
|
|
print(f" {name}: {status.chunks_indexed} chunks ({status.status})")
|
|
if status.errors:
|
|
for error in status.errors:
|
|
print(f" ERROR: {error}")
|
|
total = sum(s.chunks_indexed for s in results.values())
|
|
print(f"\nTotal: {total} chunks indexed")
|
|
|
|
elif args.ingest_source:
|
|
source = next(
|
|
(s for s in TEMPLATE_SOURCES if s.name == args.ingest_source),
|
|
None
|
|
)
|
|
if not source:
|
|
print(f"Unknown source: {args.ingest_source}")
|
|
print("Available sources:")
|
|
for s in TEMPLATE_SOURCES:
|
|
print(f" - {s.name}")
|
|
return
|
|
|
|
print(f"Ingesting: {source.name}")
|
|
status = await ingestion.ingest_source(source)
|
|
print(f"\nResult: {status.chunks_indexed} chunks ({status.status})")
|
|
if status.errors:
|
|
for error in status.errors:
|
|
print(f" ERROR: {error}")
|
|
|
|
elif args.ingest_license:
|
|
license_type = LicenseType(args.ingest_license)
|
|
print(f"Ingesting all {license_type.value} sources...")
|
|
results = await ingestion.ingest_by_license(license_type)
|
|
print("\nResults:")
|
|
for name, status in results.items():
|
|
print(f" {name}: {status.chunks_indexed} chunks ({status.status})")
|
|
|
|
elif args.search:
|
|
print(f"Searching: {args.search}")
|
|
results = await ingestion.search(
|
|
args.search,
|
|
template_type=args.template_type,
|
|
language=args.language,
|
|
)
|
|
print(f"\nFound {len(results)} results:")
|
|
for i, result in enumerate(results, 1):
|
|
print(f"\n{i}. [{result['template_type']}] {result['document_title']}")
|
|
print(f" Score: {result['score']:.3f}")
|
|
print(f" License: {result['license_name']}")
|
|
print(f" Source: {result['source_name']}")
|
|
print(f" Language: {result['language']}")
|
|
if result['attribution_required']:
|
|
print(f" Attribution: {result['attribution_text']}")
|
|
print(f" Text: {result['text'][:200]}...")
|
|
|
|
else:
|
|
parser.print_help()
|
|
|
|
finally:
|
|
await ingestion.close()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
asyncio.run(main())
|