Compare commits
4 Commits
0ac23089f4
...
5c8307f58a
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5c8307f58a | ||
|
|
92ca5b7ba5 | ||
|
|
d7cc6bfbc7 | ||
|
|
13ba1457b0 |
@@ -385,8 +385,12 @@ services:
|
|||||||
MINIO_BUCKET: ${MINIO_BUCKET:-breakpilot-rag}
|
MINIO_BUCKET: ${MINIO_BUCKET:-breakpilot-rag}
|
||||||
MINIO_SECURE: "false"
|
MINIO_SECURE: "false"
|
||||||
EMBEDDING_SERVICE_URL: http://embedding-service:8087
|
EMBEDDING_SERVICE_URL: http://embedding-service:8087
|
||||||
|
OLLAMA_URL: ${OLLAMA_URL:-http://host.docker.internal:11434}
|
||||||
|
OLLAMA_EMBED_MODEL: ${OLLAMA_EMBED_MODEL:-bge-m3}
|
||||||
JWT_SECRET: ${JWT_SECRET:-your-super-secret-jwt-key-change-in-production}
|
JWT_SECRET: ${JWT_SECRET:-your-super-secret-jwt-key-change-in-production}
|
||||||
ENVIRONMENT: ${ENVIRONMENT:-development}
|
ENVIRONMENT: ${ENVIRONMENT:-development}
|
||||||
|
extra_hosts:
|
||||||
|
- "host.docker.internal:host-gateway"
|
||||||
depends_on:
|
depends_on:
|
||||||
qdrant:
|
qdrant:
|
||||||
condition: service_healthy
|
condition: service_healthy
|
||||||
@@ -414,7 +418,7 @@ services:
|
|||||||
- embedding_models:/root/.cache/huggingface
|
- embedding_models:/root/.cache/huggingface
|
||||||
environment:
|
environment:
|
||||||
EMBEDDING_BACKEND: ${EMBEDDING_BACKEND:-local}
|
EMBEDDING_BACKEND: ${EMBEDDING_BACKEND:-local}
|
||||||
LOCAL_EMBEDDING_MODEL: ${LOCAL_EMBEDDING_MODEL:-sentence-transformers/all-MiniLM-L6-v2}
|
LOCAL_EMBEDDING_MODEL: ${LOCAL_EMBEDDING_MODEL:-BAAI/bge-m3}
|
||||||
LOCAL_RERANKER_MODEL: ${LOCAL_RERANKER_MODEL:-cross-encoder/ms-marco-MiniLM-L-6-v2}
|
LOCAL_RERANKER_MODEL: ${LOCAL_RERANKER_MODEL:-cross-encoder/ms-marco-MiniLM-L-6-v2}
|
||||||
PDF_EXTRACTION_BACKEND: ${PDF_EXTRACTION_BACKEND:-pymupdf}
|
PDF_EXTRACTION_BACKEND: ${PDF_EXTRACTION_BACKEND:-pymupdf}
|
||||||
OPENAI_API_KEY: ${OPENAI_API_KEY:-}
|
OPENAI_API_KEY: ${OPENAI_API_KEY:-}
|
||||||
@@ -423,7 +427,7 @@ services:
|
|||||||
deploy:
|
deploy:
|
||||||
resources:
|
resources:
|
||||||
limits:
|
limits:
|
||||||
memory: 4G
|
memory: 8G
|
||||||
healthcheck:
|
healthcheck:
|
||||||
test: ["CMD", "python", "-c", "import httpx; r=httpx.get('http://127.0.0.1:8087/health'); r.raise_for_status()"]
|
test: ["CMD", "python", "-c", "import httpx; r=httpx.get('http://127.0.0.1:8087/health'); r.raise_for_status()"]
|
||||||
interval: 30s
|
interval: 30s
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
@@ -8,44 +9,82 @@ from config import settings
|
|||||||
logger = logging.getLogger("rag-service.embedding")
|
logger = logging.getLogger("rag-service.embedding")
|
||||||
|
|
||||||
_TIMEOUT = httpx.Timeout(timeout=120.0, connect=10.0)
|
_TIMEOUT = httpx.Timeout(timeout=120.0, connect=10.0)
|
||||||
|
_EMBED_TIMEOUT = httpx.Timeout(timeout=300.0, connect=10.0)
|
||||||
|
|
||||||
|
# Ollama config for embeddings (bge-m3, 1024-dim)
|
||||||
|
_OLLAMA_URL = os.getenv("OLLAMA_URL", "http://ollama:11434")
|
||||||
|
_OLLAMA_EMBED_MODEL = os.getenv("OLLAMA_EMBED_MODEL", "bge-m3")
|
||||||
|
|
||||||
|
# Batch size for Ollama embedding requests
|
||||||
|
_EMBED_BATCH_SIZE = int(os.getenv("EMBED_BATCH_SIZE", "32"))
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingClient:
|
class EmbeddingClient:
|
||||||
"""HTTP client for the embedding-service (port 8087)."""
|
"""
|
||||||
|
Hybrid client:
|
||||||
|
- Embeddings via Ollama (bge-m3, 1024-dim) for Qdrant compatibility
|
||||||
|
- Chunking + PDF extraction via embedding-service (port 8087)
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self._base_url: str = settings.EMBEDDING_SERVICE_URL.rstrip("/")
|
self._embed_svc_url: str = settings.EMBEDDING_SERVICE_URL.rstrip("/")
|
||||||
|
self._ollama_url: str = _OLLAMA_URL.rstrip("/")
|
||||||
|
self._embed_model: str = _OLLAMA_EMBED_MODEL
|
||||||
|
|
||||||
def _url(self, path: str) -> str:
|
def _svc_url(self, path: str) -> str:
|
||||||
return f"{self._base_url}{path}"
|
return f"{self._embed_svc_url}{path}"
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
# Embeddings
|
# Embeddings (via Ollama)
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
async def generate_embeddings(self, texts: list[str]) -> list[list[float]]:
|
async def generate_embeddings(self, texts: list[str]) -> list[list[float]]:
|
||||||
"""
|
"""
|
||||||
Send a batch of texts to the embedding service and return a list of
|
Generate embeddings via Ollama's bge-m3 model.
|
||||||
embedding vectors.
|
Processes in batches to avoid timeout on large uploads.
|
||||||
"""
|
"""
|
||||||
async with httpx.AsyncClient(timeout=_TIMEOUT) as client:
|
all_embeddings: list[list[float]] = []
|
||||||
response = await client.post(
|
|
||||||
self._url("/api/v1/embeddings"),
|
for i in range(0, len(texts), _EMBED_BATCH_SIZE):
|
||||||
json={"texts": texts},
|
batch = texts[i : i + _EMBED_BATCH_SIZE]
|
||||||
)
|
batch_embeddings = []
|
||||||
response.raise_for_status()
|
|
||||||
data = response.json()
|
async with httpx.AsyncClient(timeout=_EMBED_TIMEOUT) as client:
|
||||||
return data.get("embeddings", [])
|
for text in batch:
|
||||||
|
response = await client.post(
|
||||||
|
f"{self._ollama_url}/api/embeddings",
|
||||||
|
json={
|
||||||
|
"model": self._embed_model,
|
||||||
|
"prompt": text,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
data = response.json()
|
||||||
|
embedding = data.get("embedding", [])
|
||||||
|
if not embedding:
|
||||||
|
raise ValueError(
|
||||||
|
f"Ollama returned empty embedding for model {self._embed_model}"
|
||||||
|
)
|
||||||
|
batch_embeddings.append(embedding)
|
||||||
|
|
||||||
|
all_embeddings.extend(batch_embeddings)
|
||||||
|
|
||||||
|
if i + _EMBED_BATCH_SIZE < len(texts):
|
||||||
|
logger.info(
|
||||||
|
"Embedding progress: %d/%d", len(all_embeddings), len(texts)
|
||||||
|
)
|
||||||
|
|
||||||
|
return all_embeddings
|
||||||
|
|
||||||
async def generate_single_embedding(self, text: str) -> list[float]:
|
async def generate_single_embedding(self, text: str) -> list[float]:
|
||||||
"""Convenience wrapper for a single text."""
|
"""Convenience wrapper for a single text."""
|
||||||
results = await self.generate_embeddings([text])
|
results = await self.generate_embeddings([text])
|
||||||
if not results:
|
if not results:
|
||||||
raise ValueError("Embedding service returned empty result")
|
raise ValueError("Ollama returned empty result")
|
||||||
return results[0]
|
return results[0]
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
# Reranking
|
# Reranking (via embedding-service)
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
async def rerank_documents(
|
async def rerank_documents(
|
||||||
@@ -60,7 +99,7 @@ class EmbeddingClient:
|
|||||||
"""
|
"""
|
||||||
async with httpx.AsyncClient(timeout=_TIMEOUT) as client:
|
async with httpx.AsyncClient(timeout=_TIMEOUT) as client:
|
||||||
response = await client.post(
|
response = await client.post(
|
||||||
self._url("/api/v1/rerank"),
|
self._svc_url("/rerank"),
|
||||||
json={
|
json={
|
||||||
"query": query,
|
"query": query,
|
||||||
"documents": documents,
|
"documents": documents,
|
||||||
@@ -72,7 +111,7 @@ class EmbeddingClient:
|
|||||||
return data.get("results", [])
|
return data.get("results", [])
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
# Chunking
|
# Chunking (via embedding-service)
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
async def chunk_text(
|
async def chunk_text(
|
||||||
@@ -88,7 +127,7 @@ class EmbeddingClient:
|
|||||||
"""
|
"""
|
||||||
async with httpx.AsyncClient(timeout=_TIMEOUT) as client:
|
async with httpx.AsyncClient(timeout=_TIMEOUT) as client:
|
||||||
response = await client.post(
|
response = await client.post(
|
||||||
self._url("/api/v1/chunk"),
|
self._svc_url("/chunk"),
|
||||||
json={
|
json={
|
||||||
"text": text,
|
"text": text,
|
||||||
"strategy": strategy,
|
"strategy": strategy,
|
||||||
@@ -101,7 +140,7 @@ class EmbeddingClient:
|
|||||||
return data.get("chunks", [])
|
return data.get("chunks", [])
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
# PDF extraction
|
# PDF extraction (via embedding-service)
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
async def extract_pdf(self, pdf_bytes: bytes) -> str:
|
async def extract_pdf(self, pdf_bytes: bytes) -> str:
|
||||||
@@ -111,7 +150,7 @@ class EmbeddingClient:
|
|||||||
"""
|
"""
|
||||||
async with httpx.AsyncClient(timeout=_TIMEOUT) as client:
|
async with httpx.AsyncClient(timeout=_TIMEOUT) as client:
|
||||||
response = await client.post(
|
response = await client.post(
|
||||||
self._url("/api/v1/extract-pdf"),
|
self._svc_url("/extract-pdf"),
|
||||||
files={"file": ("document.pdf", pdf_bytes, "application/pdf")},
|
files={"file": ("document.pdf", pdf_bytes, "application/pdf")},
|
||||||
)
|
)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
|
|||||||
@@ -167,12 +167,13 @@ class QdrantClientWrapper:
|
|||||||
)
|
)
|
||||||
qdrant_filter = qmodels.Filter(must=must_conditions)
|
qdrant_filter = qmodels.Filter(must=must_conditions)
|
||||||
|
|
||||||
results = self.client.search(
|
results = self.client.query_points(
|
||||||
collection_name=collection,
|
collection_name=collection,
|
||||||
query_vector=query_vector,
|
query=query_vector,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
query_filter=qdrant_filter,
|
query_filter=qdrant_filter,
|
||||||
score_threshold=score_threshold,
|
score_threshold=score_threshold,
|
||||||
|
with_payload=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
return [
|
return [
|
||||||
@@ -181,7 +182,7 @@ class QdrantClientWrapper:
|
|||||||
"score": hit.score,
|
"score": hit.score,
|
||||||
"payload": hit.payload or {},
|
"payload": hit.payload or {},
|
||||||
}
|
}
|
||||||
for hit in results
|
for hit in results.points
|
||||||
]
|
]
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
|
|||||||
Reference in New Issue
Block a user