Replace plain recursive chunker with legal-aware chunking that: - Detects legal section headers (§, Art., Section, Chapter, Annex) - Adds section context prefix to every chunk - Splits on paragraph boundaries then sentence boundaries - Protects DE + EN abbreviations (80+ patterns) from false splits - Supports language detection for locale-specific processing - Force-splits overlong sentences at word boundaries The old plain_recursive API option is removed — all non-semantic strategies now route through chunk_text_legal(). Includes 40 tests covering header detection, abbreviation protection, sentence splitting, and legal chunking behavior. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
932 lines
32 KiB
Python
932 lines
32 KiB
Python
"""
|
|
Embedding Service - FastAPI Application
|
|
|
|
Provides REST endpoints for:
|
|
- Embedding generation (local sentence-transformers or OpenAI)
|
|
- Re-ranking (local CrossEncoder or Cohere)
|
|
- PDF text extraction (Unstructured or pypdf)
|
|
- Text chunking (semantic or recursive)
|
|
|
|
This service handles all ML-heavy operations, keeping the main klausur-service lightweight.
|
|
"""
|
|
|
|
import os
|
|
import logging
|
|
from typing import List, Optional
|
|
from contextlib import asynccontextmanager
|
|
|
|
from fastapi import FastAPI, HTTPException, UploadFile, File
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from pydantic import BaseModel, Field
|
|
|
|
import config
|
|
|
|
# Configure logging
|
|
logging.basicConfig(
|
|
level=getattr(logging, config.LOG_LEVEL.upper(), logging.INFO),
|
|
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
|
)
|
|
logger = logging.getLogger("embedding-service")
|
|
|
|
# =============================================================================
|
|
# Lazy-loaded models
|
|
# =============================================================================
|
|
|
|
_embedding_model = None
|
|
_reranker_model = None
|
|
|
|
|
|
def get_embedding_model():
|
|
"""Lazy-load the sentence-transformers embedding model."""
|
|
global _embedding_model
|
|
if _embedding_model is None:
|
|
from sentence_transformers import SentenceTransformer
|
|
logger.info(f"Loading embedding model: {config.LOCAL_EMBEDDING_MODEL}")
|
|
_embedding_model = SentenceTransformer(config.LOCAL_EMBEDDING_MODEL)
|
|
logger.info(f"Model loaded (dim={_embedding_model.get_sentence_embedding_dimension()})")
|
|
return _embedding_model
|
|
|
|
|
|
def get_reranker_model():
|
|
"""Lazy-load the CrossEncoder reranker model."""
|
|
global _reranker_model
|
|
if _reranker_model is None:
|
|
from sentence_transformers import CrossEncoder
|
|
logger.info(f"Loading reranker model: {config.LOCAL_RERANKER_MODEL}")
|
|
_reranker_model = CrossEncoder(config.LOCAL_RERANKER_MODEL)
|
|
logger.info("Reranker loaded")
|
|
return _reranker_model
|
|
|
|
|
|
# =============================================================================
|
|
# Request/Response Models
|
|
# =============================================================================
|
|
|
|
class EmbedRequest(BaseModel):
|
|
texts: List[str] = Field(..., description="List of texts to embed")
|
|
|
|
|
|
class EmbedResponse(BaseModel):
|
|
embeddings: List[List[float]]
|
|
model: str
|
|
dimensions: int
|
|
|
|
|
|
class EmbedSingleRequest(BaseModel):
|
|
text: str = Field(..., description="Single text to embed")
|
|
|
|
|
|
class EmbedSingleResponse(BaseModel):
|
|
embedding: List[float]
|
|
model: str
|
|
dimensions: int
|
|
|
|
|
|
class RerankRequest(BaseModel):
|
|
query: str = Field(..., description="Search query")
|
|
documents: List[str] = Field(..., description="Documents to re-rank")
|
|
top_k: int = Field(default=5, description="Number of top results to return")
|
|
|
|
|
|
class RerankResult(BaseModel):
|
|
index: int
|
|
score: float
|
|
text: str
|
|
|
|
|
|
class RerankResponse(BaseModel):
|
|
results: List[RerankResult]
|
|
model: str
|
|
|
|
|
|
class ChunkRequest(BaseModel):
|
|
text: str = Field(..., description="Text to chunk")
|
|
chunk_size: int = Field(default=1000, description="Target chunk size")
|
|
overlap: int = Field(default=200, description="Overlap between chunks")
|
|
strategy: str = Field(default="semantic", description="Chunking strategy: semantic or recursive")
|
|
|
|
|
|
class ChunkResponse(BaseModel):
|
|
chunks: List[str]
|
|
count: int
|
|
strategy: str
|
|
|
|
|
|
class ExtractPDFResponse(BaseModel):
|
|
text: str
|
|
backend_used: str
|
|
pages: int
|
|
table_count: int
|
|
|
|
|
|
class HealthResponse(BaseModel):
|
|
status: str
|
|
embedding_model: str
|
|
embedding_dimensions: int
|
|
reranker_model: str
|
|
pdf_backends: List[str]
|
|
|
|
|
|
class ModelsResponse(BaseModel):
|
|
embedding_backend: str
|
|
embedding_model: str
|
|
embedding_dimensions: int
|
|
reranker_backend: str
|
|
reranker_model: str
|
|
pdf_backend: str
|
|
available_pdf_backends: List[str]
|
|
|
|
|
|
# =============================================================================
|
|
# Embedding Functions
|
|
# =============================================================================
|
|
|
|
def generate_local_embeddings(texts: List[str]) -> List[List[float]]:
|
|
"""Generate embeddings using local model."""
|
|
if not texts:
|
|
return []
|
|
model = get_embedding_model()
|
|
embeddings = model.encode(texts, show_progress_bar=len(texts) > 10)
|
|
return [emb.tolist() for emb in embeddings]
|
|
|
|
|
|
async def generate_openai_embeddings(texts: List[str]) -> List[List[float]]:
|
|
"""Generate embeddings using OpenAI API."""
|
|
import httpx
|
|
|
|
if not config.OPENAI_API_KEY:
|
|
raise HTTPException(status_code=500, detail="OPENAI_API_KEY not configured")
|
|
|
|
async with httpx.AsyncClient() as client:
|
|
response = await client.post(
|
|
"https://api.openai.com/v1/embeddings",
|
|
headers={
|
|
"Authorization": f"Bearer {config.OPENAI_API_KEY}",
|
|
"Content-Type": "application/json"
|
|
},
|
|
json={
|
|
"model": config.OPENAI_EMBEDDING_MODEL,
|
|
"input": texts
|
|
},
|
|
timeout=60.0
|
|
)
|
|
|
|
if response.status_code != 200:
|
|
raise HTTPException(
|
|
status_code=response.status_code,
|
|
detail=f"OpenAI API error: {response.text}"
|
|
)
|
|
|
|
data = response.json()
|
|
return [item["embedding"] for item in data["data"]]
|
|
|
|
|
|
# =============================================================================
|
|
# Re-ranking Functions
|
|
# =============================================================================
|
|
|
|
def rerank_local(query: str, documents: List[str], top_k: int = 5) -> List[RerankResult]:
|
|
"""Re-rank documents using local CrossEncoder."""
|
|
if not documents:
|
|
return []
|
|
|
|
model = get_reranker_model()
|
|
pairs = [(query, doc) for doc in documents]
|
|
scores = model.predict(pairs)
|
|
|
|
results = [
|
|
RerankResult(index=i, score=float(score), text=doc)
|
|
for i, (score, doc) in enumerate(zip(scores, documents))
|
|
]
|
|
results.sort(key=lambda x: x.score, reverse=True)
|
|
return results[:top_k]
|
|
|
|
|
|
async def rerank_cohere(query: str, documents: List[str], top_k: int = 5) -> List[RerankResult]:
|
|
"""Re-rank documents using Cohere API."""
|
|
import httpx
|
|
|
|
if not config.COHERE_API_KEY:
|
|
raise HTTPException(status_code=500, detail="COHERE_API_KEY not configured")
|
|
|
|
async with httpx.AsyncClient() as client:
|
|
response = await client.post(
|
|
"https://api.cohere.ai/v2/rerank",
|
|
headers={
|
|
"Authorization": f"Bearer {config.COHERE_API_KEY}",
|
|
"Content-Type": "application/json"
|
|
},
|
|
json={
|
|
"model": "rerank-multilingual-v3.0",
|
|
"query": query,
|
|
"documents": documents,
|
|
"top_n": top_k,
|
|
"return_documents": False,
|
|
},
|
|
timeout=30.0
|
|
)
|
|
|
|
if response.status_code != 200:
|
|
raise HTTPException(
|
|
status_code=response.status_code,
|
|
detail=f"Cohere API error: {response.text}"
|
|
)
|
|
|
|
data = response.json()
|
|
return [
|
|
RerankResult(
|
|
index=item["index"],
|
|
score=item["relevance_score"],
|
|
text=documents[item["index"]]
|
|
)
|
|
for item in data.get("results", [])
|
|
]
|
|
|
|
|
|
# =============================================================================
|
|
# Chunking Functions
|
|
# =============================================================================
|
|
|
|
# German abbreviations that don't end sentences
|
|
GERMAN_ABBREVIATIONS = {
|
|
'bzw', 'ca', 'chr', 'd.h', 'dr', 'etc', 'evtl', 'ggf', 'inkl', 'max',
|
|
'min', 'mio', 'mrd', 'nr', 'prof', 's', 'sog', 'u.a', 'u.ä', 'usw',
|
|
'v.a', 'vgl', 'vs', 'z.b', 'z.t', 'zzgl', 'abs', 'art', 'abschn',
|
|
'anh', 'anl', 'aufl', 'bd', 'bes', 'bzgl', 'dgl', 'einschl', 'entspr',
|
|
'erg', 'erl', 'gem', 'grds', 'hrsg', 'insb', 'ivm', 'kap', 'lit',
|
|
'nachf', 'rdnr', 'rn', 'rz', 'ua', 'uvm', 'vorst', 'ziff'
|
|
}
|
|
|
|
# English abbreviations that don't end sentences
|
|
ENGLISH_ABBREVIATIONS = {
|
|
'e.g', 'i.e', 'etc', 'vs', 'al', 'approx', 'avg', 'dept', 'dr', 'ed',
|
|
'est', 'fig', 'govt', 'inc', 'jr', 'ltd', 'max', 'min', 'mr', 'mrs',
|
|
'ms', 'no', 'prof', 'pt', 'ref', 'rev', 'sec', 'sgt', 'sr', 'st',
|
|
'vol', 'cf', 'ch', 'cl', 'col', 'corp', 'cpl', 'def', 'dist', 'div',
|
|
'gen', 'hon', 'illus', 'intl', 'natl', 'org', 'para', 'pp', 'repr',
|
|
'resp', 'supp', 'tech', 'temp', 'treas', 'univ'
|
|
}
|
|
|
|
# Combined abbreviations for both languages
|
|
ALL_ABBREVIATIONS = GERMAN_ABBREVIATIONS | ENGLISH_ABBREVIATIONS
|
|
|
|
# Regex pattern for legal section headers (§, Art., Article, Section, etc.)
|
|
import re
|
|
|
|
_LEGAL_SECTION_RE = re.compile(
|
|
r'^(?:'
|
|
r'§\s*\d+' # § 25, § 5a
|
|
r'|Art(?:ikel|icle|\.)\s*\d+' # Artikel 5, Article 12, Art. 3
|
|
r'|Section\s+\d+' # Section 4.2
|
|
r'|Abschnitt\s+\d+' # Abschnitt III
|
|
r'|Kapitel\s+\d+' # Kapitel 2
|
|
r'|Chapter\s+\d+' # Chapter 3
|
|
r'|Anhang\s+[IVXLC\d]+' # Anhang III
|
|
r'|Annex\s+[IVXLC\d]+' # Annex XII
|
|
r'|TEIL\s+[IVXLC\d]+' # TEIL II
|
|
r'|Part\s+[IVXLC\d]+' # Part III
|
|
r'|Recital\s+\d+' # Recital 42
|
|
r'|Erwaegungsgrund\s+\d+' # Erwaegungsgrund 26
|
|
r')',
|
|
re.IGNORECASE | re.MULTILINE
|
|
)
|
|
|
|
# Regex for any heading-like line (Markdown ## or ALL-CAPS line)
|
|
_HEADING_RE = re.compile(
|
|
r'^(?:'
|
|
r'#{1,6}\s+.+' # Markdown headings
|
|
r'|[A-ZÄÖÜ][A-ZÄÖÜ\s\-]{5,}$' # ALL-CAPS lines (>5 chars)
|
|
r')',
|
|
re.MULTILINE
|
|
)
|
|
|
|
|
|
def _detect_language(text: str) -> str:
|
|
"""Simple heuristic: count German vs English marker words."""
|
|
sample = text[:5000].lower()
|
|
de_markers = sum(1 for w in ['der', 'die', 'das', 'und', 'ist', 'für', 'von',
|
|
'werden', 'nach', 'gemäß', 'sowie', 'durch']
|
|
if f' {w} ' in sample)
|
|
en_markers = sum(1 for w in ['the', 'and', 'for', 'that', 'with', 'shall',
|
|
'must', 'should', 'which', 'from', 'this']
|
|
if f' {w} ' in sample)
|
|
return 'de' if de_markers > en_markers else 'en'
|
|
|
|
|
|
def _protect_abbreviations(text: str) -> str:
|
|
"""Replace dots in abbreviations with placeholders to prevent false sentence splits."""
|
|
protected = text
|
|
for abbrev in ALL_ABBREVIATIONS:
|
|
pattern = re.compile(r'\b(' + re.escape(abbrev) + r')\.', re.IGNORECASE)
|
|
# Use lambda to preserve original case of the matched abbreviation
|
|
protected = pattern.sub(lambda m: m.group(1).replace('.', '<DOT>') + '<ABBR>', protected)
|
|
# Protect decimals (3.14) and ordinals (1. Absatz)
|
|
protected = re.sub(r'(\d)\.(\d)', r'\1<DECIMAL>\2', protected)
|
|
protected = re.sub(r'(\d+)\.\s', r'\1<ORD> ', protected)
|
|
return protected
|
|
|
|
|
|
def _restore_abbreviations(text: str) -> str:
|
|
"""Restore placeholders back to dots."""
|
|
return (text
|
|
.replace('<DOT>', '.')
|
|
.replace('<ABBR>', '.')
|
|
.replace('<DECIMAL>', '.')
|
|
.replace('<ORD>', '.'))
|
|
|
|
|
|
def _split_sentences(text: str) -> List[str]:
|
|
"""Split text into sentences, respecting abbreviations in DE and EN."""
|
|
protected = _protect_abbreviations(text)
|
|
# Split after sentence-ending punctuation followed by uppercase or newline
|
|
sentence_pattern = r'(?<=[.!?])\s+(?=[A-ZÄÖÜÀ-Ý])|(?<=[.!?])\s*\n'
|
|
raw = re.split(sentence_pattern, protected)
|
|
sentences = []
|
|
for s in raw:
|
|
s = _restore_abbreviations(s).strip()
|
|
if s:
|
|
sentences.append(s)
|
|
return sentences
|
|
|
|
|
|
def _extract_section_header(line: str) -> Optional[str]:
|
|
"""Extract a legal section header from a line, or None."""
|
|
m = _LEGAL_SECTION_RE.match(line.strip())
|
|
if m:
|
|
return line.strip()
|
|
m = _HEADING_RE.match(line.strip())
|
|
if m:
|
|
return line.strip()
|
|
return None
|
|
|
|
|
|
def chunk_text_legal(text: str, chunk_size: int, overlap: int) -> List[str]:
|
|
"""
|
|
Legal-document-aware chunking.
|
|
|
|
Strategy:
|
|
1. Split on legal section boundaries (§, Art., Section, Chapter, etc.)
|
|
2. Within each section, split on paragraph boundaries (double newline)
|
|
3. Within each paragraph, split on sentence boundaries
|
|
4. Prepend section header as context prefix to every chunk
|
|
5. Add overlap from previous chunk
|
|
|
|
Works for both German (DSGVO, BGB, AI Act DE) and English (NIST, SLSA, CRA EN) texts.
|
|
"""
|
|
if not text or len(text) <= chunk_size:
|
|
return [text.strip()] if text and text.strip() else []
|
|
|
|
# --- Phase 1: Split into sections by legal headers ---
|
|
lines = text.split('\n')
|
|
sections = [] # list of (header, content)
|
|
current_header = None
|
|
current_lines = []
|
|
|
|
for line in lines:
|
|
header = _extract_section_header(line)
|
|
if header and current_lines:
|
|
sections.append((current_header, '\n'.join(current_lines)))
|
|
current_header = header
|
|
current_lines = [line]
|
|
elif header and not current_lines:
|
|
current_header = header
|
|
current_lines = [line]
|
|
else:
|
|
current_lines.append(line)
|
|
|
|
if current_lines:
|
|
sections.append((current_header, '\n'.join(current_lines)))
|
|
|
|
# --- Phase 2: Within each section, split on paragraphs, then sentences ---
|
|
raw_chunks = []
|
|
|
|
for section_header, section_text in sections:
|
|
# Build context prefix (max 120 chars to leave room for content)
|
|
prefix = ""
|
|
if section_header:
|
|
truncated = section_header[:120]
|
|
prefix = f"[{truncated}] "
|
|
|
|
paragraphs = re.split(r'\n\s*\n', section_text)
|
|
|
|
current_chunk = prefix
|
|
current_length = len(prefix)
|
|
|
|
for para in paragraphs:
|
|
para = para.strip()
|
|
if not para:
|
|
continue
|
|
|
|
# If paragraph fits in remaining space, append
|
|
if current_length + len(para) + 1 <= chunk_size:
|
|
if current_chunk and not current_chunk.endswith(' '):
|
|
current_chunk += '\n\n'
|
|
current_chunk += para
|
|
current_length = len(current_chunk)
|
|
continue
|
|
|
|
# Paragraph doesn't fit — flush current chunk if non-empty
|
|
if current_chunk.strip() and current_chunk.strip() != prefix.strip():
|
|
raw_chunks.append(current_chunk.strip())
|
|
|
|
# If entire paragraph fits in a fresh chunk, start new chunk
|
|
if len(prefix) + len(para) <= chunk_size:
|
|
current_chunk = prefix + para
|
|
current_length = len(current_chunk)
|
|
continue
|
|
|
|
# Paragraph too long — split by sentences
|
|
sentences = _split_sentences(para)
|
|
current_chunk = prefix
|
|
current_length = len(prefix)
|
|
|
|
for sentence in sentences:
|
|
sentence_len = len(sentence)
|
|
|
|
# Single sentence exceeds chunk_size — force-split
|
|
if len(prefix) + sentence_len > chunk_size:
|
|
if current_chunk.strip() and current_chunk.strip() != prefix.strip():
|
|
raw_chunks.append(current_chunk.strip())
|
|
# Hard split the long sentence
|
|
remaining = sentence
|
|
while remaining:
|
|
take = chunk_size - len(prefix)
|
|
chunk_part = prefix + remaining[:take]
|
|
raw_chunks.append(chunk_part.strip())
|
|
remaining = remaining[take:]
|
|
current_chunk = prefix
|
|
current_length = len(prefix)
|
|
continue
|
|
|
|
if current_length + sentence_len + 1 > chunk_size:
|
|
if current_chunk.strip() and current_chunk.strip() != prefix.strip():
|
|
raw_chunks.append(current_chunk.strip())
|
|
current_chunk = prefix + sentence
|
|
current_length = len(current_chunk)
|
|
else:
|
|
if current_chunk and not current_chunk.endswith(' '):
|
|
current_chunk += ' '
|
|
current_chunk += sentence
|
|
current_length = len(current_chunk)
|
|
|
|
# Flush remaining content for this section
|
|
if current_chunk.strip() and current_chunk.strip() != prefix.strip():
|
|
raw_chunks.append(current_chunk.strip())
|
|
|
|
if not raw_chunks:
|
|
return [text.strip()] if text.strip() else []
|
|
|
|
# --- Phase 3: Add overlap ---
|
|
final_chunks = []
|
|
for i, chunk in enumerate(raw_chunks):
|
|
if i > 0 and overlap > 0:
|
|
prev = raw_chunks[i - 1]
|
|
# Take overlap from end of previous chunk (but not the prefix)
|
|
overlap_text = prev[-min(overlap, len(prev)):]
|
|
# Only add overlap if it doesn't start mid-word
|
|
space_idx = overlap_text.find(' ')
|
|
if space_idx > 0:
|
|
overlap_text = overlap_text[space_idx + 1:]
|
|
if overlap_text:
|
|
chunk = overlap_text + ' ' + chunk
|
|
final_chunks.append(chunk.strip())
|
|
|
|
return [c for c in final_chunks if c]
|
|
|
|
|
|
def chunk_text_recursive(text: str, chunk_size: int, overlap: int) -> List[str]:
|
|
"""Recursive character-based chunking (legacy, use legal_recursive for legal docs)."""
|
|
if not text or len(text) <= chunk_size:
|
|
return [text] if text else []
|
|
|
|
separators = ["\n\n", "\n", ". ", " ", ""]
|
|
|
|
def split_recursive(text: str, sep_idx: int = 0) -> List[str]:
|
|
if len(text) <= chunk_size:
|
|
return [text]
|
|
|
|
if sep_idx >= len(separators):
|
|
return [text[i:i+chunk_size] for i in range(0, len(text), chunk_size - overlap)]
|
|
|
|
sep = separators[sep_idx]
|
|
if not sep:
|
|
parts = list(text)
|
|
else:
|
|
parts = text.split(sep)
|
|
|
|
result = []
|
|
current = ""
|
|
|
|
for part in parts:
|
|
test_chunk = current + sep + part if current else part
|
|
|
|
if len(test_chunk) <= chunk_size:
|
|
current = test_chunk
|
|
else:
|
|
if current:
|
|
result.append(current)
|
|
if len(part) > chunk_size:
|
|
result.extend(split_recursive(part, sep_idx + 1))
|
|
current = ""
|
|
else:
|
|
current = part
|
|
|
|
if current:
|
|
result.append(current)
|
|
|
|
return result
|
|
|
|
raw_chunks = split_recursive(text)
|
|
|
|
# Add overlap
|
|
final_chunks = []
|
|
for i, chunk in enumerate(raw_chunks):
|
|
if i > 0 and overlap > 0:
|
|
prev_chunk = raw_chunks[i-1]
|
|
overlap_text = prev_chunk[-min(overlap, len(prev_chunk)):]
|
|
chunk = overlap_text + chunk
|
|
final_chunks.append(chunk.strip())
|
|
|
|
return [c for c in final_chunks if c]
|
|
|
|
|
|
def chunk_text_semantic(text: str, chunk_size: int, overlap_sentences: int = 1) -> List[str]:
|
|
"""Semantic sentence-aware chunking."""
|
|
if not text:
|
|
return []
|
|
|
|
if len(text) <= chunk_size:
|
|
return [text.strip()]
|
|
|
|
text = re.sub(r'\s+', ' ', text).strip()
|
|
|
|
protected = _protect_abbreviations(text)
|
|
|
|
# Split on sentence endings
|
|
sentence_pattern = r'(?<=[.!?])\s+(?=[A-ZÄÖÜÀ-Ý])|(?<=[.!?])$'
|
|
raw_sentences = re.split(sentence_pattern, protected)
|
|
|
|
sentences = []
|
|
for s in raw_sentences:
|
|
s = _restore_abbreviations(s).strip()
|
|
if s:
|
|
sentences.append(s)
|
|
|
|
# Build chunks
|
|
chunks = []
|
|
current_parts = []
|
|
current_length = 0
|
|
overlap_buffer = []
|
|
|
|
for sentence in sentences:
|
|
sentence_len = len(sentence)
|
|
|
|
if sentence_len > chunk_size:
|
|
if current_parts:
|
|
chunks.append(' '.join(current_parts))
|
|
overlap_buffer = current_parts[-overlap_sentences:] if overlap_sentences > 0 else []
|
|
current_parts = list(overlap_buffer)
|
|
current_length = sum(len(s) + 1 for s in current_parts)
|
|
|
|
if overlap_buffer:
|
|
chunks.append(' '.join(overlap_buffer) + ' ' + sentence)
|
|
else:
|
|
chunks.append(sentence)
|
|
overlap_buffer = [sentence]
|
|
current_parts = list(overlap_buffer)
|
|
current_length = len(sentence) + 1
|
|
continue
|
|
|
|
if current_length + sentence_len + 1 > chunk_size and current_parts:
|
|
chunks.append(' '.join(current_parts))
|
|
overlap_buffer = current_parts[-overlap_sentences:] if overlap_sentences > 0 else []
|
|
current_parts = list(overlap_buffer)
|
|
current_length = sum(len(s) + 1 for s in current_parts)
|
|
|
|
current_parts.append(sentence)
|
|
current_length += sentence_len + 1
|
|
|
|
if current_parts:
|
|
chunks.append(' '.join(current_parts))
|
|
|
|
return [re.sub(r'\s+', ' ', c).strip() for c in chunks if c.strip()]
|
|
|
|
|
|
# =============================================================================
|
|
# PDF Extraction Functions
|
|
# =============================================================================
|
|
|
|
def detect_pdf_backends() -> List[str]:
|
|
"""Detect available PDF backends."""
|
|
available = []
|
|
|
|
try:
|
|
from unstructured.partition.pdf import partition_pdf
|
|
available.append("unstructured")
|
|
except ImportError:
|
|
pass
|
|
|
|
try:
|
|
from pypdf import PdfReader
|
|
available.append("pypdf")
|
|
except ImportError:
|
|
pass
|
|
|
|
return available
|
|
|
|
|
|
def extract_pdf_unstructured(pdf_content: bytes) -> ExtractPDFResponse:
|
|
"""Extract PDF using Unstructured."""
|
|
import tempfile
|
|
from unstructured.partition.pdf import partition_pdf
|
|
from unstructured.documents.elements import Title, ListItem, Table, Header, Footer
|
|
|
|
with tempfile.NamedTemporaryFile(suffix=".pdf", delete=False) as tmp:
|
|
tmp.write(pdf_content)
|
|
tmp_path = tmp.name
|
|
|
|
try:
|
|
elements = partition_pdf(
|
|
filename=tmp_path,
|
|
strategy="auto",
|
|
include_page_breaks=True,
|
|
infer_table_structure=True,
|
|
languages=["deu", "eng"],
|
|
)
|
|
|
|
text_parts = []
|
|
tables = []
|
|
page_count = 1
|
|
|
|
for element in elements:
|
|
if hasattr(element, "metadata") and hasattr(element.metadata, "page_number"):
|
|
page_count = max(page_count, element.metadata.page_number or 1)
|
|
|
|
if isinstance(element, (Header, Footer)):
|
|
continue
|
|
|
|
element_text = str(element)
|
|
|
|
if isinstance(element, Table):
|
|
tables.append(element_text)
|
|
text_parts.append(f"\n[TABELLE]\n{element_text}\n[/TABELLE]\n")
|
|
elif isinstance(element, Title):
|
|
text_parts.append(f"\n## {element_text}\n")
|
|
elif isinstance(element, ListItem):
|
|
text_parts.append(f"• {element_text}")
|
|
else:
|
|
text_parts.append(element_text)
|
|
|
|
return ExtractPDFResponse(
|
|
text="\n".join(text_parts),
|
|
backend_used="unstructured",
|
|
pages=page_count,
|
|
table_count=len(tables)
|
|
)
|
|
finally:
|
|
import os as os_module
|
|
try:
|
|
os_module.unlink(tmp_path)
|
|
except:
|
|
pass
|
|
|
|
|
|
def extract_pdf_pypdf(pdf_content: bytes) -> ExtractPDFResponse:
|
|
"""Extract PDF using pypdf."""
|
|
import io
|
|
from pypdf import PdfReader
|
|
|
|
pdf_file = io.BytesIO(pdf_content)
|
|
reader = PdfReader(pdf_file)
|
|
|
|
text_parts = []
|
|
for page in reader.pages:
|
|
text = page.extract_text()
|
|
if text:
|
|
text_parts.append(text)
|
|
|
|
return ExtractPDFResponse(
|
|
text="\n\n".join(text_parts),
|
|
backend_used="pypdf",
|
|
pages=len(reader.pages),
|
|
table_count=0
|
|
)
|
|
|
|
|
|
# =============================================================================
|
|
# Application Lifecycle
|
|
# =============================================================================
|
|
|
|
@asynccontextmanager
|
|
async def lifespan(app: FastAPI):
|
|
"""Preload models on startup."""
|
|
logger.info("Starting Embedding Service...")
|
|
|
|
if config.EMBEDDING_BACKEND == "local":
|
|
try:
|
|
get_embedding_model()
|
|
logger.info("Embedding model preloaded")
|
|
except Exception as e:
|
|
logger.warning(f"Failed to preload embedding model: {e}")
|
|
|
|
if config.RERANKER_BACKEND == "local":
|
|
try:
|
|
get_reranker_model()
|
|
logger.info("Reranker model preloaded")
|
|
except Exception as e:
|
|
logger.warning(f"Failed to preload reranker model: {e}")
|
|
|
|
logger.info("Embedding Service ready")
|
|
yield
|
|
logger.info("Shutting down Embedding Service")
|
|
|
|
|
|
# =============================================================================
|
|
# FastAPI Application
|
|
# =============================================================================
|
|
|
|
app = FastAPI(
|
|
title="Embedding Service",
|
|
description="ML service for embeddings, re-ranking, and PDF extraction",
|
|
version="1.0.0",
|
|
lifespan=lifespan
|
|
)
|
|
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=["*"],
|
|
allow_credentials=True,
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
|
|
# =============================================================================
|
|
# Endpoints
|
|
# =============================================================================
|
|
|
|
@app.get("/health", response_model=HealthResponse)
|
|
async def health_check():
|
|
"""Health check endpoint."""
|
|
return HealthResponse(
|
|
status="healthy",
|
|
embedding_model=config.LOCAL_EMBEDDING_MODEL if config.EMBEDDING_BACKEND == "local" else config.OPENAI_EMBEDDING_MODEL,
|
|
embedding_dimensions=config.get_current_dimensions(),
|
|
reranker_model=config.LOCAL_RERANKER_MODEL if config.RERANKER_BACKEND == "local" else "cohere-rerank",
|
|
pdf_backends=detect_pdf_backends()
|
|
)
|
|
|
|
|
|
@app.get("/models", response_model=ModelsResponse)
|
|
async def get_models():
|
|
"""Get information about configured models."""
|
|
return ModelsResponse(
|
|
embedding_backend=config.EMBEDDING_BACKEND,
|
|
embedding_model=config.LOCAL_EMBEDDING_MODEL if config.EMBEDDING_BACKEND == "local" else config.OPENAI_EMBEDDING_MODEL,
|
|
embedding_dimensions=config.get_current_dimensions(),
|
|
reranker_backend=config.RERANKER_BACKEND,
|
|
reranker_model=config.LOCAL_RERANKER_MODEL if config.RERANKER_BACKEND == "local" else "cohere-rerank-multilingual-v3.0",
|
|
pdf_backend=config.PDF_EXTRACTION_BACKEND,
|
|
available_pdf_backends=detect_pdf_backends()
|
|
)
|
|
|
|
|
|
@app.post("/embed", response_model=EmbedResponse)
|
|
async def embed_texts(request: EmbedRequest):
|
|
"""Generate embeddings for multiple texts."""
|
|
if not request.texts:
|
|
return EmbedResponse(
|
|
embeddings=[],
|
|
model=config.LOCAL_EMBEDDING_MODEL if config.EMBEDDING_BACKEND == "local" else config.OPENAI_EMBEDDING_MODEL,
|
|
dimensions=config.get_current_dimensions()
|
|
)
|
|
|
|
try:
|
|
if config.EMBEDDING_BACKEND == "local":
|
|
embeddings = generate_local_embeddings(request.texts)
|
|
else:
|
|
embeddings = await generate_openai_embeddings(request.texts)
|
|
|
|
return EmbedResponse(
|
|
embeddings=embeddings,
|
|
model=config.LOCAL_EMBEDDING_MODEL if config.EMBEDDING_BACKEND == "local" else config.OPENAI_EMBEDDING_MODEL,
|
|
dimensions=config.get_current_dimensions()
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"Embedding error: {e}")
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
@app.post("/embed-single", response_model=EmbedSingleResponse)
|
|
async def embed_single_text(request: EmbedSingleRequest):
|
|
"""Generate embedding for a single text."""
|
|
try:
|
|
if config.EMBEDDING_BACKEND == "local":
|
|
embeddings = generate_local_embeddings([request.text])
|
|
else:
|
|
embeddings = await generate_openai_embeddings([request.text])
|
|
|
|
return EmbedSingleResponse(
|
|
embedding=embeddings[0] if embeddings else [],
|
|
model=config.LOCAL_EMBEDDING_MODEL if config.EMBEDDING_BACKEND == "local" else config.OPENAI_EMBEDDING_MODEL,
|
|
dimensions=config.get_current_dimensions()
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"Embedding error: {e}")
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
@app.post("/rerank", response_model=RerankResponse)
|
|
async def rerank_documents(request: RerankRequest):
|
|
"""Re-rank documents based on query relevance."""
|
|
if not request.documents:
|
|
return RerankResponse(
|
|
results=[],
|
|
model=config.LOCAL_RERANKER_MODEL if config.RERANKER_BACKEND == "local" else "cohere-rerank"
|
|
)
|
|
|
|
try:
|
|
if config.RERANKER_BACKEND == "local":
|
|
results = rerank_local(request.query, request.documents, request.top_k)
|
|
else:
|
|
results = await rerank_cohere(request.query, request.documents, request.top_k)
|
|
|
|
return RerankResponse(
|
|
results=results,
|
|
model=config.LOCAL_RERANKER_MODEL if config.RERANKER_BACKEND == "local" else "cohere-rerank-multilingual-v3.0"
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"Rerank error: {e}")
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
@app.post("/chunk", response_model=ChunkResponse)
|
|
async def chunk_text(request: ChunkRequest):
|
|
"""Chunk text into smaller pieces.
|
|
|
|
Strategies:
|
|
- "recursive" (default): Legal-document-aware chunking with §/Art./Section
|
|
boundary detection, section context headers, paragraph-level splitting,
|
|
and sentence-level splitting respecting DE + EN abbreviations.
|
|
- "semantic": Sentence-aware chunking with overlap by sentence count.
|
|
|
|
The old plain recursive chunker has been retired and is no longer available.
|
|
"""
|
|
if not request.text:
|
|
return ChunkResponse(chunks=[], count=0, strategy=request.strategy)
|
|
|
|
try:
|
|
if request.strategy == "semantic":
|
|
overlap_sentences = max(1, request.overlap // 100)
|
|
chunks = chunk_text_semantic(request.text, request.chunk_size, overlap_sentences)
|
|
else:
|
|
# All strategies (recursive, legal_recursive, etc.) use the legal-aware chunker.
|
|
# The old plain recursive chunker is no longer exposed via the API.
|
|
chunks = chunk_text_legal(request.text, request.chunk_size, request.overlap)
|
|
|
|
return ChunkResponse(
|
|
chunks=chunks,
|
|
count=len(chunks),
|
|
strategy=request.strategy
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"Chunking error: {e}")
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
@app.post("/extract-pdf", response_model=ExtractPDFResponse)
|
|
async def extract_pdf(file: UploadFile = File(...)):
|
|
"""Extract text from PDF file."""
|
|
pdf_content = await file.read()
|
|
available = detect_pdf_backends()
|
|
|
|
if not available:
|
|
raise HTTPException(
|
|
status_code=500,
|
|
detail="No PDF backend available. Install: pip install pypdf unstructured"
|
|
)
|
|
|
|
backend = config.PDF_EXTRACTION_BACKEND
|
|
if backend == "auto":
|
|
backend = "unstructured" if "unstructured" in available else "pypdf"
|
|
|
|
try:
|
|
if backend == "unstructured" and "unstructured" in available:
|
|
return extract_pdf_unstructured(pdf_content)
|
|
elif "pypdf" in available:
|
|
return extract_pdf_pypdf(pdf_content)
|
|
else:
|
|
raise HTTPException(status_code=500, detail=f"Backend {backend} not available")
|
|
except Exception as e:
|
|
logger.error(f"PDF extraction error: {e}")
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
# =============================================================================
|
|
# Main
|
|
# =============================================================================
|
|
|
|
if __name__ == "__main__":
|
|
import uvicorn
|
|
uvicorn.run(app, host="0.0.0.0", port=config.SERVICE_PORT)
|