Switch to Ollama's bge-m3 model (1024-dim) for generating embeddings, solving the dimension mismatch with Qdrant collections. Embedding-service still used for chunking, reranking, and PDF extraction. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
163 lines
5.6 KiB
Python
163 lines
5.6 KiB
Python
import logging
|
|
import os
|
|
from typing import Optional
|
|
|
|
import httpx
|
|
|
|
from config import settings
|
|
|
|
logger = logging.getLogger("rag-service.embedding")
|
|
|
|
_TIMEOUT = httpx.Timeout(timeout=120.0, connect=10.0)
|
|
_EMBED_TIMEOUT = httpx.Timeout(timeout=300.0, connect=10.0)
|
|
|
|
# Ollama config for embeddings (bge-m3, 1024-dim)
|
|
_OLLAMA_URL = os.getenv("OLLAMA_URL", "http://ollama:11434")
|
|
_OLLAMA_EMBED_MODEL = os.getenv("OLLAMA_EMBED_MODEL", "bge-m3")
|
|
|
|
# Batch size for Ollama embedding requests
|
|
_EMBED_BATCH_SIZE = int(os.getenv("EMBED_BATCH_SIZE", "32"))
|
|
|
|
|
|
class EmbeddingClient:
|
|
"""
|
|
Hybrid client:
|
|
- Embeddings via Ollama (bge-m3, 1024-dim) for Qdrant compatibility
|
|
- Chunking + PDF extraction via embedding-service (port 8087)
|
|
"""
|
|
|
|
def __init__(self) -> None:
|
|
self._embed_svc_url: str = settings.EMBEDDING_SERVICE_URL.rstrip("/")
|
|
self._ollama_url: str = _OLLAMA_URL.rstrip("/")
|
|
self._embed_model: str = _OLLAMA_EMBED_MODEL
|
|
|
|
def _svc_url(self, path: str) -> str:
|
|
return f"{self._embed_svc_url}{path}"
|
|
|
|
# ------------------------------------------------------------------
|
|
# Embeddings (via Ollama)
|
|
# ------------------------------------------------------------------
|
|
|
|
async def generate_embeddings(self, texts: list[str]) -> list[list[float]]:
|
|
"""
|
|
Generate embeddings via Ollama's bge-m3 model.
|
|
Processes in batches to avoid timeout on large uploads.
|
|
"""
|
|
all_embeddings: list[list[float]] = []
|
|
|
|
for i in range(0, len(texts), _EMBED_BATCH_SIZE):
|
|
batch = texts[i : i + _EMBED_BATCH_SIZE]
|
|
batch_embeddings = []
|
|
|
|
async with httpx.AsyncClient(timeout=_EMBED_TIMEOUT) as client:
|
|
for text in batch:
|
|
response = await client.post(
|
|
f"{self._ollama_url}/api/embeddings",
|
|
json={
|
|
"model": self._embed_model,
|
|
"prompt": text,
|
|
},
|
|
)
|
|
response.raise_for_status()
|
|
data = response.json()
|
|
embedding = data.get("embedding", [])
|
|
if not embedding:
|
|
raise ValueError(
|
|
f"Ollama returned empty embedding for model {self._embed_model}"
|
|
)
|
|
batch_embeddings.append(embedding)
|
|
|
|
all_embeddings.extend(batch_embeddings)
|
|
|
|
if i + _EMBED_BATCH_SIZE < len(texts):
|
|
logger.info(
|
|
"Embedding progress: %d/%d", len(all_embeddings), len(texts)
|
|
)
|
|
|
|
return all_embeddings
|
|
|
|
async def generate_single_embedding(self, text: str) -> list[float]:
|
|
"""Convenience wrapper for a single text."""
|
|
results = await self.generate_embeddings([text])
|
|
if not results:
|
|
raise ValueError("Ollama returned empty result")
|
|
return results[0]
|
|
|
|
# ------------------------------------------------------------------
|
|
# Reranking (via embedding-service)
|
|
# ------------------------------------------------------------------
|
|
|
|
async def rerank_documents(
|
|
self,
|
|
query: str,
|
|
documents: list[str],
|
|
top_k: int = 10,
|
|
) -> list[dict]:
|
|
"""
|
|
Ask the embedding service to re-rank documents for a given query.
|
|
Returns a list of {index, score, text}.
|
|
"""
|
|
async with httpx.AsyncClient(timeout=_TIMEOUT) as client:
|
|
response = await client.post(
|
|
self._svc_url("/rerank"),
|
|
json={
|
|
"query": query,
|
|
"documents": documents,
|
|
"top_k": top_k,
|
|
},
|
|
)
|
|
response.raise_for_status()
|
|
data = response.json()
|
|
return data.get("results", [])
|
|
|
|
# ------------------------------------------------------------------
|
|
# Chunking (via embedding-service)
|
|
# ------------------------------------------------------------------
|
|
|
|
async def chunk_text(
|
|
self,
|
|
text: str,
|
|
strategy: str = "recursive",
|
|
chunk_size: int = 512,
|
|
overlap: int = 50,
|
|
) -> list[str]:
|
|
"""
|
|
Ask the embedding service to chunk a long text.
|
|
Returns a list of chunk strings.
|
|
"""
|
|
async with httpx.AsyncClient(timeout=_TIMEOUT) as client:
|
|
response = await client.post(
|
|
self._svc_url("/chunk"),
|
|
json={
|
|
"text": text,
|
|
"strategy": strategy,
|
|
"chunk_size": chunk_size,
|
|
"overlap": overlap,
|
|
},
|
|
)
|
|
response.raise_for_status()
|
|
data = response.json()
|
|
return data.get("chunks", [])
|
|
|
|
# ------------------------------------------------------------------
|
|
# PDF extraction (via embedding-service)
|
|
# ------------------------------------------------------------------
|
|
|
|
async def extract_pdf(self, pdf_bytes: bytes) -> str:
|
|
"""
|
|
Send raw PDF bytes to the embedding service for text extraction.
|
|
Returns the extracted text as a string.
|
|
"""
|
|
async with httpx.AsyncClient(timeout=_TIMEOUT) as client:
|
|
response = await client.post(
|
|
self._svc_url("/extract-pdf"),
|
|
files={"file": ("document.pdf", pdf_bytes, "application/pdf")},
|
|
)
|
|
response.raise_for_status()
|
|
data = response.json()
|
|
return data.get("text", "")
|
|
|
|
|
|
# Singleton
|
|
embedding_client = EmbeddingClient()
|