[split-required] Split 500-1000 LOC files across all services

backend-lehrer (5 files):
- alerts_agent/db/repository.py (992 → 5), abitur_docs_api.py (956 → 3)
- teacher_dashboard_api.py (951 → 3), services/pdf_service.py (916 → 3)
- mail/mail_db.py (987 → 6)

klausur-service (5 files):
- legal_templates_ingestion.py (942 → 3), ocr_pipeline_postprocess.py (929 → 4)
- ocr_pipeline_words.py (876 → 3), ocr_pipeline_ocr_merge.py (616 → 2)
- KorrekturPage.tsx (956 → 6)

website (5 pages):
- mail (985 → 9), edu-search (958 → 8), mac-mini (950 → 7)
- ocr-labeling (946 → 7), audit-workspace (871 → 4)

studio-v2 (5 files + 1 deleted):
- page.tsx (946 → 5), MessagesContext.tsx (925 → 4)
- korrektur (914 → 6), worksheet-cleanup (899 → 6)
- useVocabWorksheet.ts (888 → 3)
- Deleted dead page-original.tsx (934 LOC)

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
Benjamin Admin
2026-04-24 23:35:37 +02:00
parent 6811264756
commit b6983ab1dc
99 changed files with 13484 additions and 16106 deletions

View File

