import logging import os from typing import Optional import httpx from config import settings logger = logging.getLogger("rag-service.embedding") _TIMEOUT = httpx.Timeout(timeout=120.0, connect=10.0) _EMBED_TIMEOUT = httpx.Timeout(timeout=300.0, connect=10.0) # Ollama config for embeddings (bge-m3, 1024-dim) _OLLAMA_URL = os.getenv("OLLAMA_URL", "http://ollama:11434") _OLLAMA_EMBED_MODEL = os.getenv("OLLAMA_EMBED_MODEL", "bge-m3") # Batch size for Ollama embedding requests _EMBED_BATCH_SIZE = int(os.getenv("EMBED_BATCH_SIZE", "32")) class EmbeddingClient: """ Hybrid client: - Embeddings via Ollama (bge-m3, 1024-dim) for Qdrant compatibility - Chunking + PDF extraction via embedding-service (port 8087) """ def __init__(self) -> None: self._embed_svc_url: str = settings.EMBEDDING_SERVICE_URL.rstrip("/") self._ollama_url: str = _OLLAMA_URL.rstrip("/") self._embed_model: str = _OLLAMA_EMBED_MODEL def _svc_url(self, path: str) -> str: return f"{self._embed_svc_url}{path}" # ------------------------------------------------------------------ # Embeddings (via Ollama) # ------------------------------------------------------------------ async def generate_embeddings(self, texts: list[str]) -> list[list[float]]: """ Generate embeddings via Ollama's bge-m3 model. Processes in batches to avoid timeout on large uploads. """ all_embeddings: list[list[float]] = [] for i in range(0, len(texts), _EMBED_BATCH_SIZE): batch = texts[i : i + _EMBED_BATCH_SIZE] batch_embeddings = [] async with httpx.AsyncClient(timeout=_EMBED_TIMEOUT) as client: for text in batch: response = await client.post( f"{self._ollama_url}/api/embeddings", json={ "model": self._embed_model, "prompt": text, }, ) response.raise_for_status() data = response.json() embedding = data.get("embedding", []) if not embedding: raise ValueError( f"Ollama returned empty embedding for model {self._embed_model}" ) batch_embeddings.append(embedding) all_embeddings.extend(batch_embeddings) if i + _EMBED_BATCH_SIZE < len(texts): logger.info( "Embedding progress: %d/%d", len(all_embeddings), len(texts) ) return all_embeddings async def generate_single_embedding(self, text: str) -> list[float]: """Convenience wrapper for a single text.""" results = await self.generate_embeddings([text]) if not results: raise ValueError("Ollama returned empty result") return results[0] # ------------------------------------------------------------------ # Reranking (via embedding-service) # ------------------------------------------------------------------ async def rerank_documents( self, query: str, documents: list[str], top_k: int = 10, ) -> list[dict]: """ Ask the embedding service to re-rank documents for a given query. Returns a list of {index, score, text}. """ async with httpx.AsyncClient(timeout=_TIMEOUT) as client: response = await client.post( self._svc_url("/rerank"), json={ "query": query, "documents": documents, "top_k": top_k, }, ) response.raise_for_status() data = response.json() return data.get("results", []) # ------------------------------------------------------------------ # Chunking (via embedding-service) # ------------------------------------------------------------------ async def chunk_text( self, text: str, strategy: str = "recursive", chunk_size: int = 512, overlap: int = 50, ) -> list[str]: """ Ask the embedding service to chunk a long text. Returns a list of chunk strings. """ async with httpx.AsyncClient(timeout=_TIMEOUT) as client: response = await client.post( self._svc_url("/chunk"), json={ "text": text, "strategy": strategy, "chunk_size": chunk_size, "overlap": overlap, }, ) response.raise_for_status() data = response.json() return data.get("chunks", []) # ------------------------------------------------------------------ # PDF extraction (via embedding-service) # ------------------------------------------------------------------ async def extract_pdf(self, pdf_bytes: bytes) -> str: """ Send raw PDF bytes to the embedding service for text extraction. Returns the extracted text as a string. """ async with httpx.AsyncClient(timeout=_TIMEOUT) as client: response = await client.post( self._svc_url("/extract-pdf"), files={"file": ("document.pdf", pdf_bytes, "application/pdf")}, ) response.raise_for_status() data = response.json() return data.get("text", "") # Singleton embedding_client = EmbeddingClient()