Files
breakpilot-core/embedding-service/main.py
T
Benjamin Admin 0b0eed27b0 feat(embedding): NIST PDF text normalization + safe re-ingest script
Fix broken multi-column PDF extraction for NIST/BSI/ENISA documents:
- _normalize_pdf_text(): fixes broken section numbers (1 . 1 → 1.1),
  control IDs (AC - 1 → AC-1), ligatures, soft hyphens
- pdfplumber tolerances increased (x=3,y=4) for better column handling
- 3 new regex patterns: NIST CSF 2.0, NIST enhancements, OWASP Top 10
- reingest_nist.py: safe upload-before-delete for 4 lost NIST PDFs
- reingest_d5.py: safety fix — upload first, verify, then delete old

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-05-03 06:42:46 +02:00

1139 lines
39 KiB
Python

"""
Embedding Service - FastAPI Application
Provides REST endpoints for:
- Embedding generation (local sentence-transformers or OpenAI)
- Re-ranking (local CrossEncoder or Cohere)
- PDF text extraction (Unstructured or pypdf)
- Text chunking (semantic or recursive)
This service handles all ML-heavy operations, keeping the main klausur-service lightweight.
"""
import logging
import re
import unicodedata
from typing import List, Optional
from contextlib import asynccontextmanager
from fastapi import FastAPI, HTTPException, UploadFile, File
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
import config
# Configure logging
logging.basicConfig(
level=getattr(logging, config.LOG_LEVEL.upper(), logging.INFO),
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger("embedding-service")
# =============================================================================
# Lazy-loaded models
# =============================================================================
_embedding_model = None
_reranker_model = None
def get_embedding_model():
"""Lazy-load the sentence-transformers embedding model."""
global _embedding_model
if _embedding_model is None:
from sentence_transformers import SentenceTransformer
logger.info(f"Loading embedding model: {config.LOCAL_EMBEDDING_MODEL}")
_embedding_model = SentenceTransformer(config.LOCAL_EMBEDDING_MODEL)
logger.info(f"Model loaded (dim={_embedding_model.get_sentence_embedding_dimension()})")
return _embedding_model
def get_reranker_model():
"""Lazy-load the CrossEncoder reranker model."""
global _reranker_model
if _reranker_model is None:
from sentence_transformers import CrossEncoder
logger.info(f"Loading reranker model: {config.LOCAL_RERANKER_MODEL}")
_reranker_model = CrossEncoder(config.LOCAL_RERANKER_MODEL)
logger.info("Reranker loaded")
return _reranker_model
# =============================================================================
# Request/Response Models
# =============================================================================
class EmbedRequest(BaseModel):
texts: List[str] = Field(..., description="List of texts to embed")
class EmbedResponse(BaseModel):
embeddings: List[List[float]]
model: str
dimensions: int
class EmbedSingleRequest(BaseModel):
text: str = Field(..., description="Single text to embed")
class EmbedSingleResponse(BaseModel):
embedding: List[float]
model: str
dimensions: int
class RerankRequest(BaseModel):
query: str = Field(..., description="Search query")
documents: List[str] = Field(..., description="Documents to re-rank")
top_k: int = Field(default=5, description="Number of top results to return")
class RerankResult(BaseModel):
index: int
score: float
text: str
class RerankResponse(BaseModel):
results: List[RerankResult]
model: str
class ChunkRequest(BaseModel):
text: str = Field(..., description="Text to chunk")
chunk_size: int = Field(default=1000, description="Target chunk size")
overlap: int = Field(default=200, description="Overlap between chunks")
strategy: str = Field(default="semantic", description="Chunking strategy: semantic or recursive")
class ChunkMetadata(BaseModel):
text: str
section: str = ""
section_title: str = ""
paragraph: str = ""
paragraph_num: Optional[int] = None
page: Optional[int] = None
index: int = 0
class ChunkResponse(BaseModel):
chunks: List[str]
chunks_with_metadata: Optional[List[dict]] = None
count: int
strategy: str
class ExtractPDFResponse(BaseModel):
text: str
backend_used: str
pages: int
table_count: int
class HealthResponse(BaseModel):
status: str
embedding_model: str
embedding_dimensions: int
reranker_model: str
pdf_backends: List[str]
class ModelsResponse(BaseModel):
embedding_backend: str
embedding_model: str
embedding_dimensions: int
reranker_backend: str
reranker_model: str
pdf_backend: str
available_pdf_backends: List[str]
# =============================================================================
# Embedding Functions
# =============================================================================
def generate_local_embeddings(texts: List[str]) -> List[List[float]]:
"""Generate embeddings using local model."""
if not texts:
return []
model = get_embedding_model()
embeddings = model.encode(texts, show_progress_bar=len(texts) > 10)
return [emb.tolist() for emb in embeddings]
async def generate_openai_embeddings(texts: List[str]) -> List[List[float]]:
"""Generate embeddings using OpenAI API."""
import httpx
if not config.OPENAI_API_KEY:
raise HTTPException(status_code=500, detail="OPENAI_API_KEY not configured")
async with httpx.AsyncClient() as client:
response = await client.post(
"https://api.openai.com/v1/embeddings",
headers={
"Authorization": f"Bearer {config.OPENAI_API_KEY}",
"Content-Type": "application/json"
},
json={
"model": config.OPENAI_EMBEDDING_MODEL,
"input": texts
},
timeout=60.0
)
if response.status_code != 200:
raise HTTPException(
status_code=response.status_code,
detail=f"OpenAI API error: {response.text}"
)
data = response.json()
return [item["embedding"] for item in data["data"]]
# =============================================================================
# Re-ranking Functions
# =============================================================================
def rerank_local(query: str, documents: List[str], top_k: int = 5) -> List[RerankResult]:
"""Re-rank documents using local CrossEncoder."""
if not documents:
return []
model = get_reranker_model()
pairs = [(query, doc) for doc in documents]
scores = model.predict(pairs)
results = [
RerankResult(index=i, score=float(score), text=doc)
for i, (score, doc) in enumerate(zip(scores, documents))
]
results.sort(key=lambda x: x.score, reverse=True)
return results[:top_k]
async def rerank_cohere(query: str, documents: List[str], top_k: int = 5) -> List[RerankResult]:
"""Re-rank documents using Cohere API."""
import httpx
if not config.COHERE_API_KEY:
raise HTTPException(status_code=500, detail="COHERE_API_KEY not configured")
async with httpx.AsyncClient() as client:
response = await client.post(
"https://api.cohere.ai/v2/rerank",
headers={
"Authorization": f"Bearer {config.COHERE_API_KEY}",
"Content-Type": "application/json"
},
json={
"model": "rerank-multilingual-v3.0",
"query": query,
"documents": documents,
"top_n": top_k,
"return_documents": False,
},
timeout=30.0
)
if response.status_code != 200:
raise HTTPException(
status_code=response.status_code,
detail=f"Cohere API error: {response.text}"
)
data = response.json()
return [
RerankResult(
index=item["index"],
score=item["relevance_score"],
text=documents[item["index"]]
)
for item in data.get("results", [])
]
# =============================================================================
# Chunking Functions
# =============================================================================
# German abbreviations that don't end sentences
GERMAN_ABBREVIATIONS = {
'bzw', 'ca', 'chr', 'd.h', 'dr', 'etc', 'evtl', 'ggf', 'inkl', 'max',
'min', 'mio', 'mrd', 'nr', 'prof', 's', 'sog', 'u.a', 'u.ä', 'usw',
'v.a', 'vgl', 'vs', 'z.b', 'z.t', 'zzgl', 'abs', 'art', 'abschn',
'anh', 'anl', 'aufl', 'bd', 'bes', 'bzgl', 'dgl', 'einschl', 'entspr',
'erg', 'erl', 'gem', 'grds', 'hrsg', 'insb', 'ivm', 'kap', 'lit',
'nachf', 'rdnr', 'rn', 'rz', 'ua', 'uvm', 'vorst', 'ziff'
}
# English abbreviations that don't end sentences
ENGLISH_ABBREVIATIONS = {
'e.g', 'i.e', 'etc', 'vs', 'al', 'approx', 'avg', 'dept', 'dr', 'ed',
'est', 'fig', 'govt', 'inc', 'jr', 'ltd', 'max', 'min', 'mr', 'mrs',
'ms', 'no', 'prof', 'pt', 'ref', 'rev', 'sec', 'sgt', 'sr', 'st',
'vol', 'cf', 'ch', 'cl', 'col', 'corp', 'cpl', 'def', 'dist', 'div',
'gen', 'hon', 'illus', 'intl', 'natl', 'org', 'para', 'pp', 'repr',
'resp', 'supp', 'tech', 'temp', 'treas', 'univ'
}
# Combined abbreviations for both languages
ALL_ABBREVIATIONS = GERMAN_ABBREVIATIONS | ENGLISH_ABBREVIATIONS
# Regex pattern for legal/standard section headers
_LEGAL_SECTION_RE = re.compile(
r'^(?:'
r'§\s*\d+' # § 25, § 5a
r'|Art(?:ikel|icle|\.)\s*\d+' # Artikel 5, Article 12, Art. 3
r'|Section\s+\d+' # Section 4.2
r'|Abschnitt\s+\d+' # Abschnitt III
r'|Kapitel\s+\d+' # Kapitel 2
r'|Chapter\s+\d+' # Chapter 3
r'|Anhang\s+[IVXLC\d]+' # Anhang III
r'|Annex\s+[IVXLC\d]+' # Annex XII
r'|TEIL\s+[IVXLC\d]+' # TEIL II
r'|Part\s+[IVXLC\d]+' # Part III
r'|Recital\s+\d+' # Recital 42
r'|Erwaegungsgrund\s+\d+' # Erwaegungsgrund 26
# NIST/ENISA/standard numbering
r'|\d+\.\d+(?:\.\d+)*\s+[A-ZÄÖÜ]' # 1.1 Title, 2.3.1 Subtitle
r'|[A-Z]{2,4}[-\.]\d+(?:\.\d+)*\b' # AC-1, AU-2, PO.1, PW.1.1
r'|[A-Z]{2}\.[A-Z]{2}-\d{2}\b' # GV.OC-01 (NIST CSF 2.0)
r'|[A-Z]{2,4}-\d+\(\d+\)' # AC-1(1) (NIST enhancements)
r'|A\d{2}(?::\d{4})?\b' # A01:2021 (OWASP Top 10)
r'|Table\s+\d+' # Table 1, Table A-1
r'|Figure\s+\d+' # Figure 1
r'|Appendix\s+[A-Z\d]' # Appendix A, Appendix 1
r')',
re.IGNORECASE | re.MULTILINE
)
# Regex for any heading-like line (Markdown ## or ALL-CAPS line)
_HEADING_RE = re.compile(
r'^(?:'
r'#{1,6}\s+.+' # Markdown headings
r'|[A-ZÄÖÜ][A-ZÄÖÜ\s\-]{5,}$' # ALL-CAPS lines (>5 chars)
r')',
re.MULTILINE
)
def _detect_language(text: str) -> str:
"""Simple heuristic: count German vs English marker words."""
sample = text[:5000].lower()
de_markers = sum(1 for w in ['der', 'die', 'das', 'und', 'ist', 'für', 'von',
'werden', 'nach', 'gemäß', 'sowie', 'durch']
if f' {w} ' in sample)
en_markers = sum(1 for w in ['the', 'and', 'for', 'that', 'with', 'shall',
'must', 'should', 'which', 'from', 'this']
if f' {w} ' in sample)
return 'de' if de_markers > en_markers else 'en'
def _protect_abbreviations(text: str) -> str:
"""Replace dots in abbreviations with placeholders to prevent false sentence splits."""
protected = text
for abbrev in ALL_ABBREVIATIONS:
pattern = re.compile(r'\b(' + re.escape(abbrev) + r')\.', re.IGNORECASE)
# Use lambda to preserve original case of the matched abbreviation
protected = pattern.sub(lambda m: m.group(1).replace('.', '<DOT>') + '<ABBR>', protected)
# Protect decimals (3.14) and ordinals (1. Absatz)
protected = re.sub(r'(\d)\.(\d)', r'\1<DECIMAL>\2', protected)
protected = re.sub(r'(\d+)\.\s', r'\1<ORD> ', protected)
return protected
def _restore_abbreviations(text: str) -> str:
"""Restore placeholders back to dots."""
return (text
.replace('<DOT>', '.')
.replace('<ABBR>', '.')
.replace('<DECIMAL>', '.')
.replace('<ORD>', '.'))
def _split_sentences(text: str) -> List[str]:
"""Split text into sentences, respecting abbreviations in DE and EN."""
protected = _protect_abbreviations(text)
# Split after sentence-ending punctuation followed by uppercase or newline
sentence_pattern = r'(?<=[.!?])\s+(?=[A-ZÄÖÜÀ-Ý])|(?<=[.!?])\s*\n'
raw = re.split(sentence_pattern, protected)
sentences = []
for s in raw:
s = _restore_abbreviations(s).strip()
if s:
sentences.append(s)
return sentences
# Regex for paragraph/subsection references within text
_PARAGRAPH_RE = re.compile(
r'(?:'
r'Abs(?:atz|\.)\s*(\d+)' # Abs. 1, Absatz 2
r'|Nr\.\s*(\d+)' # Nr. 3
r'|Satz\s+(\d+)' # Satz 1
r'|lit\.\s*([a-z])' # lit. a
r'|\((\d+)\)' # (1), (2)
r')',
re.IGNORECASE
)
# Regex to extract section number from header
_SECTION_NUMBER_RE = re.compile(
r'(?:'
r'§\s*(\d+[a-z]*)' # § 25, § 312k
r'|Art(?:ikel|icle|\.)\s*(\d+)' # Artikel 5, Art. 3
r'|Section\s+(\d[\d.]*)' # Section 4.2
r'|Kapitel\s+(\d+)' # Kapitel 2
r'|Anhang\s+([IVXLC\d]+)' # Anhang III
r'|Annex\s+([IVXLC\d]+)' # Annex XII
r')',
re.IGNORECASE
)
def _extract_section_header(line: str) -> Optional[str]:
"""Extract a legal section header from a line, or None."""
m = _LEGAL_SECTION_RE.match(line.strip())
if m:
return line.strip()
m = _HEADING_RE.match(line.strip())
if m:
return line.strip()
return None
def _parse_section_metadata(header: str) -> dict:
"""Parse a section header into structured metadata.
Returns: {"section": "§ 312k", "section_title": "Kuendigungsbutton"}
"""
if not header:
return {"section": "", "section_title": ""}
m = _SECTION_NUMBER_RE.search(header)
section = ""
if m:
# Find which group matched
for i, g in enumerate(m.groups(), 1):
if g:
section = header[m.start():m.end()].strip()
break
# Title = everything after the section number
title = header
if section:
idx = header.find(section)
if idx >= 0:
title = header[idx + len(section):].strip()
# Remove leading punctuation/whitespace
title = title.lstrip(' .-–—:')
return {"section": section, "section_title": title.strip()}
def _extract_paragraph_ref(text: str) -> dict:
"""Extract paragraph/subsection reference from chunk text.
Returns: {"paragraph": "Abs. 1", "paragraph_num": 1}
"""
m = _PARAGRAPH_RE.search(text[:200]) # Only search first 200 chars
if not m:
return {"paragraph": "", "paragraph_num": None}
for i, g in enumerate(m.groups(), 1):
if g:
ref = text[m.start():m.end()].strip()
try:
num = int(g)
except ValueError:
num = ord(g.lower()) - ord('a') + 1 # lit. a = 1, b = 2
return {"paragraph": ref, "paragraph_num": num}
return {"paragraph": "", "paragraph_num": None}
def chunk_text_legal(text: str, chunk_size: int, overlap: int) -> List[str]:
"""
Legal-document-aware chunking.
Strategy:
1. Split on legal section boundaries (§, Art., Section, Chapter, etc.)
2. Within each section, split on paragraph boundaries (double newline)
3. Within each paragraph, split on sentence boundaries
4. Prepend section header as context prefix to every chunk
5. Add overlap from previous chunk
Works for both German (DSGVO, BGB, AI Act DE) and English (NIST, SLSA, CRA EN) texts.
"""
if not text or len(text) <= chunk_size:
return [text.strip()] if text and text.strip() else []
# --- Phase 1: Split into sections by legal headers ---
lines = text.split('\n')
sections = [] # list of (header, content)
current_header = None
current_lines = []
for line in lines:
header = _extract_section_header(line)
if header and current_lines:
sections.append((current_header, '\n'.join(current_lines)))
current_header = header
current_lines = [line]
elif header and not current_lines:
current_header = header
current_lines = [line]
else:
current_lines.append(line)
if current_lines:
sections.append((current_header, '\n'.join(current_lines)))
# --- Phase 2: Within each section, split on paragraphs, then sentences ---
raw_chunks = []
for section_header, section_text in sections:
# Build context prefix (max 120 chars to leave room for content)
prefix = ""
if section_header:
truncated = section_header[:120]
prefix = f"[{truncated}] "
paragraphs = re.split(r'\n\s*\n', section_text)
current_chunk = prefix
current_length = len(prefix)
for para in paragraphs:
para = para.strip()
if not para:
continue
# If paragraph fits in remaining space, append
if current_length + len(para) + 1 <= chunk_size:
if current_chunk and not current_chunk.endswith(' '):
current_chunk += '\n\n'
current_chunk += para
current_length = len(current_chunk)
continue
# Paragraph doesn't fit — flush current chunk if non-empty
if current_chunk.strip() and current_chunk.strip() != prefix.strip():
raw_chunks.append(current_chunk.strip())
# If entire paragraph fits in a fresh chunk, start new chunk
if len(prefix) + len(para) <= chunk_size:
current_chunk = prefix + para
current_length = len(current_chunk)
continue
# Paragraph too long — split by sentences
sentences = _split_sentences(para)
current_chunk = prefix
current_length = len(prefix)
for sentence in sentences:
sentence_len = len(sentence)
# Single sentence exceeds chunk_size — force-split
if len(prefix) + sentence_len > chunk_size:
if current_chunk.strip() and current_chunk.strip() != prefix.strip():
raw_chunks.append(current_chunk.strip())
# Hard split the long sentence
remaining = sentence
while remaining:
take = chunk_size - len(prefix)
chunk_part = prefix + remaining[:take]
raw_chunks.append(chunk_part.strip())
remaining = remaining[take:]
current_chunk = prefix
current_length = len(prefix)
continue
if current_length + sentence_len + 1 > chunk_size:
if current_chunk.strip() and current_chunk.strip() != prefix.strip():
raw_chunks.append(current_chunk.strip())
current_chunk = prefix + sentence
current_length = len(current_chunk)
else:
if current_chunk and not current_chunk.endswith(' '):
current_chunk += ' '
current_chunk += sentence
current_length = len(current_chunk)
# Flush remaining content for this section
if current_chunk.strip() and current_chunk.strip() != prefix.strip():
raw_chunks.append(current_chunk.strip())
if not raw_chunks:
return [text.strip()] if text.strip() else []
# --- Phase 3: Add overlap ---
final_chunks = []
for i, chunk in enumerate(raw_chunks):
if i > 0 and overlap > 0:
prev = raw_chunks[i - 1]
# Take overlap from end of previous chunk (but not the prefix)
overlap_text = prev[-min(overlap, len(prev)):]
# Only add overlap if it doesn't start mid-word
space_idx = overlap_text.find(' ')
if space_idx > 0:
overlap_text = overlap_text[space_idx + 1:]
if overlap_text:
# Insert overlap AFTER the [§ ...] prefix to preserve it
# for structured metadata extraction
prefix_match = re.match(r'\[.+?\]\s*', chunk)
if prefix_match:
pos = prefix_match.end()
chunk = chunk[:pos] + overlap_text + ' ' + chunk[pos:]
else:
chunk = overlap_text + ' ' + chunk
final_chunks.append(chunk.strip())
return [c for c in final_chunks if c]
def chunk_text_legal_structured(text: str, chunk_size: int, overlap: int) -> List[dict]:
"""Legal-aware chunking that returns structured metadata per chunk.
Returns list of dicts with: text, section, section_title, paragraph, paragraph_num, index.
Uses the same splitting logic as chunk_text_legal but extracts metadata.
"""
plain_chunks = chunk_text_legal(text, chunk_size, overlap)
# Track which section each chunk belongs to by re-parsing the prefix
structured = []
for i, chunk_text in enumerate(plain_chunks):
meta = {"text": chunk_text, "section": "", "section_title": "",
"paragraph": "", "paragraph_num": None, "page": None, "index": i}
# Extract section from the [§ 25 Title] prefix that chunk_text_legal adds
prefix_match = re.match(r'^\[(.+?)\]\s*', chunk_text)
if prefix_match:
header = prefix_match.group(1)
section_meta = _parse_section_metadata(header)
meta["section"] = section_meta["section"]
meta["section_title"] = section_meta["section_title"]
# Extract paragraph reference from chunk content
para_meta = _extract_paragraph_ref(chunk_text)
meta["paragraph"] = para_meta["paragraph"]
meta["paragraph_num"] = para_meta["paragraph_num"]
structured.append(meta)
return structured
def chunk_text_recursive(text: str, chunk_size: int, overlap: int) -> List[str]:
"""Recursive character-based chunking (legacy, use legal_recursive for legal docs)."""
if not text or len(text) <= chunk_size:
return [text] if text else []
separators = ["\n\n", "\n", ". ", " ", ""]
def split_recursive(text: str, sep_idx: int = 0) -> List[str]:
if len(text) <= chunk_size:
return [text]
if sep_idx >= len(separators):
return [text[i:i+chunk_size] for i in range(0, len(text), chunk_size - overlap)]
sep = separators[sep_idx]
if not sep:
parts = list(text)
else:
parts = text.split(sep)
result = []
current = ""
for part in parts:
test_chunk = current + sep + part if current else part
if len(test_chunk) <= chunk_size:
current = test_chunk
else:
if current:
result.append(current)
if len(part) > chunk_size:
result.extend(split_recursive(part, sep_idx + 1))
current = ""
else:
current = part
if current:
result.append(current)
return result
raw_chunks = split_recursive(text)
# Add overlap
final_chunks = []
for i, chunk in enumerate(raw_chunks):
if i > 0 and overlap > 0:
prev_chunk = raw_chunks[i-1]
overlap_text = prev_chunk[-min(overlap, len(prev_chunk)):]
chunk = overlap_text + chunk
final_chunks.append(chunk.strip())
return [c for c in final_chunks if c]
def chunk_text_semantic(text: str, chunk_size: int, overlap_sentences: int = 1) -> List[str]:
"""Semantic sentence-aware chunking."""
if not text:
return []
if len(text) <= chunk_size:
return [text.strip()]
text = re.sub(r'\s+', ' ', text).strip()
protected = _protect_abbreviations(text)
# Split on sentence endings
sentence_pattern = r'(?<=[.!?])\s+(?=[A-ZÄÖÜÀ-Ý])|(?<=[.!?])$'
raw_sentences = re.split(sentence_pattern, protected)
sentences = []
for s in raw_sentences:
s = _restore_abbreviations(s).strip()
if s:
sentences.append(s)
# Build chunks
chunks = []
current_parts = []
current_length = 0
overlap_buffer = []
for sentence in sentences:
sentence_len = len(sentence)
if sentence_len > chunk_size:
if current_parts:
chunks.append(' '.join(current_parts))
overlap_buffer = current_parts[-overlap_sentences:] if overlap_sentences > 0 else []
current_parts = list(overlap_buffer)
current_length = sum(len(s) + 1 for s in current_parts)
if overlap_buffer:
chunks.append(' '.join(overlap_buffer) + ' ' + sentence)
else:
chunks.append(sentence)
overlap_buffer = [sentence]
current_parts = list(overlap_buffer)
current_length = len(sentence) + 1
continue
if current_length + sentence_len + 1 > chunk_size and current_parts:
chunks.append(' '.join(current_parts))
overlap_buffer = current_parts[-overlap_sentences:] if overlap_sentences > 0 else []
current_parts = list(overlap_buffer)
current_length = sum(len(s) + 1 for s in current_parts)
current_parts.append(sentence)
current_length += sentence_len + 1
if current_parts:
chunks.append(' '.join(current_parts))
return [re.sub(r'\s+', ' ', c).strip() for c in chunks if c.strip()]
# =============================================================================
# PDF Extraction Functions
# =============================================================================
def detect_pdf_backends() -> List[str]:
"""Detect available PDF backends."""
available = []
try:
from unstructured.partition.pdf import partition_pdf # noqa: F401
available.append("unstructured")
except ImportError:
pass
try:
import pdfplumber # noqa: F401
available.append("pdfplumber")
except ImportError:
pass
try:
from pypdf import PdfReader # noqa: F401
available.append("pypdf")
except ImportError:
pass
return available
def extract_pdf_unstructured(pdf_content: bytes) -> ExtractPDFResponse:
"""Extract PDF using Unstructured."""
import tempfile
from unstructured.partition.pdf import partition_pdf
from unstructured.documents.elements import Title, ListItem, Table, Header, Footer
with tempfile.NamedTemporaryFile(suffix=".pdf", delete=False) as tmp:
tmp.write(pdf_content)
tmp_path = tmp.name
try:
elements = partition_pdf(
filename=tmp_path,
strategy="auto",
include_page_breaks=True,
infer_table_structure=True,
languages=["deu", "eng"],
)
text_parts = []
tables = []
page_count = 1
for element in elements:
if hasattr(element, "metadata") and hasattr(element.metadata, "page_number"):
page_count = max(page_count, element.metadata.page_number or 1)
if isinstance(element, (Header, Footer)):
continue
element_text = str(element)
if isinstance(element, Table):
tables.append(element_text)
text_parts.append(f"\n[TABELLE]\n{element_text}\n[/TABELLE]\n")
elif isinstance(element, Title):
text_parts.append(f"\n## {element_text}\n")
elif isinstance(element, ListItem):
text_parts.append(f"{element_text}")
else:
text_parts.append(element_text)
return ExtractPDFResponse(
text="\n".join(text_parts),
backend_used="unstructured",
pages=page_count,
table_count=len(tables)
)
finally:
import os as os_module
try:
os_module.unlink(tmp_path)
except OSError:
pass
def _normalize_pdf_text(text: str) -> str:
"""Fix broken spacing from multi-column PDF extraction.
pdfplumber/pypdf often break section numbers in multi-column NIST/BSI/ENISA
PDFs: "1 . 1" instead of "1.1", "AC - 1" instead of "AC-1".
"""
# Unicode NFKC: decompose ligatures (fi → fi) before other fixes
text = unicodedata.normalize('NFKC', text)
# Remove soft hyphens and zero-width spaces
text = text.replace('\u00ad', '').replace('\u200b', '')
# "1 . 1" → "1.1" (broken section numbers, apply repeatedly for nested)
prev = None
while prev != text:
prev = text
text = re.sub(r'(\d+)\s+\.\s+(\d+)', r'\1.\2', text)
# "AC - 1" → "AC-1" (broken NIST control IDs, 2-4 uppercase letters)
text = re.sub(r'\b([A-Z]{2,4})\s+-\s+(\d+)\b', r'\1-\2', text)
# "GV . OC - 01" → "GV.OC-01" (NIST CSF 2.0 compound IDs)
text = re.sub(
r'\b([A-Z]{2})\s*\.\s*([A-Z]{2})\s*-\s*(\d{2})\b', r'\1.\2-\3', text
)
# "AC - 1 ( 1 )" → "AC-1(1)" (NIST enhancements with spaced parens)
text = re.sub(r'\(\s+(\d+)\s+\)', r'(\1)', text)
# Collapse multiple horizontal spaces (keep newlines)
text = re.sub(r'[^\S\n]{2,}', ' ', text)
return text
def extract_pdf_pdfplumber(pdf_content: bytes) -> ExtractPDFResponse:
"""Extract PDF using pdfplumber (best for multi-column EU regulation PDFs)."""
import io
import pdfplumber
pdf_file = io.BytesIO(pdf_content)
text_parts = []
page_count = 0
with pdfplumber.open(pdf_file) as pdf:
page_count = len(pdf.pages)
for page in pdf.pages:
text = page.extract_text(x_tolerance=3, y_tolerance=4)
if text:
text_parts.append(text)
return ExtractPDFResponse(
text=_normalize_pdf_text("\n\n".join(text_parts)),
backend_used="pdfplumber",
pages=page_count,
table_count=0,
)
def extract_pdf_pypdf(pdf_content: bytes) -> ExtractPDFResponse:
"""Extract PDF using pypdf (fallback)."""
import io
from pypdf import PdfReader
pdf_file = io.BytesIO(pdf_content)
reader = PdfReader(pdf_file)
text_parts = []
for page in reader.pages:
text = page.extract_text()
if text:
text_parts.append(text)
return ExtractPDFResponse(
text=_normalize_pdf_text("\n\n".join(text_parts)),
backend_used="pypdf",
pages=len(reader.pages),
table_count=0
)
# =============================================================================
# Application Lifecycle
# =============================================================================
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Preload models on startup."""
logger.info("Starting Embedding Service...")
if config.EMBEDDING_BACKEND == "local":
try:
get_embedding_model()
logger.info("Embedding model preloaded")
except Exception as e:
logger.warning(f"Failed to preload embedding model: {e}")
if config.RERANKER_BACKEND == "local":
try:
get_reranker_model()
logger.info("Reranker model preloaded")
except Exception as e:
logger.warning(f"Failed to preload reranker model: {e}")
logger.info("Embedding Service ready")
yield
logger.info("Shutting down Embedding Service")
# =============================================================================
# FastAPI Application
# =============================================================================
app = FastAPI(
title="Embedding Service",
description="ML service for embeddings, re-ranking, and PDF extraction",
version="1.0.0",
lifespan=lifespan
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# =============================================================================
# Endpoints
# =============================================================================
@app.get("/health", response_model=HealthResponse)
async def health_check():
"""Health check endpoint."""
return HealthResponse(
status="healthy",
embedding_model=config.LOCAL_EMBEDDING_MODEL if config.EMBEDDING_BACKEND == "local" else config.OPENAI_EMBEDDING_MODEL,
embedding_dimensions=config.get_current_dimensions(),
reranker_model=config.LOCAL_RERANKER_MODEL if config.RERANKER_BACKEND == "local" else "cohere-rerank",
pdf_backends=detect_pdf_backends()
)
@app.get("/models", response_model=ModelsResponse)
async def get_models():
"""Get information about configured models."""
return ModelsResponse(
embedding_backend=config.EMBEDDING_BACKEND,
embedding_model=config.LOCAL_EMBEDDING_MODEL if config.EMBEDDING_BACKEND == "local" else config.OPENAI_EMBEDDING_MODEL,
embedding_dimensions=config.get_current_dimensions(),
reranker_backend=config.RERANKER_BACKEND,
reranker_model=config.LOCAL_RERANKER_MODEL if config.RERANKER_BACKEND == "local" else "cohere-rerank-multilingual-v3.0",
pdf_backend=config.PDF_EXTRACTION_BACKEND,
available_pdf_backends=detect_pdf_backends()
)
@app.post("/embed", response_model=EmbedResponse)
async def embed_texts(request: EmbedRequest):
"""Generate embeddings for multiple texts."""
if not request.texts:
return EmbedResponse(
embeddings=[],
model=config.LOCAL_EMBEDDING_MODEL if config.EMBEDDING_BACKEND == "local" else config.OPENAI_EMBEDDING_MODEL,
dimensions=config.get_current_dimensions()
)
try:
if config.EMBEDDING_BACKEND == "local":
embeddings = generate_local_embeddings(request.texts)
else:
embeddings = await generate_openai_embeddings(request.texts)
return EmbedResponse(
embeddings=embeddings,
model=config.LOCAL_EMBEDDING_MODEL if config.EMBEDDING_BACKEND == "local" else config.OPENAI_EMBEDDING_MODEL,
dimensions=config.get_current_dimensions()
)
except Exception as e:
logger.error(f"Embedding error: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.post("/embed-single", response_model=EmbedSingleResponse)
async def embed_single_text(request: EmbedSingleRequest):
"""Generate embedding for a single text."""
try:
if config.EMBEDDING_BACKEND == "local":
embeddings = generate_local_embeddings([request.text])
else:
embeddings = await generate_openai_embeddings([request.text])
return EmbedSingleResponse(
embedding=embeddings[0] if embeddings else [],
model=config.LOCAL_EMBEDDING_MODEL if config.EMBEDDING_BACKEND == "local" else config.OPENAI_EMBEDDING_MODEL,
dimensions=config.get_current_dimensions()
)
except Exception as e:
logger.error(f"Embedding error: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.post("/rerank", response_model=RerankResponse)
async def rerank_documents(request: RerankRequest):
"""Re-rank documents based on query relevance."""
if not request.documents:
return RerankResponse(
results=[],
model=config.LOCAL_RERANKER_MODEL if config.RERANKER_BACKEND == "local" else "cohere-rerank"
)
try:
if config.RERANKER_BACKEND == "local":
results = rerank_local(request.query, request.documents, request.top_k)
else:
results = await rerank_cohere(request.query, request.documents, request.top_k)
return RerankResponse(
results=results,
model=config.LOCAL_RERANKER_MODEL if config.RERANKER_BACKEND == "local" else "cohere-rerank-multilingual-v3.0"
)
except Exception as e:
logger.error(f"Rerank error: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.post("/chunk", response_model=ChunkResponse)
async def chunk_text(request: ChunkRequest):
"""Chunk text into smaller pieces.
Strategies:
- "recursive" (default): Legal-document-aware chunking with §/Art./Section
boundary detection, section context headers, paragraph-level splitting,
and sentence-level splitting respecting DE + EN abbreviations.
- "semantic": Sentence-aware chunking with overlap by sentence count.
The old plain recursive chunker has been retired and is no longer available.
"""
if not request.text:
return ChunkResponse(chunks=[], count=0, strategy=request.strategy)
try:
if request.strategy == "semantic":
overlap_sentences = max(1, request.overlap // 100)
chunks = chunk_text_semantic(request.text, request.chunk_size, overlap_sentences)
return ChunkResponse(
chunks=chunks,
count=len(chunks),
strategy=request.strategy,
)
else:
# All strategies use the legal-aware chunker
chunks = chunk_text_legal(request.text, request.chunk_size, request.overlap)
# Also generate structured metadata
structured = chunk_text_legal_structured(request.text, request.chunk_size, request.overlap)
return ChunkResponse(
chunks=chunks,
chunks_with_metadata=structured,
count=len(chunks),
strategy=request.strategy,
)
except Exception as e:
logger.error(f"Chunking error: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.post("/extract-pdf", response_model=ExtractPDFResponse)
async def extract_pdf(file: UploadFile = File(...)):
"""Extract text from PDF file."""
pdf_content = await file.read()
available = detect_pdf_backends()
if not available:
raise HTTPException(
status_code=500,
detail="No PDF backend available. Install: pip install pypdf unstructured"
)
backend = config.PDF_EXTRACTION_BACKEND
if backend == "auto":
# Prefer: unstructured > pdfplumber > pypdf
if "unstructured" in available:
backend = "unstructured"
elif "pdfplumber" in available:
backend = "pdfplumber"
else:
backend = "pypdf"
try:
if backend == "unstructured" and "unstructured" in available:
return extract_pdf_unstructured(pdf_content)
elif backend == "pdfplumber" and "pdfplumber" in available:
return extract_pdf_pdfplumber(pdf_content)
elif "pypdf" in available:
return extract_pdf_pypdf(pdf_content)
else:
raise HTTPException(status_code=500, detail=f"Backend {backend} not available")
except Exception as e:
logger.error(f"PDF extraction error: {e}")
raise HTTPException(status_code=500, detail=str(e))
# =============================================================================
# Main
# =============================================================================
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=config.SERVICE_PORT)