@@ -0,0 +1,282 @@
"""
Legal Templates Chunking — text splitting, type inference, and chunk creation.
Extracted from legal_templates_ingestion.py to keep files under 500 LOC.
Lizenz: Apache 2.0
"""
import re
from dataclasses import dataclass, field
from datetime import datetime
from typing import List, Optional
from template_sources import SourceConfig
from github_crawler import ExtractedDocument
# Chunking configuration defaults (can be overridden by env vars in ingestion module)
DEFAULT_CHUNK_SIZE = 1000
DEFAULT_CHUNK_OVERLAP = 200
@dataclass
class TemplateChunk:
"""A chunk of template text ready for indexing."""
text: str
chunk_index: int
document_title: str
template_type: str
clause_category: Optional[str]
language: str
jurisdiction: str
license_id: str
license_name: str
license_url: str
attribution_required: bool
share_alike: bool
no_derivatives: bool
commercial_use: bool
source_name: str
source_url: str
source_repo: Optional[str]
source_commit: Optional[str]
source_file: str
source_hash: str
attribution_text: Optional[str]
copyright_notice: Optional[str]
is_complete_document: bool
is_modular: bool
requires_customization: bool
placeholders: List[str]
training_allowed: bool
output_allowed: bool
modification_allowed: bool
distortion_prohibited: bool
@dataclass
class IngestionStatus:
"""Status of a source ingestion."""
source_name: str
status: str # "pending", "running", "completed", "failed"
documents_found: int = 0
chunks_created: int = 0
chunks_indexed: int = 0
errors: List[str] = field(default_factory=list)
started_at: Optional[datetime] = None
completed_at: Optional[datetime] = None
def split_sentences(text: str) -> List[str]:
"""Split text into sentences with basic abbreviation handling."""
# Protect common abbreviations
abbreviations = ['bzw', 'ca', 'd.h', 'etc', 'ggf', 'inkl', 'u.a', 'usw', 'z.B', 'z.b', 'e.g', 'i.e', 'vs', 'no']
protected = text
for abbr in abbreviations:
pattern = re.compile(r'\b' + re.escape(abbr) + r'\.', re.IGNORECASE)
protected = pattern.sub(abbr.replace('.', '<DOT>') + '<ABBR>', protected)
# Protect decimal numbers
protected = re.sub(r'(\d)\.(\d)', r'\1<DECIMAL>\2', protected)
# Split on sentence endings
sentences = re.split(r'(?<=[.!?])\s+', protected)
# Restore protected characters
result = []
for s in sentences:
s = s.replace('<DOT>', '.').replace('<ABBR>', '.').replace('<DECIMAL>', '.')
s = s.strip()
if s:
result.append(s)
return result
def chunk_text(
text: str,
chunk_size: int = DEFAULT_CHUNK_SIZE,
overlap: int = DEFAULT_CHUNK_OVERLAP,
) -> List[str]:
"""
Split text into overlapping chunks.
Respects paragraph and sentence boundaries where possible.
"""
if not text:
return []
if len(text) <= chunk_size:
return [text.strip()]
# Split into paragraphs first
paragraphs = text.split('\n\n')
chunks = []
current_chunk: List[str] = []
current_length = 0
for para in paragraphs:
para = para.strip()
if not para:
continue
para_length = len(para)
if para_length > chunk_size:
# Large paragraph: split by sentences
if current_chunk:
chunks.append('\n\n'.join(current_chunk))
current_chunk = []
current_length = 0
# Split long paragraph by sentences
sentences = split_sentences(para)
for sentence in sentences:
if current_length + len(sentence) + 1 > chunk_size:
if current_chunk:
chunks.append(' '.join(current_chunk))
# Keep overlap
overlap_count = max(1, len(current_chunk) // 3)
current_chunk = current_chunk[-overlap_count:]
current_length = sum(len(s) + 1 for s in current_chunk)
current_chunk.append(sentence)
current_length += len(sentence) + 1
elif current_length + para_length + 2 > chunk_size:
# Paragraph would exceed chunk size
if current_chunk:
chunks.append('\n\n'.join(current_chunk))
current_chunk = []
current_length = 0
current_chunk.append(para)
current_length = para_length
else:
current_chunk.append(para)
current_length += para_length + 2
# Add final chunk
if current_chunk:
chunks.append('\n\n'.join(current_chunk))
return [c.strip() for c in chunks if c.strip()]
def infer_template_type(doc: ExtractedDocument, source: SourceConfig) -> str:
"""Infer the template type from document content and metadata."""
text_lower = doc.text.lower()
title_lower = doc.title.lower()
# Check known indicators
type_indicators = {
"privacy_policy": ["datenschutz", "privacy", "personal data", "personenbezogen"],
"terms_of_service": ["nutzungsbedingungen", "terms of service", "terms of use", "agb"],
"cookie_banner": ["cookie", "cookies", "tracking"],
"impressum": ["impressum", "legal notice", "imprint"],
"widerruf": ["widerruf", "cancellation", "withdrawal", "right to cancel"],
"dpa": ["auftragsverarbeitung", "data processing agreement", "dpa"],
"sla": ["service level", "availability", "uptime"],
"nda": ["confidential", "non-disclosure", "geheimhaltung", "vertraulich"],
"community_guidelines": ["community", "guidelines", "conduct", "verhaltens"],
"acceptable_use": ["acceptable use", "acceptable usage", "nutzungsrichtlinien"],
}
for template_type, indicators in type_indicators.items():
for indicator in indicators:
if indicator in text_lower or indicator in title_lower:
return template_type
# Fall back to source's first template type
if source.template_types:
return source.template_types[0]
return "clause" # Generic fallback
def infer_clause_category(text: str) -> Optional[str]:
"""Infer the clause category from text content."""
text_lower = text.lower()
categories = {
"haftung": ["haftung", "liability", "haftungsausschluss", "limitation"],
"datenschutz": ["datenschutz", "privacy", "personal data", "personenbezogen"],
"widerruf": ["widerruf", "cancellation", "withdrawal"],
"gewaehrleistung": ["gewaehrleistung", "warranty", "garantie"],
"kuendigung": ["kuendigung", "termination", "beendigung"],
"zahlung": ["zahlung", "payment", "preis", "price"],
"gerichtsstand": ["gerichtsstand", "jurisdiction", "governing law"],
"aenderungen": ["aenderung", "modification", "amendment"],
"schlussbestimmungen": ["schlussbestimmung", "miscellaneous", "final provisions"],
}
for category, indicators in categories.items():
for indicator in indicators:
if indicator in text_lower:
return category
return None
def create_chunks(
doc: ExtractedDocument,
source: SourceConfig,
chunk_size: int = DEFAULT_CHUNK_SIZE,
chunk_overlap: int = DEFAULT_CHUNK_OVERLAP,
) -> List[TemplateChunk]:
"""Create template chunks from an extracted document."""
license_info = source.license_info
template_type = infer_template_type(doc, source)
# Chunk the text
text_chunks = chunk_text(doc.text, chunk_size, chunk_overlap)
chunks = []
for i, chunk_text_str in enumerate(text_chunks):
# Determine if this is a complete document or a clause
is_complete = len(text_chunks) == 1 and len(chunk_text_str) > 500
is_modular = len(doc.sections) > 0 or '##' in doc.text
requires_customization = len(doc.placeholders) > 0
# Generate attribution text
attribution_text = None
if license_info.attribution_required:
attribution_text = license_info.get_attribution_text(
source.name,
doc.source_url or source.get_source_url()
)
chunk = TemplateChunk(
text=chunk_text_str,
chunk_index=i,
document_title=doc.title,
template_type=template_type,
clause_category=infer_clause_category(chunk_text_str),
language=doc.language,
jurisdiction=source.jurisdiction,
license_id=license_info.id.value,
license_name=license_info.name,
license_url=license_info.url,
attribution_required=license_info.attribution_required,
share_alike=license_info.share_alike,
no_derivatives=license_info.no_derivatives,
commercial_use=license_info.commercial_use,
source_name=source.name,
source_url=doc.source_url or source.get_source_url(),
source_repo=source.repo_url,
source_commit=doc.source_commit,
source_file=doc.file_path,
source_hash=doc.source_hash,
attribution_text=attribution_text,
copyright_notice=None,
is_complete_document=is_complete,
is_modular=is_modular,
requires_customization=requires_customization,
placeholders=doc.placeholders,
training_allowed=license_info.training_allowed,
output_allowed=license_info.output_allowed,
modification_allowed=license_info.modification_allowed,
distortion_prohibited=license_info.distortion_prohibited,
)
chunks.append(chunk)
return chunks

View File

@@ -0,0 +1,165 @@
"""
Legal Templates CLI — command-line entry point for ingestion and search.
Extracted from legal_templates_ingestion.py to keep files under 500 LOC.
Usage:
python legal_templates_cli.py --ingest-all
python legal_templates_cli.py --ingest-source github-site-policy
python legal_templates_cli.py --status
python legal_templates_cli.py --search "Datenschutzerklaerung"
Lizenz: Apache 2.0
"""
import asyncio
import json
from template_sources import TEMPLATE_SOURCES, LicenseType
from legal_templates_ingestion import LegalTemplatesIngestion
async def main():
"""CLI entry point."""
import argparse
parser = argparse.ArgumentParser(description="Legal Templates Ingestion")
parser.add_argument(
"--ingest-all",
action="store_true",
help="Ingest all enabled sources"
)
parser.add_argument(
"--ingest-source",
type=str,
metavar="NAME",
help="Ingest a specific source by name"
)
parser.add_argument(
"--ingest-license",
type=str,
choices=["cc0", "mit", "cc_by_4", "public_domain"],
help="Ingest all sources of a specific license type"
)
parser.add_argument(
"--max-priority",
type=int,
default=3,
help="Maximum priority level to ingest (1=highest, 5=lowest)"
)
parser.add_argument(
"--status",
action="store_true",
help="Show collection status"
)
parser.add_argument(
"--search",
type=str,
metavar="QUERY",
help="Test search query"
)
parser.add_argument(
"--template-type",
type=str,
help="Filter search by template type"
)
parser.add_argument(
"--language",
type=str,
help="Filter search by language"
)
parser.add_argument(
"--reset",
action="store_true",
help="Reset (delete and recreate) the collection"
)
parser.add_argument(
"--delete-source",
type=str,
metavar="NAME",
help="Delete all chunks from a source"
)
args = parser.parse_args()
ingestion = LegalTemplatesIngestion()
try:
if args.reset:
ingestion.reset_collection()
print("Collection reset successfully")
elif args.delete_source:
count = ingestion.delete_source(args.delete_source)
print(f"Deleted {count} chunks from {args.delete_source}")
elif args.status:
status = ingestion.get_status()
print(json.dumps(status, indent=2, default=str))
elif args.ingest_all:
print(f"Ingesting all sources (max priority: {args.max_priority})...")
results = await ingestion.ingest_all(max_priority=args.max_priority)
print("\nResults:")
for name, status in results.items():
print(f" {name}: {status.chunks_indexed} chunks ({status.status})")
if status.errors:
for error in status.errors:
print(f" ERROR: {error}")
total = sum(s.chunks_indexed for s in results.values())
print(f"\nTotal: {total} chunks indexed")
elif args.ingest_source:
source = next(
(s for s in TEMPLATE_SOURCES if s.name == args.ingest_source),
None
)
if not source:
print(f"Unknown source: {args.ingest_source}")
print("Available sources:")
for s in TEMPLATE_SOURCES:
print(f" - {s.name}")
return
print(f"Ingesting: {source.name}")
status = await ingestion.ingest_source(source)
print(f"\nResult: {status.chunks_indexed} chunks ({status.status})")
if status.errors:
for error in status.errors:
print(f" ERROR: {error}")
elif args.ingest_license:
license_type = LicenseType(args.ingest_license)
print(f"Ingesting all {license_type.value} sources...")
results = await ingestion.ingest_by_license(license_type)
print("\nResults:")
for name, status in results.items():
print(f" {name}: {status.chunks_indexed} chunks ({status.status})")
elif args.search:
print(f"Searching: {args.search}")
results = await ingestion.search(
args.search,
template_type=args.template_type,
language=args.language,
)
print(f"\nFound {len(results)} results:")
for i, result in enumerate(results, 1):
print(f"\n{i}. [{result['template_type']}] {result['document_title']}")
print(f" Score: {result['score']:.3f}")
print(f" License: {result['license_name']}")
print(f" Source: {result['source_name']}")
print(f" Language: {result['language']}")
if result['attribution_required']:
print(f" Attribution: {result['attribution_text']}")
print(f" Text: {result['text'][:200]}...")
else:
parser.print_help()
finally:
await ingestion.close()
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -8,18 +8,16 @@ proper attribution tracking.
Collection: bp_legal_templates
Usage:
python legal_templates_ingestion.py --ingest-all
python legal_templates_ingestion.py --ingest-source github-site-policy
python legal_templates_ingestion.py --status
python legal_templates_ingestion.py --search "Datenschutzerklaerung"
python legal_templates_cli.py --ingest-all
python legal_templates_cli.py --ingest-source github-site-policy
python legal_templates_cli.py --status
python legal_templates_cli.py --search "Datenschutzerklaerung"
"""
import asyncio
import hashlib
import json
import logging
import os
from dataclasses import dataclass, field
from datetime import datetime
from typing import Any, Dict, List, Optional
from urllib.parse import urlparse
@@ -50,6 +48,17 @@ from github_crawler import (
RepositoryDownloader,
)
# Re-export from chunking module for backward compatibility
from legal_templates_chunking import ( # noqa: F401
IngestionStatus,
TemplateChunk,
chunk_text,
create_chunks,
infer_clause_category,
infer_template_type,
split_sentences,
)
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
@@ -78,54 +87,6 @@ MAX_RETRIES = 3
RETRY_DELAY = 3.0
@dataclass
class IngestionStatus:
"""Status of a source ingestion."""
source_name: str
status: str # "pending", "running", "completed", "failed"
documents_found: int = 0
chunks_created: int = 0
chunks_indexed: int = 0
errors: List[str] = field(default_factory=list)
started_at: Optional[datetime] = None
completed_at: Optional[datetime] = None
@dataclass
class TemplateChunk:
"""A chunk of template text ready for indexing."""
text: str
chunk_index: int
document_title: str
template_type: str
clause_category: Optional[str]
language: str
jurisdiction: str
license_id: str
license_name: str
license_url: str
attribution_required: bool
share_alike: bool
no_derivatives: bool
commercial_use: bool
source_name: str
source_url: str
source_repo: Optional[str]
source_commit: Optional[str]
source_file: str
source_hash: str
attribution_text: Optional[str]
copyright_notice: Optional[str]
is_complete_document: bool
is_modular: bool
requires_customization: bool
placeholders: List[str]
training_allowed: bool
output_allowed: bool
modification_allowed: bool
distortion_prohibited: bool
class LegalTemplatesIngestion:
"""Handles ingestion of legal templates into Qdrant."""
@@ -168,212 +129,6 @@ class LegalTemplatesIngestion:
logger.error(f"Embedding generation failed: {e}")
raise
def _chunk_text(self, text: str, chunk_size: int = CHUNK_SIZE, overlap: int = CHUNK_OVERLAP) -> List[str]:
"""
Split text into overlapping chunks.
Respects paragraph and sentence boundaries where possible.
"""
if not text:
return []
if len(text) <= chunk_size:
return [text.strip()]
# Split into paragraphs first
paragraphs = text.split('\n\n')
chunks = []
current_chunk = []
current_length = 0
for para in paragraphs:
para = para.strip()
if not para:
continue
para_length = len(para)
if para_length > chunk_size:
# Large paragraph: split by sentences
if current_chunk:
chunks.append('\n\n'.join(current_chunk))
current_chunk = []
current_length = 0
# Split long paragraph by sentences
sentences = self._split_sentences(para)
for sentence in sentences:
if current_length + len(sentence) + 1 > chunk_size:
if current_chunk:
chunks.append(' '.join(current_chunk))
# Keep overlap
overlap_count = max(1, len(current_chunk) // 3)
current_chunk = current_chunk[-overlap_count:]
current_length = sum(len(s) + 1 for s in current_chunk)
current_chunk.append(sentence)
current_length += len(sentence) + 1
elif current_length + para_length + 2 > chunk_size:
# Paragraph would exceed chunk size
if current_chunk:
chunks.append('\n\n'.join(current_chunk))
current_chunk = []
current_length = 0
current_chunk.append(para)
current_length = para_length
else:
current_chunk.append(para)
current_length += para_length + 2
# Add final chunk
if current_chunk:
chunks.append('\n\n'.join(current_chunk))
return [c.strip() for c in chunks if c.strip()]
def _split_sentences(self, text: str) -> List[str]:
"""Split text into sentences with basic abbreviation handling."""
import re
# Protect common abbreviations
abbreviations = ['bzw', 'ca', 'd.h', 'etc', 'ggf', 'inkl', 'u.a', 'usw', 'z.B', 'z.b', 'e.g', 'i.e', 'vs', 'no']
protected = text
for abbr in abbreviations:
pattern = re.compile(r'\b' + re.escape(abbr) + r'\.', re.IGNORECASE)
protected = pattern.sub(abbr.replace('.', '<DOT>') + '<ABBR>', protected)
# Protect decimal numbers
protected = re.sub(r'(\d)\.(\d)', r'\1<DECIMAL>\2', protected)
# Split on sentence endings
sentences = re.split(r'(?<=[.!?])\s+', protected)
# Restore protected characters
result = []
for s in sentences:
s = s.replace('<DOT>', '.').replace('<ABBR>', '.').replace('<DECIMAL>', '.')
s = s.strip()
if s:
result.append(s)
return result
def _infer_template_type(self, doc: ExtractedDocument, source: SourceConfig) -> str:
"""Infer the template type from document content and metadata."""
text_lower = doc.text.lower()
title_lower = doc.title.lower()
# Check known indicators
type_indicators = {
"privacy_policy": ["datenschutz", "privacy", "personal data", "personenbezogen"],
"terms_of_service": ["nutzungsbedingungen", "terms of service", "terms of use", "agb"],
"cookie_banner": ["cookie", "cookies", "tracking"],
"impressum": ["impressum", "legal notice", "imprint"],
"widerruf": ["widerruf", "cancellation", "withdrawal", "right to cancel"],
"dpa": ["auftragsverarbeitung", "data processing agreement", "dpa"],
"sla": ["service level", "availability", "uptime"],
"nda": ["confidential", "non-disclosure", "geheimhaltung", "vertraulich"],
"community_guidelines": ["community", "guidelines", "conduct", "verhaltens"],
"acceptable_use": ["acceptable use", "acceptable usage", "nutzungsrichtlinien"],
}
for template_type, indicators in type_indicators.items():
for indicator in indicators:
if indicator in text_lower or indicator in title_lower:
return template_type
# Fall back to source's first template type
if source.template_types:
return source.template_types[0]
return "clause" # Generic fallback
def _infer_clause_category(self, text: str) -> Optional[str]:
"""Infer the clause category from text content."""
text_lower = text.lower()
categories = {
"haftung": ["haftung", "liability", "haftungsausschluss", "limitation"],
"datenschutz": ["datenschutz", "privacy", "personal data", "personenbezogen"],
"widerruf": ["widerruf", "cancellation", "withdrawal"],
"gewaehrleistung": ["gewaehrleistung", "warranty", "garantie"],
"kuendigung": ["kuendigung", "termination", "beendigung"],
"zahlung": ["zahlung", "payment", "preis", "price"],
"gerichtsstand": ["gerichtsstand", "jurisdiction", "governing law"],
"aenderungen": ["aenderung", "modification", "amendment"],
"schlussbestimmungen": ["schlussbestimmung", "miscellaneous", "final provisions"],
}
for category, indicators in categories.items():
for indicator in indicators:
if indicator in text_lower:
return category
return None
def _create_chunks(
self,
doc: ExtractedDocument,
source: SourceConfig,
) -> List[TemplateChunk]:
"""Create template chunks from an extracted document."""
license_info = source.license_info
template_type = self._infer_template_type(doc, source)
# Chunk the text
text_chunks = self._chunk_text(doc.text)
chunks = []
for i, chunk_text in enumerate(text_chunks):
# Determine if this is a complete document or a clause
is_complete = len(text_chunks) == 1 and len(chunk_text) > 500
is_modular = len(doc.sections) > 0 or '##' in doc.text
requires_customization = len(doc.placeholders) > 0
# Generate attribution text
attribution_text = None
if license_info.attribution_required:
attribution_text = license_info.get_attribution_text(
source.name,
doc.source_url or source.get_source_url()
)
chunk = TemplateChunk(
text=chunk_text,
chunk_index=i,
document_title=doc.title,
template_type=template_type,
clause_category=self._infer_clause_category(chunk_text),
language=doc.language,
jurisdiction=source.jurisdiction,
license_id=license_info.id.value,
license_name=license_info.name,
license_url=license_info.url,
attribution_required=license_info.attribution_required,
share_alike=license_info.share_alike,
no_derivatives=license_info.no_derivatives,
commercial_use=license_info.commercial_use,
source_name=source.name,
source_url=doc.source_url or source.get_source_url(),
source_repo=source.repo_url,
source_commit=doc.source_commit,
source_file=doc.file_path,
source_hash=doc.source_hash,
attribution_text=attribution_text,
copyright_notice=None, # Could be extracted from doc if present
is_complete_document=is_complete,
is_modular=is_modular,
requires_customization=requires_customization,
placeholders=doc.placeholders,
training_allowed=license_info.training_allowed,
output_allowed=license_info.output_allowed,
modification_allowed=license_info.modification_allowed,
distortion_prohibited=license_info.distortion_prohibited,
)
chunks.append(chunk)
return chunks
async def ingest_source(self, source: SourceConfig) -> IngestionStatus:
"""Ingest a single source into Qdrant."""
status = IngestionStatus(
@@ -405,7 +160,7 @@ class LegalTemplatesIngestion:
# Create chunks from all documents
all_chunks: List[TemplateChunk] = []
for doc in documents:
chunks = self._create_chunks(doc, source)
chunks = create_chunks(doc, source, CHUNK_SIZE, CHUNK_OVERLAP)
all_chunks.extend(chunks)
status.chunks_created += len(chunks)
@@ -637,21 +392,7 @@ class LegalTemplatesIngestion:
attribution_required: Optional[bool] = None,
top_k: int = 10,
) -> List[Dict[str, Any]]:
"""
Search the legal templates collection.
Args:
query: Search query text
template_type: Filter by template type (e.g., "privacy_policy")
license_types: Filter by license types (e.g., ["cc0", "mit"])
language: Filter by language (e.g., "de")
jurisdiction: Filter by jurisdiction (e.g., "DE")
attribution_required: Filter by attribution requirement
top_k: Number of results to return
Returns:
List of search results with full metadata
"""
"""Search the legal templates collection."""
# Generate query embedding
embeddings = await self._generate_embeddings([query])
query_vector = embeddings[0]
@@ -661,45 +402,27 @@ class LegalTemplatesIngestion:
if template_type:
must_conditions.append(
FieldCondition(
key="template_type",
match=MatchValue(value=template_type),
)
FieldCondition(key="template_type", match=MatchValue(value=template_type))
)
if language:
must_conditions.append(
FieldCondition(
key="language",
match=MatchValue(value=language),
)
FieldCondition(key="language", match=MatchValue(value=language))
)
if jurisdiction:
must_conditions.append(
FieldCondition(
key="jurisdiction",
match=MatchValue(value=jurisdiction),
)
FieldCondition(key="jurisdiction", match=MatchValue(value=jurisdiction))
)
if attribution_required is not None:
must_conditions.append(
FieldCondition(
key="attribution_required",
match=MatchValue(value=attribution_required),
)
FieldCondition(key="attribution_required", match=MatchValue(value=attribution_required))
)
# License type filter (OR condition)
should_conditions = []
if license_types:
for license_type in license_types:
for lt in license_types:
should_conditions.append(
FieldCondition(
key="license_id",
match=MatchValue(value=license_type),
)
FieldCondition(key="license_id", match=MatchValue(value=lt))
)
# Construct filter
@@ -747,196 +470,31 @@ class LegalTemplatesIngestion:
def delete_source(self, source_name: str) -> int:
"""Delete all chunks from a specific source."""
# First count how many we're deleting
count_result = self.qdrant.count(
collection_name=LEGAL_TEMPLATES_COLLECTION,
count_filter=Filter(
must=[
FieldCondition(
key="source_name",
match=MatchValue(value=source_name),
)
]
must=[FieldCondition(key="source_name", match=MatchValue(value=source_name))]
),
)
# Delete by filter
self.qdrant.delete(
collection_name=LEGAL_TEMPLATES_COLLECTION,
points_selector=Filter(
must=[
FieldCondition(
key="source_name",
match=MatchValue(value=source_name),
)
]
must=[FieldCondition(key="source_name", match=MatchValue(value=source_name))]
),
)
return count_result.count
def reset_collection(self):
"""Delete and recreate the collection."""
logger.warning(f"Resetting collection: {LEGAL_TEMPLATES_COLLECTION}")
# Delete collection
try:
self.qdrant.delete_collection(LEGAL_TEMPLATES_COLLECTION)
except Exception:
pass # Collection might not exist
# Recreate
pass
self._ensure_collection()
self._ingestion_status.clear()
logger.info(f"Collection {LEGAL_TEMPLATES_COLLECTION} reset")
async def close(self):
"""Close HTTP client."""
await self.http_client.aclose()
async def main():
"""CLI entry point."""
import argparse
parser = argparse.ArgumentParser(description="Legal Templates Ingestion")
parser.add_argument(
"--ingest-all",
action="store_true",
help="Ingest all enabled sources"
)
parser.add_argument(
"--ingest-source",
type=str,
metavar="NAME",
help="Ingest a specific source by name"
)
parser.add_argument(
"--ingest-license",
type=str,
choices=["cc0", "mit", "cc_by_4", "public_domain"],
help="Ingest all sources of a specific license type"
)
parser.add_argument(
"--max-priority",
type=int,
default=3,
help="Maximum priority level to ingest (1=highest, 5=lowest)"
)
parser.add_argument(
"--status",
action="store_true",
help="Show collection status"
)
parser.add_argument(
"--search",
type=str,
metavar="QUERY",
help="Test search query"
)
parser.add_argument(
"--template-type",
type=str,
help="Filter search by template type"
)
parser.add_argument(
"--language",
type=str,
help="Filter search by language"
)
parser.add_argument(
"--reset",
action="store_true",
help="Reset (delete and recreate) the collection"
)
parser.add_argument(
"--delete-source",
type=str,
metavar="NAME",
help="Delete all chunks from a source"
)
args = parser.parse_args()
ingestion = LegalTemplatesIngestion()
try:
if args.reset:
ingestion.reset_collection()
print("Collection reset successfully")
elif args.delete_source:
count = ingestion.delete_source(args.delete_source)
print(f"Deleted {count} chunks from {args.delete_source}")
elif args.status:
status = ingestion.get_status()
print(json.dumps(status, indent=2, default=str))
elif args.ingest_all:
print(f"Ingesting all sources (max priority: {args.max_priority})...")
results = await ingestion.ingest_all(max_priority=args.max_priority)
print("\nResults:")
for name, status in results.items():
print(f" {name}: {status.chunks_indexed} chunks ({status.status})")
if status.errors:
for error in status.errors:
print(f" ERROR: {error}")
total = sum(s.chunks_indexed for s in results.values())
print(f"\nTotal: {total} chunks indexed")
elif args.ingest_source:
source = next(
(s for s in TEMPLATE_SOURCES if s.name == args.ingest_source),
None
)
if not source:
print(f"Unknown source: {args.ingest_source}")
print("Available sources:")
for s in TEMPLATE_SOURCES:
print(f" - {s.name}")
return
print(f"Ingesting: {source.name}")
status = await ingestion.ingest_source(source)
print(f"\nResult: {status.chunks_indexed} chunks ({status.status})")
if status.errors:
for error in status.errors:
print(f" ERROR: {error}")
elif args.ingest_license:
license_type = LicenseType(args.ingest_license)
print(f"Ingesting all {license_type.value} sources...")
results = await ingestion.ingest_by_license(license_type)
print("\nResults:")
for name, status in results.items():
print(f" {name}: {status.chunks_indexed} chunks ({status.status})")
elif args.search:
print(f"Searching: {args.search}")
results = await ingestion.search(
args.search,
template_type=args.template_type,
language=args.language,
)
print(f"\nFound {len(results)} results:")
for i, result in enumerate(results, 1):
print(f"\n{i}. [{result['template_type']}] {result['document_title']}")
print(f" Score: {result['score']:.3f}")
print(f" License: {result['license_name']}")
print(f" Source: {result['source_name']}")
print(f" Language: {result['language']}")
if result['attribution_required']:
print(f" Attribution: {result['attribution_text']}")
print(f" Text: {result['text'][:200]}...")
else:
parser.print_help()
finally:
await ingestion.close()
if __name__ == "__main__":
asyncio.run(main())

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,156 @@
"""
Mail Database - Email Account Operations.
"""
import uuid
from typing import Optional, List, Dict
from .mail_db_pool import get_pool
async def create_email_account(
user_id: str,
tenant_id: str,
email: str,
display_name: str,
account_type: str,
imap_host: str,
imap_port: int,
imap_ssl: bool,
smtp_host: str,
smtp_port: int,
smtp_ssl: bool,
vault_path: str,
) -> Optional[str]:
"""Create a new email account. Returns the account ID."""
pool = await get_pool()
if pool is None:
return None
account_id = str(uuid.uuid4())
try:
async with pool.acquire() as conn:
await conn.execute(
"""
INSERT INTO external_email_accounts
(id, user_id, tenant_id, email, display_name, account_type,
imap_host, imap_port, imap_ssl, smtp_host, smtp_port, smtp_ssl, vault_path)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13)
""",
account_id, user_id, tenant_id, email, display_name, account_type,
imap_host, imap_port, imap_ssl, smtp_host, smtp_port, smtp_ssl, vault_path
)
return account_id
except Exception as e:
print(f"Failed to create email account: {e}")
return None
async def get_email_accounts(
user_id: str,
tenant_id: Optional[str] = None,
) -> List[Dict]:
"""Get all email accounts for a user."""
pool = await get_pool()
if pool is None:
return []
try:
async with pool.acquire() as conn:
if tenant_id:
rows = await conn.fetch(
"""
SELECT * FROM external_email_accounts
WHERE user_id = $1 AND tenant_id = $2
ORDER BY created_at
""",
user_id, tenant_id
)
else:
rows = await conn.fetch(
"""
SELECT * FROM external_email_accounts
WHERE user_id = $1
ORDER BY created_at
""",
user_id
)
return [dict(r) for r in rows]
except Exception as e:
print(f"Failed to get email accounts: {e}")
return []
async def get_email_account(account_id: str, user_id: str) -> Optional[Dict]:
"""Get a single email account."""
pool = await get_pool()
if pool is None:
return None
try:
async with pool.acquire() as conn:
row = await conn.fetchrow(
"""
SELECT * FROM external_email_accounts
WHERE id = $1 AND user_id = $2
""",
account_id, user_id
)
return dict(row) if row else None
except Exception as e:
print(f"Failed to get email account: {e}")
return None
async def update_account_status(
account_id: str,
status: str,
sync_error: Optional[str] = None,
email_count: Optional[int] = None,
unread_count: Optional[int] = None,
) -> bool:
"""Update account sync status."""
pool = await get_pool()
if pool is None:
return False
try:
async with pool.acquire() as conn:
await conn.execute(
"""
UPDATE external_email_accounts SET
status = $2,
sync_error = $3,
email_count = COALESCE($4, email_count),
unread_count = COALESCE($5, unread_count),
last_sync = NOW(),
updated_at = NOW()
WHERE id = $1
""",
account_id, status, sync_error, email_count, unread_count
)
return True
except Exception as e:
print(f"Failed to update account status: {e}")
return False
async def delete_email_account(account_id: str, user_id: str) -> bool:
"""Delete an email account (cascades to emails)."""
pool = await get_pool()
if pool is None:
return False
try:
async with pool.acquire() as conn:
result = await conn.execute(
"""
DELETE FROM external_email_accounts
WHERE id = $1 AND user_id = $2
""",
account_id, user_id
)
return "DELETE" in result
except Exception as e:
print(f"Failed to delete email account: {e}")
return False

View File

@@ -0,0 +1,225 @@
"""
Mail Database - Aggregated Email Operations.
"""
import json
import uuid
from typing import Optional, List, Dict
from datetime import datetime
from .mail_db_pool import get_pool
async def upsert_email(
account_id: str,
user_id: str,
tenant_id: str,
message_id: str,
subject: str,
sender_email: str,
sender_name: Optional[str],
recipients: List[str],
cc: List[str],
body_preview: Optional[str],
body_text: Optional[str],
body_html: Optional[str],
has_attachments: bool,
attachments: List[Dict],
headers: Dict,
folder: str,
date_sent: datetime,
date_received: datetime,
) -> Optional[str]:
"""Insert or update an email. Returns the email ID."""
pool = await get_pool()
if pool is None:
return None
email_id = str(uuid.uuid4())
try:
async with pool.acquire() as conn:
row = await conn.fetchrow(
"""
INSERT INTO aggregated_emails
(id, account_id, user_id, tenant_id, message_id, subject,
sender_email, sender_name, recipients, cc, body_preview,
body_text, body_html, has_attachments, attachments, headers,
folder, date_sent, date_received)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19)
ON CONFLICT (account_id, message_id) DO UPDATE SET
subject = EXCLUDED.subject,
is_read = EXCLUDED.is_read,
folder = EXCLUDED.folder
RETURNING id
""",
email_id, account_id, user_id, tenant_id, message_id, subject,
sender_email, sender_name, json.dumps(recipients), json.dumps(cc),
body_preview, body_text, body_html, has_attachments,
json.dumps(attachments), json.dumps(headers), folder,
date_sent, date_received
)
return row['id'] if row else None
except Exception as e:
print(f"Failed to upsert email: {e}")
return None
async def get_unified_inbox(
user_id: str,
account_ids: Optional[List[str]] = None,
categories: Optional[List[str]] = None,
is_read: Optional[bool] = None,
is_starred: Optional[bool] = None,
limit: int = 50,
offset: int = 0,
) -> List[Dict]:
"""Get unified inbox with filtering."""
pool = await get_pool()
if pool is None:
return []
try:
async with pool.acquire() as conn:
conditions = ["user_id = $1", "is_deleted = FALSE"]
params = [user_id]
param_idx = 2
if account_ids:
conditions.append(f"account_id = ANY(${param_idx})")
params.append(account_ids)
param_idx += 1
if categories:
conditions.append(f"category = ANY(${param_idx})")
params.append(categories)
param_idx += 1
if is_read is not None:
conditions.append(f"is_read = ${param_idx}")
params.append(is_read)
param_idx += 1
if is_starred is not None:
conditions.append(f"is_starred = ${param_idx}")
params.append(is_starred)
param_idx += 1
where_clause = " AND ".join(conditions)
params.extend([limit, offset])
query = f"""
SELECT e.*, a.email as account_email, a.display_name as account_name
FROM aggregated_emails e
JOIN external_email_accounts a ON e.account_id = a.id
WHERE {where_clause}
ORDER BY e.date_received DESC
LIMIT ${param_idx} OFFSET ${param_idx + 1}
"""
rows = await conn.fetch(query, *params)
return [dict(r) for r in rows]
except Exception as e:
print(f"Failed to get unified inbox: {e}")
return []
async def get_email(email_id: str, user_id: str) -> Optional[Dict]:
"""Get a single email by ID."""
pool = await get_pool()
if pool is None:
return None
try:
async with pool.acquire() as conn:
row = await conn.fetchrow(
"""
SELECT e.*, a.email as account_email, a.display_name as account_name
FROM aggregated_emails e
JOIN external_email_accounts a ON e.account_id = a.id
WHERE e.id = $1 AND e.user_id = $2
""",
email_id, user_id
)
return dict(row) if row else None
except Exception as e:
print(f"Failed to get email: {e}")
return None
async def update_email_ai_analysis(
email_id: str,
category: str,
sender_type: str,
sender_authority_name: Optional[str],
detected_deadlines: List[Dict],
suggested_priority: str,
ai_summary: Optional[str],
) -> bool:
"""Update email with AI analysis results."""
pool = await get_pool()
if pool is None:
return False
try:
async with pool.acquire() as conn:
await conn.execute(
"""
UPDATE aggregated_emails SET
category = $2,
sender_type = $3,
sender_authority_name = $4,
detected_deadlines = $5,
suggested_priority = $6,
ai_summary = $7,
ai_analyzed_at = NOW()
WHERE id = $1
""",
email_id, category, sender_type, sender_authority_name,
json.dumps(detected_deadlines), suggested_priority, ai_summary
)
return True
except Exception as e:
print(f"Failed to update email AI analysis: {e}")
return False
async def mark_email_read(email_id: str, user_id: str, is_read: bool = True) -> bool:
"""Mark email as read/unread."""
pool = await get_pool()
if pool is None:
return False
try:
async with pool.acquire() as conn:
await conn.execute(
"""
UPDATE aggregated_emails SET is_read = $3
WHERE id = $1 AND user_id = $2
""",
email_id, user_id, is_read
)
return True
except Exception as e:
print(f"Failed to mark email read: {e}")
return False
async def mark_email_starred(email_id: str, user_id: str, is_starred: bool = True) -> bool:
"""Mark email as starred/unstarred."""
pool = await get_pool()
if pool is None:
return False
try:
async with pool.acquire() as conn:
await conn.execute(
"""
UPDATE aggregated_emails SET is_starred = $3
WHERE id = $1 AND user_id = $2
""",
email_id, user_id, is_starred
)
return True
except Exception as e:
print(f"Failed to mark email starred: {e}")
return False

View File

@@ -0,0 +1,253 @@
"""
Mail Database - Connection Pool and Schema Initialization.
"""
import os
# Database Configuration - from Vault or environment (test default for CI)
DATABASE_URL = os.getenv("DATABASE_URL", "postgresql://test:test@localhost:5432/test")
# Flag to check if using test defaults
_DB_CONFIGURED = DATABASE_URL != "postgresql://test:test@localhost:5432/test"
# Connection pool (shared with metrics_db)
_pool = None
async def get_pool():
"""Get or create database connection pool."""
global _pool
if _pool is None:
try:
import asyncpg
_pool = await asyncpg.create_pool(DATABASE_URL, min_size=2, max_size=10)
except ImportError:
print("Warning: asyncpg not installed. Mail database disabled.")
return None
except Exception as e:
print(f"Warning: Failed to connect to PostgreSQL: {e}")
return None
return _pool
async def init_mail_tables() -> bool:
"""Initialize mail tables in PostgreSQL."""
pool = await get_pool()
if pool is None:
return False
create_tables_sql = """
-- =============================================================================
-- External Email Accounts
-- =============================================================================
CREATE TABLE IF NOT EXISTS external_email_accounts (
id VARCHAR(36) PRIMARY KEY,
user_id VARCHAR(36) NOT NULL,
tenant_id VARCHAR(36) NOT NULL,
email VARCHAR(255) NOT NULL,
display_name VARCHAR(255),
account_type VARCHAR(50) DEFAULT 'personal',
-- IMAP Settings (password stored in Vault)
imap_host VARCHAR(255) NOT NULL,
imap_port INTEGER DEFAULT 993,
imap_ssl BOOLEAN DEFAULT TRUE,
-- SMTP Settings
smtp_host VARCHAR(255) NOT NULL,
smtp_port INTEGER DEFAULT 465,
smtp_ssl BOOLEAN DEFAULT TRUE,
-- Vault path for credentials
vault_path VARCHAR(500),
-- Status tracking
status VARCHAR(20) DEFAULT 'pending',
last_sync TIMESTAMP,
sync_error TEXT,
email_count INTEGER DEFAULT 0,
unread_count INTEGER DEFAULT 0,
-- Timestamps
created_at TIMESTAMP DEFAULT NOW(),
updated_at TIMESTAMP DEFAULT NOW(),
-- Constraints
UNIQUE(user_id, email)
);
CREATE INDEX IF NOT EXISTS idx_mail_accounts_user ON external_email_accounts(user_id);
CREATE INDEX IF NOT EXISTS idx_mail_accounts_tenant ON external_email_accounts(tenant_id);
CREATE INDEX IF NOT EXISTS idx_mail_accounts_status ON external_email_accounts(status);
-- =============================================================================
-- Aggregated Emails
-- =============================================================================
CREATE TABLE IF NOT EXISTS aggregated_emails (
id VARCHAR(36) PRIMARY KEY,
account_id VARCHAR(36) REFERENCES external_email_accounts(id) ON DELETE CASCADE,
user_id VARCHAR(36) NOT NULL,
tenant_id VARCHAR(36) NOT NULL,
-- Email identification
message_id VARCHAR(500) NOT NULL,
folder VARCHAR(100) DEFAULT 'INBOX',
-- Email content
subject TEXT,
sender_email VARCHAR(255),
sender_name VARCHAR(255),
recipients JSONB DEFAULT '[]',
cc JSONB DEFAULT '[]',
body_preview TEXT,
body_text TEXT,
body_html TEXT,
has_attachments BOOLEAN DEFAULT FALSE,
attachments JSONB DEFAULT '[]',
headers JSONB DEFAULT '{}',
-- Status flags
is_read BOOLEAN DEFAULT FALSE,
is_starred BOOLEAN DEFAULT FALSE,
is_deleted BOOLEAN DEFAULT FALSE,
-- Dates
date_sent TIMESTAMP,
date_received TIMESTAMP,
-- AI enrichment
category VARCHAR(50),
sender_type VARCHAR(50),
sender_authority_name VARCHAR(255),
detected_deadlines JSONB DEFAULT '[]',
suggested_priority VARCHAR(20),
ai_summary TEXT,
ai_analyzed_at TIMESTAMP,
created_at TIMESTAMP DEFAULT NOW(),
-- Prevent duplicate imports
UNIQUE(account_id, message_id)
);
CREATE INDEX IF NOT EXISTS idx_emails_account ON aggregated_emails(account_id);
CREATE INDEX IF NOT EXISTS idx_emails_user ON aggregated_emails(user_id);
CREATE INDEX IF NOT EXISTS idx_emails_tenant ON aggregated_emails(tenant_id);
CREATE INDEX IF NOT EXISTS idx_emails_date ON aggregated_emails(date_received DESC);
CREATE INDEX IF NOT EXISTS idx_emails_category ON aggregated_emails(category);
CREATE INDEX IF NOT EXISTS idx_emails_unread ON aggregated_emails(is_read) WHERE is_read = FALSE;
CREATE INDEX IF NOT EXISTS idx_emails_starred ON aggregated_emails(is_starred) WHERE is_starred = TRUE;
CREATE INDEX IF NOT EXISTS idx_emails_sender ON aggregated_emails(sender_email);
-- =============================================================================
-- Inbox Tasks (Arbeitsvorrat)
-- =============================================================================
CREATE TABLE IF NOT EXISTS inbox_tasks (
id VARCHAR(36) PRIMARY KEY,
user_id VARCHAR(36) NOT NULL,
tenant_id VARCHAR(36) NOT NULL,
email_id VARCHAR(36) REFERENCES aggregated_emails(id) ON DELETE SET NULL,
account_id VARCHAR(36) REFERENCES external_email_accounts(id) ON DELETE SET NULL,
-- Task content
title VARCHAR(500) NOT NULL,
description TEXT,
priority VARCHAR(20) DEFAULT 'medium',
status VARCHAR(20) DEFAULT 'pending',
deadline TIMESTAMP,
-- Source information
source_email_subject TEXT,
source_sender VARCHAR(255),
source_sender_type VARCHAR(50),
-- AI extraction info
ai_extracted BOOLEAN DEFAULT FALSE,
confidence_score FLOAT,
-- Completion tracking
completed_at TIMESTAMP,
reminder_at TIMESTAMP,
-- Timestamps
created_at TIMESTAMP DEFAULT NOW(),
updated_at TIMESTAMP DEFAULT NOW()
);
CREATE INDEX IF NOT EXISTS idx_tasks_user ON inbox_tasks(user_id);
CREATE INDEX IF NOT EXISTS idx_tasks_tenant ON inbox_tasks(tenant_id);
CREATE INDEX IF NOT EXISTS idx_tasks_status ON inbox_tasks(status);
CREATE INDEX IF NOT EXISTS idx_tasks_deadline ON inbox_tasks(deadline) WHERE deadline IS NOT NULL;
CREATE INDEX IF NOT EXISTS idx_tasks_priority ON inbox_tasks(priority);
CREATE INDEX IF NOT EXISTS idx_tasks_email ON inbox_tasks(email_id) WHERE email_id IS NOT NULL;
-- =============================================================================
-- Email Templates
-- =============================================================================
CREATE TABLE IF NOT EXISTS email_templates (
id VARCHAR(36) PRIMARY KEY,
user_id VARCHAR(36), -- NULL for system templates
tenant_id VARCHAR(36),
name VARCHAR(255) NOT NULL,
category VARCHAR(100),
subject_template TEXT,
body_template TEXT,
variables JSONB DEFAULT '[]',
is_system BOOLEAN DEFAULT FALSE,
usage_count INTEGER DEFAULT 0,
created_at TIMESTAMP DEFAULT NOW(),
updated_at TIMESTAMP DEFAULT NOW()
);
CREATE INDEX IF NOT EXISTS idx_templates_user ON email_templates(user_id);
CREATE INDEX IF NOT EXISTS idx_templates_tenant ON email_templates(tenant_id);
CREATE INDEX IF NOT EXISTS idx_templates_system ON email_templates(is_system);
-- =============================================================================
-- Mail Audit Log
-- =============================================================================
CREATE TABLE IF NOT EXISTS mail_audit_log (
id VARCHAR(36) PRIMARY KEY,
user_id VARCHAR(36) NOT NULL,
tenant_id VARCHAR(36),
action VARCHAR(100) NOT NULL,
entity_type VARCHAR(50), -- account, email, task
entity_id VARCHAR(36),
details JSONB,
ip_address VARCHAR(45),
user_agent TEXT,
created_at TIMESTAMP DEFAULT NOW()
);
CREATE INDEX IF NOT EXISTS idx_mail_audit_user ON mail_audit_log(user_id);
CREATE INDEX IF NOT EXISTS idx_mail_audit_created ON mail_audit_log(created_at DESC);
CREATE INDEX IF NOT EXISTS idx_mail_audit_action ON mail_audit_log(action);
-- =============================================================================
-- Sync Status Tracking
-- =============================================================================
CREATE TABLE IF NOT EXISTS mail_sync_status (
id VARCHAR(36) PRIMARY KEY,
account_id VARCHAR(36) REFERENCES external_email_accounts(id) ON DELETE CASCADE,
folder VARCHAR(100),
last_uid INTEGER DEFAULT 0,
last_sync TIMESTAMP,
sync_errors INTEGER DEFAULT 0,
created_at TIMESTAMP DEFAULT NOW(),
updated_at TIMESTAMP DEFAULT NOW(),
UNIQUE(account_id, folder)
);
"""
try:
async with pool.acquire() as conn:
await conn.execute(create_tables_sql)
print("Mail tables initialized successfully")
return True
except Exception as e:
print(f"Failed to initialize mail tables: {e}")
return False

View File

@@ -0,0 +1,118 @@
"""
Mail Database - Statistics and Audit Log Operations.
"""
import json
import uuid
from typing import Optional, Dict
from datetime import datetime
from .mail_db_pool import get_pool
async def get_mail_stats(user_id: str) -> Dict:
"""Get overall mail statistics for a user."""
pool = await get_pool()
if pool is None:
return {}
try:
async with pool.acquire() as conn:
today = datetime.now().replace(hour=0, minute=0, second=0, microsecond=0)
# Account stats
accounts = await conn.fetch(
"""
SELECT id, email, display_name, status, email_count, unread_count, last_sync
FROM external_email_accounts
WHERE user_id = $1
""",
user_id
)
# Email counts
email_stats = await conn.fetchrow(
"""
SELECT
COUNT(*) as total_emails,
COUNT(*) FILTER (WHERE is_read = FALSE) as unread_emails,
COUNT(*) FILTER (WHERE date_received >= $2) as emails_today,
COUNT(*) FILTER (WHERE ai_analyzed_at >= $2) as ai_analyses_today
FROM aggregated_emails
WHERE user_id = $1
""",
user_id, today
)
# Task counts
task_stats = await conn.fetchrow(
"""
SELECT
COUNT(*) as total_tasks,
COUNT(*) FILTER (WHERE status = 'pending') as pending_tasks,
COUNT(*) FILTER (WHERE status != 'completed' AND deadline < NOW()) as overdue_tasks
FROM inbox_tasks
WHERE user_id = $1
""",
user_id
)
return {
"total_accounts": len(accounts),
"active_accounts": sum(1 for a in accounts if a['status'] == 'active'),
"error_accounts": sum(1 for a in accounts if a['status'] == 'error'),
"total_emails": email_stats['total_emails'] or 0,
"unread_emails": email_stats['unread_emails'] or 0,
"total_tasks": task_stats['total_tasks'] or 0,
"pending_tasks": task_stats['pending_tasks'] or 0,
"overdue_tasks": task_stats['overdue_tasks'] or 0,
"emails_today": email_stats['emails_today'] or 0,
"ai_analyses_today": email_stats['ai_analyses_today'] or 0,
"per_account": [
{
"id": a['id'],
"email": a['email'],
"display_name": a['display_name'],
"status": a['status'],
"email_count": a['email_count'],
"unread_count": a['unread_count'],
"last_sync": a['last_sync'].isoformat() if a['last_sync'] else None,
}
for a in accounts
],
}
except Exception as e:
print(f"Failed to get mail stats: {e}")
return {}
async def log_mail_audit(
user_id: str,
action: str,
entity_type: Optional[str] = None,
entity_id: Optional[str] = None,
details: Optional[Dict] = None,
tenant_id: Optional[str] = None,
ip_address: Optional[str] = None,
user_agent: Optional[str] = None,
) -> bool:
"""Log a mail action for audit trail."""
pool = await get_pool()
if pool is None:
return False
try:
async with pool.acquire() as conn:
await conn.execute(
"""
INSERT INTO mail_audit_log
(id, user_id, tenant_id, action, entity_type, entity_id, details, ip_address, user_agent)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
""",
str(uuid.uuid4()), user_id, tenant_id, action, entity_type, entity_id,
json.dumps(details) if details else None, ip_address, user_agent
)
return True
except Exception as e:
print(f"Failed to log mail audit: {e}")
return False

View File

@@ -0,0 +1,247 @@
"""
Mail Database - Inbox Task Operations.
"""
import uuid
from typing import Optional, List, Dict
from datetime import datetime, timedelta
from .mail_db_pool import get_pool
async def create_task(
user_id: str,
tenant_id: str,
title: str,
description: Optional[str] = None,
priority: str = "medium",
deadline: Optional[datetime] = None,
email_id: Optional[str] = None,
account_id: Optional[str] = None,
source_email_subject: Optional[str] = None,
source_sender: Optional[str] = None,
source_sender_type: Optional[str] = None,
ai_extracted: bool = False,
confidence_score: Optional[float] = None,
) -> Optional[str]:
"""Create a new inbox task."""
pool = await get_pool()
if pool is None:
return None
task_id = str(uuid.uuid4())
try:
async with pool.acquire() as conn:
await conn.execute(
"""
INSERT INTO inbox_tasks
(id, user_id, tenant_id, title, description, priority, deadline,
email_id, account_id, source_email_subject, source_sender,
source_sender_type, ai_extracted, confidence_score)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14)
""",
task_id, user_id, tenant_id, title, description, priority, deadline,
email_id, account_id, source_email_subject, source_sender,
source_sender_type, ai_extracted, confidence_score
)
return task_id
except Exception as e:
print(f"Failed to create task: {e}")
return None
async def get_tasks(
user_id: str,
status: Optional[str] = None,
priority: Optional[str] = None,
include_completed: bool = False,
limit: int = 50,
offset: int = 0,
) -> List[Dict]:
"""Get tasks for a user."""
pool = await get_pool()
if pool is None:
return []
try:
async with pool.acquire() as conn:
conditions = ["user_id = $1"]
params = [user_id]
param_idx = 2
if not include_completed:
conditions.append("status != 'completed'")
if status:
conditions.append(f"status = ${param_idx}")
params.append(status)
param_idx += 1
if priority:
conditions.append(f"priority = ${param_idx}")
params.append(priority)
param_idx += 1
where_clause = " AND ".join(conditions)
params.extend([limit, offset])
query = f"""
SELECT * FROM inbox_tasks
WHERE {where_clause}
ORDER BY
CASE priority
WHEN 'urgent' THEN 1
WHEN 'high' THEN 2
WHEN 'medium' THEN 3
WHEN 'low' THEN 4
END,
deadline ASC NULLS LAST,
created_at DESC
LIMIT ${param_idx} OFFSET ${param_idx + 1}
"""
rows = await conn.fetch(query, *params)
return [dict(r) for r in rows]
except Exception as e:
print(f"Failed to get tasks: {e}")
return []
async def get_task(task_id: str, user_id: str) -> Optional[Dict]:
"""Get a single task."""
pool = await get_pool()
if pool is None:
return None
try:
async with pool.acquire() as conn:
row = await conn.fetchrow(
"SELECT * FROM inbox_tasks WHERE id = $1 AND user_id = $2",
task_id, user_id
)
return dict(row) if row else None
except Exception as e:
print(f"Failed to get task: {e}")
return None
async def update_task(
task_id: str,
user_id: str,
title: Optional[str] = None,
description: Optional[str] = None,
priority: Optional[str] = None,
status: Optional[str] = None,
deadline: Optional[datetime] = None,
) -> bool:
"""Update a task."""
pool = await get_pool()
if pool is None:
return False
try:
async with pool.acquire() as conn:
updates = ["updated_at = NOW()"]
params = [task_id, user_id]
param_idx = 3
if title is not None:
updates.append(f"title = ${param_idx}")
params.append(title)
param_idx += 1
if description is not None:
updates.append(f"description = ${param_idx}")
params.append(description)
param_idx += 1
if priority is not None:
updates.append(f"priority = ${param_idx}")
params.append(priority)
param_idx += 1
if status is not None:
updates.append(f"status = ${param_idx}")
params.append(status)
param_idx += 1
if status == "completed":
updates.append("completed_at = NOW()")
if deadline is not None:
updates.append(f"deadline = ${param_idx}")
params.append(deadline)
param_idx += 1
set_clause = ", ".join(updates)
await conn.execute(
f"UPDATE inbox_tasks SET {set_clause} WHERE id = $1 AND user_id = $2",
*params
)
return True
except Exception as e:
print(f"Failed to update task: {e}")
return False
async def get_task_dashboard_stats(user_id: str) -> Dict:
"""Get dashboard statistics for tasks."""
pool = await get_pool()
if pool is None:
return {}
try:
async with pool.acquire() as conn:
now = datetime.now()
today_end = now.replace(hour=23, minute=59, second=59)
week_end = now + timedelta(days=7)
stats = await conn.fetchrow(
"""
SELECT
COUNT(*) as total_tasks,
COUNT(*) FILTER (WHERE status = 'pending') as pending_tasks,
COUNT(*) FILTER (WHERE status = 'in_progress') as in_progress_tasks,
COUNT(*) FILTER (WHERE status = 'completed') as completed_tasks,
COUNT(*) FILTER (WHERE status != 'completed' AND deadline < $2) as overdue_tasks,
COUNT(*) FILTER (WHERE status != 'completed' AND deadline <= $3) as due_today,
COUNT(*) FILTER (WHERE status != 'completed' AND deadline <= $4) as due_this_week
FROM inbox_tasks
WHERE user_id = $1
""",
user_id, now, today_end, week_end
)
by_priority = await conn.fetch(
"""
SELECT priority, COUNT(*) as count
FROM inbox_tasks
WHERE user_id = $1 AND status != 'completed'
GROUP BY priority
""",
user_id
)
by_sender = await conn.fetch(
"""
SELECT source_sender_type, COUNT(*) as count
FROM inbox_tasks
WHERE user_id = $1 AND status != 'completed' AND source_sender_type IS NOT NULL
GROUP BY source_sender_type
""",
user_id
)
return {
"total_tasks": stats['total_tasks'] or 0,
"pending_tasks": stats['pending_tasks'] or 0,
"in_progress_tasks": stats['in_progress_tasks'] or 0,
"completed_tasks": stats['completed_tasks'] or 0,
"overdue_tasks": stats['overdue_tasks'] or 0,
"due_today": stats['due_today'] or 0,
"due_this_week": stats['due_this_week'] or 0,
"by_priority": {r['priority']: r['count'] for r in by_priority},
"by_sender_type": {r['source_sender_type']: r['count'] for r in by_sender},
}
except Exception as e:
print(f"Failed to get task stats: {e}")
return {}

View File

@@ -0,0 +1,272 @@
"""
OCR Merge Helpers — functions for combining PaddleOCR/RapidOCR with Tesseract results.
Extracted from ocr_pipeline_ocr_merge.py.
Lizenz: Apache 2.0
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
"""
import logging
from typing import List
logger = logging.getLogger(__name__)
def _split_paddle_multi_words(words: list) -> list:
"""Split PaddleOCR multi-word boxes into individual word boxes.
PaddleOCR often returns entire phrases as a single box, e.g.
"More than 200 singers took part in the" with one bounding box.
This splits them into individual words with proportional widths.
Also handles leading "!" (e.g. "!Betonung" -> ["!", "Betonung"])
and IPA brackets (e.g. "badge[bxd3]" -> ["badge", "[bxd3]"]).
"""
import re
result = []
for w in words:
raw_text = w.get("text", "").strip()
if not raw_text:
continue
# Split on whitespace, before "[" (IPA), and after "!" before letter
tokens = re.split(
r'\s+|(?=\[)|(?<=!)(?=[A-Za-z\u00c0-\u024f])', raw_text
)
tokens = [t for t in tokens if t]
if len(tokens) <= 1:
result.append(w)
else:
# Split proportionally by character count
total_chars = sum(len(t) for t in tokens)
if total_chars == 0:
continue
n_gaps = len(tokens) - 1
gap_px = w["width"] * 0.02
usable_w = w["width"] - gap_px * n_gaps
cursor = w["left"]
for t in tokens:
token_w = max(1, usable_w * len(t) / total_chars)
result.append({
"text": t,
"left": round(cursor),
"top": w["top"],
"width": round(token_w),
"height": w["height"],
"conf": w.get("conf", 0),
})
cursor += token_w + gap_px
return result
def _group_words_into_rows(words: list, row_gap: int = 12) -> list:
"""Group words into rows by Y-position clustering.
Words whose vertical centers are within `row_gap` pixels are on the same row.
Returns list of rows, each row is a list of words sorted left-to-right.
"""
if not words:
return []
# Sort by vertical center
sorted_words = sorted(words, key=lambda w: w["top"] + w.get("height", 0) / 2)
rows: list = []
current_row: list = [sorted_words[0]]
current_cy = sorted_words[0]["top"] + sorted_words[0].get("height", 0) / 2
for w in sorted_words[1:]:
cy = w["top"] + w.get("height", 0) / 2
if abs(cy - current_cy) <= row_gap:
current_row.append(w)
else:
# Sort current row left-to-right before saving
rows.append(sorted(current_row, key=lambda w: w["left"]))
current_row = [w]
current_cy = cy
if current_row:
rows.append(sorted(current_row, key=lambda w: w["left"]))
return rows
def _row_center_y(row: list) -> float:
"""Average vertical center of a row of words."""
if not row:
return 0.0
return sum(w["top"] + w.get("height", 0) / 2 for w in row) / len(row)
def _merge_row_sequences(paddle_row: list, tess_row: list) -> list:
"""Merge two word sequences from the same row using sequence alignment.
Both sequences are sorted left-to-right. Walk through both simultaneously:
- If words match (same/similar text): take Paddle text with averaged coords
- If they don't match: the extra word is unique to one engine, include it
"""
merged = []
pi, ti = 0, 0
while pi < len(paddle_row) and ti < len(tess_row):
pw = paddle_row[pi]
tw = tess_row[ti]
pt = pw.get("text", "").lower().strip()
tt = tw.get("text", "").lower().strip()
is_same = (pt == tt) or (len(pt) > 1 and len(tt) > 1 and (pt in tt or tt in pt))
# Spatial overlap check
spatial_match = False
if not is_same:
overlap_left = max(pw["left"], tw["left"])
overlap_right = min(
pw["left"] + pw.get("width", 0),
tw["left"] + tw.get("width", 0),
)
overlap_w = max(0, overlap_right - overlap_left)
min_w = min(pw.get("width", 1), tw.get("width", 1))
if min_w > 0 and overlap_w / min_w >= 0.4:
is_same = True
spatial_match = True
if is_same:
pc = pw.get("conf", 80)
tc = tw.get("conf", 50)
total = pc + tc
if total == 0:
total = 1
if spatial_match and pc < tc:
best_text = tw["text"]
else:
best_text = pw["text"]
merged.append({
"text": best_text,
"left": round((pw["left"] * pc + tw["left"] * tc) / total),
"top": round((pw["top"] * pc + tw["top"] * tc) / total),
"width": round((pw["width"] * pc + tw["width"] * tc) / total),
"height": round((pw["height"] * pc + tw["height"] * tc) / total),
"conf": max(pc, tc),
})
pi += 1
ti += 1
else:
paddle_ahead = any(
tess_row[t].get("text", "").lower().strip() == pt
for t in range(ti + 1, min(ti + 4, len(tess_row)))
)
tess_ahead = any(
paddle_row[p].get("text", "").lower().strip() == tt
for p in range(pi + 1, min(pi + 4, len(paddle_row)))
)
if paddle_ahead and not tess_ahead:
if tw.get("conf", 0) >= 30:
merged.append(tw)
ti += 1
elif tess_ahead and not paddle_ahead:
merged.append(pw)
pi += 1
else:
if pw["left"] <= tw["left"]:
merged.append(pw)
pi += 1
else:
if tw.get("conf", 0) >= 30:
merged.append(tw)
ti += 1
while pi < len(paddle_row):
merged.append(paddle_row[pi])
pi += 1
while ti < len(tess_row):
tw = tess_row[ti]
if tw.get("conf", 0) >= 30:
merged.append(tw)
ti += 1
return merged
def _merge_paddle_tesseract(paddle_words: list, tess_words: list) -> list:
"""Merge word boxes from PaddleOCR and Tesseract using row-based sequence alignment."""
if not paddle_words and not tess_words:
return []
if not paddle_words:
return [w for w in tess_words if w.get("conf", 0) >= 40]
if not tess_words:
return list(paddle_words)
paddle_rows = _group_words_into_rows(paddle_words)
tess_rows = _group_words_into_rows(tess_words)
used_tess_rows: set = set()
merged_all: list = []
for pr in paddle_rows:
pr_cy = _row_center_y(pr)
best_dist, best_tri = float("inf"), -1
for tri, tr in enumerate(tess_rows):
if tri in used_tess_rows:
continue
tr_cy = _row_center_y(tr)
dist = abs(pr_cy - tr_cy)
if dist < best_dist:
best_dist, best_tri = dist, tri
max_row_dist = max(
max((w.get("height", 20) for w in pr), default=20),
15,
)
if best_tri >= 0 and best_dist <= max_row_dist:
tr = tess_rows[best_tri]
used_tess_rows.add(best_tri)
merged_all.extend(_merge_row_sequences(pr, tr))
else:
merged_all.extend(pr)
for tri, tr in enumerate(tess_rows):
if tri not in used_tess_rows:
for tw in tr:
if tw.get("conf", 0) >= 40:
merged_all.append(tw)
return merged_all
def _deduplicate_words(words: list) -> list:
"""Remove duplicate words with same text at overlapping positions."""
if not words:
return words
result: list = []
for w in words:
wt = w.get("text", "").lower().strip()
if not wt:
continue
is_dup = False
w_right = w["left"] + w.get("width", 0)
w_bottom = w["top"] + w.get("height", 0)
for existing in result:
et = existing.get("text", "").lower().strip()
if wt != et:
continue
ox_l = max(w["left"], existing["left"])
ox_r = min(w_right, existing["left"] + existing.get("width", 0))
ox = max(0, ox_r - ox_l)
min_w = min(w.get("width", 1), existing.get("width", 1))
if min_w <= 0 or ox / min_w < 0.5:
continue
oy_t = max(w["top"], existing["top"])
oy_b = min(w_bottom, existing["top"] + existing.get("height", 0))
oy = max(0, oy_b - oy_t)
min_h = min(w.get("height", 1), existing.get("height", 1))
if min_h > 0 and oy / min_h >= 0.5:
is_dup = True
break
if not is_dup:
result.append(w)
removed = len(words) - len(result)
if removed:
logger.info("dedup: removed %d duplicate words", removed)
return result

View File

@@ -0,0 +1,209 @@
"""
OCR Pipeline LLM Review — LLM-based correction endpoints.
Extracted from ocr_pipeline_postprocess.py.
Lizenz: Apache 2.0
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
"""
import json
import logging
from datetime import datetime
from typing import Dict, List
from fastapi import APIRouter, HTTPException, Request
from fastapi.responses import StreamingResponse
from cv_vocab_pipeline import (
OLLAMA_REVIEW_MODEL,
llm_review_entries,
llm_review_entries_streaming,
)
from ocr_pipeline_session_store import (
get_session_db,
update_session_db,
)
from ocr_pipeline_common import (
_cache,
_append_pipeline_log,
)
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["ocr-pipeline"])
# ---------------------------------------------------------------------------
# Step 8: LLM Review
# ---------------------------------------------------------------------------
@router.post("/sessions/{session_id}/llm-review")
async def run_llm_review(session_id: str, request: Request, stream: bool = False):
"""Run LLM-based correction on vocab entries from Step 5.
Query params:
stream: false (default) for JSON response, true for SSE streaming
"""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
word_result = session.get("word_result")
if not word_result:
raise HTTPException(status_code=400, detail="No word result found — run Step 5 first")
entries = word_result.get("vocab_entries") or word_result.get("entries") or []
if not entries:
raise HTTPException(status_code=400, detail="No vocab entries found — run Step 5 first")
# Optional model override from request body
body = {}
try:
body = await request.json()
except Exception:
pass
model = body.get("model") or OLLAMA_REVIEW_MODEL
if stream:
return StreamingResponse(
_llm_review_stream_generator(session_id, entries, word_result, model, request),
media_type="text/event-stream",
headers={"Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no"},
)
# Non-streaming path
try:
result = await llm_review_entries(entries, model=model)
except Exception as e:
import traceback
logger.error(f"LLM review failed for session {session_id}: {type(e).__name__}: {e}\n{traceback.format_exc()}")
raise HTTPException(status_code=502, detail=f"LLM review failed ({type(e).__name__}): {e}")
# Store result inside word_result as a sub-key
word_result["llm_review"] = {
"changes": result["changes"],
"model_used": result["model_used"],
"duration_ms": result["duration_ms"],
"entries_corrected": result["entries_corrected"],
}
await update_session_db(session_id, word_result=word_result, current_step=9)
if session_id in _cache:
_cache[session_id]["word_result"] = word_result
logger.info(f"LLM review session {session_id}: {len(result['changes'])} changes, "
f"{result['duration_ms']}ms, model={result['model_used']}")
await _append_pipeline_log(session_id, "correction", {
"engine": "llm",
"model": result["model_used"],
"total_entries": len(entries),
"corrections_proposed": len(result["changes"]),
}, duration_ms=result["duration_ms"])
return {
"session_id": session_id,
"changes": result["changes"],
"model_used": result["model_used"],
"duration_ms": result["duration_ms"],
"total_entries": len(entries),
"corrections_found": len(result["changes"]),
}
async def _llm_review_stream_generator(
session_id: str,
entries: List[Dict],
word_result: Dict,
model: str,
request: Request,
):
"""SSE generator that yields batch-by-batch LLM review progress."""
try:
async for event in llm_review_entries_streaming(entries, model=model):
if await request.is_disconnected():
logger.info(f"SSE: client disconnected during LLM review for {session_id}")
return
yield f"data: {json.dumps(event, ensure_ascii=False)}\n\n"
# On complete: persist to DB
if event.get("type") == "complete":
word_result["llm_review"] = {
"changes": event["changes"],
"model_used": event["model_used"],
"duration_ms": event["duration_ms"],
"entries_corrected": event["entries_corrected"],
}
await update_session_db(session_id, word_result=word_result, current_step=9)
if session_id in _cache:
_cache[session_id]["word_result"] = word_result
logger.info(f"LLM review SSE session {session_id}: {event['corrections_found']} changes, "
f"{event['duration_ms']}ms, skipped={event['skipped']}, model={event['model_used']}")
except Exception as e:
import traceback
logger.error(f"LLM review SSE failed for {session_id}: {type(e).__name__}: {e}\n{traceback.format_exc()}")
error_event = {"type": "error", "detail": f"{type(e).__name__}: {e}"}
yield f"data: {json.dumps(error_event)}\n\n"
@router.post("/sessions/{session_id}/llm-review/apply")
async def apply_llm_corrections(session_id: str, request: Request):
"""Apply selected LLM corrections to vocab entries."""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
word_result = session.get("word_result")
if not word_result:
raise HTTPException(status_code=400, detail="No word result found")
llm_review = word_result.get("llm_review")
if not llm_review:
raise HTTPException(status_code=400, detail="No LLM review found — run /llm-review first")
body = await request.json()
accepted_indices = set(body.get("accepted_indices", [])) # indices into changes[]
changes = llm_review.get("changes", [])
entries = word_result.get("vocab_entries") or word_result.get("entries") or []
# Build a lookup: (row_index, field) -> new_value for accepted changes
corrections = {}
applied_count = 0
for idx, change in enumerate(changes):
if idx in accepted_indices:
key = (change["row_index"], change["field"])
corrections[key] = change["new"]
applied_count += 1
# Apply corrections to entries
for entry in entries:
row_idx = entry.get("row_index", -1)
for field_name in ("english", "german", "example"):
key = (row_idx, field_name)
if key in corrections:
entry[field_name] = corrections[key]
entry["llm_corrected"] = True
# Update word_result
word_result["vocab_entries"] = entries
word_result["entries"] = entries
word_result["llm_review"]["applied_count"] = applied_count
word_result["llm_review"]["applied_at"] = datetime.utcnow().isoformat()
await update_session_db(session_id, word_result=word_result)
if session_id in _cache:
_cache[session_id]["word_result"] = word_result
logger.info(f"Applied {applied_count}/{len(changes)} LLM corrections for session {session_id}")
return {
"session_id": session_id,
"applied_count": applied_count,
"total_changes": len(changes),
}

View File

@@ -1,10 +1,8 @@
"""
OCR Merge Helpers and Kombi Endpoints.
OCR Merge Kombi Endpoints — paddle-kombi and rapid-kombi endpoints.
Contains merge helper functions for combining PaddleOCR/RapidOCR with Tesseract
results, plus the paddle-kombi and rapid-kombi endpoints.
Extracted from ocr_pipeline_api.py for modularity.
Merge helper functions live in ocr_merge_helpers.py.
This module re-exports them for backward compatibility.
Lizenz: Apache 2.0
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
@@ -12,10 +10,8 @@ DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
import logging
import time
from typing import Any, Dict, List
import cv2
import httpx
import numpy as np
from fastapi import APIRouter, HTTPException
@@ -23,356 +19,23 @@ from cv_words_first import build_grid_from_words
from ocr_pipeline_common import _cache, _append_pipeline_log
from ocr_pipeline_session_store import get_session_image, update_session_db
# Re-export merge helpers for backward compatibility
from ocr_merge_helpers import ( # noqa: F401
_split_paddle_multi_words,
_group_words_into_rows,
_row_center_y,
_merge_row_sequences,
_merge_paddle_tesseract,
_deduplicate_words,
)
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["ocr-pipeline"])
# ---------------------------------------------------------------------------
# Merge helper functions
# ---------------------------------------------------------------------------
def _split_paddle_multi_words(words: list) -> list:
"""Split PaddleOCR multi-word boxes into individual word boxes.
PaddleOCR often returns entire phrases as a single box, e.g.
"More than 200 singers took part in the" with one bounding box.
This splits them into individual words with proportional widths.
Also handles leading "!" (e.g. "!Betonung" → ["!", "Betonung"])
and IPA brackets (e.g. "badge[bxd3]" → ["badge", "[bxd3]"]).
"""
import re
result = []
for w in words:
raw_text = w.get("text", "").strip()
if not raw_text:
continue
# Split on whitespace, before "[" (IPA), and after "!" before letter
tokens = re.split(
r'\s+|(?=\[)|(?<=!)(?=[A-Za-z\u00c0-\u024f])', raw_text
)
tokens = [t for t in tokens if t]
if len(tokens) <= 1:
result.append(w)
else:
# Split proportionally by character count
total_chars = sum(len(t) for t in tokens)
if total_chars == 0:
continue
n_gaps = len(tokens) - 1
gap_px = w["width"] * 0.02
usable_w = w["width"] - gap_px * n_gaps
cursor = w["left"]
for t in tokens:
token_w = max(1, usable_w * len(t) / total_chars)
result.append({
"text": t,
"left": round(cursor),
"top": w["top"],
"width": round(token_w),
"height": w["height"],
"conf": w.get("conf", 0),
})
cursor += token_w + gap_px
return result
def _group_words_into_rows(words: list, row_gap: int = 12) -> list:
"""Group words into rows by Y-position clustering.
Words whose vertical centers are within `row_gap` pixels are on the same row.
Returns list of rows, each row is a list of words sorted left-to-right.
"""
if not words:
return []
# Sort by vertical center
sorted_words = sorted(words, key=lambda w: w["top"] + w.get("height", 0) / 2)
rows: list = []
current_row: list = [sorted_words[0]]
current_cy = sorted_words[0]["top"] + sorted_words[0].get("height", 0) / 2
for w in sorted_words[1:]:
cy = w["top"] + w.get("height", 0) / 2
if abs(cy - current_cy) <= row_gap:
current_row.append(w)
else:
# Sort current row left-to-right before saving
rows.append(sorted(current_row, key=lambda w: w["left"]))
current_row = [w]
current_cy = cy
if current_row:
rows.append(sorted(current_row, key=lambda w: w["left"]))
return rows
def _row_center_y(row: list) -> float:
"""Average vertical center of a row of words."""
if not row:
return 0.0
return sum(w["top"] + w.get("height", 0) / 2 for w in row) / len(row)
def _merge_row_sequences(paddle_row: list, tess_row: list) -> list:
"""Merge two word sequences from the same row using sequence alignment.
Both sequences are sorted left-to-right. Walk through both simultaneously:
- If words match (same/similar text): take Paddle text with averaged coords
- If they don't match: the extra word is unique to one engine, include it
This prevents duplicates because both engines produce words in the same order.
"""
merged = []
pi, ti = 0, 0
while pi < len(paddle_row) and ti < len(tess_row):
pw = paddle_row[pi]
tw = tess_row[ti]
# Check if these are the same word
pt = pw.get("text", "").lower().strip()
tt = tw.get("text", "").lower().strip()
# Same text or one contains the other
is_same = (pt == tt) or (len(pt) > 1 and len(tt) > 1 and (pt in tt or tt in pt))
# Spatial overlap check: if words overlap >= 40% horizontally,
# they're the same physical word regardless of OCR text differences.
# (40% catches borderline cases like "Stick"/"Stück" at 48% overlap)
spatial_match = False
if not is_same:
overlap_left = max(pw["left"], tw["left"])
overlap_right = min(
pw["left"] + pw.get("width", 0),
tw["left"] + tw.get("width", 0),
)
overlap_w = max(0, overlap_right - overlap_left)
min_w = min(pw.get("width", 1), tw.get("width", 1))
if min_w > 0 and overlap_w / min_w >= 0.4:
is_same = True
spatial_match = True
if is_same:
# Matched — average coordinates weighted by confidence
pc = pw.get("conf", 80)
tc = tw.get("conf", 50)
total = pc + tc
if total == 0:
total = 1
# Text: prefer higher-confidence engine when texts differ
# (e.g. Tesseract "Stück" conf=98 vs PaddleOCR "Stick" conf=80)
if spatial_match and pc < tc:
best_text = tw["text"]
else:
best_text = pw["text"]
merged.append({
"text": best_text,
"left": round((pw["left"] * pc + tw["left"] * tc) / total),
"top": round((pw["top"] * pc + tw["top"] * tc) / total),
"width": round((pw["width"] * pc + tw["width"] * tc) / total),
"height": round((pw["height"] * pc + tw["height"] * tc) / total),
"conf": max(pc, tc),
})
pi += 1
ti += 1
else:
# Different text — one engine found something extra
# Look ahead: is the current Paddle word somewhere in Tesseract ahead?
paddle_ahead = any(
tess_row[t].get("text", "").lower().strip() == pt
for t in range(ti + 1, min(ti + 4, len(tess_row)))
)
# Is the current Tesseract word somewhere in Paddle ahead?
tess_ahead = any(
paddle_row[p].get("text", "").lower().strip() == tt
for p in range(pi + 1, min(pi + 4, len(paddle_row)))
)
if paddle_ahead and not tess_ahead:
# Tesseract has an extra word (e.g. "!" or bullet) → include it
if tw.get("conf", 0) >= 30:
merged.append(tw)
ti += 1
elif tess_ahead and not paddle_ahead:
# Paddle has an extra word → include it
merged.append(pw)
pi += 1
else:
# Both have unique words or neither found ahead → take leftmost first
if pw["left"] <= tw["left"]:
merged.append(pw)
pi += 1
else:
if tw.get("conf", 0) >= 30:
merged.append(tw)
ti += 1
# Remaining words from either engine
while pi < len(paddle_row):
merged.append(paddle_row[pi])
pi += 1
while ti < len(tess_row):
tw = tess_row[ti]
if tw.get("conf", 0) >= 30:
merged.append(tw)
ti += 1
return merged
def _merge_paddle_tesseract(paddle_words: list, tess_words: list) -> list:
"""Merge word boxes from PaddleOCR and Tesseract using row-based sequence alignment.
Strategy:
1. Group each engine's words into rows (by Y-position clustering)
2. Match rows between engines (by vertical center proximity)
3. Within each matched row: merge sequences left-to-right, deduplicating
words that appear in both engines at the same sequence position
4. Unmatched rows from either engine: keep as-is
This prevents:
- Cross-line averaging (words from different lines being merged)
- Duplicate words (same word from both engines shown twice)
"""
if not paddle_words and not tess_words:
return []
if not paddle_words:
return [w for w in tess_words if w.get("conf", 0) >= 40]
if not tess_words:
return list(paddle_words)
# Step 1: Group into rows
paddle_rows = _group_words_into_rows(paddle_words)
tess_rows = _group_words_into_rows(tess_words)
# Step 2: Match rows between engines by vertical center proximity
used_tess_rows: set = set()
merged_all: list = []
for pr in paddle_rows:
pr_cy = _row_center_y(pr)
best_dist, best_tri = float("inf"), -1
for tri, tr in enumerate(tess_rows):
if tri in used_tess_rows:
continue
tr_cy = _row_center_y(tr)
dist = abs(pr_cy - tr_cy)
if dist < best_dist:
best_dist, best_tri = dist, tri
# Row height threshold — rows must be within ~1.5x typical line height
max_row_dist = max(
max((w.get("height", 20) for w in pr), default=20),
15,
)
if best_tri >= 0 and best_dist <= max_row_dist:
# Matched row — merge sequences
tr = tess_rows[best_tri]
used_tess_rows.add(best_tri)
merged_all.extend(_merge_row_sequences(pr, tr))
else:
# No matching Tesseract row — keep Paddle row as-is
merged_all.extend(pr)
# Add unmatched Tesseract rows
for tri, tr in enumerate(tess_rows):
if tri not in used_tess_rows:
for tw in tr:
if tw.get("conf", 0) >= 40:
merged_all.append(tw)
return merged_all
def _deduplicate_words(words: list) -> list:
"""Remove duplicate words with same text at overlapping positions.
PaddleOCR can return overlapping phrases (e.g. "von jm." and "jm. =")
that produce duplicate words after splitting. This pass removes them.
A word is a duplicate only when BOTH horizontal AND vertical overlap
exceed 50% — same text on the same visual line at the same position.
"""
if not words:
return words
result: list = []
for w in words:
wt = w.get("text", "").lower().strip()
if not wt:
continue
is_dup = False
w_right = w["left"] + w.get("width", 0)
w_bottom = w["top"] + w.get("height", 0)
for existing in result:
et = existing.get("text", "").lower().strip()
if wt != et:
continue
# Horizontal overlap
ox_l = max(w["left"], existing["left"])
ox_r = min(w_right, existing["left"] + existing.get("width", 0))
ox = max(0, ox_r - ox_l)
min_w = min(w.get("width", 1), existing.get("width", 1))
if min_w <= 0 or ox / min_w < 0.5:
continue
# Vertical overlap — must also be on the same line
oy_t = max(w["top"], existing["top"])
oy_b = min(w_bottom, existing["top"] + existing.get("height", 0))
oy = max(0, oy_b - oy_t)
min_h = min(w.get("height", 1), existing.get("height", 1))
if min_h > 0 and oy / min_h >= 0.5:
is_dup = True
break
if not is_dup:
result.append(w)
removed = len(words) - len(result)
if removed:
logger.info("dedup: removed %d duplicate words", removed)
return result
# ---------------------------------------------------------------------------
# Kombi endpoints
# ---------------------------------------------------------------------------
@router.post("/sessions/{session_id}/paddle-kombi")
async def paddle_kombi(session_id: str):
"""Run PaddleOCR + Tesseract on the preprocessed image and merge results.
Both engines run on the same preprocessed (cropped/dewarped) image.
Word boxes are matched by IoU and coordinates are averaged weighted by
confidence. Unmatched Tesseract words (bullets, symbols) are added.
"""
img_png = await get_session_image(session_id, "cropped")
if not img_png:
img_png = await get_session_image(session_id, "dewarped")
if not img_png:
img_png = await get_session_image(session_id, "original")
if not img_png:
raise HTTPException(status_code=404, detail="No image found for this session")
img_arr = np.frombuffer(img_png, dtype=np.uint8)
img_bgr = cv2.imdecode(img_arr, cv2.IMREAD_COLOR)
if img_bgr is None:
raise HTTPException(status_code=400, detail="Failed to decode image")
img_h, img_w = img_bgr.shape[:2]
from cv_ocr_engines import ocr_region_paddle
t0 = time.time()
# --- PaddleOCR ---
paddle_words = await ocr_region_paddle(img_bgr, region=None)
if not paddle_words:
paddle_words = []
# --- Tesseract ---
def _run_tesseract_words(img_bgr) -> list:
"""Run Tesseract OCR on an image and return word dicts."""
from PIL import Image
import pytesseract
@@ -397,15 +60,98 @@ async def paddle_kombi(session_id: str):
"height": data["height"][i],
"conf": conf,
})
return tess_words
def _build_kombi_word_result(
cells: list,
columns_meta: list,
img_w: int,
img_h: int,
duration: float,
engine_name: str,
raw_engine_words: list,
raw_engine_words_split: list,
tess_words: list,
merged_words: list,
raw_engine_key: str = "raw_paddle_words",
raw_split_key: str = "raw_paddle_words_split",
) -> dict:
"""Build the word_result dict for kombi endpoints."""
n_rows = len(set(c["row_index"] for c in cells)) if cells else 0
n_cols = len(columns_meta)
col_types = {c.get("type") for c in columns_meta}
is_vocab = bool(col_types & {"column_en", "column_de"})
return {
"cells": cells,
"grid_shape": {"rows": n_rows, "cols": n_cols, "total_cells": len(cells)},
"columns_used": columns_meta,
"layout": "vocab" if is_vocab else "generic",
"image_width": img_w,
"image_height": img_h,
"duration_seconds": round(duration, 2),
"ocr_engine": engine_name,
"grid_method": engine_name,
raw_engine_key: raw_engine_words,
raw_split_key: raw_engine_words_split,
"raw_tesseract_words": tess_words,
"summary": {
"total_cells": len(cells),
"non_empty_cells": sum(1 for c in cells if c.get("text")),
"low_confidence": sum(1 for c in cells if 0 < c.get("confidence", 0) < 50),
raw_engine_key.replace("raw_", "").replace("_words", "_words"): len(raw_engine_words),
raw_split_key.replace("raw_", "").replace("_words_split", "_words_split"): len(raw_engine_words_split),
"tesseract_words": len(tess_words),
"merged_words": len(merged_words),
},
}
async def _load_session_image(session_id: str):
"""Load preprocessed image for kombi endpoints."""
img_png = await get_session_image(session_id, "cropped")
if not img_png:
img_png = await get_session_image(session_id, "dewarped")
if not img_png:
img_png = await get_session_image(session_id, "original")
if not img_png:
raise HTTPException(status_code=404, detail="No image found for this session")
img_arr = np.frombuffer(img_png, dtype=np.uint8)
img_bgr = cv2.imdecode(img_arr, cv2.IMREAD_COLOR)
if img_bgr is None:
raise HTTPException(status_code=400, detail="Failed to decode image")
return img_png, img_bgr
# ---------------------------------------------------------------------------
# Kombi endpoints
# ---------------------------------------------------------------------------
@router.post("/sessions/{session_id}/paddle-kombi")
async def paddle_kombi(session_id: str):
"""Run PaddleOCR + Tesseract on the preprocessed image and merge results."""
img_png, img_bgr = await _load_session_image(session_id)
img_h, img_w = img_bgr.shape[:2]
from cv_ocr_engines import ocr_region_paddle
t0 = time.time()
paddle_words = await ocr_region_paddle(img_bgr, region=None)
if not paddle_words:
paddle_words = []
tess_words = _run_tesseract_words(img_bgr)
# --- Split multi-word Paddle boxes into individual words ---
paddle_words_split = _split_paddle_multi_words(paddle_words)
logger.info(
"paddle_kombi: split %d paddle boxes %d individual words",
"paddle_kombi: split %d paddle boxes -> %d individual words",
len(paddle_words), len(paddle_words_split),
)
# --- Merge ---
if not paddle_words_split and not tess_words:
raise HTTPException(status_code=400, detail="Both OCR engines returned no words")
@@ -418,49 +164,23 @@ async def paddle_kombi(session_id: str):
for cell in cells:
cell["ocr_engine"] = "kombi"
n_rows = len(set(c["row_index"] for c in cells)) if cells else 0
n_cols = len(columns_meta)
col_types = {c.get("type") for c in columns_meta}
is_vocab = bool(col_types & {"column_en", "column_de"})
word_result = {
"cells": cells,
"grid_shape": {"rows": n_rows, "cols": n_cols, "total_cells": len(cells)},
"columns_used": columns_meta,
"layout": "vocab" if is_vocab else "generic",
"image_width": img_w,
"image_height": img_h,
"duration_seconds": round(duration, 2),
"ocr_engine": "kombi",
"grid_method": "kombi",
"raw_paddle_words": paddle_words,
"raw_paddle_words_split": paddle_words_split,
"raw_tesseract_words": tess_words,
"summary": {
"total_cells": len(cells),
"non_empty_cells": sum(1 for c in cells if c.get("text")),
"low_confidence": sum(1 for c in cells if 0 < c.get("confidence", 0) < 50),
"paddle_words": len(paddle_words),
"paddle_words_split": len(paddle_words_split),
"tesseract_words": len(tess_words),
"merged_words": len(merged_words),
},
}
word_result = _build_kombi_word_result(
cells, columns_meta, img_w, img_h, duration, "kombi",
paddle_words, paddle_words_split, tess_words, merged_words,
"raw_paddle_words", "raw_paddle_words_split",
)
await update_session_db(
session_id,
word_result=word_result,
cropped_png=img_png,
current_step=8,
session_id, word_result=word_result, cropped_png=img_png, current_step=8,
)
# Update in-memory cache so detect-structure can access word_result
if session_id in _cache:
_cache[session_id]["word_result"] = word_result
logger.info(
"paddle_kombi session %s: %d cells (%d rows, %d cols) in %.2fs "
"[paddle=%d, tess=%d, merged=%d]",
session_id, len(cells), n_rows, n_cols, duration,
session_id, len(cells), word_result["grid_shape"]["rows"],
word_result["grid_shape"]["cols"], duration,
len(paddle_words), len(tess_words), len(merged_words),
)
@@ -478,24 +198,8 @@ async def paddle_kombi(session_id: str):
@router.post("/sessions/{session_id}/rapid-kombi")
async def rapid_kombi(session_id: str):
"""Run RapidOCR + Tesseract on the preprocessed image and merge results.
Same merge logic as paddle-kombi, but uses local RapidOCR (ONNX Runtime)
instead of remote PaddleOCR service.
"""
img_png = await get_session_image(session_id, "cropped")
if not img_png:
img_png = await get_session_image(session_id, "dewarped")
if not img_png:
img_png = await get_session_image(session_id, "original")
if not img_png:
raise HTTPException(status_code=404, detail="No image found for this session")
img_arr = np.frombuffer(img_png, dtype=np.uint8)
img_bgr = cv2.imdecode(img_arr, cv2.IMREAD_COLOR)
if img_bgr is None:
raise HTTPException(status_code=400, detail="Failed to decode image")
"""Run RapidOCR + Tesseract on the preprocessed image and merge results."""
img_png, img_bgr = await _load_session_image(session_id)
img_h, img_w = img_bgr.shape[:2]
from cv_ocr_engines import ocr_region_rapid
@@ -503,7 +207,6 @@ async def rapid_kombi(session_id: str):
t0 = time.time()
# --- RapidOCR (local, synchronous) ---
full_region = PageRegion(
type="full_page", x=0, y=0, width=img_w, height=img_h,
)
@@ -511,40 +214,14 @@ async def rapid_kombi(session_id: str):
if not rapid_words:
rapid_words = []
# --- Tesseract ---
from PIL import Image
import pytesseract
tess_words = _run_tesseract_words(img_bgr)
pil_img = Image.fromarray(cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB))
data = pytesseract.image_to_data(
pil_img, lang="eng+deu",
config="--psm 6 --oem 3",
output_type=pytesseract.Output.DICT,
)
tess_words = []
for i in range(len(data["text"])):
text = str(data["text"][i]).strip()
conf_raw = str(data["conf"][i])
conf = int(conf_raw) if conf_raw.lstrip("-").isdigit() else -1
if not text or conf < 20:
continue
tess_words.append({
"text": text,
"left": data["left"][i],
"top": data["top"][i],
"width": data["width"][i],
"height": data["height"][i],
"conf": conf,
})
# --- Split multi-word RapidOCR boxes into individual words ---
rapid_words_split = _split_paddle_multi_words(rapid_words)
logger.info(
"rapid_kombi: split %d rapid boxes %d individual words",
"rapid_kombi: split %d rapid boxes -> %d individual words",
len(rapid_words), len(rapid_words_split),
)
# --- Merge ---
if not rapid_words_split and not tess_words:
raise HTTPException(status_code=400, detail="Both OCR engines returned no words")
@@ -557,49 +234,23 @@ async def rapid_kombi(session_id: str):
for cell in cells:
cell["ocr_engine"] = "rapid_kombi"
n_rows = len(set(c["row_index"] for c in cells)) if cells else 0
n_cols = len(columns_meta)
col_types = {c.get("type") for c in columns_meta}
is_vocab = bool(col_types & {"column_en", "column_de"})
word_result = {
"cells": cells,
"grid_shape": {"rows": n_rows, "cols": n_cols, "total_cells": len(cells)},
"columns_used": columns_meta,
"layout": "vocab" if is_vocab else "generic",
"image_width": img_w,
"image_height": img_h,
"duration_seconds": round(duration, 2),
"ocr_engine": "rapid_kombi",
"grid_method": "rapid_kombi",
"raw_rapid_words": rapid_words,
"raw_rapid_words_split": rapid_words_split,
"raw_tesseract_words": tess_words,
"summary": {
"total_cells": len(cells),
"non_empty_cells": sum(1 for c in cells if c.get("text")),
"low_confidence": sum(1 for c in cells if 0 < c.get("confidence", 0) < 50),
"rapid_words": len(rapid_words),
"rapid_words_split": len(rapid_words_split),
"tesseract_words": len(tess_words),
"merged_words": len(merged_words),
},
}
word_result = _build_kombi_word_result(
cells, columns_meta, img_w, img_h, duration, "rapid_kombi",
rapid_words, rapid_words_split, tess_words, merged_words,
"raw_rapid_words", "raw_rapid_words_split",
)
await update_session_db(
session_id,
word_result=word_result,
cropped_png=img_png,
current_step=8,
session_id, word_result=word_result, cropped_png=img_png, current_step=8,
)
# Update in-memory cache so detect-structure can access word_result
if session_id in _cache:
_cache[session_id]["word_result"] = word_result
logger.info(
"rapid_kombi session %s: %d cells (%d rows, %d cols) in %.2fs "
"[rapid=%d, tess=%d, merged=%d]",
session_id, len(cells), n_rows, n_cols, duration,
session_id, len(cells), word_result["grid_shape"]["rows"],
word_result["grid_shape"]["cols"], duration,
len(rapid_words), len(tess_words), len(merged_words),
)

