feat(klausur-service): Add Tesseract OCR, DSFA RAG, TrOCR, grid detection and vocab session store
New modules: - tesseract_vocab_extractor.py: Bounding-box OCR with multi-PSM pipeline - grid_detection_service.py: CV-based grid/table detection for worksheets - vocab_session_store.py: PostgreSQL persistence for vocab sessions - trocr_api.py: TrOCR handwriting recognition endpoint - dsfa_rag_api.py + dsfa_corpus_ingestion.py: DSFA RAG corpus search Changes: - Dockerfile: Install tesseract-ocr + deu/eng language packs - requirements.txt: Add PyMuPDF, pytesseract, Pillow - main.py: Register new routers, init DB pools + Qdrant collections Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -13,9 +13,12 @@ FROM python:3.11-slim
|
|||||||
|
|
||||||
WORKDIR /app
|
WORKDIR /app
|
||||||
|
|
||||||
# Install system dependencies
|
# Install system dependencies (incl. Tesseract OCR for bounding-box extraction)
|
||||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||||
curl \
|
curl \
|
||||||
|
tesseract-ocr \
|
||||||
|
tesseract-ocr-deu \
|
||||||
|
tesseract-ocr-eng \
|
||||||
&& rm -rf /var/lib/apt/lists/*
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
# Install Python dependencies
|
# Install Python dependencies
|
||||||
|
|||||||
1501
klausur-service/backend/dsfa_corpus_ingestion.py
Normal file
1501
klausur-service/backend/dsfa_corpus_ingestion.py
Normal file
File diff suppressed because it is too large
Load Diff
715
klausur-service/backend/dsfa_rag_api.py
Normal file
715
klausur-service/backend/dsfa_rag_api.py
Normal file
@@ -0,0 +1,715 @@
|
|||||||
|
"""
|
||||||
|
DSFA RAG API Endpoints.
|
||||||
|
|
||||||
|
Provides REST API for searching DSFA corpus with full source attribution.
|
||||||
|
|
||||||
|
Endpoints:
|
||||||
|
- GET /api/v1/dsfa-rag/search - Semantic search with attribution
|
||||||
|
- GET /api/v1/dsfa-rag/sources - List all registered sources
|
||||||
|
- POST /api/v1/dsfa-rag/sources/{code}/ingest - Trigger source ingestion
|
||||||
|
- GET /api/v1/dsfa-rag/chunks/{id} - Get single chunk with attribution
|
||||||
|
- GET /api/v1/dsfa-rag/stats - Get corpus statistics
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import uuid
|
||||||
|
import logging
|
||||||
|
from typing import List, Optional
|
||||||
|
from dataclasses import dataclass, asdict
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
from fastapi import APIRouter, HTTPException, Query, Depends
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Embedding service configuration
|
||||||
|
EMBEDDING_SERVICE_URL = os.getenv("EMBEDDING_SERVICE_URL", "http://172.18.0.13:8087")
|
||||||
|
|
||||||
|
# Import from ingestion module
|
||||||
|
from dsfa_corpus_ingestion import (
|
||||||
|
DSFACorpusStore,
|
||||||
|
DSFAQdrantService,
|
||||||
|
DSFASearchResult,
|
||||||
|
LICENSE_REGISTRY,
|
||||||
|
DSFA_SOURCES,
|
||||||
|
generate_attribution_notice,
|
||||||
|
get_license_label,
|
||||||
|
DSFA_COLLECTION,
|
||||||
|
chunk_document
|
||||||
|
)
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/api/v1/dsfa-rag", tags=["DSFA RAG"])
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Pydantic Models
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
class DSFASourceResponse(BaseModel):
|
||||||
|
"""Response model for DSFA source."""
|
||||||
|
id: str
|
||||||
|
source_code: str
|
||||||
|
name: str
|
||||||
|
full_name: Optional[str] = None
|
||||||
|
organization: Optional[str] = None
|
||||||
|
source_url: Optional[str] = None
|
||||||
|
license_code: str
|
||||||
|
license_name: str
|
||||||
|
license_url: Optional[str] = None
|
||||||
|
attribution_required: bool
|
||||||
|
attribution_text: str
|
||||||
|
document_type: Optional[str] = None
|
||||||
|
language: str = "de"
|
||||||
|
|
||||||
|
|
||||||
|
class DSFAChunkResponse(BaseModel):
|
||||||
|
"""Response model for a single chunk with attribution."""
|
||||||
|
chunk_id: str
|
||||||
|
content: str
|
||||||
|
section_title: Optional[str] = None
|
||||||
|
page_number: Optional[int] = None
|
||||||
|
category: Optional[str] = None
|
||||||
|
|
||||||
|
# Document info
|
||||||
|
document_id: str
|
||||||
|
document_title: Optional[str] = None
|
||||||
|
|
||||||
|
# Attribution (always included)
|
||||||
|
source_id: str
|
||||||
|
source_code: str
|
||||||
|
source_name: str
|
||||||
|
attribution_text: str
|
||||||
|
license_code: str
|
||||||
|
license_name: str
|
||||||
|
license_url: Optional[str] = None
|
||||||
|
attribution_required: bool
|
||||||
|
source_url: Optional[str] = None
|
||||||
|
document_type: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class DSFASearchResultResponse(BaseModel):
|
||||||
|
"""Response model for search result."""
|
||||||
|
chunk_id: str
|
||||||
|
content: str
|
||||||
|
score: float
|
||||||
|
|
||||||
|
# Attribution
|
||||||
|
source_code: str
|
||||||
|
source_name: str
|
||||||
|
attribution_text: str
|
||||||
|
license_code: str
|
||||||
|
license_name: str
|
||||||
|
license_url: Optional[str] = None
|
||||||
|
attribution_required: bool
|
||||||
|
source_url: Optional[str] = None
|
||||||
|
|
||||||
|
# Metadata
|
||||||
|
document_type: Optional[str] = None
|
||||||
|
category: Optional[str] = None
|
||||||
|
section_title: Optional[str] = None
|
||||||
|
page_number: Optional[int] = None
|
||||||
|
|
||||||
|
|
||||||
|
class DSFASearchResponse(BaseModel):
|
||||||
|
"""Response model for search endpoint."""
|
||||||
|
query: str
|
||||||
|
results: List[DSFASearchResultResponse]
|
||||||
|
total_results: int
|
||||||
|
|
||||||
|
# Aggregated licenses for footer
|
||||||
|
licenses_used: List[str]
|
||||||
|
attribution_notice: str
|
||||||
|
|
||||||
|
|
||||||
|
class DSFASourceStatsResponse(BaseModel):
|
||||||
|
"""Response model for source statistics."""
|
||||||
|
source_id: str
|
||||||
|
source_code: str
|
||||||
|
name: str
|
||||||
|
organization: Optional[str] = None
|
||||||
|
license_code: str
|
||||||
|
document_type: Optional[str] = None
|
||||||
|
document_count: int
|
||||||
|
chunk_count: int
|
||||||
|
last_indexed_at: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class DSFACorpusStatsResponse(BaseModel):
|
||||||
|
"""Response model for corpus statistics."""
|
||||||
|
sources: List[DSFASourceStatsResponse]
|
||||||
|
total_sources: int
|
||||||
|
total_documents: int
|
||||||
|
total_chunks: int
|
||||||
|
qdrant_collection: str
|
||||||
|
qdrant_points_count: int
|
||||||
|
qdrant_status: str
|
||||||
|
|
||||||
|
|
||||||
|
class IngestRequest(BaseModel):
|
||||||
|
"""Request model for ingestion."""
|
||||||
|
document_url: Optional[str] = None
|
||||||
|
document_text: Optional[str] = None
|
||||||
|
title: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class IngestResponse(BaseModel):
|
||||||
|
"""Response model for ingestion."""
|
||||||
|
source_code: str
|
||||||
|
document_id: Optional[str] = None
|
||||||
|
chunks_created: int
|
||||||
|
message: str
|
||||||
|
|
||||||
|
|
||||||
|
class LicenseInfo(BaseModel):
|
||||||
|
"""License information."""
|
||||||
|
code: str
|
||||||
|
name: str
|
||||||
|
url: Optional[str] = None
|
||||||
|
attribution_required: bool
|
||||||
|
modification_allowed: bool
|
||||||
|
commercial_use: bool
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Dependency Injection
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
# Database pool (will be set from main.py)
|
||||||
|
_db_pool = None
|
||||||
|
|
||||||
|
|
||||||
|
def set_db_pool(pool):
|
||||||
|
"""Set the database pool for API endpoints."""
|
||||||
|
global _db_pool
|
||||||
|
_db_pool = pool
|
||||||
|
|
||||||
|
|
||||||
|
async def get_store() -> DSFACorpusStore:
|
||||||
|
"""Get DSFA corpus store."""
|
||||||
|
if _db_pool is None:
|
||||||
|
raise HTTPException(status_code=503, detail="Database not initialized")
|
||||||
|
return DSFACorpusStore(_db_pool)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_qdrant() -> DSFAQdrantService:
|
||||||
|
"""Get Qdrant service."""
|
||||||
|
return DSFAQdrantService()
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Embedding Service Integration
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
async def get_embedding(text: str) -> List[float]:
|
||||||
|
"""
|
||||||
|
Get embedding for text using the embedding-service.
|
||||||
|
|
||||||
|
Uses BGE-M3 model which produces 1024-dimensional vectors.
|
||||||
|
"""
|
||||||
|
async with httpx.AsyncClient(timeout=60.0) as client:
|
||||||
|
try:
|
||||||
|
response = await client.post(
|
||||||
|
f"{EMBEDDING_SERVICE_URL}/embed-single",
|
||||||
|
json={"text": text}
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
data = response.json()
|
||||||
|
return data.get("embedding", [])
|
||||||
|
except httpx.HTTPError as e:
|
||||||
|
logger.error(f"Embedding service error: {e}")
|
||||||
|
# Fallback to hash-based pseudo-embedding for development
|
||||||
|
return _generate_fallback_embedding(text)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_embeddings_batch(texts: List[str]) -> List[List[float]]:
|
||||||
|
"""
|
||||||
|
Get embeddings for multiple texts in batch.
|
||||||
|
"""
|
||||||
|
async with httpx.AsyncClient(timeout=120.0) as client:
|
||||||
|
try:
|
||||||
|
response = await client.post(
|
||||||
|
f"{EMBEDDING_SERVICE_URL}/embed",
|
||||||
|
json={"texts": texts}
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
data = response.json()
|
||||||
|
return data.get("embeddings", [])
|
||||||
|
except httpx.HTTPError as e:
|
||||||
|
logger.error(f"Embedding service batch error: {e}")
|
||||||
|
# Fallback
|
||||||
|
return [_generate_fallback_embedding(t) for t in texts]
|
||||||
|
|
||||||
|
|
||||||
|
async def extract_text_from_url(url: str) -> str:
|
||||||
|
"""
|
||||||
|
Extract text from a document URL (PDF, HTML, etc.).
|
||||||
|
"""
|
||||||
|
async with httpx.AsyncClient(timeout=120.0) as client:
|
||||||
|
try:
|
||||||
|
# First try to use the embedding-service's extract-pdf endpoint
|
||||||
|
response = await client.post(
|
||||||
|
f"{EMBEDDING_SERVICE_URL}/extract-pdf",
|
||||||
|
json={"url": url}
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
data = response.json()
|
||||||
|
return data.get("text", "")
|
||||||
|
except httpx.HTTPError as e:
|
||||||
|
logger.error(f"PDF extraction error for {url}: {e}")
|
||||||
|
# Fallback: try to fetch HTML content directly
|
||||||
|
try:
|
||||||
|
response = await client.get(url, follow_redirects=True)
|
||||||
|
response.raise_for_status()
|
||||||
|
content_type = response.headers.get("content-type", "")
|
||||||
|
if "html" in content_type:
|
||||||
|
# Simple HTML text extraction
|
||||||
|
import re
|
||||||
|
html = response.text
|
||||||
|
# Remove scripts and styles
|
||||||
|
html = re.sub(r'<script[^>]*>.*?</script>', '', html, flags=re.DOTALL | re.IGNORECASE)
|
||||||
|
html = re.sub(r'<style[^>]*>.*?</style>', '', html, flags=re.DOTALL | re.IGNORECASE)
|
||||||
|
# Remove tags
|
||||||
|
text = re.sub(r'<[^>]+>', ' ', html)
|
||||||
|
# Clean whitespace
|
||||||
|
text = re.sub(r'\s+', ' ', text).strip()
|
||||||
|
return text
|
||||||
|
else:
|
||||||
|
return ""
|
||||||
|
except Exception as fetch_err:
|
||||||
|
logger.error(f"Fallback fetch error for {url}: {fetch_err}")
|
||||||
|
return ""
|
||||||
|
|
||||||
|
|
||||||
|
def _generate_fallback_embedding(text: str) -> List[float]:
|
||||||
|
"""
|
||||||
|
Generate deterministic pseudo-embedding from text hash.
|
||||||
|
Used as fallback when embedding service is unavailable.
|
||||||
|
"""
|
||||||
|
import hashlib
|
||||||
|
import struct
|
||||||
|
|
||||||
|
hash_bytes = hashlib.sha256(text.encode()).digest()
|
||||||
|
embedding = []
|
||||||
|
for i in range(0, min(len(hash_bytes), 128), 4):
|
||||||
|
val = struct.unpack('f', hash_bytes[i:i+4])[0]
|
||||||
|
embedding.append(val % 1.0)
|
||||||
|
|
||||||
|
# Pad to 1024 dimensions
|
||||||
|
while len(embedding) < 1024:
|
||||||
|
embedding.extend(embedding[:min(len(embedding), 1024 - len(embedding))])
|
||||||
|
|
||||||
|
return embedding[:1024]
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# API Endpoints
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
@router.get("/search", response_model=DSFASearchResponse)
|
||||||
|
async def search_dsfa_corpus(
|
||||||
|
query: str = Query(..., min_length=3, description="Search query"),
|
||||||
|
source_codes: Optional[List[str]] = Query(None, description="Filter by source codes"),
|
||||||
|
document_types: Optional[List[str]] = Query(None, description="Filter by document types (guideline, checklist, regulation)"),
|
||||||
|
categories: Optional[List[str]] = Query(None, description="Filter by categories (threshold_analysis, risk_assessment, mitigation)"),
|
||||||
|
limit: int = Query(10, ge=1, le=50, description="Maximum results"),
|
||||||
|
include_attribution: bool = Query(True, description="Include attribution in results"),
|
||||||
|
store: DSFACorpusStore = Depends(get_store),
|
||||||
|
qdrant: DSFAQdrantService = Depends(get_qdrant)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Search DSFA corpus with full attribution.
|
||||||
|
|
||||||
|
Returns matching chunks with source/license information for compliance.
|
||||||
|
"""
|
||||||
|
# Get query embedding
|
||||||
|
query_embedding = await get_embedding(query)
|
||||||
|
|
||||||
|
# Search Qdrant
|
||||||
|
raw_results = await qdrant.search(
|
||||||
|
query_embedding=query_embedding,
|
||||||
|
source_codes=source_codes,
|
||||||
|
document_types=document_types,
|
||||||
|
categories=categories,
|
||||||
|
limit=limit
|
||||||
|
)
|
||||||
|
|
||||||
|
# Transform results
|
||||||
|
results = []
|
||||||
|
licenses_used = set()
|
||||||
|
|
||||||
|
for r in raw_results:
|
||||||
|
license_code = r.get("license_code", "")
|
||||||
|
license_info = LICENSE_REGISTRY.get(license_code, {})
|
||||||
|
|
||||||
|
result = DSFASearchResultResponse(
|
||||||
|
chunk_id=r.get("chunk_id", ""),
|
||||||
|
content=r.get("content", ""),
|
||||||
|
score=r.get("score", 0.0),
|
||||||
|
source_code=r.get("source_code", ""),
|
||||||
|
source_name=r.get("source_name", ""),
|
||||||
|
attribution_text=r.get("attribution_text", ""),
|
||||||
|
license_code=license_code,
|
||||||
|
license_name=license_info.get("name", license_code),
|
||||||
|
license_url=license_info.get("url"),
|
||||||
|
attribution_required=r.get("attribution_required", True),
|
||||||
|
source_url=r.get("source_url"),
|
||||||
|
document_type=r.get("document_type"),
|
||||||
|
category=r.get("category"),
|
||||||
|
section_title=r.get("section_title"),
|
||||||
|
page_number=r.get("page_number")
|
||||||
|
)
|
||||||
|
results.append(result)
|
||||||
|
licenses_used.add(license_code)
|
||||||
|
|
||||||
|
# Generate attribution notice
|
||||||
|
search_results = [
|
||||||
|
DSFASearchResult(
|
||||||
|
chunk_id=r.chunk_id,
|
||||||
|
content=r.content,
|
||||||
|
score=r.score,
|
||||||
|
source_code=r.source_code,
|
||||||
|
source_name=r.source_name,
|
||||||
|
attribution_text=r.attribution_text,
|
||||||
|
license_code=r.license_code,
|
||||||
|
license_url=r.license_url,
|
||||||
|
attribution_required=r.attribution_required,
|
||||||
|
source_url=r.source_url,
|
||||||
|
document_type=r.document_type or "",
|
||||||
|
category=r.category or "",
|
||||||
|
section_title=r.section_title,
|
||||||
|
page_number=r.page_number
|
||||||
|
)
|
||||||
|
for r in results
|
||||||
|
]
|
||||||
|
attribution_notice = generate_attribution_notice(search_results) if include_attribution else ""
|
||||||
|
|
||||||
|
return DSFASearchResponse(
|
||||||
|
query=query,
|
||||||
|
results=results,
|
||||||
|
total_results=len(results),
|
||||||
|
licenses_used=list(licenses_used),
|
||||||
|
attribution_notice=attribution_notice
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/sources", response_model=List[DSFASourceResponse])
|
||||||
|
async def list_dsfa_sources(
|
||||||
|
document_type: Optional[str] = Query(None, description="Filter by document type"),
|
||||||
|
license_code: Optional[str] = Query(None, description="Filter by license"),
|
||||||
|
store: DSFACorpusStore = Depends(get_store)
|
||||||
|
):
|
||||||
|
"""List all registered DSFA sources with license info."""
|
||||||
|
sources = await store.list_sources()
|
||||||
|
|
||||||
|
result = []
|
||||||
|
for s in sources:
|
||||||
|
# Apply filters
|
||||||
|
if document_type and s.get("document_type") != document_type:
|
||||||
|
continue
|
||||||
|
if license_code and s.get("license_code") != license_code:
|
||||||
|
continue
|
||||||
|
|
||||||
|
license_info = LICENSE_REGISTRY.get(s.get("license_code", ""), {})
|
||||||
|
|
||||||
|
result.append(DSFASourceResponse(
|
||||||
|
id=str(s["id"]),
|
||||||
|
source_code=s["source_code"],
|
||||||
|
name=s["name"],
|
||||||
|
full_name=s.get("full_name"),
|
||||||
|
organization=s.get("organization"),
|
||||||
|
source_url=s.get("source_url"),
|
||||||
|
license_code=s.get("license_code", ""),
|
||||||
|
license_name=license_info.get("name", s.get("license_code", "")),
|
||||||
|
license_url=license_info.get("url"),
|
||||||
|
attribution_required=s.get("attribution_required", True),
|
||||||
|
attribution_text=s.get("attribution_text", ""),
|
||||||
|
document_type=s.get("document_type"),
|
||||||
|
language=s.get("language", "de")
|
||||||
|
))
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/sources/available")
|
||||||
|
async def list_available_sources():
|
||||||
|
"""List all available source definitions (from DSFA_SOURCES constant)."""
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"source_code": s["source_code"],
|
||||||
|
"name": s["name"],
|
||||||
|
"organization": s.get("organization"),
|
||||||
|
"license_code": s["license_code"],
|
||||||
|
"document_type": s.get("document_type")
|
||||||
|
}
|
||||||
|
for s in DSFA_SOURCES
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/sources/{source_code}", response_model=DSFASourceResponse)
|
||||||
|
async def get_dsfa_source(
|
||||||
|
source_code: str,
|
||||||
|
store: DSFACorpusStore = Depends(get_store)
|
||||||
|
):
|
||||||
|
"""Get details for a specific source."""
|
||||||
|
source = await store.get_source_by_code(source_code)
|
||||||
|
|
||||||
|
if not source:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Source not found: {source_code}")
|
||||||
|
|
||||||
|
license_info = LICENSE_REGISTRY.get(source.get("license_code", ""), {})
|
||||||
|
|
||||||
|
return DSFASourceResponse(
|
||||||
|
id=str(source["id"]),
|
||||||
|
source_code=source["source_code"],
|
||||||
|
name=source["name"],
|
||||||
|
full_name=source.get("full_name"),
|
||||||
|
organization=source.get("organization"),
|
||||||
|
source_url=source.get("source_url"),
|
||||||
|
license_code=source.get("license_code", ""),
|
||||||
|
license_name=license_info.get("name", source.get("license_code", "")),
|
||||||
|
license_url=license_info.get("url"),
|
||||||
|
attribution_required=source.get("attribution_required", True),
|
||||||
|
attribution_text=source.get("attribution_text", ""),
|
||||||
|
document_type=source.get("document_type"),
|
||||||
|
language=source.get("language", "de")
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/sources/{source_code}/ingest", response_model=IngestResponse)
|
||||||
|
async def ingest_dsfa_source(
|
||||||
|
source_code: str,
|
||||||
|
request: IngestRequest,
|
||||||
|
store: DSFACorpusStore = Depends(get_store),
|
||||||
|
qdrant: DSFAQdrantService = Depends(get_qdrant)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Trigger ingestion for a specific source.
|
||||||
|
|
||||||
|
Can provide document via URL or direct text.
|
||||||
|
"""
|
||||||
|
# Get source
|
||||||
|
source = await store.get_source_by_code(source_code)
|
||||||
|
if not source:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Source not found: {source_code}")
|
||||||
|
|
||||||
|
# Need either URL or text
|
||||||
|
if not request.document_text and not request.document_url:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail="Either document_text or document_url must be provided"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Ensure Qdrant collection exists
|
||||||
|
await qdrant.ensure_collection()
|
||||||
|
|
||||||
|
# Get text content
|
||||||
|
text_content = request.document_text
|
||||||
|
if request.document_url and not text_content:
|
||||||
|
# Download and extract text from URL
|
||||||
|
logger.info(f"Extracting text from URL: {request.document_url}")
|
||||||
|
text_content = await extract_text_from_url(request.document_url)
|
||||||
|
if not text_content:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail=f"Could not extract text from URL: {request.document_url}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if not text_content or len(text_content.strip()) < 50:
|
||||||
|
raise HTTPException(status_code=400, detail="Document text too short (min 50 chars)")
|
||||||
|
|
||||||
|
# Create document record
|
||||||
|
doc_title = request.title or f"Document for {source_code}"
|
||||||
|
document_id = await store.create_document(
|
||||||
|
source_id=str(source["id"]),
|
||||||
|
title=doc_title,
|
||||||
|
file_type="text",
|
||||||
|
metadata={"ingested_via": "api", "source_code": source_code}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Chunk the document
|
||||||
|
chunks = chunk_document(text_content, source_code)
|
||||||
|
|
||||||
|
if not chunks:
|
||||||
|
return IngestResponse(
|
||||||
|
source_code=source_code,
|
||||||
|
document_id=document_id,
|
||||||
|
chunks_created=0,
|
||||||
|
message="Document created but no chunks generated"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Generate embeddings in batch for efficiency
|
||||||
|
chunk_texts = [chunk["content"] for chunk in chunks]
|
||||||
|
logger.info(f"Generating embeddings for {len(chunk_texts)} chunks...")
|
||||||
|
embeddings = await get_embeddings_batch(chunk_texts)
|
||||||
|
|
||||||
|
# Create chunk records in PostgreSQL and prepare for Qdrant
|
||||||
|
chunk_records = []
|
||||||
|
for i, (chunk, embedding) in enumerate(zip(chunks, embeddings)):
|
||||||
|
# Create chunk in PostgreSQL
|
||||||
|
chunk_id = await store.create_chunk(
|
||||||
|
document_id=document_id,
|
||||||
|
source_id=str(source["id"]),
|
||||||
|
content=chunk["content"],
|
||||||
|
chunk_index=i,
|
||||||
|
section_title=chunk.get("section_title"),
|
||||||
|
page_number=chunk.get("page_number"),
|
||||||
|
category=chunk.get("category")
|
||||||
|
)
|
||||||
|
|
||||||
|
chunk_records.append({
|
||||||
|
"chunk_id": chunk_id,
|
||||||
|
"document_id": document_id,
|
||||||
|
"source_id": str(source["id"]),
|
||||||
|
"content": chunk["content"],
|
||||||
|
"section_title": chunk.get("section_title"),
|
||||||
|
"source_code": source_code,
|
||||||
|
"source_name": source["name"],
|
||||||
|
"attribution_text": source["attribution_text"],
|
||||||
|
"license_code": source["license_code"],
|
||||||
|
"attribution_required": source.get("attribution_required", True),
|
||||||
|
"document_type": source.get("document_type", ""),
|
||||||
|
"category": chunk.get("category", ""),
|
||||||
|
"language": source.get("language", "de"),
|
||||||
|
"page_number": chunk.get("page_number")
|
||||||
|
})
|
||||||
|
|
||||||
|
# Index in Qdrant
|
||||||
|
indexed_count = await qdrant.index_chunks(chunk_records, embeddings)
|
||||||
|
|
||||||
|
# Update document record
|
||||||
|
await store.update_document_indexed(document_id, len(chunks))
|
||||||
|
|
||||||
|
return IngestResponse(
|
||||||
|
source_code=source_code,
|
||||||
|
document_id=document_id,
|
||||||
|
chunks_created=indexed_count,
|
||||||
|
message=f"Successfully ingested {indexed_count} chunks from document"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/chunks/{chunk_id}", response_model=DSFAChunkResponse)
|
||||||
|
async def get_chunk_with_attribution(
|
||||||
|
chunk_id: str,
|
||||||
|
store: DSFACorpusStore = Depends(get_store)
|
||||||
|
):
|
||||||
|
"""Get single chunk with full source attribution."""
|
||||||
|
chunk = await store.get_chunk_with_attribution(chunk_id)
|
||||||
|
|
||||||
|
if not chunk:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Chunk not found: {chunk_id}")
|
||||||
|
|
||||||
|
license_info = LICENSE_REGISTRY.get(chunk.get("license_code", ""), {})
|
||||||
|
|
||||||
|
return DSFAChunkResponse(
|
||||||
|
chunk_id=str(chunk["chunk_id"]),
|
||||||
|
content=chunk.get("content", ""),
|
||||||
|
section_title=chunk.get("section_title"),
|
||||||
|
page_number=chunk.get("page_number"),
|
||||||
|
category=chunk.get("category"),
|
||||||
|
document_id=str(chunk.get("document_id", "")),
|
||||||
|
document_title=chunk.get("document_title"),
|
||||||
|
source_id=str(chunk.get("source_id", "")),
|
||||||
|
source_code=chunk.get("source_code", ""),
|
||||||
|
source_name=chunk.get("source_name", ""),
|
||||||
|
attribution_text=chunk.get("attribution_text", ""),
|
||||||
|
license_code=chunk.get("license_code", ""),
|
||||||
|
license_name=license_info.get("name", chunk.get("license_code", "")),
|
||||||
|
license_url=license_info.get("url"),
|
||||||
|
attribution_required=chunk.get("attribution_required", True),
|
||||||
|
source_url=chunk.get("source_url"),
|
||||||
|
document_type=chunk.get("document_type")
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/stats", response_model=DSFACorpusStatsResponse)
|
||||||
|
async def get_corpus_stats(
|
||||||
|
store: DSFACorpusStore = Depends(get_store),
|
||||||
|
qdrant: DSFAQdrantService = Depends(get_qdrant)
|
||||||
|
):
|
||||||
|
"""Get corpus statistics for dashboard."""
|
||||||
|
# Get PostgreSQL stats
|
||||||
|
source_stats = await store.get_source_stats()
|
||||||
|
|
||||||
|
total_docs = 0
|
||||||
|
total_chunks = 0
|
||||||
|
stats_response = []
|
||||||
|
|
||||||
|
for s in source_stats:
|
||||||
|
doc_count = s.get("document_count", 0) or 0
|
||||||
|
chunk_count = s.get("chunk_count", 0) or 0
|
||||||
|
total_docs += doc_count
|
||||||
|
total_chunks += chunk_count
|
||||||
|
|
||||||
|
last_indexed = s.get("last_indexed_at")
|
||||||
|
|
||||||
|
stats_response.append(DSFASourceStatsResponse(
|
||||||
|
source_id=str(s.get("source_id", "")),
|
||||||
|
source_code=s.get("source_code", ""),
|
||||||
|
name=s.get("name", ""),
|
||||||
|
organization=s.get("organization"),
|
||||||
|
license_code=s.get("license_code", ""),
|
||||||
|
document_type=s.get("document_type"),
|
||||||
|
document_count=doc_count,
|
||||||
|
chunk_count=chunk_count,
|
||||||
|
last_indexed_at=last_indexed.isoformat() if last_indexed else None
|
||||||
|
))
|
||||||
|
|
||||||
|
# Get Qdrant stats
|
||||||
|
qdrant_stats = await qdrant.get_stats()
|
||||||
|
|
||||||
|
return DSFACorpusStatsResponse(
|
||||||
|
sources=stats_response,
|
||||||
|
total_sources=len(source_stats),
|
||||||
|
total_documents=total_docs,
|
||||||
|
total_chunks=total_chunks,
|
||||||
|
qdrant_collection=DSFA_COLLECTION,
|
||||||
|
qdrant_points_count=qdrant_stats.get("points_count", 0),
|
||||||
|
qdrant_status=qdrant_stats.get("status", "unknown")
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/licenses")
|
||||||
|
async def list_licenses():
|
||||||
|
"""List all supported licenses with their terms."""
|
||||||
|
return [
|
||||||
|
LicenseInfo(
|
||||||
|
code=code,
|
||||||
|
name=info.get("name", code),
|
||||||
|
url=info.get("url"),
|
||||||
|
attribution_required=info.get("attribution_required", True),
|
||||||
|
modification_allowed=info.get("modification_allowed", True),
|
||||||
|
commercial_use=info.get("commercial_use", True)
|
||||||
|
)
|
||||||
|
for code, info in LICENSE_REGISTRY.items()
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/init")
|
||||||
|
async def initialize_dsfa_corpus(
|
||||||
|
store: DSFACorpusStore = Depends(get_store),
|
||||||
|
qdrant: DSFAQdrantService = Depends(get_qdrant)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize DSFA corpus.
|
||||||
|
|
||||||
|
- Creates Qdrant collection
|
||||||
|
- Registers all predefined sources
|
||||||
|
"""
|
||||||
|
# Ensure Qdrant collection exists
|
||||||
|
qdrant_ok = await qdrant.ensure_collection()
|
||||||
|
|
||||||
|
# Register all sources
|
||||||
|
registered = 0
|
||||||
|
for source in DSFA_SOURCES:
|
||||||
|
try:
|
||||||
|
await store.register_source(source)
|
||||||
|
registered += 1
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error registering source {source['source_code']}: {e}")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"qdrant_collection_created": qdrant_ok,
|
||||||
|
"sources_registered": registered,
|
||||||
|
"total_sources": len(DSFA_SOURCES)
|
||||||
|
}
|
||||||
@@ -20,6 +20,7 @@ This is the main entry point. All functionality is organized in modular packages
|
|||||||
import os
|
import os
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
|
|
||||||
|
import asyncpg
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from fastapi.staticfiles import StaticFiles
|
from fastapi.staticfiles import StaticFiles
|
||||||
@@ -36,7 +37,19 @@ from admin_api import router as admin_router
|
|||||||
from zeugnis_api import router as zeugnis_router
|
from zeugnis_api import router as zeugnis_router
|
||||||
from training_api import router as training_router
|
from training_api import router as training_router
|
||||||
from mail.api import router as mail_router
|
from mail.api import router as mail_router
|
||||||
from trocr_api import router as trocr_router
|
try:
|
||||||
|
from trocr_api import router as trocr_router
|
||||||
|
except ImportError:
|
||||||
|
trocr_router = None
|
||||||
|
from vocab_worksheet_api import router as vocab_router, set_db_pool as set_vocab_db_pool, _init_vocab_table, _load_all_sessions, DATABASE_URL as VOCAB_DATABASE_URL
|
||||||
|
try:
|
||||||
|
from dsfa_rag_api import router as dsfa_rag_router, set_db_pool as set_dsfa_db_pool
|
||||||
|
from dsfa_corpus_ingestion import DSFAQdrantService, DATABASE_URL as DSFA_DATABASE_URL
|
||||||
|
except ImportError:
|
||||||
|
dsfa_rag_router = None
|
||||||
|
set_dsfa_db_pool = None
|
||||||
|
DSFAQdrantService = None
|
||||||
|
DSFA_DATABASE_URL = None
|
||||||
|
|
||||||
# BYOEH Qdrant initialization
|
# BYOEH Qdrant initialization
|
||||||
from qdrant_service import init_qdrant_collection
|
from qdrant_service import init_qdrant_collection
|
||||||
@@ -51,12 +64,42 @@ async def lifespan(app: FastAPI):
|
|||||||
"""Application lifespan manager for startup and shutdown events."""
|
"""Application lifespan manager for startup and shutdown events."""
|
||||||
print("Klausur-Service starting...")
|
print("Klausur-Service starting...")
|
||||||
|
|
||||||
|
# Initialize database pool for Vocab Sessions
|
||||||
|
vocab_db_pool = None
|
||||||
|
try:
|
||||||
|
vocab_db_pool = await asyncpg.create_pool(VOCAB_DATABASE_URL, min_size=2, max_size=5)
|
||||||
|
set_vocab_db_pool(vocab_db_pool)
|
||||||
|
await _init_vocab_table()
|
||||||
|
await _load_all_sessions()
|
||||||
|
print(f"Vocab sessions database initialized")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Warning: Vocab sessions database initialization failed: {e}")
|
||||||
|
|
||||||
|
# Initialize database pool for DSFA RAG
|
||||||
|
dsfa_db_pool = None
|
||||||
|
if DSFA_DATABASE_URL and set_dsfa_db_pool:
|
||||||
|
try:
|
||||||
|
dsfa_db_pool = await asyncpg.create_pool(DSFA_DATABASE_URL, min_size=2, max_size=10)
|
||||||
|
set_dsfa_db_pool(dsfa_db_pool)
|
||||||
|
print(f"DSFA database pool initialized: {DSFA_DATABASE_URL}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Warning: DSFA database pool initialization failed: {e}")
|
||||||
|
|
||||||
# Initialize Qdrant collection for BYOEH
|
# Initialize Qdrant collection for BYOEH
|
||||||
try:
|
try:
|
||||||
await init_qdrant_collection()
|
await init_qdrant_collection()
|
||||||
print("Qdrant BYOEH collection initialized")
|
print("Qdrant BYOEH collection initialized")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Warning: Qdrant initialization failed: {e}")
|
print(f"Warning: Qdrant BYOEH initialization failed: {e}")
|
||||||
|
|
||||||
|
# Initialize Qdrant collection for DSFA RAG
|
||||||
|
if DSFAQdrantService:
|
||||||
|
try:
|
||||||
|
dsfa_qdrant = DSFAQdrantService()
|
||||||
|
await dsfa_qdrant.ensure_collection()
|
||||||
|
print("Qdrant DSFA corpus collection initialized")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Warning: Qdrant DSFA initialization failed: {e}")
|
||||||
|
|
||||||
# Ensure EH upload directory exists
|
# Ensure EH upload directory exists
|
||||||
os.makedirs(EH_UPLOAD_DIR, exist_ok=True)
|
os.makedirs(EH_UPLOAD_DIR, exist_ok=True)
|
||||||
@@ -65,6 +108,16 @@ async def lifespan(app: FastAPI):
|
|||||||
|
|
||||||
print("Klausur-Service shutting down...")
|
print("Klausur-Service shutting down...")
|
||||||
|
|
||||||
|
# Close Vocab sessions database pool
|
||||||
|
if vocab_db_pool:
|
||||||
|
await vocab_db_pool.close()
|
||||||
|
print("Vocab sessions database pool closed")
|
||||||
|
|
||||||
|
# Close DSFA database pool
|
||||||
|
if dsfa_db_pool:
|
||||||
|
await dsfa_db_pool.close()
|
||||||
|
print("DSFA database pool closed")
|
||||||
|
|
||||||
|
|
||||||
app = FastAPI(
|
app = FastAPI(
|
||||||
title="Klausur-Service",
|
title="Klausur-Service",
|
||||||
@@ -94,7 +147,11 @@ app.include_router(admin_router) # NiBiS Ingestion
|
|||||||
app.include_router(zeugnis_router) # Zeugnis Rights-Aware Crawler
|
app.include_router(zeugnis_router) # Zeugnis Rights-Aware Crawler
|
||||||
app.include_router(training_router) # Training Management
|
app.include_router(training_router) # Training Management
|
||||||
app.include_router(mail_router) # Unified Inbox Mail
|
app.include_router(mail_router) # Unified Inbox Mail
|
||||||
app.include_router(trocr_router) # TrOCR Handwriting OCR
|
if trocr_router:
|
||||||
|
app.include_router(trocr_router) # TrOCR Handwriting OCR
|
||||||
|
app.include_router(vocab_router) # Vocabulary Worksheet Generator
|
||||||
|
if dsfa_rag_router:
|
||||||
|
app.include_router(dsfa_rag_router) # DSFA RAG Corpus Search
|
||||||
|
|
||||||
|
|
||||||
# =============================================
|
# =============================================
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ python-dotenv>=1.0.0
|
|||||||
qdrant-client>=1.7.0
|
qdrant-client>=1.7.0
|
||||||
cryptography>=41.0.0
|
cryptography>=41.0.0
|
||||||
PyPDF2>=3.0.0
|
PyPDF2>=3.0.0
|
||||||
|
PyMuPDF>=1.24.0
|
||||||
|
|
||||||
# PyTorch CPU-only (smaller, no CUDA needed for Docker on Mac)
|
# PyTorch CPU-only (smaller, no CUDA needed for Docker on Mac)
|
||||||
--extra-index-url https://download.pytorch.org/whl/cpu
|
--extra-index-url https://download.pytorch.org/whl/cpu
|
||||||
@@ -23,6 +24,10 @@ minio>=7.2.0
|
|||||||
# OpenCV for handwriting detection (headless = no GUI, smaller for CI)
|
# OpenCV for handwriting detection (headless = no GUI, smaller for CI)
|
||||||
opencv-python-headless>=4.8.0
|
opencv-python-headless>=4.8.0
|
||||||
|
|
||||||
|
# Tesseract OCR Python binding (requires system tesseract-ocr package)
|
||||||
|
pytesseract>=0.3.10
|
||||||
|
Pillow>=10.0.0
|
||||||
|
|
||||||
# PostgreSQL (for metrics storage)
|
# PostgreSQL (for metrics storage)
|
||||||
psycopg2-binary>=2.9.0
|
psycopg2-binary>=2.9.0
|
||||||
asyncpg>=0.29.0
|
asyncpg>=0.29.0
|
||||||
|
|||||||
509
klausur-service/backend/services/grid_detection_service.py
Normal file
509
klausur-service/backend/services/grid_detection_service.py
Normal file
@@ -0,0 +1,509 @@
|
|||||||
|
"""
|
||||||
|
Grid Detection Service v4
|
||||||
|
|
||||||
|
Detects table/grid structure from OCR bounding-box data.
|
||||||
|
Converts pixel coordinates to percentage and mm coordinates (A4 format).
|
||||||
|
Supports deskew correction, column detection, and multi-line cell handling.
|
||||||
|
|
||||||
|
Lizenz: Apache 2.0 (kommerziell nutzbar)
|
||||||
|
"""
|
||||||
|
|
||||||
|
import math
|
||||||
|
import logging
|
||||||
|
from enum import Enum
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import List, Optional, Dict, Any, Tuple
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# A4 dimensions
|
||||||
|
A4_WIDTH_MM = 210.0
|
||||||
|
A4_HEIGHT_MM = 297.0
|
||||||
|
|
||||||
|
# Column margin (1mm)
|
||||||
|
COLUMN_MARGIN_MM = 1.0
|
||||||
|
COLUMN_MARGIN_PCT = (COLUMN_MARGIN_MM / A4_WIDTH_MM) * 100
|
||||||
|
|
||||||
|
|
||||||
|
class CellStatus(str, Enum):
|
||||||
|
EMPTY = "empty"
|
||||||
|
RECOGNIZED = "recognized"
|
||||||
|
PROBLEMATIC = "problematic"
|
||||||
|
MANUAL = "manual"
|
||||||
|
|
||||||
|
|
||||||
|
class ColumnType(str, Enum):
|
||||||
|
ENGLISH = "english"
|
||||||
|
GERMAN = "german"
|
||||||
|
EXAMPLE = "example"
|
||||||
|
UNKNOWN = "unknown"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class OCRRegion:
|
||||||
|
"""A word/phrase detected by OCR with bounding box coordinates in percentage (0-100)."""
|
||||||
|
text: str
|
||||||
|
confidence: float
|
||||||
|
x: float # X position as percentage of page width
|
||||||
|
y: float # Y position as percentage of page height
|
||||||
|
width: float # Width as percentage of page width
|
||||||
|
height: float # Height as percentage of page height
|
||||||
|
|
||||||
|
@property
|
||||||
|
def x_mm(self) -> float:
|
||||||
|
return round(self.x / 100 * A4_WIDTH_MM, 1)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def y_mm(self) -> float:
|
||||||
|
return round(self.y / 100 * A4_HEIGHT_MM, 1)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def width_mm(self) -> float:
|
||||||
|
return round(self.width / 100 * A4_WIDTH_MM, 1)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def height_mm(self) -> float:
|
||||||
|
return round(self.height / 100 * A4_HEIGHT_MM, 2)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def center_x(self) -> float:
|
||||||
|
return self.x + self.width / 2
|
||||||
|
|
||||||
|
@property
|
||||||
|
def center_y(self) -> float:
|
||||||
|
return self.y + self.height / 2
|
||||||
|
|
||||||
|
@property
|
||||||
|
def right(self) -> float:
|
||||||
|
return self.x + self.width
|
||||||
|
|
||||||
|
@property
|
||||||
|
def bottom(self) -> float:
|
||||||
|
return self.y + self.height
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class GridCell:
|
||||||
|
"""A cell in the detected grid with coordinates in percentage (0-100)."""
|
||||||
|
row: int
|
||||||
|
col: int
|
||||||
|
x: float
|
||||||
|
y: float
|
||||||
|
width: float
|
||||||
|
height: float
|
||||||
|
text: str = ""
|
||||||
|
confidence: float = 0.0
|
||||||
|
status: CellStatus = CellStatus.EMPTY
|
||||||
|
column_type: ColumnType = ColumnType.UNKNOWN
|
||||||
|
logical_row: int = 0
|
||||||
|
logical_col: int = 0
|
||||||
|
is_continuation: bool = False
|
||||||
|
|
||||||
|
@property
|
||||||
|
def x_mm(self) -> float:
|
||||||
|
return round(self.x / 100 * A4_WIDTH_MM, 1)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def y_mm(self) -> float:
|
||||||
|
return round(self.y / 100 * A4_HEIGHT_MM, 1)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def width_mm(self) -> float:
|
||||||
|
return round(self.width / 100 * A4_WIDTH_MM, 1)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def height_mm(self) -> float:
|
||||||
|
return round(self.height / 100 * A4_HEIGHT_MM, 2)
|
||||||
|
|
||||||
|
def to_dict(self) -> dict:
|
||||||
|
return {
|
||||||
|
"row": self.row,
|
||||||
|
"col": self.col,
|
||||||
|
"x": round(self.x, 2),
|
||||||
|
"y": round(self.y, 2),
|
||||||
|
"width": round(self.width, 2),
|
||||||
|
"height": round(self.height, 2),
|
||||||
|
"x_mm": self.x_mm,
|
||||||
|
"y_mm": self.y_mm,
|
||||||
|
"width_mm": self.width_mm,
|
||||||
|
"height_mm": self.height_mm,
|
||||||
|
"text": self.text,
|
||||||
|
"confidence": self.confidence,
|
||||||
|
"status": self.status.value,
|
||||||
|
"column_type": self.column_type.value,
|
||||||
|
"logical_row": self.logical_row,
|
||||||
|
"logical_col": self.logical_col,
|
||||||
|
"is_continuation": self.is_continuation,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class GridResult:
|
||||||
|
"""Result of grid detection."""
|
||||||
|
rows: int = 0
|
||||||
|
columns: int = 0
|
||||||
|
cells: List[List[GridCell]] = field(default_factory=list)
|
||||||
|
column_types: List[str] = field(default_factory=list)
|
||||||
|
column_boundaries: List[float] = field(default_factory=list)
|
||||||
|
row_boundaries: List[float] = field(default_factory=list)
|
||||||
|
deskew_angle: float = 0.0
|
||||||
|
stats: Dict[str, Any] = field(default_factory=dict)
|
||||||
|
|
||||||
|
def to_dict(self) -> dict:
|
||||||
|
cells_dicts = []
|
||||||
|
for row_cells in self.cells:
|
||||||
|
cells_dicts.append([c.to_dict() for c in row_cells])
|
||||||
|
|
||||||
|
return {
|
||||||
|
"rows": self.rows,
|
||||||
|
"columns": self.columns,
|
||||||
|
"cells": cells_dicts,
|
||||||
|
"column_types": self.column_types,
|
||||||
|
"column_boundaries": [round(b, 2) for b in self.column_boundaries],
|
||||||
|
"row_boundaries": [round(b, 2) for b in self.row_boundaries],
|
||||||
|
"deskew_angle": round(self.deskew_angle, 2),
|
||||||
|
"stats": self.stats,
|
||||||
|
"page_dimensions": {
|
||||||
|
"width_mm": A4_WIDTH_MM,
|
||||||
|
"height_mm": A4_HEIGHT_MM,
|
||||||
|
"format": "A4",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class GridDetectionService:
|
||||||
|
"""Detect grid/table structure from OCR bounding-box regions."""
|
||||||
|
|
||||||
|
def __init__(self, y_tolerance_pct: float = 1.5, padding_pct: float = 0.3,
|
||||||
|
column_margin_mm: float = COLUMN_MARGIN_MM):
|
||||||
|
self.y_tolerance_pct = y_tolerance_pct
|
||||||
|
self.padding_pct = padding_pct
|
||||||
|
self.column_margin_mm = column_margin_mm
|
||||||
|
|
||||||
|
def calculate_deskew_angle(self, regions: List[OCRRegion]) -> float:
|
||||||
|
"""Calculate page skew angle from OCR region positions.
|
||||||
|
|
||||||
|
Uses left-edge alignment of regions to detect consistent tilt.
|
||||||
|
Returns angle in degrees, clamped to ±5°.
|
||||||
|
"""
|
||||||
|
if len(regions) < 3:
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
# Group by similar X position (same column)
|
||||||
|
sorted_by_x = sorted(regions, key=lambda r: r.x)
|
||||||
|
|
||||||
|
# Find regions that are vertically aligned (similar X)
|
||||||
|
x_tolerance = 3.0 # percent
|
||||||
|
aligned_groups: List[List[OCRRegion]] = []
|
||||||
|
current_group = [sorted_by_x[0]]
|
||||||
|
|
||||||
|
for r in sorted_by_x[1:]:
|
||||||
|
if abs(r.x - current_group[0].x) <= x_tolerance:
|
||||||
|
current_group.append(r)
|
||||||
|
else:
|
||||||
|
if len(current_group) >= 3:
|
||||||
|
aligned_groups.append(current_group)
|
||||||
|
current_group = [r]
|
||||||
|
|
||||||
|
if len(current_group) >= 3:
|
||||||
|
aligned_groups.append(current_group)
|
||||||
|
|
||||||
|
if not aligned_groups:
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
# Use the largest aligned group to calculate skew
|
||||||
|
best_group = max(aligned_groups, key=len)
|
||||||
|
best_group.sort(key=lambda r: r.y)
|
||||||
|
|
||||||
|
# Linear regression: X as function of Y
|
||||||
|
n = len(best_group)
|
||||||
|
sum_y = sum(r.y for r in best_group)
|
||||||
|
sum_x = sum(r.x for r in best_group)
|
||||||
|
sum_xy = sum(r.x * r.y for r in best_group)
|
||||||
|
sum_y2 = sum(r.y ** 2 for r in best_group)
|
||||||
|
|
||||||
|
denom = n * sum_y2 - sum_y ** 2
|
||||||
|
if denom == 0:
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
slope = (n * sum_xy - sum_y * sum_x) / denom
|
||||||
|
|
||||||
|
# Convert slope to angle (slope is dx/dy in percent space)
|
||||||
|
# Adjust for aspect ratio: A4 is 210/297 ≈ 0.707
|
||||||
|
aspect = A4_WIDTH_MM / A4_HEIGHT_MM
|
||||||
|
angle_rad = math.atan(slope * aspect)
|
||||||
|
angle_deg = math.degrees(angle_rad)
|
||||||
|
|
||||||
|
# Clamp to ±5°
|
||||||
|
return max(-5.0, min(5.0, round(angle_deg, 2)))
|
||||||
|
|
||||||
|
def apply_deskew_to_regions(self, regions: List[OCRRegion], angle: float) -> List[OCRRegion]:
|
||||||
|
"""Apply deskew correction to region coordinates.
|
||||||
|
|
||||||
|
Rotates all coordinates around the page center by -angle.
|
||||||
|
"""
|
||||||
|
if abs(angle) < 0.01:
|
||||||
|
return regions
|
||||||
|
|
||||||
|
angle_rad = math.radians(-angle)
|
||||||
|
cos_a = math.cos(angle_rad)
|
||||||
|
sin_a = math.sin(angle_rad)
|
||||||
|
|
||||||
|
# Page center
|
||||||
|
cx, cy = 50.0, 50.0
|
||||||
|
|
||||||
|
result = []
|
||||||
|
for r in regions:
|
||||||
|
# Rotate center of region around page center
|
||||||
|
rx = r.center_x - cx
|
||||||
|
ry = r.center_y - cy
|
||||||
|
new_cx = rx * cos_a - ry * sin_a + cx
|
||||||
|
new_cy = rx * sin_a + ry * cos_a + cy
|
||||||
|
new_x = new_cx - r.width / 2
|
||||||
|
new_y = new_cy - r.height / 2
|
||||||
|
|
||||||
|
result.append(OCRRegion(
|
||||||
|
text=r.text,
|
||||||
|
confidence=r.confidence,
|
||||||
|
x=round(new_x, 2),
|
||||||
|
y=round(new_y, 2),
|
||||||
|
width=r.width,
|
||||||
|
height=r.height,
|
||||||
|
))
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def _group_regions_into_rows(self, regions: List[OCRRegion]) -> List[List[OCRRegion]]:
|
||||||
|
"""Group regions by Y position into rows."""
|
||||||
|
if not regions:
|
||||||
|
return []
|
||||||
|
|
||||||
|
sorted_regions = sorted(regions, key=lambda r: r.y)
|
||||||
|
rows: List[List[OCRRegion]] = []
|
||||||
|
current_row = [sorted_regions[0]]
|
||||||
|
current_y = sorted_regions[0].center_y
|
||||||
|
|
||||||
|
for r in sorted_regions[1:]:
|
||||||
|
if abs(r.center_y - current_y) <= self.y_tolerance_pct:
|
||||||
|
current_row.append(r)
|
||||||
|
else:
|
||||||
|
current_row.sort(key=lambda r: r.x)
|
||||||
|
rows.append(current_row)
|
||||||
|
current_row = [r]
|
||||||
|
current_y = r.center_y
|
||||||
|
|
||||||
|
if current_row:
|
||||||
|
current_row.sort(key=lambda r: r.x)
|
||||||
|
rows.append(current_row)
|
||||||
|
|
||||||
|
return rows
|
||||||
|
|
||||||
|
def _detect_column_boundaries(self, rows: List[List[OCRRegion]]) -> List[float]:
|
||||||
|
"""Detect column boundaries from row data."""
|
||||||
|
if not rows:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Collect all X starting positions
|
||||||
|
all_x = []
|
||||||
|
for row in rows:
|
||||||
|
for r in row:
|
||||||
|
all_x.append(r.x)
|
||||||
|
|
||||||
|
if not all_x:
|
||||||
|
return []
|
||||||
|
|
||||||
|
all_x.sort()
|
||||||
|
|
||||||
|
# Gap-based clustering
|
||||||
|
min_gap = 5.0 # percent
|
||||||
|
clusters: List[List[float]] = []
|
||||||
|
current = [all_x[0]]
|
||||||
|
|
||||||
|
for x in all_x[1:]:
|
||||||
|
if x - current[-1] > min_gap:
|
||||||
|
clusters.append(current)
|
||||||
|
current = [x]
|
||||||
|
else:
|
||||||
|
current.append(x)
|
||||||
|
|
||||||
|
if current:
|
||||||
|
clusters.append(current)
|
||||||
|
|
||||||
|
# Column boundaries: start of each cluster
|
||||||
|
boundaries = [min(c) - self.padding_pct for c in clusters]
|
||||||
|
# Add right boundary
|
||||||
|
boundaries.append(100.0)
|
||||||
|
|
||||||
|
return boundaries
|
||||||
|
|
||||||
|
def _assign_column_types(self, boundaries: List[float]) -> List[str]:
|
||||||
|
"""Assign column types based on position."""
|
||||||
|
num_cols = max(0, len(boundaries) - 1)
|
||||||
|
type_map = [ColumnType.ENGLISH, ColumnType.GERMAN, ColumnType.EXAMPLE]
|
||||||
|
result = []
|
||||||
|
for i in range(num_cols):
|
||||||
|
if i < len(type_map):
|
||||||
|
result.append(type_map[i].value)
|
||||||
|
else:
|
||||||
|
result.append(ColumnType.UNKNOWN.value)
|
||||||
|
return result
|
||||||
|
|
||||||
|
def detect_grid(self, regions: List[OCRRegion]) -> GridResult:
|
||||||
|
"""Detect grid structure from OCR regions.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
regions: List of OCR regions with percentage-based coordinates.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
GridResult with detected rows, columns, and cells.
|
||||||
|
"""
|
||||||
|
if not regions:
|
||||||
|
return GridResult(stats={"recognized": 0, "problematic": 0, "empty": 0, "manual": 0, "total": 0, "coverage": 0.0})
|
||||||
|
|
||||||
|
# Step 1: Calculate and apply deskew
|
||||||
|
deskew_angle = self.calculate_deskew_angle(regions)
|
||||||
|
corrected_regions = self.apply_deskew_to_regions(regions, deskew_angle)
|
||||||
|
|
||||||
|
# Step 2: Group into rows
|
||||||
|
rows = self._group_regions_into_rows(corrected_regions)
|
||||||
|
|
||||||
|
# Step 3: Detect column boundaries
|
||||||
|
col_boundaries = self._detect_column_boundaries(rows)
|
||||||
|
column_types = self._assign_column_types(col_boundaries)
|
||||||
|
num_cols = max(1, len(col_boundaries) - 1)
|
||||||
|
|
||||||
|
# Step 4: Build cell grid
|
||||||
|
num_rows = len(rows)
|
||||||
|
row_boundaries = []
|
||||||
|
cells = []
|
||||||
|
|
||||||
|
recognized = 0
|
||||||
|
problematic = 0
|
||||||
|
empty = 0
|
||||||
|
|
||||||
|
for row_idx, row_regions in enumerate(rows):
|
||||||
|
# Row Y boundary
|
||||||
|
if row_regions:
|
||||||
|
row_y = min(r.y for r in row_regions) - self.padding_pct
|
||||||
|
row_bottom = max(r.bottom for r in row_regions) + self.padding_pct
|
||||||
|
else:
|
||||||
|
row_y = row_idx / num_rows * 100
|
||||||
|
row_bottom = (row_idx + 1) / num_rows * 100
|
||||||
|
|
||||||
|
row_boundaries.append(row_y)
|
||||||
|
row_height = row_bottom - row_y
|
||||||
|
|
||||||
|
row_cells = []
|
||||||
|
for col_idx in range(num_cols):
|
||||||
|
col_x = col_boundaries[col_idx]
|
||||||
|
col_right = col_boundaries[col_idx + 1] if col_idx + 1 < len(col_boundaries) else 100.0
|
||||||
|
col_width = col_right - col_x
|
||||||
|
|
||||||
|
# Find regions in this cell
|
||||||
|
cell_regions = []
|
||||||
|
for r in row_regions:
|
||||||
|
r_center = r.center_x
|
||||||
|
if col_x <= r_center < col_right:
|
||||||
|
cell_regions.append(r)
|
||||||
|
|
||||||
|
if cell_regions:
|
||||||
|
text = " ".join(r.text for r in cell_regions)
|
||||||
|
avg_conf = sum(r.confidence for r in cell_regions) / len(cell_regions)
|
||||||
|
status = CellStatus.RECOGNIZED if avg_conf >= 0.5 else CellStatus.PROBLEMATIC
|
||||||
|
# Use actual bounding box from regions
|
||||||
|
actual_x = min(r.x for r in cell_regions)
|
||||||
|
actual_y = min(r.y for r in cell_regions)
|
||||||
|
actual_right = max(r.right for r in cell_regions)
|
||||||
|
actual_bottom = max(r.bottom for r in cell_regions)
|
||||||
|
|
||||||
|
cell = GridCell(
|
||||||
|
row=row_idx,
|
||||||
|
col=col_idx,
|
||||||
|
x=actual_x,
|
||||||
|
y=actual_y,
|
||||||
|
width=actual_right - actual_x,
|
||||||
|
height=actual_bottom - actual_y,
|
||||||
|
text=text,
|
||||||
|
confidence=round(avg_conf, 3),
|
||||||
|
status=status,
|
||||||
|
column_type=ColumnType(column_types[col_idx]) if col_idx < len(column_types) else ColumnType.UNKNOWN,
|
||||||
|
logical_row=row_idx,
|
||||||
|
logical_col=col_idx,
|
||||||
|
)
|
||||||
|
|
||||||
|
if status == CellStatus.RECOGNIZED:
|
||||||
|
recognized += 1
|
||||||
|
else:
|
||||||
|
problematic += 1
|
||||||
|
else:
|
||||||
|
cell = GridCell(
|
||||||
|
row=row_idx,
|
||||||
|
col=col_idx,
|
||||||
|
x=col_x,
|
||||||
|
y=row_y,
|
||||||
|
width=col_width,
|
||||||
|
height=row_height,
|
||||||
|
status=CellStatus.EMPTY,
|
||||||
|
column_type=ColumnType(column_types[col_idx]) if col_idx < len(column_types) else ColumnType.UNKNOWN,
|
||||||
|
logical_row=row_idx,
|
||||||
|
logical_col=col_idx,
|
||||||
|
)
|
||||||
|
empty += 1
|
||||||
|
|
||||||
|
row_cells.append(cell)
|
||||||
|
cells.append(row_cells)
|
||||||
|
|
||||||
|
# Add final row boundary
|
||||||
|
if rows and rows[-1]:
|
||||||
|
row_boundaries.append(max(r.bottom for r in rows[-1]) + self.padding_pct)
|
||||||
|
else:
|
||||||
|
row_boundaries.append(100.0)
|
||||||
|
|
||||||
|
total = num_rows * num_cols
|
||||||
|
coverage = (recognized + problematic) / max(total, 1)
|
||||||
|
|
||||||
|
return GridResult(
|
||||||
|
rows=num_rows,
|
||||||
|
columns=num_cols,
|
||||||
|
cells=cells,
|
||||||
|
column_types=column_types,
|
||||||
|
column_boundaries=col_boundaries,
|
||||||
|
row_boundaries=row_boundaries,
|
||||||
|
deskew_angle=deskew_angle,
|
||||||
|
stats={
|
||||||
|
"recognized": recognized,
|
||||||
|
"problematic": problematic,
|
||||||
|
"empty": empty,
|
||||||
|
"manual": 0,
|
||||||
|
"total": total,
|
||||||
|
"coverage": round(coverage, 3),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
def convert_tesseract_regions(self, tess_words: List[dict],
|
||||||
|
image_width: int, image_height: int) -> List[OCRRegion]:
|
||||||
|
"""Convert Tesseract word data (pixels) to OCRRegions (percentages).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tess_words: Word list from tesseract_vocab_extractor.extract_bounding_boxes.
|
||||||
|
image_width: Image width in pixels.
|
||||||
|
image_height: Image height in pixels.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of OCRRegion with percentage-based coordinates.
|
||||||
|
"""
|
||||||
|
if not tess_words or image_width == 0 or image_height == 0:
|
||||||
|
return []
|
||||||
|
|
||||||
|
regions = []
|
||||||
|
for w in tess_words:
|
||||||
|
regions.append(OCRRegion(
|
||||||
|
text=w["text"],
|
||||||
|
confidence=w.get("conf", 50) / 100.0,
|
||||||
|
x=w["left"] / image_width * 100,
|
||||||
|
y=w["top"] / image_height * 100,
|
||||||
|
width=w["width"] / image_width * 100,
|
||||||
|
height=w["height"] / image_height * 100,
|
||||||
|
))
|
||||||
|
|
||||||
|
return regions
|
||||||
346
klausur-service/backend/tesseract_vocab_extractor.py
Normal file
346
klausur-service/backend/tesseract_vocab_extractor.py
Normal file
@@ -0,0 +1,346 @@
|
|||||||
|
"""
|
||||||
|
Tesseract-based OCR extraction with word-level bounding boxes.
|
||||||
|
|
||||||
|
Uses Tesseract for spatial information (WHERE text is) while
|
||||||
|
the Vision LLM handles semantic understanding (WHAT the text means).
|
||||||
|
|
||||||
|
Tesseract runs natively on ARM64 via Debian's apt package.
|
||||||
|
|
||||||
|
Lizenz: Apache 2.0 (kommerziell nutzbar)
|
||||||
|
"""
|
||||||
|
|
||||||
|
import io
|
||||||
|
import logging
|
||||||
|
from typing import List, Dict, Any, Optional
|
||||||
|
from difflib import SequenceMatcher
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
try:
|
||||||
|
import pytesseract
|
||||||
|
from PIL import Image
|
||||||
|
TESSERACT_AVAILABLE = True
|
||||||
|
except ImportError:
|
||||||
|
TESSERACT_AVAILABLE = False
|
||||||
|
logger.warning("pytesseract or Pillow not installed - Tesseract OCR unavailable")
|
||||||
|
|
||||||
|
|
||||||
|
async def extract_bounding_boxes(image_bytes: bytes, lang: str = "eng+deu") -> dict:
|
||||||
|
"""Run Tesseract OCR and return word-level bounding boxes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_bytes: PNG/JPEG image as bytes.
|
||||||
|
lang: Tesseract language string (e.g. "eng+deu").
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict with 'words' list and 'image_width'/'image_height'.
|
||||||
|
"""
|
||||||
|
if not TESSERACT_AVAILABLE:
|
||||||
|
return {"words": [], "image_width": 0, "image_height": 0, "error": "Tesseract not available"}
|
||||||
|
|
||||||
|
image = Image.open(io.BytesIO(image_bytes))
|
||||||
|
data = pytesseract.image_to_data(image, lang=lang, output_type=pytesseract.Output.DICT)
|
||||||
|
|
||||||
|
words = []
|
||||||
|
for i in range(len(data['text'])):
|
||||||
|
text = data['text'][i].strip()
|
||||||
|
conf = int(data['conf'][i])
|
||||||
|
if not text or conf < 20:
|
||||||
|
continue
|
||||||
|
words.append({
|
||||||
|
"text": text,
|
||||||
|
"left": data['left'][i],
|
||||||
|
"top": data['top'][i],
|
||||||
|
"width": data['width'][i],
|
||||||
|
"height": data['height'][i],
|
||||||
|
"conf": conf,
|
||||||
|
"block_num": data['block_num'][i],
|
||||||
|
"par_num": data['par_num'][i],
|
||||||
|
"line_num": data['line_num'][i],
|
||||||
|
"word_num": data['word_num'][i],
|
||||||
|
})
|
||||||
|
|
||||||
|
return {
|
||||||
|
"words": words,
|
||||||
|
"image_width": image.width,
|
||||||
|
"image_height": image.height,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def group_words_into_lines(words: List[dict], y_tolerance_px: int = 15) -> List[List[dict]]:
|
||||||
|
"""Group words by their Y position into lines.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
words: List of word dicts from extract_bounding_boxes.
|
||||||
|
y_tolerance_px: Max pixel distance to consider words on the same line.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of lines, each line is a list of words sorted by X position.
|
||||||
|
"""
|
||||||
|
if not words:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Sort by Y then X
|
||||||
|
sorted_words = sorted(words, key=lambda w: (w['top'], w['left']))
|
||||||
|
|
||||||
|
lines: List[List[dict]] = []
|
||||||
|
current_line: List[dict] = [sorted_words[0]]
|
||||||
|
current_y = sorted_words[0]['top']
|
||||||
|
|
||||||
|
for word in sorted_words[1:]:
|
||||||
|
if abs(word['top'] - current_y) <= y_tolerance_px:
|
||||||
|
current_line.append(word)
|
||||||
|
else:
|
||||||
|
current_line.sort(key=lambda w: w['left'])
|
||||||
|
lines.append(current_line)
|
||||||
|
current_line = [word]
|
||||||
|
current_y = word['top']
|
||||||
|
|
||||||
|
if current_line:
|
||||||
|
current_line.sort(key=lambda w: w['left'])
|
||||||
|
lines.append(current_line)
|
||||||
|
|
||||||
|
return lines
|
||||||
|
|
||||||
|
|
||||||
|
def detect_columns(lines: List[List[dict]], image_width: int) -> Dict[str, Any]:
|
||||||
|
"""Detect column boundaries from word positions.
|
||||||
|
|
||||||
|
Typical vocab table: Left=English, Middle=German, Right=Example sentences.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict with column boundaries and type assignments.
|
||||||
|
"""
|
||||||
|
if not lines or image_width == 0:
|
||||||
|
return {"columns": [], "column_types": []}
|
||||||
|
|
||||||
|
# Collect all word X positions
|
||||||
|
all_x_positions = []
|
||||||
|
for line in lines:
|
||||||
|
for word in line:
|
||||||
|
all_x_positions.append(word['left'])
|
||||||
|
|
||||||
|
if not all_x_positions:
|
||||||
|
return {"columns": [], "column_types": []}
|
||||||
|
|
||||||
|
# Find X-position clusters (column starts)
|
||||||
|
all_x_positions.sort()
|
||||||
|
|
||||||
|
# Simple gap-based column detection
|
||||||
|
min_gap = image_width * 0.08 # 8% of page width = column gap
|
||||||
|
clusters = []
|
||||||
|
current_cluster = [all_x_positions[0]]
|
||||||
|
|
||||||
|
for x in all_x_positions[1:]:
|
||||||
|
if x - current_cluster[-1] > min_gap:
|
||||||
|
clusters.append(current_cluster)
|
||||||
|
current_cluster = [x]
|
||||||
|
else:
|
||||||
|
current_cluster.append(x)
|
||||||
|
|
||||||
|
if current_cluster:
|
||||||
|
clusters.append(current_cluster)
|
||||||
|
|
||||||
|
# Each cluster represents a column start
|
||||||
|
columns = []
|
||||||
|
for cluster in clusters:
|
||||||
|
col_start = min(cluster)
|
||||||
|
columns.append({
|
||||||
|
"x_start": col_start,
|
||||||
|
"x_start_pct": col_start / image_width * 100,
|
||||||
|
"word_count": len(cluster),
|
||||||
|
})
|
||||||
|
|
||||||
|
# Assign column types based on position (left→right: EN, DE, Example)
|
||||||
|
type_map = ["english", "german", "example"]
|
||||||
|
column_types = []
|
||||||
|
for i, col in enumerate(columns):
|
||||||
|
if i < len(type_map):
|
||||||
|
column_types.append(type_map[i])
|
||||||
|
else:
|
||||||
|
column_types.append("unknown")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"columns": columns,
|
||||||
|
"column_types": column_types,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def words_to_vocab_entries(lines: List[List[dict]], columns: List[dict],
|
||||||
|
column_types: List[str], image_width: int,
|
||||||
|
image_height: int) -> List[dict]:
|
||||||
|
"""Convert grouped words into vocabulary entries using column positions.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
lines: Grouped word lines from group_words_into_lines.
|
||||||
|
columns: Column boundaries from detect_columns.
|
||||||
|
column_types: Column type assignments.
|
||||||
|
image_width: Image width in pixels.
|
||||||
|
image_height: Image height in pixels.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of vocabulary entry dicts with english/german/example fields.
|
||||||
|
"""
|
||||||
|
if not columns or not lines:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Build column boundaries for word assignment
|
||||||
|
col_boundaries = []
|
||||||
|
for i, col in enumerate(columns):
|
||||||
|
start = col['x_start']
|
||||||
|
if i + 1 < len(columns):
|
||||||
|
end = columns[i + 1]['x_start']
|
||||||
|
else:
|
||||||
|
end = image_width
|
||||||
|
col_boundaries.append((start, end, column_types[i] if i < len(column_types) else "unknown"))
|
||||||
|
|
||||||
|
entries = []
|
||||||
|
for line in lines:
|
||||||
|
entry = {"english": "", "german": "", "example": ""}
|
||||||
|
line_words_by_col: Dict[str, List[str]] = {"english": [], "german": [], "example": []}
|
||||||
|
line_bbox: Dict[str, Optional[dict]] = {}
|
||||||
|
|
||||||
|
for word in line:
|
||||||
|
word_center_x = word['left'] + word['width'] / 2
|
||||||
|
assigned_type = "unknown"
|
||||||
|
for start, end, col_type in col_boundaries:
|
||||||
|
if start <= word_center_x < end:
|
||||||
|
assigned_type = col_type
|
||||||
|
break
|
||||||
|
|
||||||
|
if assigned_type in line_words_by_col:
|
||||||
|
line_words_by_col[assigned_type].append(word['text'])
|
||||||
|
# Track bounding box for the column
|
||||||
|
if assigned_type not in line_bbox or line_bbox[assigned_type] is None:
|
||||||
|
line_bbox[assigned_type] = {
|
||||||
|
"left": word['left'],
|
||||||
|
"top": word['top'],
|
||||||
|
"right": word['left'] + word['width'],
|
||||||
|
"bottom": word['top'] + word['height'],
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
bb = line_bbox[assigned_type]
|
||||||
|
bb['left'] = min(bb['left'], word['left'])
|
||||||
|
bb['top'] = min(bb['top'], word['top'])
|
||||||
|
bb['right'] = max(bb['right'], word['left'] + word['width'])
|
||||||
|
bb['bottom'] = max(bb['bottom'], word['top'] + word['height'])
|
||||||
|
|
||||||
|
for col_type in ["english", "german", "example"]:
|
||||||
|
if line_words_by_col[col_type]:
|
||||||
|
entry[col_type] = " ".join(line_words_by_col[col_type])
|
||||||
|
if line_bbox.get(col_type):
|
||||||
|
bb = line_bbox[col_type]
|
||||||
|
entry[f"{col_type}_bbox"] = {
|
||||||
|
"x_pct": bb['left'] / image_width * 100,
|
||||||
|
"y_pct": bb['top'] / image_height * 100,
|
||||||
|
"w_pct": (bb['right'] - bb['left']) / image_width * 100,
|
||||||
|
"h_pct": (bb['bottom'] - bb['top']) / image_height * 100,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Only add if at least one column has content
|
||||||
|
if entry["english"] or entry["german"]:
|
||||||
|
entries.append(entry)
|
||||||
|
|
||||||
|
return entries
|
||||||
|
|
||||||
|
|
||||||
|
def match_positions_to_vocab(tess_words: List[dict], llm_vocab: List[dict],
|
||||||
|
image_w: int, image_h: int,
|
||||||
|
threshold: float = 0.6) -> List[dict]:
|
||||||
|
"""Match Tesseract bounding boxes to LLM vocabulary entries.
|
||||||
|
|
||||||
|
For each LLM vocab entry, find the best-matching Tesseract word
|
||||||
|
and attach its bounding box coordinates.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tess_words: Word list from Tesseract with pixel coordinates.
|
||||||
|
llm_vocab: Vocabulary list from Vision LLM.
|
||||||
|
image_w: Image width in pixels.
|
||||||
|
image_h: Image height in pixels.
|
||||||
|
threshold: Minimum similarity ratio for a match.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
llm_vocab list with bbox_x_pct/bbox_y_pct/bbox_w_pct/bbox_h_pct added.
|
||||||
|
"""
|
||||||
|
if not tess_words or not llm_vocab or image_w == 0 or image_h == 0:
|
||||||
|
return llm_vocab
|
||||||
|
|
||||||
|
for entry in llm_vocab:
|
||||||
|
english = entry.get("english", "").lower().strip()
|
||||||
|
german = entry.get("german", "").lower().strip()
|
||||||
|
|
||||||
|
if not english and not german:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Try to match English word first, then German
|
||||||
|
for field in ["english", "german"]:
|
||||||
|
search_text = entry.get(field, "").lower().strip()
|
||||||
|
if not search_text:
|
||||||
|
continue
|
||||||
|
|
||||||
|
best_word = None
|
||||||
|
best_ratio = 0.0
|
||||||
|
|
||||||
|
for word in tess_words:
|
||||||
|
ratio = SequenceMatcher(None, search_text, word['text'].lower()).ratio()
|
||||||
|
if ratio > best_ratio:
|
||||||
|
best_ratio = ratio
|
||||||
|
best_word = word
|
||||||
|
|
||||||
|
if best_word and best_ratio >= threshold:
|
||||||
|
entry[f"bbox_x_pct"] = best_word['left'] / image_w * 100
|
||||||
|
entry[f"bbox_y_pct"] = best_word['top'] / image_h * 100
|
||||||
|
entry[f"bbox_w_pct"] = best_word['width'] / image_w * 100
|
||||||
|
entry[f"bbox_h_pct"] = best_word['height'] / image_h * 100
|
||||||
|
entry["bbox_match_field"] = field
|
||||||
|
entry["bbox_match_ratio"] = round(best_ratio, 3)
|
||||||
|
break # Found a match, no need to try the other field
|
||||||
|
|
||||||
|
return llm_vocab
|
||||||
|
|
||||||
|
|
||||||
|
async def run_tesseract_pipeline(image_bytes: bytes, lang: str = "eng+deu") -> dict:
|
||||||
|
"""Full Tesseract pipeline: extract words, group lines, detect columns, build vocab.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_bytes: PNG/JPEG image as bytes.
|
||||||
|
lang: Tesseract language string.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict with 'vocabulary', 'words', 'lines', 'columns', 'image_width', 'image_height'.
|
||||||
|
"""
|
||||||
|
# Step 1: Extract bounding boxes
|
||||||
|
bbox_data = await extract_bounding_boxes(image_bytes, lang=lang)
|
||||||
|
|
||||||
|
if bbox_data.get("error"):
|
||||||
|
return bbox_data
|
||||||
|
|
||||||
|
words = bbox_data["words"]
|
||||||
|
image_w = bbox_data["image_width"]
|
||||||
|
image_h = bbox_data["image_height"]
|
||||||
|
|
||||||
|
# Step 2: Group into lines
|
||||||
|
lines = group_words_into_lines(words)
|
||||||
|
|
||||||
|
# Step 3: Detect columns
|
||||||
|
col_info = detect_columns(lines, image_w)
|
||||||
|
|
||||||
|
# Step 4: Build vocabulary entries
|
||||||
|
vocab = words_to_vocab_entries(
|
||||||
|
lines,
|
||||||
|
col_info["columns"],
|
||||||
|
col_info["column_types"],
|
||||||
|
image_w,
|
||||||
|
image_h,
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"vocabulary": vocab,
|
||||||
|
"words": words,
|
||||||
|
"lines_count": len(lines),
|
||||||
|
"columns": col_info["columns"],
|
||||||
|
"column_types": col_info["column_types"],
|
||||||
|
"image_width": image_w,
|
||||||
|
"image_height": image_h,
|
||||||
|
"word_count": len(words),
|
||||||
|
}
|
||||||
261
klausur-service/backend/trocr_api.py
Normal file
261
klausur-service/backend/trocr_api.py
Normal file
@@ -0,0 +1,261 @@
|
|||||||
|
"""
|
||||||
|
TrOCR API - REST endpoints for TrOCR handwriting OCR.
|
||||||
|
|
||||||
|
Provides:
|
||||||
|
- /ocr/trocr - Single image OCR
|
||||||
|
- /ocr/trocr/batch - Batch image processing
|
||||||
|
- /ocr/trocr/status - Model status
|
||||||
|
- /ocr/trocr/cache - Cache statistics
|
||||||
|
"""
|
||||||
|
|
||||||
|
from fastapi import APIRouter, UploadFile, File, HTTPException, Query
|
||||||
|
from fastapi.responses import StreamingResponse
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from typing import List, Optional
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from services.trocr_service import (
|
||||||
|
run_trocr_ocr_enhanced,
|
||||||
|
run_trocr_batch,
|
||||||
|
run_trocr_batch_stream,
|
||||||
|
get_model_status,
|
||||||
|
get_cache_stats,
|
||||||
|
preload_trocr_model,
|
||||||
|
OCRResult,
|
||||||
|
BatchOCRResult
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/api/v1/ocr/trocr", tags=["TrOCR"])
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# MODELS
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
class TrOCRResponse(BaseModel):
|
||||||
|
"""Response model for single image OCR."""
|
||||||
|
text: str = Field(..., description="Extracted text")
|
||||||
|
confidence: float = Field(..., ge=0.0, le=1.0, description="Overall confidence")
|
||||||
|
processing_time_ms: int = Field(..., ge=0, description="Processing time in milliseconds")
|
||||||
|
model: str = Field(..., description="Model used for OCR")
|
||||||
|
has_lora_adapter: bool = Field(False, description="Whether LoRA adapter was used")
|
||||||
|
from_cache: bool = Field(False, description="Whether result was from cache")
|
||||||
|
image_hash: str = Field("", description="SHA256 hash of image (first 16 chars)")
|
||||||
|
word_count: int = Field(0, description="Number of words detected")
|
||||||
|
|
||||||
|
|
||||||
|
class BatchOCRResponse(BaseModel):
|
||||||
|
"""Response model for batch OCR."""
|
||||||
|
results: List[TrOCRResponse] = Field(..., description="Individual OCR results")
|
||||||
|
total_time_ms: int = Field(..., ge=0, description="Total processing time")
|
||||||
|
processed_count: int = Field(..., ge=0, description="Number of images processed")
|
||||||
|
cached_count: int = Field(0, description="Number of results from cache")
|
||||||
|
error_count: int = Field(0, description="Number of errors")
|
||||||
|
|
||||||
|
|
||||||
|
class ModelStatusResponse(BaseModel):
|
||||||
|
"""Response model for model status."""
|
||||||
|
status: str = Field(..., description="Model status: available, not_installed")
|
||||||
|
is_loaded: bool = Field(..., description="Whether model is loaded in memory")
|
||||||
|
model_name: Optional[str] = Field(None, description="Name of loaded model")
|
||||||
|
device: Optional[str] = Field(None, description="Device model is running on")
|
||||||
|
loaded_at: Optional[str] = Field(None, description="ISO timestamp when model was loaded")
|
||||||
|
|
||||||
|
|
||||||
|
class CacheStatsResponse(BaseModel):
|
||||||
|
"""Response model for cache statistics."""
|
||||||
|
size: int = Field(..., ge=0, description="Current cache size")
|
||||||
|
max_size: int = Field(..., ge=0, description="Maximum cache size")
|
||||||
|
ttl_seconds: int = Field(..., ge=0, description="Cache TTL in seconds")
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# ENDPOINTS
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
@router.get("/status", response_model=ModelStatusResponse)
|
||||||
|
async def get_trocr_status():
|
||||||
|
"""
|
||||||
|
Get TrOCR model status.
|
||||||
|
|
||||||
|
Returns information about whether the model is loaded and available.
|
||||||
|
"""
|
||||||
|
return get_model_status()
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/cache", response_model=CacheStatsResponse)
|
||||||
|
async def get_trocr_cache_stats():
|
||||||
|
"""
|
||||||
|
Get TrOCR cache statistics.
|
||||||
|
|
||||||
|
Returns information about the OCR result cache.
|
||||||
|
"""
|
||||||
|
return get_cache_stats()
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/preload")
|
||||||
|
async def preload_model(handwritten: bool = Query(True, description="Load handwritten model")):
|
||||||
|
"""
|
||||||
|
Preload TrOCR model into memory.
|
||||||
|
|
||||||
|
This speeds up the first OCR request by loading the model ahead of time.
|
||||||
|
"""
|
||||||
|
success = preload_trocr_model(handwritten=handwritten)
|
||||||
|
if success:
|
||||||
|
return {"status": "success", "message": "Model preloaded successfully"}
|
||||||
|
else:
|
||||||
|
raise HTTPException(status_code=500, detail="Failed to preload model")
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("", response_model=TrOCRResponse)
|
||||||
|
async def run_trocr(
|
||||||
|
file: UploadFile = File(..., description="Image file to process"),
|
||||||
|
handwritten: bool = Query(True, description="Use handwritten model"),
|
||||||
|
split_lines: bool = Query(True, description="Split image into lines"),
|
||||||
|
use_cache: bool = Query(True, description="Use result caching")
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Run TrOCR on a single image.
|
||||||
|
|
||||||
|
Supports PNG, JPG, and other common image formats.
|
||||||
|
"""
|
||||||
|
# Validate file type
|
||||||
|
if not file.content_type or not file.content_type.startswith("image/"):
|
||||||
|
raise HTTPException(status_code=400, detail="File must be an image")
|
||||||
|
|
||||||
|
try:
|
||||||
|
image_data = await file.read()
|
||||||
|
|
||||||
|
result = await run_trocr_ocr_enhanced(
|
||||||
|
image_data,
|
||||||
|
handwritten=handwritten,
|
||||||
|
split_lines=split_lines,
|
||||||
|
use_cache=use_cache
|
||||||
|
)
|
||||||
|
|
||||||
|
return TrOCRResponse(
|
||||||
|
text=result.text,
|
||||||
|
confidence=result.confidence,
|
||||||
|
processing_time_ms=result.processing_time_ms,
|
||||||
|
model=result.model,
|
||||||
|
has_lora_adapter=result.has_lora_adapter,
|
||||||
|
from_cache=result.from_cache,
|
||||||
|
image_hash=result.image_hash,
|
||||||
|
word_count=len(result.text.split()) if result.text else 0
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"TrOCR API error: {e}")
|
||||||
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/batch", response_model=BatchOCRResponse)
|
||||||
|
async def run_trocr_batch_endpoint(
|
||||||
|
files: List[UploadFile] = File(..., description="Image files to process"),
|
||||||
|
handwritten: bool = Query(True, description="Use handwritten model"),
|
||||||
|
split_lines: bool = Query(True, description="Split images into lines"),
|
||||||
|
use_cache: bool = Query(True, description="Use result caching")
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Run TrOCR on multiple images.
|
||||||
|
|
||||||
|
Processes images sequentially and returns all results.
|
||||||
|
"""
|
||||||
|
if not files:
|
||||||
|
raise HTTPException(status_code=400, detail="No files provided")
|
||||||
|
|
||||||
|
if len(files) > 50:
|
||||||
|
raise HTTPException(status_code=400, detail="Maximum 50 images per batch")
|
||||||
|
|
||||||
|
try:
|
||||||
|
images = []
|
||||||
|
for file in files:
|
||||||
|
if not file.content_type or not file.content_type.startswith("image/"):
|
||||||
|
raise HTTPException(status_code=400, detail=f"File {file.filename} is not an image")
|
||||||
|
images.append(await file.read())
|
||||||
|
|
||||||
|
batch_result = await run_trocr_batch(
|
||||||
|
images,
|
||||||
|
handwritten=handwritten,
|
||||||
|
split_lines=split_lines,
|
||||||
|
use_cache=use_cache
|
||||||
|
)
|
||||||
|
|
||||||
|
return BatchOCRResponse(
|
||||||
|
results=[
|
||||||
|
TrOCRResponse(
|
||||||
|
text=r.text,
|
||||||
|
confidence=r.confidence,
|
||||||
|
processing_time_ms=r.processing_time_ms,
|
||||||
|
model=r.model,
|
||||||
|
has_lora_adapter=r.has_lora_adapter,
|
||||||
|
from_cache=r.from_cache,
|
||||||
|
image_hash=r.image_hash,
|
||||||
|
word_count=len(r.text.split()) if r.text else 0
|
||||||
|
)
|
||||||
|
for r in batch_result.results
|
||||||
|
],
|
||||||
|
total_time_ms=batch_result.total_time_ms,
|
||||||
|
processed_count=batch_result.processed_count,
|
||||||
|
cached_count=batch_result.cached_count,
|
||||||
|
error_count=batch_result.error_count
|
||||||
|
)
|
||||||
|
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"TrOCR batch API error: {e}")
|
||||||
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/batch/stream")
|
||||||
|
async def run_trocr_batch_stream_endpoint(
|
||||||
|
files: List[UploadFile] = File(..., description="Image files to process"),
|
||||||
|
handwritten: bool = Query(True, description="Use handwritten model"),
|
||||||
|
split_lines: bool = Query(True, description="Split images into lines"),
|
||||||
|
use_cache: bool = Query(True, description="Use result caching")
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Run TrOCR on multiple images with Server-Sent Events (SSE) progress updates.
|
||||||
|
|
||||||
|
Returns a stream of progress events as images are processed.
|
||||||
|
"""
|
||||||
|
if not files:
|
||||||
|
raise HTTPException(status_code=400, detail="No files provided")
|
||||||
|
|
||||||
|
if len(files) > 50:
|
||||||
|
raise HTTPException(status_code=400, detail="Maximum 50 images per batch")
|
||||||
|
|
||||||
|
try:
|
||||||
|
images = []
|
||||||
|
for file in files:
|
||||||
|
if not file.content_type or not file.content_type.startswith("image/"):
|
||||||
|
raise HTTPException(status_code=400, detail=f"File {file.filename} is not an image")
|
||||||
|
images.append(await file.read())
|
||||||
|
|
||||||
|
async def event_generator():
|
||||||
|
async for update in run_trocr_batch_stream(
|
||||||
|
images,
|
||||||
|
handwritten=handwritten,
|
||||||
|
split_lines=split_lines,
|
||||||
|
use_cache=use_cache
|
||||||
|
):
|
||||||
|
yield f"data: {json.dumps(update)}\n\n"
|
||||||
|
|
||||||
|
return StreamingResponse(
|
||||||
|
event_generator(),
|
||||||
|
media_type="text/event-stream",
|
||||||
|
headers={
|
||||||
|
"Cache-Control": "no-cache",
|
||||||
|
"Connection": "keep-alive"
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"TrOCR stream API error: {e}")
|
||||||
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
428
klausur-service/backend/vocab_session_store.py
Normal file
428
klausur-service/backend/vocab_session_store.py
Normal file
@@ -0,0 +1,428 @@
|
|||||||
|
"""
|
||||||
|
Vocabulary Session Store - PostgreSQL persistence for vocab extraction sessions.
|
||||||
|
|
||||||
|
Replaces in-memory storage with database persistence.
|
||||||
|
See migrations/001_vocab_sessions.sql for schema.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import uuid
|
||||||
|
import logging
|
||||||
|
import json
|
||||||
|
from typing import Optional, List, Dict, Any
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
import asyncpg
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Database configuration
|
||||||
|
DATABASE_URL = os.getenv(
|
||||||
|
"DATABASE_URL",
|
||||||
|
"postgresql://breakpilot:breakpilot@postgres:5432/breakpilot_db"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Connection pool (initialized lazily)
|
||||||
|
_pool: Optional[asyncpg.Pool] = None
|
||||||
|
|
||||||
|
|
||||||
|
async def get_pool() -> asyncpg.Pool:
|
||||||
|
"""Get or create the database connection pool."""
|
||||||
|
global _pool
|
||||||
|
if _pool is None:
|
||||||
|
_pool = await asyncpg.create_pool(DATABASE_URL, min_size=2, max_size=10)
|
||||||
|
return _pool
|
||||||
|
|
||||||
|
|
||||||
|
async def init_vocab_tables():
|
||||||
|
"""
|
||||||
|
Initialize vocab tables if they don't exist.
|
||||||
|
This is called at startup.
|
||||||
|
"""
|
||||||
|
pool = await get_pool()
|
||||||
|
async with pool.acquire() as conn:
|
||||||
|
# Check if tables exist
|
||||||
|
tables_exist = await conn.fetchval("""
|
||||||
|
SELECT EXISTS (
|
||||||
|
SELECT FROM information_schema.tables
|
||||||
|
WHERE table_name = 'vocab_sessions'
|
||||||
|
)
|
||||||
|
""")
|
||||||
|
|
||||||
|
if not tables_exist:
|
||||||
|
logger.info("Creating vocab tables...")
|
||||||
|
# Read and execute migration
|
||||||
|
migration_path = os.path.join(
|
||||||
|
os.path.dirname(__file__),
|
||||||
|
"migrations/001_vocab_sessions.sql"
|
||||||
|
)
|
||||||
|
if os.path.exists(migration_path):
|
||||||
|
with open(migration_path, "r") as f:
|
||||||
|
sql = f.read()
|
||||||
|
await conn.execute(sql)
|
||||||
|
logger.info("Vocab tables created successfully")
|
||||||
|
else:
|
||||||
|
logger.warning(f"Migration file not found: {migration_path}")
|
||||||
|
else:
|
||||||
|
logger.debug("Vocab tables already exist")
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# SESSION OPERATIONS
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
async def create_session_db(
|
||||||
|
session_id: str,
|
||||||
|
name: str,
|
||||||
|
description: str = "",
|
||||||
|
source_language: str = "en",
|
||||||
|
target_language: str = "de"
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""Create a new vocabulary session in the database."""
|
||||||
|
pool = await get_pool()
|
||||||
|
async with pool.acquire() as conn:
|
||||||
|
row = await conn.fetchrow("""
|
||||||
|
INSERT INTO vocab_sessions (
|
||||||
|
id, name, description, source_language, target_language,
|
||||||
|
status, vocabulary_count
|
||||||
|
) VALUES ($1, $2, $3, $4, $5, 'pending', 0)
|
||||||
|
RETURNING *
|
||||||
|
""", uuid.UUID(session_id), name, description, source_language, target_language)
|
||||||
|
|
||||||
|
return _row_to_dict(row)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_session_db(session_id: str) -> Optional[Dict[str, Any]]:
|
||||||
|
"""Get a session by ID."""
|
||||||
|
pool = await get_pool()
|
||||||
|
async with pool.acquire() as conn:
|
||||||
|
row = await conn.fetchrow("""
|
||||||
|
SELECT * FROM vocab_sessions WHERE id = $1
|
||||||
|
""", uuid.UUID(session_id))
|
||||||
|
|
||||||
|
if row:
|
||||||
|
return _row_to_dict(row)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
async def list_sessions_db(
|
||||||
|
limit: int = 50,
|
||||||
|
offset: int = 0,
|
||||||
|
status: Optional[str] = None
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
|
"""List all sessions with optional filtering."""
|
||||||
|
pool = await get_pool()
|
||||||
|
async with pool.acquire() as conn:
|
||||||
|
if status:
|
||||||
|
rows = await conn.fetch("""
|
||||||
|
SELECT * FROM vocab_sessions
|
||||||
|
WHERE status = $1
|
||||||
|
ORDER BY created_at DESC
|
||||||
|
LIMIT $2 OFFSET $3
|
||||||
|
""", status, limit, offset)
|
||||||
|
else:
|
||||||
|
rows = await conn.fetch("""
|
||||||
|
SELECT * FROM vocab_sessions
|
||||||
|
ORDER BY created_at DESC
|
||||||
|
LIMIT $1 OFFSET $2
|
||||||
|
""", limit, offset)
|
||||||
|
|
||||||
|
return [_row_to_dict(row) for row in rows]
|
||||||
|
|
||||||
|
|
||||||
|
async def update_session_db(
|
||||||
|
session_id: str,
|
||||||
|
**kwargs
|
||||||
|
) -> Optional[Dict[str, Any]]:
|
||||||
|
"""Update a session with given fields."""
|
||||||
|
pool = await get_pool()
|
||||||
|
|
||||||
|
# Build dynamic UPDATE query
|
||||||
|
fields = []
|
||||||
|
values = []
|
||||||
|
param_idx = 1
|
||||||
|
|
||||||
|
allowed_fields = [
|
||||||
|
'name', 'description', 'status', 'vocabulary_count',
|
||||||
|
'extraction_confidence', 'image_path', 'pdf_path', 'pdf_page_count',
|
||||||
|
'ocr_prompts', 'processed_pages', 'successful_pages', 'failed_pages'
|
||||||
|
]
|
||||||
|
|
||||||
|
for key, value in kwargs.items():
|
||||||
|
if key in allowed_fields:
|
||||||
|
fields.append(f"{key} = ${param_idx}")
|
||||||
|
# Convert dicts/lists to JSON for JSONB columns
|
||||||
|
if key in ['ocr_prompts', 'processed_pages', 'successful_pages', 'failed_pages']:
|
||||||
|
value = json.dumps(value) if value else None
|
||||||
|
values.append(value)
|
||||||
|
param_idx += 1
|
||||||
|
|
||||||
|
if not fields:
|
||||||
|
return await get_session_db(session_id)
|
||||||
|
|
||||||
|
values.append(uuid.UUID(session_id))
|
||||||
|
|
||||||
|
async with pool.acquire() as conn:
|
||||||
|
row = await conn.fetchrow(f"""
|
||||||
|
UPDATE vocab_sessions
|
||||||
|
SET {', '.join(fields)}
|
||||||
|
WHERE id = ${param_idx}
|
||||||
|
RETURNING *
|
||||||
|
""", *values)
|
||||||
|
|
||||||
|
if row:
|
||||||
|
return _row_to_dict(row)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
async def delete_session_db(session_id: str) -> bool:
|
||||||
|
"""Delete a session and all related data (cascades)."""
|
||||||
|
pool = await get_pool()
|
||||||
|
async with pool.acquire() as conn:
|
||||||
|
result = await conn.execute("""
|
||||||
|
DELETE FROM vocab_sessions WHERE id = $1
|
||||||
|
""", uuid.UUID(session_id))
|
||||||
|
return result == "DELETE 1"
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# VOCABULARY OPERATIONS
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
async def add_vocabulary_db(
|
||||||
|
session_id: str,
|
||||||
|
vocab_list: List[Dict[str, Any]]
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
|
"""Add vocabulary entries to a session."""
|
||||||
|
if not vocab_list:
|
||||||
|
return []
|
||||||
|
|
||||||
|
pool = await get_pool()
|
||||||
|
results = []
|
||||||
|
|
||||||
|
async with pool.acquire() as conn:
|
||||||
|
for vocab in vocab_list:
|
||||||
|
vocab_id = str(uuid.uuid4())
|
||||||
|
row = await conn.fetchrow("""
|
||||||
|
INSERT INTO vocab_entries (
|
||||||
|
id, session_id, english, german, example_sentence,
|
||||||
|
example_sentence_gap, word_type, source_page
|
||||||
|
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
|
||||||
|
RETURNING *
|
||||||
|
""",
|
||||||
|
uuid.UUID(vocab_id),
|
||||||
|
uuid.UUID(session_id),
|
||||||
|
vocab.get('english', ''),
|
||||||
|
vocab.get('german', ''),
|
||||||
|
vocab.get('example_sentence'),
|
||||||
|
vocab.get('example_sentence_gap'),
|
||||||
|
vocab.get('word_type'),
|
||||||
|
vocab.get('source_page')
|
||||||
|
)
|
||||||
|
results.append(_row_to_dict(row))
|
||||||
|
|
||||||
|
# Update vocabulary count
|
||||||
|
await conn.execute("""
|
||||||
|
UPDATE vocab_sessions
|
||||||
|
SET vocabulary_count = (
|
||||||
|
SELECT COUNT(*) FROM vocab_entries WHERE session_id = $1
|
||||||
|
)
|
||||||
|
WHERE id = $1
|
||||||
|
""", uuid.UUID(session_id))
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
async def get_vocabulary_db(
|
||||||
|
session_id: str,
|
||||||
|
source_page: Optional[int] = None
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
|
"""Get vocabulary entries for a session."""
|
||||||
|
pool = await get_pool()
|
||||||
|
async with pool.acquire() as conn:
|
||||||
|
if source_page is not None:
|
||||||
|
rows = await conn.fetch("""
|
||||||
|
SELECT * FROM vocab_entries
|
||||||
|
WHERE session_id = $1 AND source_page = $2
|
||||||
|
ORDER BY created_at
|
||||||
|
""", uuid.UUID(session_id), source_page)
|
||||||
|
else:
|
||||||
|
rows = await conn.fetch("""
|
||||||
|
SELECT * FROM vocab_entries
|
||||||
|
WHERE session_id = $1
|
||||||
|
ORDER BY source_page NULLS LAST, created_at
|
||||||
|
""", uuid.UUID(session_id))
|
||||||
|
|
||||||
|
return [_row_to_dict(row) for row in rows]
|
||||||
|
|
||||||
|
|
||||||
|
async def update_vocabulary_db(
|
||||||
|
entry_id: str,
|
||||||
|
**kwargs
|
||||||
|
) -> Optional[Dict[str, Any]]:
|
||||||
|
"""Update a single vocabulary entry."""
|
||||||
|
pool = await get_pool()
|
||||||
|
|
||||||
|
fields = []
|
||||||
|
values = []
|
||||||
|
param_idx = 1
|
||||||
|
|
||||||
|
allowed_fields = [
|
||||||
|
'english', 'german', 'example_sentence', 'example_sentence_gap',
|
||||||
|
'word_type', 'source_page'
|
||||||
|
]
|
||||||
|
|
||||||
|
for key, value in kwargs.items():
|
||||||
|
if key in allowed_fields:
|
||||||
|
fields.append(f"{key} = ${param_idx}")
|
||||||
|
values.append(value)
|
||||||
|
param_idx += 1
|
||||||
|
|
||||||
|
if not fields:
|
||||||
|
return None
|
||||||
|
|
||||||
|
values.append(uuid.UUID(entry_id))
|
||||||
|
|
||||||
|
async with pool.acquire() as conn:
|
||||||
|
row = await conn.fetchrow(f"""
|
||||||
|
UPDATE vocab_entries
|
||||||
|
SET {', '.join(fields)}
|
||||||
|
WHERE id = ${param_idx}
|
||||||
|
RETURNING *
|
||||||
|
""", *values)
|
||||||
|
|
||||||
|
if row:
|
||||||
|
return _row_to_dict(row)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
async def clear_page_vocabulary_db(session_id: str, page: int) -> int:
|
||||||
|
"""Clear all vocabulary for a specific page."""
|
||||||
|
pool = await get_pool()
|
||||||
|
async with pool.acquire() as conn:
|
||||||
|
result = await conn.execute("""
|
||||||
|
DELETE FROM vocab_entries
|
||||||
|
WHERE session_id = $1 AND source_page = $2
|
||||||
|
""", uuid.UUID(session_id), page)
|
||||||
|
|
||||||
|
# Update vocabulary count
|
||||||
|
await conn.execute("""
|
||||||
|
UPDATE vocab_sessions
|
||||||
|
SET vocabulary_count = (
|
||||||
|
SELECT COUNT(*) FROM vocab_entries WHERE session_id = $1
|
||||||
|
)
|
||||||
|
WHERE id = $1
|
||||||
|
""", uuid.UUID(session_id))
|
||||||
|
|
||||||
|
# Return count of deleted rows
|
||||||
|
count = int(result.split()[-1]) if result else 0
|
||||||
|
return count
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# WORKSHEET OPERATIONS
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
async def create_worksheet_db(
|
||||||
|
session_id: str,
|
||||||
|
worksheet_types: List[str],
|
||||||
|
pdf_path: Optional[str] = None,
|
||||||
|
solution_path: Optional[str] = None
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""Create a worksheet record."""
|
||||||
|
pool = await get_pool()
|
||||||
|
worksheet_id = str(uuid.uuid4())
|
||||||
|
|
||||||
|
async with pool.acquire() as conn:
|
||||||
|
row = await conn.fetchrow("""
|
||||||
|
INSERT INTO vocab_worksheets (
|
||||||
|
id, session_id, worksheet_types, pdf_path, solution_path
|
||||||
|
) VALUES ($1, $2, $3, $4, $5)
|
||||||
|
RETURNING *
|
||||||
|
""",
|
||||||
|
uuid.UUID(worksheet_id),
|
||||||
|
uuid.UUID(session_id),
|
||||||
|
json.dumps(worksheet_types),
|
||||||
|
pdf_path,
|
||||||
|
solution_path
|
||||||
|
)
|
||||||
|
|
||||||
|
return _row_to_dict(row)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_worksheet_db(worksheet_id: str) -> Optional[Dict[str, Any]]:
|
||||||
|
"""Get a worksheet by ID."""
|
||||||
|
pool = await get_pool()
|
||||||
|
async with pool.acquire() as conn:
|
||||||
|
row = await conn.fetchrow("""
|
||||||
|
SELECT * FROM vocab_worksheets WHERE id = $1
|
||||||
|
""", uuid.UUID(worksheet_id))
|
||||||
|
|
||||||
|
if row:
|
||||||
|
return _row_to_dict(row)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
async def delete_worksheets_for_session_db(session_id: str) -> int:
|
||||||
|
"""Delete all worksheets for a session."""
|
||||||
|
pool = await get_pool()
|
||||||
|
async with pool.acquire() as conn:
|
||||||
|
result = await conn.execute("""
|
||||||
|
DELETE FROM vocab_worksheets WHERE session_id = $1
|
||||||
|
""", uuid.UUID(session_id))
|
||||||
|
|
||||||
|
count = int(result.split()[-1]) if result else 0
|
||||||
|
return count
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# PDF CACHE OPERATIONS
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
# Simple in-memory cache for PDF data (temporary until served)
|
||||||
|
_pdf_cache: Dict[str, bytes] = {}
|
||||||
|
|
||||||
|
|
||||||
|
def cache_pdf_data(worksheet_id: str, pdf_data: bytes) -> None:
|
||||||
|
"""Cache PDF data temporarily for download."""
|
||||||
|
_pdf_cache[worksheet_id] = pdf_data
|
||||||
|
|
||||||
|
|
||||||
|
def get_cached_pdf_data(worksheet_id: str) -> Optional[bytes]:
|
||||||
|
"""Get cached PDF data."""
|
||||||
|
return _pdf_cache.get(worksheet_id)
|
||||||
|
|
||||||
|
|
||||||
|
def clear_cached_pdf_data(worksheet_id: str) -> None:
|
||||||
|
"""Clear cached PDF data."""
|
||||||
|
_pdf_cache.pop(worksheet_id, None)
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# HELPER FUNCTIONS
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
def _row_to_dict(row: asyncpg.Record) -> Dict[str, Any]:
|
||||||
|
"""Convert asyncpg Record to dict with proper type handling."""
|
||||||
|
if row is None:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
result = dict(row)
|
||||||
|
|
||||||
|
# Convert UUIDs to strings
|
||||||
|
for key in ['id', 'session_id']:
|
||||||
|
if key in result and result[key] is not None:
|
||||||
|
result[key] = str(result[key])
|
||||||
|
|
||||||
|
# Convert datetimes to ISO strings
|
||||||
|
for key in ['created_at', 'updated_at', 'generated_at']:
|
||||||
|
if key in result and result[key] is not None:
|
||||||
|
result[key] = result[key].isoformat()
|
||||||
|
|
||||||
|
# Parse JSONB fields back to dicts/lists
|
||||||
|
for key in ['ocr_prompts', 'processed_pages', 'successful_pages', 'failed_pages', 'worksheet_types']:
|
||||||
|
if key in result and result[key] is not None:
|
||||||
|
if isinstance(result[key], str):
|
||||||
|
result[key] = json.loads(result[key])
|
||||||
|
|
||||||
|
return result
|
||||||
Reference in New Issue
Block a user