feat(klausur-service): Add Tesseract OCR, DSFA RAG, TrOCR, grid detection and vocab session store
New modules: - tesseract_vocab_extractor.py: Bounding-box OCR with multi-PSM pipeline - grid_detection_service.py: CV-based grid/table detection for worksheets - vocab_session_store.py: PostgreSQL persistence for vocab sessions - trocr_api.py: TrOCR handwriting recognition endpoint - dsfa_rag_api.py + dsfa_corpus_ingestion.py: DSFA RAG corpus search Changes: - Dockerfile: Install tesseract-ocr + deu/eng language packs - requirements.txt: Add PyMuPDF, pytesseract, Pillow - main.py: Register new routers, init DB pools + Qdrant collections Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -13,9 +13,12 @@ FROM python:3.11-slim
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Install system dependencies
|
||||
# Install system dependencies (incl. Tesseract OCR for bounding-box extraction)
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
curl \
|
||||
tesseract-ocr \
|
||||
tesseract-ocr-deu \
|
||||
tesseract-ocr-eng \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Install Python dependencies
|
||||
|
||||
1501
klausur-service/backend/dsfa_corpus_ingestion.py
Normal file
1501
klausur-service/backend/dsfa_corpus_ingestion.py
Normal file
File diff suppressed because it is too large
Load Diff
715
klausur-service/backend/dsfa_rag_api.py
Normal file
715
klausur-service/backend/dsfa_rag_api.py
Normal file
@@ -0,0 +1,715 @@
|
||||
"""
|
||||
DSFA RAG API Endpoints.
|
||||
|
||||
Provides REST API for searching DSFA corpus with full source attribution.
|
||||
|
||||
Endpoints:
|
||||
- GET /api/v1/dsfa-rag/search - Semantic search with attribution
|
||||
- GET /api/v1/dsfa-rag/sources - List all registered sources
|
||||
- POST /api/v1/dsfa-rag/sources/{code}/ingest - Trigger source ingestion
|
||||
- GET /api/v1/dsfa-rag/chunks/{id} - Get single chunk with attribution
|
||||
- GET /api/v1/dsfa-rag/stats - Get corpus statistics
|
||||
"""
|
||||
|
||||
import os
|
||||
import uuid
|
||||
import logging
|
||||
from typing import List, Optional
|
||||
from dataclasses import dataclass, asdict
|
||||
|
||||
import httpx
|
||||
from fastapi import APIRouter, HTTPException, Query, Depends
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Embedding service configuration
|
||||
EMBEDDING_SERVICE_URL = os.getenv("EMBEDDING_SERVICE_URL", "http://172.18.0.13:8087")
|
||||
|
||||
# Import from ingestion module
|
||||
from dsfa_corpus_ingestion import (
|
||||
DSFACorpusStore,
|
||||
DSFAQdrantService,
|
||||
DSFASearchResult,
|
||||
LICENSE_REGISTRY,
|
||||
DSFA_SOURCES,
|
||||
generate_attribution_notice,
|
||||
get_license_label,
|
||||
DSFA_COLLECTION,
|
||||
chunk_document
|
||||
)
|
||||
|
||||
router = APIRouter(prefix="/api/v1/dsfa-rag", tags=["DSFA RAG"])
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Pydantic Models
|
||||
# =============================================================================
|
||||
|
||||
class DSFASourceResponse(BaseModel):
|
||||
"""Response model for DSFA source."""
|
||||
id: str
|
||||
source_code: str
|
||||
name: str
|
||||
full_name: Optional[str] = None
|
||||
organization: Optional[str] = None
|
||||
source_url: Optional[str] = None
|
||||
license_code: str
|
||||
license_name: str
|
||||
license_url: Optional[str] = None
|
||||
attribution_required: bool
|
||||
attribution_text: str
|
||||
document_type: Optional[str] = None
|
||||
language: str = "de"
|
||||
|
||||
|
||||
class DSFAChunkResponse(BaseModel):
|
||||
"""Response model for a single chunk with attribution."""
|
||||
chunk_id: str
|
||||
content: str
|
||||
section_title: Optional[str] = None
|
||||
page_number: Optional[int] = None
|
||||
category: Optional[str] = None
|
||||
|
||||
# Document info
|
||||
document_id: str
|
||||
document_title: Optional[str] = None
|
||||
|
||||
# Attribution (always included)
|
||||
source_id: str
|
||||
source_code: str
|
||||
source_name: str
|
||||
attribution_text: str
|
||||
license_code: str
|
||||
license_name: str
|
||||
license_url: Optional[str] = None
|
||||
attribution_required: bool
|
||||
source_url: Optional[str] = None
|
||||
document_type: Optional[str] = None
|
||||
|
||||
|
||||
class DSFASearchResultResponse(BaseModel):
|
||||
"""Response model for search result."""
|
||||
chunk_id: str
|
||||
content: str
|
||||
score: float
|
||||
|
||||
# Attribution
|
||||
source_code: str
|
||||
source_name: str
|
||||
attribution_text: str
|
||||
license_code: str
|
||||
license_name: str
|
||||
license_url: Optional[str] = None
|
||||
attribution_required: bool
|
||||
source_url: Optional[str] = None
|
||||
|
||||
# Metadata
|
||||
document_type: Optional[str] = None
|
||||
category: Optional[str] = None
|
||||
section_title: Optional[str] = None
|
||||
page_number: Optional[int] = None
|
||||
|
||||
|
||||
class DSFASearchResponse(BaseModel):
|
||||
"""Response model for search endpoint."""
|
||||
query: str
|
||||
results: List[DSFASearchResultResponse]
|
||||
total_results: int
|
||||
|
||||
# Aggregated licenses for footer
|
||||
licenses_used: List[str]
|
||||
attribution_notice: str
|
||||
|
||||
|
||||
class DSFASourceStatsResponse(BaseModel):
|
||||
"""Response model for source statistics."""
|
||||
source_id: str
|
||||
source_code: str
|
||||
name: str
|
||||
organization: Optional[str] = None
|
||||
license_code: str
|
||||
document_type: Optional[str] = None
|
||||
document_count: int
|
||||
chunk_count: int
|
||||
last_indexed_at: Optional[str] = None
|
||||
|
||||
|
||||
class DSFACorpusStatsResponse(BaseModel):
|
||||
"""Response model for corpus statistics."""
|
||||
sources: List[DSFASourceStatsResponse]
|
||||
total_sources: int
|
||||
total_documents: int
|
||||
total_chunks: int
|
||||
qdrant_collection: str
|
||||
qdrant_points_count: int
|
||||
qdrant_status: str
|
||||
|
||||
|
||||
class IngestRequest(BaseModel):
|
||||
"""Request model for ingestion."""
|
||||
document_url: Optional[str] = None
|
||||
document_text: Optional[str] = None
|
||||
title: Optional[str] = None
|
||||
|
||||
|
||||
class IngestResponse(BaseModel):
|
||||
"""Response model for ingestion."""
|
||||
source_code: str
|
||||
document_id: Optional[str] = None
|
||||
chunks_created: int
|
||||
message: str
|
||||
|
||||
|
||||
class LicenseInfo(BaseModel):
|
||||
"""License information."""
|
||||
code: str
|
||||
name: str
|
||||
url: Optional[str] = None
|
||||
attribution_required: bool
|
||||
modification_allowed: bool
|
||||
commercial_use: bool
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Dependency Injection
|
||||
# =============================================================================
|
||||
|
||||
# Database pool (will be set from main.py)
|
||||
_db_pool = None
|
||||
|
||||
|
||||
def set_db_pool(pool):
|
||||
"""Set the database pool for API endpoints."""
|
||||
global _db_pool
|
||||
_db_pool = pool
|
||||
|
||||
|
||||
async def get_store() -> DSFACorpusStore:
|
||||
"""Get DSFA corpus store."""
|
||||
if _db_pool is None:
|
||||
raise HTTPException(status_code=503, detail="Database not initialized")
|
||||
return DSFACorpusStore(_db_pool)
|
||||
|
||||
|
||||
async def get_qdrant() -> DSFAQdrantService:
|
||||
"""Get Qdrant service."""
|
||||
return DSFAQdrantService()
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Embedding Service Integration
|
||||
# =============================================================================
|
||||
|
||||
async def get_embedding(text: str) -> List[float]:
|
||||
"""
|
||||
Get embedding for text using the embedding-service.
|
||||
|
||||
Uses BGE-M3 model which produces 1024-dimensional vectors.
|
||||
"""
|
||||
async with httpx.AsyncClient(timeout=60.0) as client:
|
||||
try:
|
||||
response = await client.post(
|
||||
f"{EMBEDDING_SERVICE_URL}/embed-single",
|
||||
json={"text": text}
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
return data.get("embedding", [])
|
||||
except httpx.HTTPError as e:
|
||||
logger.error(f"Embedding service error: {e}")
|
||||
# Fallback to hash-based pseudo-embedding for development
|
||||
return _generate_fallback_embedding(text)
|
||||
|
||||
|
||||
async def get_embeddings_batch(texts: List[str]) -> List[List[float]]:
|
||||
"""
|
||||
Get embeddings for multiple texts in batch.
|
||||
"""
|
||||
async with httpx.AsyncClient(timeout=120.0) as client:
|
||||
try:
|
||||
response = await client.post(
|
||||
f"{EMBEDDING_SERVICE_URL}/embed",
|
||||
json={"texts": texts}
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
return data.get("embeddings", [])
|
||||
except httpx.HTTPError as e:
|
||||
logger.error(f"Embedding service batch error: {e}")
|
||||
# Fallback
|
||||
return [_generate_fallback_embedding(t) for t in texts]
|
||||
|
||||
|
||||
async def extract_text_from_url(url: str) -> str:
|
||||
"""
|
||||
Extract text from a document URL (PDF, HTML, etc.).
|
||||
"""
|
||||
async with httpx.AsyncClient(timeout=120.0) as client:
|
||||
try:
|
||||
# First try to use the embedding-service's extract-pdf endpoint
|
||||
response = await client.post(
|
||||
f"{EMBEDDING_SERVICE_URL}/extract-pdf",
|
||||
json={"url": url}
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
return data.get("text", "")
|
||||
except httpx.HTTPError as e:
|
||||
logger.error(f"PDF extraction error for {url}: {e}")
|
||||
# Fallback: try to fetch HTML content directly
|
||||
try:
|
||||
response = await client.get(url, follow_redirects=True)
|
||||
response.raise_for_status()
|
||||
content_type = response.headers.get("content-type", "")
|
||||
if "html" in content_type:
|
||||
# Simple HTML text extraction
|
||||
import re
|
||||
html = response.text
|
||||
# Remove scripts and styles
|
||||
html = re.sub(r'<script[^>]*>.*?</script>', '', html, flags=re.DOTALL | re.IGNORECASE)
|
||||
html = re.sub(r'<style[^>]*>.*?</style>', '', html, flags=re.DOTALL | re.IGNORECASE)
|
||||
# Remove tags
|
||||
text = re.sub(r'<[^>]+>', ' ', html)
|
||||
# Clean whitespace
|
||||
text = re.sub(r'\s+', ' ', text).strip()
|
||||
return text
|
||||
else:
|
||||
return ""
|
||||
except Exception as fetch_err:
|
||||
logger.error(f"Fallback fetch error for {url}: {fetch_err}")
|
||||
return ""
|
||||
|
||||
|
||||
def _generate_fallback_embedding(text: str) -> List[float]:
|
||||
"""
|
||||
Generate deterministic pseudo-embedding from text hash.
|
||||
Used as fallback when embedding service is unavailable.
|
||||
"""
|
||||
import hashlib
|
||||
import struct
|
||||
|
||||
hash_bytes = hashlib.sha256(text.encode()).digest()
|
||||
embedding = []
|
||||
for i in range(0, min(len(hash_bytes), 128), 4):
|
||||
val = struct.unpack('f', hash_bytes[i:i+4])[0]
|
||||
embedding.append(val % 1.0)
|
||||
|
||||
# Pad to 1024 dimensions
|
||||
while len(embedding) < 1024:
|
||||
embedding.extend(embedding[:min(len(embedding), 1024 - len(embedding))])
|
||||
|
||||
return embedding[:1024]
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# API Endpoints
|
||||
# =============================================================================
|
||||
|
||||
@router.get("/search", response_model=DSFASearchResponse)
|
||||
async def search_dsfa_corpus(
|
||||
query: str = Query(..., min_length=3, description="Search query"),
|
||||
source_codes: Optional[List[str]] = Query(None, description="Filter by source codes"),
|
||||
document_types: Optional[List[str]] = Query(None, description="Filter by document types (guideline, checklist, regulation)"),
|
||||
categories: Optional[List[str]] = Query(None, description="Filter by categories (threshold_analysis, risk_assessment, mitigation)"),
|
||||
limit: int = Query(10, ge=1, le=50, description="Maximum results"),
|
||||
include_attribution: bool = Query(True, description="Include attribution in results"),
|
||||
store: DSFACorpusStore = Depends(get_store),
|
||||
qdrant: DSFAQdrantService = Depends(get_qdrant)
|
||||
):
|
||||
"""
|
||||
Search DSFA corpus with full attribution.
|
||||
|
||||
Returns matching chunks with source/license information for compliance.
|
||||
"""
|
||||
# Get query embedding
|
||||
query_embedding = await get_embedding(query)
|
||||
|
||||
# Search Qdrant
|
||||
raw_results = await qdrant.search(
|
||||
query_embedding=query_embedding,
|
||||
source_codes=source_codes,
|
||||
document_types=document_types,
|
||||
categories=categories,
|
||||
limit=limit
|
||||
)
|
||||
|
||||
# Transform results
|
||||
results = []
|
||||
licenses_used = set()
|
||||
|
||||
for r in raw_results:
|
||||
license_code = r.get("license_code", "")
|
||||
license_info = LICENSE_REGISTRY.get(license_code, {})
|
||||
|
||||
result = DSFASearchResultResponse(
|
||||
chunk_id=r.get("chunk_id", ""),
|
||||
content=r.get("content", ""),
|
||||
score=r.get("score", 0.0),
|
||||
source_code=r.get("source_code", ""),
|
||||
source_name=r.get("source_name", ""),
|
||||
attribution_text=r.get("attribution_text", ""),
|
||||
license_code=license_code,
|
||||
license_name=license_info.get("name", license_code),
|
||||
license_url=license_info.get("url"),
|
||||
attribution_required=r.get("attribution_required", True),
|
||||
source_url=r.get("source_url"),
|
||||
document_type=r.get("document_type"),
|
||||
category=r.get("category"),
|
||||
section_title=r.get("section_title"),
|
||||
page_number=r.get("page_number")
|
||||
)
|
||||
results.append(result)
|
||||
licenses_used.add(license_code)
|
||||
|
||||
# Generate attribution notice
|
||||
search_results = [
|
||||
DSFASearchResult(
|
||||
chunk_id=r.chunk_id,
|
||||
content=r.content,
|
||||
score=r.score,
|
||||
source_code=r.source_code,
|
||||
source_name=r.source_name,
|
||||
attribution_text=r.attribution_text,
|
||||
license_code=r.license_code,
|
||||
license_url=r.license_url,
|
||||
attribution_required=r.attribution_required,
|
||||
source_url=r.source_url,
|
||||
document_type=r.document_type or "",
|
||||
category=r.category or "",
|
||||
section_title=r.section_title,
|
||||
page_number=r.page_number
|
||||
)
|
||||
for r in results
|
||||
]
|
||||
attribution_notice = generate_attribution_notice(search_results) if include_attribution else ""
|
||||
|
||||
return DSFASearchResponse(
|
||||
query=query,
|
||||
results=results,
|
||||
total_results=len(results),
|
||||
licenses_used=list(licenses_used),
|
||||
attribution_notice=attribution_notice
|
||||
)
|
||||
|
||||
|
||||
@router.get("/sources", response_model=List[DSFASourceResponse])
|
||||
async def list_dsfa_sources(
|
||||
document_type: Optional[str] = Query(None, description="Filter by document type"),
|
||||
license_code: Optional[str] = Query(None, description="Filter by license"),
|
||||
store: DSFACorpusStore = Depends(get_store)
|
||||
):
|
||||
"""List all registered DSFA sources with license info."""
|
||||
sources = await store.list_sources()
|
||||
|
||||
result = []
|
||||
for s in sources:
|
||||
# Apply filters
|
||||
if document_type and s.get("document_type") != document_type:
|
||||
continue
|
||||
if license_code and s.get("license_code") != license_code:
|
||||
continue
|
||||
|
||||
license_info = LICENSE_REGISTRY.get(s.get("license_code", ""), {})
|
||||
|
||||
result.append(DSFASourceResponse(
|
||||
id=str(s["id"]),
|
||||
source_code=s["source_code"],
|
||||
name=s["name"],
|
||||
full_name=s.get("full_name"),
|
||||
organization=s.get("organization"),
|
||||
source_url=s.get("source_url"),
|
||||
license_code=s.get("license_code", ""),
|
||||
license_name=license_info.get("name", s.get("license_code", "")),
|
||||
license_url=license_info.get("url"),
|
||||
attribution_required=s.get("attribution_required", True),
|
||||
attribution_text=s.get("attribution_text", ""),
|
||||
document_type=s.get("document_type"),
|
||||
language=s.get("language", "de")
|
||||
))
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@router.get("/sources/available")
|
||||
async def list_available_sources():
|
||||
"""List all available source definitions (from DSFA_SOURCES constant)."""
|
||||
return [
|
||||
{
|
||||
"source_code": s["source_code"],
|
||||
"name": s["name"],
|
||||
"organization": s.get("organization"),
|
||||
"license_code": s["license_code"],
|
||||
"document_type": s.get("document_type")
|
||||
}
|
||||
for s in DSFA_SOURCES
|
||||
]
|
||||
|
||||
|
||||
@router.get("/sources/{source_code}", response_model=DSFASourceResponse)
|
||||
async def get_dsfa_source(
|
||||
source_code: str,
|
||||
store: DSFACorpusStore = Depends(get_store)
|
||||
):
|
||||
"""Get details for a specific source."""
|
||||
source = await store.get_source_by_code(source_code)
|
||||
|
||||
if not source:
|
||||
raise HTTPException(status_code=404, detail=f"Source not found: {source_code}")
|
||||
|
||||
license_info = LICENSE_REGISTRY.get(source.get("license_code", ""), {})
|
||||
|
||||
return DSFASourceResponse(
|
||||
id=str(source["id"]),
|
||||
source_code=source["source_code"],
|
||||
name=source["name"],
|
||||
full_name=source.get("full_name"),
|
||||
organization=source.get("organization"),
|
||||
source_url=source.get("source_url"),
|
||||
license_code=source.get("license_code", ""),
|
||||
license_name=license_info.get("name", source.get("license_code", "")),
|
||||
license_url=license_info.get("url"),
|
||||
attribution_required=source.get("attribution_required", True),
|
||||
attribution_text=source.get("attribution_text", ""),
|
||||
document_type=source.get("document_type"),
|
||||
language=source.get("language", "de")
|
||||
)
|
||||
|
||||
|
||||
@router.post("/sources/{source_code}/ingest", response_model=IngestResponse)
|
||||
async def ingest_dsfa_source(
|
||||
source_code: str,
|
||||
request: IngestRequest,
|
||||
store: DSFACorpusStore = Depends(get_store),
|
||||
qdrant: DSFAQdrantService = Depends(get_qdrant)
|
||||
):
|
||||
"""
|
||||
Trigger ingestion for a specific source.
|
||||
|
||||
Can provide document via URL or direct text.
|
||||
"""
|
||||
# Get source
|
||||
source = await store.get_source_by_code(source_code)
|
||||
if not source:
|
||||
raise HTTPException(status_code=404, detail=f"Source not found: {source_code}")
|
||||
|
||||
# Need either URL or text
|
||||
if not request.document_text and not request.document_url:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Either document_text or document_url must be provided"
|
||||
)
|
||||
|
||||
# Ensure Qdrant collection exists
|
||||
await qdrant.ensure_collection()
|
||||
|
||||
# Get text content
|
||||
text_content = request.document_text
|
||||
if request.document_url and not text_content:
|
||||
# Download and extract text from URL
|
||||
logger.info(f"Extracting text from URL: {request.document_url}")
|
||||
text_content = await extract_text_from_url(request.document_url)
|
||||
if not text_content:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Could not extract text from URL: {request.document_url}"
|
||||
)
|
||||
|
||||
if not text_content or len(text_content.strip()) < 50:
|
||||
raise HTTPException(status_code=400, detail="Document text too short (min 50 chars)")
|
||||
|
||||
# Create document record
|
||||
doc_title = request.title or f"Document for {source_code}"
|
||||
document_id = await store.create_document(
|
||||
source_id=str(source["id"]),
|
||||
title=doc_title,
|
||||
file_type="text",
|
||||
metadata={"ingested_via": "api", "source_code": source_code}
|
||||
)
|
||||
|
||||
# Chunk the document
|
||||
chunks = chunk_document(text_content, source_code)
|
||||
|
||||
if not chunks:
|
||||
return IngestResponse(
|
||||
source_code=source_code,
|
||||
document_id=document_id,
|
||||
chunks_created=0,
|
||||
message="Document created but no chunks generated"
|
||||
)
|
||||
|
||||
# Generate embeddings in batch for efficiency
|
||||
chunk_texts = [chunk["content"] for chunk in chunks]
|
||||
logger.info(f"Generating embeddings for {len(chunk_texts)} chunks...")
|
||||
embeddings = await get_embeddings_batch(chunk_texts)
|
||||
|
||||
# Create chunk records in PostgreSQL and prepare for Qdrant
|
||||
chunk_records = []
|
||||
for i, (chunk, embedding) in enumerate(zip(chunks, embeddings)):
|
||||
# Create chunk in PostgreSQL
|
||||
chunk_id = await store.create_chunk(
|
||||
document_id=document_id,
|
||||
source_id=str(source["id"]),
|
||||
content=chunk["content"],
|
||||
chunk_index=i,
|
||||
section_title=chunk.get("section_title"),
|
||||
page_number=chunk.get("page_number"),
|
||||
category=chunk.get("category")
|
||||
)
|
||||
|
||||
chunk_records.append({
|
||||
"chunk_id": chunk_id,
|
||||
"document_id": document_id,
|
||||
"source_id": str(source["id"]),
|
||||
"content": chunk["content"],
|
||||
"section_title": chunk.get("section_title"),
|
||||
"source_code": source_code,
|
||||
"source_name": source["name"],
|
||||
"attribution_text": source["attribution_text"],
|
||||
"license_code": source["license_code"],
|
||||
"attribution_required": source.get("attribution_required", True),
|
||||
"document_type": source.get("document_type", ""),
|
||||
"category": chunk.get("category", ""),
|
||||
"language": source.get("language", "de"),
|
||||
"page_number": chunk.get("page_number")
|
||||
})
|
||||
|
||||
# Index in Qdrant
|
||||
indexed_count = await qdrant.index_chunks(chunk_records, embeddings)
|
||||
|
||||
# Update document record
|
||||
await store.update_document_indexed(document_id, len(chunks))
|
||||
|
||||
return IngestResponse(
|
||||
source_code=source_code,
|
||||
document_id=document_id,
|
||||
chunks_created=indexed_count,
|
||||
message=f"Successfully ingested {indexed_count} chunks from document"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/chunks/{chunk_id}", response_model=DSFAChunkResponse)
|
||||
async def get_chunk_with_attribution(
|
||||
chunk_id: str,
|
||||
store: DSFACorpusStore = Depends(get_store)
|
||||
):
|
||||
"""Get single chunk with full source attribution."""
|
||||
chunk = await store.get_chunk_with_attribution(chunk_id)
|
||||
|
||||
if not chunk:
|
||||
raise HTTPException(status_code=404, detail=f"Chunk not found: {chunk_id}")
|
||||
|
||||
license_info = LICENSE_REGISTRY.get(chunk.get("license_code", ""), {})
|
||||
|
||||
return DSFAChunkResponse(
|
||||
chunk_id=str(chunk["chunk_id"]),
|
||||
content=chunk.get("content", ""),
|
||||
section_title=chunk.get("section_title"),
|
||||
page_number=chunk.get("page_number"),
|
||||
category=chunk.get("category"),
|
||||
document_id=str(chunk.get("document_id", "")),
|
||||
document_title=chunk.get("document_title"),
|
||||
source_id=str(chunk.get("source_id", "")),
|
||||
source_code=chunk.get("source_code", ""),
|
||||
source_name=chunk.get("source_name", ""),
|
||||
attribution_text=chunk.get("attribution_text", ""),
|
||||
license_code=chunk.get("license_code", ""),
|
||||
license_name=license_info.get("name", chunk.get("license_code", "")),
|
||||
license_url=license_info.get("url"),
|
||||
attribution_required=chunk.get("attribution_required", True),
|
||||
source_url=chunk.get("source_url"),
|
||||
document_type=chunk.get("document_type")
|
||||
)
|
||||
|
||||
|
||||
@router.get("/stats", response_model=DSFACorpusStatsResponse)
|
||||
async def get_corpus_stats(
|
||||
store: DSFACorpusStore = Depends(get_store),
|
||||
qdrant: DSFAQdrantService = Depends(get_qdrant)
|
||||
):
|
||||
"""Get corpus statistics for dashboard."""
|
||||
# Get PostgreSQL stats
|
||||
source_stats = await store.get_source_stats()
|
||||
|
||||
total_docs = 0
|
||||
total_chunks = 0
|
||||
stats_response = []
|
||||
|
||||
for s in source_stats:
|
||||
doc_count = s.get("document_count", 0) or 0
|
||||
chunk_count = s.get("chunk_count", 0) or 0
|
||||
total_docs += doc_count
|
||||
total_chunks += chunk_count
|
||||
|
||||
last_indexed = s.get("last_indexed_at")
|
||||
|
||||
stats_response.append(DSFASourceStatsResponse(
|
||||
source_id=str(s.get("source_id", "")),
|
||||
source_code=s.get("source_code", ""),
|
||||
name=s.get("name", ""),
|
||||
organization=s.get("organization"),
|
||||
license_code=s.get("license_code", ""),
|
||||
document_type=s.get("document_type"),
|
||||
document_count=doc_count,
|
||||
chunk_count=chunk_count,
|
||||
last_indexed_at=last_indexed.isoformat() if last_indexed else None
|
||||
))
|
||||
|
||||
# Get Qdrant stats
|
||||
qdrant_stats = await qdrant.get_stats()
|
||||
|
||||
return DSFACorpusStatsResponse(
|
||||
sources=stats_response,
|
||||
total_sources=len(source_stats),
|
||||
total_documents=total_docs,
|
||||
total_chunks=total_chunks,
|
||||
qdrant_collection=DSFA_COLLECTION,
|
||||
qdrant_points_count=qdrant_stats.get("points_count", 0),
|
||||
qdrant_status=qdrant_stats.get("status", "unknown")
|
||||
)
|
||||
|
||||
|
||||
@router.get("/licenses")
|
||||
async def list_licenses():
|
||||
"""List all supported licenses with their terms."""
|
||||
return [
|
||||
LicenseInfo(
|
||||
code=code,
|
||||
name=info.get("name", code),
|
||||
url=info.get("url"),
|
||||
attribution_required=info.get("attribution_required", True),
|
||||
modification_allowed=info.get("modification_allowed", True),
|
||||
commercial_use=info.get("commercial_use", True)
|
||||
)
|
||||
for code, info in LICENSE_REGISTRY.items()
|
||||
]
|
||||
|
||||
|
||||
@router.post("/init")
|
||||
async def initialize_dsfa_corpus(
|
||||
store: DSFACorpusStore = Depends(get_store),
|
||||
qdrant: DSFAQdrantService = Depends(get_qdrant)
|
||||
):
|
||||
"""
|
||||
Initialize DSFA corpus.
|
||||
|
||||
- Creates Qdrant collection
|
||||
- Registers all predefined sources
|
||||
"""
|
||||
# Ensure Qdrant collection exists
|
||||
qdrant_ok = await qdrant.ensure_collection()
|
||||
|
||||
# Register all sources
|
||||
registered = 0
|
||||
for source in DSFA_SOURCES:
|
||||
try:
|
||||
await store.register_source(source)
|
||||
registered += 1
|
||||
except Exception as e:
|
||||
print(f"Error registering source {source['source_code']}: {e}")
|
||||
|
||||
return {
|
||||
"qdrant_collection_created": qdrant_ok,
|
||||
"sources_registered": registered,
|
||||
"total_sources": len(DSFA_SOURCES)
|
||||
}
|
||||
@@ -20,6 +20,7 @@ This is the main entry point. All functionality is organized in modular packages
|
||||
import os
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
import asyncpg
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
@@ -36,7 +37,19 @@ from admin_api import router as admin_router
|
||||
from zeugnis_api import router as zeugnis_router
|
||||
from training_api import router as training_router
|
||||
from mail.api import router as mail_router
|
||||
try:
|
||||
from trocr_api import router as trocr_router
|
||||
except ImportError:
|
||||
trocr_router = None
|
||||
from vocab_worksheet_api import router as vocab_router, set_db_pool as set_vocab_db_pool, _init_vocab_table, _load_all_sessions, DATABASE_URL as VOCAB_DATABASE_URL
|
||||
try:
|
||||
from dsfa_rag_api import router as dsfa_rag_router, set_db_pool as set_dsfa_db_pool
|
||||
from dsfa_corpus_ingestion import DSFAQdrantService, DATABASE_URL as DSFA_DATABASE_URL
|
||||
except ImportError:
|
||||
dsfa_rag_router = None
|
||||
set_dsfa_db_pool = None
|
||||
DSFAQdrantService = None
|
||||
DSFA_DATABASE_URL = None
|
||||
|
||||
# BYOEH Qdrant initialization
|
||||
from qdrant_service import init_qdrant_collection
|
||||
@@ -51,12 +64,42 @@ async def lifespan(app: FastAPI):
|
||||
"""Application lifespan manager for startup and shutdown events."""
|
||||
print("Klausur-Service starting...")
|
||||
|
||||
# Initialize database pool for Vocab Sessions
|
||||
vocab_db_pool = None
|
||||
try:
|
||||
vocab_db_pool = await asyncpg.create_pool(VOCAB_DATABASE_URL, min_size=2, max_size=5)
|
||||
set_vocab_db_pool(vocab_db_pool)
|
||||
await _init_vocab_table()
|
||||
await _load_all_sessions()
|
||||
print(f"Vocab sessions database initialized")
|
||||
except Exception as e:
|
||||
print(f"Warning: Vocab sessions database initialization failed: {e}")
|
||||
|
||||
# Initialize database pool for DSFA RAG
|
||||
dsfa_db_pool = None
|
||||
if DSFA_DATABASE_URL and set_dsfa_db_pool:
|
||||
try:
|
||||
dsfa_db_pool = await asyncpg.create_pool(DSFA_DATABASE_URL, min_size=2, max_size=10)
|
||||
set_dsfa_db_pool(dsfa_db_pool)
|
||||
print(f"DSFA database pool initialized: {DSFA_DATABASE_URL}")
|
||||
except Exception as e:
|
||||
print(f"Warning: DSFA database pool initialization failed: {e}")
|
||||
|
||||
# Initialize Qdrant collection for BYOEH
|
||||
try:
|
||||
await init_qdrant_collection()
|
||||
print("Qdrant BYOEH collection initialized")
|
||||
except Exception as e:
|
||||
print(f"Warning: Qdrant initialization failed: {e}")
|
||||
print(f"Warning: Qdrant BYOEH initialization failed: {e}")
|
||||
|
||||
# Initialize Qdrant collection for DSFA RAG
|
||||
if DSFAQdrantService:
|
||||
try:
|
||||
dsfa_qdrant = DSFAQdrantService()
|
||||
await dsfa_qdrant.ensure_collection()
|
||||
print("Qdrant DSFA corpus collection initialized")
|
||||
except Exception as e:
|
||||
print(f"Warning: Qdrant DSFA initialization failed: {e}")
|
||||
|
||||
# Ensure EH upload directory exists
|
||||
os.makedirs(EH_UPLOAD_DIR, exist_ok=True)
|
||||
@@ -65,6 +108,16 @@ async def lifespan(app: FastAPI):
|
||||
|
||||
print("Klausur-Service shutting down...")
|
||||
|
||||
# Close Vocab sessions database pool
|
||||
if vocab_db_pool:
|
||||
await vocab_db_pool.close()
|
||||
print("Vocab sessions database pool closed")
|
||||
|
||||
# Close DSFA database pool
|
||||
if dsfa_db_pool:
|
||||
await dsfa_db_pool.close()
|
||||
print("DSFA database pool closed")
|
||||
|
||||
|
||||
app = FastAPI(
|
||||
title="Klausur-Service",
|
||||
@@ -94,7 +147,11 @@ app.include_router(admin_router) # NiBiS Ingestion
|
||||
app.include_router(zeugnis_router) # Zeugnis Rights-Aware Crawler
|
||||
app.include_router(training_router) # Training Management
|
||||
app.include_router(mail_router) # Unified Inbox Mail
|
||||
if trocr_router:
|
||||
app.include_router(trocr_router) # TrOCR Handwriting OCR
|
||||
app.include_router(vocab_router) # Vocabulary Worksheet Generator
|
||||
if dsfa_rag_router:
|
||||
app.include_router(dsfa_rag_router) # DSFA RAG Corpus Search
|
||||
|
||||
|
||||
# =============================================
|
||||
|
||||
@@ -9,6 +9,7 @@ python-dotenv>=1.0.0
|
||||
qdrant-client>=1.7.0
|
||||
cryptography>=41.0.0
|
||||
PyPDF2>=3.0.0
|
||||
PyMuPDF>=1.24.0
|
||||
|
||||
# PyTorch CPU-only (smaller, no CUDA needed for Docker on Mac)
|
||||
--extra-index-url https://download.pytorch.org/whl/cpu
|
||||
@@ -23,6 +24,10 @@ minio>=7.2.0
|
||||
# OpenCV for handwriting detection (headless = no GUI, smaller for CI)
|
||||
opencv-python-headless>=4.8.0
|
||||
|
||||
# Tesseract OCR Python binding (requires system tesseract-ocr package)
|
||||
pytesseract>=0.3.10
|
||||
Pillow>=10.0.0
|
||||
|
||||
# PostgreSQL (for metrics storage)
|
||||
psycopg2-binary>=2.9.0
|
||||
asyncpg>=0.29.0
|
||||
|
||||
509
klausur-service/backend/services/grid_detection_service.py
Normal file
509
klausur-service/backend/services/grid_detection_service.py
Normal file
@@ -0,0 +1,509 @@
|
||||
"""
|
||||
Grid Detection Service v4
|
||||
|
||||
Detects table/grid structure from OCR bounding-box data.
|
||||
Converts pixel coordinates to percentage and mm coordinates (A4 format).
|
||||
Supports deskew correction, column detection, and multi-line cell handling.
|
||||
|
||||
Lizenz: Apache 2.0 (kommerziell nutzbar)
|
||||
"""
|
||||
|
||||
import math
|
||||
import logging
|
||||
from enum import Enum
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, Optional, Dict, Any, Tuple
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# A4 dimensions
|
||||
A4_WIDTH_MM = 210.0
|
||||
A4_HEIGHT_MM = 297.0
|
||||
|
||||
# Column margin (1mm)
|
||||
COLUMN_MARGIN_MM = 1.0
|
||||
COLUMN_MARGIN_PCT = (COLUMN_MARGIN_MM / A4_WIDTH_MM) * 100
|
||||
|
||||
|
||||
class CellStatus(str, Enum):
|
||||
EMPTY = "empty"
|
||||
RECOGNIZED = "recognized"
|
||||
PROBLEMATIC = "problematic"
|
||||
MANUAL = "manual"
|
||||
|
||||
|
||||
class ColumnType(str, Enum):
|
||||
ENGLISH = "english"
|
||||
GERMAN = "german"
|
||||
EXAMPLE = "example"
|
||||
UNKNOWN = "unknown"
|
||||
|
||||
|
||||
@dataclass
|
||||
class OCRRegion:
|
||||
"""A word/phrase detected by OCR with bounding box coordinates in percentage (0-100)."""
|
||||
text: str
|
||||
confidence: float
|
||||
x: float # X position as percentage of page width
|
||||
y: float # Y position as percentage of page height
|
||||
width: float # Width as percentage of page width
|
||||
height: float # Height as percentage of page height
|
||||
|
||||
@property
|
||||
def x_mm(self) -> float:
|
||||
return round(self.x / 100 * A4_WIDTH_MM, 1)
|
||||
|
||||
@property
|
||||
def y_mm(self) -> float:
|
||||
return round(self.y / 100 * A4_HEIGHT_MM, 1)
|
||||
|
||||
@property
|
||||
def width_mm(self) -> float:
|
||||
return round(self.width / 100 * A4_WIDTH_MM, 1)
|
||||
|
||||
@property
|
||||
def height_mm(self) -> float:
|
||||
return round(self.height / 100 * A4_HEIGHT_MM, 2)
|
||||
|
||||
@property
|
||||
def center_x(self) -> float:
|
||||
return self.x + self.width / 2
|
||||
|
||||
@property
|
||||
def center_y(self) -> float:
|
||||
return self.y + self.height / 2
|
||||
|
||||
@property
|
||||
def right(self) -> float:
|
||||
return self.x + self.width
|
||||
|
||||
@property
|
||||
def bottom(self) -> float:
|
||||
return self.y + self.height
|
||||
|
||||
|
||||
@dataclass
|
||||
class GridCell:
|
||||
"""A cell in the detected grid with coordinates in percentage (0-100)."""
|
||||
row: int
|
||||
col: int
|
||||
x: float
|
||||
y: float
|
||||
width: float
|
||||
height: float
|
||||
text: str = ""
|
||||
confidence: float = 0.0
|
||||
status: CellStatus = CellStatus.EMPTY
|
||||
column_type: ColumnType = ColumnType.UNKNOWN
|
||||
logical_row: int = 0
|
||||
logical_col: int = 0
|
||||
is_continuation: bool = False
|
||||
|
||||
@property
|
||||
def x_mm(self) -> float:
|
||||
return round(self.x / 100 * A4_WIDTH_MM, 1)
|
||||
|
||||
@property
|
||||
def y_mm(self) -> float:
|
||||
return round(self.y / 100 * A4_HEIGHT_MM, 1)
|
||||
|
||||
@property
|
||||
def width_mm(self) -> float:
|
||||
return round(self.width / 100 * A4_WIDTH_MM, 1)
|
||||
|
||||
@property
|
||||
def height_mm(self) -> float:
|
||||
return round(self.height / 100 * A4_HEIGHT_MM, 2)
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
"row": self.row,
|
||||
"col": self.col,
|
||||
"x": round(self.x, 2),
|
||||
"y": round(self.y, 2),
|
||||
"width": round(self.width, 2),
|
||||
"height": round(self.height, 2),
|
||||
"x_mm": self.x_mm,
|
||||
"y_mm": self.y_mm,
|
||||
"width_mm": self.width_mm,
|
||||
"height_mm": self.height_mm,
|
||||
"text": self.text,
|
||||
"confidence": self.confidence,
|
||||
"status": self.status.value,
|
||||
"column_type": self.column_type.value,
|
||||
"logical_row": self.logical_row,
|
||||
"logical_col": self.logical_col,
|
||||
"is_continuation": self.is_continuation,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class GridResult:
|
||||
"""Result of grid detection."""
|
||||
rows: int = 0
|
||||
columns: int = 0
|
||||
cells: List[List[GridCell]] = field(default_factory=list)
|
||||
column_types: List[str] = field(default_factory=list)
|
||||
column_boundaries: List[float] = field(default_factory=list)
|
||||
row_boundaries: List[float] = field(default_factory=list)
|
||||
deskew_angle: float = 0.0
|
||||
stats: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
cells_dicts = []
|
||||
for row_cells in self.cells:
|
||||
cells_dicts.append([c.to_dict() for c in row_cells])
|
||||
|
||||
return {
|
||||
"rows": self.rows,
|
||||
"columns": self.columns,
|
||||
"cells": cells_dicts,
|
||||
"column_types": self.column_types,
|
||||
"column_boundaries": [round(b, 2) for b in self.column_boundaries],
|
||||
"row_boundaries": [round(b, 2) for b in self.row_boundaries],
|
||||
"deskew_angle": round(self.deskew_angle, 2),
|
||||
"stats": self.stats,
|
||||
"page_dimensions": {
|
||||
"width_mm": A4_WIDTH_MM,
|
||||
"height_mm": A4_HEIGHT_MM,
|
||||
"format": "A4",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class GridDetectionService:
|
||||
"""Detect grid/table structure from OCR bounding-box regions."""
|
||||
|
||||
def __init__(self, y_tolerance_pct: float = 1.5, padding_pct: float = 0.3,
|
||||
column_margin_mm: float = COLUMN_MARGIN_MM):
|
||||
self.y_tolerance_pct = y_tolerance_pct
|
||||
self.padding_pct = padding_pct
|
||||
self.column_margin_mm = column_margin_mm
|
||||
|
||||
def calculate_deskew_angle(self, regions: List[OCRRegion]) -> float:
|
||||
"""Calculate page skew angle from OCR region positions.
|
||||
|
||||
Uses left-edge alignment of regions to detect consistent tilt.
|
||||
Returns angle in degrees, clamped to ±5°.
|
||||
"""
|
||||
if len(regions) < 3:
|
||||
return 0.0
|
||||
|
||||
# Group by similar X position (same column)
|
||||
sorted_by_x = sorted(regions, key=lambda r: r.x)
|
||||
|
||||
# Find regions that are vertically aligned (similar X)
|
||||
x_tolerance = 3.0 # percent
|
||||
aligned_groups: List[List[OCRRegion]] = []
|
||||
current_group = [sorted_by_x[0]]
|
||||
|
||||
for r in sorted_by_x[1:]:
|
||||
if abs(r.x - current_group[0].x) <= x_tolerance:
|
||||
current_group.append(r)
|
||||
else:
|
||||
if len(current_group) >= 3:
|
||||
aligned_groups.append(current_group)
|
||||
current_group = [r]
|
||||
|
||||
if len(current_group) >= 3:
|
||||
aligned_groups.append(current_group)
|
||||
|
||||
if not aligned_groups:
|
||||
return 0.0
|
||||
|
||||
# Use the largest aligned group to calculate skew
|
||||
best_group = max(aligned_groups, key=len)
|
||||
best_group.sort(key=lambda r: r.y)
|
||||
|
||||
# Linear regression: X as function of Y
|
||||
n = len(best_group)
|
||||
sum_y = sum(r.y for r in best_group)
|
||||
sum_x = sum(r.x for r in best_group)
|
||||
sum_xy = sum(r.x * r.y for r in best_group)
|
||||
sum_y2 = sum(r.y ** 2 for r in best_group)
|
||||
|
||||
denom = n * sum_y2 - sum_y ** 2
|
||||
if denom == 0:
|
||||
return 0.0
|
||||
|
||||
slope = (n * sum_xy - sum_y * sum_x) / denom
|
||||
|
||||
# Convert slope to angle (slope is dx/dy in percent space)
|
||||
# Adjust for aspect ratio: A4 is 210/297 ≈ 0.707
|
||||
aspect = A4_WIDTH_MM / A4_HEIGHT_MM
|
||||
angle_rad = math.atan(slope * aspect)
|
||||
angle_deg = math.degrees(angle_rad)
|
||||
|
||||
# Clamp to ±5°
|
||||
return max(-5.0, min(5.0, round(angle_deg, 2)))
|
||||
|
||||
def apply_deskew_to_regions(self, regions: List[OCRRegion], angle: float) -> List[OCRRegion]:
|
||||
"""Apply deskew correction to region coordinates.
|
||||
|
||||
Rotates all coordinates around the page center by -angle.
|
||||
"""
|
||||
if abs(angle) < 0.01:
|
||||
return regions
|
||||
|
||||
angle_rad = math.radians(-angle)
|
||||
cos_a = math.cos(angle_rad)
|
||||
sin_a = math.sin(angle_rad)
|
||||
|
||||
# Page center
|
||||
cx, cy = 50.0, 50.0
|
||||
|
||||
result = []
|
||||
for r in regions:
|
||||
# Rotate center of region around page center
|
||||
rx = r.center_x - cx
|
||||
ry = r.center_y - cy
|
||||
new_cx = rx * cos_a - ry * sin_a + cx
|
||||
new_cy = rx * sin_a + ry * cos_a + cy
|
||||
new_x = new_cx - r.width / 2
|
||||
new_y = new_cy - r.height / 2
|
||||
|
||||
result.append(OCRRegion(
|
||||
text=r.text,
|
||||
confidence=r.confidence,
|
||||
x=round(new_x, 2),
|
||||
y=round(new_y, 2),
|
||||
width=r.width,
|
||||
height=r.height,
|
||||
))
|
||||
|
||||
return result
|
||||
|
||||
def _group_regions_into_rows(self, regions: List[OCRRegion]) -> List[List[OCRRegion]]:
|
||||
"""Group regions by Y position into rows."""
|
||||
if not regions:
|
||||
return []
|
||||
|
||||
sorted_regions = sorted(regions, key=lambda r: r.y)
|
||||
rows: List[List[OCRRegion]] = []
|
||||
current_row = [sorted_regions[0]]
|
||||
current_y = sorted_regions[0].center_y
|
||||
|
||||
for r in sorted_regions[1:]:
|
||||
if abs(r.center_y - current_y) <= self.y_tolerance_pct:
|
||||
current_row.append(r)
|
||||
else:
|
||||
current_row.sort(key=lambda r: r.x)
|
||||
rows.append(current_row)
|
||||
current_row = [r]
|
||||
current_y = r.center_y
|
||||
|
||||
if current_row:
|
||||
current_row.sort(key=lambda r: r.x)
|
||||
rows.append(current_row)
|
||||
|
||||
return rows
|
||||
|
||||
def _detect_column_boundaries(self, rows: List[List[OCRRegion]]) -> List[float]:
|
||||
"""Detect column boundaries from row data."""
|
||||
if not rows:
|
||||
return []
|
||||
|
||||
# Collect all X starting positions
|
||||
all_x = []
|
||||
for row in rows:
|
||||
for r in row:
|
||||
all_x.append(r.x)
|
||||
|
||||
if not all_x:
|
||||
return []
|
||||
|
||||
all_x.sort()
|
||||
|
||||
# Gap-based clustering
|
||||
min_gap = 5.0 # percent
|
||||
clusters: List[List[float]] = []
|
||||
current = [all_x[0]]
|
||||
|
||||
for x in all_x[1:]:
|
||||
if x - current[-1] > min_gap:
|
||||
clusters.append(current)
|
||||
current = [x]
|
||||
else:
|
||||
current.append(x)
|
||||
|
||||
if current:
|
||||
clusters.append(current)
|
||||
|
||||
# Column boundaries: start of each cluster
|
||||
boundaries = [min(c) - self.padding_pct for c in clusters]
|
||||
# Add right boundary
|
||||
boundaries.append(100.0)
|
||||
|
||||
return boundaries
|
||||
|
||||
def _assign_column_types(self, boundaries: List[float]) -> List[str]:
|
||||
"""Assign column types based on position."""
|
||||
num_cols = max(0, len(boundaries) - 1)
|
||||
type_map = [ColumnType.ENGLISH, ColumnType.GERMAN, ColumnType.EXAMPLE]
|
||||
result = []
|
||||
for i in range(num_cols):
|
||||
if i < len(type_map):
|
||||
result.append(type_map[i].value)
|
||||
else:
|
||||
result.append(ColumnType.UNKNOWN.value)
|
||||
return result
|
||||
|
||||
def detect_grid(self, regions: List[OCRRegion]) -> GridResult:
|
||||
"""Detect grid structure from OCR regions.
|
||||
|
||||
Args:
|
||||
regions: List of OCR regions with percentage-based coordinates.
|
||||
|
||||
Returns:
|
||||
GridResult with detected rows, columns, and cells.
|
||||
"""
|
||||
if not regions:
|
||||
return GridResult(stats={"recognized": 0, "problematic": 0, "empty": 0, "manual": 0, "total": 0, "coverage": 0.0})
|
||||
|
||||
# Step 1: Calculate and apply deskew
|
||||
deskew_angle = self.calculate_deskew_angle(regions)
|
||||
corrected_regions = self.apply_deskew_to_regions(regions, deskew_angle)
|
||||
|
||||
# Step 2: Group into rows
|
||||
rows = self._group_regions_into_rows(corrected_regions)
|
||||
|
||||
# Step 3: Detect column boundaries
|
||||
col_boundaries = self._detect_column_boundaries(rows)
|
||||
column_types = self._assign_column_types(col_boundaries)
|
||||
num_cols = max(1, len(col_boundaries) - 1)
|
||||
|
||||
# Step 4: Build cell grid
|
||||
num_rows = len(rows)
|
||||
row_boundaries = []
|
||||
cells = []
|
||||
|
||||
recognized = 0
|
||||
problematic = 0
|
||||
empty = 0
|
||||
|
||||
for row_idx, row_regions in enumerate(rows):
|
||||
# Row Y boundary
|
||||
if row_regions:
|
||||
row_y = min(r.y for r in row_regions) - self.padding_pct
|
||||
row_bottom = max(r.bottom for r in row_regions) + self.padding_pct
|
||||
else:
|
||||
row_y = row_idx / num_rows * 100
|
||||
row_bottom = (row_idx + 1) / num_rows * 100
|
||||
|
||||
row_boundaries.append(row_y)
|
||||
row_height = row_bottom - row_y
|
||||
|
||||
row_cells = []
|
||||
for col_idx in range(num_cols):
|
||||
col_x = col_boundaries[col_idx]
|
||||
col_right = col_boundaries[col_idx + 1] if col_idx + 1 < len(col_boundaries) else 100.0
|
||||
col_width = col_right - col_x
|
||||
|
||||
# Find regions in this cell
|
||||
cell_regions = []
|
||||
for r in row_regions:
|
||||
r_center = r.center_x
|
||||
if col_x <= r_center < col_right:
|
||||
cell_regions.append(r)
|
||||
|
||||
if cell_regions:
|
||||
text = " ".join(r.text for r in cell_regions)
|
||||
avg_conf = sum(r.confidence for r in cell_regions) / len(cell_regions)
|
||||
status = CellStatus.RECOGNIZED if avg_conf >= 0.5 else CellStatus.PROBLEMATIC
|
||||
# Use actual bounding box from regions
|
||||
actual_x = min(r.x for r in cell_regions)
|
||||
actual_y = min(r.y for r in cell_regions)
|
||||
actual_right = max(r.right for r in cell_regions)
|
||||
actual_bottom = max(r.bottom for r in cell_regions)
|
||||
|
||||
cell = GridCell(
|
||||
row=row_idx,
|
||||
col=col_idx,
|
||||
x=actual_x,
|
||||
y=actual_y,
|
||||
width=actual_right - actual_x,
|
||||
height=actual_bottom - actual_y,
|
||||
text=text,
|
||||
confidence=round(avg_conf, 3),
|
||||
status=status,
|
||||
column_type=ColumnType(column_types[col_idx]) if col_idx < len(column_types) else ColumnType.UNKNOWN,
|
||||
logical_row=row_idx,
|
||||
logical_col=col_idx,
|
||||
)
|
||||
|
||||
if status == CellStatus.RECOGNIZED:
|
||||
recognized += 1
|
||||
else:
|
||||
problematic += 1
|
||||
else:
|
||||
cell = GridCell(
|
||||
row=row_idx,
|
||||
col=col_idx,
|
||||
x=col_x,
|
||||
y=row_y,
|
||||
width=col_width,
|
||||
height=row_height,
|
||||
status=CellStatus.EMPTY,
|
||||
column_type=ColumnType(column_types[col_idx]) if col_idx < len(column_types) else ColumnType.UNKNOWN,
|
||||
logical_row=row_idx,
|
||||
logical_col=col_idx,
|
||||
)
|
||||
empty += 1
|
||||
|
||||
row_cells.append(cell)
|
||||
cells.append(row_cells)
|
||||
|
||||
# Add final row boundary
|
||||
if rows and rows[-1]:
|
||||
row_boundaries.append(max(r.bottom for r in rows[-1]) + self.padding_pct)
|
||||
else:
|
||||
row_boundaries.append(100.0)
|
||||
|
||||
total = num_rows * num_cols
|
||||
coverage = (recognized + problematic) / max(total, 1)
|
||||
|
||||
return GridResult(
|
||||
rows=num_rows,
|
||||
columns=num_cols,
|
||||
cells=cells,
|
||||
column_types=column_types,
|
||||
column_boundaries=col_boundaries,
|
||||
row_boundaries=row_boundaries,
|
||||
deskew_angle=deskew_angle,
|
||||
stats={
|
||||
"recognized": recognized,
|
||||
"problematic": problematic,
|
||||
"empty": empty,
|
||||
"manual": 0,
|
||||
"total": total,
|
||||
"coverage": round(coverage, 3),
|
||||
},
|
||||
)
|
||||
|
||||
def convert_tesseract_regions(self, tess_words: List[dict],
|
||||
image_width: int, image_height: int) -> List[OCRRegion]:
|
||||
"""Convert Tesseract word data (pixels) to OCRRegions (percentages).
|
||||
|
||||
Args:
|
||||
tess_words: Word list from tesseract_vocab_extractor.extract_bounding_boxes.
|
||||
image_width: Image width in pixels.
|
||||
image_height: Image height in pixels.
|
||||
|
||||
Returns:
|
||||
List of OCRRegion with percentage-based coordinates.
|
||||
"""
|
||||
if not tess_words or image_width == 0 or image_height == 0:
|
||||
return []
|
||||
|
||||
regions = []
|
||||
for w in tess_words:
|
||||
regions.append(OCRRegion(
|
||||
text=w["text"],
|
||||
confidence=w.get("conf", 50) / 100.0,
|
||||
x=w["left"] / image_width * 100,
|
||||
y=w["top"] / image_height * 100,
|
||||
width=w["width"] / image_width * 100,
|
||||
height=w["height"] / image_height * 100,
|
||||
))
|
||||
|
||||
return regions
|
||||
346
klausur-service/backend/tesseract_vocab_extractor.py
Normal file
346
klausur-service/backend/tesseract_vocab_extractor.py
Normal file
@@ -0,0 +1,346 @@
|
||||
"""
|
||||
Tesseract-based OCR extraction with word-level bounding boxes.
|
||||
|
||||
Uses Tesseract for spatial information (WHERE text is) while
|
||||
the Vision LLM handles semantic understanding (WHAT the text means).
|
||||
|
||||
Tesseract runs natively on ARM64 via Debian's apt package.
|
||||
|
||||
Lizenz: Apache 2.0 (kommerziell nutzbar)
|
||||
"""
|
||||
|
||||
import io
|
||||
import logging
|
||||
from typing import List, Dict, Any, Optional
|
||||
from difflib import SequenceMatcher
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
try:
|
||||
import pytesseract
|
||||
from PIL import Image
|
||||
TESSERACT_AVAILABLE = True
|
||||
except ImportError:
|
||||
TESSERACT_AVAILABLE = False
|
||||
logger.warning("pytesseract or Pillow not installed - Tesseract OCR unavailable")
|
||||
|
||||
|
||||
async def extract_bounding_boxes(image_bytes: bytes, lang: str = "eng+deu") -> dict:
|
||||
"""Run Tesseract OCR and return word-level bounding boxes.
|
||||
|
||||
Args:
|
||||
image_bytes: PNG/JPEG image as bytes.
|
||||
lang: Tesseract language string (e.g. "eng+deu").
|
||||
|
||||
Returns:
|
||||
Dict with 'words' list and 'image_width'/'image_height'.
|
||||
"""
|
||||
if not TESSERACT_AVAILABLE:
|
||||
return {"words": [], "image_width": 0, "image_height": 0, "error": "Tesseract not available"}
|
||||
|
||||
image = Image.open(io.BytesIO(image_bytes))
|
||||
data = pytesseract.image_to_data(image, lang=lang, output_type=pytesseract.Output.DICT)
|
||||
|
||||
words = []
|
||||
for i in range(len(data['text'])):
|
||||
text = data['text'][i].strip()
|
||||
conf = int(data['conf'][i])
|
||||
if not text or conf < 20:
|
||||
continue
|
||||
words.append({
|
||||
"text": text,
|
||||
"left": data['left'][i],
|
||||
"top": data['top'][i],
|
||||
"width": data['width'][i],
|
||||
"height": data['height'][i],
|
||||
"conf": conf,
|
||||
"block_num": data['block_num'][i],
|
||||
"par_num": data['par_num'][i],
|
||||
"line_num": data['line_num'][i],
|
||||
"word_num": data['word_num'][i],
|
||||
})
|
||||
|
||||
return {
|
||||
"words": words,
|
||||
"image_width": image.width,
|
||||
"image_height": image.height,
|
||||
}
|
||||
|
||||
|
||||
def group_words_into_lines(words: List[dict], y_tolerance_px: int = 15) -> List[List[dict]]:
|
||||
"""Group words by their Y position into lines.
|
||||
|
||||
Args:
|
||||
words: List of word dicts from extract_bounding_boxes.
|
||||
y_tolerance_px: Max pixel distance to consider words on the same line.
|
||||
|
||||
Returns:
|
||||
List of lines, each line is a list of words sorted by X position.
|
||||
"""
|
||||
if not words:
|
||||
return []
|
||||
|
||||
# Sort by Y then X
|
||||
sorted_words = sorted(words, key=lambda w: (w['top'], w['left']))
|
||||
|
||||
lines: List[List[dict]] = []
|
||||
current_line: List[dict] = [sorted_words[0]]
|
||||
current_y = sorted_words[0]['top']
|
||||
|
||||
for word in sorted_words[1:]:
|
||||
if abs(word['top'] - current_y) <= y_tolerance_px:
|
||||
current_line.append(word)
|
||||
else:
|
||||
current_line.sort(key=lambda w: w['left'])
|
||||
lines.append(current_line)
|
||||
current_line = [word]
|
||||
current_y = word['top']
|
||||
|
||||
if current_line:
|
||||
current_line.sort(key=lambda w: w['left'])
|
||||
lines.append(current_line)
|
||||
|
||||
return lines
|
||||
|
||||
|
||||
def detect_columns(lines: List[List[dict]], image_width: int) -> Dict[str, Any]:
|
||||
"""Detect column boundaries from word positions.
|
||||
|
||||
Typical vocab table: Left=English, Middle=German, Right=Example sentences.
|
||||
|
||||
Returns:
|
||||
Dict with column boundaries and type assignments.
|
||||
"""
|
||||
if not lines or image_width == 0:
|
||||
return {"columns": [], "column_types": []}
|
||||
|
||||
# Collect all word X positions
|
||||
all_x_positions = []
|
||||
for line in lines:
|
||||
for word in line:
|
||||
all_x_positions.append(word['left'])
|
||||
|
||||
if not all_x_positions:
|
||||
return {"columns": [], "column_types": []}
|
||||
|
||||
# Find X-position clusters (column starts)
|
||||
all_x_positions.sort()
|
||||
|
||||
# Simple gap-based column detection
|
||||
min_gap = image_width * 0.08 # 8% of page width = column gap
|
||||
clusters = []
|
||||
current_cluster = [all_x_positions[0]]
|
||||
|
||||
for x in all_x_positions[1:]:
|
||||
if x - current_cluster[-1] > min_gap:
|
||||
clusters.append(current_cluster)
|
||||
current_cluster = [x]
|
||||
else:
|
||||
current_cluster.append(x)
|
||||
|
||||
if current_cluster:
|
||||
clusters.append(current_cluster)
|
||||
|
||||
# Each cluster represents a column start
|
||||
columns = []
|
||||
for cluster in clusters:
|
||||
col_start = min(cluster)
|
||||
columns.append({
|
||||
"x_start": col_start,
|
||||
"x_start_pct": col_start / image_width * 100,
|
||||
"word_count": len(cluster),
|
||||
})
|
||||
|
||||
# Assign column types based on position (left→right: EN, DE, Example)
|
||||
type_map = ["english", "german", "example"]
|
||||
column_types = []
|
||||
for i, col in enumerate(columns):
|
||||
if i < len(type_map):
|
||||
column_types.append(type_map[i])
|
||||
else:
|
||||
column_types.append("unknown")
|
||||
|
||||
return {
|
||||
"columns": columns,
|
||||
"column_types": column_types,
|
||||
}
|
||||
|
||||
|
||||
def words_to_vocab_entries(lines: List[List[dict]], columns: List[dict],
|
||||
column_types: List[str], image_width: int,
|
||||
image_height: int) -> List[dict]:
|
||||
"""Convert grouped words into vocabulary entries using column positions.
|
||||
|
||||
Args:
|
||||
lines: Grouped word lines from group_words_into_lines.
|
||||
columns: Column boundaries from detect_columns.
|
||||
column_types: Column type assignments.
|
||||
image_width: Image width in pixels.
|
||||
image_height: Image height in pixels.
|
||||
|
||||
Returns:
|
||||
List of vocabulary entry dicts with english/german/example fields.
|
||||
"""
|
||||
if not columns or not lines:
|
||||
return []
|
||||
|
||||
# Build column boundaries for word assignment
|
||||
col_boundaries = []
|
||||
for i, col in enumerate(columns):
|
||||
start = col['x_start']
|
||||
if i + 1 < len(columns):
|
||||
end = columns[i + 1]['x_start']
|
||||
else:
|
||||
end = image_width
|
||||
col_boundaries.append((start, end, column_types[i] if i < len(column_types) else "unknown"))
|
||||
|
||||
entries = []
|
||||
for line in lines:
|
||||
entry = {"english": "", "german": "", "example": ""}
|
||||
line_words_by_col: Dict[str, List[str]] = {"english": [], "german": [], "example": []}
|
||||
line_bbox: Dict[str, Optional[dict]] = {}
|
||||
|
||||
for word in line:
|
||||
word_center_x = word['left'] + word['width'] / 2
|
||||
assigned_type = "unknown"
|
||||
for start, end, col_type in col_boundaries:
|
||||
if start <= word_center_x < end:
|
||||
assigned_type = col_type
|
||||
break
|
||||
|
||||
if assigned_type in line_words_by_col:
|
||||
line_words_by_col[assigned_type].append(word['text'])
|
||||
# Track bounding box for the column
|
||||
if assigned_type not in line_bbox or line_bbox[assigned_type] is None:
|
||||
line_bbox[assigned_type] = {
|
||||
"left": word['left'],
|
||||
"top": word['top'],
|
||||
"right": word['left'] + word['width'],
|
||||
"bottom": word['top'] + word['height'],
|
||||
}
|
||||
else:
|
||||
bb = line_bbox[assigned_type]
|
||||
bb['left'] = min(bb['left'], word['left'])
|
||||
bb['top'] = min(bb['top'], word['top'])
|
||||
bb['right'] = max(bb['right'], word['left'] + word['width'])
|
||||
bb['bottom'] = max(bb['bottom'], word['top'] + word['height'])
|
||||
|
||||
for col_type in ["english", "german", "example"]:
|
||||
if line_words_by_col[col_type]:
|
||||
entry[col_type] = " ".join(line_words_by_col[col_type])
|
||||
if line_bbox.get(col_type):
|
||||
bb = line_bbox[col_type]
|
||||
entry[f"{col_type}_bbox"] = {
|
||||
"x_pct": bb['left'] / image_width * 100,
|
||||
"y_pct": bb['top'] / image_height * 100,
|
||||
"w_pct": (bb['right'] - bb['left']) / image_width * 100,
|
||||
"h_pct": (bb['bottom'] - bb['top']) / image_height * 100,
|
||||
}
|
||||
|
||||
# Only add if at least one column has content
|
||||
if entry["english"] or entry["german"]:
|
||||
entries.append(entry)
|
||||
|
||||
return entries
|
||||
|
||||
|
||||
def match_positions_to_vocab(tess_words: List[dict], llm_vocab: List[dict],
|
||||
image_w: int, image_h: int,
|
||||
threshold: float = 0.6) -> List[dict]:
|
||||
"""Match Tesseract bounding boxes to LLM vocabulary entries.
|
||||
|
||||
For each LLM vocab entry, find the best-matching Tesseract word
|
||||
and attach its bounding box coordinates.
|
||||
|
||||
Args:
|
||||
tess_words: Word list from Tesseract with pixel coordinates.
|
||||
llm_vocab: Vocabulary list from Vision LLM.
|
||||
image_w: Image width in pixels.
|
||||
image_h: Image height in pixels.
|
||||
threshold: Minimum similarity ratio for a match.
|
||||
|
||||
Returns:
|
||||
llm_vocab list with bbox_x_pct/bbox_y_pct/bbox_w_pct/bbox_h_pct added.
|
||||
"""
|
||||
if not tess_words or not llm_vocab or image_w == 0 or image_h == 0:
|
||||
return llm_vocab
|
||||
|
||||
for entry in llm_vocab:
|
||||
english = entry.get("english", "").lower().strip()
|
||||
german = entry.get("german", "").lower().strip()
|
||||
|
||||
if not english and not german:
|
||||
continue
|
||||
|
||||
# Try to match English word first, then German
|
||||
for field in ["english", "german"]:
|
||||
search_text = entry.get(field, "").lower().strip()
|
||||
if not search_text:
|
||||
continue
|
||||
|
||||
best_word = None
|
||||
best_ratio = 0.0
|
||||
|
||||
for word in tess_words:
|
||||
ratio = SequenceMatcher(None, search_text, word['text'].lower()).ratio()
|
||||
if ratio > best_ratio:
|
||||
best_ratio = ratio
|
||||
best_word = word
|
||||
|
||||
if best_word and best_ratio >= threshold:
|
||||
entry[f"bbox_x_pct"] = best_word['left'] / image_w * 100
|
||||
entry[f"bbox_y_pct"] = best_word['top'] / image_h * 100
|
||||
entry[f"bbox_w_pct"] = best_word['width'] / image_w * 100
|
||||
entry[f"bbox_h_pct"] = best_word['height'] / image_h * 100
|
||||
entry["bbox_match_field"] = field
|
||||
entry["bbox_match_ratio"] = round(best_ratio, 3)
|
||||
break # Found a match, no need to try the other field
|
||||
|
||||
return llm_vocab
|
||||
|
||||
|
||||
async def run_tesseract_pipeline(image_bytes: bytes, lang: str = "eng+deu") -> dict:
|
||||
"""Full Tesseract pipeline: extract words, group lines, detect columns, build vocab.
|
||||
|
||||
Args:
|
||||
image_bytes: PNG/JPEG image as bytes.
|
||||
lang: Tesseract language string.
|
||||
|
||||
Returns:
|
||||
Dict with 'vocabulary', 'words', 'lines', 'columns', 'image_width', 'image_height'.
|
||||
"""
|
||||
# Step 1: Extract bounding boxes
|
||||
bbox_data = await extract_bounding_boxes(image_bytes, lang=lang)
|
||||
|
||||
if bbox_data.get("error"):
|
||||
return bbox_data
|
||||
|
||||
words = bbox_data["words"]
|
||||
image_w = bbox_data["image_width"]
|
||||
image_h = bbox_data["image_height"]
|
||||
|
||||
# Step 2: Group into lines
|
||||
lines = group_words_into_lines(words)
|
||||
|
||||
# Step 3: Detect columns
|
||||
col_info = detect_columns(lines, image_w)
|
||||
|
||||
# Step 4: Build vocabulary entries
|
||||
vocab = words_to_vocab_entries(
|
||||
lines,
|
||||
col_info["columns"],
|
||||
col_info["column_types"],
|
||||
image_w,
|
||||
image_h,
|
||||
)
|
||||
|
||||
return {
|
||||
"vocabulary": vocab,
|
||||
"words": words,
|
||||
"lines_count": len(lines),
|
||||
"columns": col_info["columns"],
|
||||
"column_types": col_info["column_types"],
|
||||
"image_width": image_w,
|
||||
"image_height": image_h,
|
||||
"word_count": len(words),
|
||||
}
|
||||
261
klausur-service/backend/trocr_api.py
Normal file
261
klausur-service/backend/trocr_api.py
Normal file
@@ -0,0 +1,261 @@
|
||||
"""
|
||||
TrOCR API - REST endpoints for TrOCR handwriting OCR.
|
||||
|
||||
Provides:
|
||||
- /ocr/trocr - Single image OCR
|
||||
- /ocr/trocr/batch - Batch image processing
|
||||
- /ocr/trocr/status - Model status
|
||||
- /ocr/trocr/cache - Cache statistics
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, UploadFile, File, HTTPException, Query
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import List, Optional
|
||||
import json
|
||||
import logging
|
||||
|
||||
from services.trocr_service import (
|
||||
run_trocr_ocr_enhanced,
|
||||
run_trocr_batch,
|
||||
run_trocr_batch_stream,
|
||||
get_model_status,
|
||||
get_cache_stats,
|
||||
preload_trocr_model,
|
||||
OCRResult,
|
||||
BatchOCRResult
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/v1/ocr/trocr", tags=["TrOCR"])
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# MODELS
|
||||
# =============================================================================
|
||||
|
||||
class TrOCRResponse(BaseModel):
|
||||
"""Response model for single image OCR."""
|
||||
text: str = Field(..., description="Extracted text")
|
||||
confidence: float = Field(..., ge=0.0, le=1.0, description="Overall confidence")
|
||||
processing_time_ms: int = Field(..., ge=0, description="Processing time in milliseconds")
|
||||
model: str = Field(..., description="Model used for OCR")
|
||||
has_lora_adapter: bool = Field(False, description="Whether LoRA adapter was used")
|
||||
from_cache: bool = Field(False, description="Whether result was from cache")
|
||||
image_hash: str = Field("", description="SHA256 hash of image (first 16 chars)")
|
||||
word_count: int = Field(0, description="Number of words detected")
|
||||
|
||||
|
||||
class BatchOCRResponse(BaseModel):
|
||||
"""Response model for batch OCR."""
|
||||
results: List[TrOCRResponse] = Field(..., description="Individual OCR results")
|
||||
total_time_ms: int = Field(..., ge=0, description="Total processing time")
|
||||
processed_count: int = Field(..., ge=0, description="Number of images processed")
|
||||
cached_count: int = Field(0, description="Number of results from cache")
|
||||
error_count: int = Field(0, description="Number of errors")
|
||||
|
||||
|
||||
class ModelStatusResponse(BaseModel):
|
||||
"""Response model for model status."""
|
||||
status: str = Field(..., description="Model status: available, not_installed")
|
||||
is_loaded: bool = Field(..., description="Whether model is loaded in memory")
|
||||
model_name: Optional[str] = Field(None, description="Name of loaded model")
|
||||
device: Optional[str] = Field(None, description="Device model is running on")
|
||||
loaded_at: Optional[str] = Field(None, description="ISO timestamp when model was loaded")
|
||||
|
||||
|
||||
class CacheStatsResponse(BaseModel):
|
||||
"""Response model for cache statistics."""
|
||||
size: int = Field(..., ge=0, description="Current cache size")
|
||||
max_size: int = Field(..., ge=0, description="Maximum cache size")
|
||||
ttl_seconds: int = Field(..., ge=0, description="Cache TTL in seconds")
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# ENDPOINTS
|
||||
# =============================================================================
|
||||
|
||||
@router.get("/status", response_model=ModelStatusResponse)
|
||||
async def get_trocr_status():
|
||||
"""
|
||||
Get TrOCR model status.
|
||||
|
||||
Returns information about whether the model is loaded and available.
|
||||
"""
|
||||
return get_model_status()
|
||||
|
||||
|
||||
@router.get("/cache", response_model=CacheStatsResponse)
|
||||
async def get_trocr_cache_stats():
|
||||
"""
|
||||
Get TrOCR cache statistics.
|
||||
|
||||
Returns information about the OCR result cache.
|
||||
"""
|
||||
return get_cache_stats()
|
||||
|
||||
|
||||
@router.post("/preload")
|
||||
async def preload_model(handwritten: bool = Query(True, description="Load handwritten model")):
|
||||
"""
|
||||
Preload TrOCR model into memory.
|
||||
|
||||
This speeds up the first OCR request by loading the model ahead of time.
|
||||
"""
|
||||
success = preload_trocr_model(handwritten=handwritten)
|
||||
if success:
|
||||
return {"status": "success", "message": "Model preloaded successfully"}
|
||||
else:
|
||||
raise HTTPException(status_code=500, detail="Failed to preload model")
|
||||
|
||||
|
||||
@router.post("", response_model=TrOCRResponse)
|
||||
async def run_trocr(
|
||||
file: UploadFile = File(..., description="Image file to process"),
|
||||
handwritten: bool = Query(True, description="Use handwritten model"),
|
||||
split_lines: bool = Query(True, description="Split image into lines"),
|
||||
use_cache: bool = Query(True, description="Use result caching")
|
||||
):
|
||||
"""
|
||||
Run TrOCR on a single image.
|
||||
|
||||
Supports PNG, JPG, and other common image formats.
|
||||
"""
|
||||
# Validate file type
|
||||
if not file.content_type or not file.content_type.startswith("image/"):
|
||||
raise HTTPException(status_code=400, detail="File must be an image")
|
||||
|
||||
try:
|
||||
image_data = await file.read()
|
||||
|
||||
result = await run_trocr_ocr_enhanced(
|
||||
image_data,
|
||||
handwritten=handwritten,
|
||||
split_lines=split_lines,
|
||||
use_cache=use_cache
|
||||
)
|
||||
|
||||
return TrOCRResponse(
|
||||
text=result.text,
|
||||
confidence=result.confidence,
|
||||
processing_time_ms=result.processing_time_ms,
|
||||
model=result.model,
|
||||
has_lora_adapter=result.has_lora_adapter,
|
||||
from_cache=result.from_cache,
|
||||
image_hash=result.image_hash,
|
||||
word_count=len(result.text.split()) if result.text else 0
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"TrOCR API error: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/batch", response_model=BatchOCRResponse)
|
||||
async def run_trocr_batch_endpoint(
|
||||
files: List[UploadFile] = File(..., description="Image files to process"),
|
||||
handwritten: bool = Query(True, description="Use handwritten model"),
|
||||
split_lines: bool = Query(True, description="Split images into lines"),
|
||||
use_cache: bool = Query(True, description="Use result caching")
|
||||
):
|
||||
"""
|
||||
Run TrOCR on multiple images.
|
||||
|
||||
Processes images sequentially and returns all results.
|
||||
"""
|
||||
if not files:
|
||||
raise HTTPException(status_code=400, detail="No files provided")
|
||||
|
||||
if len(files) > 50:
|
||||
raise HTTPException(status_code=400, detail="Maximum 50 images per batch")
|
||||
|
||||
try:
|
||||
images = []
|
||||
for file in files:
|
||||
if not file.content_type or not file.content_type.startswith("image/"):
|
||||
raise HTTPException(status_code=400, detail=f"File {file.filename} is not an image")
|
||||
images.append(await file.read())
|
||||
|
||||
batch_result = await run_trocr_batch(
|
||||
images,
|
||||
handwritten=handwritten,
|
||||
split_lines=split_lines,
|
||||
use_cache=use_cache
|
||||
)
|
||||
|
||||
return BatchOCRResponse(
|
||||
results=[
|
||||
TrOCRResponse(
|
||||
text=r.text,
|
||||
confidence=r.confidence,
|
||||
processing_time_ms=r.processing_time_ms,
|
||||
model=r.model,
|
||||
has_lora_adapter=r.has_lora_adapter,
|
||||
from_cache=r.from_cache,
|
||||
image_hash=r.image_hash,
|
||||
word_count=len(r.text.split()) if r.text else 0
|
||||
)
|
||||
for r in batch_result.results
|
||||
],
|
||||
total_time_ms=batch_result.total_time_ms,
|
||||
processed_count=batch_result.processed_count,
|
||||
cached_count=batch_result.cached_count,
|
||||
error_count=batch_result.error_count
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"TrOCR batch API error: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/batch/stream")
|
||||
async def run_trocr_batch_stream_endpoint(
|
||||
files: List[UploadFile] = File(..., description="Image files to process"),
|
||||
handwritten: bool = Query(True, description="Use handwritten model"),
|
||||
split_lines: bool = Query(True, description="Split images into lines"),
|
||||
use_cache: bool = Query(True, description="Use result caching")
|
||||
):
|
||||
"""
|
||||
Run TrOCR on multiple images with Server-Sent Events (SSE) progress updates.
|
||||
|
||||
Returns a stream of progress events as images are processed.
|
||||
"""
|
||||
if not files:
|
||||
raise HTTPException(status_code=400, detail="No files provided")
|
||||
|
||||
if len(files) > 50:
|
||||
raise HTTPException(status_code=400, detail="Maximum 50 images per batch")
|
||||
|
||||
try:
|
||||
images = []
|
||||
for file in files:
|
||||
if not file.content_type or not file.content_type.startswith("image/"):
|
||||
raise HTTPException(status_code=400, detail=f"File {file.filename} is not an image")
|
||||
images.append(await file.read())
|
||||
|
||||
async def event_generator():
|
||||
async for update in run_trocr_batch_stream(
|
||||
images,
|
||||
handwritten=handwritten,
|
||||
split_lines=split_lines,
|
||||
use_cache=use_cache
|
||||
):
|
||||
yield f"data: {json.dumps(update)}\n\n"
|
||||
|
||||
return StreamingResponse(
|
||||
event_generator(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive"
|
||||
}
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"TrOCR stream API error: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
428
klausur-service/backend/vocab_session_store.py
Normal file
428
klausur-service/backend/vocab_session_store.py
Normal file
@@ -0,0 +1,428 @@
|
||||
"""
|
||||
Vocabulary Session Store - PostgreSQL persistence for vocab extraction sessions.
|
||||
|
||||
Replaces in-memory storage with database persistence.
|
||||
See migrations/001_vocab_sessions.sql for schema.
|
||||
"""
|
||||
|
||||
import os
|
||||
import uuid
|
||||
import logging
|
||||
import json
|
||||
from typing import Optional, List, Dict, Any
|
||||
from datetime import datetime
|
||||
|
||||
import asyncpg
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Database configuration
|
||||
DATABASE_URL = os.getenv(
|
||||
"DATABASE_URL",
|
||||
"postgresql://breakpilot:breakpilot@postgres:5432/breakpilot_db"
|
||||
)
|
||||
|
||||
# Connection pool (initialized lazily)
|
||||
_pool: Optional[asyncpg.Pool] = None
|
||||
|
||||
|
||||
async def get_pool() -> asyncpg.Pool:
|
||||
"""Get or create the database connection pool."""
|
||||
global _pool
|
||||
if _pool is None:
|
||||
_pool = await asyncpg.create_pool(DATABASE_URL, min_size=2, max_size=10)
|
||||
return _pool
|
||||
|
||||
|
||||
async def init_vocab_tables():
|
||||
"""
|
||||
Initialize vocab tables if they don't exist.
|
||||
This is called at startup.
|
||||
"""
|
||||
pool = await get_pool()
|
||||
async with pool.acquire() as conn:
|
||||
# Check if tables exist
|
||||
tables_exist = await conn.fetchval("""
|
||||
SELECT EXISTS (
|
||||
SELECT FROM information_schema.tables
|
||||
WHERE table_name = 'vocab_sessions'
|
||||
)
|
||||
""")
|
||||
|
||||
if not tables_exist:
|
||||
logger.info("Creating vocab tables...")
|
||||
# Read and execute migration
|
||||
migration_path = os.path.join(
|
||||
os.path.dirname(__file__),
|
||||
"migrations/001_vocab_sessions.sql"
|
||||
)
|
||||
if os.path.exists(migration_path):
|
||||
with open(migration_path, "r") as f:
|
||||
sql = f.read()
|
||||
await conn.execute(sql)
|
||||
logger.info("Vocab tables created successfully")
|
||||
else:
|
||||
logger.warning(f"Migration file not found: {migration_path}")
|
||||
else:
|
||||
logger.debug("Vocab tables already exist")
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# SESSION OPERATIONS
|
||||
# =============================================================================
|
||||
|
||||
async def create_session_db(
|
||||
session_id: str,
|
||||
name: str,
|
||||
description: str = "",
|
||||
source_language: str = "en",
|
||||
target_language: str = "de"
|
||||
) -> Dict[str, Any]:
|
||||
"""Create a new vocabulary session in the database."""
|
||||
pool = await get_pool()
|
||||
async with pool.acquire() as conn:
|
||||
row = await conn.fetchrow("""
|
||||
INSERT INTO vocab_sessions (
|
||||
id, name, description, source_language, target_language,
|
||||
status, vocabulary_count
|
||||
) VALUES ($1, $2, $3, $4, $5, 'pending', 0)
|
||||
RETURNING *
|
||||
""", uuid.UUID(session_id), name, description, source_language, target_language)
|
||||
|
||||
return _row_to_dict(row)
|
||||
|
||||
|
||||
async def get_session_db(session_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get a session by ID."""
|
||||
pool = await get_pool()
|
||||
async with pool.acquire() as conn:
|
||||
row = await conn.fetchrow("""
|
||||
SELECT * FROM vocab_sessions WHERE id = $1
|
||||
""", uuid.UUID(session_id))
|
||||
|
||||
if row:
|
||||
return _row_to_dict(row)
|
||||
return None
|
||||
|
||||
|
||||
async def list_sessions_db(
|
||||
limit: int = 50,
|
||||
offset: int = 0,
|
||||
status: Optional[str] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""List all sessions with optional filtering."""
|
||||
pool = await get_pool()
|
||||
async with pool.acquire() as conn:
|
||||
if status:
|
||||
rows = await conn.fetch("""
|
||||
SELECT * FROM vocab_sessions
|
||||
WHERE status = $1
|
||||
ORDER BY created_at DESC
|
||||
LIMIT $2 OFFSET $3
|
||||
""", status, limit, offset)
|
||||
else:
|
||||
rows = await conn.fetch("""
|
||||
SELECT * FROM vocab_sessions
|
||||
ORDER BY created_at DESC
|
||||
LIMIT $1 OFFSET $2
|
||||
""", limit, offset)
|
||||
|
||||
return [_row_to_dict(row) for row in rows]
|
||||
|
||||
|
||||
async def update_session_db(
|
||||
session_id: str,
|
||||
**kwargs
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""Update a session with given fields."""
|
||||
pool = await get_pool()
|
||||
|
||||
# Build dynamic UPDATE query
|
||||
fields = []
|
||||
values = []
|
||||
param_idx = 1
|
||||
|
||||
allowed_fields = [
|
||||
'name', 'description', 'status', 'vocabulary_count',
|
||||
'extraction_confidence', 'image_path', 'pdf_path', 'pdf_page_count',
|
||||
'ocr_prompts', 'processed_pages', 'successful_pages', 'failed_pages'
|
||||
]
|
||||
|
||||
for key, value in kwargs.items():
|
||||
if key in allowed_fields:
|
||||
fields.append(f"{key} = ${param_idx}")
|
||||
# Convert dicts/lists to JSON for JSONB columns
|
||||
if key in ['ocr_prompts', 'processed_pages', 'successful_pages', 'failed_pages']:
|
||||
value = json.dumps(value) if value else None
|
||||
values.append(value)
|
||||
param_idx += 1
|
||||
|
||||
if not fields:
|
||||
return await get_session_db(session_id)
|
||||
|
||||
values.append(uuid.UUID(session_id))
|
||||
|
||||
async with pool.acquire() as conn:
|
||||
row = await conn.fetchrow(f"""
|
||||
UPDATE vocab_sessions
|
||||
SET {', '.join(fields)}
|
||||
WHERE id = ${param_idx}
|
||||
RETURNING *
|
||||
""", *values)
|
||||
|
||||
if row:
|
||||
return _row_to_dict(row)
|
||||
return None
|
||||
|
||||
|
||||
async def delete_session_db(session_id: str) -> bool:
|
||||
"""Delete a session and all related data (cascades)."""
|
||||
pool = await get_pool()
|
||||
async with pool.acquire() as conn:
|
||||
result = await conn.execute("""
|
||||
DELETE FROM vocab_sessions WHERE id = $1
|
||||
""", uuid.UUID(session_id))
|
||||
return result == "DELETE 1"
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# VOCABULARY OPERATIONS
|
||||
# =============================================================================
|
||||
|
||||
async def add_vocabulary_db(
|
||||
session_id: str,
|
||||
vocab_list: List[Dict[str, Any]]
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Add vocabulary entries to a session."""
|
||||
if not vocab_list:
|
||||
return []
|
||||
|
||||
pool = await get_pool()
|
||||
results = []
|
||||
|
||||
async with pool.acquire() as conn:
|
||||
for vocab in vocab_list:
|
||||
vocab_id = str(uuid.uuid4())
|
||||
row = await conn.fetchrow("""
|
||||
INSERT INTO vocab_entries (
|
||||
id, session_id, english, german, example_sentence,
|
||||
example_sentence_gap, word_type, source_page
|
||||
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
|
||||
RETURNING *
|
||||
""",
|
||||
uuid.UUID(vocab_id),
|
||||
uuid.UUID(session_id),
|
||||
vocab.get('english', ''),
|
||||
vocab.get('german', ''),
|
||||
vocab.get('example_sentence'),
|
||||
vocab.get('example_sentence_gap'),
|
||||
vocab.get('word_type'),
|
||||
vocab.get('source_page')
|
||||
)
|
||||
results.append(_row_to_dict(row))
|
||||
|
||||
# Update vocabulary count
|
||||
await conn.execute("""
|
||||
UPDATE vocab_sessions
|
||||
SET vocabulary_count = (
|
||||
SELECT COUNT(*) FROM vocab_entries WHERE session_id = $1
|
||||
)
|
||||
WHERE id = $1
|
||||
""", uuid.UUID(session_id))
|
||||
|
||||
return results
|
||||
|
||||
|
||||
async def get_vocabulary_db(
|
||||
session_id: str,
|
||||
source_page: Optional[int] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Get vocabulary entries for a session."""
|
||||
pool = await get_pool()
|
||||
async with pool.acquire() as conn:
|
||||
if source_page is not None:
|
||||
rows = await conn.fetch("""
|
||||
SELECT * FROM vocab_entries
|
||||
WHERE session_id = $1 AND source_page = $2
|
||||
ORDER BY created_at
|
||||
""", uuid.UUID(session_id), source_page)
|
||||
else:
|
||||
rows = await conn.fetch("""
|
||||
SELECT * FROM vocab_entries
|
||||
WHERE session_id = $1
|
||||
ORDER BY source_page NULLS LAST, created_at
|
||||
""", uuid.UUID(session_id))
|
||||
|
||||
return [_row_to_dict(row) for row in rows]
|
||||
|
||||
|
||||
async def update_vocabulary_db(
|
||||
entry_id: str,
|
||||
**kwargs
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""Update a single vocabulary entry."""
|
||||
pool = await get_pool()
|
||||
|
||||
fields = []
|
||||
values = []
|
||||
param_idx = 1
|
||||
|
||||
allowed_fields = [
|
||||
'english', 'german', 'example_sentence', 'example_sentence_gap',
|
||||
'word_type', 'source_page'
|
||||
]
|
||||
|
||||
for key, value in kwargs.items():
|
||||
if key in allowed_fields:
|
||||
fields.append(f"{key} = ${param_idx}")
|
||||
values.append(value)
|
||||
param_idx += 1
|
||||
|
||||
if not fields:
|
||||
return None
|
||||
|
||||
values.append(uuid.UUID(entry_id))
|
||||
|
||||
async with pool.acquire() as conn:
|
||||
row = await conn.fetchrow(f"""
|
||||
UPDATE vocab_entries
|
||||
SET {', '.join(fields)}
|
||||
WHERE id = ${param_idx}
|
||||
RETURNING *
|
||||
""", *values)
|
||||
|
||||
if row:
|
||||
return _row_to_dict(row)
|
||||
return None
|
||||
|
||||
|
||||
async def clear_page_vocabulary_db(session_id: str, page: int) -> int:
|
||||
"""Clear all vocabulary for a specific page."""
|
||||
pool = await get_pool()
|
||||
async with pool.acquire() as conn:
|
||||
result = await conn.execute("""
|
||||
DELETE FROM vocab_entries
|
||||
WHERE session_id = $1 AND source_page = $2
|
||||
""", uuid.UUID(session_id), page)
|
||||
|
||||
# Update vocabulary count
|
||||
await conn.execute("""
|
||||
UPDATE vocab_sessions
|
||||
SET vocabulary_count = (
|
||||
SELECT COUNT(*) FROM vocab_entries WHERE session_id = $1
|
||||
)
|
||||
WHERE id = $1
|
||||
""", uuid.UUID(session_id))
|
||||
|
||||
# Return count of deleted rows
|
||||
count = int(result.split()[-1]) if result else 0
|
||||
return count
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# WORKSHEET OPERATIONS
|
||||
# =============================================================================
|
||||
|
||||
async def create_worksheet_db(
|
||||
session_id: str,
|
||||
worksheet_types: List[str],
|
||||
pdf_path: Optional[str] = None,
|
||||
solution_path: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Create a worksheet record."""
|
||||
pool = await get_pool()
|
||||
worksheet_id = str(uuid.uuid4())
|
||||
|
||||
async with pool.acquire() as conn:
|
||||
row = await conn.fetchrow("""
|
||||
INSERT INTO vocab_worksheets (
|
||||
id, session_id, worksheet_types, pdf_path, solution_path
|
||||
) VALUES ($1, $2, $3, $4, $5)
|
||||
RETURNING *
|
||||
""",
|
||||
uuid.UUID(worksheet_id),
|
||||
uuid.UUID(session_id),
|
||||
json.dumps(worksheet_types),
|
||||
pdf_path,
|
||||
solution_path
|
||||
)
|
||||
|
||||
return _row_to_dict(row)
|
||||
|
||||
|
||||
async def get_worksheet_db(worksheet_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get a worksheet by ID."""
|
||||
pool = await get_pool()
|
||||
async with pool.acquire() as conn:
|
||||
row = await conn.fetchrow("""
|
||||
SELECT * FROM vocab_worksheets WHERE id = $1
|
||||
""", uuid.UUID(worksheet_id))
|
||||
|
||||
if row:
|
||||
return _row_to_dict(row)
|
||||
return None
|
||||
|
||||
|
||||
async def delete_worksheets_for_session_db(session_id: str) -> int:
|
||||
"""Delete all worksheets for a session."""
|
||||
pool = await get_pool()
|
||||
async with pool.acquire() as conn:
|
||||
result = await conn.execute("""
|
||||
DELETE FROM vocab_worksheets WHERE session_id = $1
|
||||
""", uuid.UUID(session_id))
|
||||
|
||||
count = int(result.split()[-1]) if result else 0
|
||||
return count
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# PDF CACHE OPERATIONS
|
||||
# =============================================================================
|
||||
|
||||
# Simple in-memory cache for PDF data (temporary until served)
|
||||
_pdf_cache: Dict[str, bytes] = {}
|
||||
|
||||
|
||||
def cache_pdf_data(worksheet_id: str, pdf_data: bytes) -> None:
|
||||
"""Cache PDF data temporarily for download."""
|
||||
_pdf_cache[worksheet_id] = pdf_data
|
||||
|
||||
|
||||
def get_cached_pdf_data(worksheet_id: str) -> Optional[bytes]:
|
||||
"""Get cached PDF data."""
|
||||
return _pdf_cache.get(worksheet_id)
|
||||
|
||||
|
||||
def clear_cached_pdf_data(worksheet_id: str) -> None:
|
||||
"""Clear cached PDF data."""
|
||||
_pdf_cache.pop(worksheet_id, None)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# HELPER FUNCTIONS
|
||||
# =============================================================================
|
||||
|
||||
def _row_to_dict(row: asyncpg.Record) -> Dict[str, Any]:
|
||||
"""Convert asyncpg Record to dict with proper type handling."""
|
||||
if row is None:
|
||||
return {}
|
||||
|
||||
result = dict(row)
|
||||
|
||||
# Convert UUIDs to strings
|
||||
for key in ['id', 'session_id']:
|
||||
if key in result and result[key] is not None:
|
||||
result[key] = str(result[key])
|
||||
|
||||
# Convert datetimes to ISO strings
|
||||
for key in ['created_at', 'updated_at', 'generated_at']:
|
||||
if key in result and result[key] is not None:
|
||||
result[key] = result[key].isoformat()
|
||||
|
||||
# Parse JSONB fields back to dicts/lists
|
||||
for key in ['ocr_prompts', 'processed_pages', 'successful_pages', 'failed_pages', 'worksheet_types']:
|
||||
if key in result and result[key] is not None:
|
||||
if isinstance(result[key], str):
|
||||
result[key] = json.loads(result[key])
|
||||
|
||||
return result
|
||||
Reference in New Issue
Block a user