View File

@@ -1,929 +1,26 @@
"""
OCR Pipeline Postprocessing API — LLM review, reconstruction, export, validation,
image detection/generation, and handwriting removal endpoints.
OCR Pipeline Postprocessing API — composite router assembling LLM review,
reconstruction, export, validation, image detection/generation, and
handwriting removal endpoints.
Extracted from ocr_pipeline_api.py to keep the main module manageable.
Split into sub-modules:
ocr_pipeline_llm_review — LLM review + apply corrections
ocr_pipeline_reconstruction — reconstruction save, Fabric JSON, merged entries, PDF/DOCX
ocr_pipeline_validation — image detection, generation, validation, handwriting removal
Lizenz: Apache 2.0
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
"""
import json
import logging
import os
import re
from datetime import datetime
from typing import Any, Dict, List, Optional
from fastapi import APIRouter, HTTPException, Request
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from cv_vocab_pipeline import (
OLLAMA_REVIEW_MODEL,
llm_review_entries,
llm_review_entries_streaming,
)
from ocr_pipeline_session_store import (
get_session_db,
get_session_image,
get_sub_sessions,
update_session_db,
)
from ocr_pipeline_common import (
_cache,
_load_session_to_cache,
_get_cached,
_get_base_image_png,
_append_pipeline_log,
RemoveHandwritingRequest,
)
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["ocr-pipeline"])
# ---------------------------------------------------------------------------
# Pydantic Models
# ---------------------------------------------------------------------------
STYLE_SUFFIXES = {
"educational": "educational illustration, textbook style, clear, colorful",
"cartoon": "cartoon, child-friendly, simple shapes",
"sketch": "pencil sketch, hand-drawn, black and white",
"clipart": "clipart, flat vector style, simple",
"realistic": "photorealistic, high detail",
}
class ValidationRequest(BaseModel):
notes: Optional[str] = None
score: Optional[int] = None
class GenerateImageRequest(BaseModel):
region_index: int
prompt: str
style: str = "educational"
# ---------------------------------------------------------------------------
# Step 8: LLM Review
# ---------------------------------------------------------------------------
@router.post("/sessions/{session_id}/llm-review")
async def run_llm_review(session_id: str, request: Request, stream: bool = False):
"""Run LLM-based correction on vocab entries from Step 5.
Query params:
stream: false (default) for JSON response, true for SSE streaming
"""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
word_result = session.get("word_result")
if not word_result:
raise HTTPException(status_code=400, detail="No word result found — run Step 5 first")
entries = word_result.get("vocab_entries") or word_result.get("entries") or []
if not entries:
raise HTTPException(status_code=400, detail="No vocab entries found — run Step 5 first")
# Optional model override from request body
body = {}
try:
body = await request.json()
except Exception:
pass
model = body.get("model") or OLLAMA_REVIEW_MODEL
if stream:
return StreamingResponse(
_llm_review_stream_generator(session_id, entries, word_result, model, request),
media_type="text/event-stream",
headers={"Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no"},
)
# Non-streaming path
try:
result = await llm_review_entries(entries, model=model)
except Exception as e:
import traceback
logger.error(f"LLM review failed for session {session_id}: {type(e).__name__}: {e}\n{traceback.format_exc()}")
raise HTTPException(status_code=502, detail=f"LLM review failed ({type(e).__name__}): {e}")
# Store result inside word_result as a sub-key
word_result["llm_review"] = {
"changes": result["changes"],
"model_used": result["model_used"],
"duration_ms": result["duration_ms"],
"entries_corrected": result["entries_corrected"],
}
await update_session_db(session_id, word_result=word_result, current_step=9)
if session_id in _cache:
_cache[session_id]["word_result"] = word_result
logger.info(f"LLM review session {session_id}: {len(result['changes'])} changes, "
f"{result['duration_ms']}ms, model={result['model_used']}")
await _append_pipeline_log(session_id, "correction", {
"engine": "llm",
"model": result["model_used"],
"total_entries": len(entries),
"corrections_proposed": len(result["changes"]),
}, duration_ms=result["duration_ms"])
return {
"session_id": session_id,
"changes": result["changes"],
"model_used": result["model_used"],
"duration_ms": result["duration_ms"],
"total_entries": len(entries),
"corrections_found": len(result["changes"]),
}
async def _llm_review_stream_generator(
session_id: str,
entries: List[Dict],
word_result: Dict,
model: str,
request: Request,
):
"""SSE generator that yields batch-by-batch LLM review progress."""
try:
async for event in llm_review_entries_streaming(entries, model=model):
if await request.is_disconnected():
logger.info(f"SSE: client disconnected during LLM review for {session_id}")
return
yield f"data: {json.dumps(event, ensure_ascii=False)}\n\n"
# On complete: persist to DB
if event.get("type") == "complete":
word_result["llm_review"] = {
"changes": event["changes"],
"model_used": event["model_used"],
"duration_ms": event["duration_ms"],
"entries_corrected": event["entries_corrected"],
}
await update_session_db(session_id, word_result=word_result, current_step=9)
if session_id in _cache:
_cache[session_id]["word_result"] = word_result
logger.info(f"LLM review SSE session {session_id}: {event['corrections_found']} changes, "
f"{event['duration_ms']}ms, skipped={event['skipped']}, model={event['model_used']}")
except Exception as e:
import traceback
logger.error(f"LLM review SSE failed for {session_id}: {type(e).__name__}: {e}\n{traceback.format_exc()}")
error_event = {"type": "error", "detail": f"{type(e).__name__}: {e}"}
yield f"data: {json.dumps(error_event)}\n\n"
@router.post("/sessions/{session_id}/llm-review/apply")
async def apply_llm_corrections(session_id: str, request: Request):
"""Apply selected LLM corrections to vocab entries."""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
word_result = session.get("word_result")
if not word_result:
raise HTTPException(status_code=400, detail="No word result found")
llm_review = word_result.get("llm_review")
if not llm_review:
raise HTTPException(status_code=400, detail="No LLM review found — run /llm-review first")
body = await request.json()
accepted_indices = set(body.get("accepted_indices", [])) # indices into changes[]
changes = llm_review.get("changes", [])
entries = word_result.get("vocab_entries") or word_result.get("entries") or []
# Build a lookup: (row_index, field) -> new_value for accepted changes
corrections = {}
applied_count = 0
for idx, change in enumerate(changes):
if idx in accepted_indices:
key = (change["row_index"], change["field"])
corrections[key] = change["new"]
applied_count += 1
# Apply corrections to entries
for entry in entries:
row_idx = entry.get("row_index", -1)
for field_name in ("english", "german", "example"):
key = (row_idx, field_name)
if key in corrections:
entry[field_name] = corrections[key]
entry["llm_corrected"] = True
# Update word_result
word_result["vocab_entries"] = entries
word_result["entries"] = entries
word_result["llm_review"]["applied_count"] = applied_count
word_result["llm_review"]["applied_at"] = datetime.utcnow().isoformat()
await update_session_db(session_id, word_result=word_result)
if session_id in _cache:
_cache[session_id]["word_result"] = word_result
logger.info(f"Applied {applied_count}/{len(changes)} LLM corrections for session {session_id}")
return {
"session_id": session_id,
"applied_count": applied_count,
"total_changes": len(changes),
}
# ---------------------------------------------------------------------------
# Step 9: Reconstruction + Fabric JSON export
# ---------------------------------------------------------------------------
@router.post("/sessions/{session_id}/reconstruction")
async def save_reconstruction(session_id: str, request: Request):
"""Save edited cell texts from reconstruction step."""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
word_result = session.get("word_result")
if not word_result:
raise HTTPException(status_code=400, detail="No word result found")
body = await request.json()
cell_updates = body.get("cells", [])
if not cell_updates:
await update_session_db(session_id, current_step=10)
return {"session_id": session_id, "updated": 0}
# Build update map: cell_id -> new text
update_map = {c["cell_id"]: c["text"] for c in cell_updates}
# Separate sub-session updates (cell_ids prefixed with "box{N}_")
sub_updates: Dict[int, Dict[str, str]] = {} # box_index -> {original_cell_id: text}
main_updates: Dict[str, str] = {}
for cell_id, text in update_map.items():
m = re.match(r'^box(\d+)_(.+)$', cell_id)
if m:
bi = int(m.group(1))
original_id = m.group(2)
sub_updates.setdefault(bi, {})[original_id] = text
else:
main_updates[cell_id] = text
# Update main session cells
cells = word_result.get("cells", [])
updated_count = 0
for cell in cells:
if cell["cell_id"] in main_updates:
cell["text"] = main_updates[cell["cell_id"]]
cell["status"] = "edited"
updated_count += 1
word_result["cells"] = cells
# Also update vocab_entries if present
entries = word_result.get("vocab_entries") or word_result.get("entries") or []
if entries:
# Map cell_id pattern "R{row}_C{col}" to entry fields
for entry in entries:
row_idx = entry.get("row_index", -1)
# Check each field's cell
for col_idx, field_name in enumerate(["english", "german", "example"]):
cell_id = f"R{row_idx:02d}_C{col_idx}"
# Also try without zero-padding
cell_id_alt = f"R{row_idx}_C{col_idx}"
new_text = main_updates.get(cell_id) or main_updates.get(cell_id_alt)
if new_text is not None:
entry[field_name] = new_text
word_result["vocab_entries"] = entries
if "entries" in word_result:
word_result["entries"] = entries
await update_session_db(session_id, word_result=word_result, current_step=10)
if session_id in _cache:
_cache[session_id]["word_result"] = word_result
# Route sub-session updates
sub_updated = 0
if sub_updates:
subs = await get_sub_sessions(session_id)
sub_by_index = {s.get("box_index"): s["id"] for s in subs}
for bi, updates in sub_updates.items():
sub_id = sub_by_index.get(bi)
if not sub_id:
continue
sub_session = await get_session_db(sub_id)
if not sub_session:
continue
sub_word = sub_session.get("word_result")
if not sub_word:
continue
sub_cells = sub_word.get("cells", [])
for cell in sub_cells:
if cell["cell_id"] in updates:
cell["text"] = updates[cell["cell_id"]]
cell["status"] = "edited"
sub_updated += 1
sub_word["cells"] = sub_cells
await update_session_db(sub_id, word_result=sub_word)
if sub_id in _cache:
_cache[sub_id]["word_result"] = sub_word
total_updated = updated_count + sub_updated
logger.info(f"Reconstruction saved for session {session_id}: "
f"{updated_count} main + {sub_updated} sub-session cells updated")
return {
"session_id": session_id,
"updated": total_updated,
"main_updated": updated_count,
"sub_updated": sub_updated,
}
@router.get("/sessions/{session_id}/reconstruction/fabric-json")
async def get_fabric_json(session_id: str):
"""Return cell grid as Fabric.js-compatible JSON for the canvas editor.
If the session has sub-sessions (box regions), their cells are merged
into the result at the correct Y positions.
"""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
word_result = session.get("word_result")
if not word_result:
raise HTTPException(status_code=400, detail="No word result found")
cells = list(word_result.get("cells", []))
img_w = word_result.get("image_width", 800)
img_h = word_result.get("image_height", 600)
# Merge sub-session cells at box positions
subs = await get_sub_sessions(session_id)
if subs:
column_result = session.get("column_result") or {}
zones = column_result.get("zones") or []
box_zones = [z for z in zones if z.get("zone_type") == "box" and z.get("box")]
for sub in subs:
sub_session = await get_session_db(sub["id"])
if not sub_session:
continue
sub_word = sub_session.get("word_result")
if not sub_word or not sub_word.get("cells"):
continue
bi = sub.get("box_index", 0)
if bi < len(box_zones):
box = box_zones[bi]["box"]
box_y, box_x = box["y"], box["x"]
else:
box_y, box_x = 0, 0
# Offset sub-session cells to absolute page coordinates
for cell in sub_word["cells"]:
cell_copy = dict(cell)
# Prefix cell_id with box index
cell_copy["cell_id"] = f"box{bi}_{cell_copy.get('cell_id', '')}"
cell_copy["source"] = f"box_{bi}"
# Offset bbox_px
bbox = cell_copy.get("bbox_px", {})
if bbox:
bbox = dict(bbox)
bbox["x"] = bbox.get("x", 0) + box_x
bbox["y"] = bbox.get("y", 0) + box_y
cell_copy["bbox_px"] = bbox
cells.append(cell_copy)
from services.layout_reconstruction_service import cells_to_fabric_json
fabric_json = cells_to_fabric_json(cells, img_w, img_h)
return fabric_json
# ---------------------------------------------------------------------------
# Vocab entries merged + PDF/DOCX export
# ---------------------------------------------------------------------------
@router.get("/sessions/{session_id}/vocab-entries/merged")
async def get_merged_vocab_entries(session_id: str):
"""Return vocab entries from main session + all sub-sessions, sorted by Y position."""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
word_result = session.get("word_result") or {}
entries = list(word_result.get("vocab_entries") or word_result.get("entries") or [])
# Tag main entries
for e in entries:
e.setdefault("source", "main")
# Merge sub-session entries
subs = await get_sub_sessions(session_id)
if subs:
column_result = session.get("column_result") or {}
zones = column_result.get("zones") or []
box_zones = [z for z in zones if z.get("zone_type") == "box" and z.get("box")]
for sub in subs:
sub_session = await get_session_db(sub["id"])
if not sub_session:
continue
sub_word = sub_session.get("word_result") or {}
sub_entries = sub_word.get("vocab_entries") or sub_word.get("entries") or []
bi = sub.get("box_index", 0)
box_y = 0
if bi < len(box_zones):
box_y = box_zones[bi]["box"]["y"]
for e in sub_entries:
e_copy = dict(e)
e_copy["source"] = f"box_{bi}"
e_copy["source_y"] = box_y # for sorting
entries.append(e_copy)
# Sort by approximate Y position
def _sort_key(e):
if e.get("source", "main") == "main":
return e.get("row_index", 0) * 100 # main entries by row index
return e.get("source_y", 0) * 100 + e.get("row_index", 0)
entries.sort(key=_sort_key)
return {
"session_id": session_id,
"entries": entries,
"total": len(entries),
"sources": list(set(e.get("source", "main") for e in entries)),
}
@router.get("/sessions/{session_id}/reconstruction/export/pdf")
async def export_reconstruction_pdf(session_id: str):
"""Export the reconstructed cell grid as a PDF table."""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
word_result = session.get("word_result")
if not word_result:
raise HTTPException(status_code=400, detail="No word result found")
cells = word_result.get("cells", [])
columns_used = word_result.get("columns_used", [])
grid_shape = word_result.get("grid_shape", {})
n_rows = grid_shape.get("rows", 0)
n_cols = grid_shape.get("cols", 0)
# Build table data: rows x columns
table_data: list[list[str]] = []
header = [c.get("label", c.get("type", f"Col {i}")) for i, c in enumerate(columns_used)]
if not header:
header = [f"Col {i}" for i in range(n_cols)]
table_data.append(header)
for r in range(n_rows):
row_texts = []
for ci in range(n_cols):
cell_id = f"R{r:02d}_C{ci}"
cell = next((c for c in cells if c.get("cell_id") == cell_id), None)
row_texts.append(cell.get("text", "") if cell else "")
table_data.append(row_texts)
# Generate PDF with reportlab
try:
from reportlab.lib.pagesizes import A4
from reportlab.lib import colors
from reportlab.platypus import SimpleDocTemplate, Table, TableStyle
import io as _io
buf = _io.BytesIO()
doc = SimpleDocTemplate(buf, pagesize=A4)
if not table_data or not table_data[0]:
raise HTTPException(status_code=400, detail="No data to export")
t = Table(table_data)
t.setStyle(TableStyle([
('BACKGROUND', (0, 0), (-1, 0), colors.HexColor('#0d9488')),
('TEXTCOLOR', (0, 0), (-1, 0), colors.white),
('FONTSIZE', (0, 0), (-1, -1), 9),
('GRID', (0, 0), (-1, -1), 0.5, colors.grey),
('VALIGN', (0, 0), (-1, -1), 'TOP'),
('WORDWRAP', (0, 0), (-1, -1), True),
]))
doc.build([t])
buf.seek(0)
from fastapi.responses import StreamingResponse
return StreamingResponse(
buf,
media_type="application/pdf",
headers={"Content-Disposition": f'attachment; filename="reconstruction_{session_id}.pdf"'},
)
except ImportError:
raise HTTPException(status_code=501, detail="reportlab not installed")
@router.get("/sessions/{session_id}/reconstruction/export/docx")
async def export_reconstruction_docx(session_id: str):
"""Export the reconstructed cell grid as a DOCX table."""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
word_result = session.get("word_result")
if not word_result:
raise HTTPException(status_code=400, detail="No word result found")
cells = word_result.get("cells", [])
columns_used = word_result.get("columns_used", [])
grid_shape = word_result.get("grid_shape", {})
n_rows = grid_shape.get("rows", 0)
n_cols = grid_shape.get("cols", 0)
try:
from docx import Document
from docx.shared import Pt
import io as _io
doc = Document()
doc.add_heading(f'Rekonstruktion Session {session_id[:8]}', level=1)
# Build header
header = [c.get("label", c.get("type", f"Col {i}")) for i, c in enumerate(columns_used)]
if not header:
header = [f"Col {i}" for i in range(n_cols)]
table = doc.add_table(rows=1 + n_rows, cols=max(n_cols, 1))
table.style = 'Table Grid'
# Header row
for ci, h in enumerate(header):
table.rows[0].cells[ci].text = h
# Data rows
for r in range(n_rows):
for ci in range(n_cols):
cell_id = f"R{r:02d}_C{ci}"
cell = next((c for c in cells if c.get("cell_id") == cell_id), None)
table.rows[r + 1].cells[ci].text = cell.get("text", "") if cell else ""
buf = _io.BytesIO()
doc.save(buf)
buf.seek(0)
from fastapi.responses import StreamingResponse
return StreamingResponse(
buf,
media_type="application/vnd.openxmlformats-officedocument.wordprocessingml.document",
headers={"Content-Disposition": f'attachment; filename="reconstruction_{session_id}.docx"'},
)
except ImportError:
raise HTTPException(status_code=501, detail="python-docx not installed")
# ---------------------------------------------------------------------------
# Step 8: Validation — Original vs. Reconstruction
# ---------------------------------------------------------------------------
@router.post("/sessions/{session_id}/reconstruction/detect-images")
async def detect_image_regions(session_id: str):
"""Detect illustration/image regions in the original scan using VLM.
Sends the original image to qwen2.5vl to find non-text, non-table
image areas, returning bounding boxes (in %) and descriptions.
"""
import base64
import httpx
import re
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
# Get original image bytes
original_png = await get_session_image(session_id, "original")
if not original_png:
raise HTTPException(status_code=400, detail="No original image found")
# Build context from vocab entries for richer descriptions
word_result = session.get("word_result") or {}
entries = word_result.get("vocab_entries") or word_result.get("entries") or []
vocab_context = ""
if entries:
sample = entries[:10]
words = [f"{e.get('english', '')} / {e.get('german', '')}" for e in sample if e.get('english')]
if words:
vocab_context = f"\nContext: This is a vocabulary page with words like: {', '.join(words)}"
ollama_base = os.getenv("OLLAMA_BASE_URL", "http://host.docker.internal:11434")
model = os.getenv("OLLAMA_HTR_MODEL", "qwen2.5vl:32b")
prompt = (
"Analyze this scanned page. Find ALL illustration/image/picture regions "
"(NOT text, NOT table cells, NOT blank areas). "
"For each image region found, return its bounding box as percentage of page dimensions "
"and a short English description of what the image shows. "
"Reply with ONLY a JSON array like: "
'[{"x": 10, "y": 20, "w": 30, "h": 25, "description": "drawing of a cat"}] '
"where x, y, w, h are percentages (0-100) of the page width/height. "
"If there are NO images on the page, return an empty array: []"
f"{vocab_context}"
)
img_b64 = base64.b64encode(original_png).decode("utf-8")
payload = {
"model": model,
"prompt": prompt,
"images": [img_b64],
"stream": False,
}
try:
async with httpx.AsyncClient(timeout=120.0) as client:
resp = await client.post(f"{ollama_base}/api/generate", json=payload)
resp.raise_for_status()
text = resp.json().get("response", "")
# Parse JSON array from response
match = re.search(r'\[.*?\]', text, re.DOTALL)
if match:
raw_regions = json.loads(match.group(0))
else:
raw_regions = []
# Normalize to ImageRegion format
regions = []
for r in raw_regions:
regions.append({
"bbox_pct": {
"x": max(0, min(100, float(r.get("x", 0)))),
"y": max(0, min(100, float(r.get("y", 0)))),
"w": max(1, min(100, float(r.get("w", 10)))),
"h": max(1, min(100, float(r.get("h", 10)))),
},
"description": r.get("description", ""),
"prompt": r.get("description", ""),
"image_b64": None,
"style": "educational",
})
# Enrich prompts with nearby vocab context
if entries:
for region in regions:
ry = region["bbox_pct"]["y"]
rh = region["bbox_pct"]["h"]
nearby = [
e for e in entries
if e.get("bbox") and abs(e["bbox"].get("y", 0) - ry) < rh + 10
]
if nearby:
en_words = [e.get("english", "") for e in nearby if e.get("english")]
de_words = [e.get("german", "") for e in nearby if e.get("german")]
if en_words or de_words:
context = f" (vocabulary context: {', '.join(en_words[:5])}"
if de_words:
context += f" / {', '.join(de_words[:5])}"
context += ")"
region["prompt"] = region["description"] + context
# Save to ground_truth JSONB
ground_truth = session.get("ground_truth") or {}
validation = ground_truth.get("validation") or {}
validation["image_regions"] = regions
validation["detected_at"] = datetime.utcnow().isoformat()
ground_truth["validation"] = validation
await update_session_db(session_id, ground_truth=ground_truth)
if session_id in _cache:
_cache[session_id]["ground_truth"] = ground_truth
logger.info(f"Detected {len(regions)} image regions for session {session_id}")
return {"regions": regions, "count": len(regions)}
except httpx.ConnectError:
logger.warning(f"VLM not available at {ollama_base} for image detection")
return {"regions": [], "count": 0, "error": "VLM not available"}
except Exception as e:
logger.error(f"Image detection failed for {session_id}: {e}")
return {"regions": [], "count": 0, "error": str(e)}
@router.post("/sessions/{session_id}/reconstruction/generate-image")
async def generate_image_for_region(session_id: str, req: GenerateImageRequest):
"""Generate a replacement image for a detected region using mflux.
Sends the prompt (with style suffix) to the mflux-service running
natively on the Mac Mini (Metal GPU required).
"""
import httpx
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
ground_truth = session.get("ground_truth") or {}
validation = ground_truth.get("validation") or {}
regions = validation.get("image_regions") or []
if req.region_index < 0 or req.region_index >= len(regions):
raise HTTPException(status_code=400, detail=f"Invalid region_index {req.region_index}, have {len(regions)} regions")
mflux_url = os.getenv("MFLUX_URL", "http://host.docker.internal:8095")
style_suffix = STYLE_SUFFIXES.get(req.style, STYLE_SUFFIXES["educational"])
full_prompt = f"{req.prompt}, {style_suffix}"
# Determine image size from region aspect ratio (snap to multiples of 64)
region = regions[req.region_index]
bbox = region["bbox_pct"]
aspect = bbox["w"] / max(bbox["h"], 1)
if aspect > 1.3:
width, height = 768, 512
elif aspect < 0.7:
width, height = 512, 768
else:
width, height = 512, 512
try:
async with httpx.AsyncClient(timeout=300.0) as client:
resp = await client.post(f"{mflux_url}/generate", json={
"prompt": full_prompt,
"width": width,
"height": height,
"steps": 4,
})
resp.raise_for_status()
data = resp.json()
image_b64 = data.get("image_b64")
if not image_b64:
return {"image_b64": None, "success": False, "error": "No image returned"}
# Save to ground_truth
regions[req.region_index]["image_b64"] = image_b64
regions[req.region_index]["prompt"] = req.prompt
regions[req.region_index]["style"] = req.style
validation["image_regions"] = regions
ground_truth["validation"] = validation
await update_session_db(session_id, ground_truth=ground_truth)
if session_id in _cache:
_cache[session_id]["ground_truth"] = ground_truth
logger.info(f"Generated image for session {session_id} region {req.region_index}")
return {"image_b64": image_b64, "success": True}
except httpx.ConnectError:
logger.warning(f"mflux-service not available at {mflux_url}")
return {"image_b64": None, "success": False, "error": f"mflux-service not available at {mflux_url}"}
except Exception as e:
logger.error(f"Image generation failed for {session_id}: {e}")
return {"image_b64": None, "success": False, "error": str(e)}
@router.post("/sessions/{session_id}/reconstruction/validate")
async def save_validation(session_id: str, req: ValidationRequest):
"""Save final validation results for step 8.
Stores notes, score, and preserves any detected/generated image regions.
Sets current_step = 10 to mark pipeline as complete.
"""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
ground_truth = session.get("ground_truth") or {}
validation = ground_truth.get("validation") or {}
validation["validated_at"] = datetime.utcnow().isoformat()
validation["notes"] = req.notes
validation["score"] = req.score
ground_truth["validation"] = validation
await update_session_db(session_id, ground_truth=ground_truth, current_step=11)
if session_id in _cache:
_cache[session_id]["ground_truth"] = ground_truth
logger.info(f"Validation saved for session {session_id}: score={req.score}")
return {"session_id": session_id, "validation": validation}
@router.get("/sessions/{session_id}/reconstruction/validation")
async def get_validation(session_id: str):
"""Retrieve saved validation data for step 8."""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
ground_truth = session.get("ground_truth") or {}
validation = ground_truth.get("validation")
return {
"session_id": session_id,
"validation": validation,
"word_result": session.get("word_result"),
}
# ---------------------------------------------------------------------------
# Remove handwriting
# ---------------------------------------------------------------------------
@router.post("/sessions/{session_id}/remove-handwriting")
async def remove_handwriting_endpoint(session_id: str, req: RemoveHandwritingRequest):
"""
Remove handwriting from a session image using inpainting.
Steps:
1. Load source image (auto -> deskewed if available, else original)
2. Detect handwriting mask (filtered by target_ink)
3. Dilate mask to cover stroke edges
4. Inpaint the image
5. Store result as clean_png in the session
Returns metadata including the URL to fetch the clean image.
"""
import time as _time
t0 = _time.monotonic()
from services.handwriting_detection import detect_handwriting
from services.inpainting_service import inpaint_image, dilate_mask as _dilate_mask, InpaintingMethod, image_to_png
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
# 1. Determine source image
source = req.use_source
if source == "auto":
deskewed = await get_session_image(session_id, "deskewed")
source = "deskewed" if deskewed else "original"
image_bytes = await get_session_image(session_id, source)
if not image_bytes:
raise HTTPException(status_code=404, detail=f"Source image '{source}' not available")
# 2. Detect handwriting mask
detection = detect_handwriting(image_bytes, target_ink=req.target_ink)
# 3. Convert mask to PNG bytes and dilate
import io
from PIL import Image as _PILImage
mask_img = _PILImage.fromarray(detection.mask)
mask_buf = io.BytesIO()
mask_img.save(mask_buf, format="PNG")
mask_bytes = mask_buf.getvalue()
if req.dilation > 0:
mask_bytes = _dilate_mask(mask_bytes, iterations=req.dilation)
# 4. Inpaint
method_map = {
"telea": InpaintingMethod.OPENCV_TELEA,
"ns": InpaintingMethod.OPENCV_NS,
"auto": InpaintingMethod.AUTO,
}
inpaint_method = method_map.get(req.method, InpaintingMethod.AUTO)
result = inpaint_image(image_bytes, mask_bytes, method=inpaint_method)
if not result.success:
raise HTTPException(status_code=500, detail="Inpainting failed")
elapsed_ms = int((_time.monotonic() - t0) * 1000)
meta = {
"method_used": result.method_used.value if hasattr(result.method_used, "value") else str(result.method_used),
"handwriting_ratio": round(detection.handwriting_ratio, 4),
"detection_confidence": round(detection.confidence, 4),
"target_ink": req.target_ink,
"dilation": req.dilation,
"source_image": source,
"processing_time_ms": elapsed_ms,
}
# 5. Persist clean image (convert BGR ndarray -> PNG bytes)
clean_png_bytes = image_to_png(result.image)
await update_session_db(session_id, clean_png=clean_png_bytes, handwriting_removal_meta=meta)
return {
**meta,
"image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/clean",
"session_id": session_id,
}
from fastapi import APIRouter
from ocr_pipeline_llm_review import router as _llm_review_router
from ocr_pipeline_reconstruction import router as _reconstruction_router
from ocr_pipeline_validation import router as _validation_router
# Composite router — drop-in replacement for the old monolithic router.
# ocr_pipeline_api.py imports ``from ocr_pipeline_postprocess import router``.
router = APIRouter()
router.include_router(_llm_review_router)
router.include_router(_reconstruction_router)
router.include_router(_validation_router)

