""" Embedding Service - FastAPI Application Provides REST endpoints for: - Embedding generation (local sentence-transformers or OpenAI) - Re-ranking (local CrossEncoder or Cohere) - PDF text extraction (Unstructured or pypdf) - Text chunking (semantic or recursive) This service handles all ML-heavy operations, keeping the main klausur-service lightweight. """ import os import logging from typing import List, Optional from contextlib import asynccontextmanager from fastapi import FastAPI, HTTPException, UploadFile, File from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel, Field import config # Configure logging logging.basicConfig( level=getattr(logging, config.LOG_LEVEL.upper(), logging.INFO), format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" ) logger = logging.getLogger("embedding-service") # ============================================================================= # Lazy-loaded models # ============================================================================= _embedding_model = None _reranker_model = None def get_embedding_model(): """Lazy-load the sentence-transformers embedding model.""" global _embedding_model if _embedding_model is None: from sentence_transformers import SentenceTransformer logger.info(f"Loading embedding model: {config.LOCAL_EMBEDDING_MODEL}") _embedding_model = SentenceTransformer(config.LOCAL_EMBEDDING_MODEL) logger.info(f"Model loaded (dim={_embedding_model.get_sentence_embedding_dimension()})") return _embedding_model def get_reranker_model(): """Lazy-load the CrossEncoder reranker model.""" global _reranker_model if _reranker_model is None: from sentence_transformers import CrossEncoder logger.info(f"Loading reranker model: {config.LOCAL_RERANKER_MODEL}") _reranker_model = CrossEncoder(config.LOCAL_RERANKER_MODEL) logger.info("Reranker loaded") return _reranker_model # ============================================================================= # Request/Response Models # ============================================================================= class EmbedRequest(BaseModel): texts: List[str] = Field(..., description="List of texts to embed") class EmbedResponse(BaseModel): embeddings: List[List[float]] model: str dimensions: int class EmbedSingleRequest(BaseModel): text: str = Field(..., description="Single text to embed") class EmbedSingleResponse(BaseModel): embedding: List[float] model: str dimensions: int class RerankRequest(BaseModel): query: str = Field(..., description="Search query") documents: List[str] = Field(..., description="Documents to re-rank") top_k: int = Field(default=5, description="Number of top results to return") class RerankResult(BaseModel): index: int score: float text: str class RerankResponse(BaseModel): results: List[RerankResult] model: str class ChunkRequest(BaseModel): text: str = Field(..., description="Text to chunk") chunk_size: int = Field(default=1000, description="Target chunk size") overlap: int = Field(default=200, description="Overlap between chunks") strategy: str = Field(default="semantic", description="Chunking strategy: semantic or recursive") class ChunkResponse(BaseModel): chunks: List[str] count: int strategy: str class ExtractPDFResponse(BaseModel): text: str backend_used: str pages: int table_count: int class HealthResponse(BaseModel): status: str embedding_model: str embedding_dimensions: int reranker_model: str pdf_backends: List[str] class ModelsResponse(BaseModel): embedding_backend: str embedding_model: str embedding_dimensions: int reranker_backend: str reranker_model: str pdf_backend: str available_pdf_backends: List[str] # ============================================================================= # Embedding Functions # ============================================================================= def generate_local_embeddings(texts: List[str]) -> List[List[float]]: """Generate embeddings using local model.""" if not texts: return [] model = get_embedding_model() embeddings = model.encode(texts, show_progress_bar=len(texts) > 10) return [emb.tolist() for emb in embeddings] async def generate_openai_embeddings(texts: List[str]) -> List[List[float]]: """Generate embeddings using OpenAI API.""" import httpx if not config.OPENAI_API_KEY: raise HTTPException(status_code=500, detail="OPENAI_API_KEY not configured") async with httpx.AsyncClient() as client: response = await client.post( "https://api.openai.com/v1/embeddings", headers={ "Authorization": f"Bearer {config.OPENAI_API_KEY}", "Content-Type": "application/json" }, json={ "model": config.OPENAI_EMBEDDING_MODEL, "input": texts }, timeout=60.0 ) if response.status_code != 200: raise HTTPException( status_code=response.status_code, detail=f"OpenAI API error: {response.text}" ) data = response.json() return [item["embedding"] for item in data["data"]] # ============================================================================= # Re-ranking Functions # ============================================================================= def rerank_local(query: str, documents: List[str], top_k: int = 5) -> List[RerankResult]: """Re-rank documents using local CrossEncoder.""" if not documents: return [] model = get_reranker_model() pairs = [(query, doc) for doc in documents] scores = model.predict(pairs) results = [ RerankResult(index=i, score=float(score), text=doc) for i, (score, doc) in enumerate(zip(scores, documents)) ] results.sort(key=lambda x: x.score, reverse=True) return results[:top_k] async def rerank_cohere(query: str, documents: List[str], top_k: int = 5) -> List[RerankResult]: """Re-rank documents using Cohere API.""" import httpx if not config.COHERE_API_KEY: raise HTTPException(status_code=500, detail="COHERE_API_KEY not configured") async with httpx.AsyncClient() as client: response = await client.post( "https://api.cohere.ai/v2/rerank", headers={ "Authorization": f"Bearer {config.COHERE_API_KEY}", "Content-Type": "application/json" }, json={ "model": "rerank-multilingual-v3.0", "query": query, "documents": documents, "top_n": top_k, "return_documents": False, }, timeout=30.0 ) if response.status_code != 200: raise HTTPException( status_code=response.status_code, detail=f"Cohere API error: {response.text}" ) data = response.json() return [ RerankResult( index=item["index"], score=item["relevance_score"], text=documents[item["index"]] ) for item in data.get("results", []) ] # ============================================================================= # Chunking Functions # ============================================================================= # German abbreviations that don't end sentences GERMAN_ABBREVIATIONS = { 'bzw', 'ca', 'chr', 'd.h', 'dr', 'etc', 'evtl', 'ggf', 'inkl', 'max', 'min', 'mio', 'mrd', 'nr', 'prof', 's', 'sog', 'u.a', 'u.ä', 'usw', 'v.a', 'vgl', 'vs', 'z.b', 'z.t', 'zzgl' } def chunk_text_recursive(text: str, chunk_size: int, overlap: int) -> List[str]: """Recursive character-based chunking.""" import re if not text or len(text) <= chunk_size: return [text] if text else [] separators = ["\n\n", "\n", ". ", " ", ""] def split_recursive(text: str, sep_idx: int = 0) -> List[str]: if len(text) <= chunk_size: return [text] if sep_idx >= len(separators): return [text[i:i+chunk_size] for i in range(0, len(text), chunk_size - overlap)] sep = separators[sep_idx] if not sep: parts = list(text) else: parts = text.split(sep) result = [] current = "" for part in parts: test_chunk = current + sep + part if current else part if len(test_chunk) <= chunk_size: current = test_chunk else: if current: result.append(current) if len(part) > chunk_size: result.extend(split_recursive(part, sep_idx + 1)) current = "" else: current = part if current: result.append(current) return result raw_chunks = split_recursive(text) # Add overlap final_chunks = [] for i, chunk in enumerate(raw_chunks): if i > 0 and overlap > 0: prev_chunk = raw_chunks[i-1] overlap_text = prev_chunk[-min(overlap, len(prev_chunk)):] chunk = overlap_text + chunk final_chunks.append(chunk.strip()) return [c for c in final_chunks if c] def chunk_text_semantic(text: str, chunk_size: int, overlap_sentences: int = 1) -> List[str]: """Semantic sentence-aware chunking.""" import re if not text: return [] if len(text) <= chunk_size: return [text.strip()] # Split into sentences (simplified for German) text = re.sub(r'\s+', ' ', text).strip() # Protect abbreviations protected = text for abbrev in GERMAN_ABBREVIATIONS: pattern = re.compile(r'\b' + re.escape(abbrev) + r'\.', re.IGNORECASE) protected = pattern.sub(abbrev.replace('.', '') + '', protected) # Protect decimals and ordinals protected = re.sub(r'(\d)\.(\d)', r'\1\2', protected) protected = re.sub(r'(\d+)\.(\s)', r'\1\2', protected) # Split on sentence endings sentence_pattern = r'(?<=[.!?])\s+(?=[A-ZÄÖÜ])|(?<=[.!?])$' raw_sentences = re.split(sentence_pattern, protected) # Restore protected characters sentences = [] for s in raw_sentences: s = s.replace('', '.').replace('', '.').replace('', '.').replace('', '.') s = s.strip() if s: sentences.append(s) # Build chunks chunks = [] current_parts = [] current_length = 0 overlap_buffer = [] for sentence in sentences: sentence_len = len(sentence) if sentence_len > chunk_size: if current_parts: chunks.append(' '.join(current_parts)) overlap_buffer = current_parts[-overlap_sentences:] if overlap_sentences > 0 else [] current_parts = list(overlap_buffer) current_length = sum(len(s) + 1 for s in current_parts) if overlap_buffer: chunks.append(' '.join(overlap_buffer) + ' ' + sentence) else: chunks.append(sentence) overlap_buffer = [sentence] current_parts = list(overlap_buffer) current_length = len(sentence) + 1 continue if current_length + sentence_len + 1 > chunk_size and current_parts: chunks.append(' '.join(current_parts)) overlap_buffer = current_parts[-overlap_sentences:] if overlap_sentences > 0 else [] current_parts = list(overlap_buffer) current_length = sum(len(s) + 1 for s in current_parts) current_parts.append(sentence) current_length += sentence_len + 1 if current_parts: chunks.append(' '.join(current_parts)) return [re.sub(r'\s+', ' ', c).strip() for c in chunks if c.strip()] # ============================================================================= # PDF Extraction Functions # ============================================================================= def detect_pdf_backends() -> List[str]: """Detect available PDF backends.""" available = [] try: from unstructured.partition.pdf import partition_pdf available.append("unstructured") except ImportError: pass try: from pypdf import PdfReader available.append("pypdf") except ImportError: pass return available def extract_pdf_unstructured(pdf_content: bytes) -> ExtractPDFResponse: """Extract PDF using Unstructured.""" import tempfile from unstructured.partition.pdf import partition_pdf from unstructured.documents.elements import Title, ListItem, Table, Header, Footer with tempfile.NamedTemporaryFile(suffix=".pdf", delete=False) as tmp: tmp.write(pdf_content) tmp_path = tmp.name try: elements = partition_pdf( filename=tmp_path, strategy="auto", include_page_breaks=True, infer_table_structure=True, languages=["deu", "eng"], ) text_parts = [] tables = [] page_count = 1 for element in elements: if hasattr(element, "metadata") and hasattr(element.metadata, "page_number"): page_count = max(page_count, element.metadata.page_number or 1) if isinstance(element, (Header, Footer)): continue element_text = str(element) if isinstance(element, Table): tables.append(element_text) text_parts.append(f"\n[TABELLE]\n{element_text}\n[/TABELLE]\n") elif isinstance(element, Title): text_parts.append(f"\n## {element_text}\n") elif isinstance(element, ListItem): text_parts.append(f"• {element_text}") else: text_parts.append(element_text) return ExtractPDFResponse( text="\n".join(text_parts), backend_used="unstructured", pages=page_count, table_count=len(tables) ) finally: import os as os_module try: os_module.unlink(tmp_path) except: pass def extract_pdf_pypdf(pdf_content: bytes) -> ExtractPDFResponse: """Extract PDF using pypdf.""" import io from pypdf import PdfReader pdf_file = io.BytesIO(pdf_content) reader = PdfReader(pdf_file) text_parts = [] for page in reader.pages: text = page.extract_text() if text: text_parts.append(text) return ExtractPDFResponse( text="\n\n".join(text_parts), backend_used="pypdf", pages=len(reader.pages), table_count=0 ) # ============================================================================= # Application Lifecycle # ============================================================================= @asynccontextmanager async def lifespan(app: FastAPI): """Preload models on startup.""" logger.info("Starting Embedding Service...") if config.EMBEDDING_BACKEND == "local": try: get_embedding_model() logger.info("Embedding model preloaded") except Exception as e: logger.warning(f"Failed to preload embedding model: {e}") if config.RERANKER_BACKEND == "local": try: get_reranker_model() logger.info("Reranker model preloaded") except Exception as e: logger.warning(f"Failed to preload reranker model: {e}") logger.info("Embedding Service ready") yield logger.info("Shutting down Embedding Service") # ============================================================================= # FastAPI Application # ============================================================================= app = FastAPI( title="Embedding Service", description="ML service for embeddings, re-ranking, and PDF extraction", version="1.0.0", lifespan=lifespan ) app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # ============================================================================= # Endpoints # ============================================================================= @app.get("/health", response_model=HealthResponse) async def health_check(): """Health check endpoint.""" return HealthResponse( status="healthy", embedding_model=config.LOCAL_EMBEDDING_MODEL if config.EMBEDDING_BACKEND == "local" else config.OPENAI_EMBEDDING_MODEL, embedding_dimensions=config.get_current_dimensions(), reranker_model=config.LOCAL_RERANKER_MODEL if config.RERANKER_BACKEND == "local" else "cohere-rerank", pdf_backends=detect_pdf_backends() ) @app.get("/models", response_model=ModelsResponse) async def get_models(): """Get information about configured models.""" return ModelsResponse( embedding_backend=config.EMBEDDING_BACKEND, embedding_model=config.LOCAL_EMBEDDING_MODEL if config.EMBEDDING_BACKEND == "local" else config.OPENAI_EMBEDDING_MODEL, embedding_dimensions=config.get_current_dimensions(), reranker_backend=config.RERANKER_BACKEND, reranker_model=config.LOCAL_RERANKER_MODEL if config.RERANKER_BACKEND == "local" else "cohere-rerank-multilingual-v3.0", pdf_backend=config.PDF_EXTRACTION_BACKEND, available_pdf_backends=detect_pdf_backends() ) @app.post("/embed", response_model=EmbedResponse) async def embed_texts(request: EmbedRequest): """Generate embeddings for multiple texts.""" if not request.texts: return EmbedResponse( embeddings=[], model=config.LOCAL_EMBEDDING_MODEL if config.EMBEDDING_BACKEND == "local" else config.OPENAI_EMBEDDING_MODEL, dimensions=config.get_current_dimensions() ) try: if config.EMBEDDING_BACKEND == "local": embeddings = generate_local_embeddings(request.texts) else: embeddings = await generate_openai_embeddings(request.texts) return EmbedResponse( embeddings=embeddings, model=config.LOCAL_EMBEDDING_MODEL if config.EMBEDDING_BACKEND == "local" else config.OPENAI_EMBEDDING_MODEL, dimensions=config.get_current_dimensions() ) except Exception as e: logger.error(f"Embedding error: {e}") raise HTTPException(status_code=500, detail=str(e)) @app.post("/embed-single", response_model=EmbedSingleResponse) async def embed_single_text(request: EmbedSingleRequest): """Generate embedding for a single text.""" try: if config.EMBEDDING_BACKEND == "local": embeddings = generate_local_embeddings([request.text]) else: embeddings = await generate_openai_embeddings([request.text]) return EmbedSingleResponse( embedding=embeddings[0] if embeddings else [], model=config.LOCAL_EMBEDDING_MODEL if config.EMBEDDING_BACKEND == "local" else config.OPENAI_EMBEDDING_MODEL, dimensions=config.get_current_dimensions() ) except Exception as e: logger.error(f"Embedding error: {e}") raise HTTPException(status_code=500, detail=str(e)) @app.post("/rerank", response_model=RerankResponse) async def rerank_documents(request: RerankRequest): """Re-rank documents based on query relevance.""" if not request.documents: return RerankResponse( results=[], model=config.LOCAL_RERANKER_MODEL if config.RERANKER_BACKEND == "local" else "cohere-rerank" ) try: if config.RERANKER_BACKEND == "local": results = rerank_local(request.query, request.documents, request.top_k) else: results = await rerank_cohere(request.query, request.documents, request.top_k) return RerankResponse( results=results, model=config.LOCAL_RERANKER_MODEL if config.RERANKER_BACKEND == "local" else "cohere-rerank-multilingual-v3.0" ) except Exception as e: logger.error(f"Rerank error: {e}") raise HTTPException(status_code=500, detail=str(e)) @app.post("/chunk", response_model=ChunkResponse) async def chunk_text(request: ChunkRequest): """Chunk text into smaller pieces.""" if not request.text: return ChunkResponse(chunks=[], count=0, strategy=request.strategy) try: if request.strategy == "semantic": overlap_sentences = max(1, request.overlap // 100) chunks = chunk_text_semantic(request.text, request.chunk_size, overlap_sentences) else: chunks = chunk_text_recursive(request.text, request.chunk_size, request.overlap) return ChunkResponse( chunks=chunks, count=len(chunks), strategy=request.strategy ) except Exception as e: logger.error(f"Chunking error: {e}") raise HTTPException(status_code=500, detail=str(e)) @app.post("/extract-pdf", response_model=ExtractPDFResponse) async def extract_pdf(file: UploadFile = File(...)): """Extract text from PDF file.""" pdf_content = await file.read() available = detect_pdf_backends() if not available: raise HTTPException( status_code=500, detail="No PDF backend available. Install: pip install pypdf unstructured" ) backend = config.PDF_EXTRACTION_BACKEND if backend == "auto": backend = "unstructured" if "unstructured" in available else "pypdf" try: if backend == "unstructured" and "unstructured" in available: return extract_pdf_unstructured(pdf_content) elif "pypdf" in available: return extract_pdf_pypdf(pdf_content) else: raise HTTPException(status_code=500, detail=f"Backend {backend} not available") except Exception as e: logger.error(f"PDF extraction error: {e}") raise HTTPException(status_code=500, detail=str(e)) # ============================================================================= # Main # ============================================================================= if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=config.SERVICE_PORT)