[split-required] Split 500-1000 LOC files across all services
backend-lehrer (5 files): - alerts_agent/db/repository.py (992 → 5), abitur_docs_api.py (956 → 3) - teacher_dashboard_api.py (951 → 3), services/pdf_service.py (916 → 3) - mail/mail_db.py (987 → 6) klausur-service (5 files): - legal_templates_ingestion.py (942 → 3), ocr_pipeline_postprocess.py (929 → 4) - ocr_pipeline_words.py (876 → 3), ocr_pipeline_ocr_merge.py (616 → 2) - KorrekturPage.tsx (956 → 6) website (5 pages): - mail (985 → 9), edu-search (958 → 8), mac-mini (950 → 7) - ocr-labeling (946 → 7), audit-workspace (871 → 4) studio-v2 (5 files + 1 deleted): - page.tsx (946 → 5), MessagesContext.tsx (925 → 4) - korrektur (914 → 6), worksheet-cleanup (899 → 6) - useVocabWorksheet.ts (888 → 3) - Deleted dead page-original.tsx (934 LOC) Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
282
klausur-service/backend/legal_templates_chunking.py
Normal file
282
klausur-service/backend/legal_templates_chunking.py
Normal file
@@ -0,0 +1,282 @@
|
||||
"""
|
||||
Legal Templates Chunking — text splitting, type inference, and chunk creation.
|
||||
|
||||
Extracted from legal_templates_ingestion.py to keep files under 500 LOC.
|
||||
|
||||
Lizenz: Apache 2.0
|
||||
"""
|
||||
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from typing import List, Optional
|
||||
|
||||
from template_sources import SourceConfig
|
||||
from github_crawler import ExtractedDocument
|
||||
|
||||
|
||||
# Chunking configuration defaults (can be overridden by env vars in ingestion module)
|
||||
DEFAULT_CHUNK_SIZE = 1000
|
||||
DEFAULT_CHUNK_OVERLAP = 200
|
||||
|
||||
|
||||
@dataclass
|
||||
class TemplateChunk:
|
||||
"""A chunk of template text ready for indexing."""
|
||||
text: str
|
||||
chunk_index: int
|
||||
document_title: str
|
||||
template_type: str
|
||||
clause_category: Optional[str]
|
||||
language: str
|
||||
jurisdiction: str
|
||||
license_id: str
|
||||
license_name: str
|
||||
license_url: str
|
||||
attribution_required: bool
|
||||
share_alike: bool
|
||||
no_derivatives: bool
|
||||
commercial_use: bool
|
||||
source_name: str
|
||||
source_url: str
|
||||
source_repo: Optional[str]
|
||||
source_commit: Optional[str]
|
||||
source_file: str
|
||||
source_hash: str
|
||||
attribution_text: Optional[str]
|
||||
copyright_notice: Optional[str]
|
||||
is_complete_document: bool
|
||||
is_modular: bool
|
||||
requires_customization: bool
|
||||
placeholders: List[str]
|
||||
training_allowed: bool
|
||||
output_allowed: bool
|
||||
modification_allowed: bool
|
||||
distortion_prohibited: bool
|
||||
|
||||
|
||||
@dataclass
|
||||
class IngestionStatus:
|
||||
"""Status of a source ingestion."""
|
||||
source_name: str
|
||||
status: str # "pending", "running", "completed", "failed"
|
||||
documents_found: int = 0
|
||||
chunks_created: int = 0
|
||||
chunks_indexed: int = 0
|
||||
errors: List[str] = field(default_factory=list)
|
||||
started_at: Optional[datetime] = None
|
||||
completed_at: Optional[datetime] = None
|
||||
|
||||
|
||||
def split_sentences(text: str) -> List[str]:
|
||||
"""Split text into sentences with basic abbreviation handling."""
|
||||
# Protect common abbreviations
|
||||
abbreviations = ['bzw', 'ca', 'd.h', 'etc', 'ggf', 'inkl', 'u.a', 'usw', 'z.B', 'z.b', 'e.g', 'i.e', 'vs', 'no']
|
||||
protected = text
|
||||
for abbr in abbreviations:
|
||||
pattern = re.compile(r'\b' + re.escape(abbr) + r'\.', re.IGNORECASE)
|
||||
protected = pattern.sub(abbr.replace('.', '<DOT>') + '<ABBR>', protected)
|
||||
|
||||
# Protect decimal numbers
|
||||
protected = re.sub(r'(\d)\.(\d)', r'\1<DECIMAL>\2', protected)
|
||||
|
||||
# Split on sentence endings
|
||||
sentences = re.split(r'(?<=[.!?])\s+', protected)
|
||||
|
||||
# Restore protected characters
|
||||
result = []
|
||||
for s in sentences:
|
||||
s = s.replace('<DOT>', '.').replace('<ABBR>', '.').replace('<DECIMAL>', '.')
|
||||
s = s.strip()
|
||||
if s:
|
||||
result.append(s)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def chunk_text(
|
||||
text: str,
|
||||
chunk_size: int = DEFAULT_CHUNK_SIZE,
|
||||
overlap: int = DEFAULT_CHUNK_OVERLAP,
|
||||
) -> List[str]:
|
||||
"""
|
||||
Split text into overlapping chunks.
|
||||
Respects paragraph and sentence boundaries where possible.
|
||||
"""
|
||||
if not text:
|
||||
return []
|
||||
|
||||
if len(text) <= chunk_size:
|
||||
return [text.strip()]
|
||||
|
||||
# Split into paragraphs first
|
||||
paragraphs = text.split('\n\n')
|
||||
chunks = []
|
||||
current_chunk: List[str] = []
|
||||
current_length = 0
|
||||
|
||||
for para in paragraphs:
|
||||
para = para.strip()
|
||||
if not para:
|
||||
continue
|
||||
|
||||
para_length = len(para)
|
||||
|
||||
if para_length > chunk_size:
|
||||
# Large paragraph: split by sentences
|
||||
if current_chunk:
|
||||
chunks.append('\n\n'.join(current_chunk))
|
||||
current_chunk = []
|
||||
current_length = 0
|
||||
|
||||
# Split long paragraph by sentences
|
||||
sentences = split_sentences(para)
|
||||
for sentence in sentences:
|
||||
if current_length + len(sentence) + 1 > chunk_size:
|
||||
if current_chunk:
|
||||
chunks.append(' '.join(current_chunk))
|
||||
# Keep overlap
|
||||
overlap_count = max(1, len(current_chunk) // 3)
|
||||
current_chunk = current_chunk[-overlap_count:]
|
||||
current_length = sum(len(s) + 1 for s in current_chunk)
|
||||
current_chunk.append(sentence)
|
||||
current_length += len(sentence) + 1
|
||||
|
||||
elif current_length + para_length + 2 > chunk_size:
|
||||
# Paragraph would exceed chunk size
|
||||
if current_chunk:
|
||||
chunks.append('\n\n'.join(current_chunk))
|
||||
current_chunk = []
|
||||
current_length = 0
|
||||
current_chunk.append(para)
|
||||
current_length = para_length
|
||||
|
||||
else:
|
||||
current_chunk.append(para)
|
||||
current_length += para_length + 2
|
||||
|
||||
# Add final chunk
|
||||
if current_chunk:
|
||||
chunks.append('\n\n'.join(current_chunk))
|
||||
|
||||
return [c.strip() for c in chunks if c.strip()]
|
||||
|
||||
|
||||
def infer_template_type(doc: ExtractedDocument, source: SourceConfig) -> str:
|
||||
"""Infer the template type from document content and metadata."""
|
||||
text_lower = doc.text.lower()
|
||||
title_lower = doc.title.lower()
|
||||
|
||||
# Check known indicators
|
||||
type_indicators = {
|
||||
"privacy_policy": ["datenschutz", "privacy", "personal data", "personenbezogen"],
|
||||
"terms_of_service": ["nutzungsbedingungen", "terms of service", "terms of use", "agb"],
|
||||
"cookie_banner": ["cookie", "cookies", "tracking"],
|
||||
"impressum": ["impressum", "legal notice", "imprint"],
|
||||
"widerruf": ["widerruf", "cancellation", "withdrawal", "right to cancel"],
|
||||
"dpa": ["auftragsverarbeitung", "data processing agreement", "dpa"],
|
||||
"sla": ["service level", "availability", "uptime"],
|
||||
"nda": ["confidential", "non-disclosure", "geheimhaltung", "vertraulich"],
|
||||
"community_guidelines": ["community", "guidelines", "conduct", "verhaltens"],
|
||||
"acceptable_use": ["acceptable use", "acceptable usage", "nutzungsrichtlinien"],
|
||||
}
|
||||
|
||||
for template_type, indicators in type_indicators.items():
|
||||
for indicator in indicators:
|
||||
if indicator in text_lower or indicator in title_lower:
|
||||
return template_type
|
||||
|
||||
# Fall back to source's first template type
|
||||
if source.template_types:
|
||||
return source.template_types[0]
|
||||
|
||||
return "clause" # Generic fallback
|
||||
|
||||
|
||||
def infer_clause_category(text: str) -> Optional[str]:
|
||||
"""Infer the clause category from text content."""
|
||||
text_lower = text.lower()
|
||||
|
||||
categories = {
|
||||
"haftung": ["haftung", "liability", "haftungsausschluss", "limitation"],
|
||||
"datenschutz": ["datenschutz", "privacy", "personal data", "personenbezogen"],
|
||||
"widerruf": ["widerruf", "cancellation", "withdrawal"],
|
||||
"gewaehrleistung": ["gewaehrleistung", "warranty", "garantie"],
|
||||
"kuendigung": ["kuendigung", "termination", "beendigung"],
|
||||
"zahlung": ["zahlung", "payment", "preis", "price"],
|
||||
"gerichtsstand": ["gerichtsstand", "jurisdiction", "governing law"],
|
||||
"aenderungen": ["aenderung", "modification", "amendment"],
|
||||
"schlussbestimmungen": ["schlussbestimmung", "miscellaneous", "final provisions"],
|
||||
}
|
||||
|
||||
for category, indicators in categories.items():
|
||||
for indicator in indicators:
|
||||
if indicator in text_lower:
|
||||
return category
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def create_chunks(
|
||||
doc: ExtractedDocument,
|
||||
source: SourceConfig,
|
||||
chunk_size: int = DEFAULT_CHUNK_SIZE,
|
||||
chunk_overlap: int = DEFAULT_CHUNK_OVERLAP,
|
||||
) -> List[TemplateChunk]:
|
||||
"""Create template chunks from an extracted document."""
|
||||
license_info = source.license_info
|
||||
template_type = infer_template_type(doc, source)
|
||||
|
||||
# Chunk the text
|
||||
text_chunks = chunk_text(doc.text, chunk_size, chunk_overlap)
|
||||
|
||||
chunks = []
|
||||
for i, chunk_text_str in enumerate(text_chunks):
|
||||
# Determine if this is a complete document or a clause
|
||||
is_complete = len(text_chunks) == 1 and len(chunk_text_str) > 500
|
||||
is_modular = len(doc.sections) > 0 or '##' in doc.text
|
||||
requires_customization = len(doc.placeholders) > 0
|
||||
|
||||
# Generate attribution text
|
||||
attribution_text = None
|
||||
if license_info.attribution_required:
|
||||
attribution_text = license_info.get_attribution_text(
|
||||
source.name,
|
||||
doc.source_url or source.get_source_url()
|
||||
)
|
||||
|
||||
chunk = TemplateChunk(
|
||||
text=chunk_text_str,
|
||||
chunk_index=i,
|
||||
document_title=doc.title,
|
||||
template_type=template_type,
|
||||
clause_category=infer_clause_category(chunk_text_str),
|
||||
language=doc.language,
|
||||
jurisdiction=source.jurisdiction,
|
||||
license_id=license_info.id.value,
|
||||
license_name=license_info.name,
|
||||
license_url=license_info.url,
|
||||
attribution_required=license_info.attribution_required,
|
||||
share_alike=license_info.share_alike,
|
||||
no_derivatives=license_info.no_derivatives,
|
||||
commercial_use=license_info.commercial_use,
|
||||
source_name=source.name,
|
||||
source_url=doc.source_url or source.get_source_url(),
|
||||
source_repo=source.repo_url,
|
||||
source_commit=doc.source_commit,
|
||||
source_file=doc.file_path,
|
||||
source_hash=doc.source_hash,
|
||||
attribution_text=attribution_text,
|
||||
copyright_notice=None,
|
||||
is_complete_document=is_complete,
|
||||
is_modular=is_modular,
|
||||
requires_customization=requires_customization,
|
||||
placeholders=doc.placeholders,
|
||||
training_allowed=license_info.training_allowed,
|
||||
output_allowed=license_info.output_allowed,
|
||||
modification_allowed=license_info.modification_allowed,
|
||||
distortion_prohibited=license_info.distortion_prohibited,
|
||||
)
|
||||
chunks.append(chunk)
|
||||
|
||||
return chunks
|
||||
165
klausur-service/backend/legal_templates_cli.py
Normal file
165
klausur-service/backend/legal_templates_cli.py
Normal file
@@ -0,0 +1,165 @@
|
||||
"""
|
||||
Legal Templates CLI — command-line entry point for ingestion and search.
|
||||
|
||||
Extracted from legal_templates_ingestion.py to keep files under 500 LOC.
|
||||
|
||||
Usage:
|
||||
python legal_templates_cli.py --ingest-all
|
||||
python legal_templates_cli.py --ingest-source github-site-policy
|
||||
python legal_templates_cli.py --status
|
||||
python legal_templates_cli.py --search "Datenschutzerklaerung"
|
||||
|
||||
Lizenz: Apache 2.0
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
|
||||
from template_sources import TEMPLATE_SOURCES, LicenseType
|
||||
from legal_templates_ingestion import LegalTemplatesIngestion
|
||||
|
||||
|
||||
async def main():
|
||||
"""CLI entry point."""
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description="Legal Templates Ingestion")
|
||||
parser.add_argument(
|
||||
"--ingest-all",
|
||||
action="store_true",
|
||||
help="Ingest all enabled sources"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ingest-source",
|
||||
type=str,
|
||||
metavar="NAME",
|
||||
help="Ingest a specific source by name"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ingest-license",
|
||||
type=str,
|
||||
choices=["cc0", "mit", "cc_by_4", "public_domain"],
|
||||
help="Ingest all sources of a specific license type"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-priority",
|
||||
type=int,
|
||||
default=3,
|
||||
help="Maximum priority level to ingest (1=highest, 5=lowest)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--status",
|
||||
action="store_true",
|
||||
help="Show collection status"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--search",
|
||||
type=str,
|
||||
metavar="QUERY",
|
||||
help="Test search query"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--template-type",
|
||||
type=str,
|
||||
help="Filter search by template type"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--language",
|
||||
type=str,
|
||||
help="Filter search by language"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--reset",
|
||||
action="store_true",
|
||||
help="Reset (delete and recreate) the collection"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--delete-source",
|
||||
type=str,
|
||||
metavar="NAME",
|
||||
help="Delete all chunks from a source"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
ingestion = LegalTemplatesIngestion()
|
||||
|
||||
try:
|
||||
if args.reset:
|
||||
ingestion.reset_collection()
|
||||
print("Collection reset successfully")
|
||||
|
||||
elif args.delete_source:
|
||||
count = ingestion.delete_source(args.delete_source)
|
||||
print(f"Deleted {count} chunks from {args.delete_source}")
|
||||
|
||||
elif args.status:
|
||||
status = ingestion.get_status()
|
||||
print(json.dumps(status, indent=2, default=str))
|
||||
|
||||
elif args.ingest_all:
|
||||
print(f"Ingesting all sources (max priority: {args.max_priority})...")
|
||||
results = await ingestion.ingest_all(max_priority=args.max_priority)
|
||||
print("\nResults:")
|
||||
for name, status in results.items():
|
||||
print(f" {name}: {status.chunks_indexed} chunks ({status.status})")
|
||||
if status.errors:
|
||||
for error in status.errors:
|
||||
print(f" ERROR: {error}")
|
||||
total = sum(s.chunks_indexed for s in results.values())
|
||||
print(f"\nTotal: {total} chunks indexed")
|
||||
|
||||
elif args.ingest_source:
|
||||
source = next(
|
||||
(s for s in TEMPLATE_SOURCES if s.name == args.ingest_source),
|
||||
None
|
||||
)
|
||||
if not source:
|
||||
print(f"Unknown source: {args.ingest_source}")
|
||||
print("Available sources:")
|
||||
for s in TEMPLATE_SOURCES:
|
||||
print(f" - {s.name}")
|
||||
return
|
||||
|
||||
print(f"Ingesting: {source.name}")
|
||||
status = await ingestion.ingest_source(source)
|
||||
print(f"\nResult: {status.chunks_indexed} chunks ({status.status})")
|
||||
if status.errors:
|
||||
for error in status.errors:
|
||||
print(f" ERROR: {error}")
|
||||
|
||||
elif args.ingest_license:
|
||||
license_type = LicenseType(args.ingest_license)
|
||||
print(f"Ingesting all {license_type.value} sources...")
|
||||
results = await ingestion.ingest_by_license(license_type)
|
||||
print("\nResults:")
|
||||
for name, status in results.items():
|
||||
print(f" {name}: {status.chunks_indexed} chunks ({status.status})")
|
||||
|
||||
elif args.search:
|
||||
print(f"Searching: {args.search}")
|
||||
results = await ingestion.search(
|
||||
args.search,
|
||||
template_type=args.template_type,
|
||||
language=args.language,
|
||||
)
|
||||
print(f"\nFound {len(results)} results:")
|
||||
for i, result in enumerate(results, 1):
|
||||
print(f"\n{i}. [{result['template_type']}] {result['document_title']}")
|
||||
print(f" Score: {result['score']:.3f}")
|
||||
print(f" License: {result['license_name']}")
|
||||
print(f" Source: {result['source_name']}")
|
||||
print(f" Language: {result['language']}")
|
||||
if result['attribution_required']:
|
||||
print(f" Attribution: {result['attribution_text']}")
|
||||
print(f" Text: {result['text'][:200]}...")
|
||||
|
||||
else:
|
||||
parser.print_help()
|
||||
|
||||
finally:
|
||||
await ingestion.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@@ -8,18 +8,16 @@ proper attribution tracking.
|
||||
Collection: bp_legal_templates
|
||||
|
||||
Usage:
|
||||
python legal_templates_ingestion.py --ingest-all
|
||||
python legal_templates_ingestion.py --ingest-source github-site-policy
|
||||
python legal_templates_ingestion.py --status
|
||||
python legal_templates_ingestion.py --search "Datenschutzerklaerung"
|
||||
python legal_templates_cli.py --ingest-all
|
||||
python legal_templates_cli.py --ingest-source github-site-policy
|
||||
python legal_templates_cli.py --status
|
||||
python legal_templates_cli.py --search "Datenschutzerklaerung"
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
from urllib.parse import urlparse
|
||||
@@ -50,6 +48,17 @@ from github_crawler import (
|
||||
RepositoryDownloader,
|
||||
)
|
||||
|
||||
# Re-export from chunking module for backward compatibility
|
||||
from legal_templates_chunking import ( # noqa: F401
|
||||
IngestionStatus,
|
||||
TemplateChunk,
|
||||
chunk_text,
|
||||
create_chunks,
|
||||
infer_clause_category,
|
||||
infer_template_type,
|
||||
split_sentences,
|
||||
)
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -78,54 +87,6 @@ MAX_RETRIES = 3
|
||||
RETRY_DELAY = 3.0
|
||||
|
||||
|
||||
@dataclass
|
||||
class IngestionStatus:
|
||||
"""Status of a source ingestion."""
|
||||
source_name: str
|
||||
status: str # "pending", "running", "completed", "failed"
|
||||
documents_found: int = 0
|
||||
chunks_created: int = 0
|
||||
chunks_indexed: int = 0
|
||||
errors: List[str] = field(default_factory=list)
|
||||
started_at: Optional[datetime] = None
|
||||
completed_at: Optional[datetime] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class TemplateChunk:
|
||||
"""A chunk of template text ready for indexing."""
|
||||
text: str
|
||||
chunk_index: int
|
||||
document_title: str
|
||||
template_type: str
|
||||
clause_category: Optional[str]
|
||||
language: str
|
||||
jurisdiction: str
|
||||
license_id: str
|
||||
license_name: str
|
||||
license_url: str
|
||||
attribution_required: bool
|
||||
share_alike: bool
|
||||
no_derivatives: bool
|
||||
commercial_use: bool
|
||||
source_name: str
|
||||
source_url: str
|
||||
source_repo: Optional[str]
|
||||
source_commit: Optional[str]
|
||||
source_file: str
|
||||
source_hash: str
|
||||
attribution_text: Optional[str]
|
||||
copyright_notice: Optional[str]
|
||||
is_complete_document: bool
|
||||
is_modular: bool
|
||||
requires_customization: bool
|
||||
placeholders: List[str]
|
||||
training_allowed: bool
|
||||
output_allowed: bool
|
||||
modification_allowed: bool
|
||||
distortion_prohibited: bool
|
||||
|
||||
|
||||
class LegalTemplatesIngestion:
|
||||
"""Handles ingestion of legal templates into Qdrant."""
|
||||
|
||||
@@ -168,212 +129,6 @@ class LegalTemplatesIngestion:
|
||||
logger.error(f"Embedding generation failed: {e}")
|
||||
raise
|
||||
|
||||
def _chunk_text(self, text: str, chunk_size: int = CHUNK_SIZE, overlap: int = CHUNK_OVERLAP) -> List[str]:
|
||||
"""
|
||||
Split text into overlapping chunks.
|
||||
Respects paragraph and sentence boundaries where possible.
|
||||
"""
|
||||
if not text:
|
||||
return []
|
||||
|
||||
if len(text) <= chunk_size:
|
||||
return [text.strip()]
|
||||
|
||||
# Split into paragraphs first
|
||||
paragraphs = text.split('\n\n')
|
||||
chunks = []
|
||||
current_chunk = []
|
||||
current_length = 0
|
||||
|
||||
for para in paragraphs:
|
||||
para = para.strip()
|
||||
if not para:
|
||||
continue
|
||||
|
||||
para_length = len(para)
|
||||
|
||||
if para_length > chunk_size:
|
||||
# Large paragraph: split by sentences
|
||||
if current_chunk:
|
||||
chunks.append('\n\n'.join(current_chunk))
|
||||
current_chunk = []
|
||||
current_length = 0
|
||||
|
||||
# Split long paragraph by sentences
|
||||
sentences = self._split_sentences(para)
|
||||
for sentence in sentences:
|
||||
if current_length + len(sentence) + 1 > chunk_size:
|
||||
if current_chunk:
|
||||
chunks.append(' '.join(current_chunk))
|
||||
# Keep overlap
|
||||
overlap_count = max(1, len(current_chunk) // 3)
|
||||
current_chunk = current_chunk[-overlap_count:]
|
||||
current_length = sum(len(s) + 1 for s in current_chunk)
|
||||
current_chunk.append(sentence)
|
||||
current_length += len(sentence) + 1
|
||||
|
||||
elif current_length + para_length + 2 > chunk_size:
|
||||
# Paragraph would exceed chunk size
|
||||
if current_chunk:
|
||||
chunks.append('\n\n'.join(current_chunk))
|
||||
current_chunk = []
|
||||
current_length = 0
|
||||
current_chunk.append(para)
|
||||
current_length = para_length
|
||||
|
||||
else:
|
||||
current_chunk.append(para)
|
||||
current_length += para_length + 2
|
||||
|
||||
# Add final chunk
|
||||
if current_chunk:
|
||||
chunks.append('\n\n'.join(current_chunk))
|
||||
|
||||
return [c.strip() for c in chunks if c.strip()]
|
||||
|
||||
def _split_sentences(self, text: str) -> List[str]:
|
||||
"""Split text into sentences with basic abbreviation handling."""
|
||||
import re
|
||||
|
||||
# Protect common abbreviations
|
||||
abbreviations = ['bzw', 'ca', 'd.h', 'etc', 'ggf', 'inkl', 'u.a', 'usw', 'z.B', 'z.b', 'e.g', 'i.e', 'vs', 'no']
|
||||
protected = text
|
||||
for abbr in abbreviations:
|
||||
pattern = re.compile(r'\b' + re.escape(abbr) + r'\.', re.IGNORECASE)
|
||||
protected = pattern.sub(abbr.replace('.', '<DOT>') + '<ABBR>', protected)
|
||||
|
||||
# Protect decimal numbers
|
||||
protected = re.sub(r'(\d)\.(\d)', r'\1<DECIMAL>\2', protected)
|
||||
|
||||
# Split on sentence endings
|
||||
sentences = re.split(r'(?<=[.!?])\s+', protected)
|
||||
|
||||
# Restore protected characters
|
||||
result = []
|
||||
for s in sentences:
|
||||
s = s.replace('<DOT>', '.').replace('<ABBR>', '.').replace('<DECIMAL>', '.')
|
||||
s = s.strip()
|
||||
if s:
|
||||
result.append(s)
|
||||
|
||||
return result
|
||||
|
||||
def _infer_template_type(self, doc: ExtractedDocument, source: SourceConfig) -> str:
|
||||
"""Infer the template type from document content and metadata."""
|
||||
text_lower = doc.text.lower()
|
||||
title_lower = doc.title.lower()
|
||||
|
||||
# Check known indicators
|
||||
type_indicators = {
|
||||
"privacy_policy": ["datenschutz", "privacy", "personal data", "personenbezogen"],
|
||||
"terms_of_service": ["nutzungsbedingungen", "terms of service", "terms of use", "agb"],
|
||||
"cookie_banner": ["cookie", "cookies", "tracking"],
|
||||
"impressum": ["impressum", "legal notice", "imprint"],
|
||||
"widerruf": ["widerruf", "cancellation", "withdrawal", "right to cancel"],
|
||||
"dpa": ["auftragsverarbeitung", "data processing agreement", "dpa"],
|
||||
"sla": ["service level", "availability", "uptime"],
|
||||
"nda": ["confidential", "non-disclosure", "geheimhaltung", "vertraulich"],
|
||||
"community_guidelines": ["community", "guidelines", "conduct", "verhaltens"],
|
||||
"acceptable_use": ["acceptable use", "acceptable usage", "nutzungsrichtlinien"],
|
||||
}
|
||||
|
||||
for template_type, indicators in type_indicators.items():
|
||||
for indicator in indicators:
|
||||
if indicator in text_lower or indicator in title_lower:
|
||||
return template_type
|
||||
|
||||
# Fall back to source's first template type
|
||||
if source.template_types:
|
||||
return source.template_types[0]
|
||||
|
||||
return "clause" # Generic fallback
|
||||
|
||||
def _infer_clause_category(self, text: str) -> Optional[str]:
|
||||
"""Infer the clause category from text content."""
|
||||
text_lower = text.lower()
|
||||
|
||||
categories = {
|
||||
"haftung": ["haftung", "liability", "haftungsausschluss", "limitation"],
|
||||
"datenschutz": ["datenschutz", "privacy", "personal data", "personenbezogen"],
|
||||
"widerruf": ["widerruf", "cancellation", "withdrawal"],
|
||||
"gewaehrleistung": ["gewaehrleistung", "warranty", "garantie"],
|
||||
"kuendigung": ["kuendigung", "termination", "beendigung"],
|
||||
"zahlung": ["zahlung", "payment", "preis", "price"],
|
||||
"gerichtsstand": ["gerichtsstand", "jurisdiction", "governing law"],
|
||||
"aenderungen": ["aenderung", "modification", "amendment"],
|
||||
"schlussbestimmungen": ["schlussbestimmung", "miscellaneous", "final provisions"],
|
||||
}
|
||||
|
||||
for category, indicators in categories.items():
|
||||
for indicator in indicators:
|
||||
if indicator in text_lower:
|
||||
return category
|
||||
|
||||
return None
|
||||
|
||||
def _create_chunks(
|
||||
self,
|
||||
doc: ExtractedDocument,
|
||||
source: SourceConfig,
|
||||
) -> List[TemplateChunk]:
|
||||
"""Create template chunks from an extracted document."""
|
||||
license_info = source.license_info
|
||||
template_type = self._infer_template_type(doc, source)
|
||||
|
||||
# Chunk the text
|
||||
text_chunks = self._chunk_text(doc.text)
|
||||
|
||||
chunks = []
|
||||
for i, chunk_text in enumerate(text_chunks):
|
||||
# Determine if this is a complete document or a clause
|
||||
is_complete = len(text_chunks) == 1 and len(chunk_text) > 500
|
||||
is_modular = len(doc.sections) > 0 or '##' in doc.text
|
||||
requires_customization = len(doc.placeholders) > 0
|
||||
|
||||
# Generate attribution text
|
||||
attribution_text = None
|
||||
if license_info.attribution_required:
|
||||
attribution_text = license_info.get_attribution_text(
|
||||
source.name,
|
||||
doc.source_url or source.get_source_url()
|
||||
)
|
||||
|
||||
chunk = TemplateChunk(
|
||||
text=chunk_text,
|
||||
chunk_index=i,
|
||||
document_title=doc.title,
|
||||
template_type=template_type,
|
||||
clause_category=self._infer_clause_category(chunk_text),
|
||||
language=doc.language,
|
||||
jurisdiction=source.jurisdiction,
|
||||
license_id=license_info.id.value,
|
||||
license_name=license_info.name,
|
||||
license_url=license_info.url,
|
||||
attribution_required=license_info.attribution_required,
|
||||
share_alike=license_info.share_alike,
|
||||
no_derivatives=license_info.no_derivatives,
|
||||
commercial_use=license_info.commercial_use,
|
||||
source_name=source.name,
|
||||
source_url=doc.source_url or source.get_source_url(),
|
||||
source_repo=source.repo_url,
|
||||
source_commit=doc.source_commit,
|
||||
source_file=doc.file_path,
|
||||
source_hash=doc.source_hash,
|
||||
attribution_text=attribution_text,
|
||||
copyright_notice=None, # Could be extracted from doc if present
|
||||
is_complete_document=is_complete,
|
||||
is_modular=is_modular,
|
||||
requires_customization=requires_customization,
|
||||
placeholders=doc.placeholders,
|
||||
training_allowed=license_info.training_allowed,
|
||||
output_allowed=license_info.output_allowed,
|
||||
modification_allowed=license_info.modification_allowed,
|
||||
distortion_prohibited=license_info.distortion_prohibited,
|
||||
)
|
||||
chunks.append(chunk)
|
||||
|
||||
return chunks
|
||||
|
||||
async def ingest_source(self, source: SourceConfig) -> IngestionStatus:
|
||||
"""Ingest a single source into Qdrant."""
|
||||
status = IngestionStatus(
|
||||
@@ -405,7 +160,7 @@ class LegalTemplatesIngestion:
|
||||
# Create chunks from all documents
|
||||
all_chunks: List[TemplateChunk] = []
|
||||
for doc in documents:
|
||||
chunks = self._create_chunks(doc, source)
|
||||
chunks = create_chunks(doc, source, CHUNK_SIZE, CHUNK_OVERLAP)
|
||||
all_chunks.extend(chunks)
|
||||
status.chunks_created += len(chunks)
|
||||
|
||||
@@ -637,21 +392,7 @@ class LegalTemplatesIngestion:
|
||||
attribution_required: Optional[bool] = None,
|
||||
top_k: int = 10,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Search the legal templates collection.
|
||||
|
||||
Args:
|
||||
query: Search query text
|
||||
template_type: Filter by template type (e.g., "privacy_policy")
|
||||
license_types: Filter by license types (e.g., ["cc0", "mit"])
|
||||
language: Filter by language (e.g., "de")
|
||||
jurisdiction: Filter by jurisdiction (e.g., "DE")
|
||||
attribution_required: Filter by attribution requirement
|
||||
top_k: Number of results to return
|
||||
|
||||
Returns:
|
||||
List of search results with full metadata
|
||||
"""
|
||||
"""Search the legal templates collection."""
|
||||
# Generate query embedding
|
||||
embeddings = await self._generate_embeddings([query])
|
||||
query_vector = embeddings[0]
|
||||
@@ -661,45 +402,27 @@ class LegalTemplatesIngestion:
|
||||
|
||||
if template_type:
|
||||
must_conditions.append(
|
||||
FieldCondition(
|
||||
key="template_type",
|
||||
match=MatchValue(value=template_type),
|
||||
)
|
||||
FieldCondition(key="template_type", match=MatchValue(value=template_type))
|
||||
)
|
||||
|
||||
if language:
|
||||
must_conditions.append(
|
||||
FieldCondition(
|
||||
key="language",
|
||||
match=MatchValue(value=language),
|
||||
)
|
||||
FieldCondition(key="language", match=MatchValue(value=language))
|
||||
)
|
||||
|
||||
if jurisdiction:
|
||||
must_conditions.append(
|
||||
FieldCondition(
|
||||
key="jurisdiction",
|
||||
match=MatchValue(value=jurisdiction),
|
||||
)
|
||||
FieldCondition(key="jurisdiction", match=MatchValue(value=jurisdiction))
|
||||
)
|
||||
|
||||
if attribution_required is not None:
|
||||
must_conditions.append(
|
||||
FieldCondition(
|
||||
key="attribution_required",
|
||||
match=MatchValue(value=attribution_required),
|
||||
)
|
||||
FieldCondition(key="attribution_required", match=MatchValue(value=attribution_required))
|
||||
)
|
||||
|
||||
# License type filter (OR condition)
|
||||
should_conditions = []
|
||||
if license_types:
|
||||
for license_type in license_types:
|
||||
for lt in license_types:
|
||||
should_conditions.append(
|
||||
FieldCondition(
|
||||
key="license_id",
|
||||
match=MatchValue(value=license_type),
|
||||
)
|
||||
FieldCondition(key="license_id", match=MatchValue(value=lt))
|
||||
)
|
||||
|
||||
# Construct filter
|
||||
@@ -747,196 +470,31 @@ class LegalTemplatesIngestion:
|
||||
|
||||
def delete_source(self, source_name: str) -> int:
|
||||
"""Delete all chunks from a specific source."""
|
||||
# First count how many we're deleting
|
||||
count_result = self.qdrant.count(
|
||||
collection_name=LEGAL_TEMPLATES_COLLECTION,
|
||||
count_filter=Filter(
|
||||
must=[
|
||||
FieldCondition(
|
||||
key="source_name",
|
||||
match=MatchValue(value=source_name),
|
||||
)
|
||||
]
|
||||
must=[FieldCondition(key="source_name", match=MatchValue(value=source_name))]
|
||||
),
|
||||
)
|
||||
|
||||
# Delete by filter
|
||||
self.qdrant.delete(
|
||||
collection_name=LEGAL_TEMPLATES_COLLECTION,
|
||||
points_selector=Filter(
|
||||
must=[
|
||||
FieldCondition(
|
||||
key="source_name",
|
||||
match=MatchValue(value=source_name),
|
||||
)
|
||||
]
|
||||
must=[FieldCondition(key="source_name", match=MatchValue(value=source_name))]
|
||||
),
|
||||
)
|
||||
|
||||
return count_result.count
|
||||
|
||||
def reset_collection(self):
|
||||
"""Delete and recreate the collection."""
|
||||
logger.warning(f"Resetting collection: {LEGAL_TEMPLATES_COLLECTION}")
|
||||
|
||||
# Delete collection
|
||||
try:
|
||||
self.qdrant.delete_collection(LEGAL_TEMPLATES_COLLECTION)
|
||||
except Exception:
|
||||
pass # Collection might not exist
|
||||
|
||||
# Recreate
|
||||
pass
|
||||
self._ensure_collection()
|
||||
self._ingestion_status.clear()
|
||||
|
||||
logger.info(f"Collection {LEGAL_TEMPLATES_COLLECTION} reset")
|
||||
|
||||
async def close(self):
|
||||
"""Close HTTP client."""
|
||||
await self.http_client.aclose()
|
||||
|
||||
|
||||
async def main():
|
||||
"""CLI entry point."""
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description="Legal Templates Ingestion")
|
||||
parser.add_argument(
|
||||
"--ingest-all",
|
||||
action="store_true",
|
||||
help="Ingest all enabled sources"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ingest-source",
|
||||
type=str,
|
||||
metavar="NAME",
|
||||
help="Ingest a specific source by name"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ingest-license",
|
||||
type=str,
|
||||
choices=["cc0", "mit", "cc_by_4", "public_domain"],
|
||||
help="Ingest all sources of a specific license type"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-priority",
|
||||
type=int,
|
||||
default=3,
|
||||
help="Maximum priority level to ingest (1=highest, 5=lowest)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--status",
|
||||
action="store_true",
|
||||
help="Show collection status"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--search",
|
||||
type=str,
|
||||
metavar="QUERY",
|
||||
help="Test search query"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--template-type",
|
||||
type=str,
|
||||
help="Filter search by template type"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--language",
|
||||
type=str,
|
||||
help="Filter search by language"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--reset",
|
||||
action="store_true",
|
||||
help="Reset (delete and recreate) the collection"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--delete-source",
|
||||
type=str,
|
||||
metavar="NAME",
|
||||
help="Delete all chunks from a source"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
ingestion = LegalTemplatesIngestion()
|
||||
|
||||
try:
|
||||
if args.reset:
|
||||
ingestion.reset_collection()
|
||||
print("Collection reset successfully")
|
||||
|
||||
elif args.delete_source:
|
||||
count = ingestion.delete_source(args.delete_source)
|
||||
print(f"Deleted {count} chunks from {args.delete_source}")
|
||||
|
||||
elif args.status:
|
||||
status = ingestion.get_status()
|
||||
print(json.dumps(status, indent=2, default=str))
|
||||
|
||||
elif args.ingest_all:
|
||||
print(f"Ingesting all sources (max priority: {args.max_priority})...")
|
||||
results = await ingestion.ingest_all(max_priority=args.max_priority)
|
||||
print("\nResults:")
|
||||
for name, status in results.items():
|
||||
print(f" {name}: {status.chunks_indexed} chunks ({status.status})")
|
||||
if status.errors:
|
||||
for error in status.errors:
|
||||
print(f" ERROR: {error}")
|
||||
total = sum(s.chunks_indexed for s in results.values())
|
||||
print(f"\nTotal: {total} chunks indexed")
|
||||
|
||||
elif args.ingest_source:
|
||||
source = next(
|
||||
(s for s in TEMPLATE_SOURCES if s.name == args.ingest_source),
|
||||
None
|
||||
)
|
||||
if not source:
|
||||
print(f"Unknown source: {args.ingest_source}")
|
||||
print("Available sources:")
|
||||
for s in TEMPLATE_SOURCES:
|
||||
print(f" - {s.name}")
|
||||
return
|
||||
|
||||
print(f"Ingesting: {source.name}")
|
||||
status = await ingestion.ingest_source(source)
|
||||
print(f"\nResult: {status.chunks_indexed} chunks ({status.status})")
|
||||
if status.errors:
|
||||
for error in status.errors:
|
||||
print(f" ERROR: {error}")
|
||||
|
||||
elif args.ingest_license:
|
||||
license_type = LicenseType(args.ingest_license)
|
||||
print(f"Ingesting all {license_type.value} sources...")
|
||||
results = await ingestion.ingest_by_license(license_type)
|
||||
print("\nResults:")
|
||||
for name, status in results.items():
|
||||
print(f" {name}: {status.chunks_indexed} chunks ({status.status})")
|
||||
|
||||
elif args.search:
|
||||
print(f"Searching: {args.search}")
|
||||
results = await ingestion.search(
|
||||
args.search,
|
||||
template_type=args.template_type,
|
||||
language=args.language,
|
||||
)
|
||||
print(f"\nFound {len(results)} results:")
|
||||
for i, result in enumerate(results, 1):
|
||||
print(f"\n{i}. [{result['template_type']}] {result['document_title']}")
|
||||
print(f" Score: {result['score']:.3f}")
|
||||
print(f" License: {result['license_name']}")
|
||||
print(f" Source: {result['source_name']}")
|
||||
print(f" Language: {result['language']}")
|
||||
if result['attribution_required']:
|
||||
print(f" Attribution: {result['attribution_text']}")
|
||||
print(f" Text: {result['text'][:200]}...")
|
||||
|
||||
else:
|
||||
parser.print_help()
|
||||
|
||||
finally:
|
||||
await ingestion.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
156
klausur-service/backend/mail/mail_db_accounts.py
Normal file
156
klausur-service/backend/mail/mail_db_accounts.py
Normal file
@@ -0,0 +1,156 @@
|
||||
"""
|
||||
Mail Database - Email Account Operations.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from typing import Optional, List, Dict
|
||||
|
||||
from .mail_db_pool import get_pool
|
||||
|
||||
|
||||
async def create_email_account(
|
||||
user_id: str,
|
||||
tenant_id: str,
|
||||
email: str,
|
||||
display_name: str,
|
||||
account_type: str,
|
||||
imap_host: str,
|
||||
imap_port: int,
|
||||
imap_ssl: bool,
|
||||
smtp_host: str,
|
||||
smtp_port: int,
|
||||
smtp_ssl: bool,
|
||||
vault_path: str,
|
||||
) -> Optional[str]:
|
||||
"""Create a new email account. Returns the account ID."""
|
||||
pool = await get_pool()
|
||||
if pool is None:
|
||||
return None
|
||||
|
||||
account_id = str(uuid.uuid4())
|
||||
try:
|
||||
async with pool.acquire() as conn:
|
||||
await conn.execute(
|
||||
"""
|
||||
INSERT INTO external_email_accounts
|
||||
(id, user_id, tenant_id, email, display_name, account_type,
|
||||
imap_host, imap_port, imap_ssl, smtp_host, smtp_port, smtp_ssl, vault_path)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13)
|
||||
""",
|
||||
account_id, user_id, tenant_id, email, display_name, account_type,
|
||||
imap_host, imap_port, imap_ssl, smtp_host, smtp_port, smtp_ssl, vault_path
|
||||
)
|
||||
return account_id
|
||||
except Exception as e:
|
||||
print(f"Failed to create email account: {e}")
|
||||
return None
|
||||
|
||||
|
||||
async def get_email_accounts(
|
||||
user_id: str,
|
||||
tenant_id: Optional[str] = None,
|
||||
) -> List[Dict]:
|
||||
"""Get all email accounts for a user."""
|
||||
pool = await get_pool()
|
||||
if pool is None:
|
||||
return []
|
||||
|
||||
try:
|
||||
async with pool.acquire() as conn:
|
||||
if tenant_id:
|
||||
rows = await conn.fetch(
|
||||
"""
|
||||
SELECT * FROM external_email_accounts
|
||||
WHERE user_id = $1 AND tenant_id = $2
|
||||
ORDER BY created_at
|
||||
""",
|
||||
user_id, tenant_id
|
||||
)
|
||||
else:
|
||||
rows = await conn.fetch(
|
||||
"""
|
||||
SELECT * FROM external_email_accounts
|
||||
WHERE user_id = $1
|
||||
ORDER BY created_at
|
||||
""",
|
||||
user_id
|
||||
)
|
||||
return [dict(r) for r in rows]
|
||||
except Exception as e:
|
||||
print(f"Failed to get email accounts: {e}")
|
||||
return []
|
||||
|
||||
|
||||
async def get_email_account(account_id: str, user_id: str) -> Optional[Dict]:
|
||||
"""Get a single email account."""
|
||||
pool = await get_pool()
|
||||
if pool is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
async with pool.acquire() as conn:
|
||||
row = await conn.fetchrow(
|
||||
"""
|
||||
SELECT * FROM external_email_accounts
|
||||
WHERE id = $1 AND user_id = $2
|
||||
""",
|
||||
account_id, user_id
|
||||
)
|
||||
return dict(row) if row else None
|
||||
except Exception as e:
|
||||
print(f"Failed to get email account: {e}")
|
||||
return None
|
||||
|
||||
|
||||
async def update_account_status(
|
||||
account_id: str,
|
||||
status: str,
|
||||
sync_error: Optional[str] = None,
|
||||
email_count: Optional[int] = None,
|
||||
unread_count: Optional[int] = None,
|
||||
) -> bool:
|
||||
"""Update account sync status."""
|
||||
pool = await get_pool()
|
||||
if pool is None:
|
||||
return False
|
||||
|
||||
try:
|
||||
async with pool.acquire() as conn:
|
||||
await conn.execute(
|
||||
"""
|
||||
UPDATE external_email_accounts SET
|
||||
status = $2,
|
||||
sync_error = $3,
|
||||
email_count = COALESCE($4, email_count),
|
||||
unread_count = COALESCE($5, unread_count),
|
||||
last_sync = NOW(),
|
||||
updated_at = NOW()
|
||||
WHERE id = $1
|
||||
""",
|
||||
account_id, status, sync_error, email_count, unread_count
|
||||
)
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"Failed to update account status: {e}")
|
||||
return False
|
||||
|
||||
|
||||
async def delete_email_account(account_id: str, user_id: str) -> bool:
|
||||
"""Delete an email account (cascades to emails)."""
|
||||
pool = await get_pool()
|
||||
if pool is None:
|
||||
return False
|
||||
|
||||
try:
|
||||
async with pool.acquire() as conn:
|
||||
result = await conn.execute(
|
||||
"""
|
||||
DELETE FROM external_email_accounts
|
||||
WHERE id = $1 AND user_id = $2
|
||||
""",
|
||||
account_id, user_id
|
||||
)
|
||||
return "DELETE" in result
|
||||
except Exception as e:
|
||||
print(f"Failed to delete email account: {e}")
|
||||
return False
|
||||
225
klausur-service/backend/mail/mail_db_emails.py
Normal file
225
klausur-service/backend/mail/mail_db_emails.py
Normal file
@@ -0,0 +1,225 @@
|
||||
"""
|
||||
Mail Database - Aggregated Email Operations.
|
||||
"""
|
||||
|
||||
import json
|
||||
import uuid
|
||||
from typing import Optional, List, Dict
|
||||
from datetime import datetime
|
||||
|
||||
from .mail_db_pool import get_pool
|
||||
|
||||
|
||||
async def upsert_email(
|
||||
account_id: str,
|
||||
user_id: str,
|
||||
tenant_id: str,
|
||||
message_id: str,
|
||||
subject: str,
|
||||
sender_email: str,
|
||||
sender_name: Optional[str],
|
||||
recipients: List[str],
|
||||
cc: List[str],
|
||||
body_preview: Optional[str],
|
||||
body_text: Optional[str],
|
||||
body_html: Optional[str],
|
||||
has_attachments: bool,
|
||||
attachments: List[Dict],
|
||||
headers: Dict,
|
||||
folder: str,
|
||||
date_sent: datetime,
|
||||
date_received: datetime,
|
||||
) -> Optional[str]:
|
||||
"""Insert or update an email. Returns the email ID."""
|
||||
pool = await get_pool()
|
||||
if pool is None:
|
||||
return None
|
||||
|
||||
email_id = str(uuid.uuid4())
|
||||
try:
|
||||
async with pool.acquire() as conn:
|
||||
row = await conn.fetchrow(
|
||||
"""
|
||||
INSERT INTO aggregated_emails
|
||||
(id, account_id, user_id, tenant_id, message_id, subject,
|
||||
sender_email, sender_name, recipients, cc, body_preview,
|
||||
body_text, body_html, has_attachments, attachments, headers,
|
||||
folder, date_sent, date_received)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19)
|
||||
ON CONFLICT (account_id, message_id) DO UPDATE SET
|
||||
subject = EXCLUDED.subject,
|
||||
is_read = EXCLUDED.is_read,
|
||||
folder = EXCLUDED.folder
|
||||
RETURNING id
|
||||
""",
|
||||
email_id, account_id, user_id, tenant_id, message_id, subject,
|
||||
sender_email, sender_name, json.dumps(recipients), json.dumps(cc),
|
||||
body_preview, body_text, body_html, has_attachments,
|
||||
json.dumps(attachments), json.dumps(headers), folder,
|
||||
date_sent, date_received
|
||||
)
|
||||
return row['id'] if row else None
|
||||
except Exception as e:
|
||||
print(f"Failed to upsert email: {e}")
|
||||
return None
|
||||
|
||||
|
||||
async def get_unified_inbox(
|
||||
user_id: str,
|
||||
account_ids: Optional[List[str]] = None,
|
||||
categories: Optional[List[str]] = None,
|
||||
is_read: Optional[bool] = None,
|
||||
is_starred: Optional[bool] = None,
|
||||
limit: int = 50,
|
||||
offset: int = 0,
|
||||
) -> List[Dict]:
|
||||
"""Get unified inbox with filtering."""
|
||||
pool = await get_pool()
|
||||
if pool is None:
|
||||
return []
|
||||
|
||||
try:
|
||||
async with pool.acquire() as conn:
|
||||
conditions = ["user_id = $1", "is_deleted = FALSE"]
|
||||
params = [user_id]
|
||||
param_idx = 2
|
||||
|
||||
if account_ids:
|
||||
conditions.append(f"account_id = ANY(${param_idx})")
|
||||
params.append(account_ids)
|
||||
param_idx += 1
|
||||
|
||||
if categories:
|
||||
conditions.append(f"category = ANY(${param_idx})")
|
||||
params.append(categories)
|
||||
param_idx += 1
|
||||
|
||||
if is_read is not None:
|
||||
conditions.append(f"is_read = ${param_idx}")
|
||||
params.append(is_read)
|
||||
param_idx += 1
|
||||
|
||||
if is_starred is not None:
|
||||
conditions.append(f"is_starred = ${param_idx}")
|
||||
params.append(is_starred)
|
||||
param_idx += 1
|
||||
|
||||
where_clause = " AND ".join(conditions)
|
||||
params.extend([limit, offset])
|
||||
|
||||
query = f"""
|
||||
SELECT e.*, a.email as account_email, a.display_name as account_name
|
||||
FROM aggregated_emails e
|
||||
JOIN external_email_accounts a ON e.account_id = a.id
|
||||
WHERE {where_clause}
|
||||
ORDER BY e.date_received DESC
|
||||
LIMIT ${param_idx} OFFSET ${param_idx + 1}
|
||||
"""
|
||||
|
||||
rows = await conn.fetch(query, *params)
|
||||
return [dict(r) for r in rows]
|
||||
except Exception as e:
|
||||
print(f"Failed to get unified inbox: {e}")
|
||||
return []
|
||||
|
||||
|
||||
async def get_email(email_id: str, user_id: str) -> Optional[Dict]:
|
||||
"""Get a single email by ID."""
|
||||
pool = await get_pool()
|
||||
if pool is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
async with pool.acquire() as conn:
|
||||
row = await conn.fetchrow(
|
||||
"""
|
||||
SELECT e.*, a.email as account_email, a.display_name as account_name
|
||||
FROM aggregated_emails e
|
||||
JOIN external_email_accounts a ON e.account_id = a.id
|
||||
WHERE e.id = $1 AND e.user_id = $2
|
||||
""",
|
||||
email_id, user_id
|
||||
)
|
||||
return dict(row) if row else None
|
||||
except Exception as e:
|
||||
print(f"Failed to get email: {e}")
|
||||
return None
|
||||
|
||||
|
||||
async def update_email_ai_analysis(
|
||||
email_id: str,
|
||||
category: str,
|
||||
sender_type: str,
|
||||
sender_authority_name: Optional[str],
|
||||
detected_deadlines: List[Dict],
|
||||
suggested_priority: str,
|
||||
ai_summary: Optional[str],
|
||||
) -> bool:
|
||||
"""Update email with AI analysis results."""
|
||||
pool = await get_pool()
|
||||
if pool is None:
|
||||
return False
|
||||
|
||||
try:
|
||||
async with pool.acquire() as conn:
|
||||
await conn.execute(
|
||||
"""
|
||||
UPDATE aggregated_emails SET
|
||||
category = $2,
|
||||
sender_type = $3,
|
||||
sender_authority_name = $4,
|
||||
detected_deadlines = $5,
|
||||
suggested_priority = $6,
|
||||
ai_summary = $7,
|
||||
ai_analyzed_at = NOW()
|
||||
WHERE id = $1
|
||||
""",
|
||||
email_id, category, sender_type, sender_authority_name,
|
||||
json.dumps(detected_deadlines), suggested_priority, ai_summary
|
||||
)
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"Failed to update email AI analysis: {e}")
|
||||
return False
|
||||
|
||||
|
||||
async def mark_email_read(email_id: str, user_id: str, is_read: bool = True) -> bool:
|
||||
"""Mark email as read/unread."""
|
||||
pool = await get_pool()
|
||||
if pool is None:
|
||||
return False
|
||||
|
||||
try:
|
||||
async with pool.acquire() as conn:
|
||||
await conn.execute(
|
||||
"""
|
||||
UPDATE aggregated_emails SET is_read = $3
|
||||
WHERE id = $1 AND user_id = $2
|
||||
""",
|
||||
email_id, user_id, is_read
|
||||
)
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"Failed to mark email read: {e}")
|
||||
return False
|
||||
|
||||
|
||||
async def mark_email_starred(email_id: str, user_id: str, is_starred: bool = True) -> bool:
|
||||
"""Mark email as starred/unstarred."""
|
||||
pool = await get_pool()
|
||||
if pool is None:
|
||||
return False
|
||||
|
||||
try:
|
||||
async with pool.acquire() as conn:
|
||||
await conn.execute(
|
||||
"""
|
||||
UPDATE aggregated_emails SET is_starred = $3
|
||||
WHERE id = $1 AND user_id = $2
|
||||
""",
|
||||
email_id, user_id, is_starred
|
||||
)
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"Failed to mark email starred: {e}")
|
||||
return False
|
||||
253
klausur-service/backend/mail/mail_db_pool.py
Normal file
253
klausur-service/backend/mail/mail_db_pool.py
Normal file
@@ -0,0 +1,253 @@
|
||||
"""
|
||||
Mail Database - Connection Pool and Schema Initialization.
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
# Database Configuration - from Vault or environment (test default for CI)
|
||||
DATABASE_URL = os.getenv("DATABASE_URL", "postgresql://test:test@localhost:5432/test")
|
||||
|
||||
# Flag to check if using test defaults
|
||||
_DB_CONFIGURED = DATABASE_URL != "postgresql://test:test@localhost:5432/test"
|
||||
|
||||
# Connection pool (shared with metrics_db)
|
||||
_pool = None
|
||||
|
||||
|
||||
async def get_pool():
|
||||
"""Get or create database connection pool."""
|
||||
global _pool
|
||||
if _pool is None:
|
||||
try:
|
||||
import asyncpg
|
||||
_pool = await asyncpg.create_pool(DATABASE_URL, min_size=2, max_size=10)
|
||||
except ImportError:
|
||||
print("Warning: asyncpg not installed. Mail database disabled.")
|
||||
return None
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to connect to PostgreSQL: {e}")
|
||||
return None
|
||||
return _pool
|
||||
|
||||
|
||||
async def init_mail_tables() -> bool:
|
||||
"""Initialize mail tables in PostgreSQL."""
|
||||
pool = await get_pool()
|
||||
if pool is None:
|
||||
return False
|
||||
|
||||
create_tables_sql = """
|
||||
-- =============================================================================
|
||||
-- External Email Accounts
|
||||
-- =============================================================================
|
||||
CREATE TABLE IF NOT EXISTS external_email_accounts (
|
||||
id VARCHAR(36) PRIMARY KEY,
|
||||
user_id VARCHAR(36) NOT NULL,
|
||||
tenant_id VARCHAR(36) NOT NULL,
|
||||
email VARCHAR(255) NOT NULL,
|
||||
display_name VARCHAR(255),
|
||||
account_type VARCHAR(50) DEFAULT 'personal',
|
||||
|
||||
-- IMAP Settings (password stored in Vault)
|
||||
imap_host VARCHAR(255) NOT NULL,
|
||||
imap_port INTEGER DEFAULT 993,
|
||||
imap_ssl BOOLEAN DEFAULT TRUE,
|
||||
|
||||
-- SMTP Settings
|
||||
smtp_host VARCHAR(255) NOT NULL,
|
||||
smtp_port INTEGER DEFAULT 465,
|
||||
smtp_ssl BOOLEAN DEFAULT TRUE,
|
||||
|
||||
-- Vault path for credentials
|
||||
vault_path VARCHAR(500),
|
||||
|
||||
-- Status tracking
|
||||
status VARCHAR(20) DEFAULT 'pending',
|
||||
last_sync TIMESTAMP,
|
||||
sync_error TEXT,
|
||||
email_count INTEGER DEFAULT 0,
|
||||
unread_count INTEGER DEFAULT 0,
|
||||
|
||||
-- Timestamps
|
||||
created_at TIMESTAMP DEFAULT NOW(),
|
||||
updated_at TIMESTAMP DEFAULT NOW(),
|
||||
|
||||
-- Constraints
|
||||
UNIQUE(user_id, email)
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_mail_accounts_user ON external_email_accounts(user_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_mail_accounts_tenant ON external_email_accounts(tenant_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_mail_accounts_status ON external_email_accounts(status);
|
||||
|
||||
-- =============================================================================
|
||||
-- Aggregated Emails
|
||||
-- =============================================================================
|
||||
CREATE TABLE IF NOT EXISTS aggregated_emails (
|
||||
id VARCHAR(36) PRIMARY KEY,
|
||||
account_id VARCHAR(36) REFERENCES external_email_accounts(id) ON DELETE CASCADE,
|
||||
user_id VARCHAR(36) NOT NULL,
|
||||
tenant_id VARCHAR(36) NOT NULL,
|
||||
|
||||
-- Email identification
|
||||
message_id VARCHAR(500) NOT NULL,
|
||||
folder VARCHAR(100) DEFAULT 'INBOX',
|
||||
|
||||
-- Email content
|
||||
subject TEXT,
|
||||
sender_email VARCHAR(255),
|
||||
sender_name VARCHAR(255),
|
||||
recipients JSONB DEFAULT '[]',
|
||||
cc JSONB DEFAULT '[]',
|
||||
body_preview TEXT,
|
||||
body_text TEXT,
|
||||
body_html TEXT,
|
||||
has_attachments BOOLEAN DEFAULT FALSE,
|
||||
attachments JSONB DEFAULT '[]',
|
||||
headers JSONB DEFAULT '{}',
|
||||
|
||||
-- Status flags
|
||||
is_read BOOLEAN DEFAULT FALSE,
|
||||
is_starred BOOLEAN DEFAULT FALSE,
|
||||
is_deleted BOOLEAN DEFAULT FALSE,
|
||||
|
||||
-- Dates
|
||||
date_sent TIMESTAMP,
|
||||
date_received TIMESTAMP,
|
||||
|
||||
-- AI enrichment
|
||||
category VARCHAR(50),
|
||||
sender_type VARCHAR(50),
|
||||
sender_authority_name VARCHAR(255),
|
||||
detected_deadlines JSONB DEFAULT '[]',
|
||||
suggested_priority VARCHAR(20),
|
||||
ai_summary TEXT,
|
||||
ai_analyzed_at TIMESTAMP,
|
||||
|
||||
created_at TIMESTAMP DEFAULT NOW(),
|
||||
|
||||
-- Prevent duplicate imports
|
||||
UNIQUE(account_id, message_id)
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_emails_account ON aggregated_emails(account_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_emails_user ON aggregated_emails(user_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_emails_tenant ON aggregated_emails(tenant_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_emails_date ON aggregated_emails(date_received DESC);
|
||||
CREATE INDEX IF NOT EXISTS idx_emails_category ON aggregated_emails(category);
|
||||
CREATE INDEX IF NOT EXISTS idx_emails_unread ON aggregated_emails(is_read) WHERE is_read = FALSE;
|
||||
CREATE INDEX IF NOT EXISTS idx_emails_starred ON aggregated_emails(is_starred) WHERE is_starred = TRUE;
|
||||
CREATE INDEX IF NOT EXISTS idx_emails_sender ON aggregated_emails(sender_email);
|
||||
|
||||
-- =============================================================================
|
||||
-- Inbox Tasks (Arbeitsvorrat)
|
||||
-- =============================================================================
|
||||
CREATE TABLE IF NOT EXISTS inbox_tasks (
|
||||
id VARCHAR(36) PRIMARY KEY,
|
||||
user_id VARCHAR(36) NOT NULL,
|
||||
tenant_id VARCHAR(36) NOT NULL,
|
||||
email_id VARCHAR(36) REFERENCES aggregated_emails(id) ON DELETE SET NULL,
|
||||
account_id VARCHAR(36) REFERENCES external_email_accounts(id) ON DELETE SET NULL,
|
||||
|
||||
-- Task content
|
||||
title VARCHAR(500) NOT NULL,
|
||||
description TEXT,
|
||||
priority VARCHAR(20) DEFAULT 'medium',
|
||||
status VARCHAR(20) DEFAULT 'pending',
|
||||
deadline TIMESTAMP,
|
||||
|
||||
-- Source information
|
||||
source_email_subject TEXT,
|
||||
source_sender VARCHAR(255),
|
||||
source_sender_type VARCHAR(50),
|
||||
|
||||
-- AI extraction info
|
||||
ai_extracted BOOLEAN DEFAULT FALSE,
|
||||
confidence_score FLOAT,
|
||||
|
||||
-- Completion tracking
|
||||
completed_at TIMESTAMP,
|
||||
reminder_at TIMESTAMP,
|
||||
|
||||
-- Timestamps
|
||||
created_at TIMESTAMP DEFAULT NOW(),
|
||||
updated_at TIMESTAMP DEFAULT NOW()
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_tasks_user ON inbox_tasks(user_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_tasks_tenant ON inbox_tasks(tenant_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_tasks_status ON inbox_tasks(status);
|
||||
CREATE INDEX IF NOT EXISTS idx_tasks_deadline ON inbox_tasks(deadline) WHERE deadline IS NOT NULL;
|
||||
CREATE INDEX IF NOT EXISTS idx_tasks_priority ON inbox_tasks(priority);
|
||||
CREATE INDEX IF NOT EXISTS idx_tasks_email ON inbox_tasks(email_id) WHERE email_id IS NOT NULL;
|
||||
|
||||
-- =============================================================================
|
||||
-- Email Templates
|
||||
-- =============================================================================
|
||||
CREATE TABLE IF NOT EXISTS email_templates (
|
||||
id VARCHAR(36) PRIMARY KEY,
|
||||
user_id VARCHAR(36), -- NULL for system templates
|
||||
tenant_id VARCHAR(36),
|
||||
|
||||
name VARCHAR(255) NOT NULL,
|
||||
category VARCHAR(100),
|
||||
subject_template TEXT,
|
||||
body_template TEXT,
|
||||
variables JSONB DEFAULT '[]',
|
||||
|
||||
is_system BOOLEAN DEFAULT FALSE,
|
||||
usage_count INTEGER DEFAULT 0,
|
||||
|
||||
created_at TIMESTAMP DEFAULT NOW(),
|
||||
updated_at TIMESTAMP DEFAULT NOW()
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_templates_user ON email_templates(user_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_templates_tenant ON email_templates(tenant_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_templates_system ON email_templates(is_system);
|
||||
|
||||
-- =============================================================================
|
||||
-- Mail Audit Log
|
||||
-- =============================================================================
|
||||
CREATE TABLE IF NOT EXISTS mail_audit_log (
|
||||
id VARCHAR(36) PRIMARY KEY,
|
||||
user_id VARCHAR(36) NOT NULL,
|
||||
tenant_id VARCHAR(36),
|
||||
action VARCHAR(100) NOT NULL,
|
||||
entity_type VARCHAR(50), -- account, email, task
|
||||
entity_id VARCHAR(36),
|
||||
details JSONB,
|
||||
ip_address VARCHAR(45),
|
||||
user_agent TEXT,
|
||||
created_at TIMESTAMP DEFAULT NOW()
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_mail_audit_user ON mail_audit_log(user_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_mail_audit_created ON mail_audit_log(created_at DESC);
|
||||
CREATE INDEX IF NOT EXISTS idx_mail_audit_action ON mail_audit_log(action);
|
||||
|
||||
-- =============================================================================
|
||||
-- Sync Status Tracking
|
||||
-- =============================================================================
|
||||
CREATE TABLE IF NOT EXISTS mail_sync_status (
|
||||
id VARCHAR(36) PRIMARY KEY,
|
||||
account_id VARCHAR(36) REFERENCES external_email_accounts(id) ON DELETE CASCADE,
|
||||
folder VARCHAR(100),
|
||||
last_uid INTEGER DEFAULT 0,
|
||||
last_sync TIMESTAMP,
|
||||
sync_errors INTEGER DEFAULT 0,
|
||||
created_at TIMESTAMP DEFAULT NOW(),
|
||||
updated_at TIMESTAMP DEFAULT NOW(),
|
||||
|
||||
UNIQUE(account_id, folder)
|
||||
);
|
||||
"""
|
||||
|
||||
try:
|
||||
async with pool.acquire() as conn:
|
||||
await conn.execute(create_tables_sql)
|
||||
print("Mail tables initialized successfully")
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"Failed to initialize mail tables: {e}")
|
||||
return False
|
||||
118
klausur-service/backend/mail/mail_db_stats.py
Normal file
118
klausur-service/backend/mail/mail_db_stats.py
Normal file
@@ -0,0 +1,118 @@
|
||||
"""
|
||||
Mail Database - Statistics and Audit Log Operations.
|
||||
"""
|
||||
|
||||
import json
|
||||
import uuid
|
||||
from typing import Optional, Dict
|
||||
from datetime import datetime
|
||||
|
||||
from .mail_db_pool import get_pool
|
||||
|
||||
|
||||
async def get_mail_stats(user_id: str) -> Dict:
|
||||
"""Get overall mail statistics for a user."""
|
||||
pool = await get_pool()
|
||||
if pool is None:
|
||||
return {}
|
||||
|
||||
try:
|
||||
async with pool.acquire() as conn:
|
||||
today = datetime.now().replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
|
||||
# Account stats
|
||||
accounts = await conn.fetch(
|
||||
"""
|
||||
SELECT id, email, display_name, status, email_count, unread_count, last_sync
|
||||
FROM external_email_accounts
|
||||
WHERE user_id = $1
|
||||
""",
|
||||
user_id
|
||||
)
|
||||
|
||||
# Email counts
|
||||
email_stats = await conn.fetchrow(
|
||||
"""
|
||||
SELECT
|
||||
COUNT(*) as total_emails,
|
||||
COUNT(*) FILTER (WHERE is_read = FALSE) as unread_emails,
|
||||
COUNT(*) FILTER (WHERE date_received >= $2) as emails_today,
|
||||
COUNT(*) FILTER (WHERE ai_analyzed_at >= $2) as ai_analyses_today
|
||||
FROM aggregated_emails
|
||||
WHERE user_id = $1
|
||||
""",
|
||||
user_id, today
|
||||
)
|
||||
|
||||
# Task counts
|
||||
task_stats = await conn.fetchrow(
|
||||
"""
|
||||
SELECT
|
||||
COUNT(*) as total_tasks,
|
||||
COUNT(*) FILTER (WHERE status = 'pending') as pending_tasks,
|
||||
COUNT(*) FILTER (WHERE status != 'completed' AND deadline < NOW()) as overdue_tasks
|
||||
FROM inbox_tasks
|
||||
WHERE user_id = $1
|
||||
""",
|
||||
user_id
|
||||
)
|
||||
|
||||
return {
|
||||
"total_accounts": len(accounts),
|
||||
"active_accounts": sum(1 for a in accounts if a['status'] == 'active'),
|
||||
"error_accounts": sum(1 for a in accounts if a['status'] == 'error'),
|
||||
"total_emails": email_stats['total_emails'] or 0,
|
||||
"unread_emails": email_stats['unread_emails'] or 0,
|
||||
"total_tasks": task_stats['total_tasks'] or 0,
|
||||
"pending_tasks": task_stats['pending_tasks'] or 0,
|
||||
"overdue_tasks": task_stats['overdue_tasks'] or 0,
|
||||
"emails_today": email_stats['emails_today'] or 0,
|
||||
"ai_analyses_today": email_stats['ai_analyses_today'] or 0,
|
||||
"per_account": [
|
||||
{
|
||||
"id": a['id'],
|
||||
"email": a['email'],
|
||||
"display_name": a['display_name'],
|
||||
"status": a['status'],
|
||||
"email_count": a['email_count'],
|
||||
"unread_count": a['unread_count'],
|
||||
"last_sync": a['last_sync'].isoformat() if a['last_sync'] else None,
|
||||
}
|
||||
for a in accounts
|
||||
],
|
||||
}
|
||||
except Exception as e:
|
||||
print(f"Failed to get mail stats: {e}")
|
||||
return {}
|
||||
|
||||
|
||||
async def log_mail_audit(
|
||||
user_id: str,
|
||||
action: str,
|
||||
entity_type: Optional[str] = None,
|
||||
entity_id: Optional[str] = None,
|
||||
details: Optional[Dict] = None,
|
||||
tenant_id: Optional[str] = None,
|
||||
ip_address: Optional[str] = None,
|
||||
user_agent: Optional[str] = None,
|
||||
) -> bool:
|
||||
"""Log a mail action for audit trail."""
|
||||
pool = await get_pool()
|
||||
if pool is None:
|
||||
return False
|
||||
|
||||
try:
|
||||
async with pool.acquire() as conn:
|
||||
await conn.execute(
|
||||
"""
|
||||
INSERT INTO mail_audit_log
|
||||
(id, user_id, tenant_id, action, entity_type, entity_id, details, ip_address, user_agent)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
|
||||
""",
|
||||
str(uuid.uuid4()), user_id, tenant_id, action, entity_type, entity_id,
|
||||
json.dumps(details) if details else None, ip_address, user_agent
|
||||
)
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"Failed to log mail audit: {e}")
|
||||
return False
|
||||
247
klausur-service/backend/mail/mail_db_tasks.py
Normal file
247
klausur-service/backend/mail/mail_db_tasks.py
Normal file
@@ -0,0 +1,247 @@
|
||||
"""
|
||||
Mail Database - Inbox Task Operations.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from typing import Optional, List, Dict
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from .mail_db_pool import get_pool
|
||||
|
||||
|
||||
async def create_task(
|
||||
user_id: str,
|
||||
tenant_id: str,
|
||||
title: str,
|
||||
description: Optional[str] = None,
|
||||
priority: str = "medium",
|
||||
deadline: Optional[datetime] = None,
|
||||
email_id: Optional[str] = None,
|
||||
account_id: Optional[str] = None,
|
||||
source_email_subject: Optional[str] = None,
|
||||
source_sender: Optional[str] = None,
|
||||
source_sender_type: Optional[str] = None,
|
||||
ai_extracted: bool = False,
|
||||
confidence_score: Optional[float] = None,
|
||||
) -> Optional[str]:
|
||||
"""Create a new inbox task."""
|
||||
pool = await get_pool()
|
||||
if pool is None:
|
||||
return None
|
||||
|
||||
task_id = str(uuid.uuid4())
|
||||
try:
|
||||
async with pool.acquire() as conn:
|
||||
await conn.execute(
|
||||
"""
|
||||
INSERT INTO inbox_tasks
|
||||
(id, user_id, tenant_id, title, description, priority, deadline,
|
||||
email_id, account_id, source_email_subject, source_sender,
|
||||
source_sender_type, ai_extracted, confidence_score)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14)
|
||||
""",
|
||||
task_id, user_id, tenant_id, title, description, priority, deadline,
|
||||
email_id, account_id, source_email_subject, source_sender,
|
||||
source_sender_type, ai_extracted, confidence_score
|
||||
)
|
||||
return task_id
|
||||
except Exception as e:
|
||||
print(f"Failed to create task: {e}")
|
||||
return None
|
||||
|
||||
|
||||
async def get_tasks(
|
||||
user_id: str,
|
||||
status: Optional[str] = None,
|
||||
priority: Optional[str] = None,
|
||||
include_completed: bool = False,
|
||||
limit: int = 50,
|
||||
offset: int = 0,
|
||||
) -> List[Dict]:
|
||||
"""Get tasks for a user."""
|
||||
pool = await get_pool()
|
||||
if pool is None:
|
||||
return []
|
||||
|
||||
try:
|
||||
async with pool.acquire() as conn:
|
||||
conditions = ["user_id = $1"]
|
||||
params = [user_id]
|
||||
param_idx = 2
|
||||
|
||||
if not include_completed:
|
||||
conditions.append("status != 'completed'")
|
||||
|
||||
if status:
|
||||
conditions.append(f"status = ${param_idx}")
|
||||
params.append(status)
|
||||
param_idx += 1
|
||||
|
||||
if priority:
|
||||
conditions.append(f"priority = ${param_idx}")
|
||||
params.append(priority)
|
||||
param_idx += 1
|
||||
|
||||
where_clause = " AND ".join(conditions)
|
||||
params.extend([limit, offset])
|
||||
|
||||
query = f"""
|
||||
SELECT * FROM inbox_tasks
|
||||
WHERE {where_clause}
|
||||
ORDER BY
|
||||
CASE priority
|
||||
WHEN 'urgent' THEN 1
|
||||
WHEN 'high' THEN 2
|
||||
WHEN 'medium' THEN 3
|
||||
WHEN 'low' THEN 4
|
||||
END,
|
||||
deadline ASC NULLS LAST,
|
||||
created_at DESC
|
||||
LIMIT ${param_idx} OFFSET ${param_idx + 1}
|
||||
"""
|
||||
|
||||
rows = await conn.fetch(query, *params)
|
||||
return [dict(r) for r in rows]
|
||||
except Exception as e:
|
||||
print(f"Failed to get tasks: {e}")
|
||||
return []
|
||||
|
||||
|
||||
async def get_task(task_id: str, user_id: str) -> Optional[Dict]:
|
||||
"""Get a single task."""
|
||||
pool = await get_pool()
|
||||
if pool is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
async with pool.acquire() as conn:
|
||||
row = await conn.fetchrow(
|
||||
"SELECT * FROM inbox_tasks WHERE id = $1 AND user_id = $2",
|
||||
task_id, user_id
|
||||
)
|
||||
return dict(row) if row else None
|
||||
except Exception as e:
|
||||
print(f"Failed to get task: {e}")
|
||||
return None
|
||||
|
||||
|
||||
async def update_task(
|
||||
task_id: str,
|
||||
user_id: str,
|
||||
title: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
priority: Optional[str] = None,
|
||||
status: Optional[str] = None,
|
||||
deadline: Optional[datetime] = None,
|
||||
) -> bool:
|
||||
"""Update a task."""
|
||||
pool = await get_pool()
|
||||
if pool is None:
|
||||
return False
|
||||
|
||||
try:
|
||||
async with pool.acquire() as conn:
|
||||
updates = ["updated_at = NOW()"]
|
||||
params = [task_id, user_id]
|
||||
param_idx = 3
|
||||
|
||||
if title is not None:
|
||||
updates.append(f"title = ${param_idx}")
|
||||
params.append(title)
|
||||
param_idx += 1
|
||||
|
||||
if description is not None:
|
||||
updates.append(f"description = ${param_idx}")
|
||||
params.append(description)
|
||||
param_idx += 1
|
||||
|
||||
if priority is not None:
|
||||
updates.append(f"priority = ${param_idx}")
|
||||
params.append(priority)
|
||||
param_idx += 1
|
||||
|
||||
if status is not None:
|
||||
updates.append(f"status = ${param_idx}")
|
||||
params.append(status)
|
||||
param_idx += 1
|
||||
if status == "completed":
|
||||
updates.append("completed_at = NOW()")
|
||||
|
||||
if deadline is not None:
|
||||
updates.append(f"deadline = ${param_idx}")
|
||||
params.append(deadline)
|
||||
param_idx += 1
|
||||
|
||||
set_clause = ", ".join(updates)
|
||||
await conn.execute(
|
||||
f"UPDATE inbox_tasks SET {set_clause} WHERE id = $1 AND user_id = $2",
|
||||
*params
|
||||
)
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"Failed to update task: {e}")
|
||||
return False
|
||||
|
||||
|
||||
async def get_task_dashboard_stats(user_id: str) -> Dict:
|
||||
"""Get dashboard statistics for tasks."""
|
||||
pool = await get_pool()
|
||||
if pool is None:
|
||||
return {}
|
||||
|
||||
try:
|
||||
async with pool.acquire() as conn:
|
||||
now = datetime.now()
|
||||
today_end = now.replace(hour=23, minute=59, second=59)
|
||||
week_end = now + timedelta(days=7)
|
||||
|
||||
stats = await conn.fetchrow(
|
||||
"""
|
||||
SELECT
|
||||
COUNT(*) as total_tasks,
|
||||
COUNT(*) FILTER (WHERE status = 'pending') as pending_tasks,
|
||||
COUNT(*) FILTER (WHERE status = 'in_progress') as in_progress_tasks,
|
||||
COUNT(*) FILTER (WHERE status = 'completed') as completed_tasks,
|
||||
COUNT(*) FILTER (WHERE status != 'completed' AND deadline < $2) as overdue_tasks,
|
||||
COUNT(*) FILTER (WHERE status != 'completed' AND deadline <= $3) as due_today,
|
||||
COUNT(*) FILTER (WHERE status != 'completed' AND deadline <= $4) as due_this_week
|
||||
FROM inbox_tasks
|
||||
WHERE user_id = $1
|
||||
""",
|
||||
user_id, now, today_end, week_end
|
||||
)
|
||||
|
||||
by_priority = await conn.fetch(
|
||||
"""
|
||||
SELECT priority, COUNT(*) as count
|
||||
FROM inbox_tasks
|
||||
WHERE user_id = $1 AND status != 'completed'
|
||||
GROUP BY priority
|
||||
""",
|
||||
user_id
|
||||
)
|
||||
|
||||
by_sender = await conn.fetch(
|
||||
"""
|
||||
SELECT source_sender_type, COUNT(*) as count
|
||||
FROM inbox_tasks
|
||||
WHERE user_id = $1 AND status != 'completed' AND source_sender_type IS NOT NULL
|
||||
GROUP BY source_sender_type
|
||||
""",
|
||||
user_id
|
||||
)
|
||||
|
||||
return {
|
||||
"total_tasks": stats['total_tasks'] or 0,
|
||||
"pending_tasks": stats['pending_tasks'] or 0,
|
||||
"in_progress_tasks": stats['in_progress_tasks'] or 0,
|
||||
"completed_tasks": stats['completed_tasks'] or 0,
|
||||
"overdue_tasks": stats['overdue_tasks'] or 0,
|
||||
"due_today": stats['due_today'] or 0,
|
||||
"due_this_week": stats['due_this_week'] or 0,
|
||||
"by_priority": {r['priority']: r['count'] for r in by_priority},
|
||||
"by_sender_type": {r['source_sender_type']: r['count'] for r in by_sender},
|
||||
}
|
||||
except Exception as e:
|
||||
print(f"Failed to get task stats: {e}")
|
||||
return {}
|
||||
272
klausur-service/backend/ocr_merge_helpers.py
Normal file
272
klausur-service/backend/ocr_merge_helpers.py
Normal file
@@ -0,0 +1,272 @@
|
||||
"""
|
||||
OCR Merge Helpers — functions for combining PaddleOCR/RapidOCR with Tesseract results.
|
||||
|
||||
Extracted from ocr_pipeline_ocr_merge.py.
|
||||
|
||||
Lizenz: Apache 2.0
|
||||
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import List
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _split_paddle_multi_words(words: list) -> list:
|
||||
"""Split PaddleOCR multi-word boxes into individual word boxes.
|
||||
|
||||
PaddleOCR often returns entire phrases as a single box, e.g.
|
||||
"More than 200 singers took part in the" with one bounding box.
|
||||
This splits them into individual words with proportional widths.
|
||||
Also handles leading "!" (e.g. "!Betonung" -> ["!", "Betonung"])
|
||||
and IPA brackets (e.g. "badge[bxd3]" -> ["badge", "[bxd3]"]).
|
||||
"""
|
||||
import re
|
||||
|
||||
result = []
|
||||
for w in words:
|
||||
raw_text = w.get("text", "").strip()
|
||||
if not raw_text:
|
||||
continue
|
||||
# Split on whitespace, before "[" (IPA), and after "!" before letter
|
||||
tokens = re.split(
|
||||
r'\s+|(?=\[)|(?<=!)(?=[A-Za-z\u00c0-\u024f])', raw_text
|
||||
)
|
||||
tokens = [t for t in tokens if t]
|
||||
|
||||
if len(tokens) <= 1:
|
||||
result.append(w)
|
||||
else:
|
||||
# Split proportionally by character count
|
||||
total_chars = sum(len(t) for t in tokens)
|
||||
if total_chars == 0:
|
||||
continue
|
||||
n_gaps = len(tokens) - 1
|
||||
gap_px = w["width"] * 0.02
|
||||
usable_w = w["width"] - gap_px * n_gaps
|
||||
cursor = w["left"]
|
||||
for t in tokens:
|
||||
token_w = max(1, usable_w * len(t) / total_chars)
|
||||
result.append({
|
||||
"text": t,
|
||||
"left": round(cursor),
|
||||
"top": w["top"],
|
||||
"width": round(token_w),
|
||||
"height": w["height"],
|
||||
"conf": w.get("conf", 0),
|
||||
})
|
||||
cursor += token_w + gap_px
|
||||
return result
|
||||
|
||||
|
||||
def _group_words_into_rows(words: list, row_gap: int = 12) -> list:
|
||||
"""Group words into rows by Y-position clustering.
|
||||
|
||||
Words whose vertical centers are within `row_gap` pixels are on the same row.
|
||||
Returns list of rows, each row is a list of words sorted left-to-right.
|
||||
"""
|
||||
if not words:
|
||||
return []
|
||||
# Sort by vertical center
|
||||
sorted_words = sorted(words, key=lambda w: w["top"] + w.get("height", 0) / 2)
|
||||
rows: list = []
|
||||
current_row: list = [sorted_words[0]]
|
||||
current_cy = sorted_words[0]["top"] + sorted_words[0].get("height", 0) / 2
|
||||
|
||||
for w in sorted_words[1:]:
|
||||
cy = w["top"] + w.get("height", 0) / 2
|
||||
if abs(cy - current_cy) <= row_gap:
|
||||
current_row.append(w)
|
||||
else:
|
||||
# Sort current row left-to-right before saving
|
||||
rows.append(sorted(current_row, key=lambda w: w["left"]))
|
||||
current_row = [w]
|
||||
current_cy = cy
|
||||
if current_row:
|
||||
rows.append(sorted(current_row, key=lambda w: w["left"]))
|
||||
return rows
|
||||
|
||||
|
||||
def _row_center_y(row: list) -> float:
|
||||
"""Average vertical center of a row of words."""
|
||||
if not row:
|
||||
return 0.0
|
||||
return sum(w["top"] + w.get("height", 0) / 2 for w in row) / len(row)
|
||||
|
||||
|
||||
def _merge_row_sequences(paddle_row: list, tess_row: list) -> list:
|
||||
"""Merge two word sequences from the same row using sequence alignment.
|
||||
|
||||
Both sequences are sorted left-to-right. Walk through both simultaneously:
|
||||
- If words match (same/similar text): take Paddle text with averaged coords
|
||||
- If they don't match: the extra word is unique to one engine, include it
|
||||
"""
|
||||
merged = []
|
||||
pi, ti = 0, 0
|
||||
|
||||
while pi < len(paddle_row) and ti < len(tess_row):
|
||||
pw = paddle_row[pi]
|
||||
tw = tess_row[ti]
|
||||
|
||||
pt = pw.get("text", "").lower().strip()
|
||||
tt = tw.get("text", "").lower().strip()
|
||||
|
||||
is_same = (pt == tt) or (len(pt) > 1 and len(tt) > 1 and (pt in tt or tt in pt))
|
||||
|
||||
# Spatial overlap check
|
||||
spatial_match = False
|
||||
if not is_same:
|
||||
overlap_left = max(pw["left"], tw["left"])
|
||||
overlap_right = min(
|
||||
pw["left"] + pw.get("width", 0),
|
||||
tw["left"] + tw.get("width", 0),
|
||||
)
|
||||
overlap_w = max(0, overlap_right - overlap_left)
|
||||
min_w = min(pw.get("width", 1), tw.get("width", 1))
|
||||
if min_w > 0 and overlap_w / min_w >= 0.4:
|
||||
is_same = True
|
||||
spatial_match = True
|
||||
|
||||
if is_same:
|
||||
pc = pw.get("conf", 80)
|
||||
tc = tw.get("conf", 50)
|
||||
total = pc + tc
|
||||
if total == 0:
|
||||
total = 1
|
||||
if spatial_match and pc < tc:
|
||||
best_text = tw["text"]
|
||||
else:
|
||||
best_text = pw["text"]
|
||||
merged.append({
|
||||
"text": best_text,
|
||||
"left": round((pw["left"] * pc + tw["left"] * tc) / total),
|
||||
"top": round((pw["top"] * pc + tw["top"] * tc) / total),
|
||||
"width": round((pw["width"] * pc + tw["width"] * tc) / total),
|
||||
"height": round((pw["height"] * pc + tw["height"] * tc) / total),
|
||||
"conf": max(pc, tc),
|
||||
})
|
||||
pi += 1
|
||||
ti += 1
|
||||
else:
|
||||
paddle_ahead = any(
|
||||
tess_row[t].get("text", "").lower().strip() == pt
|
||||
for t in range(ti + 1, min(ti + 4, len(tess_row)))
|
||||
)
|
||||
tess_ahead = any(
|
||||
paddle_row[p].get("text", "").lower().strip() == tt
|
||||
for p in range(pi + 1, min(pi + 4, len(paddle_row)))
|
||||
)
|
||||
|
||||
if paddle_ahead and not tess_ahead:
|
||||
if tw.get("conf", 0) >= 30:
|
||||
merged.append(tw)
|
||||
ti += 1
|
||||
elif tess_ahead and not paddle_ahead:
|
||||
merged.append(pw)
|
||||
pi += 1
|
||||
else:
|
||||
if pw["left"] <= tw["left"]:
|
||||
merged.append(pw)
|
||||
pi += 1
|
||||
else:
|
||||
if tw.get("conf", 0) >= 30:
|
||||
merged.append(tw)
|
||||
ti += 1
|
||||
|
||||
while pi < len(paddle_row):
|
||||
merged.append(paddle_row[pi])
|
||||
pi += 1
|
||||
while ti < len(tess_row):
|
||||
tw = tess_row[ti]
|
||||
if tw.get("conf", 0) >= 30:
|
||||
merged.append(tw)
|
||||
ti += 1
|
||||
|
||||
return merged
|
||||
|
||||
|
||||
def _merge_paddle_tesseract(paddle_words: list, tess_words: list) -> list:
|
||||
"""Merge word boxes from PaddleOCR and Tesseract using row-based sequence alignment."""
|
||||
if not paddle_words and not tess_words:
|
||||
return []
|
||||
if not paddle_words:
|
||||
return [w for w in tess_words if w.get("conf", 0) >= 40]
|
||||
if not tess_words:
|
||||
return list(paddle_words)
|
||||
|
||||
paddle_rows = _group_words_into_rows(paddle_words)
|
||||
tess_rows = _group_words_into_rows(tess_words)
|
||||
|
||||
used_tess_rows: set = set()
|
||||
merged_all: list = []
|
||||
|
||||
for pr in paddle_rows:
|
||||
pr_cy = _row_center_y(pr)
|
||||
best_dist, best_tri = float("inf"), -1
|
||||
for tri, tr in enumerate(tess_rows):
|
||||
if tri in used_tess_rows:
|
||||
continue
|
||||
tr_cy = _row_center_y(tr)
|
||||
dist = abs(pr_cy - tr_cy)
|
||||
if dist < best_dist:
|
||||
best_dist, best_tri = dist, tri
|
||||
|
||||
max_row_dist = max(
|
||||
max((w.get("height", 20) for w in pr), default=20),
|
||||
15,
|
||||
)
|
||||
|
||||
if best_tri >= 0 and best_dist <= max_row_dist:
|
||||
tr = tess_rows[best_tri]
|
||||
used_tess_rows.add(best_tri)
|
||||
merged_all.extend(_merge_row_sequences(pr, tr))
|
||||
else:
|
||||
merged_all.extend(pr)
|
||||
|
||||
for tri, tr in enumerate(tess_rows):
|
||||
if tri not in used_tess_rows:
|
||||
for tw in tr:
|
||||
if tw.get("conf", 0) >= 40:
|
||||
merged_all.append(tw)
|
||||
|
||||
return merged_all
|
||||
|
||||
|
||||
def _deduplicate_words(words: list) -> list:
|
||||
"""Remove duplicate words with same text at overlapping positions."""
|
||||
if not words:
|
||||
return words
|
||||
|
||||
result: list = []
|
||||
for w in words:
|
||||
wt = w.get("text", "").lower().strip()
|
||||
if not wt:
|
||||
continue
|
||||
is_dup = False
|
||||
w_right = w["left"] + w.get("width", 0)
|
||||
w_bottom = w["top"] + w.get("height", 0)
|
||||
for existing in result:
|
||||
et = existing.get("text", "").lower().strip()
|
||||
if wt != et:
|
||||
continue
|
||||
ox_l = max(w["left"], existing["left"])
|
||||
ox_r = min(w_right, existing["left"] + existing.get("width", 0))
|
||||
ox = max(0, ox_r - ox_l)
|
||||
min_w = min(w.get("width", 1), existing.get("width", 1))
|
||||
if min_w <= 0 or ox / min_w < 0.5:
|
||||
continue
|
||||
oy_t = max(w["top"], existing["top"])
|
||||
oy_b = min(w_bottom, existing["top"] + existing.get("height", 0))
|
||||
oy = max(0, oy_b - oy_t)
|
||||
min_h = min(w.get("height", 1), existing.get("height", 1))
|
||||
if min_h > 0 and oy / min_h >= 0.5:
|
||||
is_dup = True
|
||||
break
|
||||
if not is_dup:
|
||||
result.append(w)
|
||||
|
||||
removed = len(words) - len(result)
|
||||
if removed:
|
||||
logger.info("dedup: removed %d duplicate words", removed)
|
||||
return result
|
||||
209
klausur-service/backend/ocr_pipeline_llm_review.py
Normal file
209
klausur-service/backend/ocr_pipeline_llm_review.py
Normal file
@@ -0,0 +1,209 @@
|
||||
"""
|
||||
OCR Pipeline LLM Review — LLM-based correction endpoints.
|
||||
|
||||
Extracted from ocr_pipeline_postprocess.py.
|
||||
|
||||
Lizenz: Apache 2.0
|
||||
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Dict, List
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
from cv_vocab_pipeline import (
|
||||
OLLAMA_REVIEW_MODEL,
|
||||
llm_review_entries,
|
||||
llm_review_entries_streaming,
|
||||
)
|
||||
from ocr_pipeline_session_store import (
|
||||
get_session_db,
|
||||
update_session_db,
|
||||
)
|
||||
from ocr_pipeline_common import (
|
||||
_cache,
|
||||
_append_pipeline_log,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["ocr-pipeline"])
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Step 8: LLM Review
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.post("/sessions/{session_id}/llm-review")
|
||||
async def run_llm_review(session_id: str, request: Request, stream: bool = False):
|
||||
"""Run LLM-based correction on vocab entries from Step 5.
|
||||
|
||||
Query params:
|
||||
stream: false (default) for JSON response, true for SSE streaming
|
||||
"""
|
||||
session = await get_session_db(session_id)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||||
|
||||
word_result = session.get("word_result")
|
||||
if not word_result:
|
||||
raise HTTPException(status_code=400, detail="No word result found — run Step 5 first")
|
||||
|
||||
entries = word_result.get("vocab_entries") or word_result.get("entries") or []
|
||||
if not entries:
|
||||
raise HTTPException(status_code=400, detail="No vocab entries found — run Step 5 first")
|
||||
|
||||
# Optional model override from request body
|
||||
body = {}
|
||||
try:
|
||||
body = await request.json()
|
||||
except Exception:
|
||||
pass
|
||||
model = body.get("model") or OLLAMA_REVIEW_MODEL
|
||||
|
||||
if stream:
|
||||
return StreamingResponse(
|
||||
_llm_review_stream_generator(session_id, entries, word_result, model, request),
|
||||
media_type="text/event-stream",
|
||||
headers={"Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no"},
|
||||
)
|
||||
|
||||
# Non-streaming path
|
||||
try:
|
||||
result = await llm_review_entries(entries, model=model)
|
||||
except Exception as e:
|
||||
import traceback
|
||||
logger.error(f"LLM review failed for session {session_id}: {type(e).__name__}: {e}\n{traceback.format_exc()}")
|
||||
raise HTTPException(status_code=502, detail=f"LLM review failed ({type(e).__name__}): {e}")
|
||||
|
||||
# Store result inside word_result as a sub-key
|
||||
word_result["llm_review"] = {
|
||||
"changes": result["changes"],
|
||||
"model_used": result["model_used"],
|
||||
"duration_ms": result["duration_ms"],
|
||||
"entries_corrected": result["entries_corrected"],
|
||||
}
|
||||
await update_session_db(session_id, word_result=word_result, current_step=9)
|
||||
|
||||
if session_id in _cache:
|
||||
_cache[session_id]["word_result"] = word_result
|
||||
|
||||
logger.info(f"LLM review session {session_id}: {len(result['changes'])} changes, "
|
||||
f"{result['duration_ms']}ms, model={result['model_used']}")
|
||||
|
||||
await _append_pipeline_log(session_id, "correction", {
|
||||
"engine": "llm",
|
||||
"model": result["model_used"],
|
||||
"total_entries": len(entries),
|
||||
"corrections_proposed": len(result["changes"]),
|
||||
}, duration_ms=result["duration_ms"])
|
||||
|
||||
return {
|
||||
"session_id": session_id,
|
||||
"changes": result["changes"],
|
||||
"model_used": result["model_used"],
|
||||
"duration_ms": result["duration_ms"],
|
||||
"total_entries": len(entries),
|
||||
"corrections_found": len(result["changes"]),
|
||||
}
|
||||
|
||||
|
||||
async def _llm_review_stream_generator(
|
||||
session_id: str,
|
||||
entries: List[Dict],
|
||||
word_result: Dict,
|
||||
model: str,
|
||||
request: Request,
|
||||
):
|
||||
"""SSE generator that yields batch-by-batch LLM review progress."""
|
||||
try:
|
||||
async for event in llm_review_entries_streaming(entries, model=model):
|
||||
if await request.is_disconnected():
|
||||
logger.info(f"SSE: client disconnected during LLM review for {session_id}")
|
||||
return
|
||||
|
||||
yield f"data: {json.dumps(event, ensure_ascii=False)}\n\n"
|
||||
|
||||
# On complete: persist to DB
|
||||
if event.get("type") == "complete":
|
||||
word_result["llm_review"] = {
|
||||
"changes": event["changes"],
|
||||
"model_used": event["model_used"],
|
||||
"duration_ms": event["duration_ms"],
|
||||
"entries_corrected": event["entries_corrected"],
|
||||
}
|
||||
await update_session_db(session_id, word_result=word_result, current_step=9)
|
||||
if session_id in _cache:
|
||||
_cache[session_id]["word_result"] = word_result
|
||||
|
||||
logger.info(f"LLM review SSE session {session_id}: {event['corrections_found']} changes, "
|
||||
f"{event['duration_ms']}ms, skipped={event['skipped']}, model={event['model_used']}")
|
||||
|
||||
except Exception as e:
|
||||
import traceback
|
||||
logger.error(f"LLM review SSE failed for {session_id}: {type(e).__name__}: {e}\n{traceback.format_exc()}")
|
||||
error_event = {"type": "error", "detail": f"{type(e).__name__}: {e}"}
|
||||
yield f"data: {json.dumps(error_event)}\n\n"
|
||||
|
||||
|
||||
@router.post("/sessions/{session_id}/llm-review/apply")
|
||||
async def apply_llm_corrections(session_id: str, request: Request):
|
||||
"""Apply selected LLM corrections to vocab entries."""
|
||||
session = await get_session_db(session_id)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||||
|
||||
word_result = session.get("word_result")
|
||||
if not word_result:
|
||||
raise HTTPException(status_code=400, detail="No word result found")
|
||||
|
||||
llm_review = word_result.get("llm_review")
|
||||
if not llm_review:
|
||||
raise HTTPException(status_code=400, detail="No LLM review found — run /llm-review first")
|
||||
|
||||
body = await request.json()
|
||||
accepted_indices = set(body.get("accepted_indices", [])) # indices into changes[]
|
||||
|
||||
changes = llm_review.get("changes", [])
|
||||
entries = word_result.get("vocab_entries") or word_result.get("entries") or []
|
||||
|
||||
# Build a lookup: (row_index, field) -> new_value for accepted changes
|
||||
corrections = {}
|
||||
applied_count = 0
|
||||
for idx, change in enumerate(changes):
|
||||
if idx in accepted_indices:
|
||||
key = (change["row_index"], change["field"])
|
||||
corrections[key] = change["new"]
|
||||
applied_count += 1
|
||||
|
||||
# Apply corrections to entries
|
||||
for entry in entries:
|
||||
row_idx = entry.get("row_index", -1)
|
||||
for field_name in ("english", "german", "example"):
|
||||
key = (row_idx, field_name)
|
||||
if key in corrections:
|
||||
entry[field_name] = corrections[key]
|
||||
entry["llm_corrected"] = True
|
||||
|
||||
# Update word_result
|
||||
word_result["vocab_entries"] = entries
|
||||
word_result["entries"] = entries
|
||||
word_result["llm_review"]["applied_count"] = applied_count
|
||||
word_result["llm_review"]["applied_at"] = datetime.utcnow().isoformat()
|
||||
|
||||
await update_session_db(session_id, word_result=word_result)
|
||||
|
||||
if session_id in _cache:
|
||||
_cache[session_id]["word_result"] = word_result
|
||||
|
||||
logger.info(f"Applied {applied_count}/{len(changes)} LLM corrections for session {session_id}")
|
||||
|
||||
return {
|
||||
"session_id": session_id,
|
||||
"applied_count": applied_count,
|
||||
"total_changes": len(changes),
|
||||
}
|
||||
@@ -1,10 +1,8 @@
|
||||
"""
|
||||
OCR Merge Helpers and Kombi Endpoints.
|
||||
OCR Merge Kombi Endpoints — paddle-kombi and rapid-kombi endpoints.
|
||||
|
||||
Contains merge helper functions for combining PaddleOCR/RapidOCR with Tesseract
|
||||
results, plus the paddle-kombi and rapid-kombi endpoints.
|
||||
|
||||
Extracted from ocr_pipeline_api.py for modularity.
|
||||
Merge helper functions live in ocr_merge_helpers.py.
|
||||
This module re-exports them for backward compatibility.
|
||||
|
||||
Lizenz: Apache 2.0
|
||||
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
|
||||
@@ -12,10 +10,8 @@ DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
|
||||
|
||||
import logging
|
||||
import time
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import cv2
|
||||
import httpx
|
||||
import numpy as np
|
||||
from fastapi import APIRouter, HTTPException
|
||||
|
||||
@@ -23,356 +19,23 @@ from cv_words_first import build_grid_from_words
|
||||
from ocr_pipeline_common import _cache, _append_pipeline_log
|
||||
from ocr_pipeline_session_store import get_session_image, update_session_db
|
||||
|
||||
# Re-export merge helpers for backward compatibility
|
||||
from ocr_merge_helpers import ( # noqa: F401
|
||||
_split_paddle_multi_words,
|
||||
_group_words_into_rows,
|
||||
_row_center_y,
|
||||
_merge_row_sequences,
|
||||
_merge_paddle_tesseract,
|
||||
_deduplicate_words,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["ocr-pipeline"])
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Merge helper functions
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _split_paddle_multi_words(words: list) -> list:
|
||||
"""Split PaddleOCR multi-word boxes into individual word boxes.
|
||||
|
||||
PaddleOCR often returns entire phrases as a single box, e.g.
|
||||
"More than 200 singers took part in the" with one bounding box.
|
||||
This splits them into individual words with proportional widths.
|
||||
Also handles leading "!" (e.g. "!Betonung" → ["!", "Betonung"])
|
||||
and IPA brackets (e.g. "badge[bxd3]" → ["badge", "[bxd3]"]).
|
||||
"""
|
||||
import re
|
||||
|
||||
result = []
|
||||
for w in words:
|
||||
raw_text = w.get("text", "").strip()
|
||||
if not raw_text:
|
||||
continue
|
||||
# Split on whitespace, before "[" (IPA), and after "!" before letter
|
||||
tokens = re.split(
|
||||
r'\s+|(?=\[)|(?<=!)(?=[A-Za-z\u00c0-\u024f])', raw_text
|
||||
)
|
||||
tokens = [t for t in tokens if t]
|
||||
|
||||
if len(tokens) <= 1:
|
||||
result.append(w)
|
||||
else:
|
||||
# Split proportionally by character count
|
||||
total_chars = sum(len(t) for t in tokens)
|
||||
if total_chars == 0:
|
||||
continue
|
||||
n_gaps = len(tokens) - 1
|
||||
gap_px = w["width"] * 0.02
|
||||
usable_w = w["width"] - gap_px * n_gaps
|
||||
cursor = w["left"]
|
||||
for t in tokens:
|
||||
token_w = max(1, usable_w * len(t) / total_chars)
|
||||
result.append({
|
||||
"text": t,
|
||||
"left": round(cursor),
|
||||
"top": w["top"],
|
||||
"width": round(token_w),
|
||||
"height": w["height"],
|
||||
"conf": w.get("conf", 0),
|
||||
})
|
||||
cursor += token_w + gap_px
|
||||
return result
|
||||
|
||||
|
||||
def _group_words_into_rows(words: list, row_gap: int = 12) -> list:
|
||||
"""Group words into rows by Y-position clustering.
|
||||
|
||||
Words whose vertical centers are within `row_gap` pixels are on the same row.
|
||||
Returns list of rows, each row is a list of words sorted left-to-right.
|
||||
"""
|
||||
if not words:
|
||||
return []
|
||||
# Sort by vertical center
|
||||
sorted_words = sorted(words, key=lambda w: w["top"] + w.get("height", 0) / 2)
|
||||
rows: list = []
|
||||
current_row: list = [sorted_words[0]]
|
||||
current_cy = sorted_words[0]["top"] + sorted_words[0].get("height", 0) / 2
|
||||
|
||||
for w in sorted_words[1:]:
|
||||
cy = w["top"] + w.get("height", 0) / 2
|
||||
if abs(cy - current_cy) <= row_gap:
|
||||
current_row.append(w)
|
||||
else:
|
||||
# Sort current row left-to-right before saving
|
||||
rows.append(sorted(current_row, key=lambda w: w["left"]))
|
||||
current_row = [w]
|
||||
current_cy = cy
|
||||
if current_row:
|
||||
rows.append(sorted(current_row, key=lambda w: w["left"]))
|
||||
return rows
|
||||
|
||||
|
||||
def _row_center_y(row: list) -> float:
|
||||
"""Average vertical center of a row of words."""
|
||||
if not row:
|
||||
return 0.0
|
||||
return sum(w["top"] + w.get("height", 0) / 2 for w in row) / len(row)
|
||||
|
||||
|
||||
def _merge_row_sequences(paddle_row: list, tess_row: list) -> list:
|
||||
"""Merge two word sequences from the same row using sequence alignment.
|
||||
|
||||
Both sequences are sorted left-to-right. Walk through both simultaneously:
|
||||
- If words match (same/similar text): take Paddle text with averaged coords
|
||||
- If they don't match: the extra word is unique to one engine, include it
|
||||
|
||||
This prevents duplicates because both engines produce words in the same order.
|
||||
"""
|
||||
merged = []
|
||||
pi, ti = 0, 0
|
||||
|
||||
while pi < len(paddle_row) and ti < len(tess_row):
|
||||
pw = paddle_row[pi]
|
||||
tw = tess_row[ti]
|
||||
|
||||
# Check if these are the same word
|
||||
pt = pw.get("text", "").lower().strip()
|
||||
tt = tw.get("text", "").lower().strip()
|
||||
|
||||
# Same text or one contains the other
|
||||
is_same = (pt == tt) or (len(pt) > 1 and len(tt) > 1 and (pt in tt or tt in pt))
|
||||
|
||||
# Spatial overlap check: if words overlap >= 40% horizontally,
|
||||
# they're the same physical word regardless of OCR text differences.
|
||||
# (40% catches borderline cases like "Stick"/"Stück" at 48% overlap)
|
||||
spatial_match = False
|
||||
if not is_same:
|
||||
overlap_left = max(pw["left"], tw["left"])
|
||||
overlap_right = min(
|
||||
pw["left"] + pw.get("width", 0),
|
||||
tw["left"] + tw.get("width", 0),
|
||||
)
|
||||
overlap_w = max(0, overlap_right - overlap_left)
|
||||
min_w = min(pw.get("width", 1), tw.get("width", 1))
|
||||
if min_w > 0 and overlap_w / min_w >= 0.4:
|
||||
is_same = True
|
||||
spatial_match = True
|
||||
|
||||
if is_same:
|
||||
# Matched — average coordinates weighted by confidence
|
||||
pc = pw.get("conf", 80)
|
||||
tc = tw.get("conf", 50)
|
||||
total = pc + tc
|
||||
if total == 0:
|
||||
total = 1
|
||||
# Text: prefer higher-confidence engine when texts differ
|
||||
# (e.g. Tesseract "Stück" conf=98 vs PaddleOCR "Stick" conf=80)
|
||||
if spatial_match and pc < tc:
|
||||
best_text = tw["text"]
|
||||
else:
|
||||
best_text = pw["text"]
|
||||
merged.append({
|
||||
"text": best_text,
|
||||
"left": round((pw["left"] * pc + tw["left"] * tc) / total),
|
||||
"top": round((pw["top"] * pc + tw["top"] * tc) / total),
|
||||
"width": round((pw["width"] * pc + tw["width"] * tc) / total),
|
||||
"height": round((pw["height"] * pc + tw["height"] * tc) / total),
|
||||
"conf": max(pc, tc),
|
||||
})
|
||||
pi += 1
|
||||
ti += 1
|
||||
else:
|
||||
# Different text — one engine found something extra
|
||||
# Look ahead: is the current Paddle word somewhere in Tesseract ahead?
|
||||
paddle_ahead = any(
|
||||
tess_row[t].get("text", "").lower().strip() == pt
|
||||
for t in range(ti + 1, min(ti + 4, len(tess_row)))
|
||||
)
|
||||
# Is the current Tesseract word somewhere in Paddle ahead?
|
||||
tess_ahead = any(
|
||||
paddle_row[p].get("text", "").lower().strip() == tt
|
||||
for p in range(pi + 1, min(pi + 4, len(paddle_row)))
|
||||
)
|
||||
|
||||
if paddle_ahead and not tess_ahead:
|
||||
# Tesseract has an extra word (e.g. "!" or bullet) → include it
|
||||
if tw.get("conf", 0) >= 30:
|
||||
merged.append(tw)
|
||||
ti += 1
|
||||
elif tess_ahead and not paddle_ahead:
|
||||
# Paddle has an extra word → include it
|
||||
merged.append(pw)
|
||||
pi += 1
|
||||
else:
|
||||
# Both have unique words or neither found ahead → take leftmost first
|
||||
if pw["left"] <= tw["left"]:
|
||||
merged.append(pw)
|
||||
pi += 1
|
||||
else:
|
||||
if tw.get("conf", 0) >= 30:
|
||||
merged.append(tw)
|
||||
ti += 1
|
||||
|
||||
# Remaining words from either engine
|
||||
while pi < len(paddle_row):
|
||||
merged.append(paddle_row[pi])
|
||||
pi += 1
|
||||
while ti < len(tess_row):
|
||||
tw = tess_row[ti]
|
||||
if tw.get("conf", 0) >= 30:
|
||||
merged.append(tw)
|
||||
ti += 1
|
||||
|
||||
return merged
|
||||
|
||||
|
||||
def _merge_paddle_tesseract(paddle_words: list, tess_words: list) -> list:
|
||||
"""Merge word boxes from PaddleOCR and Tesseract using row-based sequence alignment.
|
||||
|
||||
Strategy:
|
||||
1. Group each engine's words into rows (by Y-position clustering)
|
||||
2. Match rows between engines (by vertical center proximity)
|
||||
3. Within each matched row: merge sequences left-to-right, deduplicating
|
||||
words that appear in both engines at the same sequence position
|
||||
4. Unmatched rows from either engine: keep as-is
|
||||
|
||||
This prevents:
|
||||
- Cross-line averaging (words from different lines being merged)
|
||||
- Duplicate words (same word from both engines shown twice)
|
||||
"""
|
||||
if not paddle_words and not tess_words:
|
||||
return []
|
||||
if not paddle_words:
|
||||
return [w for w in tess_words if w.get("conf", 0) >= 40]
|
||||
if not tess_words:
|
||||
return list(paddle_words)
|
||||
|
||||
# Step 1: Group into rows
|
||||
paddle_rows = _group_words_into_rows(paddle_words)
|
||||
tess_rows = _group_words_into_rows(tess_words)
|
||||
|
||||
# Step 2: Match rows between engines by vertical center proximity
|
||||
used_tess_rows: set = set()
|
||||
merged_all: list = []
|
||||
|
||||
for pr in paddle_rows:
|
||||
pr_cy = _row_center_y(pr)
|
||||
best_dist, best_tri = float("inf"), -1
|
||||
for tri, tr in enumerate(tess_rows):
|
||||
if tri in used_tess_rows:
|
||||
continue
|
||||
tr_cy = _row_center_y(tr)
|
||||
dist = abs(pr_cy - tr_cy)
|
||||
if dist < best_dist:
|
||||
best_dist, best_tri = dist, tri
|
||||
|
||||
# Row height threshold — rows must be within ~1.5x typical line height
|
||||
max_row_dist = max(
|
||||
max((w.get("height", 20) for w in pr), default=20),
|
||||
15,
|
||||
)
|
||||
|
||||
if best_tri >= 0 and best_dist <= max_row_dist:
|
||||
# Matched row — merge sequences
|
||||
tr = tess_rows[best_tri]
|
||||
used_tess_rows.add(best_tri)
|
||||
merged_all.extend(_merge_row_sequences(pr, tr))
|
||||
else:
|
||||
# No matching Tesseract row — keep Paddle row as-is
|
||||
merged_all.extend(pr)
|
||||
|
||||
# Add unmatched Tesseract rows
|
||||
for tri, tr in enumerate(tess_rows):
|
||||
if tri not in used_tess_rows:
|
||||
for tw in tr:
|
||||
if tw.get("conf", 0) >= 40:
|
||||
merged_all.append(tw)
|
||||
|
||||
return merged_all
|
||||
|
||||
|
||||
def _deduplicate_words(words: list) -> list:
|
||||
"""Remove duplicate words with same text at overlapping positions.
|
||||
|
||||
PaddleOCR can return overlapping phrases (e.g. "von jm." and "jm. =")
|
||||
that produce duplicate words after splitting. This pass removes them.
|
||||
|
||||
A word is a duplicate only when BOTH horizontal AND vertical overlap
|
||||
exceed 50% — same text on the same visual line at the same position.
|
||||
"""
|
||||
if not words:
|
||||
return words
|
||||
|
||||
result: list = []
|
||||
for w in words:
|
||||
wt = w.get("text", "").lower().strip()
|
||||
if not wt:
|
||||
continue
|
||||
is_dup = False
|
||||
w_right = w["left"] + w.get("width", 0)
|
||||
w_bottom = w["top"] + w.get("height", 0)
|
||||
for existing in result:
|
||||
et = existing.get("text", "").lower().strip()
|
||||
if wt != et:
|
||||
continue
|
||||
# Horizontal overlap
|
||||
ox_l = max(w["left"], existing["left"])
|
||||
ox_r = min(w_right, existing["left"] + existing.get("width", 0))
|
||||
ox = max(0, ox_r - ox_l)
|
||||
min_w = min(w.get("width", 1), existing.get("width", 1))
|
||||
if min_w <= 0 or ox / min_w < 0.5:
|
||||
continue
|
||||
# Vertical overlap — must also be on the same line
|
||||
oy_t = max(w["top"], existing["top"])
|
||||
oy_b = min(w_bottom, existing["top"] + existing.get("height", 0))
|
||||
oy = max(0, oy_b - oy_t)
|
||||
min_h = min(w.get("height", 1), existing.get("height", 1))
|
||||
if min_h > 0 and oy / min_h >= 0.5:
|
||||
is_dup = True
|
||||
break
|
||||
if not is_dup:
|
||||
result.append(w)
|
||||
|
||||
removed = len(words) - len(result)
|
||||
if removed:
|
||||
logger.info("dedup: removed %d duplicate words", removed)
|
||||
return result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Kombi endpoints
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@router.post("/sessions/{session_id}/paddle-kombi")
|
||||
async def paddle_kombi(session_id: str):
|
||||
"""Run PaddleOCR + Tesseract on the preprocessed image and merge results.
|
||||
|
||||
Both engines run on the same preprocessed (cropped/dewarped) image.
|
||||
Word boxes are matched by IoU and coordinates are averaged weighted by
|
||||
confidence. Unmatched Tesseract words (bullets, symbols) are added.
|
||||
"""
|
||||
img_png = await get_session_image(session_id, "cropped")
|
||||
if not img_png:
|
||||
img_png = await get_session_image(session_id, "dewarped")
|
||||
if not img_png:
|
||||
img_png = await get_session_image(session_id, "original")
|
||||
if not img_png:
|
||||
raise HTTPException(status_code=404, detail="No image found for this session")
|
||||
|
||||
img_arr = np.frombuffer(img_png, dtype=np.uint8)
|
||||
img_bgr = cv2.imdecode(img_arr, cv2.IMREAD_COLOR)
|
||||
if img_bgr is None:
|
||||
raise HTTPException(status_code=400, detail="Failed to decode image")
|
||||
|
||||
img_h, img_w = img_bgr.shape[:2]
|
||||
|
||||
from cv_ocr_engines import ocr_region_paddle
|
||||
|
||||
t0 = time.time()
|
||||
|
||||
# --- PaddleOCR ---
|
||||
paddle_words = await ocr_region_paddle(img_bgr, region=None)
|
||||
if not paddle_words:
|
||||
paddle_words = []
|
||||
|
||||
# --- Tesseract ---
|
||||
def _run_tesseract_words(img_bgr) -> list:
|
||||
"""Run Tesseract OCR on an image and return word dicts."""
|
||||
from PIL import Image
|
||||
import pytesseract
|
||||
|
||||
@@ -397,15 +60,98 @@ async def paddle_kombi(session_id: str):
|
||||
"height": data["height"][i],
|
||||
"conf": conf,
|
||||
})
|
||||
return tess_words
|
||||
|
||||
|
||||
def _build_kombi_word_result(
|
||||
cells: list,
|
||||
columns_meta: list,
|
||||
img_w: int,
|
||||
img_h: int,
|
||||
duration: float,
|
||||
engine_name: str,
|
||||
raw_engine_words: list,
|
||||
raw_engine_words_split: list,
|
||||
tess_words: list,
|
||||
merged_words: list,
|
||||
raw_engine_key: str = "raw_paddle_words",
|
||||
raw_split_key: str = "raw_paddle_words_split",
|
||||
) -> dict:
|
||||
"""Build the word_result dict for kombi endpoints."""
|
||||
n_rows = len(set(c["row_index"] for c in cells)) if cells else 0
|
||||
n_cols = len(columns_meta)
|
||||
col_types = {c.get("type") for c in columns_meta}
|
||||
is_vocab = bool(col_types & {"column_en", "column_de"})
|
||||
|
||||
return {
|
||||
"cells": cells,
|
||||
"grid_shape": {"rows": n_rows, "cols": n_cols, "total_cells": len(cells)},
|
||||
"columns_used": columns_meta,
|
||||
"layout": "vocab" if is_vocab else "generic",
|
||||
"image_width": img_w,
|
||||
"image_height": img_h,
|
||||
"duration_seconds": round(duration, 2),
|
||||
"ocr_engine": engine_name,
|
||||
"grid_method": engine_name,
|
||||
raw_engine_key: raw_engine_words,
|
||||
raw_split_key: raw_engine_words_split,
|
||||
"raw_tesseract_words": tess_words,
|
||||
"summary": {
|
||||
"total_cells": len(cells),
|
||||
"non_empty_cells": sum(1 for c in cells if c.get("text")),
|
||||
"low_confidence": sum(1 for c in cells if 0 < c.get("confidence", 0) < 50),
|
||||
raw_engine_key.replace("raw_", "").replace("_words", "_words"): len(raw_engine_words),
|
||||
raw_split_key.replace("raw_", "").replace("_words_split", "_words_split"): len(raw_engine_words_split),
|
||||
"tesseract_words": len(tess_words),
|
||||
"merged_words": len(merged_words),
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
async def _load_session_image(session_id: str):
|
||||
"""Load preprocessed image for kombi endpoints."""
|
||||
img_png = await get_session_image(session_id, "cropped")
|
||||
if not img_png:
|
||||
img_png = await get_session_image(session_id, "dewarped")
|
||||
if not img_png:
|
||||
img_png = await get_session_image(session_id, "original")
|
||||
if not img_png:
|
||||
raise HTTPException(status_code=404, detail="No image found for this session")
|
||||
|
||||
img_arr = np.frombuffer(img_png, dtype=np.uint8)
|
||||
img_bgr = cv2.imdecode(img_arr, cv2.IMREAD_COLOR)
|
||||
if img_bgr is None:
|
||||
raise HTTPException(status_code=400, detail="Failed to decode image")
|
||||
|
||||
return img_png, img_bgr
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Kombi endpoints
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.post("/sessions/{session_id}/paddle-kombi")
|
||||
async def paddle_kombi(session_id: str):
|
||||
"""Run PaddleOCR + Tesseract on the preprocessed image and merge results."""
|
||||
img_png, img_bgr = await _load_session_image(session_id)
|
||||
img_h, img_w = img_bgr.shape[:2]
|
||||
|
||||
from cv_ocr_engines import ocr_region_paddle
|
||||
|
||||
t0 = time.time()
|
||||
|
||||
paddle_words = await ocr_region_paddle(img_bgr, region=None)
|
||||
if not paddle_words:
|
||||
paddle_words = []
|
||||
|
||||
tess_words = _run_tesseract_words(img_bgr)
|
||||
|
||||
# --- Split multi-word Paddle boxes into individual words ---
|
||||
paddle_words_split = _split_paddle_multi_words(paddle_words)
|
||||
logger.info(
|
||||
"paddle_kombi: split %d paddle boxes → %d individual words",
|
||||
"paddle_kombi: split %d paddle boxes -> %d individual words",
|
||||
len(paddle_words), len(paddle_words_split),
|
||||
)
|
||||
|
||||
# --- Merge ---
|
||||
if not paddle_words_split and not tess_words:
|
||||
raise HTTPException(status_code=400, detail="Both OCR engines returned no words")
|
||||
|
||||
@@ -418,49 +164,23 @@ async def paddle_kombi(session_id: str):
|
||||
for cell in cells:
|
||||
cell["ocr_engine"] = "kombi"
|
||||
|
||||
n_rows = len(set(c["row_index"] for c in cells)) if cells else 0
|
||||
n_cols = len(columns_meta)
|
||||
col_types = {c.get("type") for c in columns_meta}
|
||||
is_vocab = bool(col_types & {"column_en", "column_de"})
|
||||
|
||||
word_result = {
|
||||
"cells": cells,
|
||||
"grid_shape": {"rows": n_rows, "cols": n_cols, "total_cells": len(cells)},
|
||||
"columns_used": columns_meta,
|
||||
"layout": "vocab" if is_vocab else "generic",
|
||||
"image_width": img_w,
|
||||
"image_height": img_h,
|
||||
"duration_seconds": round(duration, 2),
|
||||
"ocr_engine": "kombi",
|
||||
"grid_method": "kombi",
|
||||
"raw_paddle_words": paddle_words,
|
||||
"raw_paddle_words_split": paddle_words_split,
|
||||
"raw_tesseract_words": tess_words,
|
||||
"summary": {
|
||||
"total_cells": len(cells),
|
||||
"non_empty_cells": sum(1 for c in cells if c.get("text")),
|
||||
"low_confidence": sum(1 for c in cells if 0 < c.get("confidence", 0) < 50),
|
||||
"paddle_words": len(paddle_words),
|
||||
"paddle_words_split": len(paddle_words_split),
|
||||
"tesseract_words": len(tess_words),
|
||||
"merged_words": len(merged_words),
|
||||
},
|
||||
}
|
||||
word_result = _build_kombi_word_result(
|
||||
cells, columns_meta, img_w, img_h, duration, "kombi",
|
||||
paddle_words, paddle_words_split, tess_words, merged_words,
|
||||
"raw_paddle_words", "raw_paddle_words_split",
|
||||
)
|
||||
|
||||
await update_session_db(
|
||||
session_id,
|
||||
word_result=word_result,
|
||||
cropped_png=img_png,
|
||||
current_step=8,
|
||||
session_id, word_result=word_result, cropped_png=img_png, current_step=8,
|
||||
)
|
||||
# Update in-memory cache so detect-structure can access word_result
|
||||
if session_id in _cache:
|
||||
_cache[session_id]["word_result"] = word_result
|
||||
|
||||
logger.info(
|
||||
"paddle_kombi session %s: %d cells (%d rows, %d cols) in %.2fs "
|
||||
"[paddle=%d, tess=%d, merged=%d]",
|
||||
session_id, len(cells), n_rows, n_cols, duration,
|
||||
session_id, len(cells), word_result["grid_shape"]["rows"],
|
||||
word_result["grid_shape"]["cols"], duration,
|
||||
len(paddle_words), len(tess_words), len(merged_words),
|
||||
)
|
||||
|
||||
@@ -478,24 +198,8 @@ async def paddle_kombi(session_id: str):
|
||||
|
||||
@router.post("/sessions/{session_id}/rapid-kombi")
|
||||
async def rapid_kombi(session_id: str):
|
||||
"""Run RapidOCR + Tesseract on the preprocessed image and merge results.
|
||||
|
||||
Same merge logic as paddle-kombi, but uses local RapidOCR (ONNX Runtime)
|
||||
instead of remote PaddleOCR service.
|
||||
"""
|
||||
img_png = await get_session_image(session_id, "cropped")
|
||||
if not img_png:
|
||||
img_png = await get_session_image(session_id, "dewarped")
|
||||
if not img_png:
|
||||
img_png = await get_session_image(session_id, "original")
|
||||
if not img_png:
|
||||
raise HTTPException(status_code=404, detail="No image found for this session")
|
||||
|
||||
img_arr = np.frombuffer(img_png, dtype=np.uint8)
|
||||
img_bgr = cv2.imdecode(img_arr, cv2.IMREAD_COLOR)
|
||||
if img_bgr is None:
|
||||
raise HTTPException(status_code=400, detail="Failed to decode image")
|
||||
|
||||
"""Run RapidOCR + Tesseract on the preprocessed image and merge results."""
|
||||
img_png, img_bgr = await _load_session_image(session_id)
|
||||
img_h, img_w = img_bgr.shape[:2]
|
||||
|
||||
from cv_ocr_engines import ocr_region_rapid
|
||||
@@ -503,7 +207,6 @@ async def rapid_kombi(session_id: str):
|
||||
|
||||
t0 = time.time()
|
||||
|
||||
# --- RapidOCR (local, synchronous) ---
|
||||
full_region = PageRegion(
|
||||
type="full_page", x=0, y=0, width=img_w, height=img_h,
|
||||
)
|
||||
@@ -511,40 +214,14 @@ async def rapid_kombi(session_id: str):
|
||||
if not rapid_words:
|
||||
rapid_words = []
|
||||
|
||||
# --- Tesseract ---
|
||||
from PIL import Image
|
||||
import pytesseract
|
||||
tess_words = _run_tesseract_words(img_bgr)
|
||||
|
||||
pil_img = Image.fromarray(cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB))
|
||||
data = pytesseract.image_to_data(
|
||||
pil_img, lang="eng+deu",
|
||||
config="--psm 6 --oem 3",
|
||||
output_type=pytesseract.Output.DICT,
|
||||
)
|
||||
tess_words = []
|
||||
for i in range(len(data["text"])):
|
||||
text = str(data["text"][i]).strip()
|
||||
conf_raw = str(data["conf"][i])
|
||||
conf = int(conf_raw) if conf_raw.lstrip("-").isdigit() else -1
|
||||
if not text or conf < 20:
|
||||
continue
|
||||
tess_words.append({
|
||||
"text": text,
|
||||
"left": data["left"][i],
|
||||
"top": data["top"][i],
|
||||
"width": data["width"][i],
|
||||
"height": data["height"][i],
|
||||
"conf": conf,
|
||||
})
|
||||
|
||||
# --- Split multi-word RapidOCR boxes into individual words ---
|
||||
rapid_words_split = _split_paddle_multi_words(rapid_words)
|
||||
logger.info(
|
||||
"rapid_kombi: split %d rapid boxes → %d individual words",
|
||||
"rapid_kombi: split %d rapid boxes -> %d individual words",
|
||||
len(rapid_words), len(rapid_words_split),
|
||||
)
|
||||
|
||||
# --- Merge ---
|
||||
if not rapid_words_split and not tess_words:
|
||||
raise HTTPException(status_code=400, detail="Both OCR engines returned no words")
|
||||
|
||||
@@ -557,49 +234,23 @@ async def rapid_kombi(session_id: str):
|
||||
for cell in cells:
|
||||
cell["ocr_engine"] = "rapid_kombi"
|
||||
|
||||
n_rows = len(set(c["row_index"] for c in cells)) if cells else 0
|
||||
n_cols = len(columns_meta)
|
||||
col_types = {c.get("type") for c in columns_meta}
|
||||
is_vocab = bool(col_types & {"column_en", "column_de"})
|
||||
|
||||
word_result = {
|
||||
"cells": cells,
|
||||
"grid_shape": {"rows": n_rows, "cols": n_cols, "total_cells": len(cells)},
|
||||
"columns_used": columns_meta,
|
||||
"layout": "vocab" if is_vocab else "generic",
|
||||
"image_width": img_w,
|
||||
"image_height": img_h,
|
||||
"duration_seconds": round(duration, 2),
|
||||
"ocr_engine": "rapid_kombi",
|
||||
"grid_method": "rapid_kombi",
|
||||
"raw_rapid_words": rapid_words,
|
||||
"raw_rapid_words_split": rapid_words_split,
|
||||
"raw_tesseract_words": tess_words,
|
||||
"summary": {
|
||||
"total_cells": len(cells),
|
||||
"non_empty_cells": sum(1 for c in cells if c.get("text")),
|
||||
"low_confidence": sum(1 for c in cells if 0 < c.get("confidence", 0) < 50),
|
||||
"rapid_words": len(rapid_words),
|
||||
"rapid_words_split": len(rapid_words_split),
|
||||
"tesseract_words": len(tess_words),
|
||||
"merged_words": len(merged_words),
|
||||
},
|
||||
}
|
||||
word_result = _build_kombi_word_result(
|
||||
cells, columns_meta, img_w, img_h, duration, "rapid_kombi",
|
||||
rapid_words, rapid_words_split, tess_words, merged_words,
|
||||
"raw_rapid_words", "raw_rapid_words_split",
|
||||
)
|
||||
|
||||
await update_session_db(
|
||||
session_id,
|
||||
word_result=word_result,
|
||||
cropped_png=img_png,
|
||||
current_step=8,
|
||||
session_id, word_result=word_result, cropped_png=img_png, current_step=8,
|
||||
)
|
||||
# Update in-memory cache so detect-structure can access word_result
|
||||
if session_id in _cache:
|
||||
_cache[session_id]["word_result"] = word_result
|
||||
|
||||
logger.info(
|
||||
"rapid_kombi session %s: %d cells (%d rows, %d cols) in %.2fs "
|
||||
"[rapid=%d, tess=%d, merged=%d]",
|
||||
session_id, len(cells), n_rows, n_cols, duration,
|
||||
session_id, len(cells), word_result["grid_shape"]["rows"],
|
||||
word_result["grid_shape"]["cols"], duration,
|
||||
len(rapid_words), len(tess_words), len(merged_words),
|
||||
)
|
||||
|
||||
|
||||
@@ -1,929 +1,26 @@
|
||||
"""
|
||||
OCR Pipeline Postprocessing API — LLM review, reconstruction, export, validation,
|
||||
image detection/generation, and handwriting removal endpoints.
|
||||
OCR Pipeline Postprocessing API — composite router assembling LLM review,
|
||||
reconstruction, export, validation, image detection/generation, and
|
||||
handwriting removal endpoints.
|
||||
|
||||
Extracted from ocr_pipeline_api.py to keep the main module manageable.
|
||||
Split into sub-modules:
|
||||
ocr_pipeline_llm_review — LLM review + apply corrections
|
||||
ocr_pipeline_reconstruction — reconstruction save, Fabric JSON, merged entries, PDF/DOCX
|
||||
ocr_pipeline_validation — image detection, generation, validation, handwriting removal
|
||||
|
||||
Lizenz: Apache 2.0
|
||||
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
|
||||
from cv_vocab_pipeline import (
|
||||
OLLAMA_REVIEW_MODEL,
|
||||
llm_review_entries,
|
||||
llm_review_entries_streaming,
|
||||
)
|
||||
from ocr_pipeline_session_store import (
|
||||
get_session_db,
|
||||
get_session_image,
|
||||
get_sub_sessions,
|
||||
update_session_db,
|
||||
)
|
||||
from ocr_pipeline_common import (
|
||||
_cache,
|
||||
_load_session_to_cache,
|
||||
_get_cached,
|
||||
_get_base_image_png,
|
||||
_append_pipeline_log,
|
||||
RemoveHandwritingRequest,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["ocr-pipeline"])
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Pydantic Models
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
STYLE_SUFFIXES = {
|
||||
"educational": "educational illustration, textbook style, clear, colorful",
|
||||
"cartoon": "cartoon, child-friendly, simple shapes",
|
||||
"sketch": "pencil sketch, hand-drawn, black and white",
|
||||
"clipart": "clipart, flat vector style, simple",
|
||||
"realistic": "photorealistic, high detail",
|
||||
}
|
||||
|
||||
|
||||
class ValidationRequest(BaseModel):
|
||||
notes: Optional[str] = None
|
||||
score: Optional[int] = None
|
||||
|
||||
|
||||
class GenerateImageRequest(BaseModel):
|
||||
region_index: int
|
||||
prompt: str
|
||||
style: str = "educational"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Step 8: LLM Review
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.post("/sessions/{session_id}/llm-review")
|
||||
async def run_llm_review(session_id: str, request: Request, stream: bool = False):
|
||||
"""Run LLM-based correction on vocab entries from Step 5.
|
||||
|
||||
Query params:
|
||||
stream: false (default) for JSON response, true for SSE streaming
|
||||
"""
|
||||
session = await get_session_db(session_id)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||||
|
||||
word_result = session.get("word_result")
|
||||
if not word_result:
|
||||
raise HTTPException(status_code=400, detail="No word result found — run Step 5 first")
|
||||
|
||||
entries = word_result.get("vocab_entries") or word_result.get("entries") or []
|
||||
if not entries:
|
||||
raise HTTPException(status_code=400, detail="No vocab entries found — run Step 5 first")
|
||||
|
||||
# Optional model override from request body
|
||||
body = {}
|
||||
try:
|
||||
body = await request.json()
|
||||
except Exception:
|
||||
pass
|
||||
model = body.get("model") or OLLAMA_REVIEW_MODEL
|
||||
|
||||
if stream:
|
||||
return StreamingResponse(
|
||||
_llm_review_stream_generator(session_id, entries, word_result, model, request),
|
||||
media_type="text/event-stream",
|
||||
headers={"Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no"},
|
||||
)
|
||||
|
||||
# Non-streaming path
|
||||
try:
|
||||
result = await llm_review_entries(entries, model=model)
|
||||
except Exception as e:
|
||||
import traceback
|
||||
logger.error(f"LLM review failed for session {session_id}: {type(e).__name__}: {e}\n{traceback.format_exc()}")
|
||||
raise HTTPException(status_code=502, detail=f"LLM review failed ({type(e).__name__}): {e}")
|
||||
|
||||
# Store result inside word_result as a sub-key
|
||||
word_result["llm_review"] = {
|
||||
"changes": result["changes"],
|
||||
"model_used": result["model_used"],
|
||||
"duration_ms": result["duration_ms"],
|
||||
"entries_corrected": result["entries_corrected"],
|
||||
}
|
||||
await update_session_db(session_id, word_result=word_result, current_step=9)
|
||||
|
||||
if session_id in _cache:
|
||||
_cache[session_id]["word_result"] = word_result
|
||||
|
||||
logger.info(f"LLM review session {session_id}: {len(result['changes'])} changes, "
|
||||
f"{result['duration_ms']}ms, model={result['model_used']}")
|
||||
|
||||
await _append_pipeline_log(session_id, "correction", {
|
||||
"engine": "llm",
|
||||
"model": result["model_used"],
|
||||
"total_entries": len(entries),
|
||||
"corrections_proposed": len(result["changes"]),
|
||||
}, duration_ms=result["duration_ms"])
|
||||
|
||||
return {
|
||||
"session_id": session_id,
|
||||
"changes": result["changes"],
|
||||
"model_used": result["model_used"],
|
||||
"duration_ms": result["duration_ms"],
|
||||
"total_entries": len(entries),
|
||||
"corrections_found": len(result["changes"]),
|
||||
}
|
||||
|
||||
|
||||
async def _llm_review_stream_generator(
|
||||
session_id: str,
|
||||
entries: List[Dict],
|
||||
word_result: Dict,
|
||||
model: str,
|
||||
request: Request,
|
||||
):
|
||||
"""SSE generator that yields batch-by-batch LLM review progress."""
|
||||
try:
|
||||
async for event in llm_review_entries_streaming(entries, model=model):
|
||||
if await request.is_disconnected():
|
||||
logger.info(f"SSE: client disconnected during LLM review for {session_id}")
|
||||
return
|
||||
|
||||
yield f"data: {json.dumps(event, ensure_ascii=False)}\n\n"
|
||||
|
||||
# On complete: persist to DB
|
||||
if event.get("type") == "complete":
|
||||
word_result["llm_review"] = {
|
||||
"changes": event["changes"],
|
||||
"model_used": event["model_used"],
|
||||
"duration_ms": event["duration_ms"],
|
||||
"entries_corrected": event["entries_corrected"],
|
||||
}
|
||||
await update_session_db(session_id, word_result=word_result, current_step=9)
|
||||
if session_id in _cache:
|
||||
_cache[session_id]["word_result"] = word_result
|
||||
|
||||
logger.info(f"LLM review SSE session {session_id}: {event['corrections_found']} changes, "
|
||||
f"{event['duration_ms']}ms, skipped={event['skipped']}, model={event['model_used']}")
|
||||
|
||||
except Exception as e:
|
||||
import traceback
|
||||
logger.error(f"LLM review SSE failed for {session_id}: {type(e).__name__}: {e}\n{traceback.format_exc()}")
|
||||
error_event = {"type": "error", "detail": f"{type(e).__name__}: {e}"}
|
||||
yield f"data: {json.dumps(error_event)}\n\n"
|
||||
|
||||
|
||||
@router.post("/sessions/{session_id}/llm-review/apply")
|
||||
async def apply_llm_corrections(session_id: str, request: Request):
|
||||
"""Apply selected LLM corrections to vocab entries."""
|
||||
session = await get_session_db(session_id)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||||
|
||||
word_result = session.get("word_result")
|
||||
if not word_result:
|
||||
raise HTTPException(status_code=400, detail="No word result found")
|
||||
|
||||
llm_review = word_result.get("llm_review")
|
||||
if not llm_review:
|
||||
raise HTTPException(status_code=400, detail="No LLM review found — run /llm-review first")
|
||||
|
||||
body = await request.json()
|
||||
accepted_indices = set(body.get("accepted_indices", [])) # indices into changes[]
|
||||
|
||||
changes = llm_review.get("changes", [])
|
||||
entries = word_result.get("vocab_entries") or word_result.get("entries") or []
|
||||
|
||||
# Build a lookup: (row_index, field) -> new_value for accepted changes
|
||||
corrections = {}
|
||||
applied_count = 0
|
||||
for idx, change in enumerate(changes):
|
||||
if idx in accepted_indices:
|
||||
key = (change["row_index"], change["field"])
|
||||
corrections[key] = change["new"]
|
||||
applied_count += 1
|
||||
|
||||
# Apply corrections to entries
|
||||
for entry in entries:
|
||||
row_idx = entry.get("row_index", -1)
|
||||
for field_name in ("english", "german", "example"):
|
||||
key = (row_idx, field_name)
|
||||
if key in corrections:
|
||||
entry[field_name] = corrections[key]
|
||||
entry["llm_corrected"] = True
|
||||
|
||||
# Update word_result
|
||||
word_result["vocab_entries"] = entries
|
||||
word_result["entries"] = entries
|
||||
word_result["llm_review"]["applied_count"] = applied_count
|
||||
word_result["llm_review"]["applied_at"] = datetime.utcnow().isoformat()
|
||||
|
||||
await update_session_db(session_id, word_result=word_result)
|
||||
|
||||
if session_id in _cache:
|
||||
_cache[session_id]["word_result"] = word_result
|
||||
|
||||
logger.info(f"Applied {applied_count}/{len(changes)} LLM corrections for session {session_id}")
|
||||
|
||||
return {
|
||||
"session_id": session_id,
|
||||
"applied_count": applied_count,
|
||||
"total_changes": len(changes),
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Step 9: Reconstruction + Fabric JSON export
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.post("/sessions/{session_id}/reconstruction")
|
||||
async def save_reconstruction(session_id: str, request: Request):
|
||||
"""Save edited cell texts from reconstruction step."""
|
||||
session = await get_session_db(session_id)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||||
|
||||
word_result = session.get("word_result")
|
||||
if not word_result:
|
||||
raise HTTPException(status_code=400, detail="No word result found")
|
||||
|
||||
body = await request.json()
|
||||
cell_updates = body.get("cells", [])
|
||||
|
||||
if not cell_updates:
|
||||
await update_session_db(session_id, current_step=10)
|
||||
return {"session_id": session_id, "updated": 0}
|
||||
|
||||
# Build update map: cell_id -> new text
|
||||
update_map = {c["cell_id"]: c["text"] for c in cell_updates}
|
||||
|
||||
# Separate sub-session updates (cell_ids prefixed with "box{N}_")
|
||||
sub_updates: Dict[int, Dict[str, str]] = {} # box_index -> {original_cell_id: text}
|
||||
main_updates: Dict[str, str] = {}
|
||||
for cell_id, text in update_map.items():
|
||||
m = re.match(r'^box(\d+)_(.+)$', cell_id)
|
||||
if m:
|
||||
bi = int(m.group(1))
|
||||
original_id = m.group(2)
|
||||
sub_updates.setdefault(bi, {})[original_id] = text
|
||||
else:
|
||||
main_updates[cell_id] = text
|
||||
|
||||
# Update main session cells
|
||||
cells = word_result.get("cells", [])
|
||||
updated_count = 0
|
||||
for cell in cells:
|
||||
if cell["cell_id"] in main_updates:
|
||||
cell["text"] = main_updates[cell["cell_id"]]
|
||||
cell["status"] = "edited"
|
||||
updated_count += 1
|
||||
|
||||
word_result["cells"] = cells
|
||||
|
||||
# Also update vocab_entries if present
|
||||
entries = word_result.get("vocab_entries") or word_result.get("entries") or []
|
||||
if entries:
|
||||
# Map cell_id pattern "R{row}_C{col}" to entry fields
|
||||
for entry in entries:
|
||||
row_idx = entry.get("row_index", -1)
|
||||
# Check each field's cell
|
||||
for col_idx, field_name in enumerate(["english", "german", "example"]):
|
||||
cell_id = f"R{row_idx:02d}_C{col_idx}"
|
||||
# Also try without zero-padding
|
||||
cell_id_alt = f"R{row_idx}_C{col_idx}"
|
||||
new_text = main_updates.get(cell_id) or main_updates.get(cell_id_alt)
|
||||
if new_text is not None:
|
||||
entry[field_name] = new_text
|
||||
|
||||
word_result["vocab_entries"] = entries
|
||||
if "entries" in word_result:
|
||||
word_result["entries"] = entries
|
||||
|
||||
await update_session_db(session_id, word_result=word_result, current_step=10)
|
||||
|
||||
if session_id in _cache:
|
||||
_cache[session_id]["word_result"] = word_result
|
||||
|
||||
# Route sub-session updates
|
||||
sub_updated = 0
|
||||
if sub_updates:
|
||||
subs = await get_sub_sessions(session_id)
|
||||
sub_by_index = {s.get("box_index"): s["id"] for s in subs}
|
||||
for bi, updates in sub_updates.items():
|
||||
sub_id = sub_by_index.get(bi)
|
||||
if not sub_id:
|
||||
continue
|
||||
sub_session = await get_session_db(sub_id)
|
||||
if not sub_session:
|
||||
continue
|
||||
sub_word = sub_session.get("word_result")
|
||||
if not sub_word:
|
||||
continue
|
||||
sub_cells = sub_word.get("cells", [])
|
||||
for cell in sub_cells:
|
||||
if cell["cell_id"] in updates:
|
||||
cell["text"] = updates[cell["cell_id"]]
|
||||
cell["status"] = "edited"
|
||||
sub_updated += 1
|
||||
sub_word["cells"] = sub_cells
|
||||
await update_session_db(sub_id, word_result=sub_word)
|
||||
if sub_id in _cache:
|
||||
_cache[sub_id]["word_result"] = sub_word
|
||||
|
||||
total_updated = updated_count + sub_updated
|
||||
logger.info(f"Reconstruction saved for session {session_id}: "
|
||||
f"{updated_count} main + {sub_updated} sub-session cells updated")
|
||||
|
||||
return {
|
||||
"session_id": session_id,
|
||||
"updated": total_updated,
|
||||
"main_updated": updated_count,
|
||||
"sub_updated": sub_updated,
|
||||
}
|
||||
|
||||
|
||||
@router.get("/sessions/{session_id}/reconstruction/fabric-json")
|
||||
async def get_fabric_json(session_id: str):
|
||||
"""Return cell grid as Fabric.js-compatible JSON for the canvas editor.
|
||||
|
||||
If the session has sub-sessions (box regions), their cells are merged
|
||||
into the result at the correct Y positions.
|
||||
"""
|
||||
session = await get_session_db(session_id)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||||
|
||||
word_result = session.get("word_result")
|
||||
if not word_result:
|
||||
raise HTTPException(status_code=400, detail="No word result found")
|
||||
|
||||
cells = list(word_result.get("cells", []))
|
||||
img_w = word_result.get("image_width", 800)
|
||||
img_h = word_result.get("image_height", 600)
|
||||
|
||||
# Merge sub-session cells at box positions
|
||||
subs = await get_sub_sessions(session_id)
|
||||
if subs:
|
||||
column_result = session.get("column_result") or {}
|
||||
zones = column_result.get("zones") or []
|
||||
box_zones = [z for z in zones if z.get("zone_type") == "box" and z.get("box")]
|
||||
|
||||
for sub in subs:
|
||||
sub_session = await get_session_db(sub["id"])
|
||||
if not sub_session:
|
||||
continue
|
||||
sub_word = sub_session.get("word_result")
|
||||
if not sub_word or not sub_word.get("cells"):
|
||||
continue
|
||||
|
||||
bi = sub.get("box_index", 0)
|
||||
if bi < len(box_zones):
|
||||
box = box_zones[bi]["box"]
|
||||
box_y, box_x = box["y"], box["x"]
|
||||
else:
|
||||
box_y, box_x = 0, 0
|
||||
|
||||
# Offset sub-session cells to absolute page coordinates
|
||||
for cell in sub_word["cells"]:
|
||||
cell_copy = dict(cell)
|
||||
# Prefix cell_id with box index
|
||||
cell_copy["cell_id"] = f"box{bi}_{cell_copy.get('cell_id', '')}"
|
||||
cell_copy["source"] = f"box_{bi}"
|
||||
# Offset bbox_px
|
||||
bbox = cell_copy.get("bbox_px", {})
|
||||
if bbox:
|
||||
bbox = dict(bbox)
|
||||
bbox["x"] = bbox.get("x", 0) + box_x
|
||||
bbox["y"] = bbox.get("y", 0) + box_y
|
||||
cell_copy["bbox_px"] = bbox
|
||||
cells.append(cell_copy)
|
||||
|
||||
from services.layout_reconstruction_service import cells_to_fabric_json
|
||||
fabric_json = cells_to_fabric_json(cells, img_w, img_h)
|
||||
|
||||
return fabric_json
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Vocab entries merged + PDF/DOCX export
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.get("/sessions/{session_id}/vocab-entries/merged")
|
||||
async def get_merged_vocab_entries(session_id: str):
|
||||
"""Return vocab entries from main session + all sub-sessions, sorted by Y position."""
|
||||
session = await get_session_db(session_id)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||||
|
||||
word_result = session.get("word_result") or {}
|
||||
entries = list(word_result.get("vocab_entries") or word_result.get("entries") or [])
|
||||
|
||||
# Tag main entries
|
||||
for e in entries:
|
||||
e.setdefault("source", "main")
|
||||
|
||||
# Merge sub-session entries
|
||||
subs = await get_sub_sessions(session_id)
|
||||
if subs:
|
||||
column_result = session.get("column_result") or {}
|
||||
zones = column_result.get("zones") or []
|
||||
box_zones = [z for z in zones if z.get("zone_type") == "box" and z.get("box")]
|
||||
|
||||
for sub in subs:
|
||||
sub_session = await get_session_db(sub["id"])
|
||||
if not sub_session:
|
||||
continue
|
||||
sub_word = sub_session.get("word_result") or {}
|
||||
sub_entries = sub_word.get("vocab_entries") or sub_word.get("entries") or []
|
||||
|
||||
bi = sub.get("box_index", 0)
|
||||
box_y = 0
|
||||
if bi < len(box_zones):
|
||||
box_y = box_zones[bi]["box"]["y"]
|
||||
|
||||
for e in sub_entries:
|
||||
e_copy = dict(e)
|
||||
e_copy["source"] = f"box_{bi}"
|
||||
e_copy["source_y"] = box_y # for sorting
|
||||
entries.append(e_copy)
|
||||
|
||||
# Sort by approximate Y position
|
||||
def _sort_key(e):
|
||||
if e.get("source", "main") == "main":
|
||||
return e.get("row_index", 0) * 100 # main entries by row index
|
||||
return e.get("source_y", 0) * 100 + e.get("row_index", 0)
|
||||
|
||||
entries.sort(key=_sort_key)
|
||||
|
||||
return {
|
||||
"session_id": session_id,
|
||||
"entries": entries,
|
||||
"total": len(entries),
|
||||
"sources": list(set(e.get("source", "main") for e in entries)),
|
||||
}
|
||||
|
||||
|
||||
@router.get("/sessions/{session_id}/reconstruction/export/pdf")
|
||||
async def export_reconstruction_pdf(session_id: str):
|
||||
"""Export the reconstructed cell grid as a PDF table."""
|
||||
session = await get_session_db(session_id)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||||
|
||||
word_result = session.get("word_result")
|
||||
if not word_result:
|
||||
raise HTTPException(status_code=400, detail="No word result found")
|
||||
|
||||
cells = word_result.get("cells", [])
|
||||
columns_used = word_result.get("columns_used", [])
|
||||
grid_shape = word_result.get("grid_shape", {})
|
||||
n_rows = grid_shape.get("rows", 0)
|
||||
n_cols = grid_shape.get("cols", 0)
|
||||
|
||||
# Build table data: rows x columns
|
||||
table_data: list[list[str]] = []
|
||||
header = [c.get("label", c.get("type", f"Col {i}")) for i, c in enumerate(columns_used)]
|
||||
if not header:
|
||||
header = [f"Col {i}" for i in range(n_cols)]
|
||||
table_data.append(header)
|
||||
|
||||
for r in range(n_rows):
|
||||
row_texts = []
|
||||
for ci in range(n_cols):
|
||||
cell_id = f"R{r:02d}_C{ci}"
|
||||
cell = next((c for c in cells if c.get("cell_id") == cell_id), None)
|
||||
row_texts.append(cell.get("text", "") if cell else "")
|
||||
table_data.append(row_texts)
|
||||
|
||||
# Generate PDF with reportlab
|
||||
try:
|
||||
from reportlab.lib.pagesizes import A4
|
||||
from reportlab.lib import colors
|
||||
from reportlab.platypus import SimpleDocTemplate, Table, TableStyle
|
||||
import io as _io
|
||||
|
||||
buf = _io.BytesIO()
|
||||
doc = SimpleDocTemplate(buf, pagesize=A4)
|
||||
if not table_data or not table_data[0]:
|
||||
raise HTTPException(status_code=400, detail="No data to export")
|
||||
|
||||
t = Table(table_data)
|
||||
t.setStyle(TableStyle([
|
||||
('BACKGROUND', (0, 0), (-1, 0), colors.HexColor('#0d9488')),
|
||||
('TEXTCOLOR', (0, 0), (-1, 0), colors.white),
|
||||
('FONTSIZE', (0, 0), (-1, -1), 9),
|
||||
('GRID', (0, 0), (-1, -1), 0.5, colors.grey),
|
||||
('VALIGN', (0, 0), (-1, -1), 'TOP'),
|
||||
('WORDWRAP', (0, 0), (-1, -1), True),
|
||||
]))
|
||||
doc.build([t])
|
||||
buf.seek(0)
|
||||
|
||||
from fastapi.responses import StreamingResponse
|
||||
return StreamingResponse(
|
||||
buf,
|
||||
media_type="application/pdf",
|
||||
headers={"Content-Disposition": f'attachment; filename="reconstruction_{session_id}.pdf"'},
|
||||
)
|
||||
except ImportError:
|
||||
raise HTTPException(status_code=501, detail="reportlab not installed")
|
||||
|
||||
|
||||
@router.get("/sessions/{session_id}/reconstruction/export/docx")
|
||||
async def export_reconstruction_docx(session_id: str):
|
||||
"""Export the reconstructed cell grid as a DOCX table."""
|
||||
session = await get_session_db(session_id)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||||
|
||||
word_result = session.get("word_result")
|
||||
if not word_result:
|
||||
raise HTTPException(status_code=400, detail="No word result found")
|
||||
|
||||
cells = word_result.get("cells", [])
|
||||
columns_used = word_result.get("columns_used", [])
|
||||
grid_shape = word_result.get("grid_shape", {})
|
||||
n_rows = grid_shape.get("rows", 0)
|
||||
n_cols = grid_shape.get("cols", 0)
|
||||
|
||||
try:
|
||||
from docx import Document
|
||||
from docx.shared import Pt
|
||||
import io as _io
|
||||
|
||||
doc = Document()
|
||||
doc.add_heading(f'Rekonstruktion – Session {session_id[:8]}', level=1)
|
||||
|
||||
# Build header
|
||||
header = [c.get("label", c.get("type", f"Col {i}")) for i, c in enumerate(columns_used)]
|
||||
if not header:
|
||||
header = [f"Col {i}" for i in range(n_cols)]
|
||||
|
||||
table = doc.add_table(rows=1 + n_rows, cols=max(n_cols, 1))
|
||||
table.style = 'Table Grid'
|
||||
|
||||
# Header row
|
||||
for ci, h in enumerate(header):
|
||||
table.rows[0].cells[ci].text = h
|
||||
|
||||
# Data rows
|
||||
for r in range(n_rows):
|
||||
for ci in range(n_cols):
|
||||
cell_id = f"R{r:02d}_C{ci}"
|
||||
cell = next((c for c in cells if c.get("cell_id") == cell_id), None)
|
||||
table.rows[r + 1].cells[ci].text = cell.get("text", "") if cell else ""
|
||||
|
||||
buf = _io.BytesIO()
|
||||
doc.save(buf)
|
||||
buf.seek(0)
|
||||
|
||||
from fastapi.responses import StreamingResponse
|
||||
return StreamingResponse(
|
||||
buf,
|
||||
media_type="application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
||||
headers={"Content-Disposition": f'attachment; filename="reconstruction_{session_id}.docx"'},
|
||||
)
|
||||
except ImportError:
|
||||
raise HTTPException(status_code=501, detail="python-docx not installed")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Step 8: Validation — Original vs. Reconstruction
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.post("/sessions/{session_id}/reconstruction/detect-images")
|
||||
async def detect_image_regions(session_id: str):
|
||||
"""Detect illustration/image regions in the original scan using VLM.
|
||||
|
||||
Sends the original image to qwen2.5vl to find non-text, non-table
|
||||
image areas, returning bounding boxes (in %) and descriptions.
|
||||
"""
|
||||
import base64
|
||||
import httpx
|
||||
import re
|
||||
|
||||
session = await get_session_db(session_id)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||||
|
||||
# Get original image bytes
|
||||
original_png = await get_session_image(session_id, "original")
|
||||
if not original_png:
|
||||
raise HTTPException(status_code=400, detail="No original image found")
|
||||
|
||||
# Build context from vocab entries for richer descriptions
|
||||
word_result = session.get("word_result") or {}
|
||||
entries = word_result.get("vocab_entries") or word_result.get("entries") or []
|
||||
vocab_context = ""
|
||||
if entries:
|
||||
sample = entries[:10]
|
||||
words = [f"{e.get('english', '')} / {e.get('german', '')}" for e in sample if e.get('english')]
|
||||
if words:
|
||||
vocab_context = f"\nContext: This is a vocabulary page with words like: {', '.join(words)}"
|
||||
|
||||
ollama_base = os.getenv("OLLAMA_BASE_URL", "http://host.docker.internal:11434")
|
||||
model = os.getenv("OLLAMA_HTR_MODEL", "qwen2.5vl:32b")
|
||||
|
||||
prompt = (
|
||||
"Analyze this scanned page. Find ALL illustration/image/picture regions "
|
||||
"(NOT text, NOT table cells, NOT blank areas). "
|
||||
"For each image region found, return its bounding box as percentage of page dimensions "
|
||||
"and a short English description of what the image shows. "
|
||||
"Reply with ONLY a JSON array like: "
|
||||
'[{"x": 10, "y": 20, "w": 30, "h": 25, "description": "drawing of a cat"}] '
|
||||
"where x, y, w, h are percentages (0-100) of the page width/height. "
|
||||
"If there are NO images on the page, return an empty array: []"
|
||||
f"{vocab_context}"
|
||||
)
|
||||
|
||||
img_b64 = base64.b64encode(original_png).decode("utf-8")
|
||||
payload = {
|
||||
"model": model,
|
||||
"prompt": prompt,
|
||||
"images": [img_b64],
|
||||
"stream": False,
|
||||
}
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=120.0) as client:
|
||||
resp = await client.post(f"{ollama_base}/api/generate", json=payload)
|
||||
resp.raise_for_status()
|
||||
text = resp.json().get("response", "")
|
||||
|
||||
# Parse JSON array from response
|
||||
match = re.search(r'\[.*?\]', text, re.DOTALL)
|
||||
if match:
|
||||
raw_regions = json.loads(match.group(0))
|
||||
else:
|
||||
raw_regions = []
|
||||
|
||||
# Normalize to ImageRegion format
|
||||
regions = []
|
||||
for r in raw_regions:
|
||||
regions.append({
|
||||
"bbox_pct": {
|
||||
"x": max(0, min(100, float(r.get("x", 0)))),
|
||||
"y": max(0, min(100, float(r.get("y", 0)))),
|
||||
"w": max(1, min(100, float(r.get("w", 10)))),
|
||||
"h": max(1, min(100, float(r.get("h", 10)))),
|
||||
},
|
||||
"description": r.get("description", ""),
|
||||
"prompt": r.get("description", ""),
|
||||
"image_b64": None,
|
||||
"style": "educational",
|
||||
})
|
||||
|
||||
# Enrich prompts with nearby vocab context
|
||||
if entries:
|
||||
for region in regions:
|
||||
ry = region["bbox_pct"]["y"]
|
||||
rh = region["bbox_pct"]["h"]
|
||||
nearby = [
|
||||
e for e in entries
|
||||
if e.get("bbox") and abs(e["bbox"].get("y", 0) - ry) < rh + 10
|
||||
]
|
||||
if nearby:
|
||||
en_words = [e.get("english", "") for e in nearby if e.get("english")]
|
||||
de_words = [e.get("german", "") for e in nearby if e.get("german")]
|
||||
if en_words or de_words:
|
||||
context = f" (vocabulary context: {', '.join(en_words[:5])}"
|
||||
if de_words:
|
||||
context += f" / {', '.join(de_words[:5])}"
|
||||
context += ")"
|
||||
region["prompt"] = region["description"] + context
|
||||
|
||||
# Save to ground_truth JSONB
|
||||
ground_truth = session.get("ground_truth") or {}
|
||||
validation = ground_truth.get("validation") or {}
|
||||
validation["image_regions"] = regions
|
||||
validation["detected_at"] = datetime.utcnow().isoformat()
|
||||
ground_truth["validation"] = validation
|
||||
await update_session_db(session_id, ground_truth=ground_truth)
|
||||
|
||||
if session_id in _cache:
|
||||
_cache[session_id]["ground_truth"] = ground_truth
|
||||
|
||||
logger.info(f"Detected {len(regions)} image regions for session {session_id}")
|
||||
|
||||
return {"regions": regions, "count": len(regions)}
|
||||
|
||||
except httpx.ConnectError:
|
||||
logger.warning(f"VLM not available at {ollama_base} for image detection")
|
||||
return {"regions": [], "count": 0, "error": "VLM not available"}
|
||||
except Exception as e:
|
||||
logger.error(f"Image detection failed for {session_id}: {e}")
|
||||
return {"regions": [], "count": 0, "error": str(e)}
|
||||
|
||||
|
||||
@router.post("/sessions/{session_id}/reconstruction/generate-image")
|
||||
async def generate_image_for_region(session_id: str, req: GenerateImageRequest):
|
||||
"""Generate a replacement image for a detected region using mflux.
|
||||
|
||||
Sends the prompt (with style suffix) to the mflux-service running
|
||||
natively on the Mac Mini (Metal GPU required).
|
||||
"""
|
||||
import httpx
|
||||
|
||||
session = await get_session_db(session_id)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||||
|
||||
ground_truth = session.get("ground_truth") or {}
|
||||
validation = ground_truth.get("validation") or {}
|
||||
regions = validation.get("image_regions") or []
|
||||
|
||||
if req.region_index < 0 or req.region_index >= len(regions):
|
||||
raise HTTPException(status_code=400, detail=f"Invalid region_index {req.region_index}, have {len(regions)} regions")
|
||||
|
||||
mflux_url = os.getenv("MFLUX_URL", "http://host.docker.internal:8095")
|
||||
style_suffix = STYLE_SUFFIXES.get(req.style, STYLE_SUFFIXES["educational"])
|
||||
full_prompt = f"{req.prompt}, {style_suffix}"
|
||||
|
||||
# Determine image size from region aspect ratio (snap to multiples of 64)
|
||||
region = regions[req.region_index]
|
||||
bbox = region["bbox_pct"]
|
||||
aspect = bbox["w"] / max(bbox["h"], 1)
|
||||
if aspect > 1.3:
|
||||
width, height = 768, 512
|
||||
elif aspect < 0.7:
|
||||
width, height = 512, 768
|
||||
else:
|
||||
width, height = 512, 512
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=300.0) as client:
|
||||
resp = await client.post(f"{mflux_url}/generate", json={
|
||||
"prompt": full_prompt,
|
||||
"width": width,
|
||||
"height": height,
|
||||
"steps": 4,
|
||||
})
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
image_b64 = data.get("image_b64")
|
||||
|
||||
if not image_b64:
|
||||
return {"image_b64": None, "success": False, "error": "No image returned"}
|
||||
|
||||
# Save to ground_truth
|
||||
regions[req.region_index]["image_b64"] = image_b64
|
||||
regions[req.region_index]["prompt"] = req.prompt
|
||||
regions[req.region_index]["style"] = req.style
|
||||
validation["image_regions"] = regions
|
||||
ground_truth["validation"] = validation
|
||||
await update_session_db(session_id, ground_truth=ground_truth)
|
||||
|
||||
if session_id in _cache:
|
||||
_cache[session_id]["ground_truth"] = ground_truth
|
||||
|
||||
logger.info(f"Generated image for session {session_id} region {req.region_index}")
|
||||
return {"image_b64": image_b64, "success": True}
|
||||
|
||||
except httpx.ConnectError:
|
||||
logger.warning(f"mflux-service not available at {mflux_url}")
|
||||
return {"image_b64": None, "success": False, "error": f"mflux-service not available at {mflux_url}"}
|
||||
except Exception as e:
|
||||
logger.error(f"Image generation failed for {session_id}: {e}")
|
||||
return {"image_b64": None, "success": False, "error": str(e)}
|
||||
|
||||
|
||||
@router.post("/sessions/{session_id}/reconstruction/validate")
|
||||
async def save_validation(session_id: str, req: ValidationRequest):
|
||||
"""Save final validation results for step 8.
|
||||
|
||||
Stores notes, score, and preserves any detected/generated image regions.
|
||||
Sets current_step = 10 to mark pipeline as complete.
|
||||
"""
|
||||
session = await get_session_db(session_id)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||||
|
||||
ground_truth = session.get("ground_truth") or {}
|
||||
validation = ground_truth.get("validation") or {}
|
||||
validation["validated_at"] = datetime.utcnow().isoformat()
|
||||
validation["notes"] = req.notes
|
||||
validation["score"] = req.score
|
||||
ground_truth["validation"] = validation
|
||||
|
||||
await update_session_db(session_id, ground_truth=ground_truth, current_step=11)
|
||||
|
||||
if session_id in _cache:
|
||||
_cache[session_id]["ground_truth"] = ground_truth
|
||||
|
||||
logger.info(f"Validation saved for session {session_id}: score={req.score}")
|
||||
|
||||
return {"session_id": session_id, "validation": validation}
|
||||
|
||||
|
||||
@router.get("/sessions/{session_id}/reconstruction/validation")
|
||||
async def get_validation(session_id: str):
|
||||
"""Retrieve saved validation data for step 8."""
|
||||
session = await get_session_db(session_id)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||||
|
||||
ground_truth = session.get("ground_truth") or {}
|
||||
validation = ground_truth.get("validation")
|
||||
|
||||
return {
|
||||
"session_id": session_id,
|
||||
"validation": validation,
|
||||
"word_result": session.get("word_result"),
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Remove handwriting
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.post("/sessions/{session_id}/remove-handwriting")
|
||||
async def remove_handwriting_endpoint(session_id: str, req: RemoveHandwritingRequest):
|
||||
"""
|
||||
Remove handwriting from a session image using inpainting.
|
||||
|
||||
Steps:
|
||||
1. Load source image (auto -> deskewed if available, else original)
|
||||
2. Detect handwriting mask (filtered by target_ink)
|
||||
3. Dilate mask to cover stroke edges
|
||||
4. Inpaint the image
|
||||
5. Store result as clean_png in the session
|
||||
|
||||
Returns metadata including the URL to fetch the clean image.
|
||||
"""
|
||||
import time as _time
|
||||
t0 = _time.monotonic()
|
||||
|
||||
from services.handwriting_detection import detect_handwriting
|
||||
from services.inpainting_service import inpaint_image, dilate_mask as _dilate_mask, InpaintingMethod, image_to_png
|
||||
|
||||
session = await get_session_db(session_id)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||||
|
||||
# 1. Determine source image
|
||||
source = req.use_source
|
||||
if source == "auto":
|
||||
deskewed = await get_session_image(session_id, "deskewed")
|
||||
source = "deskewed" if deskewed else "original"
|
||||
|
||||
image_bytes = await get_session_image(session_id, source)
|
||||
if not image_bytes:
|
||||
raise HTTPException(status_code=404, detail=f"Source image '{source}' not available")
|
||||
|
||||
# 2. Detect handwriting mask
|
||||
detection = detect_handwriting(image_bytes, target_ink=req.target_ink)
|
||||
|
||||
# 3. Convert mask to PNG bytes and dilate
|
||||
import io
|
||||
from PIL import Image as _PILImage
|
||||
mask_img = _PILImage.fromarray(detection.mask)
|
||||
mask_buf = io.BytesIO()
|
||||
mask_img.save(mask_buf, format="PNG")
|
||||
mask_bytes = mask_buf.getvalue()
|
||||
|
||||
if req.dilation > 0:
|
||||
mask_bytes = _dilate_mask(mask_bytes, iterations=req.dilation)
|
||||
|
||||
# 4. Inpaint
|
||||
method_map = {
|
||||
"telea": InpaintingMethod.OPENCV_TELEA,
|
||||
"ns": InpaintingMethod.OPENCV_NS,
|
||||
"auto": InpaintingMethod.AUTO,
|
||||
}
|
||||
inpaint_method = method_map.get(req.method, InpaintingMethod.AUTO)
|
||||
|
||||
result = inpaint_image(image_bytes, mask_bytes, method=inpaint_method)
|
||||
if not result.success:
|
||||
raise HTTPException(status_code=500, detail="Inpainting failed")
|
||||
|
||||
elapsed_ms = int((_time.monotonic() - t0) * 1000)
|
||||
|
||||
meta = {
|
||||
"method_used": result.method_used.value if hasattr(result.method_used, "value") else str(result.method_used),
|
||||
"handwriting_ratio": round(detection.handwriting_ratio, 4),
|
||||
"detection_confidence": round(detection.confidence, 4),
|
||||
"target_ink": req.target_ink,
|
||||
"dilation": req.dilation,
|
||||
"source_image": source,
|
||||
"processing_time_ms": elapsed_ms,
|
||||
}
|
||||
|
||||
# 5. Persist clean image (convert BGR ndarray -> PNG bytes)
|
||||
clean_png_bytes = image_to_png(result.image)
|
||||
await update_session_db(session_id, clean_png=clean_png_bytes, handwriting_removal_meta=meta)
|
||||
|
||||
return {
|
||||
**meta,
|
||||
"image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/clean",
|
||||
"session_id": session_id,
|
||||
}
|
||||
from fastapi import APIRouter
|
||||
|
||||
from ocr_pipeline_llm_review import router as _llm_review_router
|
||||
from ocr_pipeline_reconstruction import router as _reconstruction_router
|
||||
from ocr_pipeline_validation import router as _validation_router
|
||||
|
||||
# Composite router — drop-in replacement for the old monolithic router.
|
||||
# ocr_pipeline_api.py imports ``from ocr_pipeline_postprocess import router``.
|
||||
router = APIRouter()
|
||||
router.include_router(_llm_review_router)
|
||||
router.include_router(_reconstruction_router)
|
||||
router.include_router(_validation_router)
|
||||
|
||||
362
klausur-service/backend/ocr_pipeline_reconstruction.py
Normal file
362
klausur-service/backend/ocr_pipeline_reconstruction.py
Normal file
@@ -0,0 +1,362 @@
|
||||
"""
|
||||
OCR Pipeline Reconstruction — save edits, Fabric JSON export, merged entries, PDF/DOCX export.
|
||||
|
||||
Extracted from ocr_pipeline_postprocess.py.
|
||||
|
||||
Lizenz: Apache 2.0
|
||||
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import re
|
||||
from typing import Dict
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
from ocr_pipeline_session_store import (
|
||||
get_session_db,
|
||||
get_sub_sessions,
|
||||
update_session_db,
|
||||
)
|
||||
from ocr_pipeline_common import _cache
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["ocr-pipeline"])
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Step 9: Reconstruction + Fabric JSON export
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.post("/sessions/{session_id}/reconstruction")
|
||||
async def save_reconstruction(session_id: str, request: Request):
|
||||
"""Save edited cell texts from reconstruction step."""
|
||||
session = await get_session_db(session_id)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||||
|
||||
word_result = session.get("word_result")
|
||||
if not word_result:
|
||||
raise HTTPException(status_code=400, detail="No word result found")
|
||||
|
||||
body = await request.json()
|
||||
cell_updates = body.get("cells", [])
|
||||
|
||||
if not cell_updates:
|
||||
await update_session_db(session_id, current_step=10)
|
||||
return {"session_id": session_id, "updated": 0}
|
||||
|
||||
# Build update map: cell_id -> new text
|
||||
update_map = {c["cell_id"]: c["text"] for c in cell_updates}
|
||||
|
||||
# Separate sub-session updates (cell_ids prefixed with "box{N}_")
|
||||
sub_updates: Dict[int, Dict[str, str]] = {} # box_index -> {original_cell_id: text}
|
||||
main_updates: Dict[str, str] = {}
|
||||
for cell_id, text in update_map.items():
|
||||
m = re.match(r'^box(\d+)_(.+)$', cell_id)
|
||||
if m:
|
||||
bi = int(m.group(1))
|
||||
original_id = m.group(2)
|
||||
sub_updates.setdefault(bi, {})[original_id] = text
|
||||
else:
|
||||
main_updates[cell_id] = text
|
||||
|
||||
# Update main session cells
|
||||
cells = word_result.get("cells", [])
|
||||
updated_count = 0
|
||||
for cell in cells:
|
||||
if cell["cell_id"] in main_updates:
|
||||
cell["text"] = main_updates[cell["cell_id"]]
|
||||
cell["status"] = "edited"
|
||||
updated_count += 1
|
||||
|
||||
word_result["cells"] = cells
|
||||
|
||||
# Also update vocab_entries if present
|
||||
entries = word_result.get("vocab_entries") or word_result.get("entries") or []
|
||||
if entries:
|
||||
for entry in entries:
|
||||
row_idx = entry.get("row_index", -1)
|
||||
for col_idx, field_name in enumerate(["english", "german", "example"]):
|
||||
cell_id = f"R{row_idx:02d}_C{col_idx}"
|
||||
cell_id_alt = f"R{row_idx}_C{col_idx}"
|
||||
new_text = main_updates.get(cell_id) or main_updates.get(cell_id_alt)
|
||||
if new_text is not None:
|
||||
entry[field_name] = new_text
|
||||
|
||||
word_result["vocab_entries"] = entries
|
||||
if "entries" in word_result:
|
||||
word_result["entries"] = entries
|
||||
|
||||
await update_session_db(session_id, word_result=word_result, current_step=10)
|
||||
|
||||
if session_id in _cache:
|
||||
_cache[session_id]["word_result"] = word_result
|
||||
|
||||
# Route sub-session updates
|
||||
sub_updated = 0
|
||||
if sub_updates:
|
||||
subs = await get_sub_sessions(session_id)
|
||||
sub_by_index = {s.get("box_index"): s["id"] for s in subs}
|
||||
for bi, updates in sub_updates.items():
|
||||
sub_id = sub_by_index.get(bi)
|
||||
if not sub_id:
|
||||
continue
|
||||
sub_session = await get_session_db(sub_id)
|
||||
if not sub_session:
|
||||
continue
|
||||
sub_word = sub_session.get("word_result")
|
||||
if not sub_word:
|
||||
continue
|
||||
sub_cells = sub_word.get("cells", [])
|
||||
for cell in sub_cells:
|
||||
if cell["cell_id"] in updates:
|
||||
cell["text"] = updates[cell["cell_id"]]
|
||||
cell["status"] = "edited"
|
||||
sub_updated += 1
|
||||
sub_word["cells"] = sub_cells
|
||||
await update_session_db(sub_id, word_result=sub_word)
|
||||
if sub_id in _cache:
|
||||
_cache[sub_id]["word_result"] = sub_word
|
||||
|
||||
total_updated = updated_count + sub_updated
|
||||
logger.info(f"Reconstruction saved for session {session_id}: "
|
||||
f"{updated_count} main + {sub_updated} sub-session cells updated")
|
||||
|
||||
return {
|
||||
"session_id": session_id,
|
||||
"updated": total_updated,
|
||||
"main_updated": updated_count,
|
||||
"sub_updated": sub_updated,
|
||||
}
|
||||
|
||||
|
||||
@router.get("/sessions/{session_id}/reconstruction/fabric-json")
|
||||
async def get_fabric_json(session_id: str):
|
||||
"""Return cell grid as Fabric.js-compatible JSON for the canvas editor."""
|
||||
session = await get_session_db(session_id)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||||
|
||||
word_result = session.get("word_result")
|
||||
if not word_result:
|
||||
raise HTTPException(status_code=400, detail="No word result found")
|
||||
|
||||
cells = list(word_result.get("cells", []))
|
||||
img_w = word_result.get("image_width", 800)
|
||||
img_h = word_result.get("image_height", 600)
|
||||
|
||||
# Merge sub-session cells at box positions
|
||||
subs = await get_sub_sessions(session_id)
|
||||
if subs:
|
||||
column_result = session.get("column_result") or {}
|
||||
zones = column_result.get("zones") or []
|
||||
box_zones = [z for z in zones if z.get("zone_type") == "box" and z.get("box")]
|
||||
|
||||
for sub in subs:
|
||||
sub_session = await get_session_db(sub["id"])
|
||||
if not sub_session:
|
||||
continue
|
||||
sub_word = sub_session.get("word_result")
|
||||
if not sub_word or not sub_word.get("cells"):
|
||||
continue
|
||||
|
||||
bi = sub.get("box_index", 0)
|
||||
if bi < len(box_zones):
|
||||
box = box_zones[bi]["box"]
|
||||
box_y, box_x = box["y"], box["x"]
|
||||
else:
|
||||
box_y, box_x = 0, 0
|
||||
|
||||
for cell in sub_word["cells"]:
|
||||
cell_copy = dict(cell)
|
||||
cell_copy["cell_id"] = f"box{bi}_{cell_copy.get('cell_id', '')}"
|
||||
cell_copy["source"] = f"box_{bi}"
|
||||
bbox = cell_copy.get("bbox_px", {})
|
||||
if bbox:
|
||||
bbox = dict(bbox)
|
||||
bbox["x"] = bbox.get("x", 0) + box_x
|
||||
bbox["y"] = bbox.get("y", 0) + box_y
|
||||
cell_copy["bbox_px"] = bbox
|
||||
cells.append(cell_copy)
|
||||
|
||||
from services.layout_reconstruction_service import cells_to_fabric_json
|
||||
fabric_json = cells_to_fabric_json(cells, img_w, img_h)
|
||||
|
||||
return fabric_json
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Vocab entries merged + PDF/DOCX export
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.get("/sessions/{session_id}/vocab-entries/merged")
|
||||
async def get_merged_vocab_entries(session_id: str):
|
||||
"""Return vocab entries from main session + all sub-sessions, sorted by Y position."""
|
||||
session = await get_session_db(session_id)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||||
|
||||
word_result = session.get("word_result") or {}
|
||||
entries = list(word_result.get("vocab_entries") or word_result.get("entries") or [])
|
||||
|
||||
for e in entries:
|
||||
e.setdefault("source", "main")
|
||||
|
||||
subs = await get_sub_sessions(session_id)
|
||||
if subs:
|
||||
column_result = session.get("column_result") or {}
|
||||
zones = column_result.get("zones") or []
|
||||
box_zones = [z for z in zones if z.get("zone_type") == "box" and z.get("box")]
|
||||
|
||||
for sub in subs:
|
||||
sub_session = await get_session_db(sub["id"])
|
||||
if not sub_session:
|
||||
continue
|
||||
sub_word = sub_session.get("word_result") or {}
|
||||
sub_entries = sub_word.get("vocab_entries") or sub_word.get("entries") or []
|
||||
|
||||
bi = sub.get("box_index", 0)
|
||||
box_y = 0
|
||||
if bi < len(box_zones):
|
||||
box_y = box_zones[bi]["box"]["y"]
|
||||
|
||||
for e in sub_entries:
|
||||
e_copy = dict(e)
|
||||
e_copy["source"] = f"box_{bi}"
|
||||
e_copy["source_y"] = box_y
|
||||
entries.append(e_copy)
|
||||
|
||||
def _sort_key(e):
|
||||
if e.get("source", "main") == "main":
|
||||
return e.get("row_index", 0) * 100
|
||||
return e.get("source_y", 0) * 100 + e.get("row_index", 0)
|
||||
|
||||
entries.sort(key=_sort_key)
|
||||
|
||||
return {
|
||||
"session_id": session_id,
|
||||
"entries": entries,
|
||||
"total": len(entries),
|
||||
"sources": list(set(e.get("source", "main") for e in entries)),
|
||||
}
|
||||
|
||||
|
||||
@router.get("/sessions/{session_id}/reconstruction/export/pdf")
|
||||
async def export_reconstruction_pdf(session_id: str):
|
||||
"""Export the reconstructed cell grid as a PDF table."""
|
||||
session = await get_session_db(session_id)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||||
|
||||
word_result = session.get("word_result")
|
||||
if not word_result:
|
||||
raise HTTPException(status_code=400, detail="No word result found")
|
||||
|
||||
cells = word_result.get("cells", [])
|
||||
columns_used = word_result.get("columns_used", [])
|
||||
grid_shape = word_result.get("grid_shape", {})
|
||||
n_rows = grid_shape.get("rows", 0)
|
||||
n_cols = grid_shape.get("cols", 0)
|
||||
|
||||
# Build table data: rows x columns
|
||||
table_data: list[list[str]] = []
|
||||
header = [c.get("label", c.get("type", f"Col {i}")) for i, c in enumerate(columns_used)]
|
||||
if not header:
|
||||
header = [f"Col {i}" for i in range(n_cols)]
|
||||
table_data.append(header)
|
||||
|
||||
for r in range(n_rows):
|
||||
row_texts = []
|
||||
for ci in range(n_cols):
|
||||
cell_id = f"R{r:02d}_C{ci}"
|
||||
cell = next((c for c in cells if c.get("cell_id") == cell_id), None)
|
||||
row_texts.append(cell.get("text", "") if cell else "")
|
||||
table_data.append(row_texts)
|
||||
|
||||
try:
|
||||
from reportlab.lib.pagesizes import A4
|
||||
from reportlab.lib import colors
|
||||
from reportlab.platypus import SimpleDocTemplate, Table, TableStyle
|
||||
import io as _io
|
||||
|
||||
buf = _io.BytesIO()
|
||||
doc = SimpleDocTemplate(buf, pagesize=A4)
|
||||
if not table_data or not table_data[0]:
|
||||
raise HTTPException(status_code=400, detail="No data to export")
|
||||
|
||||
t = Table(table_data)
|
||||
t.setStyle(TableStyle([
|
||||
('BACKGROUND', (0, 0), (-1, 0), colors.HexColor('#0d9488')),
|
||||
('TEXTCOLOR', (0, 0), (-1, 0), colors.white),
|
||||
('FONTSIZE', (0, 0), (-1, -1), 9),
|
||||
('GRID', (0, 0), (-1, -1), 0.5, colors.grey),
|
||||
('VALIGN', (0, 0), (-1, -1), 'TOP'),
|
||||
('WORDWRAP', (0, 0), (-1, -1), True),
|
||||
]))
|
||||
doc.build([t])
|
||||
buf.seek(0)
|
||||
|
||||
return StreamingResponse(
|
||||
buf,
|
||||
media_type="application/pdf",
|
||||
headers={"Content-Disposition": f'attachment; filename="reconstruction_{session_id}.pdf"'},
|
||||
)
|
||||
except ImportError:
|
||||
raise HTTPException(status_code=501, detail="reportlab not installed")
|
||||
|
||||
|
||||
@router.get("/sessions/{session_id}/reconstruction/export/docx")
|
||||
async def export_reconstruction_docx(session_id: str):
|
||||
"""Export the reconstructed cell grid as a DOCX table."""
|
||||
session = await get_session_db(session_id)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||||
|
||||
word_result = session.get("word_result")
|
||||
if not word_result:
|
||||
raise HTTPException(status_code=400, detail="No word result found")
|
||||
|
||||
cells = word_result.get("cells", [])
|
||||
columns_used = word_result.get("columns_used", [])
|
||||
grid_shape = word_result.get("grid_shape", {})
|
||||
n_rows = grid_shape.get("rows", 0)
|
||||
n_cols = grid_shape.get("cols", 0)
|
||||
|
||||
try:
|
||||
from docx import Document
|
||||
from docx.shared import Pt
|
||||
import io as _io
|
||||
|
||||
doc = Document()
|
||||
doc.add_heading(f'Rekonstruktion -- Session {session_id[:8]}', level=1)
|
||||
|
||||
header = [c.get("label", c.get("type", f"Col {i}")) for i, c in enumerate(columns_used)]
|
||||
if not header:
|
||||
header = [f"Col {i}" for i in range(n_cols)]
|
||||
|
||||
table = doc.add_table(rows=1 + n_rows, cols=max(n_cols, 1))
|
||||
table.style = 'Table Grid'
|
||||
|
||||
for ci, h in enumerate(header):
|
||||
table.rows[0].cells[ci].text = h
|
||||
|
||||
for r in range(n_rows):
|
||||
for ci in range(n_cols):
|
||||
cell_id = f"R{r:02d}_C{ci}"
|
||||
cell = next((c for c in cells if c.get("cell_id") == cell_id), None)
|
||||
table.rows[r + 1].cells[ci].text = cell.get("text", "") if cell else ""
|
||||
|
||||
buf = _io.BytesIO()
|
||||
doc.save(buf)
|
||||
buf.seek(0)
|
||||
|
||||
return StreamingResponse(
|
||||
buf,
|
||||
media_type="application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
||||
headers={"Content-Disposition": f'attachment; filename="reconstruction_{session_id}.docx"'},
|
||||
)
|
||||
except ImportError:
|
||||
raise HTTPException(status_code=501, detail="python-docx not installed")
|
||||
362
klausur-service/backend/ocr_pipeline_validation.py
Normal file
362
klausur-service/backend/ocr_pipeline_validation.py
Normal file
@@ -0,0 +1,362 @@
|
||||
"""
|
||||
OCR Pipeline Validation — image detection, generation, validation save,
|
||||
and handwriting removal endpoints.
|
||||
|
||||
Extracted from ocr_pipeline_postprocess.py.
|
||||
|
||||
Lizenz: Apache 2.0
|
||||
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ocr_pipeline_session_store import (
|
||||
get_session_db,
|
||||
get_session_image,
|
||||
update_session_db,
|
||||
)
|
||||
from ocr_pipeline_common import (
|
||||
_cache,
|
||||
RemoveHandwritingRequest,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["ocr-pipeline"])
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Pydantic Models
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
STYLE_SUFFIXES = {
|
||||
"educational": "educational illustration, textbook style, clear, colorful",
|
||||
"cartoon": "cartoon, child-friendly, simple shapes",
|
||||
"sketch": "pencil sketch, hand-drawn, black and white",
|
||||
"clipart": "clipart, flat vector style, simple",
|
||||
"realistic": "photorealistic, high detail",
|
||||
}
|
||||
|
||||
|
||||
class ValidationRequest(BaseModel):
|
||||
notes: Optional[str] = None
|
||||
score: Optional[int] = None
|
||||
|
||||
|
||||
class GenerateImageRequest(BaseModel):
|
||||
region_index: int
|
||||
prompt: str
|
||||
style: str = "educational"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Image detection + generation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.post("/sessions/{session_id}/reconstruction/detect-images")
|
||||
async def detect_image_regions(session_id: str):
|
||||
"""Detect illustration/image regions in the original scan using VLM."""
|
||||
import base64
|
||||
import httpx
|
||||
import re
|
||||
|
||||
session = await get_session_db(session_id)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||||
|
||||
original_png = await get_session_image(session_id, "original")
|
||||
if not original_png:
|
||||
raise HTTPException(status_code=400, detail="No original image found")
|
||||
|
||||
word_result = session.get("word_result") or {}
|
||||
entries = word_result.get("vocab_entries") or word_result.get("entries") or []
|
||||
vocab_context = ""
|
||||
if entries:
|
||||
sample = entries[:10]
|
||||
words = [f"{e.get('english', '')} / {e.get('german', '')}" for e in sample if e.get('english')]
|
||||
if words:
|
||||
vocab_context = f"\nContext: This is a vocabulary page with words like: {', '.join(words)}"
|
||||
|
||||
ollama_base = os.getenv("OLLAMA_BASE_URL", "http://host.docker.internal:11434")
|
||||
model = os.getenv("OLLAMA_HTR_MODEL", "qwen2.5vl:32b")
|
||||
|
||||
prompt = (
|
||||
"Analyze this scanned page. Find ALL illustration/image/picture regions "
|
||||
"(NOT text, NOT table cells, NOT blank areas). "
|
||||
"For each image region found, return its bounding box as percentage of page dimensions "
|
||||
"and a short English description of what the image shows. "
|
||||
"Reply with ONLY a JSON array like: "
|
||||
'[{"x": 10, "y": 20, "w": 30, "h": 25, "description": "drawing of a cat"}] '
|
||||
"where x, y, w, h are percentages (0-100) of the page width/height. "
|
||||
"If there are NO images on the page, return an empty array: []"
|
||||
f"{vocab_context}"
|
||||
)
|
||||
|
||||
img_b64 = base64.b64encode(original_png).decode("utf-8")
|
||||
payload = {
|
||||
"model": model,
|
||||
"prompt": prompt,
|
||||
"images": [img_b64],
|
||||
"stream": False,
|
||||
}
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=120.0) as client:
|
||||
resp = await client.post(f"{ollama_base}/api/generate", json=payload)
|
||||
resp.raise_for_status()
|
||||
text = resp.json().get("response", "")
|
||||
|
||||
match = re.search(r'\[.*?\]', text, re.DOTALL)
|
||||
if match:
|
||||
raw_regions = json.loads(match.group(0))
|
||||
else:
|
||||
raw_regions = []
|
||||
|
||||
regions = []
|
||||
for r in raw_regions:
|
||||
regions.append({
|
||||
"bbox_pct": {
|
||||
"x": max(0, min(100, float(r.get("x", 0)))),
|
||||
"y": max(0, min(100, float(r.get("y", 0)))),
|
||||
"w": max(1, min(100, float(r.get("w", 10)))),
|
||||
"h": max(1, min(100, float(r.get("h", 10)))),
|
||||
},
|
||||
"description": r.get("description", ""),
|
||||
"prompt": r.get("description", ""),
|
||||
"image_b64": None,
|
||||
"style": "educational",
|
||||
})
|
||||
|
||||
# Enrich prompts with nearby vocab context
|
||||
if entries:
|
||||
for region in regions:
|
||||
ry = region["bbox_pct"]["y"]
|
||||
rh = region["bbox_pct"]["h"]
|
||||
nearby = [
|
||||
e for e in entries
|
||||
if e.get("bbox") and abs(e["bbox"].get("y", 0) - ry) < rh + 10
|
||||
]
|
||||
if nearby:
|
||||
en_words = [e.get("english", "") for e in nearby if e.get("english")]
|
||||
de_words = [e.get("german", "") for e in nearby if e.get("german")]
|
||||
if en_words or de_words:
|
||||
context = f" (vocabulary context: {', '.join(en_words[:5])}"
|
||||
if de_words:
|
||||
context += f" / {', '.join(de_words[:5])}"
|
||||
context += ")"
|
||||
region["prompt"] = region["description"] + context
|
||||
|
||||
ground_truth = session.get("ground_truth") or {}
|
||||
validation = ground_truth.get("validation") or {}
|
||||
validation["image_regions"] = regions
|
||||
validation["detected_at"] = datetime.utcnow().isoformat()
|
||||
ground_truth["validation"] = validation
|
||||
await update_session_db(session_id, ground_truth=ground_truth)
|
||||
|
||||
if session_id in _cache:
|
||||
_cache[session_id]["ground_truth"] = ground_truth
|
||||
|
||||
logger.info(f"Detected {len(regions)} image regions for session {session_id}")
|
||||
|
||||
return {"regions": regions, "count": len(regions)}
|
||||
|
||||
except httpx.ConnectError:
|
||||
logger.warning(f"VLM not available at {ollama_base} for image detection")
|
||||
return {"regions": [], "count": 0, "error": "VLM not available"}
|
||||
except Exception as e:
|
||||
logger.error(f"Image detection failed for {session_id}: {e}")
|
||||
return {"regions": [], "count": 0, "error": str(e)}
|
||||
|
||||
|
||||
@router.post("/sessions/{session_id}/reconstruction/generate-image")
|
||||
async def generate_image_for_region(session_id: str, req: GenerateImageRequest):
|
||||
"""Generate a replacement image for a detected region using mflux."""
|
||||
import httpx
|
||||
|
||||
session = await get_session_db(session_id)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||||
|
||||
ground_truth = session.get("ground_truth") or {}
|
||||
validation = ground_truth.get("validation") or {}
|
||||
regions = validation.get("image_regions") or []
|
||||
|
||||
if req.region_index < 0 or req.region_index >= len(regions):
|
||||
raise HTTPException(status_code=400, detail=f"Invalid region_index {req.region_index}, have {len(regions)} regions")
|
||||
|
||||
mflux_url = os.getenv("MFLUX_URL", "http://host.docker.internal:8095")
|
||||
style_suffix = STYLE_SUFFIXES.get(req.style, STYLE_SUFFIXES["educational"])
|
||||
full_prompt = f"{req.prompt}, {style_suffix}"
|
||||
|
||||
region = regions[req.region_index]
|
||||
bbox = region["bbox_pct"]
|
||||
aspect = bbox["w"] / max(bbox["h"], 1)
|
||||
if aspect > 1.3:
|
||||
width, height = 768, 512
|
||||
elif aspect < 0.7:
|
||||
width, height = 512, 768
|
||||
else:
|
||||
width, height = 512, 512
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=300.0) as client:
|
||||
resp = await client.post(f"{mflux_url}/generate", json={
|
||||
"prompt": full_prompt,
|
||||
"width": width,
|
||||
"height": height,
|
||||
"steps": 4,
|
||||
})
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
image_b64 = data.get("image_b64")
|
||||
|
||||
if not image_b64:
|
||||
return {"image_b64": None, "success": False, "error": "No image returned"}
|
||||
|
||||
regions[req.region_index]["image_b64"] = image_b64
|
||||
regions[req.region_index]["prompt"] = req.prompt
|
||||
regions[req.region_index]["style"] = req.style
|
||||
validation["image_regions"] = regions
|
||||
ground_truth["validation"] = validation
|
||||
await update_session_db(session_id, ground_truth=ground_truth)
|
||||
|
||||
if session_id in _cache:
|
||||
_cache[session_id]["ground_truth"] = ground_truth
|
||||
|
||||
logger.info(f"Generated image for session {session_id} region {req.region_index}")
|
||||
return {"image_b64": image_b64, "success": True}
|
||||
|
||||
except httpx.ConnectError:
|
||||
logger.warning(f"mflux-service not available at {mflux_url}")
|
||||
return {"image_b64": None, "success": False, "error": f"mflux-service not available at {mflux_url}"}
|
||||
except Exception as e:
|
||||
logger.error(f"Image generation failed for {session_id}: {e}")
|
||||
return {"image_b64": None, "success": False, "error": str(e)}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Validation save/get
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.post("/sessions/{session_id}/reconstruction/validate")
|
||||
async def save_validation(session_id: str, req: ValidationRequest):
|
||||
"""Save final validation results for step 8."""
|
||||
session = await get_session_db(session_id)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||||
|
||||
ground_truth = session.get("ground_truth") or {}
|
||||
validation = ground_truth.get("validation") or {}
|
||||
validation["validated_at"] = datetime.utcnow().isoformat()
|
||||
validation["notes"] = req.notes
|
||||
validation["score"] = req.score
|
||||
ground_truth["validation"] = validation
|
||||
|
||||
await update_session_db(session_id, ground_truth=ground_truth, current_step=11)
|
||||
|
||||
if session_id in _cache:
|
||||
_cache[session_id]["ground_truth"] = ground_truth
|
||||
|
||||
logger.info(f"Validation saved for session {session_id}: score={req.score}")
|
||||
|
||||
return {"session_id": session_id, "validation": validation}
|
||||
|
||||
|
||||
@router.get("/sessions/{session_id}/reconstruction/validation")
|
||||
async def get_validation(session_id: str):
|
||||
"""Retrieve saved validation data for step 8."""
|
||||
session = await get_session_db(session_id)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||||
|
||||
ground_truth = session.get("ground_truth") or {}
|
||||
validation = ground_truth.get("validation")
|
||||
|
||||
return {
|
||||
"session_id": session_id,
|
||||
"validation": validation,
|
||||
"word_result": session.get("word_result"),
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Remove handwriting
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.post("/sessions/{session_id}/remove-handwriting")
|
||||
async def remove_handwriting_endpoint(session_id: str, req: RemoveHandwritingRequest):
|
||||
"""Remove handwriting from a session image using inpainting."""
|
||||
import time as _time
|
||||
|
||||
from services.handwriting_detection import detect_handwriting
|
||||
from services.inpainting_service import inpaint_image, dilate_mask as _dilate_mask, InpaintingMethod, image_to_png
|
||||
|
||||
session = await get_session_db(session_id)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||||
|
||||
t0 = _time.monotonic()
|
||||
|
||||
# 1. Determine source image
|
||||
source = req.use_source
|
||||
if source == "auto":
|
||||
deskewed = await get_session_image(session_id, "deskewed")
|
||||
source = "deskewed" if deskewed else "original"
|
||||
|
||||
image_bytes = await get_session_image(session_id, source)
|
||||
if not image_bytes:
|
||||
raise HTTPException(status_code=404, detail=f"Source image '{source}' not available")
|
||||
|
||||
# 2. Detect handwriting mask
|
||||
detection = detect_handwriting(image_bytes, target_ink=req.target_ink)
|
||||
|
||||
# 3. Convert mask to PNG bytes and dilate
|
||||
import io
|
||||
from PIL import Image as _PILImage
|
||||
mask_img = _PILImage.fromarray(detection.mask)
|
||||
mask_buf = io.BytesIO()
|
||||
mask_img.save(mask_buf, format="PNG")
|
||||
mask_bytes = mask_buf.getvalue()
|
||||
|
||||
if req.dilation > 0:
|
||||
mask_bytes = _dilate_mask(mask_bytes, iterations=req.dilation)
|
||||
|
||||
# 4. Inpaint
|
||||
method_map = {
|
||||
"telea": InpaintingMethod.OPENCV_TELEA,
|
||||
"ns": InpaintingMethod.OPENCV_NS,
|
||||
"auto": InpaintingMethod.AUTO,
|
||||
}
|
||||
inpaint_method = method_map.get(req.method, InpaintingMethod.AUTO)
|
||||
|
||||
result = inpaint_image(image_bytes, mask_bytes, method=inpaint_method)
|
||||
if not result.success:
|
||||
raise HTTPException(status_code=500, detail="Inpainting failed")
|
||||
|
||||
elapsed_ms = int((_time.monotonic() - t0) * 1000)
|
||||
|
||||
meta = {
|
||||
"method_used": result.method_used.value if hasattr(result.method_used, "value") else str(result.method_used),
|
||||
"handwriting_ratio": round(detection.handwriting_ratio, 4),
|
||||
"detection_confidence": round(detection.confidence, 4),
|
||||
"target_ink": req.target_ink,
|
||||
"dilation": req.dilation,
|
||||
"source_image": source,
|
||||
"processing_time_ms": elapsed_ms,
|
||||
}
|
||||
|
||||
# 5. Persist clean image
|
||||
clean_png_bytes = image_to_png(result.image)
|
||||
await update_session_db(session_id, clean_png=clean_png_bytes, handwriting_removal_meta=meta)
|
||||
|
||||
return {
|
||||
**meta,
|
||||
"image_url": f"/api/v1/ocr-pipeline/sessions/{session_id}/image/clean",
|
||||
"session_id": session_id,
|
||||
}
|
||||
@@ -1,18 +1,18 @@
|
||||
"""
|
||||
OCR Pipeline Words - Word detection and ground truth endpoints.
|
||||
OCR Pipeline Words — composite router for word detection, PaddleOCR direct,
|
||||
and ground truth endpoints.
|
||||
|
||||
Extracted from ocr_pipeline_api.py.
|
||||
Handles:
|
||||
- POST /sessions/{session_id}/words — main SSE streaming word detection
|
||||
- POST /sessions/{session_id}/paddle-direct — PaddleOCR direct endpoint
|
||||
- POST /sessions/{session_id}/ground-truth/words — save ground truth
|
||||
- GET /sessions/{session_id}/ground-truth/words — get ground truth
|
||||
Split into sub-modules:
|
||||
ocr_pipeline_words_detect — main detect_words endpoint (Step 7)
|
||||
ocr_pipeline_words_stream — SSE streaming generators
|
||||
|
||||
This barrel module contains the PaddleOCR direct endpoint and ground truth
|
||||
endpoints, and assembles all word-related routers.
|
||||
|
||||
Lizenz: Apache 2.0
|
||||
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from datetime import datetime
|
||||
@@ -20,22 +20,9 @@ from typing import Any, Dict, List, Optional
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from fastapi import APIRouter, HTTPException, Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from pydantic import BaseModel
|
||||
|
||||
from cv_vocab_pipeline import (
|
||||
PageRegion,
|
||||
RowGeometry,
|
||||
_cells_to_vocab_entries,
|
||||
_fix_character_confusion,
|
||||
_fix_phonetic_brackets,
|
||||
fix_cell_phonetics,
|
||||
build_cell_grid_v2,
|
||||
build_cell_grid_v2_streaming,
|
||||
create_ocr_image,
|
||||
detect_column_geometry,
|
||||
)
|
||||
from cv_words_first import build_grid_from_words
|
||||
from ocr_pipeline_session_store import (
|
||||
get_session_db,
|
||||
@@ -44,15 +31,13 @@ from ocr_pipeline_session_store import (
|
||||
)
|
||||
from ocr_pipeline_common import (
|
||||
_cache,
|
||||
_load_session_to_cache,
|
||||
_get_cached,
|
||||
_get_base_image_png,
|
||||
_append_pipeline_log,
|
||||
)
|
||||
from ocr_pipeline_words_detect import router as _detect_router
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["ocr-pipeline"])
|
||||
_local_router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["ocr-pipeline"])
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -65,689 +50,13 @@ class WordGroundTruthRequest(BaseModel):
|
||||
notes: Optional[str] = None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Word Detection Endpoint (Step 7)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.post("/sessions/{session_id}/words")
|
||||
async def detect_words(
|
||||
session_id: str,
|
||||
request: Request,
|
||||
engine: str = "auto",
|
||||
pronunciation: str = "british",
|
||||
stream: bool = False,
|
||||
skip_heal_gaps: bool = False,
|
||||
grid_method: str = "v2",
|
||||
):
|
||||
"""Build word grid from columns × rows, OCR each cell.
|
||||
|
||||
Query params:
|
||||
engine: 'auto' (default), 'tesseract', 'rapid', or 'paddle'
|
||||
pronunciation: 'british' (default) or 'american' — for IPA dictionary lookup
|
||||
stream: false (default) for JSON response, true for SSE streaming
|
||||
skip_heal_gaps: false (default). When true, cells keep exact row geometry
|
||||
positions without gap-healing expansion. Better for overlay rendering.
|
||||
grid_method: 'v2' (default) or 'words_first' — grid construction strategy.
|
||||
'v2' uses pre-detected columns/rows (top-down).
|
||||
'words_first' clusters words bottom-up (no column/row detection needed).
|
||||
"""
|
||||
# PaddleOCR is full-page remote OCR → force words_first grid method
|
||||
if engine == "paddle" and grid_method != "words_first":
|
||||
logger.info("detect_words: engine=paddle requires words_first, overriding grid_method=%s", grid_method)
|
||||
grid_method = "words_first"
|
||||
|
||||
if session_id not in _cache:
|
||||
logger.info("detect_words: session %s not in cache, loading from DB", session_id)
|
||||
await _load_session_to_cache(session_id)
|
||||
cached = _get_cached(session_id)
|
||||
|
||||
dewarped_bgr = cached.get("cropped_bgr") if cached.get("cropped_bgr") is not None else cached.get("dewarped_bgr")
|
||||
if dewarped_bgr is None:
|
||||
logger.warning("detect_words: no cropped/dewarped image for session %s (cache keys: %s)",
|
||||
session_id, [k for k in cached.keys() if k.endswith('_bgr')])
|
||||
raise HTTPException(status_code=400, detail="Crop or dewarp must be completed before word detection")
|
||||
|
||||
session = await get_session_db(session_id)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||||
|
||||
column_result = session.get("column_result")
|
||||
row_result = session.get("row_result")
|
||||
if not column_result or not column_result.get("columns"):
|
||||
# No column detection — synthesize a single full-page pseudo-column.
|
||||
# This enables the overlay pipeline which skips column detection.
|
||||
img_h_tmp, img_w_tmp = dewarped_bgr.shape[:2]
|
||||
column_result = {
|
||||
"columns": [{
|
||||
"type": "column_text",
|
||||
"x": 0, "y": 0,
|
||||
"width": img_w_tmp, "height": img_h_tmp,
|
||||
"classification_confidence": 1.0,
|
||||
"classification_method": "full_page_fallback",
|
||||
}],
|
||||
"zones": [],
|
||||
"duration_seconds": 0,
|
||||
}
|
||||
logger.info("detect_words: no column_result — using full-page pseudo-column %dx%d", img_w_tmp, img_h_tmp)
|
||||
if grid_method != "words_first" and (not row_result or not row_result.get("rows")):
|
||||
raise HTTPException(status_code=400, detail="Row detection must be completed first")
|
||||
|
||||
# Convert column dicts back to PageRegion objects
|
||||
col_regions = [
|
||||
PageRegion(
|
||||
type=c["type"],
|
||||
x=c["x"], y=c["y"],
|
||||
width=c["width"], height=c["height"],
|
||||
classification_confidence=c.get("classification_confidence", 1.0),
|
||||
classification_method=c.get("classification_method", ""),
|
||||
)
|
||||
for c in column_result["columns"]
|
||||
]
|
||||
|
||||
# Convert row dicts back to RowGeometry objects
|
||||
row_geoms = [
|
||||
RowGeometry(
|
||||
index=r["index"],
|
||||
x=r["x"], y=r["y"],
|
||||
width=r["width"], height=r["height"],
|
||||
word_count=r.get("word_count", 0),
|
||||
words=[],
|
||||
row_type=r.get("row_type", "content"),
|
||||
gap_before=r.get("gap_before", 0),
|
||||
)
|
||||
for r in row_result["rows"]
|
||||
]
|
||||
|
||||
# Cell-First OCR (v2): no full-page word re-population needed.
|
||||
# Each cell is cropped and OCR'd in isolation → no neighbour bleeding.
|
||||
# We still need word_count > 0 for row filtering in build_cell_grid_v2,
|
||||
# so populate from cached words if available (just for counting).
|
||||
word_dicts = cached.get("_word_dicts")
|
||||
if word_dicts is None:
|
||||
ocr_img_tmp = create_ocr_image(dewarped_bgr)
|
||||
geo_result = detect_column_geometry(ocr_img_tmp, dewarped_bgr)
|
||||
if geo_result is not None:
|
||||
_geoms, left_x, right_x, top_y, bottom_y, word_dicts, inv = geo_result
|
||||
cached["_word_dicts"] = word_dicts
|
||||
cached["_inv"] = inv
|
||||
cached["_content_bounds"] = (left_x, right_x, top_y, bottom_y)
|
||||
|
||||
if word_dicts:
|
||||
content_bounds = cached.get("_content_bounds")
|
||||
if content_bounds:
|
||||
_lx, _rx, top_y, _by = content_bounds
|
||||
else:
|
||||
top_y = min(r.y for r in row_geoms) if row_geoms else 0
|
||||
|
||||
for row in row_geoms:
|
||||
row_y_rel = row.y - top_y
|
||||
row_bottom_rel = row_y_rel + row.height
|
||||
row.words = [
|
||||
w for w in word_dicts
|
||||
if row_y_rel <= w['top'] + w['height'] / 2 < row_bottom_rel
|
||||
]
|
||||
row.word_count = len(row.words)
|
||||
|
||||
# Exclude rows that fall within box zones.
|
||||
# Use inner box range (shrunk by border_thickness) so that rows at
|
||||
# the boundary (overlapping with the box border) are NOT excluded.
|
||||
zones = column_result.get("zones") or []
|
||||
box_ranges_inner = []
|
||||
for zone in zones:
|
||||
if zone.get("zone_type") == "box" and zone.get("box"):
|
||||
box = zone["box"]
|
||||
bt = max(box.get("border_thickness", 0), 5) # minimum 5px margin
|
||||
box_ranges_inner.append((box["y"] + bt, box["y"] + box["height"] - bt))
|
||||
|
||||
if box_ranges_inner:
|
||||
def _row_in_box(r):
|
||||
center_y = r.y + r.height / 2
|
||||
return any(by_s <= center_y < by_e for by_s, by_e in box_ranges_inner)
|
||||
|
||||
before_count = len(row_geoms)
|
||||
row_geoms = [r for r in row_geoms if not _row_in_box(r)]
|
||||
excluded = before_count - len(row_geoms)
|
||||
if excluded:
|
||||
logger.info(f"detect_words: excluded {excluded} rows inside box zones")
|
||||
|
||||
# --- Words-First path: bottom-up grid from word boxes ---
|
||||
if grid_method == "words_first":
|
||||
t0 = time.time()
|
||||
img_h, img_w = dewarped_bgr.shape[:2]
|
||||
|
||||
# For paddle engine: run remote PaddleOCR full-page instead of Tesseract
|
||||
if engine == "paddle":
|
||||
from cv_ocr_engines import ocr_region_paddle
|
||||
|
||||
wf_word_dicts = await ocr_region_paddle(dewarped_bgr, region=None)
|
||||
# PaddleOCR returns absolute coordinates, no content_bounds offset needed
|
||||
cached["_paddle_word_dicts"] = wf_word_dicts
|
||||
else:
|
||||
# Get word_dicts from cache or run Tesseract full-page
|
||||
wf_word_dicts = cached.get("_word_dicts")
|
||||
if wf_word_dicts is None:
|
||||
ocr_img_tmp = create_ocr_image(dewarped_bgr)
|
||||
geo_result = detect_column_geometry(ocr_img_tmp, dewarped_bgr)
|
||||
if geo_result is not None:
|
||||
_geoms, left_x, right_x, top_y, bottom_y, wf_word_dicts, inv = geo_result
|
||||
cached["_word_dicts"] = wf_word_dicts
|
||||
cached["_inv"] = inv
|
||||
cached["_content_bounds"] = (left_x, right_x, top_y, bottom_y)
|
||||
|
||||
if not wf_word_dicts:
|
||||
raise HTTPException(status_code=400, detail="No words detected — cannot build words-first grid")
|
||||
|
||||
# Convert word coordinates to absolute image coordinates if needed
|
||||
# (detect_column_geometry returns words relative to content ROI)
|
||||
# PaddleOCR already returns absolute coordinates — skip offset.
|
||||
if engine != "paddle":
|
||||
content_bounds = cached.get("_content_bounds")
|
||||
if content_bounds:
|
||||
lx, _rx, ty, _by = content_bounds
|
||||
abs_words = []
|
||||
for w in wf_word_dicts:
|
||||
abs_words.append({
|
||||
**w,
|
||||
'left': w['left'] + lx,
|
||||
'top': w['top'] + ty,
|
||||
})
|
||||
wf_word_dicts = abs_words
|
||||
|
||||
# Extract box rects for box-aware column clustering
|
||||
box_rects = []
|
||||
for zone in zones:
|
||||
if zone.get("zone_type") == "box" and zone.get("box"):
|
||||
box_rects.append(zone["box"])
|
||||
|
||||
cells, columns_meta = build_grid_from_words(
|
||||
wf_word_dicts, img_w, img_h, box_rects=box_rects or None,
|
||||
)
|
||||
duration = time.time() - t0
|
||||
|
||||
# Apply IPA phonetic fixes
|
||||
fix_cell_phonetics(cells, pronunciation=pronunciation)
|
||||
|
||||
# Add zone_index for backward compat
|
||||
for cell in cells:
|
||||
cell.setdefault("zone_index", 0)
|
||||
|
||||
col_types = {c['type'] for c in columns_meta}
|
||||
is_vocab = bool(col_types & {'column_en', 'column_de'})
|
||||
n_rows = len(set(c['row_index'] for c in cells)) if cells else 0
|
||||
n_cols = len(columns_meta)
|
||||
used_engine = "paddle" if engine == "paddle" else "words_first"
|
||||
|
||||
word_result = {
|
||||
"cells": cells,
|
||||
"grid_shape": {
|
||||
"rows": n_rows,
|
||||
"cols": n_cols,
|
||||
"total_cells": len(cells),
|
||||
},
|
||||
"columns_used": columns_meta,
|
||||
"layout": "vocab" if is_vocab else "generic",
|
||||
"image_width": img_w,
|
||||
"image_height": img_h,
|
||||
"duration_seconds": round(duration, 2),
|
||||
"ocr_engine": used_engine,
|
||||
"grid_method": "words_first",
|
||||
"summary": {
|
||||
"total_cells": len(cells),
|
||||
"non_empty_cells": sum(1 for c in cells if c.get("text")),
|
||||
"low_confidence": sum(1 for c in cells if 0 < c.get("confidence", 0) < 50),
|
||||
},
|
||||
}
|
||||
|
||||
if is_vocab or 'column_text' in col_types:
|
||||
entries = _cells_to_vocab_entries(cells, columns_meta)
|
||||
entries = _fix_phonetic_brackets(entries, pronunciation=pronunciation)
|
||||
word_result["vocab_entries"] = entries
|
||||
word_result["entries"] = entries
|
||||
word_result["entry_count"] = len(entries)
|
||||
word_result["summary"]["total_entries"] = len(entries)
|
||||
word_result["summary"]["with_english"] = sum(1 for e in entries if e.get("english"))
|
||||
word_result["summary"]["with_german"] = sum(1 for e in entries if e.get("german"))
|
||||
|
||||
await update_session_db(session_id, word_result=word_result, current_step=8)
|
||||
cached["word_result"] = word_result
|
||||
|
||||
logger.info(f"OCR Pipeline: words-first session {session_id}: "
|
||||
f"{len(cells)} cells ({duration:.2f}s), {n_rows} rows, {n_cols} cols")
|
||||
|
||||
await _append_pipeline_log(session_id, "words", {
|
||||
"grid_method": "words_first",
|
||||
"total_cells": len(cells),
|
||||
"non_empty_cells": word_result["summary"]["non_empty_cells"],
|
||||
"ocr_engine": used_engine,
|
||||
"layout": word_result["layout"],
|
||||
}, duration_ms=int(duration * 1000))
|
||||
|
||||
return {"session_id": session_id, **word_result}
|
||||
|
||||
if stream:
|
||||
# Cell-First OCR v2: use batch-then-stream approach instead of
|
||||
# per-cell streaming. The parallel ThreadPoolExecutor in
|
||||
# build_cell_grid_v2 is much faster than sequential streaming.
|
||||
return StreamingResponse(
|
||||
_word_batch_stream_generator(
|
||||
session_id, cached, col_regions, row_geoms,
|
||||
dewarped_bgr, engine, pronunciation, request,
|
||||
skip_heal_gaps=skip_heal_gaps,
|
||||
),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no",
|
||||
},
|
||||
)
|
||||
|
||||
# --- Non-streaming path (grid_method=v2) ---
|
||||
t0 = time.time()
|
||||
|
||||
# Create binarized OCR image (for Tesseract)
|
||||
ocr_img = create_ocr_image(dewarped_bgr)
|
||||
img_h, img_w = dewarped_bgr.shape[:2]
|
||||
|
||||
# Build cell grid using Cell-First OCR (v2) — each cell cropped in isolation
|
||||
cells, columns_meta = build_cell_grid_v2(
|
||||
ocr_img, col_regions, row_geoms, img_w, img_h,
|
||||
ocr_engine=engine, img_bgr=dewarped_bgr,
|
||||
skip_heal_gaps=skip_heal_gaps,
|
||||
)
|
||||
duration = time.time() - t0
|
||||
|
||||
# Add zone_index to each cell (default 0 for backward compatibility)
|
||||
for cell in cells:
|
||||
cell.setdefault("zone_index", 0)
|
||||
|
||||
# Layout detection
|
||||
col_types = {c['type'] for c in columns_meta}
|
||||
is_vocab = bool(col_types & {'column_en', 'column_de'})
|
||||
|
||||
# Count content rows and columns for grid_shape
|
||||
n_content_rows = len([r for r in row_geoms if r.row_type == 'content'])
|
||||
n_cols = len(columns_meta)
|
||||
|
||||
# Determine which engine was actually used
|
||||
used_engine = cells[0].get("ocr_engine", "tesseract") if cells else engine
|
||||
|
||||
# Apply IPA phonetic fixes directly to cell texts (for overlay mode)
|
||||
fix_cell_phonetics(cells, pronunciation=pronunciation)
|
||||
|
||||
# Grid result (always generic)
|
||||
word_result = {
|
||||
"cells": cells,
|
||||
"grid_shape": {
|
||||
"rows": n_content_rows,
|
||||
"cols": n_cols,
|
||||
"total_cells": len(cells),
|
||||
},
|
||||
"columns_used": columns_meta,
|
||||
"layout": "vocab" if is_vocab else "generic",
|
||||
"image_width": img_w,
|
||||
"image_height": img_h,
|
||||
"duration_seconds": round(duration, 2),
|
||||
"ocr_engine": used_engine,
|
||||
"summary": {
|
||||
"total_cells": len(cells),
|
||||
"non_empty_cells": sum(1 for c in cells if c.get("text")),
|
||||
"low_confidence": sum(1 for c in cells if 0 < c.get("confidence", 0) < 50),
|
||||
},
|
||||
}
|
||||
|
||||
# For vocab layout or single-column (box sub-sessions): map cells 1:1
|
||||
# to vocab entries (row→entry).
|
||||
has_text_col = 'column_text' in col_types
|
||||
if is_vocab or has_text_col:
|
||||
entries = _cells_to_vocab_entries(cells, columns_meta)
|
||||
entries = _fix_phonetic_brackets(entries, pronunciation=pronunciation)
|
||||
word_result["vocab_entries"] = entries
|
||||
word_result["entries"] = entries
|
||||
word_result["entry_count"] = len(entries)
|
||||
word_result["summary"]["total_entries"] = len(entries)
|
||||
word_result["summary"]["with_english"] = sum(1 for e in entries if e.get("english"))
|
||||
word_result["summary"]["with_german"] = sum(1 for e in entries if e.get("german"))
|
||||
|
||||
# Persist to DB
|
||||
await update_session_db(
|
||||
session_id,
|
||||
word_result=word_result,
|
||||
current_step=8,
|
||||
)
|
||||
|
||||
cached["word_result"] = word_result
|
||||
|
||||
logger.info(f"OCR Pipeline: words session {session_id}: "
|
||||
f"layout={word_result['layout']}, "
|
||||
f"{len(cells)} cells ({duration:.2f}s), summary: {word_result['summary']}")
|
||||
|
||||
await _append_pipeline_log(session_id, "words", {
|
||||
"total_cells": len(cells),
|
||||
"non_empty_cells": word_result["summary"]["non_empty_cells"],
|
||||
"low_confidence_count": word_result["summary"]["low_confidence"],
|
||||
"ocr_engine": used_engine,
|
||||
"layout": word_result["layout"],
|
||||
"entry_count": word_result.get("entry_count", 0),
|
||||
}, duration_ms=int(duration * 1000))
|
||||
|
||||
return {
|
||||
"session_id": session_id,
|
||||
**word_result,
|
||||
}
|
||||
|
||||
|
||||
async def _word_batch_stream_generator(
|
||||
session_id: str,
|
||||
cached: Dict[str, Any],
|
||||
col_regions: List[PageRegion],
|
||||
row_geoms: List[RowGeometry],
|
||||
dewarped_bgr: np.ndarray,
|
||||
engine: str,
|
||||
pronunciation: str,
|
||||
request: Request,
|
||||
skip_heal_gaps: bool = False,
|
||||
):
|
||||
"""SSE generator that runs batch OCR (parallel) then streams results.
|
||||
|
||||
Unlike the old per-cell streaming, this uses build_cell_grid_v2 with
|
||||
ThreadPoolExecutor for parallel OCR, then emits all cells as SSE events.
|
||||
The 'preparing' event keeps the connection alive during OCR processing.
|
||||
"""
|
||||
import asyncio
|
||||
|
||||
t0 = time.time()
|
||||
ocr_img = create_ocr_image(dewarped_bgr)
|
||||
img_h, img_w = dewarped_bgr.shape[:2]
|
||||
|
||||
_skip_types = {'column_ignore', 'header', 'footer', 'margin_top', 'margin_bottom', 'margin_left', 'margin_right'}
|
||||
n_content_rows = len([r for r in row_geoms if r.row_type == 'content'])
|
||||
n_cols = len([c for c in col_regions if c.type not in _skip_types])
|
||||
col_types = {c.type for c in col_regions if c.type not in _skip_types}
|
||||
is_vocab = bool(col_types & {'column_en', 'column_de'})
|
||||
total_cells = n_content_rows * n_cols
|
||||
|
||||
# 1. Send meta event immediately
|
||||
meta_event = {
|
||||
"type": "meta",
|
||||
"grid_shape": {"rows": n_content_rows, "cols": n_cols, "total_cells": total_cells},
|
||||
"layout": "vocab" if is_vocab else "generic",
|
||||
}
|
||||
yield f"data: {json.dumps(meta_event)}\n\n"
|
||||
|
||||
# 2. Send preparing event (keepalive for proxy)
|
||||
yield f"data: {json.dumps({'type': 'preparing', 'message': 'Cell-First OCR laeuft parallel...'})}\n\n"
|
||||
|
||||
# 3. Run batch OCR in thread pool with periodic keepalive events.
|
||||
# The OCR takes 30-60s and proxy servers (Nginx) may drop idle SSE
|
||||
# connections after 30-60s. Send keepalive every 5s to prevent this.
|
||||
loop = asyncio.get_event_loop()
|
||||
ocr_future = loop.run_in_executor(
|
||||
None,
|
||||
lambda: build_cell_grid_v2(
|
||||
ocr_img, col_regions, row_geoms, img_w, img_h,
|
||||
ocr_engine=engine, img_bgr=dewarped_bgr,
|
||||
skip_heal_gaps=skip_heal_gaps,
|
||||
),
|
||||
)
|
||||
|
||||
# Send keepalive events every 5 seconds while OCR runs
|
||||
keepalive_count = 0
|
||||
while not ocr_future.done():
|
||||
try:
|
||||
cells, columns_meta = await asyncio.wait_for(
|
||||
asyncio.shield(ocr_future), timeout=5.0,
|
||||
)
|
||||
break # OCR finished
|
||||
except asyncio.TimeoutError:
|
||||
keepalive_count += 1
|
||||
elapsed = int(time.time() - t0)
|
||||
yield f"data: {json.dumps({'type': 'keepalive', 'elapsed': elapsed, 'message': f'OCR laeuft... ({elapsed}s)'})}\n\n"
|
||||
if await request.is_disconnected():
|
||||
logger.info(f"SSE batch: client disconnected during OCR for {session_id}")
|
||||
ocr_future.cancel()
|
||||
return
|
||||
else:
|
||||
cells, columns_meta = ocr_future.result()
|
||||
|
||||
if await request.is_disconnected():
|
||||
logger.info(f"SSE batch: client disconnected after OCR for {session_id}")
|
||||
return
|
||||
|
||||
# 4. Apply IPA phonetic fixes directly to cell texts (for overlay mode)
|
||||
fix_cell_phonetics(cells, pronunciation=pronunciation)
|
||||
|
||||
# 5. Send columns meta
|
||||
if columns_meta:
|
||||
yield f"data: {json.dumps({'type': 'columns', 'columns_used': columns_meta})}\n\n"
|
||||
|
||||
# 6. Stream all cells
|
||||
for idx, cell in enumerate(cells):
|
||||
cell_event = {
|
||||
"type": "cell",
|
||||
"cell": cell,
|
||||
"progress": {"current": idx + 1, "total": len(cells)},
|
||||
}
|
||||
yield f"data: {json.dumps(cell_event)}\n\n"
|
||||
|
||||
# 6. Build final result and persist
|
||||
duration = time.time() - t0
|
||||
used_engine = cells[0].get("ocr_engine", "tesseract") if cells else engine
|
||||
|
||||
word_result = {
|
||||
"cells": cells,
|
||||
"grid_shape": {"rows": n_content_rows, "cols": n_cols, "total_cells": len(cells)},
|
||||
"columns_used": columns_meta,
|
||||
"layout": "vocab" if is_vocab else "generic",
|
||||
"image_width": img_w,
|
||||
"image_height": img_h,
|
||||
"duration_seconds": round(duration, 2),
|
||||
"ocr_engine": used_engine,
|
||||
"summary": {
|
||||
"total_cells": len(cells),
|
||||
"non_empty_cells": sum(1 for c in cells if c.get("text")),
|
||||
"low_confidence": sum(1 for c in cells if 0 < c.get("confidence", 0) < 50),
|
||||
},
|
||||
}
|
||||
|
||||
vocab_entries = None
|
||||
has_text_col = 'column_text' in col_types
|
||||
if is_vocab or has_text_col:
|
||||
entries = _cells_to_vocab_entries(cells, columns_meta)
|
||||
entries = _fix_phonetic_brackets(entries, pronunciation=pronunciation)
|
||||
word_result["vocab_entries"] = entries
|
||||
word_result["entries"] = entries
|
||||
word_result["entry_count"] = len(entries)
|
||||
word_result["summary"]["total_entries"] = len(entries)
|
||||
word_result["summary"]["with_english"] = sum(1 for e in entries if e.get("english"))
|
||||
word_result["summary"]["with_german"] = sum(1 for e in entries if e.get("german"))
|
||||
vocab_entries = entries
|
||||
|
||||
await update_session_db(session_id, word_result=word_result, current_step=8)
|
||||
cached["word_result"] = word_result
|
||||
|
||||
logger.info(f"OCR Pipeline SSE batch: words session {session_id}: "
|
||||
f"layout={word_result['layout']}, {len(cells)} cells ({duration:.2f}s)")
|
||||
|
||||
# 7. Send complete event
|
||||
complete_event = {
|
||||
"type": "complete",
|
||||
"summary": word_result["summary"],
|
||||
"duration_seconds": round(duration, 2),
|
||||
"ocr_engine": used_engine,
|
||||
}
|
||||
if vocab_entries is not None:
|
||||
complete_event["vocab_entries"] = vocab_entries
|
||||
yield f"data: {json.dumps(complete_event)}\n\n"
|
||||
|
||||
|
||||
async def _word_stream_generator(
|
||||
session_id: str,
|
||||
cached: Dict[str, Any],
|
||||
col_regions: List[PageRegion],
|
||||
row_geoms: List[RowGeometry],
|
||||
dewarped_bgr: np.ndarray,
|
||||
engine: str,
|
||||
pronunciation: str,
|
||||
request: Request,
|
||||
):
|
||||
"""SSE generator that yields cell-by-cell OCR progress."""
|
||||
t0 = time.time()
|
||||
|
||||
ocr_img = create_ocr_image(dewarped_bgr)
|
||||
img_h, img_w = dewarped_bgr.shape[:2]
|
||||
|
||||
# Compute grid shape upfront for the meta event
|
||||
n_content_rows = len([r for r in row_geoms if r.row_type == 'content'])
|
||||
_skip_types = {'column_ignore', 'header', 'footer', 'margin_top', 'margin_bottom', 'margin_left', 'margin_right'}
|
||||
n_cols = len([c for c in col_regions if c.type not in _skip_types])
|
||||
|
||||
# Determine layout
|
||||
col_types = {c.type for c in col_regions if c.type not in _skip_types}
|
||||
is_vocab = bool(col_types & {'column_en', 'column_de'})
|
||||
|
||||
# Start streaming — first event: meta
|
||||
columns_meta = None # will be set from first yield
|
||||
total_cells = n_content_rows * n_cols
|
||||
|
||||
meta_event = {
|
||||
"type": "meta",
|
||||
"grid_shape": {"rows": n_content_rows, "cols": n_cols, "total_cells": total_cells},
|
||||
"layout": "vocab" if is_vocab else "generic",
|
||||
}
|
||||
yield f"data: {json.dumps(meta_event)}\n\n"
|
||||
|
||||
# Keepalive: send preparing event so proxy doesn't timeout during OCR init
|
||||
yield f"data: {json.dumps({'type': 'preparing', 'message': 'Cell-First OCR wird initialisiert...'})}\n\n"
|
||||
|
||||
# Stream cells one by one
|
||||
all_cells: List[Dict[str, Any]] = []
|
||||
cell_idx = 0
|
||||
last_keepalive = time.time()
|
||||
|
||||
for cell, cols_meta, total in build_cell_grid_v2_streaming(
|
||||
ocr_img, col_regions, row_geoms, img_w, img_h,
|
||||
ocr_engine=engine, img_bgr=dewarped_bgr,
|
||||
):
|
||||
if await request.is_disconnected():
|
||||
logger.info(f"SSE: client disconnected during streaming for {session_id}")
|
||||
return
|
||||
|
||||
if columns_meta is None:
|
||||
columns_meta = cols_meta
|
||||
# Send columns_used as part of first cell or update meta
|
||||
meta_update = {
|
||||
"type": "columns",
|
||||
"columns_used": cols_meta,
|
||||
}
|
||||
yield f"data: {json.dumps(meta_update)}\n\n"
|
||||
|
||||
all_cells.append(cell)
|
||||
cell_idx += 1
|
||||
|
||||
cell_event = {
|
||||
"type": "cell",
|
||||
"cell": cell,
|
||||
"progress": {"current": cell_idx, "total": total},
|
||||
}
|
||||
yield f"data: {json.dumps(cell_event)}\n\n"
|
||||
|
||||
# All cells done — build final result
|
||||
duration = time.time() - t0
|
||||
if columns_meta is None:
|
||||
columns_meta = []
|
||||
|
||||
# Post-OCR: remove rows where ALL cells are empty (inter-row gaps
|
||||
# that had stray Tesseract artifacts giving word_count > 0).
|
||||
rows_with_text: set = set()
|
||||
for c in all_cells:
|
||||
if c.get("text", "").strip():
|
||||
rows_with_text.add(c["row_index"])
|
||||
before_filter = len(all_cells)
|
||||
all_cells = [c for c in all_cells if c["row_index"] in rows_with_text]
|
||||
empty_rows_removed = (before_filter - len(all_cells)) // max(n_cols, 1)
|
||||
if empty_rows_removed > 0:
|
||||
logger.info(f"SSE: removed {empty_rows_removed} all-empty rows after OCR")
|
||||
|
||||
used_engine = all_cells[0].get("ocr_engine", "tesseract") if all_cells else engine
|
||||
|
||||
# Apply IPA phonetic fixes directly to cell texts (for overlay mode)
|
||||
fix_cell_phonetics(all_cells, pronunciation=pronunciation)
|
||||
|
||||
word_result = {
|
||||
"cells": all_cells,
|
||||
"grid_shape": {
|
||||
"rows": n_content_rows,
|
||||
"cols": n_cols,
|
||||
"total_cells": len(all_cells),
|
||||
},
|
||||
"columns_used": columns_meta,
|
||||
"layout": "vocab" if is_vocab else "generic",
|
||||
"image_width": img_w,
|
||||
"image_height": img_h,
|
||||
"duration_seconds": round(duration, 2),
|
||||
"ocr_engine": used_engine,
|
||||
"summary": {
|
||||
"total_cells": len(all_cells),
|
||||
"non_empty_cells": sum(1 for c in all_cells if c.get("text")),
|
||||
"low_confidence": sum(1 for c in all_cells if 0 < c.get("confidence", 0) < 50),
|
||||
},
|
||||
}
|
||||
|
||||
# For vocab layout or single-column (box sub-sessions): map cells 1:1
|
||||
# to vocab entries (row→entry).
|
||||
vocab_entries = None
|
||||
has_text_col = 'column_text' in col_types
|
||||
if is_vocab or has_text_col:
|
||||
entries = _cells_to_vocab_entries(all_cells, columns_meta)
|
||||
entries = _fix_character_confusion(entries)
|
||||
entries = _fix_phonetic_brackets(entries, pronunciation=pronunciation)
|
||||
word_result["vocab_entries"] = entries
|
||||
word_result["entries"] = entries
|
||||
word_result["entry_count"] = len(entries)
|
||||
word_result["summary"]["total_entries"] = len(entries)
|
||||
word_result["summary"]["with_english"] = sum(1 for e in entries if e.get("english"))
|
||||
word_result["summary"]["with_german"] = sum(1 for e in entries if e.get("german"))
|
||||
vocab_entries = entries
|
||||
|
||||
# Persist to DB
|
||||
await update_session_db(
|
||||
session_id,
|
||||
word_result=word_result,
|
||||
current_step=8,
|
||||
)
|
||||
cached["word_result"] = word_result
|
||||
|
||||
logger.info(f"OCR Pipeline SSE: words session {session_id}: "
|
||||
f"layout={word_result['layout']}, "
|
||||
f"{len(all_cells)} cells ({duration:.2f}s)")
|
||||
|
||||
# Final complete event
|
||||
complete_event = {
|
||||
"type": "complete",
|
||||
"summary": word_result["summary"],
|
||||
"duration_seconds": round(duration, 2),
|
||||
"ocr_engine": used_engine,
|
||||
}
|
||||
if vocab_entries is not None:
|
||||
complete_event["vocab_entries"] = vocab_entries
|
||||
yield f"data: {json.dumps(complete_event)}\n\n"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# PaddleOCR Direct Endpoint
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.post("/sessions/{session_id}/paddle-direct")
|
||||
@_local_router.post("/sessions/{session_id}/paddle-direct")
|
||||
async def paddle_direct(session_id: str):
|
||||
"""Run PaddleOCR on the preprocessed image and build a word grid directly.
|
||||
|
||||
Expects orientation/deskew/dewarp/crop to be done already.
|
||||
Uses the cropped image (falls back to dewarped, then original).
|
||||
The used image is stored as cropped_png so OverlayReconstruction
|
||||
can display it as the background.
|
||||
"""
|
||||
# Try preprocessed images first (crop > dewarp > original)
|
||||
"""Run PaddleOCR on the preprocessed image and build a word grid directly."""
|
||||
img_png = await get_session_image(session_id, "cropped")
|
||||
if not img_png:
|
||||
img_png = await get_session_image(session_id, "dewarped")
|
||||
@@ -770,13 +79,9 @@ async def paddle_direct(session_id: str):
|
||||
if not word_dicts:
|
||||
raise HTTPException(status_code=400, detail="PaddleOCR returned no words")
|
||||
|
||||
# Reuse build_grid_from_words — same function that works in the regular
|
||||
# pipeline with PaddleOCR (engine=paddle, grid_method=words_first).
|
||||
# Handles phrase splitting, column clustering, and reading order.
|
||||
cells, columns_meta = build_grid_from_words(word_dicts, img_w, img_h)
|
||||
duration = time.time() - t0
|
||||
|
||||
# Tag cells as paddle_direct
|
||||
for cell in cells:
|
||||
cell["ocr_engine"] = "paddle_direct"
|
||||
|
||||
@@ -787,11 +92,7 @@ async def paddle_direct(session_id: str):
|
||||
|
||||
word_result = {
|
||||
"cells": cells,
|
||||
"grid_shape": {
|
||||
"rows": n_rows,
|
||||
"cols": n_cols,
|
||||
"total_cells": len(cells),
|
||||
},
|
||||
"grid_shape": {"rows": n_rows, "cols": n_cols, "total_cells": len(cells)},
|
||||
"columns_used": columns_meta,
|
||||
"layout": "vocab" if is_vocab else "generic",
|
||||
"image_width": img_w,
|
||||
@@ -806,7 +107,6 @@ async def paddle_direct(session_id: str):
|
||||
},
|
||||
}
|
||||
|
||||
# Store preprocessed image as cropped_png so OverlayReconstruction shows it
|
||||
await update_session_db(
|
||||
session_id,
|
||||
word_result=word_result,
|
||||
@@ -832,7 +132,7 @@ async def paddle_direct(session_id: str):
|
||||
# Ground Truth Words Endpoints
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.post("/sessions/{session_id}/ground-truth/words")
|
||||
@_local_router.post("/sessions/{session_id}/ground-truth/words")
|
||||
async def save_word_ground_truth(session_id: str, req: WordGroundTruthRequest):
|
||||
"""Save ground truth feedback for the word recognition step."""
|
||||
session = await get_session_db(session_id)
|
||||
@@ -857,7 +157,7 @@ async def save_word_ground_truth(session_id: str, req: WordGroundTruthRequest):
|
||||
return {"session_id": session_id, "ground_truth": gt}
|
||||
|
||||
|
||||
@router.get("/sessions/{session_id}/ground-truth/words")
|
||||
@_local_router.get("/sessions/{session_id}/ground-truth/words")
|
||||
async def get_word_ground_truth(session_id: str):
|
||||
"""Retrieve saved ground truth for word recognition."""
|
||||
session = await get_session_db(session_id)
|
||||
@@ -874,3 +174,12 @@ async def get_word_ground_truth(session_id: str):
|
||||
"words_gt": words_gt,
|
||||
"words_auto": session.get("word_result"),
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Composite router
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
router = APIRouter()
|
||||
router.include_router(_detect_router)
|
||||
router.include_router(_local_router)
|
||||
|
||||
393
klausur-service/backend/ocr_pipeline_words_detect.py
Normal file
393
klausur-service/backend/ocr_pipeline_words_detect.py
Normal file
@@ -0,0 +1,393 @@
|
||||
"""
|
||||
OCR Pipeline Words Detect — main word detection endpoint (Step 7).
|
||||
|
||||
Extracted from ocr_pipeline_words.py. Contains the ``detect_words``
|
||||
endpoint which handles both v2 and words_first grid methods.
|
||||
|
||||
Lizenz: Apache 2.0
|
||||
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import numpy as np
|
||||
from fastapi import APIRouter, HTTPException, Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
from cv_vocab_pipeline import (
|
||||
PageRegion,
|
||||
RowGeometry,
|
||||
_cells_to_vocab_entries,
|
||||
_fix_phonetic_brackets,
|
||||
fix_cell_phonetics,
|
||||
build_cell_grid_v2,
|
||||
create_ocr_image,
|
||||
detect_column_geometry,
|
||||
)
|
||||
from cv_words_first import build_grid_from_words
|
||||
from ocr_pipeline_session_store import (
|
||||
get_session_db,
|
||||
update_session_db,
|
||||
)
|
||||
from ocr_pipeline_common import (
|
||||
_cache,
|
||||
_load_session_to_cache,
|
||||
_get_cached,
|
||||
_append_pipeline_log,
|
||||
)
|
||||
from ocr_pipeline_words_stream import (
|
||||
_word_batch_stream_generator,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/v1/ocr-pipeline", tags=["ocr-pipeline"])
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Word Detection Endpoint (Step 7)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.post("/sessions/{session_id}/words")
|
||||
async def detect_words(
|
||||
session_id: str,
|
||||
request: Request,
|
||||
engine: str = "auto",
|
||||
pronunciation: str = "british",
|
||||
stream: bool = False,
|
||||
skip_heal_gaps: bool = False,
|
||||
grid_method: str = "v2",
|
||||
):
|
||||
"""Build word grid from columns x rows, OCR each cell.
|
||||
|
||||
Query params:
|
||||
engine: 'auto' (default), 'tesseract', 'rapid', or 'paddle'
|
||||
pronunciation: 'british' (default) or 'american'
|
||||
stream: false (default) for JSON response, true for SSE streaming
|
||||
skip_heal_gaps: false (default). When true, cells keep exact row geometry.
|
||||
grid_method: 'v2' (default) or 'words_first'
|
||||
"""
|
||||
# PaddleOCR is full-page remote OCR -> force words_first grid method
|
||||
if engine == "paddle" and grid_method != "words_first":
|
||||
logger.info("detect_words: engine=paddle requires words_first, overriding grid_method=%s", grid_method)
|
||||
grid_method = "words_first"
|
||||
|
||||
if session_id not in _cache:
|
||||
logger.info("detect_words: session %s not in cache, loading from DB", session_id)
|
||||
await _load_session_to_cache(session_id)
|
||||
cached = _get_cached(session_id)
|
||||
|
||||
dewarped_bgr = cached.get("cropped_bgr") if cached.get("cropped_bgr") is not None else cached.get("dewarped_bgr")
|
||||
if dewarped_bgr is None:
|
||||
logger.warning("detect_words: no cropped/dewarped image for session %s (cache keys: %s)",
|
||||
session_id, [k for k in cached.keys() if k.endswith('_bgr')])
|
||||
raise HTTPException(status_code=400, detail="Crop or dewarp must be completed before word detection")
|
||||
|
||||
session = await get_session_db(session_id)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
||||
|
||||
column_result = session.get("column_result")
|
||||
row_result = session.get("row_result")
|
||||
if not column_result or not column_result.get("columns"):
|
||||
img_h_tmp, img_w_tmp = dewarped_bgr.shape[:2]
|
||||
column_result = {
|
||||
"columns": [{
|
||||
"type": "column_text",
|
||||
"x": 0, "y": 0,
|
||||
"width": img_w_tmp, "height": img_h_tmp,
|
||||
"classification_confidence": 1.0,
|
||||
"classification_method": "full_page_fallback",
|
||||
}],
|
||||
"zones": [],
|
||||
"duration_seconds": 0,
|
||||
}
|
||||
logger.info("detect_words: no column_result -- using full-page pseudo-column %dx%d", img_w_tmp, img_h_tmp)
|
||||
if grid_method != "words_first" and (not row_result or not row_result.get("rows")):
|
||||
raise HTTPException(status_code=400, detail="Row detection must be completed first")
|
||||
|
||||
# Convert column dicts back to PageRegion objects
|
||||
col_regions = [
|
||||
PageRegion(
|
||||
type=c["type"],
|
||||
x=c["x"], y=c["y"],
|
||||
width=c["width"], height=c["height"],
|
||||
classification_confidence=c.get("classification_confidence", 1.0),
|
||||
classification_method=c.get("classification_method", ""),
|
||||
)
|
||||
for c in column_result["columns"]
|
||||
]
|
||||
|
||||
# Convert row dicts back to RowGeometry objects
|
||||
row_geoms = [
|
||||
RowGeometry(
|
||||
index=r["index"],
|
||||
x=r["x"], y=r["y"],
|
||||
width=r["width"], height=r["height"],
|
||||
word_count=r.get("word_count", 0),
|
||||
words=[],
|
||||
row_type=r.get("row_type", "content"),
|
||||
gap_before=r.get("gap_before", 0),
|
||||
)
|
||||
for r in row_result["rows"]
|
||||
]
|
||||
|
||||
# Populate word counts from cached words
|
||||
word_dicts = cached.get("_word_dicts")
|
||||
if word_dicts is None:
|
||||
ocr_img_tmp = create_ocr_image(dewarped_bgr)
|
||||
geo_result = detect_column_geometry(ocr_img_tmp, dewarped_bgr)
|
||||
if geo_result is not None:
|
||||
_geoms, left_x, right_x, top_y, bottom_y, word_dicts, inv = geo_result
|
||||
cached["_word_dicts"] = word_dicts
|
||||
cached["_inv"] = inv
|
||||
cached["_content_bounds"] = (left_x, right_x, top_y, bottom_y)
|
||||
|
||||
if word_dicts:
|
||||
content_bounds = cached.get("_content_bounds")
|
||||
if content_bounds:
|
||||
_lx, _rx, top_y, _by = content_bounds
|
||||
else:
|
||||
top_y = min(r.y for r in row_geoms) if row_geoms else 0
|
||||
|
||||
for row in row_geoms:
|
||||
row_y_rel = row.y - top_y
|
||||
row_bottom_rel = row_y_rel + row.height
|
||||
row.words = [
|
||||
w for w in word_dicts
|
||||
if row_y_rel <= w['top'] + w['height'] / 2 < row_bottom_rel
|
||||
]
|
||||
row.word_count = len(row.words)
|
||||
|
||||
# Exclude rows that fall within box zones
|
||||
zones = column_result.get("zones") or []
|
||||
box_ranges_inner = []
|
||||
for zone in zones:
|
||||
if zone.get("zone_type") == "box" and zone.get("box"):
|
||||
box = zone["box"]
|
||||
bt = max(box.get("border_thickness", 0), 5)
|
||||
box_ranges_inner.append((box["y"] + bt, box["y"] + box["height"] - bt))
|
||||
|
||||
if box_ranges_inner:
|
||||
def _row_in_box(r):
|
||||
center_y = r.y + r.height / 2
|
||||
return any(by_s <= center_y < by_e for by_s, by_e in box_ranges_inner)
|
||||
|
||||
before_count = len(row_geoms)
|
||||
row_geoms = [r for r in row_geoms if not _row_in_box(r)]
|
||||
excluded = before_count - len(row_geoms)
|
||||
if excluded:
|
||||
logger.info(f"detect_words: excluded {excluded} rows inside box zones")
|
||||
|
||||
# --- Words-First path ---
|
||||
if grid_method == "words_first":
|
||||
return await _words_first_path(
|
||||
session_id, cached, dewarped_bgr, engine, pronunciation, zones,
|
||||
)
|
||||
|
||||
if stream:
|
||||
return StreamingResponse(
|
||||
_word_batch_stream_generator(
|
||||
session_id, cached, col_regions, row_geoms,
|
||||
dewarped_bgr, engine, pronunciation, request,
|
||||
skip_heal_gaps=skip_heal_gaps,
|
||||
),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no",
|
||||
},
|
||||
)
|
||||
|
||||
# --- Non-streaming path (grid_method=v2) ---
|
||||
return await _v2_path(
|
||||
session_id, cached, col_regions, row_geoms,
|
||||
dewarped_bgr, engine, pronunciation, skip_heal_gaps,
|
||||
)
|
||||
|
||||
|
||||
async def _words_first_path(
|
||||
session_id: str,
|
||||
cached: Dict[str, Any],
|
||||
dewarped_bgr: np.ndarray,
|
||||
engine: str,
|
||||
pronunciation: str,
|
||||
zones: list,
|
||||
) -> dict:
|
||||
"""Words-first grid construction path."""
|
||||
t0 = time.time()
|
||||
img_h, img_w = dewarped_bgr.shape[:2]
|
||||
|
||||
if engine == "paddle":
|
||||
from cv_ocr_engines import ocr_region_paddle
|
||||
wf_word_dicts = await ocr_region_paddle(dewarped_bgr, region=None)
|
||||
cached["_paddle_word_dicts"] = wf_word_dicts
|
||||
else:
|
||||
wf_word_dicts = cached.get("_word_dicts")
|
||||
if wf_word_dicts is None:
|
||||
ocr_img_tmp = create_ocr_image(dewarped_bgr)
|
||||
geo_result = detect_column_geometry(ocr_img_tmp, dewarped_bgr)
|
||||
if geo_result is not None:
|
||||
_geoms, left_x, right_x, top_y, bottom_y, wf_word_dicts, inv = geo_result
|
||||
cached["_word_dicts"] = wf_word_dicts
|
||||
cached["_inv"] = inv
|
||||
cached["_content_bounds"] = (left_x, right_x, top_y, bottom_y)
|
||||
|
||||
if not wf_word_dicts:
|
||||
raise HTTPException(status_code=400, detail="No words detected -- cannot build words-first grid")
|
||||
|
||||
# Convert word coordinates to absolute if needed
|
||||
if engine != "paddle":
|
||||
content_bounds = cached.get("_content_bounds")
|
||||
if content_bounds:
|
||||
lx, _rx, ty, _by = content_bounds
|
||||
abs_words = []
|
||||
for w in wf_word_dicts:
|
||||
abs_words.append({**w, 'left': w['left'] + lx, 'top': w['top'] + ty})
|
||||
wf_word_dicts = abs_words
|
||||
|
||||
box_rects = []
|
||||
for zone in zones:
|
||||
if zone.get("zone_type") == "box" and zone.get("box"):
|
||||
box_rects.append(zone["box"])
|
||||
|
||||
cells, columns_meta = build_grid_from_words(
|
||||
wf_word_dicts, img_w, img_h, box_rects=box_rects or None,
|
||||
)
|
||||
duration = time.time() - t0
|
||||
|
||||
fix_cell_phonetics(cells, pronunciation=pronunciation)
|
||||
for cell in cells:
|
||||
cell.setdefault("zone_index", 0)
|
||||
|
||||
col_types = {c['type'] for c in columns_meta}
|
||||
is_vocab = bool(col_types & {'column_en', 'column_de'})
|
||||
n_rows = len(set(c['row_index'] for c in cells)) if cells else 0
|
||||
n_cols = len(columns_meta)
|
||||
used_engine = "paddle" if engine == "paddle" else "words_first"
|
||||
|
||||
word_result = {
|
||||
"cells": cells,
|
||||
"grid_shape": {"rows": n_rows, "cols": n_cols, "total_cells": len(cells)},
|
||||
"columns_used": columns_meta,
|
||||
"layout": "vocab" if is_vocab else "generic",
|
||||
"image_width": img_w,
|
||||
"image_height": img_h,
|
||||
"duration_seconds": round(duration, 2),
|
||||
"ocr_engine": used_engine,
|
||||
"grid_method": "words_first",
|
||||
"summary": {
|
||||
"total_cells": len(cells),
|
||||
"non_empty_cells": sum(1 for c in cells if c.get("text")),
|
||||
"low_confidence": sum(1 for c in cells if 0 < c.get("confidence", 0) < 50),
|
||||
},
|
||||
}
|
||||
|
||||
if is_vocab or 'column_text' in col_types:
|
||||
entries = _cells_to_vocab_entries(cells, columns_meta)
|
||||
entries = _fix_phonetic_brackets(entries, pronunciation=pronunciation)
|
||||
word_result["vocab_entries"] = entries
|
||||
word_result["entries"] = entries
|
||||
word_result["entry_count"] = len(entries)
|
||||
word_result["summary"]["total_entries"] = len(entries)
|
||||
word_result["summary"]["with_english"] = sum(1 for e in entries if e.get("english"))
|
||||
word_result["summary"]["with_german"] = sum(1 for e in entries if e.get("german"))
|
||||
|
||||
await update_session_db(session_id, word_result=word_result, current_step=8)
|
||||
cached["word_result"] = word_result
|
||||
|
||||
logger.info(f"OCR Pipeline: words-first session {session_id}: "
|
||||
f"{len(cells)} cells ({duration:.2f}s), {n_rows} rows, {n_cols} cols")
|
||||
|
||||
await _append_pipeline_log(session_id, "words", {
|
||||
"grid_method": "words_first",
|
||||
"total_cells": len(cells),
|
||||
"non_empty_cells": word_result["summary"]["non_empty_cells"],
|
||||
"ocr_engine": used_engine,
|
||||
"layout": word_result["layout"],
|
||||
}, duration_ms=int(duration * 1000))
|
||||
|
||||
return {"session_id": session_id, **word_result}
|
||||
|
||||
|
||||
async def _v2_path(
|
||||
session_id: str,
|
||||
cached: Dict[str, Any],
|
||||
col_regions: List[PageRegion],
|
||||
row_geoms: List[RowGeometry],
|
||||
dewarped_bgr: np.ndarray,
|
||||
engine: str,
|
||||
pronunciation: str,
|
||||
skip_heal_gaps: bool,
|
||||
) -> dict:
|
||||
"""Cell-First OCR v2 non-streaming path."""
|
||||
t0 = time.time()
|
||||
ocr_img = create_ocr_image(dewarped_bgr)
|
||||
img_h, img_w = dewarped_bgr.shape[:2]
|
||||
|
||||
cells, columns_meta = build_cell_grid_v2(
|
||||
ocr_img, col_regions, row_geoms, img_w, img_h,
|
||||
ocr_engine=engine, img_bgr=dewarped_bgr,
|
||||
skip_heal_gaps=skip_heal_gaps,
|
||||
)
|
||||
duration = time.time() - t0
|
||||
|
||||
for cell in cells:
|
||||
cell.setdefault("zone_index", 0)
|
||||
|
||||
col_types = {c['type'] for c in columns_meta}
|
||||
is_vocab = bool(col_types & {'column_en', 'column_de'})
|
||||
n_content_rows = len([r for r in row_geoms if r.row_type == 'content'])
|
||||
n_cols = len(columns_meta)
|
||||
used_engine = cells[0].get("ocr_engine", "tesseract") if cells else engine
|
||||
|
||||
fix_cell_phonetics(cells, pronunciation=pronunciation)
|
||||
|
||||
word_result = {
|
||||
"cells": cells,
|
||||
"grid_shape": {"rows": n_content_rows, "cols": n_cols, "total_cells": len(cells)},
|
||||
"columns_used": columns_meta,
|
||||
"layout": "vocab" if is_vocab else "generic",
|
||||
"image_width": img_w,
|
||||
"image_height": img_h,
|
||||
"duration_seconds": round(duration, 2),
|
||||
"ocr_engine": used_engine,
|
||||
"summary": {
|
||||
"total_cells": len(cells),
|
||||
"non_empty_cells": sum(1 for c in cells if c.get("text")),
|
||||
"low_confidence": sum(1 for c in cells if 0 < c.get("confidence", 0) < 50),
|
||||
},
|
||||
}
|
||||
|
||||
has_text_col = 'column_text' in col_types
|
||||
if is_vocab or has_text_col:
|
||||
entries = _cells_to_vocab_entries(cells, columns_meta)
|
||||
entries = _fix_phonetic_brackets(entries, pronunciation=pronunciation)
|
||||
word_result["vocab_entries"] = entries
|
||||
word_result["entries"] = entries
|
||||
word_result["entry_count"] = len(entries)
|
||||
word_result["summary"]["total_entries"] = len(entries)
|
||||
word_result["summary"]["with_english"] = sum(1 for e in entries if e.get("english"))
|
||||
word_result["summary"]["with_german"] = sum(1 for e in entries if e.get("german"))
|
||||
|
||||
await update_session_db(session_id, word_result=word_result, current_step=8)
|
||||
cached["word_result"] = word_result
|
||||
|
||||
logger.info(f"OCR Pipeline: words session {session_id}: "
|
||||
f"layout={word_result['layout']}, "
|
||||
f"{len(cells)} cells ({duration:.2f}s), summary: {word_result['summary']}")
|
||||
|
||||
await _append_pipeline_log(session_id, "words", {
|
||||
"total_cells": len(cells),
|
||||
"non_empty_cells": word_result["summary"]["non_empty_cells"],
|
||||
"low_confidence_count": word_result["summary"]["low_confidence"],
|
||||
"ocr_engine": used_engine,
|
||||
"layout": word_result["layout"],
|
||||
"entry_count": word_result.get("entry_count", 0),
|
||||
}, duration_ms=int(duration * 1000))
|
||||
|
||||
return {"session_id": session_id, **word_result}
|
||||
303
klausur-service/backend/ocr_pipeline_words_stream.py
Normal file
303
klausur-service/backend/ocr_pipeline_words_stream.py
Normal file
@@ -0,0 +1,303 @@
|
||||
"""
|
||||
OCR Pipeline Words Stream — SSE streaming generators for word detection.
|
||||
|
||||
Extracted from ocr_pipeline_words.py.
|
||||
|
||||
Lizenz: Apache 2.0
|
||||
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import numpy as np
|
||||
from fastapi import Request
|
||||
|
||||
from cv_vocab_pipeline import (
|
||||
PageRegion,
|
||||
RowGeometry,
|
||||
_cells_to_vocab_entries,
|
||||
_fix_character_confusion,
|
||||
_fix_phonetic_brackets,
|
||||
fix_cell_phonetics,
|
||||
build_cell_grid_v2,
|
||||
build_cell_grid_v2_streaming,
|
||||
create_ocr_image,
|
||||
)
|
||||
from ocr_pipeline_session_store import update_session_db
|
||||
from ocr_pipeline_common import _cache
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def _word_batch_stream_generator(
|
||||
session_id: str,
|
||||
cached: Dict[str, Any],
|
||||
col_regions: List[PageRegion],
|
||||
row_geoms: List[RowGeometry],
|
||||
dewarped_bgr: np.ndarray,
|
||||
engine: str,
|
||||
pronunciation: str,
|
||||
request: Request,
|
||||
skip_heal_gaps: bool = False,
|
||||
):
|
||||
"""SSE generator that runs batch OCR (parallel) then streams results.
|
||||
|
||||
Uses build_cell_grid_v2 with ThreadPoolExecutor for parallel OCR,
|
||||
then emits all cells as SSE events.
|
||||
"""
|
||||
import asyncio
|
||||
|
||||
t0 = time.time()
|
||||
ocr_img = create_ocr_image(dewarped_bgr)
|
||||
img_h, img_w = dewarped_bgr.shape[:2]
|
||||
|
||||
_skip_types = {'column_ignore', 'header', 'footer', 'margin_top', 'margin_bottom', 'margin_left', 'margin_right'}
|
||||
n_content_rows = len([r for r in row_geoms if r.row_type == 'content'])
|
||||
n_cols = len([c for c in col_regions if c.type not in _skip_types])
|
||||
col_types = {c.type for c in col_regions if c.type not in _skip_types}
|
||||
is_vocab = bool(col_types & {'column_en', 'column_de'})
|
||||
total_cells = n_content_rows * n_cols
|
||||
|
||||
# 1. Send meta event immediately
|
||||
meta_event = {
|
||||
"type": "meta",
|
||||
"grid_shape": {"rows": n_content_rows, "cols": n_cols, "total_cells": total_cells},
|
||||
"layout": "vocab" if is_vocab else "generic",
|
||||
}
|
||||
yield f"data: {json.dumps(meta_event)}\n\n"
|
||||
|
||||
# 2. Send preparing event (keepalive for proxy)
|
||||
yield f"data: {json.dumps({'type': 'preparing', 'message': 'Cell-First OCR laeuft parallel...'})}\n\n"
|
||||
|
||||
# 3. Run batch OCR in thread pool with periodic keepalive events.
|
||||
loop = asyncio.get_event_loop()
|
||||
ocr_future = loop.run_in_executor(
|
||||
None,
|
||||
lambda: build_cell_grid_v2(
|
||||
ocr_img, col_regions, row_geoms, img_w, img_h,
|
||||
ocr_engine=engine, img_bgr=dewarped_bgr,
|
||||
skip_heal_gaps=skip_heal_gaps,
|
||||
),
|
||||
)
|
||||
|
||||
# Send keepalive events every 5 seconds while OCR runs
|
||||
keepalive_count = 0
|
||||
while not ocr_future.done():
|
||||
try:
|
||||
cells, columns_meta = await asyncio.wait_for(
|
||||
asyncio.shield(ocr_future), timeout=5.0,
|
||||
)
|
||||
break # OCR finished
|
||||
except asyncio.TimeoutError:
|
||||
keepalive_count += 1
|
||||
elapsed = int(time.time() - t0)
|
||||
yield f"data: {json.dumps({'type': 'keepalive', 'elapsed': elapsed, 'message': f'OCR laeuft... ({elapsed}s)'})}\n\n"
|
||||
if await request.is_disconnected():
|
||||
logger.info(f"SSE batch: client disconnected during OCR for {session_id}")
|
||||
ocr_future.cancel()
|
||||
return
|
||||
else:
|
||||
cells, columns_meta = ocr_future.result()
|
||||
|
||||
if await request.is_disconnected():
|
||||
logger.info(f"SSE batch: client disconnected after OCR for {session_id}")
|
||||
return
|
||||
|
||||
# 4. Apply IPA phonetic fixes
|
||||
fix_cell_phonetics(cells, pronunciation=pronunciation)
|
||||
|
||||
# 5. Send columns meta
|
||||
if columns_meta:
|
||||
yield f"data: {json.dumps({'type': 'columns', 'columns_used': columns_meta})}\n\n"
|
||||
|
||||
# 6. Stream all cells
|
||||
for idx, cell in enumerate(cells):
|
||||
cell_event = {
|
||||
"type": "cell",
|
||||
"cell": cell,
|
||||
"progress": {"current": idx + 1, "total": len(cells)},
|
||||
}
|
||||
yield f"data: {json.dumps(cell_event)}\n\n"
|
||||
|
||||
# 7. Build final result and persist
|
||||
duration = time.time() - t0
|
||||
used_engine = cells[0].get("ocr_engine", "tesseract") if cells else engine
|
||||
|
||||
word_result = {
|
||||
"cells": cells,
|
||||
"grid_shape": {"rows": n_content_rows, "cols": n_cols, "total_cells": len(cells)},
|
||||
"columns_used": columns_meta,
|
||||
"layout": "vocab" if is_vocab else "generic",
|
||||
"image_width": img_w,
|
||||
"image_height": img_h,
|
||||
"duration_seconds": round(duration, 2),
|
||||
"ocr_engine": used_engine,
|
||||
"summary": {
|
||||
"total_cells": len(cells),
|
||||
"non_empty_cells": sum(1 for c in cells if c.get("text")),
|
||||
"low_confidence": sum(1 for c in cells if 0 < c.get("confidence", 0) < 50),
|
||||
},
|
||||
}
|
||||
|
||||
vocab_entries = None
|
||||
has_text_col = 'column_text' in col_types
|
||||
if is_vocab or has_text_col:
|
||||
entries = _cells_to_vocab_entries(cells, columns_meta)
|
||||
entries = _fix_phonetic_brackets(entries, pronunciation=pronunciation)
|
||||
word_result["vocab_entries"] = entries
|
||||
word_result["entries"] = entries
|
||||
word_result["entry_count"] = len(entries)
|
||||
word_result["summary"]["total_entries"] = len(entries)
|
||||
word_result["summary"]["with_english"] = sum(1 for e in entries if e.get("english"))
|
||||
word_result["summary"]["with_german"] = sum(1 for e in entries if e.get("german"))
|
||||
vocab_entries = entries
|
||||
|
||||
await update_session_db(session_id, word_result=word_result, current_step=8)
|
||||
cached["word_result"] = word_result
|
||||
|
||||
logger.info(f"OCR Pipeline SSE batch: words session {session_id}: "
|
||||
f"layout={word_result['layout']}, {len(cells)} cells ({duration:.2f}s)")
|
||||
|
||||
# 8. Send complete event
|
||||
complete_event = {
|
||||
"type": "complete",
|
||||
"summary": word_result["summary"],
|
||||
"duration_seconds": round(duration, 2),
|
||||
"ocr_engine": used_engine,
|
||||
}
|
||||
if vocab_entries is not None:
|
||||
complete_event["vocab_entries"] = vocab_entries
|
||||
yield f"data: {json.dumps(complete_event)}\n\n"
|
||||
|
||||
|
||||
async def _word_stream_generator(
|
||||
session_id: str,
|
||||
cached: Dict[str, Any],
|
||||
col_regions: List[PageRegion],
|
||||
row_geoms: List[RowGeometry],
|
||||
dewarped_bgr: np.ndarray,
|
||||
engine: str,
|
||||
pronunciation: str,
|
||||
request: Request,
|
||||
):
|
||||
"""SSE generator that yields cell-by-cell OCR progress."""
|
||||
t0 = time.time()
|
||||
|
||||
ocr_img = create_ocr_image(dewarped_bgr)
|
||||
img_h, img_w = dewarped_bgr.shape[:2]
|
||||
|
||||
n_content_rows = len([r for r in row_geoms if r.row_type == 'content'])
|
||||
_skip_types = {'column_ignore', 'header', 'footer', 'margin_top', 'margin_bottom', 'margin_left', 'margin_right'}
|
||||
n_cols = len([c for c in col_regions if c.type not in _skip_types])
|
||||
|
||||
col_types = {c.type for c in col_regions if c.type not in _skip_types}
|
||||
is_vocab = bool(col_types & {'column_en', 'column_de'})
|
||||
|
||||
columns_meta = None
|
||||
total_cells = n_content_rows * n_cols
|
||||
|
||||
meta_event = {
|
||||
"type": "meta",
|
||||
"grid_shape": {"rows": n_content_rows, "cols": n_cols, "total_cells": total_cells},
|
||||
"layout": "vocab" if is_vocab else "generic",
|
||||
}
|
||||
yield f"data: {json.dumps(meta_event)}\n\n"
|
||||
|
||||
yield f"data: {json.dumps({'type': 'preparing', 'message': 'Cell-First OCR wird initialisiert...'})}\n\n"
|
||||
|
||||
all_cells: List[Dict[str, Any]] = []
|
||||
cell_idx = 0
|
||||
last_keepalive = time.time()
|
||||
|
||||
for cell, cols_meta, total in build_cell_grid_v2_streaming(
|
||||
ocr_img, col_regions, row_geoms, img_w, img_h,
|
||||
ocr_engine=engine, img_bgr=dewarped_bgr,
|
||||
):
|
||||
if await request.is_disconnected():
|
||||
logger.info(f"SSE: client disconnected during streaming for {session_id}")
|
||||
return
|
||||
|
||||
if columns_meta is None:
|
||||
columns_meta = cols_meta
|
||||
meta_update = {"type": "columns", "columns_used": cols_meta}
|
||||
yield f"data: {json.dumps(meta_update)}\n\n"
|
||||
|
||||
all_cells.append(cell)
|
||||
cell_idx += 1
|
||||
|
||||
cell_event = {
|
||||
"type": "cell",
|
||||
"cell": cell,
|
||||
"progress": {"current": cell_idx, "total": total},
|
||||
}
|
||||
yield f"data: {json.dumps(cell_event)}\n\n"
|
||||
|
||||
# All cells done
|
||||
duration = time.time() - t0
|
||||
if columns_meta is None:
|
||||
columns_meta = []
|
||||
|
||||
# Remove all-empty rows
|
||||
rows_with_text: set = set()
|
||||
for c in all_cells:
|
||||
if c.get("text", "").strip():
|
||||
rows_with_text.add(c["row_index"])
|
||||
before_filter = len(all_cells)
|
||||
all_cells = [c for c in all_cells if c["row_index"] in rows_with_text]
|
||||
empty_rows_removed = (before_filter - len(all_cells)) // max(n_cols, 1)
|
||||
if empty_rows_removed > 0:
|
||||
logger.info(f"SSE: removed {empty_rows_removed} all-empty rows after OCR")
|
||||
|
||||
used_engine = all_cells[0].get("ocr_engine", "tesseract") if all_cells else engine
|
||||
|
||||
fix_cell_phonetics(all_cells, pronunciation=pronunciation)
|
||||
|
||||
word_result = {
|
||||
"cells": all_cells,
|
||||
"grid_shape": {"rows": n_content_rows, "cols": n_cols, "total_cells": len(all_cells)},
|
||||
"columns_used": columns_meta,
|
||||
"layout": "vocab" if is_vocab else "generic",
|
||||
"image_width": img_w,
|
||||
"image_height": img_h,
|
||||
"duration_seconds": round(duration, 2),
|
||||
"ocr_engine": used_engine,
|
||||
"summary": {
|
||||
"total_cells": len(all_cells),
|
||||
"non_empty_cells": sum(1 for c in all_cells if c.get("text")),
|
||||
"low_confidence": sum(1 for c in all_cells if 0 < c.get("confidence", 0) < 50),
|
||||
},
|
||||
}
|
||||
|
||||
vocab_entries = None
|
||||
has_text_col = 'column_text' in col_types
|
||||
if is_vocab or has_text_col:
|
||||
entries = _cells_to_vocab_entries(all_cells, columns_meta)
|
||||
entries = _fix_character_confusion(entries)
|
||||
entries = _fix_phonetic_brackets(entries, pronunciation=pronunciation)
|
||||
word_result["vocab_entries"] = entries
|
||||
word_result["entries"] = entries
|
||||
word_result["entry_count"] = len(entries)
|
||||
word_result["summary"]["total_entries"] = len(entries)
|
||||
word_result["summary"]["with_english"] = sum(1 for e in entries if e.get("english"))
|
||||
word_result["summary"]["with_german"] = sum(1 for e in entries if e.get("german"))
|
||||
vocab_entries = entries
|
||||
|
||||
await update_session_db(session_id, word_result=word_result, current_step=8)
|
||||
cached["word_result"] = word_result
|
||||
|
||||
logger.info(f"OCR Pipeline SSE: words session {session_id}: "
|
||||
f"layout={word_result['layout']}, "
|
||||
f"{len(all_cells)} cells ({duration:.2f}s)")
|
||||
|
||||
complete_event = {
|
||||
"type": "complete",
|
||||
"summary": word_result["summary"],
|
||||
"duration_seconds": round(duration, 2),
|
||||
"ocr_engine": used_engine,
|
||||
}
|
||||
if vocab_entries is not None:
|
||||
complete_event["vocab_entries"] = vocab_entries
|
||||
yield f"data: {json.dumps(complete_event)}\n\n"
|
||||
Reference in New Issue
Block a user