View File

@@ -0,0 +1,362 @@
"""
OCR Pipeline Reconstruction — save edits, Fabric JSON export, merged entries, PDF/DOCX export.
Extracted from ocr_pipeline_postprocess.py.
Lizenz: Apache 2.0
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
"""
import logging
import re
from typing import Dict
from fastapi import APIRouter, HTTPException, Request
from fastapi.responses import StreamingResponse
from ocr_pipeline_session_store import (
get_session_db,
get_sub_sessions,
update_session_db,
)
from ocr_pipeline_common import _cache
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["ocr-pipeline"])
# ---------------------------------------------------------------------------
# Step 9: Reconstruction + Fabric JSON export
# ---------------------------------------------------------------------------
@router.post("/sessions/{session_id}/reconstruction")
async def save_reconstruction(session_id: str, request: Request):
"""Save edited cell texts from reconstruction step."""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
word_result = session.get("word_result")
if not word_result:
raise HTTPException(status_code=400, detail="No word result found")
body = await request.json()
cell_updates = body.get("cells", [])
if not cell_updates:
await update_session_db(session_id, current_step=10)
return {"session_id": session_id, "updated": 0}
# Build update map: cell_id -> new text
update_map = {c["cell_id"]: c["text"] for c in cell_updates}
# Separate sub-session updates (cell_ids prefixed with "box{N}_")
sub_updates: Dict[int, Dict[str, str]] = {} # box_index -> {original_cell_id: text}
main_updates: Dict[str, str] = {}
for cell_id, text in update_map.items():
m = re.match(r'^box(\d+)_(.+)$', cell_id)
if m:
bi = int(m.group(1))
original_id = m.group(2)
sub_updates.setdefault(bi, {})[original_id] = text
else:
main_updates[cell_id] = text
# Update main session cells
cells = word_result.get("cells", [])
updated_count = 0
for cell in cells:
if cell["cell_id"] in main_updates:
cell["text"] = main_updates[cell["cell_id"]]
cell["status"] = "edited"
updated_count += 1
word_result["cells"] = cells
# Also update vocab_entries if present
entries = word_result.get("vocab_entries") or word_result.get("entries") or []
if entries:
for entry in entries:
row_idx = entry.get("row_index", -1)
for col_idx, field_name in enumerate(["english", "german", "example"]):
cell_id = f"R{row_idx:02d}_C{col_idx}"
cell_id_alt = f"R{row_idx}_C{col_idx}"
new_text = main_updates.get(cell_id) or main_updates.get(cell_id_alt)
if new_text is not None:
entry[field_name] = new_text
word_result["vocab_entries"] = entries
if "entries" in word_result:
word_result["entries"] = entries
await update_session_db(session_id, word_result=word_result, current_step=10)
if session_id in _cache:
_cache[session_id]["word_result"] = word_result
# Route sub-session updates
sub_updated = 0
if sub_updates:
subs = await get_sub_sessions(session_id)
sub_by_index = {s.get("box_index"): s["id"] for s in subs}
for bi, updates in sub_updates.items():
sub_id = sub_by_index.get(bi)
if not sub_id:
continue
sub_session = await get_session_db(sub_id)
if not sub_session:
continue
sub_word = sub_session.get("word_result")
if not sub_word:
continue
sub_cells = sub_word.get("cells", [])
for cell in sub_cells:
if cell["cell_id"] in updates:
cell["text"] = updates[cell["cell_id"]]
cell["status"] = "edited"
sub_updated += 1
sub_word["cells"] = sub_cells
await update_session_db(sub_id, word_result=sub_word)
if sub_id in _cache:
_cache[sub_id]["word_result"] = sub_word
total_updated = updated_count + sub_updated
logger.info(f"Reconstruction saved for session {session_id}: "
f"{updated_count} main + {sub_updated} sub-session cells updated")
return {
"session_id": session_id,
"updated": total_updated,
"main_updated": updated_count,
"sub_updated": sub_updated,
}
@router.get("/sessions/{session_id}/reconstruction/fabric-json")
async def get_fabric_json(session_id: str):
"""Return cell grid as Fabric.js-compatible JSON for the canvas editor."""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
word_result = session.get("word_result")
if not word_result:
raise HTTPException(status_code=400, detail="No word result found")
cells = list(word_result.get("cells", []))
img_w = word_result.get("image_width", 800)
img_h = word_result.get("image_height", 600)
# Merge sub-session cells at box positions
subs = await get_sub_sessions(session_id)
if subs:
column_result = session.get("column_result") or {}
zones = column_result.get("zones") or []
box_zones = [z for z in zones if z.get("zone_type") == "box" and z.get("box")]
for sub in subs:
sub_session = await get_session_db(sub["id"])
if not sub_session:
continue
sub_word = sub_session.get("word_result")
if not sub_word or not sub_word.get("cells"):
continue
bi = sub.get("box_index", 0)
if bi < len(box_zones):
box = box_zones[bi]["box"]
box_y, box_x = box["y"], box["x"]
else:
box_y, box_x = 0, 0
for cell in sub_word["cells"]:
cell_copy = dict(cell)
cell_copy["cell_id"] = f"box{bi}_{cell_copy.get('cell_id', '')}"
cell_copy["source"] = f"box_{bi}"
bbox = cell_copy.get("bbox_px", {})
if bbox:
bbox = dict(bbox)
bbox["x"] = bbox.get("x", 0) + box_x
bbox["y"] = bbox.get("y", 0) + box_y
cell_copy["bbox_px"] = bbox
cells.append(cell_copy)
from services.layout_reconstruction_service import cells_to_fabric_json
fabric_json = cells_to_fabric_json(cells, img_w, img_h)
return fabric_json
# ---------------------------------------------------------------------------
# Vocab entries merged + PDF/DOCX export
# ---------------------------------------------------------------------------
@router.get("/sessions/{session_id}/vocab-entries/merged")
async def get_merged_vocab_entries(session_id: str):
"""Return vocab entries from main session + all sub-sessions, sorted by Y position."""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
word_result = session.get("word_result") or {}
entries = list(word_result.get("vocab_entries") or word_result.get("entries") or [])
for e in entries:
e.setdefault("source", "main")
subs = await get_sub_sessions(session_id)
if subs:
column_result = session.get("column_result") or {}
zones = column_result.get("zones") or []
box_zones = [z for z in zones if z.get("zone_type") == "box" and z.get("box")]
for sub in subs:
sub_session = await get_session_db(sub["id"])
if not sub_session:
continue
sub_word = sub_session.get("word_result") or {}
sub_entries = sub_word.get("vocab_entries") or sub_word.get("entries") or []
bi = sub.get("box_index", 0)
box_y = 0
if bi < len(box_zones):
box_y = box_zones[bi]["box"]["y"]
for e in sub_entries:
e_copy = dict(e)
e_copy["source"] = f"box_{bi}"
e_copy["source_y"] = box_y
entries.append(e_copy)
def _sort_key(e):
if e.get("source", "main") == "main":
return e.get("row_index", 0) * 100
return e.get("source_y", 0) * 100 + e.get("row_index", 0)
entries.sort(key=_sort_key)
return {
"session_id": session_id,
"entries": entries,
"total": len(entries),
"sources": list(set(e.get("source", "main") for e in entries)),
}
@router.get("/sessions/{session_id}/reconstruction/export/pdf")
async def export_reconstruction_pdf(session_id: str):
"""Export the reconstructed cell grid as a PDF table."""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
word_result = session.get("word_result")
if not word_result:
raise HTTPException(status_code=400, detail="No word result found")
cells = word_result.get("cells", [])
columns_used = word_result.get("columns_used", [])
grid_shape = word_result.get("grid_shape", {})
n_rows = grid_shape.get("rows", 0)
n_cols = grid_shape.get("cols", 0)
# Build table data: rows x columns
table_data: list[list[str]] = []
header = [c.get("label", c.get("type", f"Col {i}")) for i, c in enumerate(columns_used)]
if not header:
header = [f"Col {i}" for i in range(n_cols)]
table_data.append(header)
for r in range(n_rows):
row_texts = []
for ci in range(n_cols):
cell_id = f"R{r:02d}_C{ci}"
cell = next((c for c in cells if c.get("cell_id") == cell_id), None)
row_texts.append(cell.get("text", "") if cell else "")
table_data.append(row_texts)
try:
from reportlab.lib.pagesizes import A4
from reportlab.lib import colors
from reportlab.platypus import SimpleDocTemplate, Table, TableStyle
import io as _io
buf = _io.BytesIO()
doc = SimpleDocTemplate(buf, pagesize=A4)
if not table_data or not table_data[0]:
raise HTTPException(status_code=400, detail="No data to export")
t = Table(table_data)
t.setStyle(TableStyle([
('BACKGROUND', (0, 0), (-1, 0), colors.HexColor('#0d9488')),
('TEXTCOLOR', (0, 0), (-1, 0), colors.white),
('FONTSIZE', (0, 0), (-1, -1), 9),
('GRID', (0, 0), (-1, -1), 0.5, colors.grey),
('VALIGN', (0, 0), (-1, -1), 'TOP'),
('WORDWRAP', (0, 0), (-1, -1), True),
]))
doc.build([t])
buf.seek(0)
return StreamingResponse(
buf,
media_type="application/pdf",
headers={"Content-Disposition": f'attachment; filename="reconstruction_{session_id}.pdf"'},
)
except ImportError:
raise HTTPException(status_code=501, detail="reportlab not installed")
@router.get("/sessions/{session_id}/reconstruction/export/docx")
async def export_reconstruction_docx(session_id: str):
"""Export the reconstructed cell grid as a DOCX table."""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
word_result = session.get("word_result")
if not word_result:
raise HTTPException(status_code=400, detail="No word result found")
cells = word_result.get("cells", [])
columns_used = word_result.get("columns_used", [])
grid_shape = word_result.get("grid_shape", {})
n_rows = grid_shape.get("rows", 0)
n_cols = grid_shape.get("cols", 0)
try:
from docx import Document
from docx.shared import Pt
import io as _io
doc = Document()
doc.add_heading(f'Rekonstruktion -- Session {session_id[:8]}', level=1)
header = [c.get("label", c.get("type", f"Col {i}")) for i, c in enumerate(columns_used)]
if not header:
header = [f"Col {i}" for i in range(n_cols)]
table = doc.add_table(rows=1 + n_rows, cols=max(n_cols, 1))
table.style = 'Table Grid'
for ci, h in enumerate(header):
table.rows[0].cells[ci].text = h
for r in range(n_rows):
for ci in range(n_cols):
cell_id = f"R{r:02d}_C{ci}"
cell = next((c for c in cells if c.get("cell_id") == cell_id), None)
table.rows[r + 1].cells[ci].text = cell.get("text", "") if cell else ""
buf = _io.BytesIO()
doc.save(buf)
buf.seek(0)
return StreamingResponse(
buf,
media_type="application/vnd.openxmlformats-officedocument.wordprocessingml.document",
headers={"Content-Disposition": f'attachment; filename="reconstruction_{session_id}.docx"'},
)
except ImportError:
raise HTTPException(status_code=501, detail="python-docx not installed")

