feat(klausur-service): Add Tesseract OCR, DSFA RAG, TrOCR, grid detection and vocab session store

New modules:
- tesseract_vocab_extractor.py: Bounding-box OCR with multi-PSM pipeline
- grid_detection_service.py: CV-based grid/table detection for worksheets
- vocab_session_store.py: PostgreSQL persistence for vocab sessions
- trocr_api.py: TrOCR handwriting recognition endpoint
- dsfa_rag_api.py + dsfa_corpus_ingestion.py: DSFA RAG corpus search

Changes:
- Dockerfile: Install tesseract-ocr + deu/eng language packs
- requirements.txt: Add PyMuPDF, pytesseract, Pillow
- main.py: Register new routers, init DB pools + Qdrant collections

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
BreakPilot Dev
2026-02-10 00:00:19 +01:00
parent 46cb873190
commit 53219e3eaf
9 changed files with 3829 additions and 4 deletions

View File

@@ -13,9 +13,12 @@ FROM python:3.11-slim
WORKDIR /app
# Install system dependencies
# Install system dependencies (incl. Tesseract OCR for bounding-box extraction)
RUN apt-get update && apt-get install -y --no-install-recommends \
curl \
tesseract-ocr \
tesseract-ocr-deu \
tesseract-ocr-eng \
&& rm -rf /var/lib/apt/lists/*
# Install Python dependencies

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,715 @@
"""
DSFA RAG API Endpoints.
Provides REST API for searching DSFA corpus with full source attribution.
Endpoints:
- GET /api/v1/dsfa-rag/search - Semantic search with attribution
- GET /api/v1/dsfa-rag/sources - List all registered sources
- POST /api/v1/dsfa-rag/sources/{code}/ingest - Trigger source ingestion
- GET /api/v1/dsfa-rag/chunks/{id} - Get single chunk with attribution
- GET /api/v1/dsfa-rag/stats - Get corpus statistics
"""
import os
import uuid
import logging
from typing import List, Optional
from dataclasses import dataclass, asdict
import httpx
from fastapi import APIRouter, HTTPException, Query, Depends
from pydantic import BaseModel, Field
logger = logging.getLogger(__name__)
# Embedding service configuration
EMBEDDING_SERVICE_URL = os.getenv("EMBEDDING_SERVICE_URL", "http://172.18.0.13:8087")
# Import from ingestion module
from dsfa_corpus_ingestion import (
DSFACorpusStore,
DSFAQdrantService,
DSFASearchResult,
LICENSE_REGISTRY,
DSFA_SOURCES,
generate_attribution_notice,
get_license_label,
DSFA_COLLECTION,
chunk_document
)
router = APIRouter(prefix="/api/v1/dsfa-rag", tags=["DSFA RAG"])
# =============================================================================
# Pydantic Models
# =============================================================================
class DSFASourceResponse(BaseModel):
"""Response model for DSFA source."""
id: str
source_code: str
name: str
full_name: Optional[str] = None
organization: Optional[str] = None
source_url: Optional[str] = None
license_code: str
license_name: str
license_url: Optional[str] = None
attribution_required: bool
attribution_text: str
document_type: Optional[str] = None
language: str = "de"
class DSFAChunkResponse(BaseModel):
"""Response model for a single chunk with attribution."""
chunk_id: str
content: str
section_title: Optional[str] = None
page_number: Optional[int] = None
category: Optional[str] = None
# Document info
document_id: str
document_title: Optional[str] = None
# Attribution (always included)
source_id: str
source_code: str
source_name: str
attribution_text: str
license_code: str
license_name: str
license_url: Optional[str] = None
attribution_required: bool
source_url: Optional[str] = None
document_type: Optional[str] = None
class DSFASearchResultResponse(BaseModel):
"""Response model for search result."""
chunk_id: str
content: str
score: float
# Attribution
source_code: str
source_name: str
attribution_text: str
license_code: str
license_name: str
license_url: Optional[str] = None
attribution_required: bool
source_url: Optional[str] = None
# Metadata
document_type: Optional[str] = None
category: Optional[str] = None
section_title: Optional[str] = None
page_number: Optional[int] = None
class DSFASearchResponse(BaseModel):
"""Response model for search endpoint."""
query: str
results: List[DSFASearchResultResponse]
total_results: int
# Aggregated licenses for footer
licenses_used: List[str]
attribution_notice: str
class DSFASourceStatsResponse(BaseModel):
"""Response model for source statistics."""
source_id: str
source_code: str
name: str
organization: Optional[str] = None
license_code: str
document_type: Optional[str] = None
document_count: int
chunk_count: int
last_indexed_at: Optional[str] = None
class DSFACorpusStatsResponse(BaseModel):
"""Response model for corpus statistics."""
sources: List[DSFASourceStatsResponse]
total_sources: int
total_documents: int
total_chunks: int
qdrant_collection: str
qdrant_points_count: int
qdrant_status: str
class IngestRequest(BaseModel):
"""Request model for ingestion."""
document_url: Optional[str] = None
document_text: Optional[str] = None
title: Optional[str] = None
class IngestResponse(BaseModel):
"""Response model for ingestion."""
source_code: str
document_id: Optional[str] = None
chunks_created: int
message: str
class LicenseInfo(BaseModel):
"""License information."""
code: str
name: str
url: Optional[str] = None
attribution_required: bool
modification_allowed: bool
commercial_use: bool
# =============================================================================
# Dependency Injection
# =============================================================================
# Database pool (will be set from main.py)
_db_pool = None
def set_db_pool(pool):
"""Set the database pool for API endpoints."""
global _db_pool
_db_pool = pool
async def get_store() -> DSFACorpusStore:
"""Get DSFA corpus store."""
if _db_pool is None:
raise HTTPException(status_code=503, detail="Database not initialized")
return DSFACorpusStore(_db_pool)
async def get_qdrant() -> DSFAQdrantService:
"""Get Qdrant service."""
return DSFAQdrantService()
# =============================================================================
# Embedding Service Integration
# =============================================================================
async def get_embedding(text: str) -> List[float]:
"""
Get embedding for text using the embedding-service.
Uses BGE-M3 model which produces 1024-dimensional vectors.
"""
async with httpx.AsyncClient(timeout=60.0) as client:
try:
response = await client.post(
f"{EMBEDDING_SERVICE_URL}/embed-single",
json={"text": text}
)
response.raise_for_status()
data = response.json()
return data.get("embedding", [])
except httpx.HTTPError as e:
logger.error(f"Embedding service error: {e}")
# Fallback to hash-based pseudo-embedding for development
return _generate_fallback_embedding(text)
async def get_embeddings_batch(texts: List[str]) -> List[List[float]]:
"""
Get embeddings for multiple texts in batch.
"""
async with httpx.AsyncClient(timeout=120.0) as client:
try:
response = await client.post(
f"{EMBEDDING_SERVICE_URL}/embed",
json={"texts": texts}
)
response.raise_for_status()
data = response.json()
return data.get("embeddings", [])
except httpx.HTTPError as e:
logger.error(f"Embedding service batch error: {e}")
# Fallback
return [_generate_fallback_embedding(t) for t in texts]
async def extract_text_from_url(url: str) -> str:
"""
Extract text from a document URL (PDF, HTML, etc.).
"""
async with httpx.AsyncClient(timeout=120.0) as client:
try:
# First try to use the embedding-service's extract-pdf endpoint
response = await client.post(
f"{EMBEDDING_SERVICE_URL}/extract-pdf",
json={"url": url}
)
response.raise_for_status()
data = response.json()
return data.get("text", "")
except httpx.HTTPError as e:
logger.error(f"PDF extraction error for {url}: {e}")
# Fallback: try to fetch HTML content directly
try:
response = await client.get(url, follow_redirects=True)
response.raise_for_status()
content_type = response.headers.get("content-type", "")
if "html" in content_type:
# Simple HTML text extraction
import re
html = response.text
# Remove scripts and styles
html = re.sub(r'<script[^>]*>.*?</script>', '', html, flags=re.DOTALL | re.IGNORECASE)
html = re.sub(r'<style[^>]*>.*?</style>', '', html, flags=re.DOTALL | re.IGNORECASE)
# Remove tags
text = re.sub(r'<[^>]+>', ' ', html)
# Clean whitespace
text = re.sub(r'\s+', ' ', text).strip()
return text
else:
return ""
except Exception as fetch_err:
logger.error(f"Fallback fetch error for {url}: {fetch_err}")
return ""
def _generate_fallback_embedding(text: str) -> List[float]:
"""
Generate deterministic pseudo-embedding from text hash.
Used as fallback when embedding service is unavailable.
"""
import hashlib
import struct
hash_bytes = hashlib.sha256(text.encode()).digest()
embedding = []
for i in range(0, min(len(hash_bytes), 128), 4):
val = struct.unpack('f', hash_bytes[i:i+4])[0]
embedding.append(val % 1.0)
# Pad to 1024 dimensions
while len(embedding) < 1024:
embedding.extend(embedding[:min(len(embedding), 1024 - len(embedding))])
return embedding[:1024]
# =============================================================================
# API Endpoints
# =============================================================================
@router.get("/search", response_model=DSFASearchResponse)
async def search_dsfa_corpus(
query: str = Query(..., min_length=3, description="Search query"),
source_codes: Optional[List[str]] = Query(None, description="Filter by source codes"),
document_types: Optional[List[str]] = Query(None, description="Filter by document types (guideline, checklist, regulation)"),
categories: Optional[List[str]] = Query(None, description="Filter by categories (threshold_analysis, risk_assessment, mitigation)"),
limit: int = Query(10, ge=1, le=50, description="Maximum results"),
include_attribution: bool = Query(True, description="Include attribution in results"),
store: DSFACorpusStore = Depends(get_store),
qdrant: DSFAQdrantService = Depends(get_qdrant)
):
"""
Search DSFA corpus with full attribution.
Returns matching chunks with source/license information for compliance.
"""
# Get query embedding
query_embedding = await get_embedding(query)
# Search Qdrant
raw_results = await qdrant.search(
query_embedding=query_embedding,
source_codes=source_codes,
document_types=document_types,
categories=categories,
limit=limit
)
# Transform results
results = []
licenses_used = set()
for r in raw_results:
license_code = r.get("license_code", "")
license_info = LICENSE_REGISTRY.get(license_code, {})
result = DSFASearchResultResponse(
chunk_id=r.get("chunk_id", ""),
content=r.get("content", ""),
score=r.get("score", 0.0),
source_code=r.get("source_code", ""),
source_name=r.get("source_name", ""),
attribution_text=r.get("attribution_text", ""),
license_code=license_code,
license_name=license_info.get("name", license_code),
license_url=license_info.get("url"),
attribution_required=r.get("attribution_required", True),
source_url=r.get("source_url"),
document_type=r.get("document_type"),
category=r.get("category"),
section_title=r.get("section_title"),
page_number=r.get("page_number")
)
results.append(result)
licenses_used.add(license_code)
# Generate attribution notice
search_results = [
DSFASearchResult(
chunk_id=r.chunk_id,
content=r.content,
score=r.score,
source_code=r.source_code,
source_name=r.source_name,
attribution_text=r.attribution_text,
license_code=r.license_code,
license_url=r.license_url,
attribution_required=r.attribution_required,
source_url=r.source_url,
document_type=r.document_type or "",
category=r.category or "",
section_title=r.section_title,
page_number=r.page_number
)
for r in results
]
attribution_notice = generate_attribution_notice(search_results) if include_attribution else ""
return DSFASearchResponse(
query=query,
results=results,
total_results=len(results),
licenses_used=list(licenses_used),
attribution_notice=attribution_notice
)
@router.get("/sources", response_model=List[DSFASourceResponse])
async def list_dsfa_sources(
document_type: Optional[str] = Query(None, description="Filter by document type"),
license_code: Optional[str] = Query(None, description="Filter by license"),
store: DSFACorpusStore = Depends(get_store)
):
"""List all registered DSFA sources with license info."""
sources = await store.list_sources()
result = []
for s in sources:
# Apply filters
if document_type and s.get("document_type") != document_type:
continue
if license_code and s.get("license_code") != license_code:
continue
license_info = LICENSE_REGISTRY.get(s.get("license_code", ""), {})
result.append(DSFASourceResponse(
id=str(s["id"]),
source_code=s["source_code"],
name=s["name"],
full_name=s.get("full_name"),
organization=s.get("organization"),
source_url=s.get("source_url"),
license_code=s.get("license_code", ""),
license_name=license_info.get("name", s.get("license_code", "")),
license_url=license_info.get("url"),
attribution_required=s.get("attribution_required", True),
attribution_text=s.get("attribution_text", ""),
document_type=s.get("document_type"),
language=s.get("language", "de")
))
return result
@router.get("/sources/available")
async def list_available_sources():
"""List all available source definitions (from DSFA_SOURCES constant)."""
return [
{
"source_code": s["source_code"],
"name": s["name"],
"organization": s.get("organization"),
"license_code": s["license_code"],
"document_type": s.get("document_type")
}
for s in DSFA_SOURCES
]
@router.get("/sources/{source_code}", response_model=DSFASourceResponse)
async def get_dsfa_source(
source_code: str,
store: DSFACorpusStore = Depends(get_store)
):
"""Get details for a specific source."""
source = await store.get_source_by_code(source_code)
if not source:
raise HTTPException(status_code=404, detail=f"Source not found: {source_code}")
license_info = LICENSE_REGISTRY.get(source.get("license_code", ""), {})
return DSFASourceResponse(
id=str(source["id"]),
source_code=source["source_code"],
name=source["name"],
full_name=source.get("full_name"),
organization=source.get("organization"),
source_url=source.get("source_url"),
license_code=source.get("license_code", ""),
license_name=license_info.get("name", source.get("license_code", "")),
license_url=license_info.get("url"),
attribution_required=source.get("attribution_required", True),
attribution_text=source.get("attribution_text", ""),
document_type=source.get("document_type"),
language=source.get("language", "de")
)
@router.post("/sources/{source_code}/ingest", response_model=IngestResponse)
async def ingest_dsfa_source(
source_code: str,
request: IngestRequest,
store: DSFACorpusStore = Depends(get_store),
qdrant: DSFAQdrantService = Depends(get_qdrant)
):
"""
Trigger ingestion for a specific source.
Can provide document via URL or direct text.
"""
# Get source
source = await store.get_source_by_code(source_code)
if not source:
raise HTTPException(status_code=404, detail=f"Source not found: {source_code}")
# Need either URL or text
if not request.document_text and not request.document_url:
raise HTTPException(
status_code=400,
detail="Either document_text or document_url must be provided"
)
# Ensure Qdrant collection exists
await qdrant.ensure_collection()
# Get text content
text_content = request.document_text
if request.document_url and not text_content:
# Download and extract text from URL
logger.info(f"Extracting text from URL: {request.document_url}")
text_content = await extract_text_from_url(request.document_url)
if not text_content:
raise HTTPException(
status_code=400,
detail=f"Could not extract text from URL: {request.document_url}"
)
if not text_content or len(text_content.strip()) < 50:
raise HTTPException(status_code=400, detail="Document text too short (min 50 chars)")
# Create document record
doc_title = request.title or f"Document for {source_code}"
document_id = await store.create_document(
source_id=str(source["id"]),
title=doc_title,
file_type="text",
metadata={"ingested_via": "api", "source_code": source_code}
)
# Chunk the document
chunks = chunk_document(text_content, source_code)
if not chunks:
return IngestResponse(
source_code=source_code,
document_id=document_id,
chunks_created=0,
message="Document created but no chunks generated"
)
# Generate embeddings in batch for efficiency
chunk_texts = [chunk["content"] for chunk in chunks]
logger.info(f"Generating embeddings for {len(chunk_texts)} chunks...")
embeddings = await get_embeddings_batch(chunk_texts)
# Create chunk records in PostgreSQL and prepare for Qdrant
chunk_records = []
for i, (chunk, embedding) in enumerate(zip(chunks, embeddings)):
# Create chunk in PostgreSQL
chunk_id = await store.create_chunk(
document_id=document_id,
source_id=str(source["id"]),
content=chunk["content"],
chunk_index=i,
section_title=chunk.get("section_title"),
page_number=chunk.get("page_number"),
category=chunk.get("category")
)
chunk_records.append({
"chunk_id": chunk_id,
"document_id": document_id,
"source_id": str(source["id"]),
"content": chunk["content"],
"section_title": chunk.get("section_title"),
"source_code": source_code,
"source_name": source["name"],
"attribution_text": source["attribution_text"],
"license_code": source["license_code"],
"attribution_required": source.get("attribution_required", True),
"document_type": source.get("document_type", ""),
"category": chunk.get("category", ""),
"language": source.get("language", "de"),
"page_number": chunk.get("page_number")
})
# Index in Qdrant
indexed_count = await qdrant.index_chunks(chunk_records, embeddings)
# Update document record
await store.update_document_indexed(document_id, len(chunks))
return IngestResponse(
source_code=source_code,
document_id=document_id,
chunks_created=indexed_count,
message=f"Successfully ingested {indexed_count} chunks from document"
)
@router.get("/chunks/{chunk_id}", response_model=DSFAChunkResponse)
async def get_chunk_with_attribution(
chunk_id: str,
store: DSFACorpusStore = Depends(get_store)
):
"""Get single chunk with full source attribution."""
chunk = await store.get_chunk_with_attribution(chunk_id)
if not chunk:
raise HTTPException(status_code=404, detail=f"Chunk not found: {chunk_id}")
license_info = LICENSE_REGISTRY.get(chunk.get("license_code", ""), {})
return DSFAChunkResponse(
chunk_id=str(chunk["chunk_id"]),
content=chunk.get("content", ""),
section_title=chunk.get("section_title"),
page_number=chunk.get("page_number"),
category=chunk.get("category"),
document_id=str(chunk.get("document_id", "")),
document_title=chunk.get("document_title"),
source_id=str(chunk.get("source_id", "")),
source_code=chunk.get("source_code", ""),
source_name=chunk.get("source_name", ""),
attribution_text=chunk.get("attribution_text", ""),
license_code=chunk.get("license_code", ""),
license_name=license_info.get("name", chunk.get("license_code", "")),
license_url=license_info.get("url"),
attribution_required=chunk.get("attribution_required", True),
source_url=chunk.get("source_url"),
document_type=chunk.get("document_type")
)
@router.get("/stats", response_model=DSFACorpusStatsResponse)
async def get_corpus_stats(
store: DSFACorpusStore = Depends(get_store),
qdrant: DSFAQdrantService = Depends(get_qdrant)
):
"""Get corpus statistics for dashboard."""
# Get PostgreSQL stats
source_stats = await store.get_source_stats()
total_docs = 0
total_chunks = 0
stats_response = []
for s in source_stats:
doc_count = s.get("document_count", 0) or 0
chunk_count = s.get("chunk_count", 0) or 0
total_docs += doc_count
total_chunks += chunk_count
last_indexed = s.get("last_indexed_at")
stats_response.append(DSFASourceStatsResponse(
source_id=str(s.get("source_id", "")),
source_code=s.get("source_code", ""),
name=s.get("name", ""),
organization=s.get("organization"),
license_code=s.get("license_code", ""),
document_type=s.get("document_type"),
document_count=doc_count,
chunk_count=chunk_count,
last_indexed_at=last_indexed.isoformat() if last_indexed else None
))
# Get Qdrant stats
qdrant_stats = await qdrant.get_stats()
return DSFACorpusStatsResponse(
sources=stats_response,
total_sources=len(source_stats),
total_documents=total_docs,
total_chunks=total_chunks,
qdrant_collection=DSFA_COLLECTION,
qdrant_points_count=qdrant_stats.get("points_count", 0),
qdrant_status=qdrant_stats.get("status", "unknown")
)
@router.get("/licenses")
async def list_licenses():
"""List all supported licenses with their terms."""
return [
LicenseInfo(
code=code,
name=info.get("name", code),
url=info.get("url"),
attribution_required=info.get("attribution_required", True),
modification_allowed=info.get("modification_allowed", True),
commercial_use=info.get("commercial_use", True)
)
for code, info in LICENSE_REGISTRY.items()
]
@router.post("/init")
async def initialize_dsfa_corpus(
store: DSFACorpusStore = Depends(get_store),
qdrant: DSFAQdrantService = Depends(get_qdrant)
):
"""
Initialize DSFA corpus.
- Creates Qdrant collection
- Registers all predefined sources
"""
# Ensure Qdrant collection exists
qdrant_ok = await qdrant.ensure_collection()
# Register all sources
registered = 0
for source in DSFA_SOURCES:
try:
await store.register_source(source)
registered += 1
except Exception as e:
print(f"Error registering source {source['source_code']}: {e}")
return {
"qdrant_collection_created": qdrant_ok,
"sources_registered": registered,
"total_sources": len(DSFA_SOURCES)
}

View File

@@ -20,6 +20,7 @@ This is the main entry point. All functionality is organized in modular packages
import os
from contextlib import asynccontextmanager
import asyncpg
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
@@ -36,7 +37,19 @@ from admin_api import router as admin_router
from zeugnis_api import router as zeugnis_router
from training_api import router as training_router
from mail.api import router as mail_router
try:
from trocr_api import router as trocr_router
except ImportError:
trocr_router = None
from vocab_worksheet_api import router as vocab_router, set_db_pool as set_vocab_db_pool, _init_vocab_table, _load_all_sessions, DATABASE_URL as VOCAB_DATABASE_URL
try:
from dsfa_rag_api import router as dsfa_rag_router, set_db_pool as set_dsfa_db_pool
from dsfa_corpus_ingestion import DSFAQdrantService, DATABASE_URL as DSFA_DATABASE_URL
except ImportError:
dsfa_rag_router = None
set_dsfa_db_pool = None
DSFAQdrantService = None
DSFA_DATABASE_URL = None
# BYOEH Qdrant initialization
from qdrant_service import init_qdrant_collection
@@ -51,12 +64,42 @@ async def lifespan(app: FastAPI):
"""Application lifespan manager for startup and shutdown events."""
print("Klausur-Service starting...")
# Initialize database pool for Vocab Sessions
vocab_db_pool = None
try:
vocab_db_pool = await asyncpg.create_pool(VOCAB_DATABASE_URL, min_size=2, max_size=5)
set_vocab_db_pool(vocab_db_pool)
await _init_vocab_table()
await _load_all_sessions()
print(f"Vocab sessions database initialized")
except Exception as e:
print(f"Warning: Vocab sessions database initialization failed: {e}")
# Initialize database pool for DSFA RAG
dsfa_db_pool = None
if DSFA_DATABASE_URL and set_dsfa_db_pool:
try:
dsfa_db_pool = await asyncpg.create_pool(DSFA_DATABASE_URL, min_size=2, max_size=10)
set_dsfa_db_pool(dsfa_db_pool)
print(f"DSFA database pool initialized: {DSFA_DATABASE_URL}")
except Exception as e:
print(f"Warning: DSFA database pool initialization failed: {e}")
# Initialize Qdrant collection for BYOEH
try:
await init_qdrant_collection()
print("Qdrant BYOEH collection initialized")
except Exception as e:
print(f"Warning: Qdrant initialization failed: {e}")
print(f"Warning: Qdrant BYOEH initialization failed: {e}")
# Initialize Qdrant collection for DSFA RAG
if DSFAQdrantService:
try:
dsfa_qdrant = DSFAQdrantService()
await dsfa_qdrant.ensure_collection()
print("Qdrant DSFA corpus collection initialized")
except Exception as e:
print(f"Warning: Qdrant DSFA initialization failed: {e}")
# Ensure EH upload directory exists
os.makedirs(EH_UPLOAD_DIR, exist_ok=True)
@@ -65,6 +108,16 @@ async def lifespan(app: FastAPI):
print("Klausur-Service shutting down...")
# Close Vocab sessions database pool
if vocab_db_pool:
await vocab_db_pool.close()
print("Vocab sessions database pool closed")
# Close DSFA database pool
if dsfa_db_pool:
await dsfa_db_pool.close()
print("DSFA database pool closed")
app = FastAPI(
title="Klausur-Service",
@@ -94,7 +147,11 @@ app.include_router(admin_router) # NiBiS Ingestion
app.include_router(zeugnis_router) # Zeugnis Rights-Aware Crawler
app.include_router(training_router) # Training Management
app.include_router(mail_router) # Unified Inbox Mail
if trocr_router:
app.include_router(trocr_router) # TrOCR Handwriting OCR
app.include_router(vocab_router) # Vocabulary Worksheet Generator
if dsfa_rag_router:
app.include_router(dsfa_rag_router) # DSFA RAG Corpus Search
# =============================================

View File

@@ -9,6 +9,7 @@ python-dotenv>=1.0.0
qdrant-client>=1.7.0
cryptography>=41.0.0
PyPDF2>=3.0.0
PyMuPDF>=1.24.0
# PyTorch CPU-only (smaller, no CUDA needed for Docker on Mac)
--extra-index-url https://download.pytorch.org/whl/cpu
@@ -23,6 +24,10 @@ minio>=7.2.0
# OpenCV for handwriting detection (headless = no GUI, smaller for CI)
opencv-python-headless>=4.8.0
# Tesseract OCR Python binding (requires system tesseract-ocr package)
pytesseract>=0.3.10
Pillow>=10.0.0
# PostgreSQL (for metrics storage)
psycopg2-binary>=2.9.0
asyncpg>=0.29.0

View File

@@ -0,0 +1,509 @@
"""
Grid Detection Service v4
Detects table/grid structure from OCR bounding-box data.
Converts pixel coordinates to percentage and mm coordinates (A4 format).
Supports deskew correction, column detection, and multi-line cell handling.
Lizenz: Apache 2.0 (kommerziell nutzbar)
"""
import math
import logging
from enum import Enum
from dataclasses import dataclass, field
from typing import List, Optional, Dict, Any, Tuple
logger = logging.getLogger(__name__)
# A4 dimensions
A4_WIDTH_MM = 210.0
A4_HEIGHT_MM = 297.0
# Column margin (1mm)
COLUMN_MARGIN_MM = 1.0
COLUMN_MARGIN_PCT = (COLUMN_MARGIN_MM / A4_WIDTH_MM) * 100
class CellStatus(str, Enum):
EMPTY = "empty"
RECOGNIZED = "recognized"
PROBLEMATIC = "problematic"
MANUAL = "manual"
class ColumnType(str, Enum):
ENGLISH = "english"
GERMAN = "german"
EXAMPLE = "example"
UNKNOWN = "unknown"
@dataclass
class OCRRegion:
"""A word/phrase detected by OCR with bounding box coordinates in percentage (0-100)."""
text: str
confidence: float
x: float # X position as percentage of page width
y: float # Y position as percentage of page height
width: float # Width as percentage of page width
height: float # Height as percentage of page height
@property
def x_mm(self) -> float:
return round(self.x / 100 * A4_WIDTH_MM, 1)
@property
def y_mm(self) -> float:
return round(self.y / 100 * A4_HEIGHT_MM, 1)
@property
def width_mm(self) -> float:
return round(self.width / 100 * A4_WIDTH_MM, 1)
@property
def height_mm(self) -> float:
return round(self.height / 100 * A4_HEIGHT_MM, 2)
@property
def center_x(self) -> float:
return self.x + self.width / 2
@property
def center_y(self) -> float:
return self.y + self.height / 2
@property
def right(self) -> float:
return self.x + self.width
@property
def bottom(self) -> float:
return self.y + self.height
@dataclass
class GridCell:
"""A cell in the detected grid with coordinates in percentage (0-100)."""
row: int
col: int
x: float
y: float
width: float
height: float
text: str = ""
confidence: float = 0.0
status: CellStatus = CellStatus.EMPTY
column_type: ColumnType = ColumnType.UNKNOWN
logical_row: int = 0
logical_col: int = 0
is_continuation: bool = False
@property
def x_mm(self) -> float:
return round(self.x / 100 * A4_WIDTH_MM, 1)
@property
def y_mm(self) -> float:
return round(self.y / 100 * A4_HEIGHT_MM, 1)
@property
def width_mm(self) -> float:
return round(self.width / 100 * A4_WIDTH_MM, 1)
@property
def height_mm(self) -> float:
return round(self.height / 100 * A4_HEIGHT_MM, 2)
def to_dict(self) -> dict:
return {
"row": self.row,
"col": self.col,
"x": round(self.x, 2),
"y": round(self.y, 2),
"width": round(self.width, 2),
"height": round(self.height, 2),
"x_mm": self.x_mm,
"y_mm": self.y_mm,
"width_mm": self.width_mm,
"height_mm": self.height_mm,
"text": self.text,
"confidence": self.confidence,
"status": self.status.value,
"column_type": self.column_type.value,
"logical_row": self.logical_row,
"logical_col": self.logical_col,
"is_continuation": self.is_continuation,
}
@dataclass
class GridResult:
"""Result of grid detection."""
rows: int = 0
columns: int = 0
cells: List[List[GridCell]] = field(default_factory=list)
column_types: List[str] = field(default_factory=list)
column_boundaries: List[float] = field(default_factory=list)
row_boundaries: List[float] = field(default_factory=list)
deskew_angle: float = 0.0
stats: Dict[str, Any] = field(default_factory=dict)
def to_dict(self) -> dict:
cells_dicts = []
for row_cells in self.cells:
cells_dicts.append([c.to_dict() for c in row_cells])
return {
"rows": self.rows,
"columns": self.columns,
"cells": cells_dicts,
"column_types": self.column_types,
"column_boundaries": [round(b, 2) for b in self.column_boundaries],
"row_boundaries": [round(b, 2) for b in self.row_boundaries],
"deskew_angle": round(self.deskew_angle, 2),
"stats": self.stats,
"page_dimensions": {
"width_mm": A4_WIDTH_MM,
"height_mm": A4_HEIGHT_MM,
"format": "A4",
},
}
class GridDetectionService:
"""Detect grid/table structure from OCR bounding-box regions."""
def __init__(self, y_tolerance_pct: float = 1.5, padding_pct: float = 0.3,
column_margin_mm: float = COLUMN_MARGIN_MM):
self.y_tolerance_pct = y_tolerance_pct
self.padding_pct = padding_pct
self.column_margin_mm = column_margin_mm
def calculate_deskew_angle(self, regions: List[OCRRegion]) -> float:
"""Calculate page skew angle from OCR region positions.
Uses left-edge alignment of regions to detect consistent tilt.
Returns angle in degrees, clamped to ±5°.
"""
if len(regions) < 3:
return 0.0
# Group by similar X position (same column)
sorted_by_x = sorted(regions, key=lambda r: r.x)
# Find regions that are vertically aligned (similar X)
x_tolerance = 3.0 # percent
aligned_groups: List[List[OCRRegion]] = []
current_group = [sorted_by_x[0]]
for r in sorted_by_x[1:]:
if abs(r.x - current_group[0].x) <= x_tolerance:
current_group.append(r)
else:
if len(current_group) >= 3:
aligned_groups.append(current_group)
current_group = [r]
if len(current_group) >= 3:
aligned_groups.append(current_group)
if not aligned_groups:
return 0.0
# Use the largest aligned group to calculate skew
best_group = max(aligned_groups, key=len)
best_group.sort(key=lambda r: r.y)
# Linear regression: X as function of Y
n = len(best_group)
sum_y = sum(r.y for r in best_group)
sum_x = sum(r.x for r in best_group)
sum_xy = sum(r.x * r.y for r in best_group)
sum_y2 = sum(r.y ** 2 for r in best_group)
denom = n * sum_y2 - sum_y ** 2
if denom == 0:
return 0.0
slope = (n * sum_xy - sum_y * sum_x) / denom
# Convert slope to angle (slope is dx/dy in percent space)
# Adjust for aspect ratio: A4 is 210/297 ≈ 0.707
aspect = A4_WIDTH_MM / A4_HEIGHT_MM
angle_rad = math.atan(slope * aspect)
angle_deg = math.degrees(angle_rad)
# Clamp to ±5°
return max(-5.0, min(5.0, round(angle_deg, 2)))
def apply_deskew_to_regions(self, regions: List[OCRRegion], angle: float) -> List[OCRRegion]:
"""Apply deskew correction to region coordinates.
Rotates all coordinates around the page center by -angle.
"""
if abs(angle) < 0.01:
return regions
angle_rad = math.radians(-angle)
cos_a = math.cos(angle_rad)
sin_a = math.sin(angle_rad)
# Page center
cx, cy = 50.0, 50.0
result = []
for r in regions:
# Rotate center of region around page center
rx = r.center_x - cx
ry = r.center_y - cy
new_cx = rx * cos_a - ry * sin_a + cx
new_cy = rx * sin_a + ry * cos_a + cy
new_x = new_cx - r.width / 2
new_y = new_cy - r.height / 2
result.append(OCRRegion(
text=r.text,
confidence=r.confidence,
x=round(new_x, 2),
y=round(new_y, 2),
width=r.width,
height=r.height,
))
return result
def _group_regions_into_rows(self, regions: List[OCRRegion]) -> List[List[OCRRegion]]:
"""Group regions by Y position into rows."""
if not regions:
return []
sorted_regions = sorted(regions, key=lambda r: r.y)
rows: List[List[OCRRegion]] = []
current_row = [sorted_regions[0]]
current_y = sorted_regions[0].center_y
for r in sorted_regions[1:]:
if abs(r.center_y - current_y) <= self.y_tolerance_pct:
current_row.append(r)
else:
current_row.sort(key=lambda r: r.x)
rows.append(current_row)
current_row = [r]
current_y = r.center_y
if current_row:
current_row.sort(key=lambda r: r.x)
rows.append(current_row)
return rows
def _detect_column_boundaries(self, rows: List[List[OCRRegion]]) -> List[float]:
"""Detect column boundaries from row data."""
if not rows:
return []
# Collect all X starting positions
all_x = []
for row in rows:
for r in row:
all_x.append(r.x)
if not all_x:
return []
all_x.sort()
# Gap-based clustering
min_gap = 5.0 # percent
clusters: List[List[float]] = []
current = [all_x[0]]
for x in all_x[1:]:
if x - current[-1] > min_gap:
clusters.append(current)
current = [x]
else:
current.append(x)
if current:
clusters.append(current)
# Column boundaries: start of each cluster
boundaries = [min(c) - self.padding_pct for c in clusters]
# Add right boundary
boundaries.append(100.0)
return boundaries
def _assign_column_types(self, boundaries: List[float]) -> List[str]:
"""Assign column types based on position."""
num_cols = max(0, len(boundaries) - 1)
type_map = [ColumnType.ENGLISH, ColumnType.GERMAN, ColumnType.EXAMPLE]
result = []
for i in range(num_cols):
if i < len(type_map):
result.append(type_map[i].value)
else:
result.append(ColumnType.UNKNOWN.value)
return result
def detect_grid(self, regions: List[OCRRegion]) -> GridResult:
"""Detect grid structure from OCR regions.
Args:
regions: List of OCR regions with percentage-based coordinates.
Returns:
GridResult with detected rows, columns, and cells.
"""
if not regions:
return GridResult(stats={"recognized": 0, "problematic": 0, "empty": 0, "manual": 0, "total": 0, "coverage": 0.0})
# Step 1: Calculate and apply deskew
deskew_angle = self.calculate_deskew_angle(regions)
corrected_regions = self.apply_deskew_to_regions(regions, deskew_angle)
# Step 2: Group into rows
rows = self._group_regions_into_rows(corrected_regions)
# Step 3: Detect column boundaries
col_boundaries = self._detect_column_boundaries(rows)
column_types = self._assign_column_types(col_boundaries)
num_cols = max(1, len(col_boundaries) - 1)
# Step 4: Build cell grid
num_rows = len(rows)
row_boundaries = []
cells = []
recognized = 0
problematic = 0
empty = 0
for row_idx, row_regions in enumerate(rows):
# Row Y boundary
if row_regions:
row_y = min(r.y for r in row_regions) - self.padding_pct
row_bottom = max(r.bottom for r in row_regions) + self.padding_pct
else:
row_y = row_idx / num_rows * 100
row_bottom = (row_idx + 1) / num_rows * 100
row_boundaries.append(row_y)
row_height = row_bottom - row_y
row_cells = []
for col_idx in range(num_cols):
col_x = col_boundaries[col_idx]
col_right = col_boundaries[col_idx + 1] if col_idx + 1 < len(col_boundaries) else 100.0
col_width = col_right - col_x
# Find regions in this cell
cell_regions = []
for r in row_regions:
r_center = r.center_x
if col_x <= r_center < col_right:
cell_regions.append(r)
if cell_regions:
text = " ".join(r.text for r in cell_regions)
avg_conf = sum(r.confidence for r in cell_regions) / len(cell_regions)
status = CellStatus.RECOGNIZED if avg_conf >= 0.5 else CellStatus.PROBLEMATIC
# Use actual bounding box from regions
actual_x = min(r.x for r in cell_regions)
actual_y = min(r.y for r in cell_regions)
actual_right = max(r.right for r in cell_regions)
actual_bottom = max(r.bottom for r in cell_regions)
cell = GridCell(
row=row_idx,
col=col_idx,
x=actual_x,
y=actual_y,
width=actual_right - actual_x,
height=actual_bottom - actual_y,
text=text,
confidence=round(avg_conf, 3),
status=status,
column_type=ColumnType(column_types[col_idx]) if col_idx < len(column_types) else ColumnType.UNKNOWN,
logical_row=row_idx,
logical_col=col_idx,
)
if status == CellStatus.RECOGNIZED:
recognized += 1
else:
problematic += 1
else:
cell = GridCell(
row=row_idx,
col=col_idx,
x=col_x,
y=row_y,
width=col_width,
height=row_height,
status=CellStatus.EMPTY,
column_type=ColumnType(column_types[col_idx]) if col_idx < len(column_types) else ColumnType.UNKNOWN,
logical_row=row_idx,
logical_col=col_idx,
)
empty += 1
row_cells.append(cell)
cells.append(row_cells)
# Add final row boundary
if rows and rows[-1]:
row_boundaries.append(max(r.bottom for r in rows[-1]) + self.padding_pct)
else:
row_boundaries.append(100.0)
total = num_rows * num_cols
coverage = (recognized + problematic) / max(total, 1)
return GridResult(
rows=num_rows,
columns=num_cols,
cells=cells,
column_types=column_types,
column_boundaries=col_boundaries,
row_boundaries=row_boundaries,
deskew_angle=deskew_angle,
stats={
"recognized": recognized,
"problematic": problematic,
"empty": empty,
"manual": 0,
"total": total,
"coverage": round(coverage, 3),
},
)
def convert_tesseract_regions(self, tess_words: List[dict],
image_width: int, image_height: int) -> List[OCRRegion]:
"""Convert Tesseract word data (pixels) to OCRRegions (percentages).
Args:
tess_words: Word list from tesseract_vocab_extractor.extract_bounding_boxes.
image_width: Image width in pixels.
image_height: Image height in pixels.
Returns:
List of OCRRegion with percentage-based coordinates.
"""
if not tess_words or image_width == 0 or image_height == 0:
return []
regions = []
for w in tess_words:
regions.append(OCRRegion(
text=w["text"],
confidence=w.get("conf", 50) / 100.0,
x=w["left"] / image_width * 100,
y=w["top"] / image_height * 100,
width=w["width"] / image_width * 100,
height=w["height"] / image_height * 100,
))
return regions

View File

@@ -0,0 +1,346 @@
"""
Tesseract-based OCR extraction with word-level bounding boxes.
Uses Tesseract for spatial information (WHERE text is) while
the Vision LLM handles semantic understanding (WHAT the text means).
Tesseract runs natively on ARM64 via Debian's apt package.
Lizenz: Apache 2.0 (kommerziell nutzbar)
"""
import io
import logging
from typing import List, Dict, Any, Optional
from difflib import SequenceMatcher
logger = logging.getLogger(__name__)
try:
import pytesseract
from PIL import Image
TESSERACT_AVAILABLE = True
except ImportError:
TESSERACT_AVAILABLE = False
logger.warning("pytesseract or Pillow not installed - Tesseract OCR unavailable")
async def extract_bounding_boxes(image_bytes: bytes, lang: str = "eng+deu") -> dict:
"""Run Tesseract OCR and return word-level bounding boxes.
Args:
image_bytes: PNG/JPEG image as bytes.
lang: Tesseract language string (e.g. "eng+deu").
Returns:
Dict with 'words' list and 'image_width'/'image_height'.
"""
if not TESSERACT_AVAILABLE:
return {"words": [], "image_width": 0, "image_height": 0, "error": "Tesseract not available"}
image = Image.open(io.BytesIO(image_bytes))
data = pytesseract.image_to_data(image, lang=lang, output_type=pytesseract.Output.DICT)
words = []
for i in range(len(data['text'])):
text = data['text'][i].strip()
conf = int(data['conf'][i])
if not text or conf < 20:
continue
words.append({
"text": text,
"left": data['left'][i],
"top": data['top'][i],
"width": data['width'][i],
"height": data['height'][i],
"conf": conf,
"block_num": data['block_num'][i],
"par_num": data['par_num'][i],
"line_num": data['line_num'][i],
"word_num": data['word_num'][i],
})
return {
"words": words,
"image_width": image.width,
"image_height": image.height,
}
def group_words_into_lines(words: List[dict], y_tolerance_px: int = 15) -> List[List[dict]]:
"""Group words by their Y position into lines.
Args:
words: List of word dicts from extract_bounding_boxes.
y_tolerance_px: Max pixel distance to consider words on the same line.
Returns:
List of lines, each line is a list of words sorted by X position.
"""
if not words:
return []
# Sort by Y then X
sorted_words = sorted(words, key=lambda w: (w['top'], w['left']))
lines: List[List[dict]] = []
current_line: List[dict] = [sorted_words[0]]
current_y = sorted_words[0]['top']
for word in sorted_words[1:]:
if abs(word['top'] - current_y) <= y_tolerance_px:
current_line.append(word)
else:
current_line.sort(key=lambda w: w['left'])
lines.append(current_line)
current_line = [word]
current_y = word['top']
if current_line:
current_line.sort(key=lambda w: w['left'])
lines.append(current_line)
return lines
def detect_columns(lines: List[List[dict]], image_width: int) -> Dict[str, Any]:
"""Detect column boundaries from word positions.
Typical vocab table: Left=English, Middle=German, Right=Example sentences.
Returns:
Dict with column boundaries and type assignments.
"""
if not lines or image_width == 0:
return {"columns": [], "column_types": []}
# Collect all word X positions
all_x_positions = []
for line in lines:
for word in line:
all_x_positions.append(word['left'])
if not all_x_positions:
return {"columns": [], "column_types": []}
# Find X-position clusters (column starts)
all_x_positions.sort()
# Simple gap-based column detection
min_gap = image_width * 0.08 # 8% of page width = column gap
clusters = []
current_cluster = [all_x_positions[0]]
for x in all_x_positions[1:]:
if x - current_cluster[-1] > min_gap:
clusters.append(current_cluster)
current_cluster = [x]
else:
current_cluster.append(x)
if current_cluster:
clusters.append(current_cluster)
# Each cluster represents a column start
columns = []
for cluster in clusters:
col_start = min(cluster)
columns.append({
"x_start": col_start,
"x_start_pct": col_start / image_width * 100,
"word_count": len(cluster),
})
# Assign column types based on position (left→right: EN, DE, Example)
type_map = ["english", "german", "example"]
column_types = []
for i, col in enumerate(columns):
if i < len(type_map):
column_types.append(type_map[i])
else:
column_types.append("unknown")
return {
"columns": columns,
"column_types": column_types,
}
def words_to_vocab_entries(lines: List[List[dict]], columns: List[dict],
column_types: List[str], image_width: int,
image_height: int) -> List[dict]:
"""Convert grouped words into vocabulary entries using column positions.
Args:
lines: Grouped word lines from group_words_into_lines.
columns: Column boundaries from detect_columns.
column_types: Column type assignments.
image_width: Image width in pixels.
image_height: Image height in pixels.
Returns:
List of vocabulary entry dicts with english/german/example fields.
"""
if not columns or not lines:
return []
# Build column boundaries for word assignment
col_boundaries = []
for i, col in enumerate(columns):
start = col['x_start']
if i + 1 < len(columns):
end = columns[i + 1]['x_start']
else:
end = image_width
col_boundaries.append((start, end, column_types[i] if i < len(column_types) else "unknown"))
entries = []
for line in lines:
entry = {"english": "", "german": "", "example": ""}
line_words_by_col: Dict[str, List[str]] = {"english": [], "german": [], "example": []}
line_bbox: Dict[str, Optional[dict]] = {}
for word in line:
word_center_x = word['left'] + word['width'] / 2
assigned_type = "unknown"
for start, end, col_type in col_boundaries:
if start <= word_center_x < end:
assigned_type = col_type
break
if assigned_type in line_words_by_col:
line_words_by_col[assigned_type].append(word['text'])
# Track bounding box for the column
if assigned_type not in line_bbox or line_bbox[assigned_type] is None:
line_bbox[assigned_type] = {
"left": word['left'],
"top": word['top'],
"right": word['left'] + word['width'],
"bottom": word['top'] + word['height'],
}
else:
bb = line_bbox[assigned_type]
bb['left'] = min(bb['left'], word['left'])
bb['top'] = min(bb['top'], word['top'])
bb['right'] = max(bb['right'], word['left'] + word['width'])
bb['bottom'] = max(bb['bottom'], word['top'] + word['height'])
for col_type in ["english", "german", "example"]:
if line_words_by_col[col_type]:
entry[col_type] = " ".join(line_words_by_col[col_type])
if line_bbox.get(col_type):
bb = line_bbox[col_type]
entry[f"{col_type}_bbox"] = {
"x_pct": bb['left'] / image_width * 100,
"y_pct": bb['top'] / image_height * 100,
"w_pct": (bb['right'] - bb['left']) / image_width * 100,
"h_pct": (bb['bottom'] - bb['top']) / image_height * 100,
}
# Only add if at least one column has content
if entry["english"] or entry["german"]:
entries.append(entry)
return entries
def match_positions_to_vocab(tess_words: List[dict], llm_vocab: List[dict],
image_w: int, image_h: int,
threshold: float = 0.6) -> List[dict]:
"""Match Tesseract bounding boxes to LLM vocabulary entries.
For each LLM vocab entry, find the best-matching Tesseract word
and attach its bounding box coordinates.
Args:
tess_words: Word list from Tesseract with pixel coordinates.
llm_vocab: Vocabulary list from Vision LLM.
image_w: Image width in pixels.
image_h: Image height in pixels.
threshold: Minimum similarity ratio for a match.
Returns:
llm_vocab list with bbox_x_pct/bbox_y_pct/bbox_w_pct/bbox_h_pct added.
"""
if not tess_words or not llm_vocab or image_w == 0 or image_h == 0:
return llm_vocab
for entry in llm_vocab:
english = entry.get("english", "").lower().strip()
german = entry.get("german", "").lower().strip()
if not english and not german:
continue
# Try to match English word first, then German
for field in ["english", "german"]:
search_text = entry.get(field, "").lower().strip()
if not search_text:
continue
best_word = None
best_ratio = 0.0
for word in tess_words:
ratio = SequenceMatcher(None, search_text, word['text'].lower()).ratio()
if ratio > best_ratio:
best_ratio = ratio
best_word = word
if best_word and best_ratio >= threshold:
entry[f"bbox_x_pct"] = best_word['left'] / image_w * 100
entry[f"bbox_y_pct"] = best_word['top'] / image_h * 100
entry[f"bbox_w_pct"] = best_word['width'] / image_w * 100
entry[f"bbox_h_pct"] = best_word['height'] / image_h * 100
entry["bbox_match_field"] = field
entry["bbox_match_ratio"] = round(best_ratio, 3)
break # Found a match, no need to try the other field
return llm_vocab
async def run_tesseract_pipeline(image_bytes: bytes, lang: str = "eng+deu") -> dict:
"""Full Tesseract pipeline: extract words, group lines, detect columns, build vocab.
Args:
image_bytes: PNG/JPEG image as bytes.
lang: Tesseract language string.
Returns:
Dict with 'vocabulary', 'words', 'lines', 'columns', 'image_width', 'image_height'.
"""
# Step 1: Extract bounding boxes
bbox_data = await extract_bounding_boxes(image_bytes, lang=lang)
if bbox_data.get("error"):
return bbox_data
words = bbox_data["words"]
image_w = bbox_data["image_width"]
image_h = bbox_data["image_height"]
# Step 2: Group into lines
lines = group_words_into_lines(words)
# Step 3: Detect columns
col_info = detect_columns(lines, image_w)
# Step 4: Build vocabulary entries
vocab = words_to_vocab_entries(
lines,
col_info["columns"],
col_info["column_types"],
image_w,
image_h,
)
return {
"vocabulary": vocab,
"words": words,
"lines_count": len(lines),
"columns": col_info["columns"],
"column_types": col_info["column_types"],
"image_width": image_w,
"image_height": image_h,
"word_count": len(words),
}

View File

@@ -0,0 +1,261 @@
"""
TrOCR API - REST endpoints for TrOCR handwriting OCR.
Provides:
- /ocr/trocr - Single image OCR
- /ocr/trocr/batch - Batch image processing
- /ocr/trocr/status - Model status
- /ocr/trocr/cache - Cache statistics
"""
from fastapi import APIRouter, UploadFile, File, HTTPException, Query
from fastapi.responses import StreamingResponse
from pydantic import BaseModel, Field
from typing import List, Optional
import json
import logging
from services.trocr_service import (
run_trocr_ocr_enhanced,
run_trocr_batch,
run_trocr_batch_stream,
get_model_status,
get_cache_stats,
preload_trocr_model,
OCRResult,
BatchOCRResult
)
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/v1/ocr/trocr", tags=["TrOCR"])
# =============================================================================
# MODELS
# =============================================================================
class TrOCRResponse(BaseModel):
"""Response model for single image OCR."""
text: str = Field(..., description="Extracted text")
confidence: float = Field(..., ge=0.0, le=1.0, description="Overall confidence")
processing_time_ms: int = Field(..., ge=0, description="Processing time in milliseconds")
model: str = Field(..., description="Model used for OCR")
has_lora_adapter: bool = Field(False, description="Whether LoRA adapter was used")
from_cache: bool = Field(False, description="Whether result was from cache")
image_hash: str = Field("", description="SHA256 hash of image (first 16 chars)")
word_count: int = Field(0, description="Number of words detected")
class BatchOCRResponse(BaseModel):
"""Response model for batch OCR."""
results: List[TrOCRResponse] = Field(..., description="Individual OCR results")
total_time_ms: int = Field(..., ge=0, description="Total processing time")
processed_count: int = Field(..., ge=0, description="Number of images processed")
cached_count: int = Field(0, description="Number of results from cache")
error_count: int = Field(0, description="Number of errors")
class ModelStatusResponse(BaseModel):
"""Response model for model status."""
status: str = Field(..., description="Model status: available, not_installed")
is_loaded: bool = Field(..., description="Whether model is loaded in memory")
model_name: Optional[str] = Field(None, description="Name of loaded model")
device: Optional[str] = Field(None, description="Device model is running on")
loaded_at: Optional[str] = Field(None, description="ISO timestamp when model was loaded")
class CacheStatsResponse(BaseModel):
"""Response model for cache statistics."""
size: int = Field(..., ge=0, description="Current cache size")
max_size: int = Field(..., ge=0, description="Maximum cache size")
ttl_seconds: int = Field(..., ge=0, description="Cache TTL in seconds")
# =============================================================================
# ENDPOINTS
# =============================================================================
@router.get("/status", response_model=ModelStatusResponse)
async def get_trocr_status():
"""
Get TrOCR model status.
Returns information about whether the model is loaded and available.
"""
return get_model_status()
@router.get("/cache", response_model=CacheStatsResponse)
async def get_trocr_cache_stats():
"""
Get TrOCR cache statistics.
Returns information about the OCR result cache.
"""
return get_cache_stats()
@router.post("/preload")
async def preload_model(handwritten: bool = Query(True, description="Load handwritten model")):
"""
Preload TrOCR model into memory.
This speeds up the first OCR request by loading the model ahead of time.
"""
success = preload_trocr_model(handwritten=handwritten)
if success:
return {"status": "success", "message": "Model preloaded successfully"}
else:
raise HTTPException(status_code=500, detail="Failed to preload model")
@router.post("", response_model=TrOCRResponse)
async def run_trocr(
file: UploadFile = File(..., description="Image file to process"),
handwritten: bool = Query(True, description="Use handwritten model"),
split_lines: bool = Query(True, description="Split image into lines"),
use_cache: bool = Query(True, description="Use result caching")
):
"""
Run TrOCR on a single image.
Supports PNG, JPG, and other common image formats.
"""
# Validate file type
if not file.content_type or not file.content_type.startswith("image/"):
raise HTTPException(status_code=400, detail="File must be an image")
try:
image_data = await file.read()
result = await run_trocr_ocr_enhanced(
image_data,
handwritten=handwritten,
split_lines=split_lines,
use_cache=use_cache
)
return TrOCRResponse(
text=result.text,
confidence=result.confidence,
processing_time_ms=result.processing_time_ms,
model=result.model,
has_lora_adapter=result.has_lora_adapter,
from_cache=result.from_cache,
image_hash=result.image_hash,
word_count=len(result.text.split()) if result.text else 0
)
except Exception as e:
logger.error(f"TrOCR API error: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.post("/batch", response_model=BatchOCRResponse)
async def run_trocr_batch_endpoint(
files: List[UploadFile] = File(..., description="Image files to process"),
handwritten: bool = Query(True, description="Use handwritten model"),
split_lines: bool = Query(True, description="Split images into lines"),
use_cache: bool = Query(True, description="Use result caching")
):
"""
Run TrOCR on multiple images.
Processes images sequentially and returns all results.
"""
if not files:
raise HTTPException(status_code=400, detail="No files provided")
if len(files) > 50:
raise HTTPException(status_code=400, detail="Maximum 50 images per batch")
try:
images = []
for file in files:
if not file.content_type or not file.content_type.startswith("image/"):
raise HTTPException(status_code=400, detail=f"File {file.filename} is not an image")
images.append(await file.read())
batch_result = await run_trocr_batch(
images,
handwritten=handwritten,
split_lines=split_lines,
use_cache=use_cache
)
return BatchOCRResponse(
results=[
TrOCRResponse(
text=r.text,
confidence=r.confidence,
processing_time_ms=r.processing_time_ms,
model=r.model,
has_lora_adapter=r.has_lora_adapter,
from_cache=r.from_cache,
image_hash=r.image_hash,
word_count=len(r.text.split()) if r.text else 0
)
for r in batch_result.results
],
total_time_ms=batch_result.total_time_ms,
processed_count=batch_result.processed_count,
cached_count=batch_result.cached_count,
error_count=batch_result.error_count
)
except HTTPException:
raise
except Exception as e:
logger.error(f"TrOCR batch API error: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.post("/batch/stream")
async def run_trocr_batch_stream_endpoint(
files: List[UploadFile] = File(..., description="Image files to process"),
handwritten: bool = Query(True, description="Use handwritten model"),
split_lines: bool = Query(True, description="Split images into lines"),
use_cache: bool = Query(True, description="Use result caching")
):
"""
Run TrOCR on multiple images with Server-Sent Events (SSE) progress updates.
Returns a stream of progress events as images are processed.
"""
if not files:
raise HTTPException(status_code=400, detail="No files provided")
if len(files) > 50:
raise HTTPException(status_code=400, detail="Maximum 50 images per batch")
try:
images = []
for file in files:
if not file.content_type or not file.content_type.startswith("image/"):
raise HTTPException(status_code=400, detail=f"File {file.filename} is not an image")
images.append(await file.read())
async def event_generator():
async for update in run_trocr_batch_stream(
images,
handwritten=handwritten,
split_lines=split_lines,
use_cache=use_cache
):
yield f"data: {json.dumps(update)}\n\n"
return StreamingResponse(
event_generator(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive"
}
)
except HTTPException:
raise
except Exception as e:
logger.error(f"TrOCR stream API error: {e}")
raise HTTPException(status_code=500, detail=str(e))

View File

@@ -0,0 +1,428 @@
"""
Vocabulary Session Store - PostgreSQL persistence for vocab extraction sessions.
Replaces in-memory storage with database persistence.
See migrations/001_vocab_sessions.sql for schema.
"""
import os
import uuid
import logging
import json
from typing import Optional, List, Dict, Any
from datetime import datetime
import asyncpg
logger = logging.getLogger(__name__)
# Database configuration
DATABASE_URL = os.getenv(
"DATABASE_URL",
"postgresql://breakpilot:breakpilot@postgres:5432/breakpilot_db"
)
# Connection pool (initialized lazily)
_pool: Optional[asyncpg.Pool] = None
async def get_pool() -> asyncpg.Pool:
"""Get or create the database connection pool."""
global _pool
if _pool is None:
_pool = await asyncpg.create_pool(DATABASE_URL, min_size=2, max_size=10)
return _pool
async def init_vocab_tables():
"""
Initialize vocab tables if they don't exist.
This is called at startup.
"""
pool = await get_pool()
async with pool.acquire() as conn:
# Check if tables exist
tables_exist = await conn.fetchval("""
SELECT EXISTS (
SELECT FROM information_schema.tables
WHERE table_name = 'vocab_sessions'
)
""")
if not tables_exist:
logger.info("Creating vocab tables...")
# Read and execute migration
migration_path = os.path.join(
os.path.dirname(__file__),
"migrations/001_vocab_sessions.sql"
)
if os.path.exists(migration_path):
with open(migration_path, "r") as f:
sql = f.read()
await conn.execute(sql)
logger.info("Vocab tables created successfully")
else:
logger.warning(f"Migration file not found: {migration_path}")
else:
logger.debug("Vocab tables already exist")
# =============================================================================
# SESSION OPERATIONS
# =============================================================================
async def create_session_db(
session_id: str,
name: str,
description: str = "",
source_language: str = "en",
target_language: str = "de"
) -> Dict[str, Any]:
"""Create a new vocabulary session in the database."""
pool = await get_pool()
async with pool.acquire() as conn:
row = await conn.fetchrow("""
INSERT INTO vocab_sessions (
id, name, description, source_language, target_language,
status, vocabulary_count
) VALUES ($1, $2, $3, $4, $5, 'pending', 0)
RETURNING *
""", uuid.UUID(session_id), name, description, source_language, target_language)
return _row_to_dict(row)
async def get_session_db(session_id: str) -> Optional[Dict[str, Any]]:
"""Get a session by ID."""
pool = await get_pool()
async with pool.acquire() as conn:
row = await conn.fetchrow("""
SELECT * FROM vocab_sessions WHERE id = $1
""", uuid.UUID(session_id))
if row:
return _row_to_dict(row)
return None
async def list_sessions_db(
limit: int = 50,
offset: int = 0,
status: Optional[str] = None
) -> List[Dict[str, Any]]:
"""List all sessions with optional filtering."""
pool = await get_pool()
async with pool.acquire() as conn:
if status:
rows = await conn.fetch("""
SELECT * FROM vocab_sessions
WHERE status = $1
ORDER BY created_at DESC
LIMIT $2 OFFSET $3
""", status, limit, offset)
else:
rows = await conn.fetch("""
SELECT * FROM vocab_sessions
ORDER BY created_at DESC
LIMIT $1 OFFSET $2
""", limit, offset)
return [_row_to_dict(row) for row in rows]
async def update_session_db(
session_id: str,
**kwargs
) -> Optional[Dict[str, Any]]:
"""Update a session with given fields."""
pool = await get_pool()
# Build dynamic UPDATE query
fields = []
values = []
param_idx = 1
allowed_fields = [
'name', 'description', 'status', 'vocabulary_count',
'extraction_confidence', 'image_path', 'pdf_path', 'pdf_page_count',
'ocr_prompts', 'processed_pages', 'successful_pages', 'failed_pages'
]
for key, value in kwargs.items():
if key in allowed_fields:
fields.append(f"{key} = ${param_idx}")
# Convert dicts/lists to JSON for JSONB columns
if key in ['ocr_prompts', 'processed_pages', 'successful_pages', 'failed_pages']:
value = json.dumps(value) if value else None
values.append(value)
param_idx += 1
if not fields:
return await get_session_db(session_id)
values.append(uuid.UUID(session_id))
async with pool.acquire() as conn:
row = await conn.fetchrow(f"""
UPDATE vocab_sessions
SET {', '.join(fields)}
WHERE id = ${param_idx}
RETURNING *
""", *values)
if row:
return _row_to_dict(row)
return None
async def delete_session_db(session_id: str) -> bool:
"""Delete a session and all related data (cascades)."""
pool = await get_pool()
async with pool.acquire() as conn:
result = await conn.execute("""
DELETE FROM vocab_sessions WHERE id = $1
""", uuid.UUID(session_id))
return result == "DELETE 1"
# =============================================================================
# VOCABULARY OPERATIONS
# =============================================================================
async def add_vocabulary_db(
session_id: str,
vocab_list: List[Dict[str, Any]]
) -> List[Dict[str, Any]]:
"""Add vocabulary entries to a session."""
if not vocab_list:
return []
pool = await get_pool()
results = []
async with pool.acquire() as conn:
for vocab in vocab_list:
vocab_id = str(uuid.uuid4())
row = await conn.fetchrow("""
INSERT INTO vocab_entries (
id, session_id, english, german, example_sentence,
example_sentence_gap, word_type, source_page
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
RETURNING *
""",
uuid.UUID(vocab_id),
uuid.UUID(session_id),
vocab.get('english', ''),
vocab.get('german', ''),
vocab.get('example_sentence'),
vocab.get('example_sentence_gap'),
vocab.get('word_type'),
vocab.get('source_page')
)
results.append(_row_to_dict(row))
# Update vocabulary count
await conn.execute("""
UPDATE vocab_sessions
SET vocabulary_count = (
SELECT COUNT(*) FROM vocab_entries WHERE session_id = $1
)
WHERE id = $1
""", uuid.UUID(session_id))
return results
async def get_vocabulary_db(
session_id: str,
source_page: Optional[int] = None
) -> List[Dict[str, Any]]:
"""Get vocabulary entries for a session."""
pool = await get_pool()
async with pool.acquire() as conn:
if source_page is not None:
rows = await conn.fetch("""
SELECT * FROM vocab_entries
WHERE session_id = $1 AND source_page = $2
ORDER BY created_at
""", uuid.UUID(session_id), source_page)
else:
rows = await conn.fetch("""
SELECT * FROM vocab_entries
WHERE session_id = $1
ORDER BY source_page NULLS LAST, created_at
""", uuid.UUID(session_id))
return [_row_to_dict(row) for row in rows]
async def update_vocabulary_db(
entry_id: str,
**kwargs
) -> Optional[Dict[str, Any]]:
"""Update a single vocabulary entry."""
pool = await get_pool()
fields = []
values = []
param_idx = 1
allowed_fields = [
'english', 'german', 'example_sentence', 'example_sentence_gap',
'word_type', 'source_page'
]
for key, value in kwargs.items():
if key in allowed_fields:
fields.append(f"{key} = ${param_idx}")
values.append(value)
param_idx += 1
if not fields:
return None
values.append(uuid.UUID(entry_id))
async with pool.acquire() as conn:
row = await conn.fetchrow(f"""
UPDATE vocab_entries
SET {', '.join(fields)}
WHERE id = ${param_idx}
RETURNING *
""", *values)
if row:
return _row_to_dict(row)
return None
async def clear_page_vocabulary_db(session_id: str, page: int) -> int:
"""Clear all vocabulary for a specific page."""
pool = await get_pool()
async with pool.acquire() as conn:
result = await conn.execute("""
DELETE FROM vocab_entries
WHERE session_id = $1 AND source_page = $2
""", uuid.UUID(session_id), page)
# Update vocabulary count
await conn.execute("""
UPDATE vocab_sessions
SET vocabulary_count = (
SELECT COUNT(*) FROM vocab_entries WHERE session_id = $1
)
WHERE id = $1
""", uuid.UUID(session_id))
# Return count of deleted rows
count = int(result.split()[-1]) if result else 0
return count
# =============================================================================
# WORKSHEET OPERATIONS
# =============================================================================
async def create_worksheet_db(
session_id: str,
worksheet_types: List[str],
pdf_path: Optional[str] = None,
solution_path: Optional[str] = None
) -> Dict[str, Any]:
"""Create a worksheet record."""
pool = await get_pool()
worksheet_id = str(uuid.uuid4())
async with pool.acquire() as conn:
row = await conn.fetchrow("""
INSERT INTO vocab_worksheets (
id, session_id, worksheet_types, pdf_path, solution_path
) VALUES ($1, $2, $3, $4, $5)
RETURNING *
""",
uuid.UUID(worksheet_id),
uuid.UUID(session_id),
json.dumps(worksheet_types),
pdf_path,
solution_path
)
return _row_to_dict(row)
async def get_worksheet_db(worksheet_id: str) -> Optional[Dict[str, Any]]:
"""Get a worksheet by ID."""
pool = await get_pool()
async with pool.acquire() as conn:
row = await conn.fetchrow("""
SELECT * FROM vocab_worksheets WHERE id = $1
""", uuid.UUID(worksheet_id))
if row:
return _row_to_dict(row)
return None
async def delete_worksheets_for_session_db(session_id: str) -> int:
"""Delete all worksheets for a session."""
pool = await get_pool()
async with pool.acquire() as conn:
result = await conn.execute("""
DELETE FROM vocab_worksheets WHERE session_id = $1
""", uuid.UUID(session_id))
count = int(result.split()[-1]) if result else 0
return count
# =============================================================================
# PDF CACHE OPERATIONS
# =============================================================================
# Simple in-memory cache for PDF data (temporary until served)
_pdf_cache: Dict[str, bytes] = {}
def cache_pdf_data(worksheet_id: str, pdf_data: bytes) -> None:
"""Cache PDF data temporarily for download."""
_pdf_cache[worksheet_id] = pdf_data
def get_cached_pdf_data(worksheet_id: str) -> Optional[bytes]:
"""Get cached PDF data."""
return _pdf_cache.get(worksheet_id)
def clear_cached_pdf_data(worksheet_id: str) -> None:
"""Clear cached PDF data."""
_pdf_cache.pop(worksheet_id, None)
# =============================================================================
# HELPER FUNCTIONS
# =============================================================================
def _row_to_dict(row: asyncpg.Record) -> Dict[str, Any]:
"""Convert asyncpg Record to dict with proper type handling."""
if row is None:
return {}
result = dict(row)
# Convert UUIDs to strings
for key in ['id', 'session_id']:
if key in result and result[key] is not None:
result[key] = str(result[key])
# Convert datetimes to ISO strings
for key in ['created_at', 'updated_at', 'generated_at']:
if key in result and result[key] is not None:
result[key] = result[key].isoformat()
# Parse JSONB fields back to dicts/lists
for key in ['ocr_prompts', 'processed_pages', 'successful_pages', 'failed_pages', 'worksheet_types']:
if key in result and result[key] is not None:
if isinstance(result[key], str):
result[key] = json.loads(result[key])
return result