View File

@@ -0,0 +1,362 @@
"""
OCR Pipeline Validation — image detection, generation, validation save,
and handwriting removal endpoints.
Extracted from ocr_pipeline_postprocess.py.
Lizenz: Apache 2.0
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
"""
import json
import logging
import os
from datetime import datetime
from typing import Optional
from fastapi import APIRouter, HTTPException
from pydantic import BaseModel
from ocr_pipeline_session_store import (
get_session_db,
get_session_image,
update_session_db,
)
from ocr_pipeline_common import (
_cache,
RemoveHandwritingRequest,
)
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["ocr-pipeline"])
# ---------------------------------------------------------------------------
# Pydantic Models
# ---------------------------------------------------------------------------
STYLE_SUFFIXES = {
"educational": "educational illustration, textbook style, clear, colorful",
"cartoon": "cartoon, child-friendly, simple shapes",
"sketch": "pencil sketch, hand-drawn, black and white",
"clipart": "clipart, flat vector style, simple",
"realistic": "photorealistic, high detail",
}
class ValidationRequest(BaseModel):
notes: Optional[str] = None
score: Optional[int] = None
class GenerateImageRequest(BaseModel):
region_index: int
prompt: str
style: str = "educational"
# ---------------------------------------------------------------------------
# Image detection + generation
# ---------------------------------------------------------------------------
@router.post("/sessions/{session_id}/reconstruction/detect-images")
async def detect_image_regions(session_id: str):
"""Detect illustration/image regions in the original scan using VLM."""
import base64
import httpx
import re
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
original_png = await get_session_image(session_id, "original")
if not original_png:
raise HTTPException(status_code=400, detail="No original image found")
word_result = session.get("word_result") or {}
entries = word_result.get("vocab_entries") or word_result.get("entries") or []
vocab_context = ""
if entries:
sample = entries[:10]
words = [f"{e.get('english', '')} / {e.get('german', '')}" for e in sample if e.get('english')]
if words:
vocab_context = f"\nContext: This is a vocabulary page with words like: {', '.join(words)}"
ollama_base = os.getenv("OLLAMA_BASE_URL", "http://host.docker.internal:11434")
model = os.getenv("OLLAMA_HTR_MODEL", "qwen2.5vl:32b")
prompt = (
"Analyze this scanned page. Find ALL illustration/image/picture regions "
"(NOT text, NOT table cells, NOT blank areas). "
"For each image region found, return its bounding box as percentage of page dimensions "
"and a short English description of what the image shows. "
"Reply with ONLY a JSON array like: "
'[{"x": 10, "y": 20, "w": 30, "h": 25, "description": "drawing of a cat"}] '
"where x, y, w, h are percentages (0-100) of the page width/height. "
"If there are NO images on the page, return an empty array: []"
f"{vocab_context}"
)
img_b64 = base64.b64encode(original_png).decode("utf-8")
payload = {
"model": model,
"prompt": prompt,
"images": [img_b64],
"stream": False,
}
try:
async with httpx.AsyncClient(timeout=120.0) as client:
resp = await client.post(f"{ollama_base}/api/generate", json=payload)
resp.raise_for_status()
text = resp.json().get("response", "")
match = re.search(r'\[.*?\]', text, re.DOTALL)
if match:
raw_regions = json.loads(match.group(0))
else:
raw_regions = []
regions = []
for r in raw_regions:
regions.append({
"bbox_pct": {
"x": max(0, min(100, float(r.get("x", 0)))),
"y": max(0, min(100, float(r.get("y", 0)))),
"w": max(1, min(100, float(r.get("w", 10)))),
"h": max(1, min(100, float(r.get("h", 10)))),
},
"description": r.get("description", ""),
"prompt": r.get("description", ""),
"image_b64": None,
"style": "educational",
})
# Enrich prompts with nearby vocab context
if entries:
for region in regions:
ry = region["bbox_pct"]["y"]
rh = region["bbox_pct"]["h"]
nearby = [
e for e in entries
if e.get("bbox") and abs(e["bbox"].get("y", 0) - ry) < rh + 10
]
if nearby:
en_words = [e.get("english", "") for e in nearby if e.get("english")]
de_words = [e.get("german", "") for e in nearby if e.get("german")]
if en_words or de_words:
context = f" (vocabulary context: {', '.join(en_words[:5])}"
if de_words:
context += f" / {', '.join(de_words[:5])}"
context += ")"
region["prompt"] = region["description"] + context
ground_truth = session.get("ground_truth") or {}
validation = ground_truth.get("validation") or {}
validation["image_regions"] = regions
validation["detected_at"] = datetime.utcnow().isoformat()
ground_truth["validation"] = validation
await update_session_db(session_id, ground_truth=ground_truth)
if session_id in _cache:
_cache[session_id]["ground_truth"] = ground_truth
logger.info(f"Detected {len(regions)} image regions for session {session_id}")
return {"regions": regions, "count": len(regions)}
except httpx.ConnectError:
logger.warning(f"VLM not available at {ollama_base} for image detection")
return {"regions": [], "count": 0, "error": "VLM not available"}
except Exception as e:
logger.error(f"Image detection failed for {session_id}: {e}")
return {"regions": [], "count": 0, "error": str(e)}
@router.post("/sessions/{session_id}/reconstruction/generate-image")
async def generate_image_for_region(session_id: str, req: GenerateImageRequest):
"""Generate a replacement image for a detected region using mflux."""
import httpx
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
ground_truth = session.get("ground_truth") or {}
validation = ground_truth.get("validation") or {}
regions = validation.get("image_regions") or []
if req.region_index < 0 or req.region_index >= len(regions):
raise HTTPException(status_code=400, detail=f"Invalid region_index {req.region_index}, have {len(regions)} regions")
mflux_url = os.getenv("MFLUX_URL", "http://host.docker.internal:8095")
style_suffix = STYLE_SUFFIXES.get(req.style, STYLE_SUFFIXES["educational"])
full_prompt = f"{req.prompt}, {style_suffix}"
region = regions[req.region_index]
bbox = region["bbox_pct"]
aspect = bbox["w"] / max(bbox["h"], 1)
if aspect > 1.3:
width, height = 768, 512
elif aspect < 0.7:
width, height = 512, 768
else:
width, height = 512, 512
try:
async with httpx.AsyncClient(timeout=300.0) as client:
resp = await client.post(f"{mflux_url}/generate", json={
"prompt": full_prompt,
"width": width,
"height": height,
"steps": 4,
})
resp.raise_for_status()
data = resp.json()
image_b64 = data.get("image_b64")
if not image_b64:
return {"image_b64": None, "success": False, "error": "No image returned"}
regions[req.region_index]["image_b64"] = image_b64
regions[req.region_index]["prompt"] = req.prompt
regions[req.region_index]["style"] = req.style
validation["image_regions"] = regions
ground_truth["validation"] = validation
await update_session_db(session_id, ground_truth=ground_truth)
if session_id in _cache:
_cache[session_id]["ground_truth"] = ground_truth
logger.info(f"Generated image for session {session_id} region {req.region_index}")
return {"image_b64": image_b64, "success": True}
except httpx.ConnectError:
logger.warning(f"mflux-service not available at {mflux_url}")
return {"image_b64": None, "success": False, "error": f"mflux-service not available at {mflux_url}"}
except Exception as e:
logger.error(f"Image generation failed for {session_id}: {e}")
return {"image_b64": None, "success": False, "error": str(e)}
# ---------------------------------------------------------------------------
# Validation save/get
# ---------------------------------------------------------------------------
@router.post("/sessions/{session_id}/reconstruction/validate")
async def save_validation(session_id: str, req: ValidationRequest):
"""Save final validation results for step 8."""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
ground_truth = session.get("ground_truth") or {}
validation = ground_truth.get("validation") or {}
validation["validated_at"] = datetime.utcnow().isoformat()
validation["notes"] = req.notes
validation["score"] = req.score
ground_truth["validation"] = validation
await update_session_db(session_id, ground_truth=ground_truth, current_step=11)
if session_id in _cache:
_cache[session_id]["ground_truth"] = ground_truth
logger.info(f"Validation saved for session {session_id}: score={req.score}")
return {"session_id": session_id, "validation": validation}
@router.get("/sessions/{session_id}/reconstruction/validation")
async def get_validation(session_id: str):
"""Retrieve saved validation data for step 8."""
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
ground_truth = session.get("ground_truth") or {}
validation = ground_truth.get("validation")
return {
"session_id": session_id,
"validation": validation,
"word_result": session.get("word_result"),
}
# ---------------------------------------------------------------------------
# Remove handwriting
# ---------------------------------------------------------------------------
@router.post("/sessions/{session_id}/remove-handwriting")
async def remove_handwriting_endpoint(session_id: str, req: RemoveHandwritingRequest):
"""Remove handwriting from a session image using inpainting."""
import time as _time
from services.handwriting_detection import detect_handwriting
from services.inpainting_service import inpaint_image, dilate_mask as _dilate_mask, InpaintingMethod, image_to_png
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
t0 = _time.monotonic()
# 1. Determine source image
source = req.use_source
if source == "auto":
deskewed = await get_session_image(session_id, "deskewed")
source = "deskewed" if deskewed else "original"
image_bytes = await get_session_image(session_id, source)
if not image_bytes:
raise HTTPException(status_code=404, detail=f"Source image '{source}' not available")
# 2. Detect handwriting mask
detection = detect_handwriting(image_bytes, target_ink=req.target_ink)
# 3. Convert mask to PNG bytes and dilate
import io
from PIL import Image as _PILImage
mask_img = _PILImage.fromarray(detection.mask)
mask_buf = io.BytesIO()
mask_img.save(mask_buf, format="PNG")
mask_bytes = mask_buf.getvalue()
if req.dilation > 0:
mask_bytes = _dilate_mask(mask_bytes, iterations=req.dilation)
# 4. Inpaint
method_map = {
"telea": InpaintingMethod.OPENCV_TELEA,
"ns": InpaintingMethod.OPENCV_NS,
"auto": InpaintingMethod.AUTO,
}
inpaint_method = method_map.get(req.method, InpaintingMethod.AUTO)
result = inpaint_image(image_bytes, mask_bytes, method=inpaint_method)
if not result.success:
raise HTTPException(status_code=500, detail="Inpainting failed")
elapsed_ms = int((_time.monotonic() - t0) * 1000)
meta = {
"method_used": result.method_used.value if hasattr(result.method_used, "value") else str(result.method_used),
"handwriting_ratio": round(detection.handwriting_ratio, 4),
"detection_confidence": round(detection.confidence, 4),
"target_ink": req.target_ink,
"dilation": req.dilation,
"source_image": source,
"processing_time_ms": elapsed_ms,
}
# 5. Persist clean image
clean_png_bytes = image_to_png(result.image)
await update_session_db(session_id, clean_png=clean_png_bytes, handwriting_removal_meta=meta)
return {
**meta,
"image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/clean",
"session_id": session_id,
}

View File

@@ -1,18 +1,18 @@
"""
OCR Pipeline Words - Word detection and ground truth endpoints.
OCR Pipeline Words — composite router for word detection, PaddleOCR direct,
and ground truth endpoints.
Extracted from ocr_pipeline_api.py.
Handles:
- POST /sessions/{session_id}/words — main SSE streaming word detection
- POST /sessions/{session_id}/paddle-direct — PaddleOCR direct endpoint
- POST /sessions/{session_id}/ground-truth/words — save ground truth
- GET /sessions/{session_id}/ground-truth/words — get ground truth
Split into sub-modules:
ocr_pipeline_words_detect — main detect_words endpoint (Step 7)
ocr_pipeline_words_stream — SSE streaming generators
This barrel module contains the PaddleOCR direct endpoint and ground truth
endpoints, and assembles all word-related routers.
Lizenz: Apache 2.0
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
"""
import json
import logging
import time
from datetime import datetime
@@ -20,22 +20,9 @@ from typing import Any, Dict, List, Optional
import cv2
import numpy as np
from fastapi import APIRouter, HTTPException, Request
from fastapi.responses import StreamingResponse
from fastapi import APIRouter, HTTPException
from pydantic import BaseModel
from cv_vocab_pipeline import (
PageRegion,
RowGeometry,
_cells_to_vocab_entries,
_fix_character_confusion,
_fix_phonetic_brackets,
fix_cell_phonetics,
build_cell_grid_v2,
build_cell_grid_v2_streaming,
create_ocr_image,
detect_column_geometry,
)
from cv_words_first import build_grid_from_words
from ocr_pipeline_session_store import (
get_session_db,
@@ -44,15 +31,13 @@ from ocr_pipeline_session_store import (
)
from ocr_pipeline_common import (
_cache,
_load_session_to_cache,
_get_cached,
_get_base_image_png,
_append_pipeline_log,
)
from ocr_pipeline_words_detect import router as _detect_router
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["ocr-pipeline"])
_local_router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["ocr-pipeline"])
# ---------------------------------------------------------------------------
@@ -65,689 +50,13 @@ class WordGroundTruthRequest(BaseModel):
notes: Optional[str] = None
# ---------------------------------------------------------------------------
# Word Detection Endpoint (Step 7)
# ---------------------------------------------------------------------------
@router.post("/sessions/{session_id}/words")
async def detect_words(
session_id: str,
request: Request,
engine: str = "auto",
pronunciation: str = "british",
stream: bool = False,
skip_heal_gaps: bool = False,
grid_method: str = "v2",
):
"""Build word grid from columns × rows, OCR each cell.
Query params:
engine: 'auto' (default), 'tesseract', 'rapid', or 'paddle'
pronunciation: 'british' (default) or 'american' — for IPA dictionary lookup
stream: false (default) for JSON response, true for SSE streaming
skip_heal_gaps: false (default). When true, cells keep exact row geometry
positions without gap-healing expansion. Better for overlay rendering.
grid_method: 'v2' (default) or 'words_first' — grid construction strategy.
'v2' uses pre-detected columns/rows (top-down).
'words_first' clusters words bottom-up (no column/row detection needed).
"""
# PaddleOCR is full-page remote OCR → force words_first grid method
if engine == "paddle" and grid_method != "words_first":
logger.info("detect_words: engine=paddle requires words_first, overriding grid_method=%s", grid_method)
grid_method = "words_first"
if session_id not in _cache:
logger.info("detect_words: session %s not in cache, loading from DB", session_id)
await _load_session_to_cache(session_id)
cached = _get_cached(session_id)
dewarped_bgr = cached.get("cropped_bgr") if cached.get("cropped_bgr") is not None else cached.get("dewarped_bgr")
if dewarped_bgr is None:
logger.warning("detect_words: no cropped/dewarped image for session %s (cache keys: %s)",
session_id, [k for k in cached.keys() if k.endswith('_bgr')])
raise HTTPException(status_code=400, detail="Crop or dewarp must be completed before word detection")
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
column_result = session.get("column_result")
row_result = session.get("row_result")
if not column_result or not column_result.get("columns"):
# No column detection — synthesize a single full-page pseudo-column.
# This enables the overlay pipeline which skips column detection.
img_h_tmp, img_w_tmp = dewarped_bgr.shape[:2]
column_result = {
"columns": [{
"type": "column_text",
"x": 0, "y": 0,
"width": img_w_tmp, "height": img_h_tmp,
"classification_confidence": 1.0,
"classification_method": "full_page_fallback",
}],
"zones": [],
"duration_seconds": 0,
}
logger.info("detect_words: no column_result — using full-page pseudo-column %dx%d", img_w_tmp, img_h_tmp)
if grid_method != "words_first" and (not row_result or not row_result.get("rows")):
raise HTTPException(status_code=400, detail="Row detection must be completed first")
# Convert column dicts back to PageRegion objects
col_regions = [
PageRegion(
type=c["type"],
x=c["x"], y=c["y"],
width=c["width"], height=c["height"],
classification_confidence=c.get("classification_confidence", 1.0),
classification_method=c.get("classification_method", ""),
)
for c in column_result["columns"]
]
# Convert row dicts back to RowGeometry objects
row_geoms = [
RowGeometry(
index=r["index"],
x=r["x"], y=r["y"],
width=r["width"], height=r["height"],
word_count=r.get("word_count", 0),
words=[],
row_type=r.get("row_type", "content"),
gap_before=r.get("gap_before", 0),
)
for r in row_result["rows"]
]
# Cell-First OCR (v2): no full-page word re-population needed.
# Each cell is cropped and OCR'd in isolation → no neighbour bleeding.
# We still need word_count > 0 for row filtering in build_cell_grid_v2,
# so populate from cached words if available (just for counting).
word_dicts = cached.get("_word_dicts")
if word_dicts is None:
ocr_img_tmp = create_ocr_image(dewarped_bgr)
geo_result = detect_column_geometry(ocr_img_tmp, dewarped_bgr)
if geo_result is not None:
_geoms, left_x, right_x, top_y, bottom_y, word_dicts, inv = geo_result
cached["_word_dicts"] = word_dicts
cached["_inv"] = inv
cached["_content_bounds"] = (left_x, right_x, top_y, bottom_y)
if word_dicts:
content_bounds = cached.get("_content_bounds")
if content_bounds:
_lx, _rx, top_y, _by = content_bounds
else:
top_y = min(r.y for r in row_geoms) if row_geoms else 0
for row in row_geoms:
row_y_rel = row.y - top_y
row_bottom_rel = row_y_rel + row.height
row.words = [
w for w in word_dicts
if row_y_rel <= w['top'] + w['height'] / 2 < row_bottom_rel
]
row.word_count = len(row.words)
# Exclude rows that fall within box zones.
# Use inner box range (shrunk by border_thickness) so that rows at
# the boundary (overlapping with the box border) are NOT excluded.
zones = column_result.get("zones") or []
box_ranges_inner = []
for zone in zones:
if zone.get("zone_type") == "box" and zone.get("box"):
box = zone["box"]
bt = max(box.get("border_thickness", 0), 5) # minimum 5px margin
box_ranges_inner.append((box["y"] + bt, box["y"] + box["height"] - bt))
if box_ranges_inner:
def _row_in_box(r):
center_y = r.y + r.height / 2
return any(by_s <= center_y < by_e for by_s, by_e in box_ranges_inner)
before_count = len(row_geoms)
row_geoms = [r for r in row_geoms if not _row_in_box(r)]
excluded = before_count - len(row_geoms)
if excluded:
logger.info(f"detect_words: excluded {excluded} rows inside box zones")
# --- Words-First path: bottom-up grid from word boxes ---
if grid_method == "words_first":
t0 = time.time()
img_h, img_w = dewarped_bgr.shape[:2]
# For paddle engine: run remote PaddleOCR full-page instead of Tesseract
if engine == "paddle":
from cv_ocr_engines import ocr_region_paddle
wf_word_dicts = await ocr_region_paddle(dewarped_bgr, region=None)
# PaddleOCR returns absolute coordinates, no content_bounds offset needed
cached["_paddle_word_dicts"] = wf_word_dicts
else:
# Get word_dicts from cache or run Tesseract full-page
wf_word_dicts = cached.get("_word_dicts")
if wf_word_dicts is None:
ocr_img_tmp = create_ocr_image(dewarped_bgr)
geo_result = detect_column_geometry(ocr_img_tmp, dewarped_bgr)
if geo_result is not None:
_geoms, left_x, right_x, top_y, bottom_y, wf_word_dicts, inv = geo_result
cached["_word_dicts"] = wf_word_dicts
cached["_inv"] = inv
cached["_content_bounds"] = (left_x, right_x, top_y, bottom_y)
if not wf_word_dicts:
raise HTTPException(status_code=400, detail="No words detected — cannot build words-first grid")
# Convert word coordinates to absolute image coordinates if needed
# (detect_column_geometry returns words relative to content ROI)
# PaddleOCR already returns absolute coordinates — skip offset.
if engine != "paddle":
content_bounds = cached.get("_content_bounds")
if content_bounds:
lx, _rx, ty, _by = content_bounds
abs_words = []
for w in wf_word_dicts:
abs_words.append({
**w,
'left': w['left'] + lx,
'top': w['top'] + ty,
})
wf_word_dicts = abs_words
# Extract box rects for box-aware column clustering
box_rects = []
for zone in zones:
if zone.get("zone_type") == "box" and zone.get("box"):
box_rects.append(zone["box"])
cells, columns_meta = build_grid_from_words(
wf_word_dicts, img_w, img_h, box_rects=box_rects or None,
)
duration = time.time() - t0
# Apply IPA phonetic fixes
fix_cell_phonetics(cells, pronunciation=pronunciation)
# Add zone_index for backward compat
for cell in cells:
cell.setdefault("zone_index", 0)
col_types = {c['type'] for c in columns_meta}
is_vocab = bool(col_types & {'column_en', 'column_de'})
n_rows = len(set(c['row_index'] for c in cells)) if cells else 0
n_cols = len(columns_meta)
used_engine = "paddle" if engine == "paddle" else "words_first"
word_result = {
"cells": cells,
"grid_shape": {
"rows": n_rows,
"cols": n_cols,
"total_cells": len(cells),
},
"columns_used": columns_meta,
"layout": "vocab" if is_vocab else "generic",
"image_width": img_w,
"image_height": img_h,
"duration_seconds": round(duration, 2),
"ocr_engine": used_engine,
"grid_method": "words_first",
"summary": {
"total_cells": len(cells),
"non_empty_cells": sum(1 for c in cells if c.get("text")),
"low_confidence": sum(1 for c in cells if 0 < c.get("confidence", 0) < 50),
},
}
if is_vocab or 'column_text' in col_types:
entries = _cells_to_vocab_entries(cells, columns_meta)
entries = _fix_phonetic_brackets(entries, pronunciation=pronunciation)
word_result["vocab_entries"] = entries
word_result["entries"] = entries
word_result["entry_count"] = len(entries)
word_result["summary"]["total_entries"] = len(entries)
word_result["summary"]["with_english"] = sum(1 for e in entries if e.get("english"))
word_result["summary"]["with_german"] = sum(1 for e in entries if e.get("german"))
await update_session_db(session_id, word_result=word_result, current_step=8)
cached["word_result"] = word_result
logger.info(f"OCR Pipeline: words-first session {session_id}: "
f"{len(cells)} cells ({duration:.2f}s), {n_rows} rows, {n_cols} cols")
await _append_pipeline_log(session_id, "words", {
"grid_method": "words_first",
"total_cells": len(cells),
"non_empty_cells": word_result["summary"]["non_empty_cells"],
"ocr_engine": used_engine,
"layout": word_result["layout"],
}, duration_ms=int(duration * 1000))
return {"session_id": session_id, **word_result}
if stream:
# Cell-First OCR v2: use batch-then-stream approach instead of
# per-cell streaming. The parallel ThreadPoolExecutor in
# build_cell_grid_v2 is much faster than sequential streaming.
return StreamingResponse(
_word_batch_stream_generator(
session_id, cached, col_regions, row_geoms,
dewarped_bgr, engine, pronunciation, request,
skip_heal_gaps=skip_heal_gaps,
),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
},
)
# --- Non-streaming path (grid_method=v2) ---
t0 = time.time()
# Create binarized OCR image (for Tesseract)
ocr_img = create_ocr_image(dewarped_bgr)
img_h, img_w = dewarped_bgr.shape[:2]
# Build cell grid using Cell-First OCR (v2) — each cell cropped in isolation
cells, columns_meta = build_cell_grid_v2(
ocr_img, col_regions, row_geoms, img_w, img_h,
ocr_engine=engine, img_bgr=dewarped_bgr,
skip_heal_gaps=skip_heal_gaps,
)
duration = time.time() - t0
# Add zone_index to each cell (default 0 for backward compatibility)
for cell in cells:
cell.setdefault("zone_index", 0)
# Layout detection
col_types = {c['type'] for c in columns_meta}
is_vocab = bool(col_types & {'column_en', 'column_de'})
# Count content rows and columns for grid_shape
n_content_rows = len([r for r in row_geoms if r.row_type == 'content'])
n_cols = len(columns_meta)
# Determine which engine was actually used
used_engine = cells[0].get("ocr_engine", "tesseract") if cells else engine
# Apply IPA phonetic fixes directly to cell texts (for overlay mode)
fix_cell_phonetics(cells, pronunciation=pronunciation)
# Grid result (always generic)
word_result = {
"cells": cells,
"grid_shape": {
"rows": n_content_rows,
"cols": n_cols,
"total_cells": len(cells),
},
"columns_used": columns_meta,
"layout": "vocab" if is_vocab else "generic",
"image_width": img_w,
"image_height": img_h,
"duration_seconds": round(duration, 2),
"ocr_engine": used_engine,
"summary": {
"total_cells": len(cells),
"non_empty_cells": sum(1 for c in cells if c.get("text")),
"low_confidence": sum(1 for c in cells if 0 < c.get("confidence", 0) < 50),
},
}
# For vocab layout or single-column (box sub-sessions): map cells 1:1
# to vocab entries (row→entry).
has_text_col = 'column_text' in col_types
if is_vocab or has_text_col:
entries = _cells_to_vocab_entries(cells, columns_meta)
entries = _fix_phonetic_brackets(entries, pronunciation=pronunciation)
word_result["vocab_entries"] = entries
word_result["entries"] = entries
word_result["entry_count"] = len(entries)
word_result["summary"]["total_entries"] = len(entries)
word_result["summary"]["with_english"] = sum(1 for e in entries if e.get("english"))
word_result["summary"]["with_german"] = sum(1 for e in entries if e.get("german"))
# Persist to DB
await update_session_db(
session_id,
word_result=word_result,
current_step=8,
)
cached["word_result"] = word_result
logger.info(f"OCR Pipeline: words session {session_id}: "
f"layout={word_result['layout']}, "
f"{len(cells)} cells ({duration:.2f}s), summary: {word_result['summary']}")
await _append_pipeline_log(session_id, "words", {
"total_cells": len(cells),
"non_empty_cells": word_result["summary"]["non_empty_cells"],
"low_confidence_count": word_result["summary"]["low_confidence"],
"ocr_engine": used_engine,
"layout": word_result["layout"],
"entry_count": word_result.get("entry_count", 0),
}, duration_ms=int(duration * 1000))
return {
"session_id": session_id,
**word_result,
}
async def _word_batch_stream_generator(
session_id: str,
cached: Dict[str, Any],
col_regions: List[PageRegion],
row_geoms: List[RowGeometry],
dewarped_bgr: np.ndarray,
engine: str,
pronunciation: str,
request: Request,
skip_heal_gaps: bool = False,
):
"""SSE generator that runs batch OCR (parallel) then streams results.
Unlike the old per-cell streaming, this uses build_cell_grid_v2 with
ThreadPoolExecutor for parallel OCR, then emits all cells as SSE events.
The 'preparing' event keeps the connection alive during OCR processing.
"""
import asyncio
t0 = time.time()
ocr_img = create_ocr_image(dewarped_bgr)
img_h, img_w = dewarped_bgr.shape[:2]
_skip_types = {'column_ignore', 'header', 'footer', 'margin_top', 'margin_bottom', 'margin_left', 'margin_right'}
n_content_rows = len([r for r in row_geoms if r.row_type == 'content'])
n_cols = len([c for c in col_regions if c.type not in _skip_types])
col_types = {c.type for c in col_regions if c.type not in _skip_types}
is_vocab = bool(col_types & {'column_en', 'column_de'})
total_cells = n_content_rows * n_cols
# 1. Send meta event immediately
meta_event = {
"type": "meta",
"grid_shape": {"rows": n_content_rows, "cols": n_cols, "total_cells": total_cells},
"layout": "vocab" if is_vocab else "generic",
}
yield f"data: {json.dumps(meta_event)}\n\n"
# 2. Send preparing event (keepalive for proxy)
yield f"data: {json.dumps({'type': 'preparing', 'message': 'Cell-First OCR laeuft parallel...'})}\n\n"
# 3. Run batch OCR in thread pool with periodic keepalive events.
# The OCR takes 30-60s and proxy servers (Nginx) may drop idle SSE
# connections after 30-60s. Send keepalive every 5s to prevent this.
loop = asyncio.get_event_loop()
ocr_future = loop.run_in_executor(
None,
lambda: build_cell_grid_v2(
ocr_img, col_regions, row_geoms, img_w, img_h,
ocr_engine=engine, img_bgr=dewarped_bgr,
skip_heal_gaps=skip_heal_gaps,
),
)
# Send keepalive events every 5 seconds while OCR runs
keepalive_count = 0
while not ocr_future.done():
try:
cells, columns_meta = await asyncio.wait_for(
asyncio.shield(ocr_future), timeout=5.0,
)
break # OCR finished
except asyncio.TimeoutError:
keepalive_count += 1
elapsed = int(time.time() - t0)
yield f"data: {json.dumps({'type': 'keepalive', 'elapsed': elapsed, 'message': f'OCR laeuft... ({elapsed}s)'})}\n\n"
if await request.is_disconnected():
logger.info(f"SSE batch: client disconnected during OCR for {session_id}")
ocr_future.cancel()
return
else:
cells, columns_meta = ocr_future.result()
if await request.is_disconnected():
logger.info(f"SSE batch: client disconnected after OCR for {session_id}")
return
# 4. Apply IPA phonetic fixes directly to cell texts (for overlay mode)
fix_cell_phonetics(cells, pronunciation=pronunciation)
# 5. Send columns meta
if columns_meta:
yield f"data: {json.dumps({'type': 'columns', 'columns_used': columns_meta})}\n\n"
# 6. Stream all cells
for idx, cell in enumerate(cells):
cell_event = {
"type": "cell",
"cell": cell,
"progress": {"current": idx + 1, "total": len(cells)},
}
yield f"data: {json.dumps(cell_event)}\n\n"
# 6. Build final result and persist
duration = time.time() - t0
used_engine = cells[0].get("ocr_engine", "tesseract") if cells else engine
word_result = {
"cells": cells,
"grid_shape": {"rows": n_content_rows, "cols": n_cols, "total_cells": len(cells)},
"columns_used": columns_meta,
"layout": "vocab" if is_vocab else "generic",
"image_width": img_w,
"image_height": img_h,
"duration_seconds": round(duration, 2),
"ocr_engine": used_engine,
"summary": {
"total_cells": len(cells),
"non_empty_cells": sum(1 for c in cells if c.get("text")),
"low_confidence": sum(1 for c in cells if 0 < c.get("confidence", 0) < 50),
},
}
vocab_entries = None
has_text_col = 'column_text' in col_types
if is_vocab or has_text_col:
entries = _cells_to_vocab_entries(cells, columns_meta)
entries = _fix_phonetic_brackets(entries, pronunciation=pronunciation)
word_result["vocab_entries"] = entries
word_result["entries"] = entries
word_result["entry_count"] = len(entries)
word_result["summary"]["total_entries"] = len(entries)
word_result["summary"]["with_english"] = sum(1 for e in entries if e.get("english"))
word_result["summary"]["with_german"] = sum(1 for e in entries if e.get("german"))
vocab_entries = entries
await update_session_db(session_id, word_result=word_result, current_step=8)
cached["word_result"] = word_result
logger.info(f"OCR Pipeline SSE batch: words session {session_id}: "
f"layout={word_result['layout']}, {len(cells)} cells ({duration:.2f}s)")
# 7. Send complete event
complete_event = {
"type": "complete",
"summary": word_result["summary"],
"duration_seconds": round(duration, 2),
"ocr_engine": used_engine,
}
if vocab_entries is not None:
complete_event["vocab_entries"] = vocab_entries
yield f"data: {json.dumps(complete_event)}\n\n"
async def _word_stream_generator(
session_id: str,
cached: Dict[str, Any],
col_regions: List[PageRegion],
row_geoms: List[RowGeometry],
dewarped_bgr: np.ndarray,
engine: str,
pronunciation: str,
request: Request,
):
"""SSE generator that yields cell-by-cell OCR progress."""
t0 = time.time()
ocr_img = create_ocr_image(dewarped_bgr)
img_h, img_w = dewarped_bgr.shape[:2]
# Compute grid shape upfront for the meta event
n_content_rows = len([r for r in row_geoms if r.row_type == 'content'])
_skip_types = {'column_ignore', 'header', 'footer', 'margin_top', 'margin_bottom', 'margin_left', 'margin_right'}
n_cols = len([c for c in col_regions if c.type not in _skip_types])
# Determine layout
col_types = {c.type for c in col_regions if c.type not in _skip_types}
is_vocab = bool(col_types & {'column_en', 'column_de'})
# Start streaming — first event: meta
columns_meta = None # will be set from first yield
total_cells = n_content_rows * n_cols
meta_event = {
"type": "meta",
"grid_shape": {"rows": n_content_rows, "cols": n_cols, "total_cells": total_cells},
"layout": "vocab" if is_vocab else "generic",
}
yield f"data: {json.dumps(meta_event)}\n\n"
# Keepalive: send preparing event so proxy doesn't timeout during OCR init
yield f"data: {json.dumps({'type': 'preparing', 'message': 'Cell-First OCR wird initialisiert...'})}\n\n"
# Stream cells one by one
all_cells: List[Dict[str, Any]] = []
cell_idx = 0
last_keepalive = time.time()
for cell, cols_meta, total in build_cell_grid_v2_streaming(
ocr_img, col_regions, row_geoms, img_w, img_h,
ocr_engine=engine, img_bgr=dewarped_bgr,
):
if await request.is_disconnected():
logger.info(f"SSE: client disconnected during streaming for {session_id}")
return
if columns_meta is None:
columns_meta = cols_meta
# Send columns_used as part of first cell or update meta
meta_update = {
"type": "columns",
"columns_used": cols_meta,
}
yield f"data: {json.dumps(meta_update)}\n\n"
all_cells.append(cell)
cell_idx += 1
cell_event = {
"type": "cell",
"cell": cell,
"progress": {"current": cell_idx, "total": total},
}
yield f"data: {json.dumps(cell_event)}\n\n"
# All cells done — build final result
duration = time.time() - t0
if columns_meta is None:
columns_meta = []
# Post-OCR: remove rows where ALL cells are empty (inter-row gaps
# that had stray Tesseract artifacts giving word_count > 0).
rows_with_text: set = set()
for c in all_cells:
if c.get("text", "").strip():
rows_with_text.add(c["row_index"])
before_filter = len(all_cells)
all_cells = [c for c in all_cells if c["row_index"] in rows_with_text]
empty_rows_removed = (before_filter - len(all_cells)) // max(n_cols, 1)
if empty_rows_removed > 0:
logger.info(f"SSE: removed {empty_rows_removed} all-empty rows after OCR")
used_engine = all_cells[0].get("ocr_engine", "tesseract") if all_cells else engine
# Apply IPA phonetic fixes directly to cell texts (for overlay mode)
fix_cell_phonetics(all_cells, pronunciation=pronunciation)
word_result = {
"cells": all_cells,
"grid_shape": {
"rows": n_content_rows,
"cols": n_cols,
"total_cells": len(all_cells),
},
"columns_used": columns_meta,
"layout": "vocab" if is_vocab else "generic",
"image_width": img_w,
"image_height": img_h,
"duration_seconds": round(duration, 2),
"ocr_engine": used_engine,
"summary": {
"total_cells": len(all_cells),
"non_empty_cells": sum(1 for c in all_cells if c.get("text")),
"low_confidence": sum(1 for c in all_cells if 0 < c.get("confidence", 0) < 50),
},
}
# For vocab layout or single-column (box sub-sessions): map cells 1:1
# to vocab entries (row→entry).
vocab_entries = None
has_text_col = 'column_text' in col_types
if is_vocab or has_text_col:
entries = _cells_to_vocab_entries(all_cells, columns_meta)
entries = _fix_character_confusion(entries)
entries = _fix_phonetic_brackets(entries, pronunciation=pronunciation)
word_result["vocab_entries"] = entries
word_result["entries"] = entries
word_result["entry_count"] = len(entries)
word_result["summary"]["total_entries"] = len(entries)
word_result["summary"]["with_english"] = sum(1 for e in entries if e.get("english"))
word_result["summary"]["with_german"] = sum(1 for e in entries if e.get("german"))
vocab_entries = entries
# Persist to DB
await update_session_db(
session_id,
word_result=word_result,
current_step=8,
)
cached["word_result"] = word_result
logger.info(f"OCR Pipeline SSE: words session {session_id}: "
f"layout={word_result['layout']}, "
f"{len(all_cells)} cells ({duration:.2f}s)")
# Final complete event
complete_event = {
"type": "complete",
"summary": word_result["summary"],
"duration_seconds": round(duration, 2),
"ocr_engine": used_engine,
}
if vocab_entries is not None:
complete_event["vocab_entries"] = vocab_entries
yield f"data: {json.dumps(complete_event)}\n\n"
# ---------------------------------------------------------------------------
# PaddleOCR Direct Endpoint
# ---------------------------------------------------------------------------
@router.post("/sessions/{session_id}/paddle-direct")
@_local_router.post("/sessions/{session_id}/paddle-direct")
async def paddle_direct(session_id: str):
"""Run PaddleOCR on the preprocessed image and build a word grid directly.
Expects orientation/deskew/dewarp/crop to be done already.
Uses the cropped image (falls back to dewarped, then original).
The used image is stored as cropped_png so OverlayReconstruction
can display it as the background.
"""
# Try preprocessed images first (crop > dewarp > original)
"""Run PaddleOCR on the preprocessed image and build a word grid directly."""
img_png = await get_session_image(session_id, "cropped")
if not img_png:
img_png = await get_session_image(session_id, "dewarped")
@@ -770,13 +79,9 @@ async def paddle_direct(session_id: str):
if not word_dicts:
raise HTTPException(status_code=400, detail="PaddleOCR returned no words")
# Reuse build_grid_from_words — same function that works in the regular
# pipeline with PaddleOCR (engine=paddle, grid_method=words_first).
# Handles phrase splitting, column clustering, and reading order.
cells, columns_meta = build_grid_from_words(word_dicts, img_w, img_h)
duration = time.time() - t0
# Tag cells as paddle_direct
for cell in cells:
cell["ocr_engine"] = "paddle_direct"
@@ -787,11 +92,7 @@ async def paddle_direct(session_id: str):
word_result = {
"cells": cells,
"grid_shape": {
"rows": n_rows,
"cols": n_cols,
"total_cells": len(cells),
},
"grid_shape": {"rows": n_rows, "cols": n_cols, "total_cells": len(cells)},
"columns_used": columns_meta,
"layout": "vocab" if is_vocab else "generic",
"image_width": img_w,
@@ -806,7 +107,6 @@ async def paddle_direct(session_id: str):
},
}
# Store preprocessed image as cropped_png so OverlayReconstruction shows it
await update_session_db(
session_id,
word_result=word_result,
@@ -832,7 +132,7 @@ async def paddle_direct(session_id: str):
# Ground Truth Words Endpoints
# ---------------------------------------------------------------------------
@router.post("/sessions/{session_id}/ground-truth/words")
@_local_router.post("/sessions/{session_id}/ground-truth/words")
async def save_word_ground_truth(session_id: str, req: WordGroundTruthRequest):
"""Save ground truth feedback for the word recognition step."""
session = await get_session_db(session_id)
@@ -857,7 +157,7 @@ async def save_word_ground_truth(session_id: str, req: WordGroundTruthRequest):
return {"session_id": session_id, "ground_truth": gt}
@router.get("/sessions/{session_id}/ground-truth/words")
@_local_router.get("/sessions/{session_id}/ground-truth/words")
async def get_word_ground_truth(session_id: str):
"""Retrieve saved ground truth for word recognition."""
session = await get_session_db(session_id)
@@ -874,3 +174,12 @@ async def get_word_ground_truth(session_id: str):
"words_gt": words_gt,
"words_auto": session.get("word_result"),
}
# ---------------------------------------------------------------------------
# Composite router
# ---------------------------------------------------------------------------
router = APIRouter()
router.include_router(_detect_router)
router.include_router(_local_router)

View File

@@ -0,0 +1,393 @@
"""
OCR Pipeline Words Detect — main word detection endpoint (Step 7).
Extracted from ocr_pipeline_words.py. Contains the ``detect_words``
endpoint which handles both v2 and words_first grid methods.
Lizenz: Apache 2.0
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
"""
import json
import logging
import time
from typing import Any, Dict, List
import numpy as np
from fastapi import APIRouter, HTTPException, Request
from fastapi.responses import StreamingResponse
from cv_vocab_pipeline import (
PageRegion,
RowGeometry,
_cells_to_vocab_entries,
_fix_phonetic_brackets,
fix_cell_phonetics,
build_cell_grid_v2,
create_ocr_image,
detect_column_geometry,
)
from cv_words_first import build_grid_from_words
from ocr_pipeline_session_store import (
get_session_db,
update_session_db,
)
from ocr_pipeline_common import (
_cache,
_load_session_to_cache,
_get_cached,
_append_pipeline_log,
)
from ocr_pipeline_words_stream import (
_word_batch_stream_generator,
)
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["ocr-pipeline"])
# ---------------------------------------------------------------------------
# Word Detection Endpoint (Step 7)
# ---------------------------------------------------------------------------
@router.post("/sessions/{session_id}/words")
async def detect_words(
session_id: str,
request: Request,
engine: str = "auto",
pronunciation: str = "british",
stream: bool = False,
skip_heal_gaps: bool = False,
grid_method: str = "v2",
):
"""Build word grid from columns x rows, OCR each cell.
Query params:
engine: 'auto' (default), 'tesseract', 'rapid', or 'paddle'
pronunciation: 'british' (default) or 'american'
stream: false (default) for JSON response, true for SSE streaming
skip_heal_gaps: false (default). When true, cells keep exact row geometry.
grid_method: 'v2' (default) or 'words_first'
"""
# PaddleOCR is full-page remote OCR -> force words_first grid method
if engine == "paddle" and grid_method != "words_first":
logger.info("detect_words: engine=paddle requires words_first, overriding grid_method=%s", grid_method)
grid_method = "words_first"
if session_id not in _cache:
logger.info("detect_words: session %s not in cache, loading from DB", session_id)
await _load_session_to_cache(session_id)
cached = _get_cached(session_id)
dewarped_bgr = cached.get("cropped_bgr") if cached.get("cropped_bgr") is not None else cached.get("dewarped_bgr")
if dewarped_bgr is None:
logger.warning("detect_words: no cropped/dewarped image for session %s (cache keys: %s)",
session_id, [k for k in cached.keys() if k.endswith('_bgr')])
raise HTTPException(status_code=400, detail="Crop or dewarp must be completed before word detection")
session = await get_session_db(session_id)
if not session:
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
column_result = session.get("column_result")
row_result = session.get("row_result")
if not column_result or not column_result.get("columns"):
img_h_tmp, img_w_tmp = dewarped_bgr.shape[:2]
column_result = {
"columns": [{
"type": "column_text",
"x": 0, "y": 0,
"width": img_w_tmp, "height": img_h_tmp,
"classification_confidence": 1.0,
"classification_method": "full_page_fallback",
}],
"zones": [],
"duration_seconds": 0,
}
logger.info("detect_words: no column_result -- using full-page pseudo-column %dx%d", img_w_tmp, img_h_tmp)
if grid_method != "words_first" and (not row_result or not row_result.get("rows")):
raise HTTPException(status_code=400, detail="Row detection must be completed first")
# Convert column dicts back to PageRegion objects
col_regions = [
PageRegion(
type=c["type"],
x=c["x"], y=c["y"],
width=c["width"], height=c["height"],
classification_confidence=c.get("classification_confidence", 1.0),
classification_method=c.get("classification_method", ""),
)
for c in column_result["columns"]
]
# Convert row dicts back to RowGeometry objects
row_geoms = [
RowGeometry(
index=r["index"],
x=r["x"], y=r["y"],
width=r["width"], height=r["height"],
word_count=r.get("word_count", 0),
words=[],
row_type=r.get("row_type", "content"),
gap_before=r.get("gap_before", 0),
)
for r in row_result["rows"]
]
# Populate word counts from cached words
word_dicts = cached.get("_word_dicts")
if word_dicts is None:
ocr_img_tmp = create_ocr_image(dewarped_bgr)
geo_result = detect_column_geometry(ocr_img_tmp, dewarped_bgr)
if geo_result is not None:
_geoms, left_x, right_x, top_y, bottom_y, word_dicts, inv = geo_result
cached["_word_dicts"] = word_dicts
cached["_inv"] = inv
cached["_content_bounds"] = (left_x, right_x, top_y, bottom_y)
if word_dicts:
content_bounds = cached.get("_content_bounds")
if content_bounds:
_lx, _rx, top_y, _by = content_bounds
else:
top_y = min(r.y for r in row_geoms) if row_geoms else 0
for row in row_geoms:
row_y_rel = row.y - top_y
row_bottom_rel = row_y_rel + row.height
row.words = [
w for w in word_dicts
if row_y_rel <= w['top'] + w['height'] / 2 < row_bottom_rel
]
row.word_count = len(row.words)
# Exclude rows that fall within box zones
zones = column_result.get("zones") or []
box_ranges_inner = []
for zone in zones:
if zone.get("zone_type") == "box" and zone.get("box"):
box = zone["box"]
bt = max(box.get("border_thickness", 0), 5)
box_ranges_inner.append((box["y"] + bt, box["y"] + box["height"] - bt))
if box_ranges_inner:
def _row_in_box(r):
center_y = r.y + r.height / 2
return any(by_s <= center_y < by_e for by_s, by_e in box_ranges_inner)
before_count = len(row_geoms)
row_geoms = [r for r in row_geoms if not _row_in_box(r)]
excluded = before_count - len(row_geoms)
if excluded:
logger.info(f"detect_words: excluded {excluded} rows inside box zones")
# --- Words-First path ---
if grid_method == "words_first":
return await _words_first_path(
session_id, cached, dewarped_bgr, engine, pronunciation, zones,
)
if stream:
return StreamingResponse(
_word_batch_stream_generator(
session_id, cached, col_regions, row_geoms,
dewarped_bgr, engine, pronunciation, request,
skip_heal_gaps=skip_heal_gaps,
),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
},
)
# --- Non-streaming path (grid_method=v2) ---
return await _v2_path(
session_id, cached, col_regions, row_geoms,
dewarped_bgr, engine, pronunciation, skip_heal_gaps,
)
async def _words_first_path(
session_id: str,
cached: Dict[str, Any],
dewarped_bgr: np.ndarray,
engine: str,
pronunciation: str,
zones: list,
) -> dict:
"""Words-first grid construction path."""
t0 = time.time()
img_h, img_w = dewarped_bgr.shape[:2]
if engine == "paddle":
from cv_ocr_engines import ocr_region_paddle
wf_word_dicts = await ocr_region_paddle(dewarped_bgr, region=None)
cached["_paddle_word_dicts"] = wf_word_dicts
else:
wf_word_dicts = cached.get("_word_dicts")
if wf_word_dicts is None:
ocr_img_tmp = create_ocr_image(dewarped_bgr)
geo_result = detect_column_geometry(ocr_img_tmp, dewarped_bgr)
if geo_result is not None:
_geoms, left_x, right_x, top_y, bottom_y, wf_word_dicts, inv = geo_result
cached["_word_dicts"] = wf_word_dicts
cached["_inv"] = inv
cached["_content_bounds"] = (left_x, right_x, top_y, bottom_y)
if not wf_word_dicts:
raise HTTPException(status_code=400, detail="No words detected -- cannot build words-first grid")
# Convert word coordinates to absolute if needed
if engine != "paddle":
content_bounds = cached.get("_content_bounds")
if content_bounds:
lx, _rx, ty, _by = content_bounds
abs_words = []
for w in wf_word_dicts:
abs_words.append({**w, 'left': w['left'] + lx, 'top': w['top'] + ty})
wf_word_dicts = abs_words
box_rects = []
for zone in zones:
if zone.get("zone_type") == "box" and zone.get("box"):
box_rects.append(zone["box"])
cells, columns_meta = build_grid_from_words(
wf_word_dicts, img_w, img_h, box_rects=box_rects or None,
)
duration = time.time() - t0
fix_cell_phonetics(cells, pronunciation=pronunciation)
for cell in cells:
cell.setdefault("zone_index", 0)
col_types = {c['type'] for c in columns_meta}
is_vocab = bool(col_types & {'column_en', 'column_de'})
n_rows = len(set(c['row_index'] for c in cells)) if cells else 0
n_cols = len(columns_meta)
used_engine = "paddle" if engine == "paddle" else "words_first"
word_result = {
"cells": cells,
"grid_shape": {"rows": n_rows, "cols": n_cols, "total_cells": len(cells)},
"columns_used": columns_meta,
"layout": "vocab" if is_vocab else "generic",
"image_width": img_w,
"image_height": img_h,
"duration_seconds": round(duration, 2),
"ocr_engine": used_engine,
"grid_method": "words_first",
"summary": {
"total_cells": len(cells),
"non_empty_cells": sum(1 for c in cells if c.get("text")),
"low_confidence": sum(1 for c in cells if 0 < c.get("confidence", 0) < 50),
},
}
if is_vocab or 'column_text' in col_types:
entries = _cells_to_vocab_entries(cells, columns_meta)
entries = _fix_phonetic_brackets(entries, pronunciation=pronunciation)
word_result["vocab_entries"] = entries
word_result["entries"] = entries
word_result["entry_count"] = len(entries)
word_result["summary"]["total_entries"] = len(entries)
word_result["summary"]["with_english"] = sum(1 for e in entries if e.get("english"))
word_result["summary"]["with_german"] = sum(1 for e in entries if e.get("german"))
await update_session_db(session_id, word_result=word_result, current_step=8)
cached["word_result"] = word_result
logger.info(f"OCR Pipeline: words-first session {session_id}: "
f"{len(cells)} cells ({duration:.2f}s), {n_rows} rows, {n_cols} cols")
await _append_pipeline_log(session_id, "words", {
"grid_method": "words_first",
"total_cells": len(cells),
"non_empty_cells": word_result["summary"]["non_empty_cells"],
"ocr_engine": used_engine,
"layout": word_result["layout"],
}, duration_ms=int(duration * 1000))
return {"session_id": session_id, **word_result}
async def _v2_path(
session_id: str,
cached: Dict[str, Any],
col_regions: List[PageRegion],
row_geoms: List[RowGeometry],
dewarped_bgr: np.ndarray,
engine: str,
pronunciation: str,
skip_heal_gaps: bool,
) -> dict:
"""Cell-First OCR v2 non-streaming path."""
t0 = time.time()
ocr_img = create_ocr_image(dewarped_bgr)
img_h, img_w = dewarped_bgr.shape[:2]
cells, columns_meta = build_cell_grid_v2(
ocr_img, col_regions, row_geoms, img_w, img_h,
ocr_engine=engine, img_bgr=dewarped_bgr,
skip_heal_gaps=skip_heal_gaps,
)
duration = time.time() - t0
for cell in cells:
cell.setdefault("zone_index", 0)
col_types = {c['type'] for c in columns_meta}
is_vocab = bool(col_types & {'column_en', 'column_de'})
n_content_rows = len([r for r in row_geoms if r.row_type == 'content'])
n_cols = len(columns_meta)
used_engine = cells[0].get("ocr_engine", "tesseract") if cells else engine
fix_cell_phonetics(cells, pronunciation=pronunciation)
word_result = {
"cells": cells,
"grid_shape": {"rows": n_content_rows, "cols": n_cols, "total_cells": len(cells)},
"columns_used": columns_meta,
"layout": "vocab" if is_vocab else "generic",
"image_width": img_w,
"image_height": img_h,
"duration_seconds": round(duration, 2),
"ocr_engine": used_engine,
"summary": {
"total_cells": len(cells),
"non_empty_cells": sum(1 for c in cells if c.get("text")),
"low_confidence": sum(1 for c in cells if 0 < c.get("confidence", 0) < 50),
},
}
has_text_col = 'column_text' in col_types
if is_vocab or has_text_col:
entries = _cells_to_vocab_entries(cells, columns_meta)
entries = _fix_phonetic_brackets(entries, pronunciation=pronunciation)
word_result["vocab_entries"] = entries
word_result["entries"] = entries
word_result["entry_count"] = len(entries)
word_result["summary"]["total_entries"] = len(entries)
word_result["summary"]["with_english"] = sum(1 for e in entries if e.get("english"))
word_result["summary"]["with_german"] = sum(1 for e in entries if e.get("german"))
await update_session_db(session_id, word_result=word_result, current_step=8)
cached["word_result"] = word_result
logger.info(f"OCR Pipeline: words session {session_id}: "
f"layout={word_result['layout']}, "
f"{len(cells)} cells ({duration:.2f}s), summary: {word_result['summary']}")
await _append_pipeline_log(session_id, "words", {
"total_cells": len(cells),
"non_empty_cells": word_result["summary"]["non_empty_cells"],
"low_confidence_count": word_result["summary"]["low_confidence"],
"ocr_engine": used_engine,
"layout": word_result["layout"],
"entry_count": word_result.get("entry_count", 0),
}, duration_ms=int(duration * 1000))
return {"session_id": session_id, **word_result}

View File

@@ -0,0 +1,303 @@
"""
OCR Pipeline Words Stream — SSE streaming generators for word detection.
Extracted from ocr_pipeline_words.py.
Lizenz: Apache 2.0
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
"""
import json
import logging
import time
from typing import Any, Dict, List
import numpy as np
from fastapi import Request
from cv_vocab_pipeline import (
PageRegion,
RowGeometry,
_cells_to_vocab_entries,
_fix_character_confusion,
_fix_phonetic_brackets,
fix_cell_phonetics,
build_cell_grid_v2,
build_cell_grid_v2_streaming,
create_ocr_image,
)
from ocr_pipeline_session_store import update_session_db
from ocr_pipeline_common import _cache
logger = logging.getLogger(__name__)
async def _word_batch_stream_generator(
session_id: str,
cached: Dict[str, Any],
col_regions: List[PageRegion],
row_geoms: List[RowGeometry],
dewarped_bgr: np.ndarray,
engine: str,
pronunciation: str,
request: Request,
skip_heal_gaps: bool = False,
):
"""SSE generator that runs batch OCR (parallel) then streams results.
Uses build_cell_grid_v2 with ThreadPoolExecutor for parallel OCR,
then emits all cells as SSE events.
"""
import asyncio
t0 = time.time()
ocr_img = create_ocr_image(dewarped_bgr)
img_h, img_w = dewarped_bgr.shape[:2]
_skip_types = {'column_ignore', 'header', 'footer', 'margin_top', 'margin_bottom', 'margin_left', 'margin_right'}
n_content_rows = len([r for r in row_geoms if r.row_type == 'content'])
n_cols = len([c for c in col_regions if c.type not in _skip_types])
col_types = {c.type for c in col_regions if c.type not in _skip_types}
is_vocab = bool(col_types & {'column_en', 'column_de'})
total_cells = n_content_rows * n_cols
# 1. Send meta event immediately
meta_event = {
"type": "meta",
"grid_shape": {"rows": n_content_rows, "cols": n_cols, "total_cells": total_cells},
"layout": "vocab" if is_vocab else "generic",
}
yield f"data: {json.dumps(meta_event)}\n\n"
# 2. Send preparing event (keepalive for proxy)
yield f"data: {json.dumps({'type': 'preparing', 'message': 'Cell-First OCR laeuft parallel...'})}\n\n"
# 3. Run batch OCR in thread pool with periodic keepalive events.
loop = asyncio.get_event_loop()
ocr_future = loop.run_in_executor(
None,
lambda: build_cell_grid_v2(
ocr_img, col_regions, row_geoms, img_w, img_h,
ocr_engine=engine, img_bgr=dewarped_bgr,
skip_heal_gaps=skip_heal_gaps,
),
)
# Send keepalive events every 5 seconds while OCR runs
keepalive_count = 0
while not ocr_future.done():
try:
cells, columns_meta = await asyncio.wait_for(
asyncio.shield(ocr_future), timeout=5.0,
)
break # OCR finished
except asyncio.TimeoutError:
keepalive_count += 1
elapsed = int(time.time() - t0)
yield f"data: {json.dumps({'type': 'keepalive', 'elapsed': elapsed, 'message': f'OCR laeuft... ({elapsed}s)'})}\n\n"
if await request.is_disconnected():
logger.info(f"SSE batch: client disconnected during OCR for {session_id}")
ocr_future.cancel()
return
else:
cells, columns_meta = ocr_future.result()
if await request.is_disconnected():
logger.info(f"SSE batch: client disconnected after OCR for {session_id}")
return
# 4. Apply IPA phonetic fixes
fix_cell_phonetics(cells, pronunciation=pronunciation)
# 5. Send columns meta
if columns_meta:
yield f"data: {json.dumps({'type': 'columns', 'columns_used': columns_meta})}\n\n"
# 6. Stream all cells
for idx, cell in enumerate(cells):
cell_event = {
"type": "cell",
"cell": cell,
"progress": {"current": idx + 1, "total": len(cells)},
}
yield f"data: {json.dumps(cell_event)}\n\n"
# 7. Build final result and persist
duration = time.time() - t0
used_engine = cells[0].get("ocr_engine", "tesseract") if cells else engine
word_result = {
"cells": cells,
"grid_shape": {"rows": n_content_rows, "cols": n_cols, "total_cells": len(cells)},
"columns_used": columns_meta,
"layout": "vocab" if is_vocab else "generic",
"image_width": img_w,
"image_height": img_h,
"duration_seconds": round(duration, 2),
"ocr_engine": used_engine,
"summary": {
"total_cells": len(cells),
"non_empty_cells": sum(1 for c in cells if c.get("text")),
"low_confidence": sum(1 for c in cells if 0 < c.get("confidence", 0) < 50),
},
}
vocab_entries = None
has_text_col = 'column_text' in col_types
if is_vocab or has_text_col:
entries = _cells_to_vocab_entries(cells, columns_meta)
entries = _fix_phonetic_brackets(entries, pronunciation=pronunciation)
word_result["vocab_entries"] = entries
word_result["entries"] = entries
word_result["entry_count"] = len(entries)
word_result["summary"]["total_entries"] = len(entries)
word_result["summary"]["with_english"] = sum(1 for e in entries if e.get("english"))
word_result["summary"]["with_german"] = sum(1 for e in entries if e.get("german"))
vocab_entries = entries
await update_session_db(session_id, word_result=word_result, current_step=8)
cached["word_result"] = word_result
logger.info(f"OCR Pipeline SSE batch: words session {session_id}: "
f"layout={word_result['layout']}, {len(cells)} cells ({duration:.2f}s)")
# 8. Send complete event
complete_event = {
"type": "complete",
"summary": word_result["summary"],
"duration_seconds": round(duration, 2),
"ocr_engine": used_engine,
}
if vocab_entries is not None:
complete_event["vocab_entries"] = vocab_entries
yield f"data: {json.dumps(complete_event)}\n\n"
async def _word_stream_generator(
session_id: str,
cached: Dict[str, Any],
col_regions: List[PageRegion],
row_geoms: List[RowGeometry],
dewarped_bgr: np.ndarray,
engine: str,
pronunciation: str,
request: Request,
):
"""SSE generator that yields cell-by-cell OCR progress."""
t0 = time.time()
ocr_img = create_ocr_image(dewarped_bgr)
img_h, img_w = dewarped_bgr.shape[:2]
n_content_rows = len([r for r in row_geoms if r.row_type == 'content'])
_skip_types = {'column_ignore', 'header', 'footer', 'margin_top', 'margin_bottom', 'margin_left', 'margin_right'}
n_cols = len([c for c in col_regions if c.type not in _skip_types])
col_types = {c.type for c in col_regions if c.type not in _skip_types}
is_vocab = bool(col_types & {'column_en', 'column_de'})
columns_meta = None
total_cells = n_content_rows * n_cols
meta_event = {
"type": "meta",
"grid_shape": {"rows": n_content_rows, "cols": n_cols, "total_cells": total_cells},
"layout": "vocab" if is_vocab else "generic",
}
yield f"data: {json.dumps(meta_event)}\n\n"
yield f"data: {json.dumps({'type': 'preparing', 'message': 'Cell-First OCR wird initialisiert...'})}\n\n"
all_cells: List[Dict[str, Any]] = []
cell_idx = 0
last_keepalive = time.time()
for cell, cols_meta, total in build_cell_grid_v2_streaming(
ocr_img, col_regions, row_geoms, img_w, img_h,
ocr_engine=engine, img_bgr=dewarped_bgr,
):
if await request.is_disconnected():
logger.info(f"SSE: client disconnected during streaming for {session_id}")
return
if columns_meta is None:
columns_meta = cols_meta
meta_update = {"type": "columns", "columns_used": cols_meta}
yield f"data: {json.dumps(meta_update)}\n\n"
all_cells.append(cell)
cell_idx += 1
cell_event = {
"type": "cell",
"cell": cell,
"progress": {"current": cell_idx, "total": total},
}
yield f"data: {json.dumps(cell_event)}\n\n"
# All cells done
duration = time.time() - t0
if columns_meta is None:
columns_meta = []
# Remove all-empty rows
rows_with_text: set = set()
for c in all_cells:
if c.get("text", "").strip():
rows_with_text.add(c["row_index"])
before_filter = len(all_cells)
all_cells = [c for c in all_cells if c["row_index"] in rows_with_text]
empty_rows_removed = (before_filter - len(all_cells)) // max(n_cols, 1)
if empty_rows_removed > 0:
logger.info(f"SSE: removed {empty_rows_removed} all-empty rows after OCR")
used_engine = all_cells[0].get("ocr_engine", "tesseract") if all_cells else engine
fix_cell_phonetics(all_cells, pronunciation=pronunciation)
word_result = {
"cells": all_cells,
"grid_shape": {"rows": n_content_rows, "cols": n_cols, "total_cells": len(all_cells)},
"columns_used": columns_meta,
"layout": "vocab" if is_vocab else "generic",
"image_width": img_w,
"image_height": img_h,
"duration_seconds": round(duration, 2),
"ocr_engine": used_engine,
"summary": {
"total_cells": len(all_cells),
"non_empty_cells": sum(1 for c in all_cells if c.get("text")),
"low_confidence": sum(1 for c in all_cells if 0 < c.get("confidence", 0) < 50),
},
}
vocab_entries = None
has_text_col = 'column_text' in col_types
if is_vocab or has_text_col:
entries = _cells_to_vocab_entries(all_cells, columns_meta)
entries = _fix_character_confusion(entries)
entries = _fix_phonetic_brackets(entries, pronunciation=pronunciation)
word_result["vocab_entries"] = entries
word_result["entries"] = entries
word_result["entry_count"] = len(entries)
word_result["summary"]["total_entries"] = len(entries)
word_result["summary"]["with_english"] = sum(1 for e in entries if e.get("english"))
word_result["summary"]["with_german"] = sum(1 for e in entries if e.get("german"))
vocab_entries = entries
await update_session_db(session_id, word_result=word_result, current_step=8)
cached["word_result"] = word_result
logger.info(f"OCR Pipeline SSE: words session {session_id}: "
f"layout={word_result['layout']}, "
f"{len(all_cells)} cells ({duration:.2f}s)")
complete_event = {
"type": "complete",
"summary": word_result["summary"],
"duration_seconds": round(duration, 2),
"ocr_engine": used_engine,
}
if vocab_entries is not None:
complete_event["vocab_entries"] = vocab_entries
yield f"data: {json.dumps(complete_event)}\n\n"